fix some tests

This commit is contained in:
Ankush Gola
2023-09-28 11:27:44 -07:00
parent 19e11a602d
commit 548ca264ec
3 changed files with 19 additions and 21 deletions

View File

@@ -112,9 +112,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if metadata:
kwargs.update({"metadata": metadata})
if len(prompts) != 1:
raise ValueError(
"Tracer does not support multiple prompts for LLM runs."
)
raise ValueError("Tracer does not support multiple prompts for LLM runs.")
llm_run = Run(
id=run_id,
parent_run_id=parent_run_id,

View File

@@ -13,4 +13,4 @@ def test_collect_runs() -> None:
assert cb.traced_runs
assert len(cb.traced_runs) == 1
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}
assert cb.traced_runs[0].inputs == {"prompt": "hi"}

View File

@@ -48,14 +48,14 @@ def test_tracer_llm_run() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
inputs={"prompts": []},
inputs={"prompt": "test"},
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@@ -81,7 +81,7 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]),
inputs=dict(prompt="Human: "),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
@@ -117,7 +117,7 @@ def test_tracer_multiple_llm_runs() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
@@ -126,7 +126,7 @@ def test_tracer_multiple_llm_runs() -> None:
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@@ -208,7 +208,7 @@ def test_tracer_nested_run() -> None:
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
)
@@ -216,7 +216,7 @@ def test_tracer_nested_run() -> None:
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
@@ -272,7 +272,7 @@ def test_tracer_nested_run() -> None:
execution_order=3,
child_execution_order=3,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]]),
run_type="llm",
)
@@ -292,7 +292,7 @@ def test_tracer_nested_run() -> None:
execution_order=4,
child_execution_order=4,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]]),
run_type="llm",
),
@@ -320,14 +320,14 @@ def test_tracer_llm_run_on_error() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=None,
error=repr(exception),
run_type="llm",
)
tracer = FakeTracer()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@@ -411,14 +411,14 @@ def test_tracer_nested_runs_on_error() -> None:
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
@@ -431,7 +431,7 @@ def test_tracer_nested_runs_on_error() -> None:
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
@@ -470,7 +470,7 @@ def test_tracer_nested_runs_on_error() -> None:
child_execution_order=2,
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]], llm_output=None),
run_type="llm",
),
@@ -488,7 +488,7 @@ def test_tracer_nested_runs_on_error() -> None:
child_execution_order=3,
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]], llm_output=None),
run_type="llm",
),
@@ -524,7 +524,7 @@ def test_tracer_nested_runs_on_error() -> None:
child_execution_order=5,
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=None,
run_type="llm",
)