mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-25 04:49:17 +00:00
@@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
|
||||
class FakeTracer(BaseTracer):
|
||||
"""Fake tracer that records LangChain execution."""
|
||||
@@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
@@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None:
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
serialized=SERIALIZED_CHAT,
|
||||
inputs=dict(prompts=[""]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
@@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
@@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
@@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
@@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
@@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None:
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
@@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
@@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
@@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
@@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
@@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
@@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
error=repr(exception),
|
||||
inputs=dict(prompts=[]),
|
||||
outputs=None,
|
||||
|
@@ -22,6 +22,9 @@ from langchain.schema import LLMResult
|
||||
|
||||
TEST_SESSION_ID = 2023
|
||||
|
||||
SERIALIZED = {"id": ["llm"]}
|
||||
SERIALIZED_CHAT = {"id": ["chat_model"]}
|
||||
|
||||
|
||||
def load_session(session_name: str) -> TracerSessionV1:
|
||||
"""Load a tracing session."""
|
||||
@@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "chat_model"},
|
||||
serialized=SERIALIZED_CHAT,
|
||||
prompts=[""],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
tracer.new_session()
|
||||
manager = CallbackManager(handlers=[tracer])
|
||||
run_manager = manager.on_chat_model_start(
|
||||
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
|
||||
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
|
||||
)
|
||||
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
|
||||
assert tracer.runs == [compare_run]
|
||||
@@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
tracer.new_session()
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
@@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
@@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None:
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None:
|
||||
extra={},
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=LLMResult(generations=[[]]),
|
||||
session_id=TEST_SESSION_ID,
|
||||
@@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
response=None,
|
||||
session_id=TEST_SESSION_ID,
|
||||
@@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.new_session()
|
||||
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
@@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
@@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=2,
|
||||
child_execution_order=2,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
@@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=None,
|
||||
prompts=[],
|
||||
@@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
extra={},
|
||||
execution_order=5,
|
||||
child_execution_order=5,
|
||||
serialized={"name": "llm"},
|
||||
serialized=SERIALIZED,
|
||||
session_id=TEST_SESSION_ID,
|
||||
error=repr(exception),
|
||||
prompts=[],
|
||||
|
273
tests/unit_tests/load/__snapshots__/test_dump.ambr
Normal file
273
tests/unit_tests/load/__snapshots__/test_dump.ambr
Normal file
@@ -0,0 +1,273 @@
|
||||
# serializer version: 1
|
||||
# name: test_person
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"test_dump",
|
||||
"Person"
|
||||
],
|
||||
"kwargs": {
|
||||
"secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"SECRET"
|
||||
]
|
||||
},
|
||||
"you_can_see_me": "hello"
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_person.1
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"test_dump",
|
||||
"SpecialPerson"
|
||||
],
|
||||
"kwargs": {
|
||||
"another_secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"ANOTHER_SECRET"
|
||||
]
|
||||
},
|
||||
"secret": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"SECRET"
|
||||
]
|
||||
},
|
||||
"another_visible": "bye",
|
||||
"you_can_see_me": "hello"
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_chat
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chat_models",
|
||||
"openai",
|
||||
"ChatOpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": "hello"
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"ChatPromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"messages": [
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"chat",
|
||||
"HumanMessagePromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_llmchain_with_non_serializable_arg
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"chains",
|
||||
"llm",
|
||||
"LLMChain"
|
||||
],
|
||||
"kwargs": {
|
||||
"llm": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.5,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
},
|
||||
"client": {
|
||||
"lc": 1,
|
||||
"type": "not_implemented",
|
||||
"id": [
|
||||
"openai",
|
||||
"api_resources",
|
||||
"completion",
|
||||
"Completion"
|
||||
]
|
||||
}
|
||||
}
|
||||
},
|
||||
"prompt": {
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"prompts",
|
||||
"prompt",
|
||||
"PromptTemplate"
|
||||
],
|
||||
"kwargs": {
|
||||
"input_variables": [
|
||||
"name"
|
||||
],
|
||||
"template": "hello {name}!",
|
||||
"template_format": "f-string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
||||
# name: test_serialize_openai_llm
|
||||
'''
|
||||
{
|
||||
"lc": 1,
|
||||
"type": "constructor",
|
||||
"id": [
|
||||
"langchain",
|
||||
"llms",
|
||||
"openai",
|
||||
"OpenAI"
|
||||
],
|
||||
"kwargs": {
|
||||
"model": "davinci",
|
||||
"temperature": 0.7,
|
||||
"openai_api_key": {
|
||||
"lc": 1,
|
||||
"type": "secret",
|
||||
"id": [
|
||||
"OPENAI_API_KEY"
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
# ---
|
103
tests/unit_tests/load/test_dump.py
Normal file
103
tests/unit_tests/load/test_dump.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""Test for Serializable base class"""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.serializable import Serializable
|
||||
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class Person(Serializable):
|
||||
secret: str
|
||||
|
||||
you_can_see_me: str = "hello"
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"secret": "SECRET"}
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, str]:
|
||||
return {"you_can_see_me": self.you_can_see_me}
|
||||
|
||||
|
||||
class SpecialPerson(Person):
|
||||
another_secret: str
|
||||
|
||||
another_visible: str = "bye"
|
||||
|
||||
# Gets merged with parent class's secrets
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"another_secret": "ANOTHER_SECRET"}
|
||||
|
||||
# Gets merged with parent class's attributes
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, str]:
|
||||
return {"another_visible": self.another_visible}
|
||||
|
||||
|
||||
class NotSerializable:
|
||||
pass
|
||||
|
||||
|
||||
def test_person(snapshot: Any) -> None:
|
||||
p = Person(secret="hello")
|
||||
assert dumps(p, pretty=True) == snapshot
|
||||
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
|
||||
assert dumps(sp, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_openai_llm(snapshot: Any) -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
# This is excluded from serialization
|
||||
callbacks=[LangChainTracer()],
|
||||
)
|
||||
llm.temperature = 0.7 # this is reflected in serialization
|
||||
assert dumps(llm, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain(snapshot: Any) -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_chat(snapshot: Any) -> None:
|
||||
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = ChatPromptTemplate.from_messages(
|
||||
[HumanMessagePromptTemplate.from_template("hello {name}!")]
|
||||
)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
assert dumps(chain, pretty=True) == snapshot
|
54
tests/unit_tests/load/test_load.py
Normal file
54
tests/unit_tests/load/test_load.py
Normal file
@@ -0,0 +1,54 @@
|
||||
"""Test for Serializable base class"""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.chains.llm import LLMChain
|
||||
from langchain.llms.openai import OpenAI
|
||||
from langchain.load.dump import dumps
|
||||
from langchain.load.load import loads
|
||||
from langchain.prompts.prompt import PromptTemplate
|
||||
|
||||
|
||||
class NotSerializable:
|
||||
pass
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_openai_llm() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
llm_string = dumps(llm)
|
||||
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert llm2 == llm
|
||||
assert dumps(llm2) == llm_string
|
||||
assert isinstance(llm2, OpenAI)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain() -> None:
|
||||
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_string = dumps(chain)
|
||||
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
||||
|
||||
assert chain2 == chain
|
||||
assert dumps(chain2) == chain_string
|
||||
assert isinstance(chain2, LLMChain)
|
||||
assert isinstance(chain2.llm, OpenAI)
|
||||
assert isinstance(chain2.prompt, PromptTemplate)
|
||||
|
||||
|
||||
@pytest.mark.requires("openai")
|
||||
def test_load_llmchain_with_non_serializable_arg() -> None:
|
||||
llm = OpenAI(
|
||||
model="davinci",
|
||||
temperature=0.5,
|
||||
openai_api_key="hello",
|
||||
client=NotSerializable,
|
||||
)
|
||||
prompt = PromptTemplate.from_template("hello {name}!")
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
chain_string = dumps(chain, pretty=True)
|
||||
with pytest.raises(NotImplementedError):
|
||||
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
|
@@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
|
||||
"pytest-socket",
|
||||
"pytest-watcher",
|
||||
"responses",
|
||||
"syrupy",
|
||||
]
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user