mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-31 00:29:57 +00:00
py tracer fixes (#5377)
This commit is contained in:
parent
ce8b7a2a69
commit
1671c2afb2
@ -347,7 +347,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 12,
|
||||
"execution_count": 7,
|
||||
"id": "87027b0d-3a61-47cf-8a65-3002968be7f9",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -356,13 +356,13 @@
|
||||
"source": [
|
||||
"import os\n",
|
||||
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
||||
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://langchainpro-api-gateway-12bfv6cf.uc.gateway.dev\" # Uncomment this line if you want to use the hosted version\n",
|
||||
"# os.environ[\"LANGCHAIN_ENDPOINT\"] = \"https://api.langchain.plus\" # Uncomment this line if you want to use the hosted version\n",
|
||||
"# os.environ[\"LANGCHAIN_API_KEY\"] = \"<YOUR-LANGCHAINPLUS-API-KEY>\" # Uncomment this line if you want to use the hosted version."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 13,
|
||||
"execution_count": 8,
|
||||
"id": "5b4f49a2-7d09-4601-a8ba-976f0517c64c",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -379,7 +379,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 14,
|
||||
"execution_count": 9,
|
||||
"id": "029b4a57-dc49-49de-8f03-53c292144e09",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -397,7 +397,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"execution_count": 10,
|
||||
"id": "91a85fb2-6027-4bd0-b1fe-2a3b3b79e2dd",
|
||||
"metadata": {
|
||||
"tags": []
|
||||
@ -426,7 +426,7 @@
|
||||
"'1.0891804557407723'"
|
||||
]
|
||||
},
|
||||
"execution_count": 15,
|
||||
"execution_count": 10,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
|
@ -3,24 +3,35 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
from tenacity import retry, stop_after_attempt, wait_fixed
|
||||
from requests.exceptions import HTTPError
|
||||
from tenacity import (
|
||||
before_sleep_log,
|
||||
retry,
|
||||
retry_if_exception_type,
|
||||
stop_after_attempt,
|
||||
wait_exponential,
|
||||
)
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import (
|
||||
Run,
|
||||
RunCreate,
|
||||
RunTypeEnum,
|
||||
RunUpdate,
|
||||
TracerSession,
|
||||
TracerSessionCreate,
|
||||
)
|
||||
from langchain.schema import BaseMessage, messages_to_dict
|
||||
from langchain.utils import raise_for_status_with_text
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_headers() -> Dict[str, Any]:
|
||||
"""Get the headers for the LangChain API."""
|
||||
@ -34,7 +45,27 @@ def get_endpoint() -> str:
|
||||
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:1984")
|
||||
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
class LangChainTracerAPIError(Exception):
|
||||
"""An error occurred while communicating with the LangChain API."""
|
||||
|
||||
|
||||
class LangChainTracerUserError(Exception):
|
||||
"""An error occurred while communicating with the LangChain API."""
|
||||
|
||||
|
||||
class LangChainTracerError(Exception):
|
||||
"""An error occurred while communicating with the LangChain API."""
|
||||
|
||||
|
||||
retry_decorator = retry(
|
||||
stop=stop_after_attempt(3),
|
||||
wait=wait_exponential(multiplier=1, min=4, max=10),
|
||||
retry=retry_if_exception_type(LangChainTracerAPIError),
|
||||
before_sleep=before_sleep_log(logger, logging.WARNING),
|
||||
)
|
||||
|
||||
|
||||
@retry_decorator
|
||||
def _get_tenant_id(
|
||||
tenant_id: Optional[str], endpoint: Optional[str], headers: Optional[dict]
|
||||
) -> str:
|
||||
@ -44,8 +75,24 @@ def _get_tenant_id(
|
||||
return tenant_id_
|
||||
endpoint_ = endpoint or get_endpoint()
|
||||
headers_ = headers or get_headers()
|
||||
response = requests.get(endpoint_ + "/tenants", headers=headers_)
|
||||
raise_for_status_with_text(response)
|
||||
response = None
|
||||
try:
|
||||
response = requests.get(endpoint_ + "/tenants", headers=headers_)
|
||||
raise_for_status_with_text(response)
|
||||
except HTTPError as e:
|
||||
if response is not None and response.status_code == 500:
|
||||
raise LangChainTracerAPIError(
|
||||
f"Failed to get tenant ID from LangChain API. {e}"
|
||||
)
|
||||
else:
|
||||
raise LangChainTracerUserError(
|
||||
f"Failed to get tenant ID from LangChain API. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise LangChainTracerError(
|
||||
f"Failed to get tenant ID from LangChain API. {e}"
|
||||
) from e
|
||||
|
||||
tenants: List[Dict[str, Any]] = response.json()
|
||||
if not tenants:
|
||||
raise ValueError(f"No tenants found for URL {endpoint_}")
|
||||
@ -72,6 +119,8 @@ class LangChainTracer(BaseTracer):
|
||||
self.example_id = example_id
|
||||
self.session_name = session_name or os.getenv("LANGCHAIN_SESSION", "default")
|
||||
self.session_extra = session_extra
|
||||
# set max_workers to 1 to process tasks in order
|
||||
self.executor = ThreadPoolExecutor(max_workers=1)
|
||||
|
||||
def on_chat_model_start(
|
||||
self,
|
||||
@ -108,7 +157,7 @@ class LangChainTracer(BaseTracer):
|
||||
self.tenant_id = tenant_id
|
||||
return tenant_id
|
||||
|
||||
@retry(stop=stop_after_attempt(3), wait=wait_fixed(0.5))
|
||||
@retry_decorator
|
||||
def ensure_session(self) -> TracerSession:
|
||||
"""Upsert a session."""
|
||||
if self.session is not None:
|
||||
@ -118,37 +167,124 @@ class LangChainTracer(BaseTracer):
|
||||
session_create = TracerSessionCreate(
|
||||
name=self.session_name, extra=self.session_extra, tenant_id=tenant_id
|
||||
)
|
||||
r = requests.post(
|
||||
url,
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(r)
|
||||
self.session = TracerSession(**r.json())
|
||||
response = None
|
||||
try:
|
||||
response = requests.post(
|
||||
url,
|
||||
data=session_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if response is not None and response.status_code == 500:
|
||||
raise LangChainTracerAPIError(
|
||||
f"Failed to upsert session to LangChain API. {e}"
|
||||
)
|
||||
else:
|
||||
raise LangChainTracerUserError(
|
||||
f"Failed to upsert session to LangChain API. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
raise LangChainTracerError(
|
||||
f"Failed to upsert session to LangChain API. {e}"
|
||||
) from e
|
||||
self.session = TracerSession(**response.json())
|
||||
return self.session
|
||||
|
||||
def _persist_run_nested(self, run: Run) -> None:
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
|
||||
@retry_decorator
|
||||
def _persist_run_single(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
session = self.ensure_session()
|
||||
child_runs = run.child_runs
|
||||
if run.parent_run_id is None:
|
||||
run.reference_example_id = self.example_id
|
||||
run_dict = run.dict()
|
||||
del run_dict["child_runs"]
|
||||
run_create = RunCreate(**run_dict, session_id=session.id)
|
||||
response = None
|
||||
try:
|
||||
response = requests.post(
|
||||
f"{self._endpoint}/runs",
|
||||
data=run_create.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
raise_for_status_with_text(response)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if response is not None and response.status_code == 500:
|
||||
raise LangChainTracerAPIError(
|
||||
f"Failed to upsert persist run to LangChain API. {e}"
|
||||
)
|
||||
else:
|
||||
raise LangChainTracerUserError(
|
||||
f"Failed to persist run to LangChain API. {e}"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to persist run: {e}")
|
||||
for child_run in child_runs:
|
||||
child_run.parent_run_id = run.id
|
||||
self._persist_run_nested(child_run)
|
||||
raise LangChainTracerError(
|
||||
f"Failed to persist run to LangChain API. {e}"
|
||||
) from e
|
||||
|
||||
def _persist_run(self, run: Run) -> None:
|
||||
"""Persist a run."""
|
||||
run.reference_example_id = self.example_id
|
||||
# TODO: Post first then patch
|
||||
self._persist_run_nested(run)
|
||||
@retry_decorator
|
||||
def _update_run_single(self, run: Run) -> None:
|
||||
"""Update a run."""
|
||||
run_update = RunUpdate(**run.dict())
|
||||
response = None
|
||||
try:
|
||||
response = requests.patch(
|
||||
f"{self._endpoint}/runs/{run.id}",
|
||||
data=run_update.json(),
|
||||
headers=self._headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
except HTTPError as e:
|
||||
if response is not None and response.status_code == 500:
|
||||
raise LangChainTracerAPIError(
|
||||
f"Failed to update run to LangChain API. {e}"
|
||||
)
|
||||
else:
|
||||
raise LangChainTracerUserError(f"Failed to run to LangChain API. {e}")
|
||||
except Exception as e:
|
||||
raise LangChainTracerError(
|
||||
f"Failed to update run to LangChain API. {e}"
|
||||
) from e
|
||||
|
||||
def _on_llm_start(self, run: Run) -> None:
|
||||
"""Persist an LLM run."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_chat_model_start(self, run: Run) -> None:
|
||||
"""Persist an LLM run."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_llm_end(self, run: Run) -> None:
|
||||
"""Process the LLM Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_llm_error(self, run: Run) -> None:
|
||||
"""Process the LLM Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_chain_start(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon start."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_chain_end(self, run: Run) -> None:
|
||||
"""Process the Chain Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_chain_error(self, run: Run) -> None:
|
||||
"""Process the Chain Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_tool_start(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon start."""
|
||||
self.executor.submit(self._persist_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_tool_end(self, run: Run) -> None:
|
||||
"""Process the Tool Run."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
||||
def _on_tool_error(self, run: Run) -> None:
|
||||
"""Process the Tool Run upon error."""
|
||||
self.executor.submit(self._update_run_single, run.copy(deep=True))
|
||||
|
@ -91,6 +91,9 @@ class ToolRun(BaseRun):
|
||||
child_tool_runs: List[ToolRun] = Field(default_factory=list)
|
||||
|
||||
|
||||
# Begin V2 API Schemas
|
||||
|
||||
|
||||
class RunTypeEnum(str, Enum):
|
||||
"""Enum for run types."""
|
||||
|
||||
@ -105,7 +108,7 @@ class RunBase(BaseModel):
|
||||
id: Optional[UUID]
|
||||
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
end_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
|
||||
extra: dict
|
||||
extra: Optional[Dict[str, Any]] = None
|
||||
error: Optional[str]
|
||||
execution_order: int
|
||||
child_execution_order: Optional[int]
|
||||
@ -144,5 +147,13 @@ class RunCreate(RunBase):
|
||||
return values
|
||||
|
||||
|
||||
class RunUpdate(BaseModel):
|
||||
end_time: Optional[datetime.datetime]
|
||||
error: Optional[str]
|
||||
outputs: Optional[dict]
|
||||
parent_run_id: Optional[UUID]
|
||||
reference_example_id: Optional[UUID]
|
||||
|
||||
|
||||
ChainRun.update_forward_refs()
|
||||
ToolRun.update_forward_refs()
|
||||
|
@ -8,6 +8,7 @@ from aiohttp import ClientSession
|
||||
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||
from langchain.callbacks import tracing_enabled
|
||||
from langchain.callbacks.manager import tracing_v2_enabled
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.llms import OpenAI
|
||||
|
||||
questions = [
|
||||
@ -140,10 +141,10 @@ async def test_tracing_v2_environment_variable() -> None:
|
||||
|
||||
|
||||
def test_tracing_v2_context_manager() -> None:
|
||||
llm = OpenAI(temperature=0)
|
||||
llm = ChatOpenAI(temperature=0)
|
||||
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||
)
|
||||
if "LANGCHAIN_TRACING_V2" in os.environ:
|
||||
del os.environ["LANGCHAIN_TRACING_V2"]
|
||||
|
@ -1,134 +0,0 @@
|
||||
"""Test Tracer classes."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Tuple
|
||||
from unittest.mock import patch
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
import pytest
|
||||
from freezegun import freeze_time
|
||||
|
||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
_SESSION_ID = UUID("4fbf7c55-2727-4711-8964-d821ed4d4e2a")
|
||||
_TENANT_ID = UUID("57a08cc4-73d2-4236-8378-549099d07fad")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def lang_chain_tracer_v2(monkeypatch: pytest.MonkeyPatch) -> LangChainTracer:
|
||||
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
|
||||
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
|
||||
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
|
||||
tracer = LangChainTracer()
|
||||
return tracer
|
||||
|
||||
|
||||
# Mock a sample TracerSession object
|
||||
@pytest.fixture
|
||||
def sample_tracer_session_v2() -> TracerSession:
|
||||
return TracerSession(id=_SESSION_ID, name="Sample session", tenant_id=_TENANT_ID)
|
||||
|
||||
|
||||
@freeze_time("2023-01-01")
|
||||
@pytest.fixture
|
||||
def sample_runs() -> Tuple[Run, Run, Run]:
|
||||
llm_run = Run(
|
||||
id="57a08cc4-73d2-4236-8370-549099d07fad",
|
||||
name="llm_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
parent_run_id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
session_id=1,
|
||||
inputs={"prompts": []},
|
||||
outputs=LLMResult(generations=[[]]).dict(),
|
||||
serialized={},
|
||||
extra={},
|
||||
run_type=RunTypeEnum.llm,
|
||||
)
|
||||
chain_run = Run(
|
||||
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
name="chain_run",
|
||||
execution_order=1,
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
child_execution_order=1,
|
||||
serialized={},
|
||||
inputs={},
|
||||
outputs={},
|
||||
child_runs=[llm_run],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.chain,
|
||||
)
|
||||
|
||||
tool_run = Run(
|
||||
id="57a08cc4-73d2-4236-8372-549099d07fad",
|
||||
name="tool_run",
|
||||
execution_order=1,
|
||||
child_execution_order=1,
|
||||
inputs={"input": "test"},
|
||||
start_time=datetime.utcnow(),
|
||||
end_time=datetime.utcnow(),
|
||||
outputs=None,
|
||||
serialized={},
|
||||
child_runs=[],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.tool,
|
||||
)
|
||||
return llm_run, chain_run, tool_run
|
||||
|
||||
|
||||
def test_persist_run(
|
||||
lang_chain_tracer_v2: LangChainTracer,
|
||||
sample_tracer_session_v2: TracerSession,
|
||||
sample_runs: Tuple[Run, Run, Run],
|
||||
) -> None:
|
||||
"""Test that persist_run method calls requests.post once per method call."""
|
||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||
"langchain.callbacks.tracers.langchain.requests.get"
|
||||
) as get:
|
||||
post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
for run in sample_runs:
|
||||
lang_chain_tracer_v2.run_map[str(run.id)] = run
|
||||
for run in sample_runs:
|
||||
lang_chain_tracer_v2._end_trace(run)
|
||||
|
||||
assert post.call_count == 3
|
||||
assert get.call_count == 0
|
||||
|
||||
|
||||
def test_persist_run_with_example_id(
|
||||
lang_chain_tracer_v2: LangChainTracer,
|
||||
sample_tracer_session_v2: TracerSession,
|
||||
sample_runs: Tuple[Run, Run, Run],
|
||||
) -> None:
|
||||
"""Test the example ID is assigned only to the parent run and not the children."""
|
||||
example_id = uuid4()
|
||||
llm_run, chain_run, tool_run = sample_runs
|
||||
chain_run.child_runs = [tool_run]
|
||||
tool_run.child_runs = [llm_run]
|
||||
with patch("langchain.callbacks.tracers.langchain.requests.post") as post, patch(
|
||||
"langchain.callbacks.tracers.langchain.requests.get"
|
||||
) as get:
|
||||
post.return_value.raise_for_status.return_value = None
|
||||
lang_chain_tracer_v2.session = sample_tracer_session_v2
|
||||
lang_chain_tracer_v2.example_id = example_id
|
||||
lang_chain_tracer_v2._persist_run(chain_run)
|
||||
|
||||
assert post.call_count == 3
|
||||
assert get.call_count == 0
|
||||
posted_data = [
|
||||
json.loads(call_args[1]["data"]) for call_args in post.call_args_list
|
||||
]
|
||||
assert posted_data[0]["id"] == str(chain_run.id)
|
||||
assert posted_data[0]["reference_example_id"] == str(example_id)
|
||||
assert posted_data[1]["id"] == str(tool_run.id)
|
||||
assert not posted_data[1].get("reference_example_id")
|
||||
assert posted_data[2]["id"] == str(llm_run.id)
|
||||
assert not posted_data[2].get("reference_example_id")
|
Loading…
Reference in New Issue
Block a user