From f0b0c72d98d5cf4511339b5d5d1e4fb7559cc9ab Mon Sep 17 00:00:00 2001 From: Alec Flett Date: Fri, 4 Aug 2023 01:49:41 -0700 Subject: [PATCH] add `load()` deserializer function that bypasses need for json serialization (#7626) There is already a `loads()` function which takes a JSON string and loads it using the Reviver But in the callbacks system, there is a `serialized` object that is passed in and that object is already a deserialized JSON-compatible object. This allows you to call `load(serialized)` and bypass intermediate JSON encoding. I found one other place in the code that benefited from this short-circuiting (string_run_evaluator.py) so I fixed that too. Tagging @baskaryan for general/utility stuff. --------- Co-authored-by: Nuno Campos --- libs/langchain/langchain/load/load.py | 39 +++++++++- .../smith/evaluation/string_run_evaluator.py | 6 +- .../tests/unit_tests/load/test_load.py | 77 +++++++++++++++++-- 3 files changed, 110 insertions(+), 12 deletions(-) diff --git a/libs/langchain/langchain/load/load.py b/libs/langchain/langchain/load/load.py index fe3653d5503..5d8b7ccd33e 100644 --- a/libs/langchain/langchain/load/load.py +++ b/libs/langchain/langchain/load/load.py @@ -55,7 +55,7 @@ class Reviver: raise ValueError(f"Invalid namespace: {value}") # The root namespace "langchain" is not a valid identifier. - if len(namespace) == 1: + if len(namespace) == 1 and namespace[0] == "langchain": raise ValueError(f"Invalid namespace: {value}") mod = importlib.import_module(".".join(namespace)) @@ -79,7 +79,8 @@ def loads( secrets_map: Optional[Dict[str, str]] = None, valid_namespaces: Optional[List[str]] = None, ) -> Any: - """Load a JSON object from a string. + """Revive a LangChain class from a JSON string. + Equivalent to `load(json.loads(text))`. Args: text: The string to load. @@ -88,6 +89,38 @@ def loads( to allow to be deserialized. Returns: - + Revived LangChain objects. """ return json.loads(text, object_hook=Reviver(secrets_map, valid_namespaces)) + + +def load( + obj: Any, + *, + secrets_map: Optional[Dict[str, str]] = None, + valid_namespaces: Optional[List[str]] = None, +) -> Any: + """Revive a LangChain class from a JSON object. Use this if you already + have a parsed JSON object, eg. from `json.load` or `orjson.loads`. + + Args: + obj: The object to load. + secrets_map: A map of secrets to load. + valid_namespaces: A list of additional namespaces (modules) + to allow to be deserialized. + + Returns: + Revived LangChain objects. + """ + reviver = Reviver(secrets_map, valid_namespaces) + + def _load(obj: Any) -> Any: + if isinstance(obj, dict): + # Need to revive leaf nodes before reviving this node + loaded_obj = {k: _load(v) for k, v in obj.items()} + return reviver(loaded_obj) + if isinstance(obj, list): + return [_load(o) for o in obj] + return obj + + return _load(obj) diff --git a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py index 41fcbe8e3fa..4016482c191 100644 --- a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py +++ b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py @@ -13,8 +13,8 @@ from langchain.callbacks.manager import ( ) from langchain.chains.base import Chain from langchain.evaluation.schema import StringEvaluator -from langchain.load.dump import dumps -from langchain.load.load import loads +from langchain.load.dump import dumpd +from langchain.load.load import load from langchain.load.serializable import Serializable from langchain.schema import RUN_KEY, messages_from_dict from langchain.schema.messages import BaseMessage, get_buffer_string @@ -25,7 +25,7 @@ def _get_messages_from_run_dict(messages: List[dict]) -> List[BaseMessage]: return [] first_message = messages[0] if "lc" in first_message: - return [loads(dumps(message)) for message in messages] + return [load(dumpd(message)) for message in messages] else: return messages_from_dict(messages) diff --git a/libs/langchain/tests/unit_tests/load/test_load.py b/libs/langchain/tests/unit_tests/load/test_load.py index a4713106f37..38310b15bfc 100644 --- a/libs/langchain/tests/unit_tests/load/test_load.py +++ b/libs/langchain/tests/unit_tests/load/test_load.py @@ -4,8 +4,8 @@ import pytest from langchain.chains.llm import LLMChain from langchain.llms.openai import OpenAI -from langchain.load.dump import dumps -from langchain.load.load import loads +from langchain.load.dump import dumpd, dumps +from langchain.load.load import load, loads from langchain.prompts.prompt import PromptTemplate @@ -14,7 +14,7 @@ class NotSerializable: @pytest.mark.requires("openai") -def test_load_openai_llm() -> None: +def test_loads_openai_llm() -> None: llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") llm_string = dumps(llm) llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"}) @@ -25,7 +25,7 @@ def test_load_openai_llm() -> None: @pytest.mark.requires("openai") -def test_load_llmchain() -> None: +def test_loads_llmchain() -> None: llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") prompt = PromptTemplate.from_template("hello {name}!") chain = LLMChain(llm=llm, prompt=prompt) @@ -40,7 +40,7 @@ def test_load_llmchain() -> None: @pytest.mark.requires("openai") -def test_load_llmchain_env() -> None: +def test_loads_llmchain_env() -> None: import os has_env = "OPENAI_API_KEY" in os.environ @@ -64,7 +64,7 @@ def test_load_llmchain_env() -> None: @pytest.mark.requires("openai") -def test_load_llmchain_with_non_serializable_arg() -> None: +def test_loads_llmchain_with_non_serializable_arg() -> None: llm = OpenAI( model="davinci", temperature=0.5, @@ -76,3 +76,68 @@ def test_load_llmchain_with_non_serializable_arg() -> None: chain_string = dumps(chain, pretty=True) with pytest.raises(NotImplementedError): loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"}) + + +@pytest.mark.requires("openai") +def test_load_openai_llm() -> None: + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + llm_obj = dumpd(llm) + llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"}) + + assert llm2 == llm + assert dumpd(llm2) == llm_obj + assert isinstance(llm2, OpenAI) + + +@pytest.mark.requires("openai") +def test_load_llmchain() -> None: + llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + chain_obj = dumpd(chain) + chain2 = load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"}) + + assert chain2 == chain + assert dumpd(chain2) == chain_obj + assert isinstance(chain2, LLMChain) + assert isinstance(chain2.llm, OpenAI) + assert isinstance(chain2.prompt, PromptTemplate) + + +@pytest.mark.requires("openai") +def test_load_llmchain_env() -> None: + import os + + has_env = "OPENAI_API_KEY" in os.environ + if not has_env: + os.environ["OPENAI_API_KEY"] = "env_variable" + + llm = OpenAI(model="davinci", temperature=0.5) + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + chain_obj = dumpd(chain) + chain2 = load(chain_obj) + + assert chain2 == chain + assert dumpd(chain2) == chain_obj + assert isinstance(chain2, LLMChain) + assert isinstance(chain2.llm, OpenAI) + assert isinstance(chain2.prompt, PromptTemplate) + + if not has_env: + del os.environ["OPENAI_API_KEY"] + + +@pytest.mark.requires("openai") +def test_load_llmchain_with_non_serializable_arg() -> None: + llm = OpenAI( + model="davinci", + temperature=0.5, + openai_api_key="hello", + client=NotSerializable, + ) + prompt = PromptTemplate.from_template("hello {name}!") + chain = LLMChain(llm=llm, prompt=prompt) + chain_obj = dumpd(chain) + with pytest.raises(NotImplementedError): + load(chain_obj, secrets_map={"OPENAI_API_KEY": "hello"})