mirror of
https://github.com/hwchase17/langchain.git
synced 2026-06-09 18:50:33 +00:00
fix(core): serialization patch (#34455)
- `allowed_objects` kwarg in `load` - escape lc-ser formatted dicts on `dump` - fix for jinja2 --------- Co-authored-by: Mason Daugherty <github@mdrxy.com>
This commit is contained in:
@@ -1360,7 +1360,7 @@ def test_structured_outputs_parser() -> None:
|
||||
partial(_oai_structured_outputs_parser, schema=GenerateUsername)
|
||||
)
|
||||
serialized = dumps(llm_output)
|
||||
deserialized = loads(serialized)
|
||||
deserialized = loads(serialized, allowed_objects=[ChatGeneration, AIMessage])
|
||||
assert isinstance(deserialized, ChatGeneration)
|
||||
result = output_parser.invoke(cast(AIMessage, deserialized.message))
|
||||
assert result == parsed_response
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from langchain_core.load import dumpd, dumps, load, loads
|
||||
from langchain_core.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain_core.prompts.prompt import PromptTemplate
|
||||
from langchain_core.runnables import RunnableSequence
|
||||
|
||||
from langchain_openai import ChatOpenAI, OpenAI
|
||||
|
||||
@@ -6,7 +9,11 @@ from langchain_openai import ChatOpenAI, OpenAI
|
||||
def test_loads_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello", top_p=0.8) # type: ignore[call-arg]
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = loads(
|
||||
llm_string,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[OpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
llm_string_2 = dumps(llm2)
|
||||
@@ -17,7 +24,11 @@ def test_loads_openai_llm() -> None:
|
||||
def test_load_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
llm_obj = dumpd(llm)
|
||||
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = load(
|
||||
llm_obj,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[OpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
assert dumpd(llm2) == llm_obj
|
||||
@@ -27,7 +38,11 @@ def test_load_openai_llm() -> None:
|
||||
def test_loads_openai_chat() -> None:
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = loads(
|
||||
llm_string,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[ChatOpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
llm_string_2 = dumps(llm2)
|
||||
@@ -38,8 +53,85 @@ def test_loads_openai_chat() -> None:
|
||||
def test_load_openai_chat() -> None:
|
||||
llm = ChatOpenAI(model="gpt-3.5-turbo", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
llm_obj = dumpd(llm)
|
||||
llm2 = load(llm_obj, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
llm2 = load(
|
||||
llm_obj,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[ChatOpenAI],
|
||||
)
|
||||
|
||||
assert llm2.dict() == llm.dict()
|
||||
assert dumpd(llm2) == llm_obj
|
||||
assert isinstance(llm2, ChatOpenAI)
|
||||
|
||||
|
||||
def test_loads_runnable_sequence_prompt_model() -> None:
|
||||
"""Test serialization/deserialization of a chain:
|
||||
|
||||
`prompt | model (RunnableSequence)`
|
||||
"""
|
||||
prompt = ChatPromptTemplate.from_messages([("user", "Hello, {name}!")])
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0.5, openai_api_key="hello") # type: ignore[call-arg]
|
||||
chain = prompt | model
|
||||
|
||||
# Verify the chain is a RunnableSequence
|
||||
assert isinstance(chain, RunnableSequence)
|
||||
|
||||
# Serialize
|
||||
chain_string = dumps(chain)
|
||||
|
||||
# Deserialize
|
||||
# (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate)
|
||||
chain2 = loads(
|
||||
chain_string,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[
|
||||
RunnableSequence,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
ChatOpenAI,
|
||||
],
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(chain2, RunnableSequence)
|
||||
assert isinstance(chain2.first, ChatPromptTemplate)
|
||||
assert isinstance(chain2.last, ChatOpenAI)
|
||||
|
||||
# Verify round-trip serialization
|
||||
assert dumps(chain2) == chain_string
|
||||
|
||||
|
||||
def test_load_runnable_sequence_prompt_model() -> None:
|
||||
"""Test load() with a chain:
|
||||
|
||||
`prompt | model (RunnableSequence)`.
|
||||
"""
|
||||
prompt = ChatPromptTemplate.from_messages([("user", "Tell me about {topic}")])
|
||||
model = ChatOpenAI(model="gpt-4o-mini", temperature=0.7, openai_api_key="hello") # type: ignore[call-arg]
|
||||
chain = prompt | model
|
||||
|
||||
# Serialize
|
||||
chain_obj = dumpd(chain)
|
||||
|
||||
# Deserialize
|
||||
# (ChatPromptTemplate contains HumanMessagePromptTemplate and PromptTemplate)
|
||||
chain2 = load(
|
||||
chain_obj,
|
||||
secrets_map={"OPENAI_API_KEY": "hello"},
|
||||
allowed_objects=[
|
||||
RunnableSequence,
|
||||
ChatPromptTemplate,
|
||||
HumanMessagePromptTemplate,
|
||||
PromptTemplate,
|
||||
ChatOpenAI,
|
||||
],
|
||||
)
|
||||
|
||||
# Verify structure
|
||||
assert isinstance(chain2, RunnableSequence)
|
||||
assert isinstance(chain2.first, ChatPromptTemplate)
|
||||
assert isinstance(chain2.last, ChatOpenAI)
|
||||
|
||||
# Verify round-trip serialization
|
||||
assert dumpd(chain2) == chain_obj
|
||||
|
||||
Reference in New Issue
Block a user