mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 23:29:21 +00:00
add load()
deserializer function that bypasses need for json serialization (#7626)
There is already a `loads()` function which takes a JSON string and loads it using the Reviver But in the callbacks system, there is a `serialized` object that is passed in and that object is already a deserialized JSON-compatible object. This allows you to call `load(serialized)` and bypass intermediate JSON encoding. I found one other place in the code that benefited from this short-circuiting (string_run_evaluator.py) so I fixed that too. Tagging @baskaryan for general/utility stuff. <!-- Thank you for contributing to LangChain! Replace this comment with: - Description: a description of the change, - Issue: the issue # it fixes (if applicable), - Dependencies: any dependencies required for this change, - Tag maintainer: for a quicker response, tag the relevant maintainer (see below), - Twitter handle: we announce bigger features on Twitter. If your PR gets announced and you'd like a mention, we'll gladly shout you out! If you're adding a new integration, please include: 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. Maintainer responsibilities: - General / Misc / if you don't know who to tag: @baskaryan - DataLoaders / VectorStores / Retrievers: @rlancemartin, @eyurtsev - Models / Prompts: @hwchase17, @baskaryan - Memory: @hwchase17 - Agents / Tools / Toolkits: @hinthornw - Tracing / Callbacks: @agola11 - Async: @agola11 If no one reviews your PR within a few days, feel free to @-mention the same people again. See contribution guidelines for more information on how to write/run tests, lint, etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> --------- Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
parent
6aee589eec
commit
f0b0c72d98
@ -55,7 +55,7 @@ class Reviver:
|
|||||||
raise ValueError(f"Invalid namespace: {value}")
|
raise ValueError(f"Invalid namespace: {value}")
|
||||||
|
|
||||||
# The root namespace "langchain" is not a valid identifier.
|
# The root namespace "langchain" is not a valid identifier.
|
||||||
if len(namespace) == 1:
|
if len(namespace) == 1 and namespace[0] == "langchain":
|
||||||
raise ValueError(f"Invalid namespace: {value}")
|
raise ValueError(f"Invalid namespace: {value}")
|
||||||
|
|
||||||
mod = importlib.import_module(".".join(namespace))
|
mod = importlib.import_module(".".join(namespace))
|
||||||
@ -79,7 +79,8 @@ def loads(
|
|||||||
secrets_map: Optional[Dict[str, str]] = None,
|
secrets_map: Optional[Dict[str, str]] = None,
|
||||||
valid_namespaces: Optional[List[str]] = None,
|
valid_namespaces: Optional[List[str]] = None,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Load a JSON object from a string.
|
"""Revive a LangChain class from a JSON string.
|
||||||
|
Equivalent to `load(json.loads(text))`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
text: The string to load.
|
text: The string to load.
|
||||||
@ -88,6 +89,38 @@ def loads(
|
|||||||
to allow to be deserialized.
|
to allow to be deserialized.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
Revived LangChain objects.
|
||||||
"""
|
"""
|
||||||
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
|
return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces))
|
||||||
|
|
||||||
|
|
||||||
|
def load(
|
||||||
|
obj: Any,
|
||||||
|
*,
|
||||||
|
secrets_map: Optional[Dict[str, str]] = None,
|
||||||
|
valid_namespaces: Optional[List[str]] = None,
|
||||||
|
) -> Any:
|
||||||
|
"""Revive a LangChain class from a JSON object. Use this if you already
|
||||||
|
have a parsed JSON object, eg. from `json.load` or `orjson.loads`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
obj: The object to load.
|
||||||
|
secrets_map: A map of secrets to load.
|
||||||
|
valid_namespaces: A list of additional namespaces (modules)
|
||||||
|
to allow to be deserialized.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Revived LangChain objects.
|
||||||
|
"""
|
||||||
|
reviver = Reviver(secrets_map, valid_namespaces)
|
||||||
|
|
||||||
|
def _load(obj: Any) -> Any:
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
# Need to revive leaf nodes before reviving this node
|
||||||
|
loaded_obj = {k: _load(v) for k, v in obj.items()}
|
||||||
|
return reviver(loaded_obj)
|
||||||
|
if isinstance(obj, list):
|
||||||
|
return [_load(o) for o in obj]
|
||||||
|
return obj
|
||||||
|
|
||||||
|
return _load(obj)
|
||||||
|
@ -13,8 +13,8 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.evaluation.schema import StringEvaluator
|
from langchain.evaluation.schema import StringEvaluator
|
||||||
from langchain.load.dump import dumps
|
from langchain.load.dump import dumpd
|
||||||
from langchain.load.load import loads
|
from langchain.load.load import load
|
||||||
from langchain.load.serializable import Serializable
|
from langchain.load.serializable import Serializable
|
||||||
from langchain.schema import RUN_KEY, messages_from_dict
|
from langchain.schema import RUN_KEY, messages_from_dict
|
||||||
from langchain.schema.messages import BaseMessage, get_buffer_string
|
from langchain.schema.messages import BaseMessage, get_buffer_string
|
||||||
@ -25,7 +25,7 @@ def _get_messages_from_run_dict(messages: List[dict]) -> List[BaseMessage]:
|
|||||||
return []
|
return []
|
||||||
first_message = messages[0]
|
first_message = messages[0]
|
||||||
if "lc" in first_message:
|
if "lc" in first_message:
|
||||||
return [loads(dumps(message)) for message in messages]
|
return [load(dumpd(message)) for message in messages]
|
||||||
else:
|
else:
|
||||||
return messages_from_dict(messages)
|
return messages_from_dict(messages)
|
||||||
|
|
||||||
|
@ -4,8 +4,8 @@ import pytest
|
|||||||
|
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.llms.openai import OpenAI
|
from langchain.llms.openai import OpenAI
|
||||||
from langchain.load.dump import dumps
|
from langchain.load.dump import dumpd, dumps
|
||||||
from langchain.load.load import loads
|
from langchain.load.load import load, loads
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
|
||||||
|
|
||||||
@ -14,7 +14,7 @@ class NotSerializable:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_load_openai_llm() -> None:
|
def test_loads_openai_llm() -> None:
|
||||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
llm_string = dumps(llm)
|
llm_string = dumps(llm)
|
||||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||||
@ -25,7 +25,7 @@ def test_load_openai_llm() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_load_llmchain() -> None:
|
def test_loads_llmchain() -> None:
|
||||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
prompt = PromptTemplate.from_template("hello {name}!")
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
chain = LLMChain(llm=llm, prompt=prompt)
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
@ -40,7 +40,7 @@ def test_load_llmchain() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_load_llmchain_env() -> None:
|
def test_loads_llmchain_env() -> None:
|
||||||
import os
|
import os
|
||||||
|
|
||||||
has_env = "OPENAI_API_KEY" in os.environ
|
has_env = "OPENAI_API_KEY" in os.environ
|
||||||
@ -64,7 +64,7 @@ def test_load_llmchain_env() -> None:
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.requires("openai")
|
@pytest.mark.requires("openai")
|
||||||
def test_load_llmchain_with_non_serializable_arg() -> None:
|
def test_loads_llmchain_with_non_serializable_arg() -> None:
|
||||||
llm = OpenAI(
|
llm = OpenAI(
|
||||||
model="davinci",
|
model="davinci",
|
||||||
temperature=0.5,
|
temperature=0.5,
|
||||||
@ -76,3 +76,68 @@ def test_load_llmchain_with_non_serializable_arg() -> None:
|
|||||||
chain_string = dumps(chain, pretty=True)
|
chain_string = dumps(chain, pretty=True)
|
||||||
with pytest.raises(NotImplementedError):
|
with pytest.raises(NotImplementedError):
|
||||||
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_load_openai_llm() -> None:
|
||||||
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
|
llm_obj = dumpd(llm)
|
||||||
|
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||||
|
|
||||||
|
assert llm2 == llm
|
||||||
|
assert dumpd(llm2) == llm_obj
|
||||||
|
assert isinstance(llm2, OpenAI)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_load_llmchain() -> None:
|
||||||
|
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||||
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
chain_obj = dumpd(chain)
|
||||||
|
chain2 = load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||||
|
|
||||||
|
assert chain2 == chain
|
||||||
|
assert dumpd(chain2) == chain_obj
|
||||||
|
assert isinstance(chain2, LLMChain)
|
||||||
|
assert isinstance(chain2.llm, OpenAI)
|
||||||
|
assert isinstance(chain2.prompt, PromptTemplate)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_load_llmchain_env() -> None:
|
||||||
|
import os
|
||||||
|
|
||||||
|
has_env = "OPENAI_API_KEY" in os.environ
|
||||||
|
if not has_env:
|
||||||
|
os.environ["OPENAI_API_KEY"] = "env_variable"
|
||||||
|
|
||||||
|
llm = OpenAI(model="davinci", temperature=0.5)
|
||||||
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
chain_obj = dumpd(chain)
|
||||||
|
chain2 = load(chain_obj)
|
||||||
|
|
||||||
|
assert chain2 == chain
|
||||||
|
assert dumpd(chain2) == chain_obj
|
||||||
|
assert isinstance(chain2, LLMChain)
|
||||||
|
assert isinstance(chain2.llm, OpenAI)
|
||||||
|
assert isinstance(chain2.prompt, PromptTemplate)
|
||||||
|
|
||||||
|
if not has_env:
|
||||||
|
del os.environ["OPENAI_API_KEY"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.requires("openai")
|
||||||
|
def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||||
|
llm = OpenAI(
|
||||||
|
model="davinci",
|
||||||
|
temperature=0.5,
|
||||||
|
openai_api_key="hello",
|
||||||
|
client=NotSerializable,
|
||||||
|
)
|
||||||
|
prompt = PromptTemplate.from_template("hello {name}!")
|
||||||
|
chain = LLMChain(llm=llm, prompt=prompt)
|
||||||
|
chain_obj = dumpd(chain)
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||||
|
Loading…
Reference in New Issue
Block a user