core[patch], community[patch]: mark runnable context, lc load as beta (#15603)

This commit is contained in:
Bagatur 2024-01-05 17:54:26 -05:00 committed by GitHub
parent 75281af822
commit a7d023aaf0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 70 additions and 16 deletions

View File

@ -56,7 +56,7 @@ from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM, get_prompts from langchain_core.language_models.llms import LLM, get_prompts
from langchain_core.load.dump import dumps from langchain_core.load.dump import dumps
from langchain_core.load.load import loads from langchain_core.load.load import _loads_suppress_warning
from langchain_core.outputs import ChatGeneration, Generation from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils import get_from_env from langchain_core.utils import get_from_env
@ -149,7 +149,10 @@ def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
RETURN_VAL_TYPE: A list of generations. RETURN_VAL_TYPE: A list of generations.
""" """
try: try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)] generations = [
_loads_suppress_warning(_item_str)
for _item_str in json.loads(generations_str)
]
return generations return generations
except (json.JSONDecodeError, TypeError): except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt # deferring the (soft) handling to after the legacy-format attempt
@ -224,7 +227,7 @@ class SQLAlchemyCache(BaseCache):
rows = session.execute(stmt).fetchall() rows = session.execute(stmt).fetchall()
if rows: if rows:
try: try:
return [loads(row[0]) for row in rows] return [_loads_suppress_warning(row[0]) for row in rows]
except Exception: except Exception:
logger.warning( logger.warning(
"Retrieving a cache value that could not be deserialized " "Retrieving a cache value that could not be deserialized "
@ -395,7 +398,7 @@ class RedisCache(BaseCache):
if results: if results:
for _, text in results.items(): for _, text in results.items():
try: try:
generations.append(loads(text)) generations.append(_loads_suppress_warning(text))
except Exception: except Exception:
logger.warning( logger.warning(
"Retrieving a cache value that could not be deserialized " "Retrieving a cache value that could not be deserialized "
@ -535,7 +538,9 @@ class RedisSemanticCache(BaseCache):
if results: if results:
for document in results: for document in results:
try: try:
generations.extend(loads(document.metadata["return_val"])) generations.extend(
_loads_suppress_warning(document.metadata["return_val"])
)
except Exception: except Exception:
logger.warning( logger.warning(
"Retrieving a cache value that could not be deserialized " "Retrieving a cache value that could not be deserialized "
@ -1185,7 +1190,7 @@ class SQLAlchemyMd5Cache(BaseCache):
"""Look up based on prompt and llm_string.""" """Look up based on prompt and llm_string."""
rows = self._search_rows(prompt, llm_string) rows = self._search_rows(prompt, llm_string)
if rows: if rows:
return [loads(row[0]) for row in rows] return [_loads_suppress_warning(row[0]) for row in rows]
return None return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None:

View File

@ -4,7 +4,7 @@ import logging
from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast from typing import TYPE_CHECKING, Dict, Iterable, Iterator, List, Optional, Union, cast
from langchain_core.chat_sessions import ChatSession from langchain_core.chat_sessions import ChatSession
from langchain_core.load import load from langchain_core.load.load import _load_suppress_warning
from langchain_community.chat_loaders.base import BaseChatLoader from langchain_community.chat_loaders.base import BaseChatLoader
@ -66,8 +66,10 @@ class LangSmithRunChatLoader(BaseChatLoader):
raise ValueError(f"Run has no 'messages' inputs. Got {llm_run.inputs}") raise ValueError(f"Run has no 'messages' inputs. Got {llm_run.inputs}")
if not llm_run.outputs: if not llm_run.outputs:
raise ValueError("Cannot convert pending run") raise ValueError("Cannot convert pending run")
messages = load(llm_run.inputs)["messages"] messages = _load_suppress_warning(llm_run.inputs)["messages"]
message_chunk = load(llm_run.outputs)["generations"][0]["message"] message_chunk = _load_suppress_warning(llm_run.outputs)["generations"][0][
"message"
]
return ChatSession(messages=messages + [message_chunk]) return ChatSession(messages=messages + [message_chunk])
@staticmethod @staticmethod

View File

@ -1,7 +1,15 @@
from importlib import metadata from importlib import metadata
from langchain_core._api import (
surface_langchain_beta_warnings,
surface_langchain_deprecation_warnings,
)
try: try:
__version__ = metadata.version(__package__) __version__ = metadata.version(__package__)
except metadata.PackageNotFoundError: except metadata.PackageNotFoundError:
# Case where package metadata is not available. # Case where package metadata is not available.
__version__ = "" __version__ = ""
surface_langchain_deprecation_warnings()
surface_langchain_beta_warnings()

View File

@ -8,7 +8,12 @@ This module is only relevant for LangChain developers, not for users.
in your own code. We may change the API at any time with no warning. in your own code. We may change the API at any time with no warning.
""" """
from .beta_decorator import (
LangChainBetaWarning,
beta,
suppress_langchain_beta_warning,
surface_langchain_beta_warnings,
)
from .deprecation import ( from .deprecation import (
LangChainDeprecationWarning, LangChainDeprecationWarning,
deprecated, deprecated,
@ -20,9 +25,13 @@ from .path import as_import_path, get_relative_path
__all__ = [ __all__ = [
"as_import_path", "as_import_path",
"beta",
"deprecated", "deprecated",
"get_relative_path", "get_relative_path",
"LangChainBetaWarning",
"LangChainDeprecationWarning", "LangChainDeprecationWarning",
"suppress_langchain_beta_warning",
"surface_langchain_beta_warnings",
"suppress_langchain_deprecation_warning", "suppress_langchain_deprecation_warning",
"surface_langchain_deprecation_warnings", "surface_langchain_deprecation_warnings",
"warn_deprecated", "warn_deprecated",

View File

@ -18,6 +18,7 @@ from typing import (
Union, Union,
) )
from langchain_core._api.beta_decorator import beta
from langchain_core.runnables.base import ( from langchain_core.runnables.base import (
Runnable, Runnable,
RunnableSerializable, RunnableSerializable,
@ -156,6 +157,7 @@ def config_with_context(
return _config_with_context(config, steps, _setter, _getter, threading.Event) return _config_with_context(config, steps, _setter, _getter, threading.Event)
@beta()
class ContextGet(RunnableSerializable): class ContextGet(RunnableSerializable):
"""Get a context value.""" """Get a context value."""
@ -219,6 +221,7 @@ def _coerce_set_value(value: SetValue) -> Runnable[Input, Output]:
return coerce_to_runnable(value) return coerce_to_runnable(value)
@beta()
class ContextSet(RunnableSerializable): class ContextSet(RunnableSerializable):
"""Set a context value.""" """Set a context value."""

View File

@ -3,6 +3,7 @@ import json
import os import os
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from langchain_core._api import beta, suppress_langchain_beta_warning
from langchain_core.load.mapping import ( from langchain_core.load.mapping import (
OLD_PROMPT_TEMPLATE_FORMATS, OLD_PROMPT_TEMPLATE_FORMATS,
SERIALIZABLE_MAPPING, SERIALIZABLE_MAPPING,
@ -102,6 +103,7 @@ class Reviver:
return value return value
@beta()
def loads( def loads(
text: str, text: str,
*, *,
@ -123,6 +125,17 @@ def loads(
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces)) return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
def _loads_suppress_warning(
text: str,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
with suppress_langchain_beta_warning():
return loads(text, secrets_map=secrets_map, valid_namespaces=valid_namespaces)
@beta()
def load( def load(
obj: Any, obj: Any,
*, *,
@ -153,3 +166,13 @@ def load(
return obj return obj
return _load(obj) return _load(obj)
def _load_suppress_warning(
obj: Any,
*,
secrets_map: Optional[Dict[str, str]] = None,
valid_namespaces: Optional[List[str]] = None,
) -> Any:
with suppress_langchain_beta_warning():
return load(obj, secrets_map=secrets_map, valid_namespaces=valid_namespaces)

View File

@ -14,7 +14,7 @@ from typing import (
) )
from langchain_core.chat_history import BaseChatMessageHistory from langchain_core.chat_history import BaseChatMessageHistory
from langchain_core.load import load from langchain_core.load.load import _load_suppress_warning
from langchain_core.pydantic_v1 import BaseModel, create_model from langchain_core.pydantic_v1 import BaseModel, create_model
from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda from langchain_core.runnables.base import Runnable, RunnableBindingBase, RunnableLambda
from langchain_core.runnables.config import run_in_executor from langchain_core.runnables.config import run_in_executor
@ -337,7 +337,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
hist = config["configurable"]["message_history"] hist = config["configurable"]["message_history"]
# Get the input messages # Get the input messages
inputs = load(run.inputs) inputs = _load_suppress_warning(run.inputs)
input_val = inputs[self.input_messages_key or "input"] input_val = inputs[self.input_messages_key or "input"]
input_messages = self._get_input_messages(input_val) input_messages = self._get_input_messages(input_val)
@ -348,7 +348,7 @@ class RunnableWithMessageHistory(RunnableBindingBase):
input_messages = input_messages[len(historic_messages) :] input_messages = input_messages[len(historic_messages) :]
# Get the output messages # Get the output messages
output_val = load(run.outputs) output_val = _load_suppress_warning(run.outputs)
output_messages = self._get_output_messages(output_val) output_messages = self._get_output_messages(output_val)
for m in input_messages + output_messages: for m in input_messages + output_messages:

View File

@ -19,7 +19,7 @@ from uuid import UUID
import jsonpatch # type: ignore[import] import jsonpatch # type: ignore[import]
from anyio import create_memory_object_stream from anyio import create_memory_object_stream
from langchain_core.load import load from langchain_core.load.load import _load_suppress_warning
from langchain_core.outputs import ChatGenerationChunk, GenerationChunk from langchain_core.outputs import ChatGenerationChunk, GenerationChunk
from langchain_core.tracers.base import BaseTracer from langchain_core.tracers.base import BaseTracer
from langchain_core.tracers.schemas import Run from langchain_core.tracers.schemas import Run
@ -267,7 +267,7 @@ class LogStreamCallbackHandler(BaseTracer):
"op": "add", "op": "add",
"path": f"/logs/{index}/final_output", "path": f"/logs/{index}/final_output",
# to undo the dumpd done by some runnables / tracer / etc # to undo the dumpd done by some runnables / tracer / etc
"value": load(run.outputs), "value": _load_suppress_warning(run.outputs),
}, },
{ {
"op": "add", "op": "add",

View File

@ -3,7 +3,7 @@ from typing import Any, Dict
import pytest import pytest
from langchain_core._api.beta import beta, warn_beta from langchain_core._api.beta_decorator import beta, warn_beta
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel

View File

@ -1,8 +1,12 @@
from langchain_core._api import __all__ from langchain_core._api import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"beta",
"deprecated", "deprecated",
"LangChainBetaWarning",
"LangChainDeprecationWarning", "LangChainDeprecationWarning",
"suppress_langchain_beta_warning",
"surface_langchain_beta_warnings",
"suppress_langchain_deprecation_warning", "suppress_langchain_deprecation_warning",
"surface_langchain_deprecation_warnings", "surface_langchain_deprecation_warnings",
"warn_deprecated", "warn_deprecated",