mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 14:43:07 +00:00
Post with child runs
This commit is contained in:
@@ -127,13 +127,22 @@ class LangChainTracer(BaseTracer):
|
||||
self.session = TracerSession(**r.json())
|
||||
return self.session
|
||||
|
||||
def _persist_run_nested(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
def _convert_to_create(self, run: Run) -> RunCreate:
|
||||
"""Convert a Run to a RunCreate."""
|
||||
session = self.ensure_session()
|
||||
child_runs = run.child_runs
|
||||
run_dict = run.dict()
|
||||
del run_dict["child_runs"]
|
||||
run_dict = run.dict(exclude_none=True, exclude={"child_runs", "parent_run_id"})
|
||||
run_create = RunCreate(**run_dict, session_id=session.id)
|
||||
run_create.child_runs = [
|
||||
self._convert_to_create(child_run) for child_run in child_runs
|
||||
]
|
||||
return run_create
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
run.reference_example_id = self.example_id
|
||||
# TODO: Post first then patch
|
||||
run_create = self._convert_to_create(run)
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self._endpoint}/runs",
|
||||
@@ -143,12 +152,3 @@ class LangChainTracer(BaseTracer):
|
||||
raise_for_status_with_text(response)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
for child_run in child_runs:
|
||||
child_run.parent_run_id = run.id
|
||||
self._persist_run_nested(child_run)
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
run.reference_example_id = self.example_id
|
||||
# TODO: Post first then patch
|
||||
self._persist_run_nested(run)
|
||||
|
||||
@@ -134,6 +134,7 @@ class Run(RunBase):
|
||||
class RunCreate(RunBase):
|
||||
name: str
|
||||
session_id: UUID
|
||||
child_runs: List[RunCreate] = Field(default_factory=list)
|
||||
|
||||
@root_validator(pre=True)
|
||||
def add_runtime_env(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@@ -146,3 +147,5 @@ class RunCreate(RunBase):
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
ToolRun.update_forward_refs()
|
||||
Run.update_forward_refs()
|
||||
RunCreate.update_forward_refs()
|
||||
|
||||
@@ -99,7 +99,9 @@ def test_persist_run(
|
||||
for run in sample_runs:
|
||||
lang_chain_tracer_v2._end_trace(run)
|
||||
|
||||
assert post.call_count == 3
|
||||
assert post.call_count == sum(
|
||||
[run.parent_run_id is None for run in sample_runs]
|
||||
)
|
||||
assert get.call_count == 0
|
||||
|
||||
|
||||
@@ -121,14 +123,10 @@ def test_persist_run_with_example_id(
|
||||
lang_chain_tracer_v2.example_id = example_id
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
|
||||
assert post.call_count == 3
|
||||
assert post.call_count == 1
|
||||
assert get.call_count == 0
|
||||
posted_data = [
|
||||
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
||||
]
|
||||
assert posted_data[0]["id"] == str(chain_run.id)
|
||||
assert posted_data[0]["reference_example_id"] == str(example_id)
|
||||
assert posted_data[1]["id"] == str(tool_run.id)
|
||||
assert not posted_data[1].get("reference_example_id")
|
||||
assert posted_data[2]["id"] == str(llm_run.id)
|
||||
assert not posted_data[2].get("reference_example_id")
|
||||
|
||||
Reference in New Issue
Block a user