BUG: more core fixes (#13665)

Fix some circular deps:
- move PromptValue into top level module bc both PromptTemplates and
OutputParsers import
- move tracer context vars to `tracers.context` and import them in
functions in `callbacks.manager`
- add core import tests
This commit is contained in:
Bagatur 2023-11-21 15:15:48 -08:00 committed by GitHub
parent 59df16ab92
commit c61e30632e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
99 changed files with 1000 additions and 576 deletions

View File

@ -63,6 +63,46 @@ jobs:
langchain-core-location: ../core
secrets: inherit
# It's possible that langchain works fine with the latest *published* langchain-core,
# but is broken with the langchain-core on `master`.
#
# We want to catch situations like that *before* releasing a new langchain-core, hence this test.
test-with-latest-langchain-core:
runs-on: ubuntu-latest
defaults:
run:
working-directory: ${{ env.WORKDIR }}
strategy:
matrix:
python-version:
- "3.8"
- "3.9"
- "3.10"
- "3.11"
name: test with unpublished langchain-core - Python ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v4
- name: Set up Python ${{ matrix.python-version }} + Poetry ${{ env.POETRY_VERSION }}
uses: "./.github/actions/poetry_setup"
with:
python-version: ${{ matrix.python-version }}
poetry-version: ${{ env.POETRY_VERSION }}
working-directory: ${{ env.WORKDIR }}
cache-key: unpublished-langchain-core
- name: Install dependencies
shell: bash
run: |
echo "Running tests with unpublished langchain, installing dependencies with poetry..."
poetry install
echo "Editably installing langchain-core outside of poetry, to avoid messing up lockfile..."
poetry run pip install -e ../core
- name: Run tests
run: make test
extended-tests:
runs-on: ubuntu-latest
defaults:

View File

@ -28,8 +28,6 @@ from langchain_core.callbacks.manager import (
CallbackManagerForToolRun,
ParentRunManager,
RunManager,
env_var_is_set,
register_configure_hook,
)
from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
@ -64,6 +62,4 @@ __all__ = [
"AsyncCallbackManagerForChainGroup",
"StdOutCallbackHandler",
"StreamingStdOutCallbackHandler",
"env_var_is_set",
"register_configure_hook",
]

View File

@ -3,13 +3,10 @@ from __future__ import annotations
import asyncio
import functools
import logging
import os
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager, contextmanager
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
AsyncGenerator,
Coroutine,
@ -18,7 +15,6 @@ from typing import (
List,
Optional,
Sequence,
Tuple,
Type,
TypeVar,
Union,
@ -26,14 +22,10 @@ from typing import (
)
from uuid import UUID
from langsmith import utils as ls_utils
from langsmith.run_helpers import get_run_tree_context
from tenacity import RetryCallState
from langchain_core.agents import (
AgentAction,
AgentFinish,
)
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.callbacks.base import (
BaseCallbackHandler,
BaseCallbackManager,
@ -48,32 +40,10 @@ from langchain_core.callbacks.stdout import StdOutCallbackHandler
from langchain_core.documents import Document
from langchain_core.messages import BaseMessage, get_buffer_string
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk, LLMResult
from langchain_core.tracers import run_collector
from langchain_core.tracers.langchain import (
LangChainTracer,
)
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.schemas import TracerSessionV1
from langchain_core.tracers.stdout import ConsoleCallbackHandler
if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
from langchain_core.utils.env import env_var_is_set
logger = logging.getLogger(__name__)
tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None
)
run_collector_var: ContextVar[
Optional[run_collector.RunCollectorCallbackHandler]
] = ContextVar( # noqa: E501
"run_collector", default=None
)
def _get_debug() -> bool:
from langchain_core.globals import get_debug
@ -81,123 +51,6 @@ def _get_debug() -> bool:
return get_debug()
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Get the Deprecated LangChainTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
TracerSessionV1: The LangChainTracer session.
Example:
>>> with tracing_enabled() as session:
... # Use the LangChainTracer session
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
try:
tracing_callback_var.set(cb)
yield session
finally:
tracing_callback_var.set(None)
@contextmanager
def tracing_v2_enabled(
project_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
client: Optional[LangSmithClient] = None,
) -> Generator[LangChainTracer, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith.
Args:
project_name (str, optional): The name of the project.
Defaults to "default".
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The tags to add to the run.
Defaults to None.
Returns:
None
Example:
>>> with tracing_v2_enabled():
... # LangChain code will automatically be traced
You can use this to fetch the LangSmith run URL:
>>> with tracing_v2_enabled() as cb:
... chain.invoke("foo")
... run_url = cb.get_run_url()
"""
if isinstance(example_id, str):
example_id = UUID(example_id)
cb = LangChainTracer(
example_id=example_id,
project_name=project_name,
tags=tags,
client=client,
)
try:
tracing_v2_callback_var.set(cb)
yield cb
finally:
tracing_v2_callback_var.set(None)
@contextmanager
def collect_runs() -> Generator[run_collector.RunCollectorCallbackHandler, None, None]:
"""Collect all run traces in context.
Returns:
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
Example:
>>> with collect_runs() as runs_cb:
chain.invoke("foo")
run_id = runs_cb.traced_runs[0].id
"""
cb = run_collector.RunCollectorCallbackHandler()
run_collector_var.set(cb)
yield cb
run_collector_var.set(None)
def _get_trace_callbacks(
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
) -> Callbacks:
if _tracing_v2_is_enabled():
project_name_ = project_name or _get_tracer_project()
tracer = tracing_v2_callback_var.get() or LangChainTracer(
project_name=project_name_,
example_id=example_id,
)
if callback_manager is None:
cb = cast(Callbacks, [tracer])
else:
if not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
callback_manager.add_handler(tracer, True)
# If it already has a LangChainTracer, we don't need to add another one.
# this would likely mess up the trace hierarchy.
cb = callback_manager
else:
cb = None
return cb
@contextmanager
def trace_as_chain_group(
group_name: str,
@ -239,6 +92,8 @@ def trace_as_chain_group(
res = llm.predict(llm_input, callbacks=manager)
manager.on_chain_end({"output": res})
""" # noqa: E501
from langchain_core.tracers.context import _get_trace_callbacks
cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager
)
@ -310,6 +165,8 @@ async def atrace_as_chain_group(
res = await llm.apredict(llm_input, callbacks=manager)
await manager.on_chain_end({"output": res})
""" # noqa: E501
from langchain_core.tracers.context import _get_trace_callbacks
cb = _get_trace_callbacks(
project_name, example_id, callback_manager=callback_manager
)
@ -1850,88 +1707,9 @@ class AsyncCallbackManagerForChainGroup(AsyncCallbackManager):
T = TypeVar("T", CallbackManager, AsyncCallbackManager)
def env_var_is_set(env_var: str) -> bool:
"""Check if an environment variable is set.
Args:
env_var (str): The name of the environment variable.
Returns:
bool: True if the environment variable is set, False otherwise.
"""
return env_var in os.environ and os.environ[env_var] not in (
"",
"0",
"false",
"False",
)
def _tracing_v2_is_enabled() -> bool:
return (
env_var_is_set("LANGCHAIN_TRACING_V2")
or tracing_v2_callback_var.get() is not None
or get_run_tree_context() is not None
)
def _get_tracer_project() -> str:
run_tree = get_run_tree_context()
return getattr(
run_tree,
"session_name",
getattr(
# Note, if people are trying to nest @traceable functions and the
# tracing_v2_enabled context manager, this will likely mess up the
# tree structure.
tracing_v2_callback_var.get(),
"project",
# Have to set this to a string even though it always will return
# a string because `get_tracer_project` technically can return
# None, but only when a specific argument is supplied.
# Therefore, this just tricks the mypy type checker
str(ls_utils.get_tracer_project()),
),
)
_configure_hooks: List[
Tuple[
ContextVar[Optional[BaseCallbackHandler]],
bool,
Optional[Type[BaseCallbackHandler]],
Optional[str],
]
] = []
H = TypeVar("H", bound=BaseCallbackHandler, covariant=True)
def register_configure_hook(
context_var: ContextVar[Optional[Any]],
inheritable: bool,
handle_class: Optional[Type[BaseCallbackHandler]] = None,
env_var: Optional[str] = None,
) -> None:
if env_var is not None and handle_class is None:
raise ValueError(
"If env_var is set, handle_class must also be set to a non-None value."
)
_configure_hooks.append(
(
# the typings of ContextVar do not have the generic arg set as covariant
# so we have to cast it
cast(ContextVar[Optional[BaseCallbackHandler]], context_var),
inheritable,
handle_class,
env_var,
)
)
register_configure_hook(run_collector_var, False)
def _configure(
callback_manager_cls: Type[T],
inheritable_callbacks: Callbacks = None,
@ -1962,6 +1740,14 @@ def _configure(
Returns:
T: The configured callback manager.
"""
from langchain_core.tracers.context import (
_configure_hooks,
_get_tracer_project,
_tracing_v2_is_enabled,
tracing_callback_var,
tracing_v2_callback_var,
)
run_tree = get_run_tree_context()
parent_run_id = None if run_tree is None else getattr(run_tree, "id")
callback_manager = callback_manager_cls(handlers=[], parent_run_id=parent_run_id)
@ -2009,6 +1795,10 @@ def _configure(
tracer_project = _get_tracer_project()
debug = _get_debug()
if verbose or debug or tracing_enabled_ or tracing_v2_enabled_:
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.stdout import ConsoleCallbackHandler
if verbose and not any(
isinstance(handler, StdOutCallbackHandler)
for handler in callback_manager.handlers

View File

@ -19,7 +19,7 @@ _llm_cache: Optional["BaseCache"] = None
def set_verbose(value: bool) -> None:
"""Set a new value for the `verbose` global setting."""
try:
import langchain
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
@ -47,7 +47,7 @@ def set_verbose(value: bool) -> None:
def get_verbose() -> bool:
"""Get the value of the `verbose` global setting."""
try:
import langchain
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
@ -80,7 +80,7 @@ def get_verbose() -> bool:
def set_debug(value: bool) -> None:
"""Set a new value for the `debug` global setting."""
try:
import langchain
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
@ -106,7 +106,7 @@ def set_debug(value: bool) -> None:
def get_debug() -> bool:
"""Get the value of the `debug` global setting."""
try:
import langchain
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
@ -137,7 +137,7 @@ def get_debug() -> bool:
def set_llm_cache(value: Optional["BaseCache"]) -> None:
"""Set a new LLM cache, overwriting the previous value, if any."""
try:
import langchain
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.
@ -165,7 +165,7 @@ def set_llm_cache(value: Optional["BaseCache"]) -> None:
def get_llm_cache() -> "BaseCache":
"""Get the value of the `llm_cache` global setting."""
try:
import langchain
import langchain # type: ignore[import]
# We're about to run some deprecated code, don't report warnings from it.
# The user called the correct (non-deprecated) code path and shouldn't get warnings.

View File

@ -1,4 +1,8 @@
from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput
from langchain_core.language_models.base import (
BaseLanguageModel,
LanguageModelInput,
get_tokenizer,
)
from langchain_core.language_models.chat_models import BaseChatModel, SimpleChatModel
from langchain_core.language_models.llms import LLM, BaseLLM
@ -9,4 +13,5 @@ __all__ = [
"BaseLLM",
"LLM",
"LanguageModelInput",
"get_tokenizer",
]

View File

@ -17,7 +17,7 @@ from typing_extensions import TypeAlias
from langchain_core.messages import AnyMessage, BaseMessage, get_buffer_string
from langchain_core.outputs import LLMResult
from langchain_core.prompts import PromptValue
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import RunnableSerializable
from langchain_core.utils import get_pydantic_field_names
@ -28,7 +28,7 @@ if TYPE_CHECKING:
@lru_cache(maxsize=None) # Cache the tokenizer
def get_tokenizer() -> Any:
try:
from transformers import GPT2TokenizerFast
from transformers import GPT2TokenizerFast # type: ignore[import]
except ImportError:
raise ImportError(
"Could not import transformers python package. "
@ -74,8 +74,10 @@ class BaseLanguageModel(
@property
def InputType(self) -> TypeAlias:
"""Get the input type for this runnable."""
from langchain_core.prompts.chat import ChatPromptValueConcrete
from langchain_core.prompts.string import StringPromptValue
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
# This is a version of LanguageModelInput which replaces the abstract
# base class BaseMessage with a union of its subclasses, which makes

View File

@ -39,7 +39,7 @@ from langchain_core.outputs import (
LLMResult,
RunInfo,
)
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.runnables import RunnableConfig

View File

@ -50,7 +50,7 @@ from langchain_core.language_models.base import BaseLanguageModel, LanguageModel
from langchain_core.load import dumpd
from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string
from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo
from langchain_core.prompts import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue
from langchain_core.pydantic_v1 import Field, root_validator, validator
from langchain_core.runnables import RunnableConfig, get_config_list

View File

@ -9,7 +9,7 @@ from langchain_core.output_parsers.list import (
MarkdownListOutputParser,
NumberedListOutputParser,
)
from langchain_core.output_parsers.str import StrOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.output_parsers.transform import (
BaseCumulativeTransformOutputParser,
BaseTransformOutputParser,

View File

@ -17,11 +17,8 @@ from typing import (
from typing_extensions import get_args
from langchain_core.messages import AnyMessage, BaseMessage
from langchain_core.outputs import (
ChatGeneration,
Generation,
)
from langchain_core.prompts.value import PromptValue
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.prompt_values import PromptValue
from langchain_core.runnables import RunnableConfig, RunnableSerializable
T = TypeVar("T")

View File

@ -0,0 +1,76 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Literal, Sequence
from langchain_core.load.serializable import Serializable
from langchain_core.messages import (
AnyMessage,
BaseMessage,
HumanMessage,
get_buffer_string,
)
class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@abstractmethod
def to_string(self) -> str:
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""
class StringPromptValue(PromptValue):
"""String prompt value."""
text: str
"""Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
class ChatPromptValue(PromptValue):
"""Chat prompt value.
A type of a prompt value that is built from messages.
"""
messages: Sequence[BaseMessage]
"""List of messages."""
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return list(self.messages)
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
messages: Sequence[AnyMessage]
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"

View File

@ -23,9 +23,6 @@ from multiple components. Prompt classes and functions make constructing
AIMessagePromptTemplate
SystemMessagePromptTemplate
PromptValue --> StringPromptValue
ChatPromptValue
""" # noqa: E501
from langchain_core.prompts.base import BasePromptTemplate, format_document
from langchain_core.prompts.chat import (
@ -33,8 +30,6 @@ from langchain_core.prompts.chat import (
BaseChatPromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
ChatPromptValue,
ChatPromptValueConcrete,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
@ -47,7 +42,13 @@ from langchain_core.prompts.few_shot_with_templates import FewShotPromptWithTemp
from langchain_core.prompts.loading import load_prompt
from langchain_core.prompts.pipeline import PipelinePromptTemplate
from langchain_core.prompts.prompt import Prompt, PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate, StringPromptValue
from langchain_core.prompts.string import (
StringPromptTemplate,
check_valid_template,
get_template_variables,
jinja2_formatter,
validate_jinja2,
)
__all__ = [
"AIMessagePromptTemplate",
@ -55,8 +56,6 @@ __all__ = [
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"ChatPromptValue",
"ChatPromptValueConcrete",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
@ -65,12 +64,12 @@ __all__ = [
"PipelinePromptTemplate",
"Prompt",
"PromptTemplate",
"PromptValue",
"StringPromptValue",
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"format_document",
"check_valid_template",
"get_template_variables",
"jinja2_formatter",
"validate_jinja2",
]
from langchain_core.prompts.value import PromptValue

View File

@ -8,8 +8,8 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, Type, Union
import yaml
from langchain_core.documents import Document
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts.value import PromptValue
from langchain_core.output_parsers.base import BaseOutputParser
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import BaseModel, Field, create_model, root_validator
from langchain_core.runnables import RunnableConfig, RunnableSerializable
@ -40,8 +40,10 @@ class BasePromptTemplate(RunnableSerializable[Dict, PromptValue], ABC):
@property
def OutputType(self) -> Any:
from langchain_core.prompts.chat import ChatPromptValueConcrete
from langchain_core.prompts.string import StringPromptValue
from langchain_core.prompt_values import (
ChatPromptValueConcrete,
StringPromptValue,
)
return Union[StringPromptValue, ChatPromptValueConcrete]

View File

@ -8,7 +8,6 @@ from typing import (
Callable,
Dict,
List,
Literal,
Sequence,
Set,
Tuple,
@ -27,12 +26,11 @@ from langchain_core.messages import (
ChatMessage,
HumanMessage,
SystemMessage,
get_buffer_string,
)
from langchain_core.prompt_values import ChatPromptValue, PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompts.string import StringPromptTemplate
from langchain_core.prompts.value import PromptValue
from langchain_core.pydantic_v1 import Field, root_validator
@ -277,33 +275,6 @@ class SystemMessagePromptTemplate(BaseStringMessagePromptTemplate):
return SystemMessage(content=text, additional_kwargs=self.additional_kwargs)
class ChatPromptValue(PromptValue):
"""Chat prompt value.
A type of a prompt value that is built from messages.
"""
messages: Sequence[BaseMessage]
"""List of messages."""
def to_string(self) -> str:
"""Return prompt as string."""
return get_buffer_string(self.messages)
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of messages."""
return list(self.messages)
class ChatPromptValueConcrete(ChatPromptValue):
"""Chat prompt value which explicitly lists out the message types it accepts.
For use in external schemas."""
messages: Sequence[AnyMessage]
type: Literal["ChatPromptValueConcrete"] = "ChatPromptValueConcrete"
class BaseChatPromptTemplate(BasePromptTemplate, ABC):
"""Base class for chat prompt templates."""

View File

@ -6,7 +6,7 @@ from typing import Callable, Dict, Union
import yaml
from langchain_core.output_parsers import StrOutputParser
from langchain_core.output_parsers.string import StrOutputParser
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.few_shot import FewShotPromptTemplate
from langchain_core.prompts.prompt import PromptTemplate

View File

@ -1,8 +1,8 @@
from typing import Any, Dict, List, Tuple
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.chat import BaseChatPromptTemplate
from langchain_core.prompts.value import PromptValue
from langchain_core.pydantic_v1 import root_validator

View File

@ -4,11 +4,10 @@ from __future__ import annotations
import warnings
from abc import ABC
from string import Formatter
from typing import Any, Callable, Dict, List, Literal, Set
from typing import Any, Callable, Dict, List, Set
from langchain_core.messages import BaseMessage, HumanMessage
from langchain_core.prompt_values import PromptValue, StringPromptValue
from langchain_core.prompts.base import BasePromptTemplate
from langchain_core.prompts.value import PromptValue
from langchain_core.utils.formatting import formatter
@ -149,22 +148,6 @@ def get_template_variables(template: str, template_format: str) -> List[str]:
return sorted(input_variables)
class StringPromptValue(PromptValue):
"""String prompt value."""
text: str
"""Prompt text."""
type: Literal["StringPromptValue"] = "StringPromptValue"
def to_string(self) -> str:
"""Return prompt as string."""
return self.text
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as messages."""
return [HumanMessage(content=self.text)]
class StringPromptTemplate(BasePromptTemplate, ABC):
"""String prompt that exposes the format method, returning a prompt."""

View File

@ -1,28 +0,0 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List
from langchain_core.load.serializable import Serializable
from langchain_core.messages import BaseMessage
class PromptValue(Serializable, ABC):
"""Base abstract class for inputs to any language model.
PromptValues can be converted to both LLM (pure text-generation) inputs and
ChatModel inputs.
"""
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this class is serializable."""
return True
@abstractmethod
def to_string(self) -> str:
"""Return prompt value as string."""
@abstractmethod
def to_messages(self) -> List[BaseMessage]:
"""Return prompt as a list of Messages."""

View File

@ -9,7 +9,7 @@ from uuid import UUID
from tenacity import RetryCallState
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.callbacks.base import BaseCallbackHandler
from langchain_core.documents import Document
from langchain_core.exceptions import TracerException
from langchain_core.load import dumpd

View File

@ -0,0 +1,226 @@
from __future__ import annotations
from contextlib import contextmanager
from contextvars import ContextVar
from typing import (
TYPE_CHECKING,
Any,
Generator,
List,
Optional,
Tuple,
Type,
Union,
cast,
)
from uuid import UUID
from langsmith import utils as ls_utils
from langsmith.run_helpers import get_run_tree_context
from langchain_core.tracers.langchain import LangChainTracer
from langchain_core.tracers.langchain_v1 import LangChainTracerV1
from langchain_core.tracers.run_collector import RunCollectorCallbackHandler
from langchain_core.tracers.schemas import TracerSessionV1
from langchain_core.utils.env import env_var_is_set
if TYPE_CHECKING:
from langsmith import Client as LangSmithClient
from langchain_core.callbacks.base import BaseCallbackHandler, Callbacks
from langchain_core.callbacks.manager import AsyncCallbackManager, CallbackManager
tracing_callback_var: ContextVar[Optional[LangChainTracerV1]] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
tracing_v2_callback_var: ContextVar[Optional[LangChainTracer]] = ContextVar( # noqa: E501
"tracing_callback_v2", default=None
)
run_collector_var: ContextVar[Optional[RunCollectorCallbackHandler]] = ContextVar( # noqa: E501
"run_collector", default=None
)
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Get the Deprecated LangChainTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
TracerSessionV1: The LangChainTracer session.
Example:
>>> with tracing_enabled() as session:
... # Use the LangChainTracer session
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
try:
tracing_callback_var.set(cb)
yield session
finally:
tracing_callback_var.set(None)
@contextmanager
def tracing_v2_enabled(
project_name: Optional[str] = None,
*,
example_id: Optional[Union[str, UUID]] = None,
tags: Optional[List[str]] = None,
client: Optional[LangSmithClient] = None,
) -> Generator[LangChainTracer, None, None]:
"""Instruct LangChain to log all runs in context to LangSmith.
Args:
project_name (str, optional): The name of the project.
Defaults to "default".
example_id (str or UUID, optional): The ID of the example.
Defaults to None.
tags (List[str], optional): The tags to add to the run.
Defaults to None.
Returns:
None
Example:
>>> with tracing_v2_enabled():
... # LangChain code will automatically be traced
You can use this to fetch the LangSmith run URL:
>>> with tracing_v2_enabled() as cb:
... chain.invoke("foo")
... run_url = cb.get_run_url()
"""
if isinstance(example_id, str):
example_id = UUID(example_id)
cb = LangChainTracer(
example_id=example_id,
project_name=project_name,
tags=tags,
client=client,
)
try:
tracing_v2_callback_var.set(cb)
yield cb
finally:
tracing_v2_callback_var.set(None)
@contextmanager
def collect_runs() -> Generator[RunCollectorCallbackHandler, None, None]:
"""Collect all run traces in context.
Returns:
run_collector.RunCollectorCallbackHandler: The run collector callback handler.
Example:
>>> with collect_runs() as runs_cb:
chain.invoke("foo")
run_id = runs_cb.traced_runs[0].id
"""
cb = RunCollectorCallbackHandler()
run_collector_var.set(cb)
yield cb
run_collector_var.set(None)
def _get_trace_callbacks(
project_name: Optional[str] = None,
example_id: Optional[Union[str, UUID]] = None,
callback_manager: Optional[Union[CallbackManager, AsyncCallbackManager]] = None,
) -> Callbacks:
if _tracing_v2_is_enabled():
project_name_ = project_name or _get_tracer_project()
tracer = tracing_v2_callback_var.get() or LangChainTracer(
project_name=project_name_,
example_id=example_id,
)
if callback_manager is None:
from langchain_core.callbacks.base import Callbacks
cb = cast(Callbacks, [tracer])
else:
if not any(
isinstance(handler, LangChainTracer)
for handler in callback_manager.handlers
):
callback_manager.add_handler(tracer, True)
# If it already has a LangChainTracer, we don't need to add another one.
# this would likely mess up the trace hierarchy.
cb = callback_manager
else:
cb = None
return cb
def _tracing_v2_is_enabled() -> bool:
return (
env_var_is_set("LANGCHAIN_TRACING_V2")
or tracing_v2_callback_var.get() is not None
or get_run_tree_context() is not None
)
def _get_tracer_project() -> str:
run_tree = get_run_tree_context()
return getattr(
run_tree,
"session_name",
getattr(
# Note, if people are trying to nest @traceable functions and the
# tracing_v2_enabled context manager, this will likely mess up the
# tree structure.
tracing_v2_callback_var.get(),
"project",
# Have to set this to a string even though it always will return
# a string because `get_tracer_project` technically can return
# None, but only when a specific argument is supplied.
# Therefore, this just tricks the mypy type checker
str(ls_utils.get_tracer_project()),
),
)
_configure_hooks: List[
Tuple[
ContextVar[Optional[BaseCallbackHandler]],
bool,
Optional[Type[BaseCallbackHandler]],
Optional[str],
]
] = []
def register_configure_hook(
context_var: ContextVar[Optional[Any]],
inheritable: bool,
handle_class: Optional[Type[BaseCallbackHandler]] = None,
env_var: Optional[str] = None,
) -> None:
if env_var is not None and handle_class is None:
raise ValueError(
"If env_var is set, handle_class must also be set to a non-None value."
)
from langchain_core.callbacks.base import BaseCallbackHandler
_configure_hooks.append(
(
# the typings of ContextVar do not have the generic arg set as covariant
# so we have to cast it
cast(ContextVar[Optional[BaseCallbackHandler]], context_var),
inheritable,
handle_class,
env_var,
)
)
register_configure_hook(run_collector_var, False)

View File

@ -11,9 +11,9 @@ from uuid import UUID
import langsmith
from langsmith.evaluation.evaluator import EvaluationResult, EvaluationResults
from langchain_core.callbacks import manager
from langchain_core.tracers import langchain as langchain_tracer
from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.context import tracing_v2_enabled
from langchain_core.tracers.langchain import _get_executor
from langchain_core.tracers.schemas import Run
@ -115,7 +115,7 @@ class EvaluatorCallbackHandler(BaseTracer):
if self.project_name is None:
eval_result = self.client.evaluate_run(run, evaluator)
eval_results = [eval_result]
with manager.tracing_v2_enabled(
with tracing_v2_enabled(
project_name=self.project_name, tags=["eval"], client=self.client
) as cb:
reference_example = (

View File

@ -15,7 +15,7 @@ from typing import (
)
from uuid import UUID
import jsonpatch
import jsonpatch # type: ignore[import]
from anyio import create_memory_object_stream
from langchain_core.load import load

View File

@ -0,0 +1,20 @@
from __future__ import annotations
import os
def env_var_is_set(env_var: str) -> bool:
"""Check if an environment variable is set.
Args:
env_var (str): The name of the environment variable.
Returns:
bool: True if the environment variable is set, False otherwise.
"""
return env_var in os.environ and os.environ[env_var] not in (
"",
"0",
"false",
"False",
)

View File

@ -1,6 +1,6 @@
[tool.poetry]
name = "langchain-core"
version = "0.0.2"
version = "0.0.3"
description = "Building applications with LLMs through composability"
authors = []
license = "MIT"
@ -51,7 +51,6 @@ select = [
]
[tool.mypy]
ignore_missing_imports = "True"
disallow_untyped_defs = "True"
exclude = ["notebooks", "examples", "example_data", "langchain_core/pydantic"]

View File

@ -0,0 +1,37 @@
from langchain_core.callbacks import __all__
EXPECTED_ALL = [
"RetrieverManagerMixin",
"LLMManagerMixin",
"ChainManagerMixin",
"ToolManagerMixin",
"Callbacks",
"CallbackManagerMixin",
"RunManagerMixin",
"BaseCallbackHandler",
"AsyncCallbackHandler",
"BaseCallbackManager",
"BaseRunManager",
"RunManager",
"ParentRunManager",
"AsyncRunManager",
"AsyncParentRunManager",
"CallbackManagerForLLMRun",
"AsyncCallbackManagerForLLMRun",
"CallbackManagerForChainRun",
"AsyncCallbackManagerForChainRun",
"CallbackManagerForToolRun",
"AsyncCallbackManagerForToolRun",
"CallbackManagerForRetrieverRun",
"AsyncCallbackManagerForRetrieverRun",
"CallbackManager",
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"StdOutCallbackHandler",
"StreamingStdOutCallbackHandler",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,13 @@
from langchain_core.example_selectors import __all__
EXPECTED_ALL = [
"BaseExampleSelector",
"LengthBasedExampleSelector",
"MaxMarginalRelevanceExampleSelector",
"SemanticSimilarityExampleSelector",
"sorted_values",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,15 @@
from langchain_core.language_models import __all__
EXPECTED_ALL = [
"BaseLanguageModel",
"BaseChatModel",
"SimpleChatModel",
"BaseLLM",
"LLM",
"LanguageModelInput",
"get_tokenizer",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain_core.load import __all__
EXPECTED_ALL = ["dumpd", "dumps", "load", "loads", "Serializable"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,28 @@
from langchain_core.messages import __all__
EXPECTED_ALL = [
"AIMessage",
"AIMessageChunk",
"AnyMessage",
"BaseMessage",
"BaseMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"SystemMessage",
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"get_buffer_string",
"messages_from_dict",
"messages_to_dict",
"message_to_dict",
"merge_content",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,18 @@
from langchain_core.output_parsers import __all__
EXPECTED_ALL = [
"BaseLLMOutputParser",
"BaseGenerationOutputParser",
"BaseOutputParser",
"ListOutputParser",
"CommaSeparatedListOutputParser",
"NumberedListOutputParser",
"MarkdownListOutputParser",
"StrOutputParser",
"BaseTransformOutputParser",
"BaseCumulativeTransformOutputParser",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,15 @@
from langchain_core.outputs import __all__
EXPECTED_ALL = [
"ChatGeneration",
"ChatGenerationChunk",
"ChatResult",
"Generation",
"GenerationChunk",
"LLMResult",
"RunInfo",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -10,6 +10,7 @@ from langchain_core.messages import (
SystemMessage,
get_buffer_string,
)
from langchain_core.prompt_values import ChatPromptValue
from langchain_core.prompts import PromptTemplate
from langchain_core.prompts.chat import (
AIMessagePromptTemplate,
@ -17,7 +18,6 @@ from langchain_core.prompts.chat import (
ChatMessage,
ChatMessagePromptTemplate,
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
SystemMessagePromptTemplate,
_convert_to_message,

View File

@ -6,14 +6,10 @@ EXPECTED_ALL = [
"BasePromptTemplate",
"ChatMessagePromptTemplate",
"ChatPromptTemplate",
"ChatPromptValueConcrete",
"FewShotPromptTemplate",
"FewShotPromptWithTemplates",
"FewShotChatMessagePromptTemplate",
"format_document",
"ChatPromptValue",
"PromptValue",
"StringPromptValue",
"HumanMessagePromptTemplate",
"MessagesPlaceholder",
"PipelinePromptTemplate",
@ -22,6 +18,10 @@ EXPECTED_ALL = [
"StringPromptTemplate",
"SystemMessagePromptTemplate",
"load_prompt",
"check_valid_template",
"get_template_variables",
"jinja2_formatter",
"validate_jinja2",
]

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,28 @@
from langchain_core.runnables import __all__
EXPECTED_ALL = [
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"patch_config",
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableSerializable",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",
"RunnableGenerator",
"RunnableLambda",
"RunnableMap",
"RunnableParallel",
"RunnablePassthrough",
"RunnableSequence",
"RunnableWithFallbacks",
"get_config_list",
"add",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -23,7 +23,6 @@ from typing_extensions import TypedDict
from langchain_core.callbacks.manager import (
Callbacks,
atrace_as_chain_group,
collect_runs,
trace_as_chain_group,
)
from langchain_core.documents import Document
@ -39,13 +38,12 @@ from langchain_core.output_parsers import (
CommaSeparatedListOutputParser,
StrOutputParser,
)
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_core.prompts import (
ChatPromptTemplate,
ChatPromptValue,
HumanMessagePromptTemplate,
MessagesPlaceholder,
PromptTemplate,
StringPromptValue,
SystemMessagePromptTemplate,
)
from langchain_core.pydantic_v1 import BaseModel
@ -75,6 +73,7 @@ from langchain_core.tracers import (
RunLog,
RunLogPatch,
)
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.chat_model import FakeListChatModel
from tests.unit_tests.fake.llm import FakeListLLM, FakeStreamingListLLM

View File

@ -0,0 +1,15 @@
from langchain_core.tracers import __all__
EXPECTED_ALL = [
"BaseTracer",
"EvaluatorCallbackHandler",
"LangChainTracer",
"ConsoleCallbackHandler",
"Run",
"RunLog",
"RunLogPatch",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -2,12 +2,12 @@
import uuid
from langchain.callbacks import collect_runs
from tests.unit_tests.llms.fake_llm import FakeLLM
from langchain_core.tracers.context import collect_runs
from tests.unit_tests.fake.llm import FakeListLLM
def test_collect_runs() -> None:
llm = FakeLLM(queries={"hi": "hello"}, sequential_responses=True)
llm = FakeListLLM(responses=["hello"])
with collect_runs() as cb:
llm.predict("hi")
assert cb.traced_runs

View File

@ -2,9 +2,8 @@ import uuid
from typing import Any, Callable, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.prompts.base import StringPromptValue
from langchain.prompts.chat import ChatPromptValue
from langchain.schema import AIMessage, HumanMessage
from langchain_core.prompt_values import ChatPromptValue, StringPromptValue
from langchain_experimental.comprehend_moderation.pii import ComprehendPII
from langchain_experimental.comprehend_moderation.prompt_safety import (

View File

@ -29,14 +29,16 @@ from langchain_core.callbacks.manager import (
RunManager,
ahandle_event,
atrace_as_chain_group,
collect_runs,
env_var_is_set,
handle_event,
register_configure_hook,
trace_as_chain_group,
)
from langchain_core.tracers.context import (
collect_runs,
register_configure_hook,
tracing_enabled,
tracing_v2_enabled,
)
from langchain_core.utils.env import env_var_is_set
from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
@ -122,6 +124,6 @@ __all__ = [
"trace_as_chain_group",
"handle_event",
"ahandle_event",
"env_var_is_set",
"Callbacks",
"env_var_is_set",
]

View File

@ -12,8 +12,8 @@ from langchain_core.load.dump import dumpd
from langchain_core.messages import BaseMessage
from langchain_core.output_parsers import BaseLLMOutputParser, StrOutputParser
from langchain_core.outputs import ChatGeneration, Generation, LLMResult
from langchain_core.prompts import BasePromptTemplate, PromptValue
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
from langchain_core.pydantic_v1 import Extra, Field
from langchain_core.runnables import (
Runnable,

View File

@ -9,7 +9,7 @@ from langchain_core.messages import (
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.prompts import PromptValue
from langchain_core.prompt_values import PromptValue
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,

View File

@ -13,7 +13,7 @@ from typing import (
from langchain_core.language_models import BaseLanguageModel
from langchain_core.outputs import GenerationChunk
from langchain_core.prompts import PromptValue
from langchain_core.prompt_values import PromptValue
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
check_package_version,

View File

@ -5,8 +5,8 @@ from typing import Any, TypeVar
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import BaseLanguageModel
from langchain_core.output_parsers import BaseOutputParser
from langchain_core.prompts import BasePromptTemplate, PromptValue
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, PromptTemplate
NAIVE_COMPLETION_RETRY = """Prompt:
{prompt}

View File

@ -1,7 +1,6 @@
from langchain_core.prompts import BasePromptTemplate
from langchain_core.prompts.base import (
from langchain_core.prompts import (
BasePromptTemplate,
StringPromptTemplate,
StringPromptValue,
check_valid_template,
get_template_variables,
jinja2_formatter,
@ -13,7 +12,6 @@ __all__ = [
"validate_jinja2",
"check_valid_template",
"get_template_variables",
"StringPromptValue",
"StringPromptTemplate",
"BasePromptTemplate",
]

View File

@ -5,8 +5,6 @@ from langchain_core.prompts.chat import (
BaseStringMessagePromptTemplate,
ChatMessagePromptTemplate,
ChatPromptTemplate,
ChatPromptValue,
ChatPromptValueConcrete,
HumanMessagePromptTemplate,
MessagesPlaceholder,
SystemMessagePromptTemplate,
@ -20,8 +18,6 @@ __all__ = [
"HumanMessagePromptTemplate",
"AIMessagePromptTemplate",
"SystemMessagePromptTemplate",
"ChatPromptValue",
"ChatPromptValueConcrete",
"BaseChatPromptTemplate",
"ChatPromptTemplate",
]

View File

@ -31,12 +31,16 @@ from langchain_core.outputs import (
LLMResult,
RunInfo,
)
from langchain_core.prompts import BasePromptTemplate, PromptValue, format_document
from langchain_core.prompt_values import PromptValue
from langchain_core.prompts import BasePromptTemplate, format_document
from langchain_core.retrievers import BaseRetriever
from langchain_core.stores import BaseStore
RUN_KEY = "__run"
# Backwards compatibility.
Memory = BaseMemory
_message_to_dict = message_to_dict
__all__ = [
"BaseCache",
@ -56,6 +60,7 @@ __all__ = [
"messages_from_dict",
"messages_to_dict",
"message_to_dict",
"_message_to_dict",
"_message_from_dict",
"get_buffer_string",
"RunInfo",

View File

@ -16,14 +16,16 @@ from langchain_core.callbacks.manager import (
CallbackManagerForToolRun,
ParentRunManager,
RunManager,
collect_runs,
env_var_is_set,
handle_event,
register_configure_hook,
trace_as_chain_group,
)
from langchain_core.tracers.context import (
collect_runs,
register_configure_hook,
tracing_enabled,
tracing_v2_enabled,
)
from langchain_core.utils.env import env_var_is_set
__all__ = [
"tracing_enabled",
@ -48,6 +50,6 @@ __all__ = [
"CallbackManagerForChainGroup",
"AsyncCallbackManager",
"AsyncCallbackManagerForChainGroup",
"env_var_is_set",
"register_configure_hook",
"env_var_is_set",
]

View File

@ -1,3 +1,4 @@
from langchain_core.documents import BaseDocumentTransformer, Document
from langchain_core.document_transformers import BaseDocumentTransformer
from langchain_core.documents import Document
__all__ = ["Document", "BaseDocumentTransformer"]

View File

@ -1,3 +1,4 @@
from langchain_core.language_models import BaseLanguageModel, get_tokenizer
from langchain_core.language_models.base import _get_token_ids_default_method
__all__ = ["get_tokenizer", "BaseLanguageModel"]
__all__ = ["get_tokenizer", "BaseLanguageModel", "_get_token_ids_default_method"]

View File

@ -13,12 +13,17 @@ from langchain_core.messages import (
SystemMessageChunk,
ToolMessage,
ToolMessageChunk,
_message_from_dict,
get_buffer_string,
merge_content,
message_to_dict,
messages_from_dict,
messages_to_dict,
)
# Backwards compatibility.
_message_to_dict = message_to_dict
__all__ = [
"get_buffer_string",
"BaseMessage",
@ -38,4 +43,7 @@ __all__ = [
"ChatMessageChunk",
"messages_to_dict",
"messages_from_dict",
"_message_to_dict",
"_message_from_dict",
"message_to_dict",
]

View File

@ -1,19 +1,23 @@
from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import (
BaseCumulativeTransformOutputParser,
BaseGenerationOutputParser,
BaseLLMOutputParser,
BaseOutputParser,
BaseTransformOutputParser,
OutputParserException,
StrOutputParser,
)
# Backwards compatibility.
NoOpOutputParser = StrOutputParser
__all__ = [
"BaseLLMOutputParser",
"BaseGenerationOutputParser",
"BaseOutputParser",
"BaseTransformOutputParser",
"BaseCumulativeTransformOutputParser",
"NoOpOutputParser",
"StrOutputParser",
"OutputParserException",
]

View File

@ -1,3 +1,3 @@
from langchain_core.prompts import PromptValue
from langchain_core.prompt_values import PromptValue
__all__ = ["PromptValue"]

View File

@ -12,6 +12,9 @@ from langchain_core.runnables.base import (
coerce_to_runnable,
)
# Backwards compatibility.
RunnableMap = RunnableParallel
__all__ = [
"Runnable",
"RunnableSerializable",
@ -23,5 +26,6 @@ __all__ = [
"RunnableEach",
"RunnableBindingBase",
"RunnableBinding",
"RunnableMap",
"coerce_to_runnable",
]

View File

@ -1,6 +1,7 @@
from langchain_core.runnables.config import (
EmptyDict,
RunnableConfig,
acall_func_with_variable_args,
call_func_with_variable_args,
ensure_config,
get_async_callback_manager_for_config,
@ -18,6 +19,7 @@ __all__ = [
"get_config_list",
"patch_config",
"merge_configs",
"acall_func_with_variable_args",
"call_func_with_variable_args",
"get_callback_manager_for_config",
"get_async_callback_manager_for_config",

View File

@ -1,7 +1,8 @@
from langchain_core.runnables.passthrough import (
RunnableAssign,
RunnablePassthrough,
aidentity,
identity,
)
__all__ = ["identity", "RunnablePassthrough", "RunnableAssign"]
__all__ = ["aidentity", "identity", "RunnablePassthrough", "RunnableAssign"]

View File

@ -8,9 +8,12 @@ from langchain_core.runnables.utils import (
IsFunctionArgDict,
IsLocalDict,
SupportsAdd,
aadd,
accepts_config,
accepts_run_manager,
add,
gated_coro,
gather_with_concurrency,
get_function_first_arg_dict_keys,
get_lambda_source,
get_unique_config_specs,
@ -34,4 +37,7 @@ __all__ = [
"ConfigurableFieldMultiOption",
"ConfigurableFieldSpec",
"get_unique_config_specs",
"aadd",
"gated_coro",
"gather_with_concurrency",
]

View File

@ -4128,13 +4128,13 @@ tests = ["pandas (>=1.4)", "pytest", "pytest-asyncio", "pytest-mock"]
[[package]]
name = "langchain-core"
version = "0.0.1"
version = "0.0.2"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "langchain_core-0.0.1-py3-none-any.whl", hash = "sha256:cad923dd3bc39cd9fe24b9d6a9799c97719aeaafc9b19509fe1347109fcb65b3"},
{file = "langchain_core-0.0.1.tar.gz", hash = "sha256:488b72223e14849bf9588ed677a999b282904d1d5e1f81d12767ee1024220724"},
{file = "langchain_core-0.0.2-py3-none-any.whl", hash = "sha256:dd448c7887c24105761a0763bee5ec5b072de905b4fbd83d693e7a181fd63208"},
{file = "langchain_core-0.0.2.tar.gz", hash = "sha256:772600dfe3e707adb9055ed96797763b0de32b394eac8325bf609a7e5929dda2"},
]
[package.dependencies]
@ -6515,6 +6515,8 @@ files = [
{file = "psycopg2-2.9.9-cp310-cp310-win_amd64.whl", hash = "sha256:426f9f29bde126913a20a96ff8ce7d73fd8a216cfb323b1f04da402d452853c3"},
{file = "psycopg2-2.9.9-cp311-cp311-win32.whl", hash = "sha256:ade01303ccf7ae12c356a5e10911c9e1c51136003a9a1d92f7aa9d010fb98372"},
{file = "psycopg2-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:121081ea2e76729acfb0673ff33755e8703d45e926e416cb59bae3a86c6a4981"},
{file = "psycopg2-2.9.9-cp312-cp312-win32.whl", hash = "sha256:d735786acc7dd25815e89cc4ad529a43af779db2e25aa7c626de864127e5a024"},
{file = "psycopg2-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:a7653d00b732afb6fc597e29c50ad28087dcb4fbfb28e86092277a559ae4e693"},
{file = "psycopg2-2.9.9-cp37-cp37m-win32.whl", hash = "sha256:5e0d98cade4f0e0304d7d6f25bbfbc5bd186e07b38eac65379309c4ca3193efa"},
{file = "psycopg2-2.9.9-cp37-cp37m-win_amd64.whl", hash = "sha256:7e2dacf8b009a1c1e843b5213a87f7c544b2b042476ed7755be813eaf4e8347a"},
{file = "psycopg2-2.9.9-cp38-cp38-win32.whl", hash = "sha256:ff432630e510709564c01dafdbe996cb552e0b9f3f065eb89bdce5bd31fabf4c"},
@ -6557,6 +6559,7 @@ files = [
{file = "psycopg2_binary-2.9.9-cp311-cp311-win32.whl", hash = "sha256:dc4926288b2a3e9fd7b50dc6a1909a13bbdadfc67d93f3374d984e56f885579d"},
{file = "psycopg2_binary-2.9.9-cp311-cp311-win_amd64.whl", hash = "sha256:b76bedd166805480ab069612119ea636f5ab8f8771e640ae103e05a4aae3e417"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:8532fd6e6e2dc57bcb3bc90b079c60de896d2128c5d9d6f24a63875a95a088cf"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:b0605eaed3eb239e87df0d5e3c6489daae3f7388d455d0c0b4df899519c6a38d"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f8544b092a29a6ddd72f3556a9fcf249ec412e10ad28be6a0c0d948924f2212"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2d423c8d8a3c82d08fe8af900ad5b613ce3632a1249fd6a223941d0735fce493"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2e5afae772c00980525f6d6ecf7cbca55676296b580c0e6abb407f15f3706996"},
@ -6565,6 +6568,8 @@ files = [
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:cb16c65dcb648d0a43a2521f2f0a2300f40639f6f8c1ecbc662141e4e3e1ee07"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:911dda9c487075abd54e644ccdf5e5c16773470a6a5d3826fda76699410066fb"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:57fede879f08d23c85140a360c6a77709113efd1c993923c59fde17aa27599fe"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win32.whl", hash = "sha256:64cf30263844fa208851ebb13b0732ce674d8ec6a0c86a4e160495d299ba3c93"},
{file = "psycopg2_binary-2.9.9-cp312-cp312-win_amd64.whl", hash = "sha256:81ff62668af011f9a48787564ab7eded4e9fb17a4a6a74af5ffa6a457400d2ab"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2293b001e319ab0d869d660a704942c9e2cce19745262a8aba2115ef41a0a42a"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:03ef7df18daf2c4c07e2695e8cfd5ee7f748a1d54d802330985a78d2a5a6dca9"},
{file = "psycopg2_binary-2.9.9-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a602ea5aff39bb9fac6308e9c9d82b9a35c2bf288e184a816002c9fae930b77"},
@ -11080,4 +11085,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "6781b828c8bc2b08ece3cbf82be799dc0e361b9f6f9c204ddefcfee70ab0db8b"
content-hash = "3660cb9106129dcf1a2fe157c1bdfe567923fc9ed408672bcbcb70f4345f1285"

View File

@ -12,7 +12,7 @@ langchain-server = "langchain.server:main"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.0.1"
langchain-core = "^0.0.2"
pydantic = ">=1,<3"
SQLAlchemy = ">=1.4,<3"
requests = "^2"

View File

@ -2,7 +2,6 @@ import os
import time
from typing import Optional
from cassandra.cluster import Cluster
from langchain_core.messages import AIMessage, HumanMessage
from langchain.memory import ConversationBufferMemory
@ -16,6 +15,8 @@ def _chat_message_history(
drop: bool = True,
ttl_seconds: Optional[int] = None,
) -> CassandraChatMessageHistory:
from cassandra.cluster import Cluster
keyspace = "cmh_test_keyspace"
table_name = "cmh_test_table"
# get db connection

View File

@ -2,8 +2,6 @@
import time
from typing import List, Optional, Type
from cassandra.cluster import Cluster
from langchain.docstore.document import Document
from langchain.vectorstores import Cassandra
from tests.integration_tests.vectorstores.fake_embeddings import (
@ -19,6 +17,8 @@ def _vectorstore_from_texts(
embedding_class: Type[Embeddings] = ConsistentFakeEmbeddings,
drop: bool = True,
) -> Cassandra:
from cassandra.cluster import Cluster
keyspace = "vector_test_keyspace"
table_name = "vector_test_table"
# get db connection
@ -154,12 +154,3 @@ def test_cassandra_delete() -> None:
time.sleep(0.3)
output = docsearch.similarity_search("foo", k=10)
assert len(output) == 0
# if __name__ == "__main__":
# test_cassandra()
# test_cassandra_with_score()
# test_cassandra_max_marginal_relevance_search()
# test_cassandra_add_extra()
# test_cassandra_no_drop()
# test_cassandra_delete()

View File

@ -0,0 +1,20 @@
from langchain.schema.runnable.base import __all__
EXPECTED_ALL = [
"Runnable",
"RunnableBinding",
"RunnableBindingBase",
"RunnableEach",
"RunnableEachBase",
"RunnableGenerator",
"RunnableLambda",
"RunnableMap",
"RunnableParallel",
"RunnableSequence",
"RunnableSerializable",
"coerce_to_runnable",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.runnable.branch import __all__
EXPECTED_ALL = ["RunnableBranch"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,19 @@
from langchain.schema.runnable.config import __all__
EXPECTED_ALL = [
"EmptyDict",
"RunnableConfig",
"acall_func_with_variable_args",
"call_func_with_variable_args",
"ensure_config",
"get_async_callback_manager_for_config",
"get_callback_manager_for_config",
"get_config_list",
"get_executor_for_config",
"merge_configs",
"patch_config",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,13 @@
from langchain.schema.runnable.configurable import __all__
EXPECTED_ALL = [
"DynamicRunnable",
"RunnableConfigurableAlternatives",
"RunnableConfigurableFields",
"StrEnum",
"make_options_spec",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.runnable.fallbacks import __all__
EXPECTED_ALL = ["RunnableWithFallbacks"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.runnable.history import __all__
EXPECTED_ALL = ["RunnableWithMessageHistory"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,26 @@
from langchain.schema.runnable import __all__
EXPECTED_ALL = [
"ConfigurableField",
"ConfigurableFieldSingleOption",
"ConfigurableFieldMultiOption",
"patch_config",
"RouterInput",
"RouterRunnable",
"Runnable",
"RunnableSerializable",
"RunnableBinding",
"RunnableBranch",
"RunnableConfig",
"RunnableGenerator",
"RunnableLambda",
"RunnableMap",
"RunnableParallel",
"RunnablePassthrough",
"RunnableSequence",
"RunnableWithFallbacks",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.runnable.passthrough import __all__
EXPECTED_ALL = ["RunnableAssign", "RunnablePassthrough", "aidentity", "identity"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.runnable.retry import __all__
EXPECTED_ALL = ["RunnableRetry"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.runnable.router import __all__
EXPECTED_ALL = ["RouterInput", "RouterRunnable"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,27 @@
from langchain.schema.runnable.utils import __all__
EXPECTED_ALL = [
"AddableDict",
"ConfigurableField",
"ConfigurableFieldMultiOption",
"ConfigurableFieldSingleOption",
"ConfigurableFieldSpec",
"GetLambdaSource",
"IsFunctionArgDict",
"IsLocalDict",
"SupportsAdd",
"aadd",
"accepts_config",
"accepts_run_manager",
"add",
"gated_coro",
"gather_with_concurrency",
"get_function_first_arg_dict_keys",
"get_lambda_source",
"get_unique_config_specs",
"indent_lines_after_first",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.agent import __all__
EXPECTED_ALL = ["AgentAction", "AgentActionMessageLog", "AgentFinish"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.cache import __all__
EXPECTED_ALL = ["BaseCache"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.chat import __all__
EXPECTED_ALL = ["ChatSession"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.chat_history import __all__
EXPECTED_ALL = ["BaseChatMessageHistory"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.document import __all__
EXPECTED_ALL = ["BaseDocumentTransformer", "Document"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.embeddings import __all__
EXPECTED_ALL = ["Embeddings"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.exceptions import __all__
EXPECTED_ALL = ["LangChainException"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -18,6 +18,7 @@ EXPECTED_ALL = [
"messages_from_dict",
"messages_to_dict",
"message_to_dict",
"_message_to_dict",
"_message_from_dict",
"get_buffer_string",
"RunInfo",

View File

@ -0,0 +1,7 @@
from langchain.schema.language_model import __all__
EXPECTED_ALL = ["BaseLanguageModel", "_get_token_ids_default_method", "get_tokenizer"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.memory import __all__
EXPECTED_ALL = ["BaseMemory"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -1,101 +1,29 @@
import pytest
from langchain_core.messages import (
AIMessageChunk,
ChatMessageChunk,
FunctionMessageChunk,
HumanMessageChunk,
)
from langchain.schema.messages import __all__
EXPECTED_ALL = [
"AIMessage",
"AIMessageChunk",
"BaseMessage",
"BaseMessageChunk",
"ChatMessage",
"ChatMessageChunk",
"FunctionMessage",
"FunctionMessageChunk",
"HumanMessage",
"HumanMessageChunk",
"SystemMessage",
"SystemMessageChunk",
"ToolMessage",
"ToolMessageChunk",
"_message_from_dict",
"_message_to_dict",
"message_to_dict",
"get_buffer_string",
"merge_content",
"messages_from_dict",
"messages_to_dict",
]
def test_message_chunks() -> None:
assert AIMessageChunk(content="I am") + AIMessageChunk(
content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "MessageChunk + MessageChunk should be a MessageChunk"
assert (
AIMessageChunk(content="I am") + HumanMessageChunk(content=" indeed.")
== AIMessageChunk(content="I am indeed.")
), "MessageChunk + MessageChunk should be a MessageChunk of same class as the left side" # noqa: E501
assert (
AIMessageChunk(content="", additional_kwargs={"foo": "bar"})
+ AIMessageChunk(content="", additional_kwargs={"baz": "foo"})
== AIMessageChunk(content="", additional_kwargs={"foo": "bar", "baz": "foo"})
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
assert (
AIMessageChunk(
content="", additional_kwargs={"function_call": {"name": "web_search"}}
)
+ AIMessageChunk(
content="", additional_kwargs={"function_call": {"arguments": "{\n"}}
)
+ AIMessageChunk(
content="",
additional_kwargs={
"function_call": {"arguments": ' "query": "turtles"\n}'}
},
)
== AIMessageChunk(
content="",
additional_kwargs={
"function_call": {
"name": "web_search",
"arguments": '{\n "query": "turtles"\n}',
}
},
)
), "MessageChunk + MessageChunk should be a MessageChunk with merged additional_kwargs" # noqa: E501
def test_chat_message_chunks() -> None:
assert ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == ChatMessageChunk(
role="User", content="I am indeed."
), "ChatMessageChunk + ChatMessageChunk should be a ChatMessageChunk"
with pytest.raises(ValueError):
ChatMessageChunk(role="User", content="I am") + ChatMessageChunk(
role="Assistant", content=" indeed."
)
assert (
ChatMessageChunk(role="User", content="I am")
+ AIMessageChunk(content=" indeed.")
== ChatMessageChunk(role="User", content="I am indeed.")
), "ChatMessageChunk + other MessageChunk should be a ChatMessageChunk with the left side's role" # noqa: E501
assert AIMessageChunk(content="I am") + ChatMessageChunk(
role="User", content=" indeed."
) == AIMessageChunk(
content="I am indeed."
), "Other MessageChunk + ChatMessageChunk should be a MessageChunk as the left side" # noqa: E501
def test_function_message_chunks() -> None:
assert FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="hello", content=" indeed."
) == FunctionMessageChunk(
name="hello", content="I am indeed."
), "FunctionMessageChunk + FunctionMessageChunk should be a FunctionMessageChunk"
with pytest.raises(ValueError):
FunctionMessageChunk(name="hello", content="I am") + FunctionMessageChunk(
name="bye", content=" indeed."
)
def test_ani_message_chunks() -> None:
assert AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=True, content=" indeed."
) == AIMessageChunk(
example=True, content="I am indeed."
), "AIMessageChunk + AIMessageChunk should be a AIMessageChunk"
with pytest.raises(ValueError):
AIMessageChunk(example=True, content="I am") + AIMessageChunk(
example=False, content=" indeed."
)
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -1,60 +1,15 @@
from langchain_core.messages import HumanMessageChunk
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain.schema.output import __all__
EXPECTED_ALL = [
"ChatGeneration",
"ChatGenerationChunk",
"ChatResult",
"Generation",
"GenerationChunk",
"LLMResult",
"RunInfo",
]
def test_generation_chunk() -> None:
assert GenerationChunk(text="Hello, ") + GenerationChunk(
text="world!"
) == GenerationChunk(
text="Hello, world!"
), "GenerationChunk + GenerationChunk should be a GenerationChunk"
assert (
GenerationChunk(text="Hello, ")
+ GenerationChunk(text="world!", generation_info={"foo": "bar"})
== GenerationChunk(text="Hello, world!", generation_info={"foo": "bar"})
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
assert (
GenerationChunk(text="Hello, ")
+ GenerationChunk(text="world!", generation_info={"foo": "bar"})
+ GenerationChunk(text="!", generation_info={"baz": "foo"})
== GenerationChunk(
text="Hello, world!!", generation_info={"foo": "bar", "baz": "foo"}
)
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
def test_chat_generation_chunk() -> None:
assert ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, ")
) + ChatGenerationChunk(
message=HumanMessageChunk(content="world!")
) == ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!")
), "ChatGenerationChunk + ChatGenerationChunk should be a ChatGenerationChunk"
assert (
ChatGenerationChunk(message=HumanMessageChunk(content="Hello, "))
+ ChatGenerationChunk(
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
)
== ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!"),
generation_info={"foo": "bar"},
)
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
assert (
ChatGenerationChunk(message=HumanMessageChunk(content="Hello, "))
+ ChatGenerationChunk(
message=HumanMessageChunk(content="world!"), generation_info={"foo": "bar"}
)
+ ChatGenerationChunk(
message=HumanMessageChunk(content="!"), generation_info={"baz": "foo"}
)
== ChatGenerationChunk(
message=HumanMessageChunk(content="Hello, world!!"),
generation_info={"foo": "bar", "baz": "foo"},
)
), "GenerationChunk + GenerationChunk should be a GenerationChunk with merged generation_info" # noqa: E501
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,16 @@
from langchain.schema.output_parser import __all__
EXPECTED_ALL = [
"BaseCumulativeTransformOutputParser",
"BaseGenerationOutputParser",
"BaseLLMOutputParser",
"BaseOutputParser",
"BaseTransformOutputParser",
"NoOpOutputParser",
"OutputParserException",
"StrOutputParser",
]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.prompt import __all__
EXPECTED_ALL = ["PromptValue"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.prompt_template import __all__
EXPECTED_ALL = ["BasePromptTemplate", "format_document"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.retriever import __all__
EXPECTED_ALL = ["BaseRetriever"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.storage import __all__
EXPECTED_ALL = ["BaseStore"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -0,0 +1,7 @@
from langchain.schema.vectorstore import __all__
EXPECTED_ALL = ["VectorStore", "VectorStoreRetriever"]
def test_all_imports() -> None:
assert set(__all__) == set(EXPECTED_ALL)

View File

@ -21,7 +21,7 @@ from langchain_core.messages import (
messages_to_dict,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, Generation
from langchain_core.prompts import ChatPromptValueConcrete, StringPromptValue
from langchain_core.prompt_values import ChatPromptValueConcrete, StringPromptValue
from langchain_core.pydantic_v1 import BaseModel, ValidationError