From 676a077c4e7b1cfa0a3ddeee40cc10ff86d57f57 Mon Sep 17 00:00:00 2001 From: Aliaksandr Kuzmik <98702584+alexkuzmik@users.noreply.github.com> Date: Tue, 5 Dec 2023 01:46:48 +0100 Subject: [PATCH] Add CometTracer (#13661) Hi! I'm Alex, Python SDK Team Lead from [Comet](https://www.comet.com/site/). This PR contains our new integration between langchain and Comet - `CometTracer` class which uses new `comet_llm` python package for submitting data to Comet. No additional dependencies for the langchain package are required directly, but if the user wants to use `CometTracer`, `comet-llm>=2.0.0` should be installed. Otherwise an exception will be raised from `CometTracer.__init__`. A test for the feature is included. There is also an already existing callback (and .ipynb file with example) which ideally should be deprecated in favor of a new tracer. I wasn't sure how exactly you'd prefer to do it. For example we could open a separate PR for that. I'm open to your ideas :) --- .../langchain/callbacks/tracers/comet.py | 138 ++++++++++++++++++ .../callbacks/tracers/test_comet.py | 97 ++++++++++++ 2 files changed, 235 insertions(+) create mode 100644 libs/langchain/langchain/callbacks/tracers/comet.py create mode 100644 libs/langchain/tests/unit_tests/callbacks/tracers/test_comet.py diff --git a/libs/langchain/langchain/callbacks/tracers/comet.py b/libs/langchain/langchain/callbacks/tracers/comet.py new file mode 100644 index 00000000000..bfe7bb44342 --- /dev/null +++ b/libs/langchain/langchain/callbacks/tracers/comet.py @@ -0,0 +1,138 @@ +from types import ModuleType, SimpleNamespace +from typing import TYPE_CHECKING, Any, Callable, Dict + +from langchain.callbacks.tracers.base import BaseTracer + +if TYPE_CHECKING: + from uuid import UUID + + from comet_llm import Span + from comet_llm.chains.chain import Chain + + from langchain.callbacks.tracers.schemas import Run + + +def _get_run_type(run: "Run") -> str: + if isinstance(run.run_type, str): + return run.run_type + elif hasattr(run.run_type, "value"): + return run.run_type.value + else: + return str(run.run_type) + + +def import_comet_llm_api() -> SimpleNamespace: + """Import comet_llm api and raise an error if it is not installed.""" + try: + from comet_llm import ( + experiment_info, # noqa: F401 + flush, # noqa: F401 + ) + from comet_llm.chains import api as chain_api # noqa: F401 + from comet_llm.chains import ( + chain, # noqa: F401 + span, # noqa: F401 + ) + + except ImportError: + raise ImportError( + "To use the CometTracer you need to have the " + "`comet_llm>=2.0.0` python package installed. Please install it with" + " `pip install -U comet_llm`" + ) + return SimpleNamespace( + chain=chain, + span=span, + chain_api=chain_api, + experiment_info=experiment_info, + flush=flush, + ) + + +class CometTracer(BaseTracer): + def __init__(self, **kwargs: Any) -> None: + super().__init__(**kwargs) + self._span_map: Dict["UUID", "Span"] = {} + self._chains_map: Dict["UUID", "Chain"] = {} + self._initialize_comet_modules() + + def _initialize_comet_modules(self) -> None: + comet_llm_api = import_comet_llm_api() + self._chain: ModuleType = comet_llm_api.chain + self._span: ModuleType = comet_llm_api.span + self._chain_api: ModuleType = comet_llm_api.chain_api + self._experiment_info: ModuleType = comet_llm_api.experiment_info + self._flush: Callable[[], None] = comet_llm_api.flush + + def _persist_run(self, run: "Run") -> None: + chain_ = self._chains_map[run.id] + chain_.set_outputs(outputs=run.outputs) + self._chain_api.log_chain(chain_) + + def _process_start_trace(self, run: "Run") -> None: + if not run.parent_run_id: + # This is the first run, which maps to a chain + chain_: "Chain" = self._chain.Chain( + inputs=run.inputs, + metadata=None, + experiment_info=self._experiment_info.get(), + ) + self._chains_map[run.id] = chain_ + else: + span: "Span" = self._span.Span( + inputs=run.inputs, + category=_get_run_type(run), + metadata=run.extra, + name=run.name, + ) + span.__api__start__(self._chains_map[run.parent_run_id]) + self._chains_map[run.id] = self._chains_map[run.parent_run_id] + self._span_map[run.id] = span + + def _process_end_trace(self, run: "Run") -> None: + if not run.parent_run_id: + pass + # Langchain will call _persist_run for us + else: + span = self._span_map[run.id] + span.set_outputs(outputs=run.outputs) + span.__api__end__() + + def flush(self) -> None: + self._flush() + + def _on_llm_start(self, run: "Run") -> None: + """Process the LLM Run upon start.""" + self._process_start_trace(run) + + def _on_llm_end(self, run: "Run") -> None: + """Process the LLM Run.""" + self._process_end_trace(run) + + def _on_llm_error(self, run: "Run") -> None: + """Process the LLM Run upon error.""" + self._process_end_trace(run) + + def _on_chain_start(self, run: "Run") -> None: + """Process the Chain Run upon start.""" + self._process_start_trace(run) + + def _on_chain_end(self, run: "Run") -> None: + """Process the Chain Run.""" + self._process_end_trace(run) + + def _on_chain_error(self, run: "Run") -> None: + """Process the Chain Run upon error.""" + self._process_end_trace(run) + + def _on_tool_start(self, run: "Run") -> None: + """Process the Tool Run upon start.""" + self._process_start_trace(run) + + def _on_tool_end(self, run: "Run") -> None: + """Process the Tool Run.""" + self._process_end_trace(run) + + def _on_tool_error(self, run: "Run") -> None: + """Process the Tool Run upon error.""" + self._process_end_trace(run) diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_comet.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_comet.py new file mode 100644 index 00000000000..537bc64e455 --- /dev/null +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_comet.py @@ -0,0 +1,97 @@ +import uuid +from types import SimpleNamespace +from unittest import mock + +from langchain.callbacks.tracers import comet +from langchain.schema.output import LLMResult + + +def test_comet_tracer__trace_chain_with_single_span__happyflow() -> None: + # Setup mocks + chain_module_mock = mock.Mock() + chain_instance_mock = mock.Mock() + chain_module_mock.Chain.return_value = chain_instance_mock + + span_module_mock = mock.Mock() + span_instance_mock = mock.MagicMock() + span_instance_mock.__api__start__ = mock.Mock() + span_instance_mock.__api__end__ = mock.Mock() + + span_module_mock.Span.return_value = span_instance_mock + + experiment_info_module_mock = mock.Mock() + experiment_info_module_mock.get.return_value = "the-experiment-info" + + chain_api_module_mock = mock.Mock() + + comet_ml_api_mock = SimpleNamespace( + chain=chain_module_mock, + span=span_module_mock, + experiment_info=experiment_info_module_mock, + chain_api=chain_api_module_mock, + flush="not-used-in-this-test", + ) + + # Create tracer + with mock.patch.object( + comet, "import_comet_llm_api", return_value=comet_ml_api_mock + ): + tracer = comet.CometTracer() + + run_id_1 = uuid.UUID("9d878ab3-e5ca-4218-aef6-44cbdc90160a") + run_id_2 = uuid.UUID("4f31216e-7c26-4027-a5fd-0bbf9ace17dc") + + # Parent run + tracer.on_chain_start( + {"name": "chain-input"}, + ["chain-input-prompt"], + parent_run_id=None, + run_id=run_id_1, + ) + + # Check that chain was created + chain_module_mock.Chain.assert_called_once_with( + inputs={"input": ["chain-input-prompt"]}, + metadata=None, + experiment_info="the-experiment-info", + ) + + # Child run + tracer.on_llm_start( + {"name": "span-input"}, + ["span-input-prompt"], + parent_run_id=run_id_1, + run_id=run_id_2, + ) + + # Check that Span was created and attached to chain + span_module_mock.Span.assert_called_once_with( + inputs={"prompts": ["span-input-prompt"]}, + category=mock.ANY, + metadata=mock.ANY, + name=mock.ANY, + ) + span_instance_mock.__api__start__(chain_instance_mock) + + # Child run end + tracer.on_llm_end( + LLMResult(generations=[], llm_output={"span-output-key": "span-output-value"}), + run_id=run_id_2, + ) + # Check that Span outputs are set and span is ended + span_instance_mock.set_outputs.assert_called_once() + actual_span_outputs = span_instance_mock.set_outputs.call_args[1]["outputs"] + assert { + "llm_output": {"span-output-key": "span-output-value"}, + "generations": [], + }.items() <= actual_span_outputs.items() + span_instance_mock.__api__end__() + + # Parent run end + tracer.on_chain_end({"chain-output-key": "chain-output-value"}, run_id=run_id_1) + + # Check that chain outputs are set and chain is logged + chain_instance_mock.set_outputs.assert_called_once() + actual_chain_outputs = chain_instance_mock.set_outputs.call_args[1]["outputs"] + assert ("chain-output-key", "chain-output-value") in actual_chain_outputs.items() + chain_api_module_mock.log_chain.assert_called_once_with(chain_instance_mock)