mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-08 22:15:08 +00:00
fix trubrics lint issue (#11202)
This commit is contained in:
parent
b738ccd91e
commit
8cd18a48e4
@ -2,10 +2,44 @@ import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.messages import BaseMessage
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
class TrubricsCallbackHandler(BaseCallbackHandler):
|
||||
@ -25,7 +59,7 @@ class TrubricsCallbackHandler(BaseCallbackHandler):
|
||||
project: str = "default",
|
||||
email: Optional[str] = None,
|
||||
password: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
try:
|
||||
@ -56,9 +90,9 @@ class TrubricsCallbackHandler(BaseCallbackHandler):
|
||||
self,
|
||||
serialized: Dict[str, Any],
|
||||
messages: List[List[BaseMessage]],
|
||||
**kwargs: Any
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
self.messages = [convert_message_to_dict(message) for message in messages[0]]
|
||||
self.messages = [_convert_message_to_dict(message) for message in messages[0]]
|
||||
self.prompt = self.messages[-1]["content"]
|
||||
|
||||
def on_llm_end(self, response: LLMResult, run_id: UUID, **kwargs: Any) -> None:
|
||||
|
Loading…
Reference in New Issue
Block a user