Compare commits

...

1 Commits

Author SHA1 Message Date
vowelparrot
0bdf280469 Post with child runs 2023-05-28 16:03:40 -07:00
3 changed files with 20 additions and 19 deletions

View File

@@ -127,13 +127,22 @@ class LangChainTracer(BaseTracer):
self.session = TracerSession(**r.json())
return self.session
def _persist_run_nested(self, run: Run) -> None:
"""Persist a run."""
def _convert_to_create(self, run: Run) -> RunCreate:
"""Convert a Run to a RunCreate."""
session = self.ensure_session()
child_runs = run.child_runs
run_dict = run.dict()
del run_dict["child_runs"]
run_dict = run.dict(exclude_none=True, exclude={"child_runs", "parent_run_id"})
run_create = RunCreate(**run_dict, session_id=session.id)
run_create.child_runs = [
self._convert_to_create(child_run) for child_run in child_runs
]
return run_create
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
run.reference_example_id = self.example_id
# TODO: Post first then patch
run_create = self._convert_to_create(run)
try:
response = requests.post(
f"{self._endpoint}/runs",
@@ -143,12 +152,3 @@ class LangChainTracer(BaseTracer):
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
for child_run in child_runs:
child_run.parent_run_id = run.id
self._persist_run_nested(child_run)
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
run.reference_example_id = self.example_id
# TODO: Post first then patch
self._persist_run_nested(run)

View File

@@ -134,6 +134,7 @@ class Run(RunBase):
class RunCreate(RunBase):
name: str
session_id: UUID
child_runs: List[RunCreate] = Field(default_factory=list)
@root_validator(pre=True)
def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@@ -146,3 +147,5 @@ class RunCreate(RunBase):
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
Run.update_forward_refs()
RunCreate.update_forward_refs()

View File

@@ -99,7 +99,9 @@ def test_persist_run(
for run in sample_runs:
lang_chain_tracer_v2._end_trace(run)
assert post.call_count == 3
assert post.call_count == sum(
[run.parent_run_id is None for run in sample_runs]
)
assert get.call_count == 0
@@ -121,14 +123,10 @@ def test_persist_run_with_example_id(
lang_chain_tracer_v2.example_id = example_id
lang_chain_tracer_v2._persist_run(chain_run)
assert post.call_count == 3
assert post.call_count == 1
assert get.call_count == 0
posted_data = [
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
]
assert posted_data[0]["id"] == str(chain_run.id)
assert posted_data[0]["reference_example_id"] == str(example_id)
assert posted_data[1]["id"] == str(tool_run.id)
assert not posted_data[1].get("reference_example_id")
assert posted_data[2]["id"] == str(llm_run.id)
assert not posted_data[2].get("reference_example_id")