Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Nuno Campos
2023-06-11 23:51:28 +01:00
committed by GitHub
parent 614cff89bc
commit 18af149e91
27 changed files with 810 additions and 71 deletions

View File

@@ -13,6 +13,9 @@ from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.schemas import Run
from langchain.schema import LLMResult
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
@@ -39,7 +42,7 @@ def test_tracer_llm_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]),
error=None,
@@ -47,7 +50,7 @@ def test_tracer_llm_run() -> None:
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@@ -64,7 +67,7 @@ def test_tracer_chat_model_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=[""]),
outputs=LLMResult(generations=[[]]),
error=None,
@@ -73,7 +76,7 @@ def test_tracer_chat_model_run() -> None:
tracer = FakeTracer()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@@ -100,7 +103,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
error=None,
@@ -110,7 +113,7 @@ def test_tracer_multiple_llm_runs() -> None:
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@@ -183,7 +186,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
@@ -191,7 +194,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@@ -235,7 +238,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
@@ -251,7 +254,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]]),
run_type="llm",
@@ -275,7 +278,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
inputs=dict(prompts=[]),
outputs=None,
error=repr(exception),
@@ -283,7 +286,7 @@ def test_tracer_llm_run_on_error() -> None:
)
tracer = FakeTracer()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@@ -358,14 +361,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@@ -378,7 +381,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
@@ -408,7 +411,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
@@ -422,7 +425,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
outputs=LLMResult(generations=[[]], llm_output=None),
@@ -450,7 +453,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
outputs=None,

View File

@@ -22,6 +22,9 @@ from langchain.schema import LLMResult
TEST_SESSION_ID = 2023
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
def load_session(session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
@@ -107,7 +110,7 @@ def test_tracer_llm_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@@ -116,7 +119,7 @@ def test_tracer_llm_run() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@@ -133,7 +136,7 @@ def test_tracer_chat_model_run() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chat_model"},
serialized=SERIALIZED_CHAT,
prompts=[""],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@@ -144,7 +147,7 @@ def test_tracer_chat_model_run() -> None:
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_manager = manager.on_chat_model_start(
serialized={"name": "chat_model"}, messages=[[]], run_id=uuid
serialized=SERIALIZED_CHAT, messages=[[]], run_id=uuid
)
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@@ -172,7 +175,7 @@ def test_tracer_multiple_llm_runs() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@@ -183,7 +186,7 @@ def test_tracer_multiple_llm_runs() -> None:
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@@ -263,7 +266,7 @@ def test_tracer_nested_run() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
@@ -271,7 +274,7 @@ def test_tracer_nested_run() -> None:
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@@ -319,7 +322,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@@ -337,7 +340,7 @@ def test_tracer_nested_run() -> None:
extra={},
execution_order=4,
child_execution_order=4,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
@@ -362,7 +365,7 @@ def test_tracer_llm_run_on_error() -> None:
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
@@ -371,7 +374,7 @@ def test_tracer_llm_run_on_error() -> None:
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized={"name": "llm"}, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@@ -451,14 +454,14 @@ def test_tracer_nested_runs_on_error() -> None:
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
@@ -471,7 +474,7 @@ def test_tracer_nested_runs_on_error() -> None:
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized={"name": "llm"},
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
@@ -501,7 +504,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=2,
child_execution_order=2,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
@@ -515,7 +518,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=3,
child_execution_order=3,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
@@ -547,7 +550,7 @@ def test_tracer_nested_runs_on_error() -> None:
extra={},
execution_order=5,
child_execution_order=5,
serialized={"name": "llm"},
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],

View File

@@ -0,0 +1,273 @@
# serializer version: 1
# name: test_person
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"Person"
],
"kwargs": {
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_person.1
'''
{
"lc": 1,
"type": "constructor",
"id": [
"test_dump",
"SpecialPerson"
],
"kwargs": {
"another_secret": {
"lc": 1,
"type": "secret",
"id": [
"ANOTHER_SECRET"
]
},
"secret": {
"lc": 1,
"type": "secret",
"id": [
"SECRET"
]
},
"another_visible": "bye",
"you_can_see_me": "hello"
}
}
'''
# ---
# name: test_serialize_llmchain
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_chat
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chat_models",
"openai",
"ChatOpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": "hello"
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"ChatPromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"messages": [
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"chat",
"HumanMessagePromptTemplate"
],
"kwargs": {
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
]
}
}
}
}
'''
# ---
# name: test_serialize_llmchain_with_non_serializable_arg
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"chains",
"llm",
"LLMChain"
],
"kwargs": {
"llm": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.5,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
},
"client": {
"lc": 1,
"type": "not_implemented",
"id": [
"openai",
"api_resources",
"completion",
"Completion"
]
}
}
},
"prompt": {
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"prompts",
"prompt",
"PromptTemplate"
],
"kwargs": {
"input_variables": [
"name"
],
"template": "hello {name}!",
"template_format": "f-string"
}
}
}
}
'''
# ---
# name: test_serialize_openai_llm
'''
{
"lc": 1,
"type": "constructor",
"id": [
"langchain",
"llms",
"openai",
"OpenAI"
],
"kwargs": {
"model": "davinci",
"temperature": 0.7,
"openai_api_key": {
"lc": 1,
"type": "secret",
"id": [
"OPENAI_API_KEY"
]
}
}
}
'''
# ---

View File

@@ -0,0 +1,103 @@
"""Test for Serializable base class"""
from typing import Any, Dict
import pytest
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.llm import LLMChain
from langchain.chat_models.openai import ChatOpenAI
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.serializable import Serializable
from langchain.prompts.chat import ChatPromptTemplate, HumanMessagePromptTemplate
from langchain.prompts.prompt import PromptTemplate
class Person(Serializable):
secret: str
you_can_see_me: str = "hello"
@property
def lc_serializable(self) -> bool:
return True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"secret": "SECRET"}
@property
def lc_attributes(self) -> Dict[str, str]:
return {"you_can_see_me": self.you_can_see_me}
class SpecialPerson(Person):
another_secret: str
another_visible: str = "bye"
# Gets merged with parent class's secrets
@property
def lc_secrets(self) -> Dict[str, str]:
return {"another_secret": "ANOTHER_SECRET"}
# Gets merged with parent class's attributes
@property
def lc_attributes(self) -> Dict[str, str]:
return {"another_visible": self.another_visible}
class NotSerializable:
pass
def test_person(snapshot: Any) -> None:
p = Person(secret="hello")
assert dumps(p, pretty=True) == snapshot
sp = SpecialPerson(another_secret="Wooo", secret="Hmm")
assert dumps(sp, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_openai_llm(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
# This is excluded from serialization
callbacks=[LangChainTracer()],
)
llm.temperature = 0.7 # this is reflected in serialization
assert dumps(llm, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain(snapshot: Any) -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_chat(snapshot: Any) -> None:
llm = ChatOpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = ChatPromptTemplate.from_messages(
[HumanMessagePromptTemplate.from_template("hello {name}!")]
)
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot
@pytest.mark.requires("openai")
def test_serialize_llmchain_with_non_serializable_arg(snapshot: Any) -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
assert dumps(chain, pretty=True) == snapshot

View File

@@ -0,0 +1,54 @@
"""Test for Serializable base class"""
import pytest
from langchain.chains.llm import LLMChain
from langchain.llms.openai import OpenAI
from langchain.load.dump import dumps
from langchain.load.load import loads
from langchain.prompts.prompt import PromptTemplate
class NotSerializable:
pass
@pytest.mark.requires("openai")
def test_load_openai_llm() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
llm_string = dumps(llm)
llm2 = loads(llm_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert llm2 == llm
assert dumps(llm2) == llm_string
assert isinstance(llm2, OpenAI)
@pytest.mark.requires("openai")
def test_load_llmchain() -> None:
llm = OpenAI(model="davinci", temperature=0.5, openai_api_key="hello")
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain)
chain2 = loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})
assert chain2 == chain
assert dumps(chain2) == chain_string
assert isinstance(chain2, LLMChain)
assert isinstance(chain2.llm, OpenAI)
assert isinstance(chain2.prompt, PromptTemplate)
@pytest.mark.requires("openai")
def test_load_llmchain_with_non_serializable_arg() -> None:
llm = OpenAI(
model="davinci",
temperature=0.5,
openai_api_key="hello",
client=NotSerializable,
)
prompt = PromptTemplate.from_template("hello {name}!")
chain = LLMChain(llm=llm, prompt=prompt)
chain_string = dumps(chain, pretty=True)
with pytest.raises(NotImplementedError):
loads(chain_string, secrets_map={"OPENAI_API_KEY": "hello"})

View File

@@ -72,6 +72,7 @@ def test_test_group_dependencies(poetry_conf: Mapping[str, Any]) -> None:
"pytest-socket",
"pytest-watcher",
"responses",
"syrupy",
]