mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
feat(integrations): Add WandbTracer (#4521)
# WandbTracer This PR adds the `WandbTracer` and deprecates the existing `WandbCallbackHandler`. Added an example notebook under the docs section alongside the `LangchainTracer` Here's an example [colab](https://colab.research.google.com/drive/1pY13ym8ENEZ8Fh7nA99ILk2GcdUQu0jR?usp=sharing) with the same notebook and the [trace](https://wandb.ai/parambharat/langchain-tracing/runs/8i45cst6) generated from the colab run Co-authored-by: Bharat Ramanathan <ramanathan.parameshwaran@gohuddl.com> Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
This commit is contained in:
parent
373ad49157
commit
22603d19e0
238
docs/integrations/agent_with_wandb_tracing.ipynb
Normal file
238
docs/integrations/agent_with_wandb_tracing.ipynb
Normal file
File diff suppressed because one or more lines are too long
@ -10,7 +10,9 @@
|
|||||||
"\n",
|
"\n",
|
||||||
"Run in Colab: https://colab.research.google.com/drive/1DXH4beT4HFaRKy_Vm4PoxhXVDRf7Ym8L?usp=sharing\n",
|
"Run in Colab: https://colab.research.google.com/drive/1DXH4beT4HFaRKy_Vm4PoxhXVDRf7Ym8L?usp=sharing\n",
|
||||||
"\n",
|
"\n",
|
||||||
"View Report: https://wandb.ai/a-sh0ts/langchain_callback_demo/reports/Prompt-Engineering-LLMs-with-LangChain-and-W-B--VmlldzozNjk1NTUw#👋-how-to-build-a-callback-in-langchain-for-better-prompt-engineering"
|
"View Report: https://wandb.ai/a-sh0ts/langchain_callback_demo/reports/Prompt-Engineering-LLMs-with-LangChain-and-W-B--VmlldzozNjk1NTUw#👋-how-to-build-a-callback-in-langchain-for-better-prompt-engineering\n",
|
||||||
|
"\n",
|
||||||
|
"**Note**: _the `WandbCallbackHandler` is being deprecated in favour of the `WandbTracer`_ . In future please use the `WandbTracer` as it is more flexible and allows for more granular logging. To know more about the `WandbTracer` refer to the agent_with_wandb_tracing.ipynb notebook in docs or use the following [colab](https://colab.research.google.com/drive/1pY13ym8ENEZ8Fh7nA99ILk2GcdUQu0jR?usp=sharing)."
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -107,7 +109,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mharrison-chase\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
|
"\u001B[34m\u001B[1mwandb\u001B[0m: Currently logged in as: \u001B[33mharrison-chase\u001B[0m. Use \u001B[1m`wandb login --relogin`\u001B[0m to force relogin\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@ -174,7 +176,7 @@
|
|||||||
"name": "stderr",
|
"name": "stderr",
|
||||||
"output_type": "stream",
|
"output_type": "stream",
|
||||||
"text": [
|
"text": [
|
||||||
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m The wandb callback is currently in beta and is subject to change based on updates to `langchain`. Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.\n"
|
"\u001B[34m\u001B[1mwandb\u001B[0m: \u001B[33mWARNING\u001B[0m The wandb callback is currently in beta and is subject to change based on updates to `langchain`. Please report any issues to https://github.com/wandb/wandb/issues with the tag `langchain`.\n"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
@ -521,20 +523,20 @@
|
|||||||
"text": [
|
"text": [
|
||||||
"\n",
|
"\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Entering new AgentExecutor chain...\u001b[0m\n",
|
"\u001B[1m> Entering new AgentExecutor chain...\u001B[0m\n",
|
||||||
"\u001b[32;1m\u001b[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
|
"\u001B[32;1m\u001B[1;3m I need to find out who Leo DiCaprio's girlfriend is and then calculate her age raised to the 0.43 power.\n",
|
||||||
"Action: Search\n",
|
"Action: Search\n",
|
||||||
"Action Input: \"Leo DiCaprio girlfriend\"\u001b[0m\n",
|
"Action Input: \"Leo DiCaprio girlfriend\"\u001B[0m\n",
|
||||||
"Observation: \u001b[36;1m\u001b[1;3mDiCaprio had a steady girlfriend in Camila Morrone. He had been with the model turned actress for nearly five years, as they were first said to be dating at the end of 2017. And the now 26-year-old Morrone is no stranger to Hollywood.\u001b[0m\n",
|
"Observation: \u001B[36;1m\u001B[1;3mDiCaprio had a steady girlfriend in Camila Morrone. He had been with the model turned actress for nearly five years, as they were first said to be dating at the end of 2017. And the now 26-year-old Morrone is no stranger to Hollywood.\u001B[0m\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I need to calculate her age raised to the 0.43 power.\n",
|
"Thought:\u001B[32;1m\u001B[1;3m I need to calculate her age raised to the 0.43 power.\n",
|
||||||
"Action: Calculator\n",
|
"Action: Calculator\n",
|
||||||
"Action Input: 26^0.43\u001b[0m\n",
|
"Action Input: 26^0.43\u001B[0m\n",
|
||||||
"Observation: \u001b[33;1m\u001b[1;3mAnswer: 4.059182145592686\n",
|
"Observation: \u001B[33;1m\u001B[1;3mAnswer: 4.059182145592686\n",
|
||||||
"\u001b[0m\n",
|
"\u001B[0m\n",
|
||||||
"Thought:\u001b[32;1m\u001b[1;3m I now know the final answer.\n",
|
"Thought:\u001B[32;1m\u001B[1;3m I now know the final answer.\n",
|
||||||
"Final Answer: Leo DiCaprio's girlfriend is Camila Morrone and her current age raised to the 0.43 power is 4.059182145592686.\u001b[0m\n",
|
"Final Answer: Leo DiCaprio's girlfriend is Camila Morrone and her current age raised to the 0.43 power is 4.059182145592686.\u001B[0m\n",
|
||||||
"\n",
|
"\n",
|
||||||
"\u001b[1m> Finished chain.\u001b[0m\n"
|
"\u001B[1m> Finished chain.\u001B[0m\n"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -6,6 +6,7 @@ from langchain.callbacks.comet_ml_callback import CometCallbackHandler
|
|||||||
from langchain.callbacks.manager import (
|
from langchain.callbacks.manager import (
|
||||||
get_openai_callback,
|
get_openai_callback,
|
||||||
tracing_enabled,
|
tracing_enabled,
|
||||||
|
wandb_tracing_enabled,
|
||||||
)
|
)
|
||||||
from langchain.callbacks.mlflow_callback import MlflowCallbackHandler
|
from langchain.callbacks.mlflow_callback import MlflowCallbackHandler
|
||||||
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
from langchain.callbacks.openai_info import OpenAICallbackHandler
|
||||||
@ -26,4 +27,5 @@ __all__ = [
|
|||||||
"AsyncIteratorCallbackHandler",
|
"AsyncIteratorCallbackHandler",
|
||||||
"get_openai_callback",
|
"get_openai_callback",
|
||||||
"tracing_enabled",
|
"tracing_enabled",
|
||||||
|
"wandb_tracing_enabled",
|
||||||
]
|
]
|
||||||
|
@ -25,6 +25,7 @@ from langchain.callbacks.tracers.langchain import LangChainTracer
|
|||||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1, TracerSessionV1
|
||||||
from langchain.callbacks.tracers.schemas import TracerSession
|
from langchain.callbacks.tracers.schemas import TracerSession
|
||||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
from langchain.callbacks.tracers.wandb import WandbTracer
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
AgentAction,
|
AgentAction,
|
||||||
AgentFinish,
|
AgentFinish,
|
||||||
@ -44,6 +45,12 @@ tracing_callback_var: ContextVar[
|
|||||||
] = ContextVar( # noqa: E501
|
] = ContextVar( # noqa: E501
|
||||||
"tracing_callback", default=None
|
"tracing_callback", default=None
|
||||||
)
|
)
|
||||||
|
wandb_tracing_callback_var: ContextVar[
|
||||||
|
Optional[WandbTracer]
|
||||||
|
] = ContextVar( # noqa: E501
|
||||||
|
"tracing_wandb_callback", default=None
|
||||||
|
)
|
||||||
|
|
||||||
tracing_v2_callback_var: ContextVar[
|
tracing_v2_callback_var: ContextVar[
|
||||||
Optional[LangChainTracer]
|
Optional[LangChainTracer]
|
||||||
] = ContextVar( # noqa: E501
|
] = ContextVar( # noqa: E501
|
||||||
@ -76,6 +83,17 @@ def tracing_enabled(
|
|||||||
tracing_callback_var.set(None)
|
tracing_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def wandb_tracing_enabled(
|
||||||
|
session_name: str = "default",
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""Get WandbTracer in a context manager."""
|
||||||
|
cb = WandbTracer()
|
||||||
|
wandb_tracing_callback_var.set(cb)
|
||||||
|
yield None
|
||||||
|
wandb_tracing_callback_var.set(None)
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def tracing_v2_enabled(
|
def tracing_v2_enabled(
|
||||||
session_name: Optional[str] = None,
|
session_name: Optional[str] = None,
|
||||||
@ -831,12 +849,17 @@ def _configure(
|
|||||||
callback_manager.add_handler(handler, False)
|
callback_manager.add_handler(handler, False)
|
||||||
|
|
||||||
tracer = tracing_callback_var.get()
|
tracer = tracing_callback_var.get()
|
||||||
|
wandb_tracer = wandb_tracing_callback_var.get()
|
||||||
open_ai = openai_callback_var.get()
|
open_ai = openai_callback_var.get()
|
||||||
tracing_enabled_ = (
|
tracing_enabled_ = (
|
||||||
os.environ.get("LANGCHAIN_TRACING") is not None
|
os.environ.get("LANGCHAIN_TRACING") is not None
|
||||||
or tracer is not None
|
or tracer is not None
|
||||||
or os.environ.get("LANGCHAIN_HANDLER") is not None
|
or os.environ.get("LANGCHAIN_HANDLER") is not None
|
||||||
)
|
)
|
||||||
|
wandb_tracing_enabled_ = (
|
||||||
|
os.environ.get("LANGCHAIN_WANDB_TRACING") is not None
|
||||||
|
or wandb_tracer is not None
|
||||||
|
)
|
||||||
|
|
||||||
tracer_v2 = tracing_v2_callback_var.get()
|
tracer_v2 = tracing_v2_callback_var.get()
|
||||||
tracing_v2_enabled_ = (
|
tracing_v2_enabled_ = (
|
||||||
@ -851,6 +874,7 @@ def _configure(
|
|||||||
or debug
|
or debug
|
||||||
or tracing_enabled_
|
or tracing_enabled_
|
||||||
or tracing_v2_enabled_
|
or tracing_v2_enabled_
|
||||||
|
or wandb_tracing_enabled_
|
||||||
or open_ai is not None
|
or open_ai is not None
|
||||||
):
|
):
|
||||||
if verbose and not any(
|
if verbose and not any(
|
||||||
@ -876,6 +900,14 @@ def _configure(
|
|||||||
handler = LangChainTracerV1()
|
handler = LangChainTracerV1()
|
||||||
handler.load_session(tracer_session)
|
handler.load_session(tracer_session)
|
||||||
callback_manager.add_handler(handler, True)
|
callback_manager.add_handler(handler, True)
|
||||||
|
if wandb_tracing_enabled_ and not any(
|
||||||
|
isinstance(handler, WandbTracer) for handler in callback_manager.handlers
|
||||||
|
):
|
||||||
|
if wandb_tracer:
|
||||||
|
callback_manager.add_handler(wandb_tracer, True)
|
||||||
|
else:
|
||||||
|
handler = WandbTracer()
|
||||||
|
callback_manager.add_handler(handler, True)
|
||||||
if tracing_v2_enabled_ and not any(
|
if tracing_v2_enabled_ and not any(
|
||||||
isinstance(handler, LangChainTracer)
|
isinstance(handler, LangChainTracer)
|
||||||
for handler in callback_manager.handlers
|
for handler in callback_manager.handlers
|
||||||
|
@ -3,5 +3,11 @@
|
|||||||
from langchain.callbacks.tracers.langchain import LangChainTracer
|
from langchain.callbacks.tracers.langchain import LangChainTracer
|
||||||
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1
|
from langchain.callbacks.tracers.langchain_v1 import LangChainTracerV1
|
||||||
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
from langchain.callbacks.tracers.stdout import ConsoleCallbackHandler
|
||||||
|
from langchain.callbacks.tracers.wandb import WandbTracer
|
||||||
|
|
||||||
__all__ = ["LangChainTracer", "LangChainTracerV1", "ConsoleCallbackHandler"]
|
__all__ = [
|
||||||
|
"LangChainTracer",
|
||||||
|
"LangChainTracerV1",
|
||||||
|
"ConsoleCallbackHandler",
|
||||||
|
"WandbTracer",
|
||||||
|
]
|
||||||
|
265
langchain/callbacks/tracers/wandb.py
Normal file
265
langchain/callbacks/tracers/wandb.py
Normal file
@ -0,0 +1,265 @@
|
|||||||
|
"""A Tracer Implementation that records activity to Weights & Biases."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import (
|
||||||
|
TYPE_CHECKING,
|
||||||
|
Any,
|
||||||
|
Dict,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Sequence,
|
||||||
|
TypedDict,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
|
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from wandb import Settings as WBSettings
|
||||||
|
from wandb.sdk.data_types import trace_tree
|
||||||
|
from wandb.sdk.lib.paths import StrPath
|
||||||
|
from wandb.wandb_run import Run as WBRun
|
||||||
|
|
||||||
|
|
||||||
|
PRINT_WARNINGS = True
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_lc_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||||
|
if run.run_type == RunTypeEnum.llm:
|
||||||
|
return _convert_llm_run_to_wb_span(trace_tree, run)
|
||||||
|
elif run.run_type == RunTypeEnum.chain:
|
||||||
|
return _convert_chain_run_to_wb_span(trace_tree, run)
|
||||||
|
elif run.run_type == RunTypeEnum.tool:
|
||||||
|
return _convert_tool_run_to_wb_span(trace_tree, run)
|
||||||
|
else:
|
||||||
|
return _convert_run_to_wb_span(trace_tree, run)
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_llm_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||||
|
base_span = _convert_run_to_wb_span(trace_tree, run)
|
||||||
|
|
||||||
|
base_span.results = [
|
||||||
|
trace_tree.Result(
|
||||||
|
inputs={"prompt": prompt},
|
||||||
|
outputs={
|
||||||
|
f"gen_{g_i}": gen["text"]
|
||||||
|
for g_i, gen in enumerate(run.outputs["generations"][ndx])
|
||||||
|
}
|
||||||
|
if (
|
||||||
|
run.outputs is not None
|
||||||
|
and len(run.outputs["generations"]) > ndx
|
||||||
|
and len(run.outputs["generations"][ndx]) > 0
|
||||||
|
)
|
||||||
|
else None,
|
||||||
|
)
|
||||||
|
for ndx, prompt in enumerate(run.inputs["prompts"] or [])
|
||||||
|
]
|
||||||
|
base_span.span_kind = trace_tree.SpanKind.LLM
|
||||||
|
|
||||||
|
return base_span
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_chain_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||||
|
base_span = _convert_run_to_wb_span(trace_tree, run)
|
||||||
|
|
||||||
|
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)]
|
||||||
|
base_span.child_spans = [
|
||||||
|
_convert_lc_run_to_wb_span(trace_tree, child_run)
|
||||||
|
for child_run in run.child_runs
|
||||||
|
]
|
||||||
|
base_span.span_kind = (
|
||||||
|
trace_tree.SpanKind.AGENT
|
||||||
|
if "agent" in run.serialized.get("name", "").lower()
|
||||||
|
else trace_tree.SpanKind.CHAIN
|
||||||
|
)
|
||||||
|
|
||||||
|
return base_span
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_tool_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||||
|
base_span = _convert_run_to_wb_span(trace_tree, run)
|
||||||
|
base_span.results = [trace_tree.Result(inputs=run.inputs, outputs=run.outputs)]
|
||||||
|
base_span.child_spans = [
|
||||||
|
_convert_lc_run_to_wb_span(trace_tree, child_run)
|
||||||
|
for child_run in run.child_runs
|
||||||
|
]
|
||||||
|
base_span.span_kind = trace_tree.SpanKind.TOOL
|
||||||
|
|
||||||
|
return base_span
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_run_to_wb_span(trace_tree: Any, run: Run) -> trace_tree.Span:
|
||||||
|
attributes = {**run.extra} if run.extra else {}
|
||||||
|
attributes["execution_order"] = run.execution_order
|
||||||
|
|
||||||
|
return trace_tree.Span(
|
||||||
|
span_id=str(run.id) if run.id is not None else None,
|
||||||
|
name=run.serialized.get("name"),
|
||||||
|
start_time_ms=int(run.start_time.timestamp() * 1000),
|
||||||
|
end_time_ms=int(run.end_time.timestamp() * 1000),
|
||||||
|
status_code=trace_tree.StatusCode.SUCCESS
|
||||||
|
if run.error is None
|
||||||
|
else trace_tree.StatusCode.ERROR,
|
||||||
|
status_message=run.error,
|
||||||
|
attributes=attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_type_with_kind(data: Any) -> Any:
|
||||||
|
if isinstance(data, dict):
|
||||||
|
# W&B TraceTree expects "_kind" instead of "_type" since `_type` is special
|
||||||
|
# in W&B.
|
||||||
|
if "_type" in data:
|
||||||
|
_type = data.pop("_type")
|
||||||
|
data["_kind"] = _type
|
||||||
|
return {k: _replace_type_with_kind(v) for k, v in data.items()}
|
||||||
|
elif isinstance(data, list):
|
||||||
|
return [_replace_type_with_kind(v) for v in data]
|
||||||
|
elif isinstance(data, tuple):
|
||||||
|
return tuple(_replace_type_with_kind(v) for v in data)
|
||||||
|
elif isinstance(data, set):
|
||||||
|
return {_replace_type_with_kind(v) for v in data}
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
class WandbRunArgs(TypedDict):
|
||||||
|
job_type: Optional[str]
|
||||||
|
dir: Optional[StrPath]
|
||||||
|
config: Union[Dict, str, None]
|
||||||
|
project: Optional[str]
|
||||||
|
entity: Optional[str]
|
||||||
|
reinit: Optional[bool]
|
||||||
|
tags: Optional[Sequence]
|
||||||
|
group: Optional[str]
|
||||||
|
name: Optional[str]
|
||||||
|
notes: Optional[str]
|
||||||
|
magic: Optional[Union[dict, str, bool]]
|
||||||
|
config_exclude_keys: Optional[List[str]]
|
||||||
|
config_include_keys: Optional[List[str]]
|
||||||
|
anonymous: Optional[str]
|
||||||
|
mode: Optional[str]
|
||||||
|
allow_val_change: Optional[bool]
|
||||||
|
resume: Optional[Union[bool, str]]
|
||||||
|
force: Optional[bool]
|
||||||
|
tensorboard: Optional[bool]
|
||||||
|
sync_tensorboard: Optional[bool]
|
||||||
|
monitor_gym: Optional[bool]
|
||||||
|
save_code: Optional[bool]
|
||||||
|
id: Optional[str]
|
||||||
|
settings: Union[WBSettings, Dict[str, Any], None]
|
||||||
|
|
||||||
|
|
||||||
|
class WandbTracer(BaseTracer):
|
||||||
|
"""Callback Handler that logs to Weights and Biases.
|
||||||
|
|
||||||
|
This handler will log the model architecture and run traces to Weights and Biases.
|
||||||
|
This will ensure that all LangChain activity is logged to W&B.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_run: Optional[WBRun] = None
|
||||||
|
_run_args: Optional[WandbRunArgs] = None
|
||||||
|
|
||||||
|
def __init__(self, run_args: Optional[WandbRunArgs] = None, **kwargs: Any) -> None:
|
||||||
|
"""Initializes the WandbTracer.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
run_args: (dict, optional) Arguments to pass to `wandb.init()`. If not
|
||||||
|
provided, `wandb.init()` will be called with no arguments. Please
|
||||||
|
refer to the `wandb.init` for more details.
|
||||||
|
|
||||||
|
To use W&B to monitor all LangChain activity, add this tracer like any other
|
||||||
|
LangChain callback:
|
||||||
|
```
|
||||||
|
from wandb.integration.langchain import WandbTracer
|
||||||
|
|
||||||
|
tracer = WandbTracer()
|
||||||
|
chain = LLMChain(llm, callbacks=[tracer])
|
||||||
|
# ...end of notebook / script:
|
||||||
|
tracer.finish()
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
super().__init__(**kwargs)
|
||||||
|
try:
|
||||||
|
import wandb
|
||||||
|
from wandb.sdk.data_types import trace_tree
|
||||||
|
except ImportError as e:
|
||||||
|
raise ImportError(
|
||||||
|
"Could not import wandb python package."
|
||||||
|
"Please install it with `pip install wandb`."
|
||||||
|
) from e
|
||||||
|
self._wandb = wandb
|
||||||
|
self._trace_tree = trace_tree
|
||||||
|
self._run_args = run_args
|
||||||
|
self._ensure_run(should_print_url=(wandb.run is None))
|
||||||
|
|
||||||
|
def finish(self) -> None:
|
||||||
|
"""Waits for all asynchronous processes to finish and data to upload.
|
||||||
|
|
||||||
|
Proxy for `wandb.finish()`.
|
||||||
|
"""
|
||||||
|
self._wandb.finish()
|
||||||
|
|
||||||
|
def _log_trace_from_run(self, run: Run) -> None:
|
||||||
|
"""Logs a LangChain Run to W*B as a W&B Trace."""
|
||||||
|
self._ensure_run()
|
||||||
|
|
||||||
|
try:
|
||||||
|
root_span = _convert_lc_run_to_wb_span(self._trace_tree, run)
|
||||||
|
except Exception as e:
|
||||||
|
if PRINT_WARNINGS:
|
||||||
|
self._wandb.termwarn(
|
||||||
|
f"Skipping trace saving - unable to safely convert LangChain Run "
|
||||||
|
f"into W&B Trace due to: {e}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
model_dict = None
|
||||||
|
|
||||||
|
# TODO: Add something like this once we have a way to get the clean serialized
|
||||||
|
# parent dict from a run:
|
||||||
|
# serialized_parent = safely_get_span_producing_model(run)
|
||||||
|
# if serialized_parent is not None:
|
||||||
|
# model_dict = safely_convert_model_to_dict(serialized_parent)
|
||||||
|
|
||||||
|
model_trace = self._trace_tree.WBTraceTree(
|
||||||
|
root_span=root_span,
|
||||||
|
model_dict=model_dict,
|
||||||
|
)
|
||||||
|
if self._wandb.run is not None:
|
||||||
|
self._wandb.run.log({"langchain_trace": model_trace})
|
||||||
|
|
||||||
|
def _ensure_run(self, should_print_url: bool = False) -> None:
|
||||||
|
"""Ensures an active W&B run exists.
|
||||||
|
|
||||||
|
If not, will start a new run with the provided run_args.
|
||||||
|
"""
|
||||||
|
if self._wandb.run is None:
|
||||||
|
# Make a shallow copy of the run args, so we don't modify the original
|
||||||
|
run_args = self._run_args or {} # type: ignore
|
||||||
|
run_args: dict = {**run_args} # type: ignore
|
||||||
|
|
||||||
|
# Prefer to run in silent mode since W&B has a lot of output
|
||||||
|
# which can be undesirable when dealing with text-based models.
|
||||||
|
if "settings" not in run_args: # type: ignore
|
||||||
|
run_args["settings"] = {"silent": True} # type: ignore
|
||||||
|
|
||||||
|
# Start the run and add the stream table
|
||||||
|
self._wandb.init(**run_args)
|
||||||
|
if self._wandb.run is not None:
|
||||||
|
if should_print_url:
|
||||||
|
run_url = self._wandb.run.settings.run_url
|
||||||
|
self._wandb.termlog(
|
||||||
|
f"Streaming LangChain activity to W&B at {run_url}\n"
|
||||||
|
"`WandbTracer` is currently in beta.\n"
|
||||||
|
"Please report any issues to "
|
||||||
|
"https://github.com/wandb/wandb/issues with the tag "
|
||||||
|
"`langchain`."
|
||||||
|
)
|
||||||
|
|
||||||
|
self._wandb.run._label(repo="langchain")
|
||||||
|
|
||||||
|
def _persist_run(self, run: "Run") -> None:
|
||||||
|
"""Persist a run."""
|
||||||
|
self._log_trace_from_run(run)
|
@ -200,9 +200,9 @@ class WandbCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
|
|||||||
notes=self.notes,
|
notes=self.notes,
|
||||||
)
|
)
|
||||||
warning = (
|
warning = (
|
||||||
"The wandb callback is currently in beta and is subject to change "
|
"DEPRECATION: The `WandbCallbackHandler` will soon be deprecated in favor "
|
||||||
"based on updates to `langchain`. Please report any issues to "
|
"of the `WandbTracer`. Please update your code to use the `WandbTracer` "
|
||||||
"https://github.com/wandb/wandb/issues with the tag `langchain`."
|
"instead."
|
||||||
)
|
)
|
||||||
wandb.termwarn(
|
wandb.termwarn(
|
||||||
warning,
|
warning,
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Generic utility functions."""
|
"""Generic utility functions."""
|
||||||
import contextlib
|
import contextlib
|
||||||
import datetime
|
import datetime
|
||||||
|
import importlib
|
||||||
import os
|
import os
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
@ -115,3 +116,18 @@ def mock_now(dt_value): # type: ignore
|
|||||||
yield datetime.datetime
|
yield datetime.datetime
|
||||||
finally:
|
finally:
|
||||||
datetime.datetime = real_datetime
|
datetime.datetime = real_datetime
|
||||||
|
|
||||||
|
|
||||||
|
def guard_import(
|
||||||
|
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
|
||||||
|
) -> Any:
|
||||||
|
"""Dynamically imports a module and raises a helpful exception if the module is not
|
||||||
|
installed."""
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name, package)
|
||||||
|
except ImportError:
|
||||||
|
raise ImportError(
|
||||||
|
f"Could not import {module_name} python package. "
|
||||||
|
f"Please install it with `pip install {pip_name or module_name}`."
|
||||||
|
)
|
||||||
|
return module
|
||||||
|
@ -3,7 +3,6 @@
|
|||||||
The vector store can be persisted in json, bson or parquet format.
|
The vector store can be persisted in json, bson or parquet format.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import importlib
|
|
||||||
import json
|
import json
|
||||||
import math
|
import math
|
||||||
import os
|
import os
|
||||||
@ -13,6 +12,7 @@ from uuid import uuid4
|
|||||||
|
|
||||||
from langchain.docstore.document import Document
|
from langchain.docstore.document import Document
|
||||||
from langchain.embeddings.base import Embeddings
|
from langchain.embeddings.base import Embeddings
|
||||||
|
from langchain.utils import guard_import
|
||||||
from langchain.vectorstores.base import VectorStore
|
from langchain.vectorstores.base import VectorStore
|
||||||
from langchain.vectorstores.utils import maximal_marginal_relevance
|
from langchain.vectorstores.utils import maximal_marginal_relevance
|
||||||
|
|
||||||
@ -20,21 +20,6 @@ DEFAULT_K = 4 # Number of Documents to return.
|
|||||||
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
|
DEFAULT_FETCH_K = 20 # Number of Documents to initially fetch during MMR search.
|
||||||
|
|
||||||
|
|
||||||
def guard_import(
|
|
||||||
module_name: str, *, pip_name: Optional[str] = None, package: Optional[str] = None
|
|
||||||
) -> Any:
|
|
||||||
"""Dynamically imports a module and raises a helpful exception if the module is not
|
|
||||||
installed."""
|
|
||||||
try:
|
|
||||||
module = importlib.import_module(module_name, package)
|
|
||||||
except ImportError:
|
|
||||||
raise ImportError(
|
|
||||||
f"Could not import {module_name} python package. "
|
|
||||||
f"Please install it with `pip install {pip_name or module_name}`."
|
|
||||||
)
|
|
||||||
return module
|
|
||||||
|
|
||||||
|
|
||||||
class BaseSerializer(ABC):
|
class BaseSerializer(ABC):
|
||||||
"""Abstract base class for saving and loading data."""
|
"""Abstract base class for saving and loading data."""
|
||||||
|
|
||||||
|
117
tests/integration_tests/callbacks/test_wandb_tracer.py
Normal file
117
tests/integration_tests/callbacks/test_wandb_tracer.py
Normal file
@ -0,0 +1,117 @@
|
|||||||
|
"""Integration tests for the langchain tracer module."""
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import ClientSession
|
||||||
|
|
||||||
|
from langchain.agents import AgentType, initialize_agent, load_tools
|
||||||
|
from langchain.callbacks.manager import wandb_tracing_enabled
|
||||||
|
from langchain.llms import OpenAI
|
||||||
|
|
||||||
|
questions = [
|
||||||
|
(
|
||||||
|
"Who won the US Open men's final in 2019? "
|
||||||
|
"What is his age raised to the 0.334 power?"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Who is Olivia Wilde's boyfriend? "
|
||||||
|
"What is his current age raised to the 0.23 power?"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Who won the most recent formula 1 grand prix? "
|
||||||
|
"What is their age raised to the 0.23 power?"
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"Who won the US Open women's final in 2019? "
|
||||||
|
"What is her age raised to the 0.34 power?"
|
||||||
|
),
|
||||||
|
("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracing_sequential() -> None:
|
||||||
|
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||||
|
os.environ["WANDB_PROJECT"] = "langchain-tracing"
|
||||||
|
|
||||||
|
for q in questions[:3]:
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
tools = load_tools(
|
||||||
|
["llm-math", "serpapi"],
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
agent.run(q)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracing_session_env_var() -> None:
|
||||||
|
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||||
|
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
tools = load_tools(
|
||||||
|
["llm-math", "serpapi"],
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
agent.run(questions[0])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tracing_concurrent() -> None:
|
||||||
|
os.environ["LANGCHAIN_WANDB_TRACING"] = "true"
|
||||||
|
aiosession = ClientSession()
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
async_tools = load_tools(
|
||||||
|
["llm-math", "serpapi"],
|
||||||
|
llm=llm,
|
||||||
|
aiosession=aiosession,
|
||||||
|
)
|
||||||
|
agent = initialize_agent(
|
||||||
|
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
tasks = [agent.arun(q) for q in questions[:3]]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
await aiosession.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_tracing_context_manager() -> None:
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
tools = load_tools(
|
||||||
|
["llm-math", "serpapi"],
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
agent = initialize_agent(
|
||||||
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
if "LANGCHAIN_WANDB_TRACING" in os.environ:
|
||||||
|
del os.environ["LANGCHAIN_WANDB_TRACING"]
|
||||||
|
with wandb_tracing_enabled():
|
||||||
|
agent.run(questions[0]) # this should be traced
|
||||||
|
|
||||||
|
agent.run(questions[0]) # this should not be traced
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tracing_context_manager_async() -> None:
|
||||||
|
llm = OpenAI(temperature=0)
|
||||||
|
async_tools = load_tools(
|
||||||
|
["llm-math", "serpapi"],
|
||||||
|
llm=llm,
|
||||||
|
)
|
||||||
|
agent = initialize_agent(
|
||||||
|
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
||||||
|
)
|
||||||
|
if "LANGCHAIN_WANDB_TRACING" in os.environ:
|
||||||
|
del os.environ["LANGCHAIN_TRACING"]
|
||||||
|
|
||||||
|
# start a background task
|
||||||
|
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
|
||||||
|
with wandb_tracing_enabled():
|
||||||
|
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
await task
|
Loading…
Reference in New Issue
Block a user