mirror of
https://github.com/hwchase17/langchain.git
synced 2026-02-21 06:33:41 +00:00
Compare commits
1 Commits
langchain-
...
vwp/patch
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
92b10d98e0 |
@@ -18,6 +18,8 @@ class TracerException(Exception):
|
|||||||
class BaseTracer(BaseCallbackHandler, ABC):
|
class BaseTracer(BaseCallbackHandler, ABC):
|
||||||
"""Base interface for tracers."""
|
"""Base interface for tracers."""
|
||||||
|
|
||||||
|
_supports_patch = False
|
||||||
|
|
||||||
def __init__(self, **kwargs: Any) -> None:
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
self.run_map: Dict[str, Run] = {}
|
self.run_map: Dict[str, Run] = {}
|
||||||
@@ -30,9 +32,12 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
"""Add child run to a chain run or tool run."""
|
"""Add child run to a chain run or tool run."""
|
||||||
parent_run.child_runs.append(child_run)
|
parent_run.child_runs.append(child_run)
|
||||||
|
|
||||||
|
def _persist_partial_run(self, run: Run) -> None:
|
||||||
|
"""Persist a run on trace start."""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist or patch a run on end or error."""
|
||||||
|
|
||||||
def _start_trace(self, run: Run) -> None:
|
def _start_trace(self, run: Run) -> None:
|
||||||
"""Start a trace for a run."""
|
"""Start a trace for a run."""
|
||||||
@@ -45,12 +50,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
f"Parent run with UUID {run.parent_run_id} not found."
|
f"Parent run with UUID {run.parent_run_id} not found."
|
||||||
)
|
)
|
||||||
self.run_map[str(run.id)] = run
|
self.run_map[str(run.id)] = run
|
||||||
|
self._persist_partial_run(run)
|
||||||
|
|
||||||
def _end_trace(self, run: Run) -> None:
|
def _end_trace(self, run: Run) -> None:
|
||||||
"""End a trace for a run."""
|
"""End a trace for a run."""
|
||||||
if not run.parent_run_id:
|
if not run.parent_run_id:
|
||||||
self._persist_run(run)
|
self._persist_run(run)
|
||||||
else:
|
else:
|
||||||
|
if self._supports_patch:
|
||||||
|
self._persist_run(run)
|
||||||
parent_run = self.run_map.get(str(run.parent_run_id))
|
parent_run = self.run_map.get(str(run.parent_run_id))
|
||||||
if parent_run is None:
|
if parent_run is None:
|
||||||
raise TracerException(
|
raise TracerException(
|
||||||
@@ -58,6 +66,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
|||||||
)
|
)
|
||||||
if run.child_execution_order > parent_run.child_execution_order:
|
if run.child_execution_order > parent_run.child_execution_order:
|
||||||
parent_run.child_execution_order = run.child_execution_order
|
parent_run.child_execution_order = run.child_execution_order
|
||||||
|
|
||||||
self.run_map.pop(str(run.id))
|
self.run_map.pop(str(run.id))
|
||||||
|
|
||||||
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
def _get_execution_order(self, parent_run_id: Optional[str] = None) -> int:
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from langchain.callbacks.tracers.schemas import (
|
|||||||
Run,
|
Run,
|
||||||
RunCreate,
|
RunCreate,
|
||||||
RunTypeEnum,
|
RunTypeEnum,
|
||||||
|
RunUpdate,
|
||||||
TracerSession,
|
TracerSession,
|
||||||
TracerSessionCreate,
|
TracerSessionCreate,
|
||||||
)
|
)
|
||||||
@@ -53,6 +54,8 @@ def _get_tenant_id(
|
|||||||
class LangChainTracer(BaseTracer):
|
class LangChainTracer(BaseTracer):
|
||||||
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
|
||||||
|
|
||||||
|
_supports_patch = True
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
tenant_id: Optional[str] = None,
|
tenant_id: Optional[str] = None,
|
||||||
@@ -124,10 +127,13 @@ class LangChainTracer(BaseTracer):
|
|||||||
self.session = TracerSession(**r.json())
|
self.session = TracerSession(**r.json())
|
||||||
return self.session
|
return self.session
|
||||||
|
|
||||||
def _persist_run_nested(self, run: Run) -> None:
|
def _persist_partial_run(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run on the start of a trace."""
|
||||||
|
if run.parent_run_id is None:
|
||||||
|
# If we are tracing examples in a dataset,
|
||||||
|
# relate only the top-level run to the example.
|
||||||
|
run.reference_example_id = self.example_id
|
||||||
session = self.ensure_session()
|
session = self.ensure_session()
|
||||||
child_runs = run.child_runs
|
|
||||||
run_dict = run.dict()
|
run_dict = run.dict()
|
||||||
del run_dict["child_runs"]
|
del run_dict["child_runs"]
|
||||||
run_create = RunCreate(**run_dict, session_id=session.id)
|
run_create = RunCreate(**run_dict, session_id=session.id)
|
||||||
@@ -140,12 +146,16 @@ class LangChainTracer(BaseTracer):
|
|||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to persist run: {e}")
|
logging.warning(f"Failed to persist run: {e}")
|
||||||
for child_run in child_runs:
|
|
||||||
child_run.parent_run_id = run.id
|
|
||||||
self._persist_run_nested(child_run)
|
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Update a run on the trace end or error."""
|
||||||
run.reference_example_id = self.example_id
|
update_run = RunUpdate(**run.dict())
|
||||||
# TODO: Post first then patch
|
try:
|
||||||
self._persist_run_nested(run)
|
response = requests.patch(
|
||||||
|
f"{self._endpoint}/runs/{run.id}",
|
||||||
|
data=update_run.json(),
|
||||||
|
headers=self._headers,
|
||||||
|
)
|
||||||
|
raise_for_status_with_text(response)
|
||||||
|
except Exception as e:
|
||||||
|
logging.warning(f"Failed to update run: {e}")
|
||||||
|
|||||||
@@ -133,9 +133,22 @@ class Run(RunBase):
|
|||||||
|
|
||||||
|
|
||||||
class RunCreate(RunBase):
|
class RunCreate(RunBase):
|
||||||
|
"""Run schema for a create request."""
|
||||||
|
|
||||||
name: str
|
name: str
|
||||||
session_id: UUID
|
session_id: UUID
|
||||||
|
|
||||||
|
|
||||||
|
class RunUpdate(BaseModel):
|
||||||
|
"""Schema for a patch request to update a run."""
|
||||||
|
|
||||||
|
end_time: Optional[datetime.datetime]
|
||||||
|
extra: Optional[Dict]
|
||||||
|
error: Optional[str]
|
||||||
|
outputs: Optional[Dict]
|
||||||
|
parent_run_id: Optional[UUID]
|
||||||
|
reference_example_id: Optional[UUID]
|
||||||
|
|
||||||
|
|
||||||
ChainRun.update_forward_refs()
|
ChainRun.update_forward_refs()
|
||||||
ToolRun.update_forward_refs()
|
ToolRun.update_forward_refs()
|
||||||
|
|||||||
@@ -19,7 +19,7 @@ _TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
def lang_chain_tracer(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
||||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||||
@@ -29,7 +29,7 @@ def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
|||||||
|
|
||||||
# Mock a sample TracerSession object
|
# Mock a sample TracerSession object
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def sample_tracer_session_v2() -> TracerSession:
|
def sample_tracer_session() -> TracerSession:
|
||||||
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
||||||
|
|
||||||
|
|
||||||
@@ -84,28 +84,28 @@ def sample_runs() -> Tuple[Run, Run, Run]:
|
|||||||
|
|
||||||
|
|
||||||
def test_persist_run(
|
def test_persist_run(
|
||||||
lang_chain_tracer_v2: LangChainTracer,
|
lang_chain_tracer: LangChainTracer,
|
||||||
sample_tracer_session_v2: TracerSession,
|
sample_tracer_session: TracerSession,
|
||||||
sample_runs: Tuple[Run, Run, Run],
|
sample_runs: Tuple[Run, Run, Run],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that persist_run method calls requests.post once per method call."""
|
"""Test that persist_run method calls requests.post once per method call."""
|
||||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
with patch(
|
||||||
"langchain.callbacks.tracers.langchain.requests.get"
|
"langchain.callbacks.tracers.langchain.requests.patch"
|
||||||
) as get:
|
) as req_patch, patch("langchain.callbacks.tracers.langchain.requests.get") as get:
|
||||||
post.return_value.raise_for_status.return_value = None
|
req_patch.return_value.raise_for_status.return_value = None
|
||||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
lang_chain_tracer.session = sample_tracer_session
|
||||||
for run in sample_runs:
|
for run in sample_runs:
|
||||||
lang_chain_tracer_v2.run_map[str(run.id)] = run
|
lang_chain_tracer.run_map[str(run.id)] = run
|
||||||
for run in sample_runs:
|
for run in sample_runs:
|
||||||
lang_chain_tracer_v2._end_trace(run)
|
lang_chain_tracer._end_trace(run)
|
||||||
|
|
||||||
assert post.call_count == 3
|
assert req_patch.call_count == 3
|
||||||
assert get.call_count == 0
|
assert get.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
def test_persist_run_with_example_id(
|
def test_persist_partial_run_with_example_id(
|
||||||
lang_chain_tracer_v2: LangChainTracer,
|
lang_chain_tracer: LangChainTracer,
|
||||||
sample_tracer_session_v2: TracerSession,
|
sample_tracer_session: TracerSession,
|
||||||
sample_runs: Tuple[Run, Run, Run],
|
sample_runs: Tuple[Run, Run, Run],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the example ID is assigned only to the parent run and not the children."""
|
"""Test the example ID is assigned only to the parent run and not the children."""
|
||||||
@@ -113,22 +113,51 @@ def test_persist_run_with_example_id(
|
|||||||
llm_run, chain_run, tool_run = sample_runs
|
llm_run, chain_run, tool_run = sample_runs
|
||||||
chain_run.child_runs = [tool_run]
|
chain_run.child_runs = [tool_run]
|
||||||
tool_run.child_runs = [llm_run]
|
tool_run.child_runs = [llm_run]
|
||||||
|
tool_run.parent_run_id = chain_run.id
|
||||||
|
llm_run.parent_run_id = tool_run.id
|
||||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||||
"langchain.callbacks.tracers.langchain.requests.get"
|
"langchain.callbacks.tracers.langchain.requests.get"
|
||||||
) as get:
|
) as get:
|
||||||
post.return_value.raise_for_status.return_value = None
|
post.return_value.raise_for_status.return_value = None
|
||||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
lang_chain_tracer.session = sample_tracer_session
|
||||||
lang_chain_tracer_v2.example_id = example_id
|
lang_chain_tracer.example_id = example_id
|
||||||
lang_chain_tracer_v2._persist_run(chain_run)
|
lang_chain_tracer._persist_partial_run(llm_run)
|
||||||
|
lang_chain_tracer._persist_partial_run(tool_run)
|
||||||
|
lang_chain_tracer._persist_partial_run(chain_run)
|
||||||
|
|
||||||
assert post.call_count == 3
|
assert post.call_count == 3
|
||||||
assert get.call_count == 0
|
assert get.call_count == 0
|
||||||
posted_data = [
|
posted_data = [
|
||||||
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
||||||
]
|
]
|
||||||
assert posted_data[0]["id"] == str(chain_run.id)
|
# Assert that the URL that was called ends with /runs/<run_id>
|
||||||
assert posted_data[0]["reference_example_id"] == str(example_id)
|
assert posted_data[0]["id"] == str(llm_run.id)
|
||||||
|
assert not posted_data[0].get("reference_example_id")
|
||||||
assert posted_data[1]["id"] == str(tool_run.id)
|
assert posted_data[1]["id"] == str(tool_run.id)
|
||||||
assert not posted_data[1].get("reference_example_id")
|
assert not posted_data[1].get("reference_example_id")
|
||||||
assert posted_data[2]["id"] == str(llm_run.id)
|
assert posted_data[2]["id"] == str(chain_run.id)
|
||||||
assert not posted_data[2].get("reference_example_id")
|
assert posted_data[2]["reference_example_id"] == str(example_id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_persist_run_with_example_id(
|
||||||
|
lang_chain_tracer: LangChainTracer,
|
||||||
|
sample_tracer_session: TracerSession,
|
||||||
|
sample_runs: Tuple[Run, Run, Run],
|
||||||
|
) -> None:
|
||||||
|
"""Test the persist / patch run is called with the correct ID."""
|
||||||
|
example_id = uuid4()
|
||||||
|
llm_run, chain_run, tool_run = sample_runs
|
||||||
|
chain_run.child_runs = [tool_run]
|
||||||
|
tool_run.child_runs = [llm_run]
|
||||||
|
with patch(
|
||||||
|
"langchain.callbacks.tracers.langchain.requests.patch"
|
||||||
|
) as patch_req, patch("langchain.callbacks.tracers.langchain.requests.get") as get:
|
||||||
|
patch_req.return_value.raise_for_status.return_value = None
|
||||||
|
lang_chain_tracer.session = sample_tracer_session
|
||||||
|
lang_chain_tracer.example_id = example_id
|
||||||
|
lang_chain_tracer._persist_run(chain_run)
|
||||||
|
|
||||||
|
assert patch_req.call_count == 1
|
||||||
|
assert get.call_count == 0
|
||||||
|
# Assert that the URL that was called ends with /runs/<run_id>
|
||||||
|
assert patch_req.call_args[0][0].endswith(f"/runs/{chain_run.id}")
|
||||||
|
|||||||
Reference in New Issue
Block a user