mirror of
https://github.com/hwchase17/langchain.git
synced 2026-04-24 04:36:46 +00:00
fix some tests
This commit is contained in:
@@ -112,9 +112,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if metadata:
|
||||
kwargs.update({"metadata": metadata})
|
||||
if len(prompts) != 1:
|
||||
raise ValueError(
|
||||
"Tracer does not support multiple prompts for LLM runs."
|
||||
)
|
||||
raise ValueError("Tracer does not support multiple prompts for LLM runs.")
|
||||
llm_run = Run(
|
||||
id=run_id,
|
||||
parent_run_id=parent_run_id,
|
||||
|
||||
@@ -13,4 +13,4 @@ def test_collect_runs() -> None:
|
||||
assert cb.traced_runs
|
||||
assert len(cb.traced_runs) == 1
|
||||
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
|
||||
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}
|
||||
assert cb.traced_runs[0].inputs == {"prompt": "hi"}
|
||||
|
||||
@@ -48,14 +48,14 @@ def test_tracer_llm_run() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED,
|
||||
inputs={"prompts": []},
|
||||
inputs={"prompt": "test"},
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -81,7 +81,7 @@ def test_tracer_chat_model_run() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED_CHAT,
|
||||
inputs=dict(prompts=["Human: "]),
|
||||
inputs=dict(prompt="Human: "),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
@@ -117,7 +117,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
inputs=dict(prompt="test"),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
error=None,
|
||||
run_type="llm",
|
||||
@@ -126,7 +126,7 @@ def test_tracer_multiple_llm_runs() -> None:
|
||||
|
||||
num_runs = 10
|
||||
for _ in range(num_runs):
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
|
||||
|
||||
assert tracer.runs == [compare_run] * num_runs
|
||||
@@ -208,7 +208,7 @@ def test_tracer_nested_run() -> None:
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
prompts=["test"],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
@@ -216,7 +216,7 @@ def test_tracer_nested_run() -> None:
|
||||
tracer.on_tool_end("test", run_id=tool_uuid)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
prompts=["test"],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
@@ -272,7 +272,7 @@ def test_tracer_nested_run() -> None:
|
||||
execution_order=3,
|
||||
child_execution_order=3,
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
inputs=dict(prompt="test"),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
)
|
||||
@@ -292,7 +292,7 @@ def test_tracer_nested_run() -> None:
|
||||
execution_order=4,
|
||||
child_execution_order=4,
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
inputs=dict(prompt="test"),
|
||||
outputs=LLMResult(generations=[[]]),
|
||||
run_type="llm",
|
||||
),
|
||||
@@ -320,14 +320,14 @@ def test_tracer_llm_run_on_error() -> None:
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
serialized=SERIALIZED,
|
||||
inputs=dict(prompts=[]),
|
||||
inputs=dict(prompt="test"),
|
||||
outputs=None,
|
||||
error=repr(exception),
|
||||
run_type="llm",
|
||||
)
|
||||
tracer = FakeTracer()
|
||||
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
|
||||
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
|
||||
tracer.on_llm_error(exception, run_id=uuid)
|
||||
assert tracer.runs == [compare_run]
|
||||
|
||||
@@ -411,14 +411,14 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
prompts=["test"],
|
||||
run_id=llm_uuid1,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
prompts=["test"],
|
||||
run_id=llm_uuid2,
|
||||
parent_run_id=chain_uuid,
|
||||
)
|
||||
@@ -431,7 +431,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
)
|
||||
tracer.on_llm_start(
|
||||
serialized=SERIALIZED,
|
||||
prompts=[],
|
||||
prompts=["test"],
|
||||
run_id=llm_uuid3,
|
||||
parent_run_id=tool_uuid,
|
||||
)
|
||||
@@ -470,7 +470,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
child_execution_order=2,
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
inputs=dict(prompt="test"),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
run_type="llm",
|
||||
),
|
||||
@@ -488,7 +488,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
child_execution_order=3,
|
||||
serialized=SERIALIZED,
|
||||
error=None,
|
||||
inputs=dict(prompts=[]),
|
||||
inputs=dict(prompt="test"),
|
||||
outputs=LLMResult(generations=[[]], llm_output=None),
|
||||
run_type="llm",
|
||||
),
|
||||
@@ -524,7 +524,7 @@ def test_tracer_nested_runs_on_error() -> None:
|
||||
child_execution_order=5,
|
||||
serialized=SERIALIZED,
|
||||
error=repr(exception),
|
||||
inputs=dict(prompts=[]),
|
||||
inputs=dict(prompt="test"),
|
||||
outputs=None,
|
||||
run_type="llm",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user