mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-11 05:45:01 +00:00
py tracer fixes (#5377)
This commit is contained in:
parent
ce8b7a2a69
commit
1671c2afb2
@ -347,7 +347,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 12,
|
"execution_count": 7,
|
||||||
"id": "87027b0d-3a61-47cf-8a65-3002968be7f9",
|
"id": "87027b0d-3a61-47cf-8a65-3002968be7f9",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -356,13 +356,13 @@
|
|||||||
"source": [
|
"source": [
|
||||||
"import os\n",
|
"import os\n",
|
||||||
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
||||||
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchainpro-api-gateway-12bfv6cf.uc.gateway.dev\" # Uncomment this line if you want to use the hosted version\n",
|
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.langchain.plus\" # Uncomment this line if you want to use the hosted version\n",
|
||||||
"# os.environ[\"LANGCHAIN_API_KEY\"] = \"<YOUR-LANGCHAINPLUS-API-KEY>\" # Uncomment this line if you want to use the hosted version."
|
"# os.environ[\"LANGCHAIN_API_KEY\"] = \"<YOUR-LANGCHAINPLUS-API-KEY>\" # Uncomment this line if you want to use the hosted version."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 13,
|
"execution_count": 8,
|
||||||
"id": "5b4f49a2-7d09-4601-a8ba-976f0517c64c",
|
"id": "5b4f49a2-7d09-4601-a8ba-976f0517c64c",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -379,7 +379,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 14,
|
"execution_count": 9,
|
||||||
"id": "029b4a57-dc49-49de-8f03-53c292144e09",
|
"id": "029b4a57-dc49-49de-8f03-53c292144e09",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -397,7 +397,7 @@
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"cell_type": "code",
|
"cell_type": "code",
|
||||||
"execution_count": 15,
|
"execution_count": 10,
|
||||||
"id": "91a85fb2-6027-4bd0-b1fe-2a3b3b79e2dd",
|
"id": "91a85fb2-6027-4bd0-b1fe-2a3b3b79e2dd",
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"tags": []
|
"tags": []
|
||||||
@ -426,7 +426,7 @@
|
|||||||
"'1.0891804557407723'"
|
"'1.0891804557407723'"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"execution_count": 15,
|
"execution_count": 10,
|
||||||
"metadata": {},
|
"metadata": {},
|
||||||
"output_type": "execute_result"
|
"output_type": "execute_result"
|
||||||
}
|
}
|
||||||
|
@ -3,24 +3,35 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
from requests.exceptions import HTTPError
|
||||||
|
from tenacity import (
|
||||||
|
before_sleep_log,
|
||||||
|
retry,
|
||||||
|
retry_if_exception_type,
|
||||||
|
stop_after_attempt,
|
||||||
|
wait_exponential,
|
||||||
|
)
|
||||||
|
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import (
|
from langchain.callbacks.tracers.schemas import (
|
||||||
Run,
|
Run,
|
||||||
RunCreate,
|
RunCreate,
|
||||||
RunTypeEnum,
|
RunTypeEnum,
|
||||||
|
RunUpdate,
|
||||||
TracerSession,
|
TracerSession,
|
||||||
TracerSessionCreate,
|
TracerSessionCreate,
|
||||||
)
|
)
|
||||||
from langchain.schema import BaseMessage, messages_to_dict
|
from langchain.schema import BaseMessage, messages_to_dict
|
||||||
from langchain.utils import raise_for_status_with_text
|
from langchain.utils import raise_for_status_with_text
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_headers() -> Dict[str, Any]:
|
def get_headers() -> Dict[str, Any]:
|
||||||
"""Get the headers for the LangChain API."""
|
"""Get the headers for the LangChain API."""
|
||||||
@ -34,7 +45,27 @@ def get_endpoint() -> str:
|
|||||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||||
|
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
class LangChainTracerAPIError(Exception):
|
||||||
|
"""An error occurred while communicating with the LangChain API."""
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainTracerUserError(Exception):
|
||||||
|
"""An error occurred while communicating with the LangChain API."""
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainTracerError(Exception):
|
||||||
|
"""An error occurred while communicating with the LangChain API."""
|
||||||
|
|
||||||
|
|
||||||
|
retry_decorator = retry(
|
||||||
|
stop=stop_after_attempt(3),
|
||||||
|
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||||
|
retry=retry_if_exception_type(LangChainTracerAPIError),
|
||||||
|
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
def _get_tenant_id(
|
def _get_tenant_id(
|
||||||
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
|
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -44,8 +75,24 @@ def _get_tenant_id(
|
|||||||
return tenant_id_
|
return tenant_id_
|
||||||
endpoint_ = endpoint or get_endpoint()
|
endpoint_ = endpoint or get_endpoint()
|
||||||
headers_ = headers or get_headers()
|
headers_ = headers or get_headers()
|
||||||
|
response = None
|
||||||
|
try:
|
||||||
response = requests.get(endpoint_ + "/tenants", headers=headers_)
|
response = requests.get(endpoint_ + "/tenants", headers=headers_)
|
||||||
raise_for_status_with_text(response)
|
raise_for_status_with_text(response)
|
||||||
|
except HTTPError as e:
|
||||||
|
if response is not None and response.status_code == 500:
|
||||||
|
raise LangChainTracerAPIError(
|
||||||
|
f"Failed to get tenant ID from LangChain API. {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise LangChainTracerUserError(
|
||||||
|
f"Failed to get tenant ID from LangChain API. {e}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise LangChainTracerError(
|
||||||
|
f"Failed to get tenant ID from LangChain API. {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
tenants: List[Dict[str, Any]] = response.json()
|
tenants: List[Dict[str, Any]] = response.json()
|
||||||
if not tenants:
|
if not tenants:
|
||||||
raise ValueError(f"No tenants found for URL {endpoint_}")
|
raise ValueError(f"No tenants found for URL {endpoint_}")
|
||||||
@ -72,6 +119,8 @@ class LangChainTracer(BaseTracer):
|
|||||||
self.example_id = example_id
|
self.example_id = example_id
|
||||||
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
||||||
self.session_extra = session_extra
|
self.session_extra = session_extra
|
||||||
|
# set max_workers to 1 to process tasks in order
|
||||||
|
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||||
|
|
||||||
def on_chat_model_start(
|
def on_chat_model_start(
|
||||||
self,
|
self,
|
||||||
@ -108,7 +157,7 @@ class LangChainTracer(BaseTracer):
|
|||||||
self.tenant_id = tenant_id
|
self.tenant_id = tenant_id
|
||||||
return tenant_id
|
return tenant_id
|
||||||
|
|
||||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
@retry_decorator
|
||||||
def ensure_session(self) -> TracerSession:
|
def ensure_session(self) -> TracerSession:
|
||||||
"""Upsert a session."""
|
"""Upsert a session."""
|
||||||
if self.session is not None:
|
if self.session is not None:
|
||||||
@ -118,37 +167,124 @@ class LangChainTracer(BaseTracer):
|
|||||||
session_create = TracerSessionCreate(
|
session_create = TracerSessionCreate(
|
||||||
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
|
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
|
||||||
)
|
)
|
||||||
r = requests.post(
|
response = None
|
||||||
|
try:
|
||||||
|
response = requests.post(
|
||||||
url,
|
url,
|
||||||
data=session_create.json(),
|
data=session_create.json(),
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
)
|
)
|
||||||
raise_for_status_with_text(r)
|
response.raise_for_status()
|
||||||
self.session = TracerSession(**r.json())
|
except HTTPError as e:
|
||||||
|
if response is not None and response.status_code == 500:
|
||||||
|
raise LangChainTracerAPIError(
|
||||||
|
f"Failed to upsert session to LangChain API. {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise LangChainTracerUserError(
|
||||||
|
f"Failed to upsert session to LangChain API. {e}"
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise LangChainTracerError(
|
||||||
|
f"Failed to upsert session to LangChain API. {e}"
|
||||||
|
) from e
|
||||||
|
self.session = TracerSession(**response.json())
|
||||||
return self.session
|
return self.session
|
||||||
|
|
||||||
def _persist_run_nested(self, run: Run) -> None:
|
def _persist_run(self, run: Run) -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
|
||||||
|
@retry_decorator
|
||||||
|
def _persist_run_single(self, run: Run) -> None:
|
||||||
"""Persist a run."""
|
"""Persist a run."""
|
||||||
session = self.ensure_session()
|
session = self.ensure_session()
|
||||||
child_runs = run.child_runs
|
if run.parent_run_id is None:
|
||||||
|
run.reference_example_id = self.example_id
|
||||||
run_dict = run.dict()
|
run_dict = run.dict()
|
||||||
del run_dict["child_runs"]
|
del run_dict["child_runs"]
|
||||||
run_create = RunCreate(**run_dict, session_id=session.id)
|
run_create = RunCreate(**run_dict, session_id=session.id)
|
||||||
|
response = None
|
||||||
try:
|
try:
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
f"{self._endpoint}/runs",
|
f"{self._endpoint}/runs",
|
||||||
data=run_create.json(),
|
data=run_create.json(),
|
||||||
headers=self._headers,
|
headers=self._headers,
|
||||||
)
|
)
|
||||||
raise_for_status_with_text(response)
|
response.raise_for_status()
|
||||||
|
except HTTPError as e:
|
||||||
|
if response is not None and response.status_code == 500:
|
||||||
|
raise LangChainTracerAPIError(
|
||||||
|
f"Failed to upsert persist run to LangChain API. {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise LangChainTracerUserError(
|
||||||
|
f"Failed to persist run to LangChain API. {e}"
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning(f"Failed to persist run: {e}")
|
raise LangChainTracerError(
|
||||||
for child_run in child_runs:
|
f"Failed to persist run to LangChain API. {e}"
|
||||||
child_run.parent_run_id = run.id
|
) from e
|
||||||
self._persist_run_nested(child_run)
|
|
||||||
|
|
||||||
def _persist_run(self, run: Run) -> None:
|
@retry_decorator
|
||||||
"""Persist a run."""
|
def _update_run_single(self, run: Run) -> None:
|
||||||
run.reference_example_id = self.example_id
|
"""Update a run."""
|
||||||
# TODO: Post first then patch
|
run_update = RunUpdate(**run.dict())
|
||||||
self._persist_run_nested(run)
|
response = None
|
||||||
|
try:
|
||||||
|
response = requests.patch(
|
||||||
|
f"{self._endpoint}/runs/{run.id}",
|
||||||
|
data=run_update.json(),
|
||||||
|
headers=self._headers,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except HTTPError as e:
|
||||||
|
if response is not None and response.status_code == 500:
|
||||||
|
raise LangChainTracerAPIError(
|
||||||
|
f"Failed to update run to LangChain API. {e}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise LangChainTracerUserError(f"Failed to run to LangChain API. {e}")
|
||||||
|
except Exception as e:
|
||||||
|
raise LangChainTracerError(
|
||||||
|
f"Failed to update run to LangChain API. {e}"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
def _on_llm_start(self, run: Run) -> None:
|
||||||
|
"""Persist an LLM run."""
|
||||||
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_chat_model_start(self, run: Run) -> None:
|
||||||
|
"""Persist an LLM run."""
|
||||||
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_llm_end(self, run: Run) -> None:
|
||||||
|
"""Process the LLM Run."""
|
||||||
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_llm_error(self, run: Run) -> None:
|
||||||
|
"""Process the LLM Run upon error."""
|
||||||
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_chain_start(self, run: Run) -> None:
|
||||||
|
"""Process the Chain Run upon start."""
|
||||||
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_chain_end(self, run: Run) -> None:
|
||||||
|
"""Process the Chain Run."""
|
||||||
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_chain_error(self, run: Run) -> None:
|
||||||
|
"""Process the Chain Run upon error."""
|
||||||
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_tool_start(self, run: Run) -> None:
|
||||||
|
"""Process the Tool Run upon start."""
|
||||||
|
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_tool_end(self, run: Run) -> None:
|
||||||
|
"""Process the Tool Run."""
|
||||||
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
|
||||||
|
def _on_tool_error(self, run: Run) -> None:
|
||||||
|
"""Process the Tool Run upon error."""
|
||||||
|
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||||
|
@ -91,6 +91,9 @@ class ToolRun(BaseRun):
|
|||||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# Begin V2 API Schemas
|
||||||
|
|
||||||
|
|
||||||
class RunTypeEnum(str, Enum):
|
class RunTypeEnum(str, Enum):
|
||||||
"""Enum for run types."""
|
"""Enum for run types."""
|
||||||
|
|
||||||
@ -105,7 +108,7 @@ class RunBase(BaseModel):
|
|||||||
id: Optional[UUID]
|
id: Optional[UUID]
|
||||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||||
extra: dict
|
extra: Optional[Dict[str, Any]] = None
|
||||||
error: Optional[str]
|
error: Optional[str]
|
||||||
execution_order: int
|
execution_order: int
|
||||||
child_execution_order: Optional[int]
|
child_execution_order: Optional[int]
|
||||||
@ -144,5 +147,13 @@ class RunCreate(RunBase):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class RunUpdate(BaseModel):
|
||||||
|
end_time: Optional[datetime.datetime]
|
||||||
|
error: Optional[str]
|
||||||
|
outputs: Optional[dict]
|
||||||
|
parent_run_id: Optional[UUID]
|
||||||
|
reference_example_id: Optional[UUID]
|
||||||
|
|
||||||
|
|
||||||
ChainRun.update_forward_refs()
|
ChainRun.update_forward_refs()
|
||||||
ToolRun.update_forward_refs()
|
ToolRun.update_forward_refs()
|
||||||
|
@ -8,6 +8,7 @@ from aiohttp import ClientSession
|
|||||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||||
from langchain.callbacks import tracing_enabled
|
from langchain.callbacks import tracing_enabled
|
||||||
from langchain.callbacks.manager import tracing_v2_enabled
|
from langchain.callbacks.manager import tracing_v2_enabled
|
||||||
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.llms import OpenAI
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
questions = [
|
questions = [
|
||||||
@ -140,10 +141,10 @@ async def test_tracing_v2_environment_variable() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_tracing_v2_context_manager() -> None:
|
def test_tracing_v2_context_manager() -> None:
|
||||||
llm = OpenAI(temperature=0)
|
llm = ChatOpenAI(temperature=0)
|
||||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||||
agent = initialize_agent(
|
agent = initialize_agent(
|
||||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
)
|
)
|
||||||
if "LANGCHAIN_TRACING_V2" in os.environ:
|
if "LANGCHAIN_TRACING_V2" in os.environ:
|
||||||
del os.environ["LANGCHAIN_TRACING_V2"]
|
del os.environ["LANGCHAIN_TRACING_V2"]
|
||||||
|
@ -1,134 +0,0 @@
|
|||||||
"""Test Tracer classes."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime
|
|
||||||
from typing import Tuple
|
|
||||||
from unittest.mock import patch
|
|
||||||
from uuid import UUID, uuid4
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from freezegun import freeze_time
|
|
||||||
|
|
||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
|
||||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
|
|
||||||
from langchain.schema import LLMResult
|
|
||||||
|
|
||||||
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
|
||||||
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
|
||||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
|
||||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
|
||||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
|
||||||
tracer = LangChainTracer()
|
|
||||||
return tracer
|
|
||||||
|
|
||||||
|
|
||||||
# Mock a sample TracerSession object
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_tracer_session_v2() -> TracerSession:
|
|
||||||
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
|
||||||
|
|
||||||
|
|
||||||
@freeze_time("2023-01-01")
|
|
||||||
@pytest.fixture
|
|
||||||
def sample_runs() -> Tuple[Run, Run, Run]:
|
|
||||||
llm_run = Run(
|
|
||||||
id="57a08cc4-73d2-4236-8370-549099d07fad",
|
|
||||||
name="llm_run",
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
session_id=1,
|
|
||||||
inputs={"prompts": []},
|
|
||||||
outputs=LLMResult(generations=[[]]).dict(),
|
|
||||||
serialized={},
|
|
||||||
extra={},
|
|
||||||
run_type=RunTypeEnum.llm,
|
|
||||||
)
|
|
||||||
chain_run = Run(
|
|
||||||
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
|
||||||
name="chain_run",
|
|
||||||
execution_order=1,
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
child_execution_order=1,
|
|
||||||
serialized={},
|
|
||||||
inputs={},
|
|
||||||
outputs={},
|
|
||||||
child_runs=[llm_run],
|
|
||||||
extra={},
|
|
||||||
run_type=RunTypeEnum.chain,
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_run = Run(
|
|
||||||
id="57a08cc4-73d2-4236-8372-549099d07fad",
|
|
||||||
name="tool_run",
|
|
||||||
execution_order=1,
|
|
||||||
child_execution_order=1,
|
|
||||||
inputs={"input": "test"},
|
|
||||||
start_time=datetime.utcnow(),
|
|
||||||
end_time=datetime.utcnow(),
|
|
||||||
outputs=None,
|
|
||||||
serialized={},
|
|
||||||
child_runs=[],
|
|
||||||
extra={},
|
|
||||||
run_type=RunTypeEnum.tool,
|
|
||||||
)
|
|
||||||
return llm_run, chain_run, tool_run
|
|
||||||
|
|
||||||
|
|
||||||
def test_persist_run(
|
|
||||||
lang_chain_tracer_v2: LangChainTracer,
|
|
||||||
sample_tracer_session_v2: TracerSession,
|
|
||||||
sample_runs: Tuple[Run, Run, Run],
|
|
||||||
) -> None:
|
|
||||||
"""Test that persist_run method calls requests.post once per method call."""
|
|
||||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
|
||||||
"langchain.callbacks.tracers.langchain.requests.get"
|
|
||||||
) as get:
|
|
||||||
post.return_value.raise_for_status.return_value = None
|
|
||||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
|
||||||
for run in sample_runs:
|
|
||||||
lang_chain_tracer_v2.run_map[str(run.id)] = run
|
|
||||||
for run in sample_runs:
|
|
||||||
lang_chain_tracer_v2._end_trace(run)
|
|
||||||
|
|
||||||
assert post.call_count == 3
|
|
||||||
assert get.call_count == 0
|
|
||||||
|
|
||||||
|
|
||||||
def test_persist_run_with_example_id(
|
|
||||||
lang_chain_tracer_v2: LangChainTracer,
|
|
||||||
sample_tracer_session_v2: TracerSession,
|
|
||||||
sample_runs: Tuple[Run, Run, Run],
|
|
||||||
) -> None:
|
|
||||||
"""Test the example ID is assigned only to the parent run and not the children."""
|
|
||||||
example_id = uuid4()
|
|
||||||
llm_run, chain_run, tool_run = sample_runs
|
|
||||||
chain_run.child_runs = [tool_run]
|
|
||||||
tool_run.child_runs = [llm_run]
|
|
||||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
|
||||||
"langchain.callbacks.tracers.langchain.requests.get"
|
|
||||||
) as get:
|
|
||||||
post.return_value.raise_for_status.return_value = None
|
|
||||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
|
||||||
lang_chain_tracer_v2.example_id = example_id
|
|
||||||
lang_chain_tracer_v2._persist_run(chain_run)
|
|
||||||
|
|
||||||
assert post.call_count == 3
|
|
||||||
assert get.call_count == 0
|
|
||||||
posted_data = [
|
|
||||||
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
|
||||||
]
|
|
||||||
assert posted_data[0]["id"] == str(chain_run.id)
|
|
||||||
assert posted_data[0]["reference_example_id"] == str(example_id)
|
|
||||||
assert posted_data[1]["id"] == str(tool_run.id)
|
|
||||||
assert not posted_data[1].get("reference_example_id")
|
|
||||||
assert posted_data[2]["id"] == str(llm_run.id)
|
|
||||||
assert not posted_data[2].get("reference_example_id")
|
|
Loading…
Reference in New Issue
Block a user