[LLMonitorCallbackHandler] Various improvements (#13151)

Small improvements for the llmonitor callback handler, like better
support for non-openai models.


---------

Co-authored-by: vincelwt <vince@lyser.io>
This commit is contained in:
Hugues Chocart 2023-11-17 08:39:36 +01:00 committed by GitHub
parent c1b041c188
commit 35e04f204b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,7 +4,7 @@ import os
import traceback import traceback
import warnings import warnings
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Dict, List, Literal, Union from typing import Any, Dict, List, Union, cast
from uuid import UUID from uuid import UUID
import requests import requests
@ -15,11 +15,30 @@ from langchain.schema.agent import AgentAction, AgentFinish
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
from langchain.schema.output import LLMResult from langchain.schema.output import LLMResult
logger = logging.getLogger(__name__)
DEFAULT_API_URL = "https://app.llmonitor.com" DEFAULT_API_URL = "https://app.llmonitor.com"
user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None) user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None) user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
PARAMS_TO_CAPTURE = [
"temperature",
"top_p",
"top_k",
"stop",
"presence_penalty",
"frequence_penalty",
"seed",
"function_call",
"functions",
"tools",
"tool_choice",
"response_format",
"max_tokens",
"logit_bias",
]
class UserContextManager: class UserContextManager:
"""Context manager for LLMonitor user context.""" """Context manager for LLMonitor user context."""
@ -66,6 +85,10 @@ def _parse_input(raw_input: Any) -> Any:
if not raw_input: if not raw_input:
return None return None
# if it's an array of 1, just parse the first element
if isinstance(raw_input, list) and len(raw_input) == 1:
return _parse_input(raw_input[0])
if not isinstance(raw_input, dict): if not isinstance(raw_input, dict):
return _serialize(raw_input) return _serialize(raw_input)
@ -115,17 +138,11 @@ def _parse_output(raw_output: dict) -> Any:
def _parse_lc_role( def _parse_lc_role(
role: str, role: str,
) -> Union[Literal["user", "ai", "system", "function"], None]: ) -> str:
if role == "human": if role == "human":
return "user" return "user"
elif role == "ai":
return "ai"
elif role == "system":
return "system"
elif role == "function":
return "function"
else: else:
return None return role
def _get_user_id(metadata: Any) -> Any: def _get_user_id(metadata: Any) -> Any:
@ -148,13 +165,15 @@ def _get_user_props(metadata: Any) -> Any:
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]: def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
keys = ["function_call", "tool_calls", "tool_call_id", "name"]
parsed = {"text": message.content, "role": _parse_lc_role(message.type)} parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
parsed.update(
function_call = (message.additional_kwargs or {}).get("function_call") {
key: cast(Any, message.additional_kwargs.get(key))
if function_call is not None: for key in keys
parsed["functionCall"] = function_call if message.additional_kwargs.get(key) is not None
}
)
return parsed return parsed
@ -213,19 +232,20 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
self.__track_event = llmonitor.track_event self.__track_event = llmonitor.track_event
except ImportError: except ImportError:
warnings.warn( logger.warning(
"""[LLMonitor] To use the LLMonitor callback handler you need to """[LLMonitor] To use the LLMonitor callback handler you need to
have the `llmonitor` Python package installed. Please install it have the `llmonitor` Python package installed. Please install it
with `pip install llmonitor`""" with `pip install llmonitor`"""
) )
self.__has_valid_config = False self.__has_valid_config = False
return
if parse(self.__llmonitor_version) < parse("0.0.20"): if parse(self.__llmonitor_version) < parse("0.0.32"):
warnings.warn( logger.warning(
f"""[LLMonitor] The installed `llmonitor` version is f"""[LLMonitor] The installed `llmonitor` version is
{self.__llmonitor_version} but `LLMonitorCallbackHandler` requires {self.__llmonitor_version}
at least version 0.0.20 upgrade `llmonitor` with `pip install but `LLMonitorCallbackHandler` requires at least version 0.0.32
--upgrade llmonitor`""" upgrade `llmonitor` with `pip install --upgrade llmonitor`"""
) )
self.__has_valid_config = False self.__has_valid_config = False
@ -236,9 +256,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
_app_id = app_id or os.getenv("LLMONITOR_APP_ID") _app_id = app_id or os.getenv("LLMONITOR_APP_ID")
if _app_id is None: if _app_id is None:
warnings.warn( logger.warning(
"""[LLMonitor] app_id must be provided either as an argument or as """[LLMonitor] app_id must be provided either as an argument or
an environment variable""" as an environment variable"""
) )
self.__has_valid_config = False self.__has_valid_config = False
else: else:
@ -252,7 +272,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
if not res.ok: if not res.ok:
raise ConnectionError() raise ConnectionError()
except Exception: except Exception:
warnings.warn( logger.warning(
f"""[LLMonitor] Could not connect to the LLMonitor API at f"""[LLMonitor] Could not connect to the LLMonitor API at
{self.__api_url}""" {self.__api_url}"""
) )
@ -273,7 +293,27 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
try: try:
user_id = _get_user_id(metadata) user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata) user_props = _get_user_props(metadata)
name = kwargs.get("invocation_params", {}).get("model_name")
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_input(prompts) input = _parse_input(prompts)
self.__track_event( self.__track_event(
@ -285,8 +325,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
name=name, name=name,
input=input, input=input,
tags=tags, tags=tags,
extra=extra,
metadata=metadata, metadata=metadata,
user_props=user_props, user_props=user_props,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}") warnings.warn(f"[LLMonitor] An error occurred in on_llm_start: {e}")
@ -304,10 +346,31 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
) -> Any: ) -> Any:
if self.__has_valid_config is False: if self.__has_valid_config is False:
return return
try: try:
user_id = _get_user_id(metadata) user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata) user_props = _get_user_props(metadata)
name = kwargs.get("invocation_params", {}).get("model_name")
params = kwargs.get("invocation_params", {})
params.update(
serialized.get("kwargs", {})
) # Sometimes, for example with ChatAnthropic, `invocation_params` is empty
name = (
params.get("model")
or params.get("model_name")
or params.get("model_id")
)
if not name and "anthropic" in params.get("_type"):
name = "claude-2"
extra = {
param: params.get(param)
for param in PARAMS_TO_CAPTURE
if params.get(param) is not None
}
input = _parse_lc_messages(messages[0]) input = _parse_lc_messages(messages[0])
self.__track_event( self.__track_event(
@ -319,13 +382,13 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
name=name, name=name,
input=input, input=input,
tags=tags, tags=tags,
extra=extra,
metadata=metadata, metadata=metadata,
user_props=user_props, user_props=user_props,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
logging.warning( logger.error(f"[LLMonitor] An error occurred in on_chat_model_start: {e}")
f"[LLMonitor] An error occurred in on_chat_model_start: {e}"
)
def on_llm_end( def on_llm_end(
self, self,
@ -340,25 +403,18 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
try: try:
token_usage = (response.llm_output or {}).get("token_usage", {}) token_usage = (response.llm_output or {}).get("token_usage", {})
parsed_output = [
{ parsed_output: Any = [
"text": generation.text, _parse_lc_message(generation.message)
"role": "ai",
**(
{
"functionCall": generation.message.additional_kwargs[
"function_call"
]
}
if hasattr(generation, "message") if hasattr(generation, "message")
and hasattr(generation.message, "additional_kwargs") else generation.text
and "function_call" in generation.message.additional_kwargs
else {}
),
}
for generation in response.generations[0] for generation in response.generations[0]
] ]
# if it's an array of 1, just parse the first element
if len(parsed_output) == 1:
parsed_output = parsed_output[0]
self.__track_event( self.__track_event(
"llm", "llm",
"end", "end",
@ -369,9 +425,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
"prompt": token_usage.get("prompt_tokens"), "prompt": token_usage.get("prompt_tokens"),
"completion": token_usage.get("completion_tokens"), "completion": token_usage.get("completion_tokens"),
}, },
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_llm_end: {e}") logger.error(f"[LLMonitor] An error occurred in on_llm_end: {e}")
def on_tool_start( def on_tool_start(
self, self,
@ -402,9 +459,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
tags=tags, tags=tags,
metadata=metadata, metadata=metadata,
user_props=user_props, user_props=user_props,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_tool_start: {e}") logger.error(f"[LLMonitor] An error occurred in on_tool_start: {e}")
def on_tool_end( def on_tool_end(
self, self,
@ -424,9 +482,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id), run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None, parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output, output=output,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_tool_end: {e}") logger.error(f"[LLMonitor] An error occurred in on_tool_end: {e}")
def on_chain_start( def on_chain_start(
self, self,
@ -473,9 +532,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
tags=tags, tags=tags,
metadata=metadata, metadata=metadata,
user_props=user_props, user_props=user_props,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
warnings.warn(f"[LLMonitor] An error occurred in on_chain_start: {e}") logger.error(f"[LLMonitor] An error occurred in on_chain_start: {e}")
def on_chain_end( def on_chain_end(
self, self,
@ -496,9 +556,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id), run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None, parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output, output=output,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_chain_end: {e}") logger.error(f"[LLMonitor] An error occurred in on_chain_end: {e}")
def on_agent_action( def on_agent_action(
self, self,
@ -521,9 +582,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
parent_run_id=str(parent_run_id) if parent_run_id else None, parent_run_id=str(parent_run_id) if parent_run_id else None,
name=name, name=name,
input=input, input=input,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_agent_action: {e}") logger.error(f"[LLMonitor] An error occurred in on_agent_action: {e}")
def on_agent_finish( def on_agent_finish(
self, self,
@ -544,9 +606,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id), run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None, parent_run_id=str(parent_run_id) if parent_run_id else None,
output=output, output=output,
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_agent_finish: {e}") logger.error(f"[LLMonitor] An error occurred in on_agent_finish: {e}")
def on_chain_error( def on_chain_error(
self, self,
@ -565,9 +628,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id), run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None, parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()}, error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_chain_error: {e}") logger.error(f"[LLMonitor] An error occurred in on_chain_error: {e}")
def on_tool_error( def on_tool_error(
self, self,
@ -586,9 +650,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id), run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None, parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()}, error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_tool_error: {e}") logger.error(f"[LLMonitor] An error occurred in on_tool_error: {e}")
def on_llm_error( def on_llm_error(
self, self,
@ -607,9 +672,10 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
run_id=str(run_id), run_id=str(run_id),
parent_run_id=str(parent_run_id) if parent_run_id else None, parent_run_id=str(parent_run_id) if parent_run_id else None,
error={"message": str(error), "stack": traceback.format_exc()}, error={"message": str(error), "stack": traceback.format_exc()},
app_id=self.__app_id,
) )
except Exception as e: except Exception as e:
logging.warning(f"[LLMonitor] An error occurred in on_llm_error: {e}") logger.error(f"[LLMonitor] An error occurred in on_llm_error: {e}")
__all__ = ["LLMonitorCallbackHandler", "identify"] __all__ = ["LLMonitorCallbackHandler", "identify"]