mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
add load()
deserializer function that bypasses need for json serialization (#7626)
There is already a `loads()` function which takes a JSON string and loads it using the Reviver But in the callbacks system, there is a `serialized` object that is passed in and that object is already a deserialized JSON-compatible object. This allows you to call `load(serialized)` and bypass intermediate JSON encoding. I found one other place in the code that benefited from this short-circuiting (string_run_evaluator.py) so I fixed that too. Tagging @baskaryan for general/utility stuff. <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
6aee589eec
commit
f0b0c72d98
@ -55,7 +55,7 @@ class Reviver:
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
# The root namespace "langchain" is not a valid identifier.
|
||||
if len(namespace) == 1:
|
||||
if len(namespace) == 1 and namespace[0] == "langchain":
|
||||
raise ValueError(f"Invalid namespace: {value}")
|
||||
|
||||
mod = importlib.import_module(".".join(namespace))
|
||||
@ -79,7 +79,8 @@ def loads(
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Load a JSON object from a string.
|
||||
"""Revive a LangChain class from a JSON string.
|
||||
Equivalent to `load(json.loads(text))`.
|
||||
|
||||
Args:
|
||||
text: The string to load.
|
||||
@ -88,6 +89,38 @@ def loads(
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
|
||||
|
||||
|
||||
def load(
|
||||
obj: Any,
|
||||
*,
|
||||
secrets_map: Optional[Dict[str, str]] = None,
|
||||
valid_namespaces: Optional[List[str]] = None,
|
||||
) -> Any:
|
||||
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||
|
||||
Args:
|
||||
obj: The object to load.
|
||||
secrets_map: A map of secrets to load.
|
||||
valid_namespaces: A list of additional namespaces (modules)
|
||||
to allow to be deserialized.
|
||||
|
||||
Returns:
|
||||
Revived LangChain objects.
|
||||
"""
|
||||
reviver = Reviver(secrets_map, valid_namespaces)
|
||||
|
||||
def _load(obj: Any) -> Any:
|
||||
if isinstance(obj, dict):
|
||||
# Need to revive leaf nodes before reviving this node
|
||||
loaded_obj = {k: _load(v) for k, v in obj.items()}
|
||||
return reviver(loaded_obj)
|
||||
if isinstance(obj, list):
|
||||
return [_load(o) for o in obj]
|
||||
return obj
|
||||
|
||||
return _load(obj)
|
||||
|
@ -13,8 +13,8 @@ from langchain.callbacks.manager import (
|
||||
)
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.evaluation.schema import StringEvaluator
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.load.load import load
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.schema import RUN_KEY, messages_from_dict
|
||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||
@ -25,7 +25,7 @@ def _get_messages_from_run_dict(messages: List[dict]) -> List[BaseMessage]:
|
||||
return []
|
||||
first_message = messages[0]
|
||||
if "lc" in first_message:
|
||||
return [loads(dumps(message)) for message in messages]
|
||||
return [load(dumpd(message)) for message in messages]
|
||||
else:
|
||||
return messages_from_dict(messages)
|
||||
|
||||
|
@ -4,8 +4,8 @@ import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.load.dump import dumpd, dumps
|
||||
from langchain.load.load import load, loads
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
@ -14,7 +14,7 @@ class NotSerializable:
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_openai_llm() -> None:
|
||||
def test_loads_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
@ -25,7 +25,7 @@ def test_load_openai_llm() -> None:
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain() -> None:
|
||||
def test_loads_llmchain() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
@ -40,7 +40,7 @@ def test_load_llmchain() -> None:
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain_env() -> None:
|
||||
def test_loads_llmchain_env() -> None:
|
||||
import os
|
||||
|
||||
has_env = "OPENAI_API_KEY" in os.environ
|
||||
@ -64,7 +64,7 @@ def test_load_llmchain_env() -> None:
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||
def test_loads_llmchain_with_non_serializable_arg() -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
@ -76,3 +76,68 @@ def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||
chain_string = dumps(chain, pretty=True)
|
||||
with pytest.raises(NotImplementedError):
|
||||
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
llm_obj = dumpd(llm)
|
||||
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert llm2 == llm
|
||||
assert dumpd(llm2) == llm_obj
|
||||
assert isinstance(llm2, OpenAI)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_obj = dumpd(chain)
|
||||
chain2 = load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert chain2 == chain
|
||||
assert dumpd(chain2) == chain_obj
|
||||
assert isinstance(chain2, LLMChain)
|
||||
assert isinstance(chain2.llm, OpenAI)
|
||||
assert isinstance(chain2.prompt, PromptTemplate)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain_env() -> None:
|
||||
import os
|
||||
|
||||
has_env = "OPENAI_API_KEY" in os.environ
|
||||
if not has_env:
|
||||
os.environ["OPENAI_API_KEY"] = "env_variable"
|
||||
|
||||
llm = OpenAI(model="davinci", temperature=0.5)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_obj = dumpd(chain)
|
||||
chain2 = load(chain_obj)
|
||||
|
||||
assert chain2 == chain
|
||||
assert dumpd(chain2) == chain_obj
|
||||
assert isinstance(chain2, LLMChain)
|
||||
assert isinstance(chain2.llm, OpenAI)
|
||||
assert isinstance(chain2.prompt, PromptTemplate)
|
||||
|
||||
if not has_env:
|
||||
del os.environ["OPENAI_API_KEY"]
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_obj = dumpd(chain)
|
||||
with pytest.raises(NotImplementedError):
|
||||
load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
Loading…
Reference in New Issue
Block a user