Compare commits

...

1 Commits

Author SHA1 Message Date
vowelparrot
92b10d98e0 Add Patch to stream runs 2023-05-14 19:42:30 -07:00
4 changed files with 94 additions and 33 deletions

View File

@@ -18,6 +18,8 @@ class TracerException(Exception):
class BaseTracer(BaseCallbackHandler, ABC):
"""Base interface for tracers."""
_supports_patch = False
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.run_map: Dict[str, Run] = {}
@@ -30,9 +32,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
"""Add child run to a chain run or tool run."""
parent_run.child_runs.append(child_run)
def _persist_partial_run(self, run: Run) -> None:
"""Persist a run on trace start."""
@abstractmethod
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
"""Persist or patch a run on end or error."""
def _start_trace(self, run: Run) -> None:
"""Start a trace for a run."""
@@ -45,12 +50,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
f"Parent run with UUID {run.parent_run_id} not found."
)
self.run_map[str(run.id)] = run
self._persist_partial_run(run)
def _end_trace(self, run: Run) -> None:
"""End a trace for a run."""
if not run.parent_run_id:
self._persist_run(run)
else:
if self._supports_patch:
self._persist_run(run)
parent_run = self.run_map.get(str(run.parent_run_id))
if parent_run is None:
raise TracerException(
@@ -58,6 +66,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
)
if run.child_execution_order > parent_run.child_execution_order:
parent_run.child_execution_order = run.child_execution_order
self.run_map.pop(str(run.id))
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:

View File

@@ -14,6 +14,7 @@ from langchain.callbacks.tracers.schemas import (
Run,
RunCreate,
RunTypeEnum,
RunUpdate,
TracerSession,
TracerSessionCreate,
)
@@ -53,6 +54,8 @@ def _get_tenant_id(
class LangChainTracer(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
_supports_patch = True
def __init__(
self,
tenant_id: Optional[str] = None,
@@ -124,10 +127,13 @@ class LangChainTracer(BaseTracer):
self.session = TracerSession(**r.json())
return self.session
def _persist_run_nested(self, run: Run) -> None:
"""Persist a run."""
def _persist_partial_run(self, run: Run) -> None:
"""Persist a run on the start of a trace."""
if run.parent_run_id is None:
# If we are tracing examples in a dataset,
# relate only the top-level run to the example.
run.reference_example_id = self.example_id
session = self.ensure_session()
child_runs = run.child_runs
run_dict = run.dict()
del run_dict["child_runs"]
run_create = RunCreate(**run_dict, session_id=session.id)
@@ -140,12 +146,16 @@ class LangChainTracer(BaseTracer):
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to persist run: {e}")
for child_run in child_runs:
child_run.parent_run_id = run.id
self._persist_run_nested(child_run)
def _persist_run(self, run: Run) -> None:
"""Persist a run."""
run.reference_example_id = self.example_id
# TODO: Post first then patch
self._persist_run_nested(run)
"""Update a run on the trace end or error."""
update_run = RunUpdate(**run.dict())
try:
response = requests.patch(
f"{self._endpoint}/runs/{run.id}",
data=update_run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logging.warning(f"Failed to update run: {e}")

View File

@@ -133,9 +133,22 @@ class Run(RunBase):
class RunCreate(RunBase):
"""Run schema for a create request."""
name: str
session_id: UUID
class RunUpdate(BaseModel):
"""Schema for a patch request to update a run."""
end_time: Optional[datetime.datetime]
extra: Optional[Dict]
error: Optional[str]
outputs: Optional[Dict]
parent_run_id: Optional[UUID]
reference_example_id: Optional[UUID]
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()

View File

@@ -19,7 +19,7 @@ _TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
@pytest.fixture
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
def lang_chain_tracer(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
@@ -29,7 +29,7 @@ def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
# Mock a sample TracerSession object
@pytest.fixture
def sample_tracer_session_v2() -> TracerSession:
def sample_tracer_session() -> TracerSession:
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
@@ -84,28 +84,28 @@ def sample_runs() -> Tuple[Run, Run, Run]:
def test_persist_run(
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
lang_chain_tracer: LangChainTracer,
sample_tracer_session: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test that persist_run method calls requests.post once per method call."""
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
with patch(
"langchain.callbacks.tracers.langchain.requests.patch"
) as req_patch, patch("langchain.callbacks.tracers.langchain.requests.get") as get:
req_patch.return_value.raise_for_status.return_value = None
lang_chain_tracer.session = sample_tracer_session
for run in sample_runs:
lang_chain_tracer_v2.run_map[str(run.id)] = run
lang_chain_tracer.run_map[str(run.id)] = run
for run in sample_runs:
lang_chain_tracer_v2._end_trace(run)
lang_chain_tracer._end_trace(run)
assert post.call_count == 3
assert req_patch.call_count == 3
assert get.call_count == 0
def test_persist_run_with_example_id(
lang_chain_tracer_v2: LangChainTracer,
sample_tracer_session_v2: TracerSession,
def test_persist_partial_run_with_example_id(
lang_chain_tracer: LangChainTracer,
sample_tracer_session: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test the example ID is assigned only to the parent run and not the children."""
@@ -113,22 +113,51 @@ def test_persist_run_with_example_id(
llm_run, chain_run, tool_run = sample_runs
chain_run.child_runs = [tool_run]
tool_run.child_runs = [llm_run]
tool_run.parent_run_id = chain_run.id
llm_run.parent_run_id = tool_run.id
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
"langchain.callbacks.tracers.langchain.requests.get"
) as get:
post.return_value.raise_for_status.return_value = None
lang_chain_tracer_v2.session = sample_tracer_session_v2
lang_chain_tracer_v2.example_id = example_id
lang_chain_tracer_v2._persist_run(chain_run)
lang_chain_tracer.session = sample_tracer_session
lang_chain_tracer.example_id = example_id
lang_chain_tracer._persist_partial_run(llm_run)
lang_chain_tracer._persist_partial_run(tool_run)
lang_chain_tracer._persist_partial_run(chain_run)
assert post.call_count == 3
assert get.call_count == 0
posted_data = [
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
]
assert posted_data[0]["id"] == str(chain_run.id)
assert posted_data[0]["reference_example_id"] == str(example_id)
# Assert that the URL that was called ends with /runs/<run_id>
assert posted_data[0]["id"] == str(llm_run.id)
assert not posted_data[0].get("reference_example_id")
assert posted_data[1]["id"] == str(tool_run.id)
assert not posted_data[1].get("reference_example_id")
assert posted_data[2]["id"] == str(llm_run.id)
assert not posted_data[2].get("reference_example_id")
assert posted_data[2]["id"] == str(chain_run.id)
assert posted_data[2]["reference_example_id"] == str(example_id)
def test_persist_run_with_example_id(
lang_chain_tracer: LangChainTracer,
sample_tracer_session: TracerSession,
sample_runs: Tuple[Run, Run, Run],
) -> None:
"""Test the persist / patch run is called with the correct ID."""
example_id = uuid4()
llm_run, chain_run, tool_run = sample_runs
chain_run.child_runs = [tool_run]
tool_run.child_runs = [llm_run]
with patch(
"langchain.callbacks.tracers.langchain.requests.patch"
) as patch_req, patch("langchain.callbacks.tracers.langchain.requests.get") as get:
patch_req.return_value.raise_for_status.return_value = None
lang_chain_tracer.session = sample_tracer_session
lang_chain_tracer.example_id = example_id
lang_chain_tracer._persist_run(chain_run)
assert patch_req.call_count == 1
assert get.call_count == 0
# Assert that the URL that was called ends with /runs/<run_id>
assert patch_req.call_args[0][0].endswith(f"/runs/{chain_run.id}")