mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-25 16:13:25 +00:00
Add CometTracer (#13661)
Hi! I'm Alex, Python SDK Team Lead from [Comet](https://www.comet.com/site/). This PR contains our new integration between langchain and Comet - `CometTracer` class which uses new `comet_llm` python package for submitting data to Comet. No additional dependencies for the langchain package are required directly, but if the user wants to use `CometTracer`, `comet-llm>=2.0.0` should be installed. Otherwise an exception will be raised from `CometTracer.__init__`. A test for the feature is included. There is also an already existing callback (and .ipynb file with example) which ideally should be deprecated in favor of a new tracer. I wasn't sure how exactly you'd prefer to do it. For example we could open a separate PR for that. I'm open to your ideas :)
This commit is contained in:
parent
921c4b5597
commit
676a077c4e
138
libs/langchain/langchain/callbacks/tracers/comet.py
Normal file
138
libs/langchain/langchain/callbacks/tracers/comet.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
from types import ModuleType, SimpleNamespace
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, Dict
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from uuid import UUID
|
||||||
|
|
||||||
|
from comet_llm import Span
|
||||||
|
from comet_llm.chains.chain import Chain
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
|
|
||||||
|
|
||||||
|
def _get_run_type(run: "Run") -> str:
|
||||||
|
if isinstance(run.run_type, str):
|
||||||
|
return run.run_type
|
||||||
|
elif hasattr(run.run_type, "value"):
|
||||||
|
return run.run_type.value
|
||||||
|
else:
|
||||||
|
return str(run.run_type)
|
||||||
|
|
||||||
|
|
||||||
|
def import_comet_llm_api() -> SimpleNamespace:
|
||||||
|
"""Import comet_llm api and raise an error if it is not installed."""
|
||||||
|
try:
|
||||||
|
from comet_llm import (
|
||||||
|
experiment_info, # noqa: F401
|
||||||
|
flush, # noqa: F401
|
||||||
|
)
|
||||||
|
from comet_llm.chains import api as chain_api # noqa: F401
|
||||||
|
from comet_llm.chains import (
|
||||||
|
chain, # noqa: F401
|
||||||
|
span, # noqa: F401
|
||||||
|
)
|
||||||
|
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
"To use the CometTracer you need to have the "
|
||||||
|
"`comet_llm>=2.0.0` python package installed. Please install it with"
|
||||||
|
" `pip install -U comet_llm`"
|
||||||
|
)
|
||||||
|
return SimpleNamespace(
|
||||||
|
chain=chain,
|
||||||
|
span=span,
|
||||||
|
chain_api=chain_api,
|
||||||
|
experiment_info=experiment_info,
|
||||||
|
flush=flush,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CometTracer(BaseTracer):
|
||||||
|
def __init__(self, **kwargs: Any) -> None:
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
self._span_map: Dict["UUID", "Span"] = {}
|
||||||
|
self._chains_map: Dict["UUID", "Chain"] = {}
|
||||||
|
self._initialize_comet_modules()
|
||||||
|
|
||||||
|
def _initialize_comet_modules(self) -> None:
|
||||||
|
comet_llm_api = import_comet_llm_api()
|
||||||
|
self._chain: ModuleType = comet_llm_api.chain
|
||||||
|
self._span: ModuleType = comet_llm_api.span
|
||||||
|
self._chain_api: ModuleType = comet_llm_api.chain_api
|
||||||
|
self._experiment_info: ModuleType = comet_llm_api.experiment_info
|
||||||
|
self._flush: Callable[[], None] = comet_llm_api.flush
|
||||||
|
|
||||||
|
def _persist_run(self, run: "Run") -> None:
|
||||||
|
chain_ = self._chains_map[run.id]
|
||||||
|
chain_.set_outputs(outputs=run.outputs)
|
||||||
|
self._chain_api.log_chain(chain_)
|
||||||
|
|
||||||
|
def _process_start_trace(self, run: "Run") -> None:
|
||||||
|
if not run.parent_run_id:
|
||||||
|
# This is the first run, which maps to a chain
|
||||||
|
chain_: "Chain" = self._chain.Chain(
|
||||||
|
inputs=run.inputs,
|
||||||
|
metadata=None,
|
||||||
|
experiment_info=self._experiment_info.get(),
|
||||||
|
)
|
||||||
|
self._chains_map[run.id] = chain_
|
||||||
|
else:
|
||||||
|
span: "Span" = self._span.Span(
|
||||||
|
inputs=run.inputs,
|
||||||
|
category=_get_run_type(run),
|
||||||
|
metadata=run.extra,
|
||||||
|
name=run.name,
|
||||||
|
)
|
||||||
|
span.__api__start__(self._chains_map[run.parent_run_id])
|
||||||
|
self._chains_map[run.id] = self._chains_map[run.parent_run_id]
|
||||||
|
self._span_map[run.id] = span
|
||||||
|
|
||||||
|
def _process_end_trace(self, run: "Run") -> None:
|
||||||
|
if not run.parent_run_id:
|
||||||
|
pass
|
||||||
|
# Langchain will call _persist_run for us
|
||||||
|
else:
|
||||||
|
span = self._span_map[run.id]
|
||||||
|
span.set_outputs(outputs=run.outputs)
|
||||||
|
span.__api__end__()
|
||||||
|
|
||||||
|
def flush(self) -> None:
|
||||||
|
self._flush()
|
||||||
|
|
||||||
|
def _on_llm_start(self, run: "Run") -> None:
|
||||||
|
"""Process the LLM Run upon start."""
|
||||||
|
self._process_start_trace(run)
|
||||||
|
|
||||||
|
def _on_llm_end(self, run: "Run") -> None:
|
||||||
|
"""Process the LLM Run."""
|
||||||
|
self._process_end_trace(run)
|
||||||
|
|
||||||
|
def _on_llm_error(self, run: "Run") -> None:
|
||||||
|
"""Process the LLM Run upon error."""
|
||||||
|
self._process_end_trace(run)
|
||||||
|
|
||||||
|
def _on_chain_start(self, run: "Run") -> None:
|
||||||
|
"""Process the Chain Run upon start."""
|
||||||
|
self._process_start_trace(run)
|
||||||
|
|
||||||
|
def _on_chain_end(self, run: "Run") -> None:
|
||||||
|
"""Process the Chain Run."""
|
||||||
|
self._process_end_trace(run)
|
||||||
|
|
||||||
|
def _on_chain_error(self, run: "Run") -> None:
|
||||||
|
"""Process the Chain Run upon error."""
|
||||||
|
self._process_end_trace(run)
|
||||||
|
|
||||||
|
def _on_tool_start(self, run: "Run") -> None:
|
||||||
|
"""Process the Tool Run upon start."""
|
||||||
|
self._process_start_trace(run)
|
||||||
|
|
||||||
|
def _on_tool_end(self, run: "Run") -> None:
|
||||||
|
"""Process the Tool Run."""
|
||||||
|
self._process_end_trace(run)
|
||||||
|
|
||||||
|
def _on_tool_error(self, run: "Run") -> None:
|
||||||
|
"""Process the Tool Run upon error."""
|
||||||
|
self._process_end_trace(run)
|
@ -0,0 +1,97 @@
|
|||||||
|
import uuid
|
||||||
|
from types import SimpleNamespace
|
||||||
|
from unittest import mock
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers import comet
|
||||||
|
from langchain.schema.output import LLMResult
|
||||||
|
|
||||||
|
|
||||||
|
def test_comet_tracer__trace_chain_with_single_span__happyflow() -> None:
|
||||||
|
# Setup mocks
|
||||||
|
chain_module_mock = mock.Mock()
|
||||||
|
chain_instance_mock = mock.Mock()
|
||||||
|
chain_module_mock.Chain.return_value = chain_instance_mock
|
||||||
|
|
||||||
|
span_module_mock = mock.Mock()
|
||||||
|
span_instance_mock = mock.MagicMock()
|
||||||
|
span_instance_mock.__api__start__ = mock.Mock()
|
||||||
|
span_instance_mock.__api__end__ = mock.Mock()
|
||||||
|
|
||||||
|
span_module_mock.Span.return_value = span_instance_mock
|
||||||
|
|
||||||
|
experiment_info_module_mock = mock.Mock()
|
||||||
|
experiment_info_module_mock.get.return_value = "the-experiment-info"
|
||||||
|
|
||||||
|
chain_api_module_mock = mock.Mock()
|
||||||
|
|
||||||
|
comet_ml_api_mock = SimpleNamespace(
|
||||||
|
chain=chain_module_mock,
|
||||||
|
span=span_module_mock,
|
||||||
|
experiment_info=experiment_info_module_mock,
|
||||||
|
chain_api=chain_api_module_mock,
|
||||||
|
flush="not-used-in-this-test",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tracer
|
||||||
|
with mock.patch.object(
|
||||||
|
comet, "import_comet_llm_api", return_value=comet_ml_api_mock
|
||||||
|
):
|
||||||
|
tracer = comet.CometTracer()
|
||||||
|
|
||||||
|
run_id_1 = uuid.UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a")
|
||||||
|
run_id_2 = uuid.UUID("4f31216e-7c26-4027-a5fd-0bbf9ace17dc")
|
||||||
|
|
||||||
|
# Parent run
|
||||||
|
tracer.on_chain_start(
|
||||||
|
{"name": "chain-input"},
|
||||||
|
["chain-input-prompt"],
|
||||||
|
parent_run_id=None,
|
||||||
|
run_id=run_id_1,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that chain was created
|
||||||
|
chain_module_mock.Chain.assert_called_once_with(
|
||||||
|
inputs={"input": ["chain-input-prompt"]},
|
||||||
|
metadata=None,
|
||||||
|
experiment_info="the-experiment-info",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Child run
|
||||||
|
tracer.on_llm_start(
|
||||||
|
{"name": "span-input"},
|
||||||
|
["span-input-prompt"],
|
||||||
|
parent_run_id=run_id_1,
|
||||||
|
run_id=run_id_2,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that Span was created and attached to chain
|
||||||
|
span_module_mock.Span.assert_called_once_with(
|
||||||
|
inputs={"prompts": ["span-input-prompt"]},
|
||||||
|
category=mock.ANY,
|
||||||
|
metadata=mock.ANY,
|
||||||
|
name=mock.ANY,
|
||||||
|
)
|
||||||
|
span_instance_mock.__api__start__(chain_instance_mock)
|
||||||
|
|
||||||
|
# Child run end
|
||||||
|
tracer.on_llm_end(
|
||||||
|
LLMResult(generations=[], llm_output={"span-output-key": "span-output-value"}),
|
||||||
|
run_id=run_id_2,
|
||||||
|
)
|
||||||
|
# Check that Span outputs are set and span is ended
|
||||||
|
span_instance_mock.set_outputs.assert_called_once()
|
||||||
|
actual_span_outputs = span_instance_mock.set_outputs.call_args[1]["outputs"]
|
||||||
|
assert {
|
||||||
|
"llm_output": {"span-output-key": "span-output-value"},
|
||||||
|
"generations": [],
|
||||||
|
}.items() <= actual_span_outputs.items()
|
||||||
|
span_instance_mock.__api__end__()
|
||||||
|
|
||||||
|
# Parent run end
|
||||||
|
tracer.on_chain_end({"chain-output-key": "chain-output-value"}, run_id=run_id_1)
|
||||||
|
|
||||||
|
# Check that chain outputs are set and chain is logged
|
||||||
|
chain_instance_mock.set_outputs.assert_called_once()
|
||||||
|
actual_chain_outputs = chain_instance_mock.set_outputs.call_args[1]["outputs"]
|
||||||
|
assert ("chain-output-key", "chain-output-value") in actual_chain_outputs.items()
|
||||||
|
chain_api_module_mock.log_chain.assert_called_once_with(chain_instance_mock)
|
Loading…
Reference in New Issue
Block a user