LLMonitor Callback handler: fix bug (#11128)

Here is a small bug fix for the LLMonitor callback handler. I've also
added user identification capabilities.
This commit is contained in:
Hugues 2023-09-29 00:00:38 +02:00 committed by GitHub
parent e9b51513e9
commit b599f91e33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 13 deletions

View File

@ -37,10 +37,10 @@ llm = OpenAI(
callbacks=[handler], callbacks=[handler],
) )
chat = ChatOpenAI( chat = ChatOpenAI(callbacks=[handler])
callbacks=[handler],
metadata={"userId": "123"}, # you can assign user ids to models in the metadata llm("Tell me a joke")
)
``` ```
## Usage with chains and agents ## Usage with chains and agents
@ -100,6 +100,18 @@ agent.run(
) )
``` ```
## User Tracking
User tracking allows you to identify your users, track their cost, conversations and more.
```python
from langchain.callbacks.llmonitor_callback import LLMonitorCallbackHandler, identify
with identify("user-123"):
llm("Tell me a joke")
with identify("user-456", user_props={"email": "user456@test.com"}):
agen.run("Who is Leo DiCaprio's girlfriend?")
```
## Support ## Support
For any question or issue with integration you can reach out to the LLMonitor team on [Discord](http://discord.com/invite/8PafSG58kK) or via [email](mailto:vince@llmonitor.com). For any question or issue with integration you can reach out to the LLMonitor team on [Discord](http://discord.com/invite/8PafSG58kK) or via [email](mailto:vince@llmonitor.com).

View File

@ -1,5 +1,6 @@
import os import os
import traceback import traceback
from contextvars import ContextVar
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Literal, Union from typing import Any, Dict, List, Literal, Union
from uuid import UUID from uuid import UUID
@ -13,6 +14,26 @@ from langchain.schema.output import LLMResult
DEFAULT_API_URL = "https://app.llmonitor.com" DEFAULT_API_URL = "https://app.llmonitor.com"
user_ctx = ContextVar[Union[str, None]]("user_ctx", default=None)
user_props_ctx = ContextVar[Union[str, None]]("user_props_ctx", default=None)
class UserContextManager:
def __init__(self, user_id: str, user_props: Any = None) -> None:
user_ctx.set(user_id)
user_props_ctx.set(user_props)
def __enter__(self) -> Any:
pass
def __exit__(self, exc_type: Any, exc_value: Any, exc_tb: Any) -> Any:
user_ctx.set(None)
user_props_ctx.set(None)
def identify(user_id: str, user_props: Any = None) -> UserContextManager:
return UserContextManager(user_id, user_props)
def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]: def _serialize(obj: Any) -> Union[Dict[str, Any], List[Any], Any]:
if hasattr(obj, "to_json"): if hasattr(obj, "to_json"):
@ -94,13 +115,24 @@ def _parse_lc_role(
def _get_user_id(metadata: Any) -> Any: def _get_user_id(metadata: Any) -> Any:
if user_ctx.get() is not None:
return user_ctx.get()
metadata = metadata or {} metadata = metadata or {}
user_id = metadata.get("user_id") user_id = metadata.get("user_id")
if user_id is None: if user_id is None:
user_id = metadata.get("userId") user_id = metadata.get("userId") # legacy, to delete in the future
return user_id return user_id
def _get_user_props(metadata: Any) -> Any:
if user_props_ctx.get() is not None:
return user_props_ctx.get()
metadata = metadata or {}
return metadata.get("user_props")
def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]: def _parse_lc_message(message: BaseMessage) -> Dict[str, Any]:
parsed = {"text": message.content, "role": _parse_lc_role(message.type)} parsed = {"text": message.content, "role": _parse_lc_role(message.type)}
@ -198,10 +230,13 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
metadata: Union[Dict[str, Any], None] = None, metadata: Union[Dict[str, Any], None] = None,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
event = { event = {
"event": "start", "event": "start",
"type": "llm", "type": "llm",
"userId": (metadata or {}).get("userId"), "userId": user_id,
"runId": str(run_id), "runId": str(run_id),
"parentRunId": str(parent_run_id) if parent_run_id else None, "parentRunId": str(parent_run_id) if parent_run_id else None,
"input": _parse_input(prompts), "input": _parse_input(prompts),
@ -209,6 +244,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
"tags": tags, "tags": tags,
"metadata": metadata, "metadata": metadata,
} }
if user_props:
event["userProps"] = user_props
self.__send_event(event) self.__send_event(event)
def on_chat_model_start( def on_chat_model_start(
@ -223,6 +261,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> Any: ) -> Any:
user_id = _get_user_id(metadata) user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
event = { event = {
"event": "start", "event": "start",
@ -235,6 +274,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
"tags": tags, "tags": tags,
"metadata": metadata, "metadata": metadata,
} }
if user_props:
event["userProps"] = user_props
self.__send_event(event) self.__send_event(event)
def on_llm_end( def on_llm_end(
@ -247,12 +289,24 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
) -> None: ) -> None:
token_usage = (response.llm_output or {}).get("token_usage", {}) token_usage = (response.llm_output or {}).get("token_usage", {})
parsed_output = _parse_lc_messages( parsed_output = [
map( {
lambda o: o.message if hasattr(o, "message") else None, "text": generation.text,
response.generations[0], "role": "ai",
) **(
) {
"functionCall": generation.message.additional_kwargs[
"function_call"
]
}
if hasattr(generation, "message")
and hasattr(generation.message, "additional_kwargs")
and "function_call" in generation.message.additional_kwargs
else {}
),
}
for generation in response.generations[0]
]
event = { event = {
"event": "end", "event": "end",
@ -279,6 +333,8 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
user_id = _get_user_id(metadata) user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
event = { event = {
"event": "start", "event": "start",
"type": "tool", "type": "tool",
@ -290,6 +346,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
"tags": tags, "tags": tags,
"metadata": metadata, "metadata": metadata,
} }
if user_props:
event["userProps"] = user_props
self.__send_event(event) self.__send_event(event)
def on_tool_end( def on_tool_end(
@ -339,6 +398,7 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
type = "chain" type = "chain"
user_id = _get_user_id(metadata) user_id = _get_user_id(metadata)
user_props = _get_user_props(metadata)
event = { event = {
"event": "start", "event": "start",
@ -351,6 +411,9 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
"metadata": metadata, "metadata": metadata,
"name": name, "name": name,
} }
if user_props:
event["userProps"] = user_props
self.__send_event(event) self.__send_event(event)
def on_chain_end( def on_chain_end(
@ -456,4 +519,4 @@ class LLMonitorCallbackHandler(BaseCallbackHandler):
self.__send_event(event) self.__send_event(event)
__all__ = ["LLMonitorCallbackHandler"] __all__ = ["LLMonitorCallbackHandler", "identify"]