Compare commits

...

7 Commits

Author SHA1 Message Date
William Fu-Hinthorn
812f35e69d Update wandb 2023-10-02 13:07:53 -07:00
Ankush Gola
2bc244a6a1 update snapshots 2023-09-28 16:33:06 -07:00
Ankush Gola
5a839aa7d1 Merge branch 'ankush/delete_v1_tracer' into ankush/single-input 2023-09-28 16:11:26 -07:00
Ankush Gola
80dd05b802 delete v1 tracer 2023-09-28 16:05:29 -07:00
Ankush Gola
548ca264ec fix some tests 2023-09-28 11:27:44 -07:00
Ankush Gola
19e11a602d fix wording 2023-09-27 17:21:59 -07:00
Ankush Gola
5afd3b2672 only allow single prompt or message batch 2023-09-27 17:19:24 -07:00
15 changed files with 51 additions and 1053 deletions

View File

@@ -129,7 +129,6 @@
"from langchain.callbacks.base import BaseCallbackHandler\n",
"from langchain.schema import AgentAction\n",
"from langchain.agents import AgentType, initialize_agent, load_tools\n",
"from langchain.callbacks import tracing_enabled\n",
"from langchain.llms import OpenAI\n",
"\n",
"\n",
@@ -186,9 +185,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "venv",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "venv"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
@@ -200,7 +199,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.3"
"version": "3.11.5"
}
},
"nbformat": 4,

View File

@@ -23,7 +23,6 @@ from langchain.callbacks.llmonitor_callback import LLMonitorCallbackHandler
from langchain.callbacks.manager import (
collect_runs,
get_openai_callback,
tracing_enabled,
tracing_v2_enabled,
wandb_tracing_enabled,
)
@@ -67,7 +66,6 @@ __all__ = [
"WandbCallbackHandler",
"WhyLabsCallbackHandler",
"get_openai_callback",
"tracing_enabled",
"tracing_v2_enabled",
"collect_runs",
"wandb_tracing_enabled",

View File

@@ -40,7 +40,6 @@ from langchain.callbacks.openai_info import OpenAICallbackHandler
from langchain.callbacks.stdout import StdOutCallbackHandler
from langchain.callbacks.tracers import run_collector
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
from langchain.callbacks.tracers.wandb import WandbTracer
from langchain.schema import (
@@ -60,11 +59,6 @@ logger = logging.getLogger(__name__)
openai_callback_var: ContextVar[Optional[OpenAICallbackHandler]] = ContextVar(
"openai_callback", default=None
)
tracing_callback_var: ContextVar[
Optional[LangChainTracerV1]
] = ContextVar( # noqa: E501
"tracing_callback", default=None
)
wandb_tracing_callback_var: ContextVar[
Optional[WandbTracer]
] = ContextVar( # noqa: E501
@@ -105,30 +99,6 @@ def get_openai_callback() -> Generator[OpenAICallbackHandler, None, None]:
openai_callback_var.set(None)
@contextmanager
def tracing_enabled(
session_name: str = "default",
) -> Generator[TracerSessionV1, None, None]:
"""Get the Deprecated LangChainTracer in a context manager.
Args:
session_name (str, optional): The name of the session.
Defaults to "default".
Returns:
TracerSessionV1: The LangChainTracer session.
Example:
>>> with tracing_enabled() as session:
... # Use the LangChainTracer session
"""
cb = LangChainTracerV1()
session = cast(TracerSessionV1, cb.load_session(session_name))
tracing_callback_var.set(cb)
yield session
tracing_callback_var.set(None)
@contextmanager
def wandb_tracing_enabled(
session_name: str = "default",
@@ -1852,21 +1822,18 @@ def _configure(
callback_manager.add_metadata(inheritable_metadata or {})
callback_manager.add_metadata(local_metadata or {}, False)
tracer = tracing_callback_var.get()
wandb_tracer = wandb_tracing_callback_var.get()
open_ai = openai_callback_var.get()
tracing_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING")
or tracer is not None
or env_var_is_set("LANGCHAIN_HANDLER")
)
wandb_tracing_enabled_ = (
env_var_is_set("LANGCHAIN_WANDB_TRACING") or wandb_tracer is not None
)
tracer_v2 = tracing_v2_callback_var.get()
tracing_v2_enabled_ = (
env_var_is_set("LANGCHAIN_TRACING_V2") or tracer_v2 is not None
env_var_is_set("LANGCHAIN_TRACING_V2")
or tracer_v2 is not None
or env_var_is_set("LANGCHAIN_TRACING")
)
tracer_project = os.environ.get(
"LANGCHAIN_PROJECT", os.environ.get("LANGCHAIN_SESSION", "default")
@@ -1876,7 +1843,6 @@ def _configure(
if (
verbose
or debug
or tracing_enabled_
or tracing_v2_enabled_
or wandb_tracing_enabled_
or open_ai is not None
@@ -1894,16 +1860,6 @@ def _configure(
for handler in callback_manager.handlers
):
callback_manager.add_handler(ConsoleCallbackHandler(), True)
if tracing_enabled_ and not any(
isinstance(handler, LangChainTracerV1)
for handler in callback_manager.handlers
):
if tracer:
callback_manager.add_handler(tracer, True)
else:
handler = LangChainTracerV1()
handler.load_session(tracer_project)
callback_manager.add_handler(handler, True)
if wandb_tracing_enabled_ and not any(
isinstance(handler, WandbTracer) for handler in callback_manager.handlers
):

View File

@@ -1,7 +1,6 @@
"""Tracers that record execution of LangChain runs."""
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1
from langchain.callbacks.tracers.stdout import (
ConsoleCallbackHandler,
FunctionCallbackHandler,
@@ -10,7 +9,6 @@ from langchain.callbacks.tracers.wandb import WandbTracer
__all__ = [
"LangChainTracer",
"LangChainTracerV1",
"FunctionCallbackHandler",
"ConsoleCallbackHandler",
"WandbTracer",

View File

@@ -111,11 +111,15 @@ class BaseTracer(BaseCallbackHandler, ABC):
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
if len(prompts) != 1:
raise ValueError("Tracer does not support multiple prompts for LLM runs.")
llm_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"prompts": prompts},
# Send a single prompt since the CallbackManager will
# split up the list of prompts for us. Invariant: len(prompts) == 1
inputs={"prompt": prompts[0]},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,

View File

@@ -12,7 +12,7 @@ from uuid import UUID
from langsmith import Client
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, TracerSession
from langchain.callbacks.tracers.schemas import Run
from langchain.env import get_runtime_environment
from langchain.load.dump import dumpd
from langchain.schema.messages import BaseMessage
@@ -64,7 +64,6 @@ class LangChainTracer(BaseTracer):
) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self.session: Optional[TracerSession] = None
self.example_id = (
UUID(example_id) if isinstance(example_id, str) else example_id
)
@@ -107,11 +106,17 @@ class LangChainTracer(BaseTracer):
start_time = datetime.utcnow()
if metadata:
kwargs.update({"metadata": metadata})
if len(messages) != 1:
raise ValueError(
"LangChainTracer does not support multiple batches of messages."
)
chat_model_run = Run(
id=run_id,
parent_run_id=parent_run_id,
serialized=serialized,
inputs={"messages": [[dumpd(msg) for msg in batch] for batch in messages]},
# Send a single batch of messages since the CallbackManager will
# split up the messages for us. Invariant: len(messages) == 1
inputs={"messages": [dumpd(msg) for msg in messages[0]]},
extra=kwargs,
events=[{"name": "start", "time": start_time}],
start_time=start_time,

View File

@@ -1,185 +0,0 @@
from __future__ import annotations
import logging
import os
from typing import Any, Dict, Optional, Union
import requests
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import (
ChainRun,
LLMRun,
Run,
ToolRun,
TracerSession,
TracerSessionV1,
TracerSessionV1Base,
)
from langchain.schema.messages import get_buffer_string
from langchain.utils import raise_for_status_with_text
logger = logging.getLogger(__name__)
def get_headers() -> Dict[str, Any]:
"""Get the headers for the LangChain API."""
headers: Dict[str, Any] = {"Content-Type": "application/json"}
if os.getenv("LANGCHAIN_API_KEY"):
headers["x-api-key"] = os.getenv("LANGCHAIN_API_KEY")
return headers
def _get_endpoint() -> str:
return os.getenv("LANGCHAIN_ENDPOINT", "http://localhost:8000")
class LangChainTracerV1(BaseTracer):
"""An implementation of the SharedTracer that POSTS to the langchain endpoint."""
def __init__(self, **kwargs: Any) -> None:
"""Initialize the LangChain tracer."""
super().__init__(**kwargs)
self.session: Optional[TracerSessionV1] = None
self._endpoint = _get_endpoint()
self._headers = get_headers()
def _convert_to_v1_run(self, run: Run) -> Union[LLMRun, ChainRun, ToolRun]:
session = self.session or self.load_default_session()
if not isinstance(session, TracerSessionV1):
raise ValueError(
"LangChainTracerV1 is not compatible with"
f" session of type {type(session)}"
)
if run.run_type == "llm":
if "prompts" in run.inputs:
prompts = run.inputs["prompts"]
elif "messages" in run.inputs:
prompts = [get_buffer_string(batch) for batch in run.inputs["messages"]]
else:
raise ValueError("No prompts found in LLM run inputs")
return LLMRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
extra=run.extra,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
error=run.error,
prompts=prompts,
response=run.outputs if run.outputs else None,
)
if run.run_type == "chain":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ChainRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
inputs=run.inputs,
outputs=run.outputs,
error=run.error,
extra=run.extra,
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
)
if run.run_type == "tool":
child_runs = [self._convert_to_v1_run(run) for run in run.child_runs]
return ToolRun(
uuid=str(run.id) if run.id else None,
parent_uuid=str(run.parent_run_id) if run.parent_run_id else None,
start_time=run.start_time,
end_time=run.end_time,
execution_order=run.execution_order,
child_execution_order=run.child_execution_order,
serialized=run.serialized,
session_id=session.id,
action=str(run.serialized),
tool_input=run.inputs.get("input", ""),
output=None if run.outputs is None else run.outputs.get("output"),
error=run.error,
extra=run.extra,
child_chain_runs=[
run for run in child_runs if isinstance(run, ChainRun)
],
child_tool_runs=[run for run in child_runs if isinstance(run, ToolRun)],
child_llm_runs=[run for run in child_runs if isinstance(run, LLMRun)],
)
raise ValueError(f"Unknown run type: {run.run_type}")
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
v1_run = self._convert_to_v1_run(run)
else:
v1_run = run
if isinstance(v1_run, LLMRun):
endpoint = f"{self._endpoint}/llm-runs"
elif isinstance(v1_run, ChainRun):
endpoint = f"{self._endpoint}/chain-runs"
else:
endpoint = f"{self._endpoint}/tool-runs"
try:
response = requests.post(
endpoint,
data=v1_run.json(),
headers=self._headers,
)
raise_for_status_with_text(response)
except Exception as e:
logger.warning(f"Failed to persist run: {e}")
def _persist_session(
self, session_create: TracerSessionV1Base
) -> Union[TracerSessionV1, TracerSession]:
"""Persist a session."""
try:
r = requests.post(
f"{self._endpoint}/sessions",
data=session_create.json(),
headers=self._headers,
)
session = TracerSessionV1(id=r.json()["id"], **session_create.dict())
except Exception as e:
logger.warning(f"Failed to create session, using default session: {e}")
session = TracerSessionV1(id=1, **session_create.dict())
return session
def _load_session(self, session_name: Optional[str] = None) -> TracerSessionV1:
"""Load a session from the tracer."""
try:
url = f"{self._endpoint}/sessions"
if session_name:
url += f"?name={session_name}"
r = requests.get(url, headers=self._headers)
tracer_session = TracerSessionV1(**r.json()[0])
except Exception as e:
session_type = "default" if not session_name else session_name
logger.warning(
f"Failed to load {session_type} session, using empty session: {e}"
)
tracer_session = TracerSessionV1(id=1)
self.session = tracer_session
return tracer_session
def load_session(self, session_name: str) -> Union[TracerSessionV1, TracerSession]:
"""Load a session with the given name from the tracer."""
return self._load_session(session_name)
def load_default_session(self) -> Union[TracerSessionV1, TracerSession]:
"""Load the default tracing session and set it as the Tracer's session."""
return self._load_session("default")

View File

@@ -2,55 +2,11 @@
from __future__ import annotations
import datetime
import warnings
from typing import Any, Dict, List, Optional
from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from langchain.pydantic_v1 import BaseModel, Field, root_validator
from langchain.schema import LLMResult
def RunTypeEnum() -> RunTypeEnumDep:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
)
return RunTypeEnumDep
class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1."""
start_time: datetime.datetime = Field(default_factory=datetime.datetime.utcnow)
name: Optional[str] = None
extra: Optional[Dict[str, Any]] = None
class TracerSessionV1Create(TracerSessionV1Base):
"""Create class for TracerSessionV1."""
class TracerSessionV1(TracerSessionV1Base):
"""TracerSessionV1 schema."""
id: int
class TracerSessionBase(TracerSessionV1Base):
"""Base class for TracerSession."""
tenant_id: UUID
class TracerSession(TracerSessionBase):
"""TracerSessionV1 schema for the V2 API."""
id: UUID
class BaseRun(BaseModel):
@@ -68,37 +24,6 @@ class BaseRun(BaseModel):
error: Optional[str] = None
class LLMRun(BaseRun):
"""Class for LLMRun."""
prompts: List[str]
response: Optional[LLMResult] = None
class ChainRun(BaseRun):
"""Class for ChainRun."""
inputs: Dict[str, Any]
outputs: Optional[Dict[str, Any]] = None
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
class ToolRun(BaseRun):
"""Class for ToolRun."""
tool_input: str
output: Optional[str] = None
action: str
child_llm_runs: List[LLMRun] = Field(default_factory=list)
child_chain_runs: List[ChainRun] = Field(default_factory=list)
child_tool_runs: List[ToolRun] = Field(default_factory=list)
# Begin V2 API Schemas
class Run(BaseRunV2):
"""Run schema for the V2 API in the Tracer."""
@@ -118,20 +43,9 @@ class Run(BaseRunV2):
return values
ChainRun.update_forward_refs()
ToolRun.update_forward_refs()
Run.update_forward_refs()
__all__ = [
"BaseRun",
"ChainRun",
"LLMRun",
"Run",
"RunTypeEnum",
"ToolRun",
"TracerSession",
"TracerSessionBase",
"TracerSessionV1",
"TracerSessionV1Base",
"TracerSessionV1Create",
]

View File

@@ -99,19 +99,18 @@ class RunProcessor:
base_span.results = [
self.trace_tree.Result(
inputs={"prompt": prompt},
inputs={"prompt": run.inputs["prompt"]},
outputs={
f"gen_{g_i}": gen["text"]
for g_i, gen in enumerate(run.outputs["generations"][ndx])
for g_i, gen in enumerate(run.outputs["generations"][0])
}
if (
run.outputs is not None
and len(run.outputs["generations"]) > ndx
and len(run.outputs["generations"][ndx]) > 0
and len(run.outputs["generations"]) > 0
and len(run.outputs["generations"][0]) > 0
)
else None,
)
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
]
base_span.span_kind = self.trace_tree.SpanKind.LLM

View File

@@ -6,7 +6,6 @@ import pytest
from aiohttp import ClientSession
from langchain.agents import AgentType, initialize_agent, load_tools
from langchain.callbacks import tracing_enabled
from langchain.callbacks.manager import (
atrace_as_chain_group,
trace_as_chain_group,
@@ -106,7 +105,7 @@ def test_tracing_context_manager() -> None:
)
if "LANGCHAIN_TRACING" in os.environ:
del os.environ["LANGCHAIN_TRACING"]
with tracing_enabled() as session:
with tracing_v2_enabled() as session:
assert session
agent.run(questions[0]) # this should be traced
@@ -125,7 +124,7 @@ async def test_tracing_context_manager_async() -> None:
# start a background task
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
with tracing_enabled() as session:
with tracing_v2_enabled() as session:
assert session
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
await asyncio.gather(*tasks)

View File

@@ -13,4 +13,4 @@ def test_collect_runs() -> None:
assert cb.traced_runs
assert len(cb.traced_runs) == 1
assert isinstance(cb.traced_runs[0].id, uuid.UUID)
assert cb.traced_runs[0].inputs == {"prompts": ["hi"]}
assert cb.traced_runs[0].inputs == {"prompt": "hi"}

View File

@@ -6,16 +6,7 @@ def test_public_api() -> None:
"""Test for changes in the public API."""
expected_all = [
"BaseRun",
"ChainRun",
"LLMRun",
"Run",
"RunTypeEnum",
"ToolRun",
"TracerSession",
"TracerSessionBase",
"TracerSessionV1",
"TracerSessionV1Base",
"TracerSessionV1Create",
]
assert sorted(schemas_all) == expected_all

View File

@@ -48,14 +48,14 @@ def test_tracer_llm_run() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
inputs={"prompts": []},
inputs={"prompt": "test"},
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
)
tracer = FakeTracer()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@@ -81,7 +81,7 @@ def test_tracer_chat_model_run() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED_CHAT,
inputs=dict(prompts=["Human: "]),
inputs=dict(prompt="Human: "),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
@@ -117,7 +117,7 @@ def test_tracer_multiple_llm_runs() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]]),
error=None,
run_type="llm",
@@ -126,7 +126,7 @@ def test_tracer_multiple_llm_runs() -> None:
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@@ -208,7 +208,7 @@ def test_tracer_nested_run() -> None:
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
)
@@ -216,7 +216,7 @@ def test_tracer_nested_run() -> None:
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
@@ -272,7 +272,7 @@ def test_tracer_nested_run() -> None:
execution_order=3,
child_execution_order=3,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]]),
run_type="llm",
)
@@ -292,7 +292,7 @@ def test_tracer_nested_run() -> None:
execution_order=4,
child_execution_order=4,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]]),
run_type="llm",
),
@@ -320,14 +320,14 @@ def test_tracer_llm_run_on_error() -> None:
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=None,
error=repr(exception),
run_type="llm",
)
tracer = FakeTracer()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_start(serialized=SERIALIZED, prompts=["test"], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@@ -411,14 +411,14 @@ def test_tracer_nested_runs_on_error() -> None:
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
@@ -431,7 +431,7 @@ def test_tracer_nested_runs_on_error() -> None:
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
prompts=["test"],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
@@ -470,7 +470,7 @@ def test_tracer_nested_runs_on_error() -> None:
child_execution_order=2,
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]], llm_output=None),
run_type="llm",
),
@@ -488,7 +488,7 @@ def test_tracer_nested_runs_on_error() -> None:
child_execution_order=3,
serialized=SERIALIZED,
error=None,
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=LLMResult(generations=[[]], llm_output=None),
run_type="llm",
),
@@ -524,7 +524,7 @@ def test_tracer_nested_runs_on_error() -> None:
child_execution_order=5,
serialized=SERIALIZED,
error=repr(exception),
inputs=dict(prompts=[]),
inputs=dict(prompt="test"),
outputs=None,
run_type="llm",
)

View File

@@ -1,680 +0,0 @@
"""Test Tracer classes."""
from __future__ import annotations
from datetime import datetime
from typing import List, Optional, Union
from uuid import uuid4
import pytest
from freezegun import freeze_time
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.tracers.base import BaseTracer, TracerException
from langchain.callbacks.tracers.langchain_v1 import (
ChainRun,
LangChainTracerV1,
LLMRun,
ToolRun,
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base
from langchain.schema import LLMResult
from langchain.schema.messages import HumanMessage
TEST_SESSION_ID = 2023
SERIALIZED = {"id": ["llm"]}
SERIALIZED_CHAT = {"id": ["chat_model"]}
def load_session(session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name=session_name, start_time=datetime.utcnow()
)
def new_session(name: Optional[str] = None) -> TracerSessionV1:
"""Create a new tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name=name or "default", start_time=datetime.utcnow()
)
def _persist_session(session: TracerSessionV1Base) -> TracerSessionV1:
"""Persist a tracing session."""
return TracerSessionV1(**{**session.dict(), "id": TEST_SESSION_ID})
def load_default_session() -> TracerSessionV1:
"""Load a tracing session."""
return TracerSessionV1(
id=TEST_SESSION_ID, name="default", start_time=datetime.utcnow()
)
@pytest.fixture
def lang_chain_tracer_v1(monkeypatch: pytest.MonkeyPatch) -> LangChainTracerV1:
monkeypatch.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
monkeypatch.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
monkeypatch.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV1()
return tracer
class FakeTracer(BaseTracer):
"""Fake tracer that records LangChain execution."""
def __init__(self) -> None:
"""Initialize the tracer."""
super().__init__()
self.runs: List[Union[LLMRun, ChainRun, ToolRun]] = []
def _persist_run(self, run: Union[Run, LLMRun, ChainRun, ToolRun]) -> None:
"""Persist a run."""
if isinstance(run, Run):
with pytest.MonkeyPatch().context() as m:
m.setenv("LANGCHAIN_TENANT_ID", "test-tenant-id")
m.setenv("LANGCHAIN_ENDPOINT", "http://test-endpoint.com")
m.setenv("LANGCHAIN_API_KEY", "foo")
tracer = LangChainTracerV1()
tracer.load_default_session = load_default_session # type: ignore
run = tracer._convert_to_v1_run(run)
self.runs.append(run)
def _persist_session(self, session: TracerSessionV1Base) -> TracerSessionV1:
"""Persist a tracing session."""
return _persist_session(session)
def new_session(self, name: Optional[str] = None) -> TracerSessionV1:
"""Create a new tracing session."""
return new_session(name)
def load_session(self, session_name: str) -> TracerSessionV1:
"""Load a tracing session."""
return load_session(session_name)
def load_default_session(self) -> TracerSessionV1:
"""Load a tracing session."""
return load_default_session()
@freeze_time("2023-01-01")
def test_tracer_llm_run() -> None:
"""Test tracer on an LLM run."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chat_model_run() -> None:
"""Test tracer on a Chat Model run."""
tracer = FakeTracer()
tracer.new_session()
manager = CallbackManager(handlers=[tracer])
run_managers = manager.on_chat_model_start(
serialized=SERIALIZED_CHAT, messages=[[HumanMessage(content="")]]
)
compare_run = LLMRun(
uuid=str(run_managers[0].run_id),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED_CHAT,
prompts=["Human: "],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
for run_manager in run_managers:
run_manager.on_llm_end(response=LLMResult(generations=[[]]))
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_llm_run_errors_no_start() -> None:
"""Test tracer on an LLM run without a start."""
tracer = FakeTracer()
tracer.new_session()
with pytest.raises(TracerException):
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid4())
@freeze_time("2023-01-01")
def test_tracer_multiple_llm_runs() -> None:
"""Test the tracer with multiple runs."""
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
num_runs = 10
for _ in range(num_runs):
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=uuid)
assert tracer.runs == [compare_run] * num_runs
@freeze_time("2023-01-01")
def test_tracer_chain_run() -> None:
"""Test tracer on a Chain run."""
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_end(outputs={}, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run() -> None:
"""Test tracer on a Tool run."""
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
tool_input="test",
output="test",
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=None,
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_end("test", run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_run() -> None:
"""Test tracer on a nested run."""
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
for _ in range(10):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=tool_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_tool_end("test", run_id=tool_uuid)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_chain_end(outputs={}, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=4,
serialized={"name": "chain"},
inputs={},
outputs={},
session_id=TEST_SESSION_ID,
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=3,
serialized={"name": "tool"},
tool_input="test",
output="test",
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=None,
child_chain_runs=[],
child_tool_runs=[],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(tool_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
)
],
),
],
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
error=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=4,
serialized=SERIALIZED,
prompts=[],
response=LLMResult(generations=[[]]),
session_id=TEST_SESSION_ID,
),
],
)
assert tracer.runs[0] == compare_run
assert tracer.runs == [compare_run] * 10
@freeze_time("2023-01-01")
def test_tracer_llm_run_on_error() -> None:
"""Test tracer on an LLM run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = LLMRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized=SERIALIZED,
prompts=[],
response=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_llm_start(serialized=SERIALIZED, prompts=[], run_id=uuid)
tracer.on_llm_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_chain_run_on_error() -> None:
"""Test tracer on a Chain run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ChainRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "chain"},
inputs={},
outputs=None,
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_chain_start(serialized={"name": "chain"}, inputs={}, run_id=uuid)
tracer.on_chain_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_tool_run_on_error() -> None:
"""Test tracer on a Tool run with an error."""
exception = Exception("test")
uuid = uuid4()
compare_run = ToolRun(
uuid=str(uuid),
parent_uuid=None,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=1,
serialized={"name": "tool"},
tool_input="test",
output=None,
action="{'name': 'tool'}",
session_id=TEST_SESSION_ID,
error=repr(exception),
)
tracer = FakeTracer()
tracer.new_session()
tracer.on_tool_start(serialized={"name": "tool"}, input_str="test", run_id=uuid)
tracer.on_tool_error(exception, run_id=uuid)
assert tracer.runs == [compare_run]
@freeze_time("2023-01-01")
def test_tracer_nested_runs_on_error() -> None:
"""Test tracer on a nested run with an error."""
exception = Exception("test")
tracer = FakeTracer()
tracer.new_session()
chain_uuid = uuid4()
tool_uuid = uuid4()
llm_uuid1 = uuid4()
llm_uuid2 = uuid4()
llm_uuid3 = uuid4()
for _ in range(3):
tracer.on_chain_start(
serialized={"name": "chain"}, inputs={}, run_id=chain_uuid
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid1,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid1)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid2,
parent_run_id=chain_uuid,
)
tracer.on_llm_end(response=LLMResult(generations=[[]]), run_id=llm_uuid2)
tracer.on_tool_start(
serialized={"name": "tool"},
input_str="test",
run_id=tool_uuid,
parent_run_id=chain_uuid,
)
tracer.on_llm_start(
serialized=SERIALIZED,
prompts=[],
run_id=llm_uuid3,
parent_run_id=tool_uuid,
)
tracer.on_llm_error(exception, run_id=llm_uuid3)
tracer.on_tool_error(exception, run_id=tool_uuid)
tracer.on_chain_error(exception, run_id=chain_uuid)
compare_run = ChainRun(
uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=1,
child_execution_order=5,
serialized={"name": "chain"},
session_id=TEST_SESSION_ID,
error=repr(exception),
inputs={},
outputs=None,
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid1),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=2,
child_execution_order=2,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
LLMRun(
uuid=str(llm_uuid2),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=3,
child_execution_order=3,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=None,
prompts=[],
response=LLMResult(generations=[[]], llm_output=None),
),
],
child_chain_runs=[],
child_tool_runs=[
ToolRun(
uuid=str(tool_uuid),
parent_uuid=str(chain_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=4,
child_execution_order=5,
serialized={"name": "tool"},
session_id=TEST_SESSION_ID,
error=repr(exception),
tool_input="test",
output=None,
action="{'name': 'tool'}",
child_llm_runs=[
LLMRun(
uuid=str(llm_uuid3),
parent_uuid=str(tool_uuid),
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
extra={},
execution_order=5,
child_execution_order=5,
serialized=SERIALIZED,
session_id=TEST_SESSION_ID,
error=repr(exception),
prompts=[],
response=None,
)
],
child_chain_runs=[],
child_tool_runs=[],
),
],
)
assert tracer.runs == [compare_run] * 3
@pytest.fixture
def sample_tracer_session_v1() -> TracerSessionV1:
return TracerSessionV1(id=2, name="Sample session")
@freeze_time("2023-01-01")
def test_convert_run(
lang_chain_tracer_v1: LangChainTracerV1,
sample_tracer_session_v1: TracerSessionV1,
) -> None:
"""Test converting a run to a V1 run."""
llm_run = Run(
id="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=TEST_SESSION_ID,
inputs={"prompts": []},
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type="llm",
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
child_execution_order=1,
serialized={},
inputs={},
outputs={},
child_runs=[llm_run],
extra={},
run_type="chain",
)
tool_run = Run(
id="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
inputs={"input": "test"},
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
outputs=None,
serialized={},
child_runs=[],
extra={},
run_type="tool",
)
expected_llm_run = LLMRun(
uuid="57a08cc4-73d2-4236-8370-549099d07fad",
name="llm_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=2,
prompts=[],
response=LLMResult(generations=[[]]),
serialized={},
extra={},
)
expected_chain_run = ChainRun(
uuid="57a08cc4-73d2-4236-8371-549099d07fad",
name="chain_run",
execution_order=1,
child_execution_order=1,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
session_id=2,
serialized={},
inputs={},
outputs={},
child_llm_runs=[expected_llm_run],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
expected_tool_run = ToolRun(
uuid="57a08cc4-73d2-4236-8372-549099d07fad",
name="tool_run",
execution_order=1,
child_execution_order=1,
session_id=2,
start_time=datetime.utcnow(),
end_time=datetime.utcnow(),
tool_input="test",
action="{}",
serialized={},
child_llm_runs=[],
child_chain_runs=[],
child_tool_runs=[],
extra={},
)
lang_chain_tracer_v1.session = sample_tracer_session_v1
converted_llm_run = lang_chain_tracer_v1._convert_to_v1_run(llm_run)
converted_chain_run = lang_chain_tracer_v1._convert_to_v1_run(chain_run)
converted_tool_run = lang_chain_tracer_v1._convert_to_v1_run(tool_run)
assert isinstance(converted_llm_run, LLMRun)
assert isinstance(converted_chain_run, ChainRun)
assert isinstance(converted_tool_run, ToolRun)
assert converted_llm_run == expected_llm_run
assert converted_tool_run == expected_tool_run
assert converted_chain_run == expected_chain_run

File diff suppressed because one or more lines are too long