core[patch], community[patch]: mark runnable context, lc load as beta (#15603)

This commit is contained in:
Bagatur
2024-01-05 17:54:26 -05:00
committed by GitHub
parent 75281af822
commit a7d023aaf0
11 changed files with 70 additions and 16 deletions

View File

@@ -56,7 +56,7 @@ from langchain_core.caches import RETURN_VAL_TYPE, BaseCache
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.llms import LLM, get_prompts
from langchain_core.load.dump import dumps
from langchain_core.load.load import loads
from langchain_core.load.load import _loads_suppress_warning
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.utils import get_from_env
@@ -149,7 +149,10 @@ def _loads_generations(generations_str: str) -> Union[RETURN_VAL_TYPE, None]:
RETURN_VAL_TYPE: A list of generations.
"""
try:
generations = [loads(_item_str) for _item_str in json.loads(generations_str)]
generations = [
_loads_suppress_warning(_item_str)
for _item_str in json.loads(generations_str)
]
return generations
except (json.JSONDecodeError, TypeError):
# deferring the (soft) handling to after the legacy-format attempt
@@ -224,7 +227,7 @@ class SQLAlchemyCache(BaseCache):
rows = session.execute(stmt).fetchall()
if rows:
try:
return [loads(row[0]) for row in rows]
return [_loads_suppress_warning(row[0]) for row in rows]
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
@@ -395,7 +398,7 @@ class RedisCache(BaseCache):
if results:
for _, text in results.items():
try:
generations.append(loads(text))
generations.append(_loads_suppress_warning(text))
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
@@ -535,7 +538,9 @@ class RedisSemanticCache(BaseCache):
if results:
for document in results:
try:
generations.extend(loads(document.metadata["return_val"]))
generations.extend(
_loads_suppress_warning(document.metadata["return_val"])
)
except Exception:
logger.warning(
"Retrieving a cache value that could not be deserialized "
@@ -1185,7 +1190,7 @@ class SQLAlchemyMd5Cache(BaseCache):
"""Look up based on prompt and llm_string."""
rows = self._search_rows(prompt, llm_string)
if rows:
return [loads(row[0]) for row in rows]
return [_loads_suppress_warning(row[0]) for row in rows]
return None
def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> None: