google-genai[patch]: added parsing of function call / response (#17245)

This commit is contained in:
Leonid Kuligin
2024-02-08 22:34:46 +01:00
committed by GitHub
parent a210a8bc53
commit 1862900078

View File

@@ -16,18 +16,18 @@ from typing import (
Optional, Optional,
Sequence, Sequence,
Tuple, Tuple,
Type,
Union, Union,
cast, cast,
) )
from urllib.parse import urlparse from urllib.parse import urlparse
import google.ai.generativelanguage as glm
import google.api_core import google.api_core
# TODO: remove ignore once the google package is published with types # TODO: remove ignore once the google package is published with types
import google.generativeai as genai # type: ignore[import] import google.generativeai as genai # type: ignore[import]
import proto # type: ignore[import]
import requests import requests
from google.ai.generativelanguage_v1beta import FunctionCall
from langchain_core.callbacks.manager import ( from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
@@ -35,13 +35,9 @@ from langchain_core.callbacks.manager import (
from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk,
BaseMessage, BaseMessage,
ChatMessage,
ChatMessageChunk,
FunctionMessage, FunctionMessage,
HumanMessage, HumanMessage,
HumanMessageChunk,
SystemMessage, SystemMessage,
) )
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
@@ -327,15 +323,30 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
continue continue
elif isinstance(message, AIMessage): elif isinstance(message, AIMessage):
role = "model" role = "model"
# TODO: Handle AImessage with function call raw_function_call = message.additional_kwargs.get("function_call")
parts = _convert_to_parts(message.content) if raw_function_call:
function_call = glm.FunctionCall(
{
"name": raw_function_call["name"],
"args": json.loads(raw_function_call["arguments"]),
}
)
parts = [glm.Part(function_call=function_call)]
else:
parts = _convert_to_parts(message.content)
elif isinstance(message, HumanMessage): elif isinstance(message, HumanMessage):
role = "user" role = "user"
parts = _convert_to_parts(message.content) parts = _convert_to_parts(message.content)
elif isinstance(message, FunctionMessage): elif isinstance(message, FunctionMessage):
role = "user" role = "user"
# TODO: Handle FunctionMessage parts = [
parts = _convert_to_parts(message.content) glm.Part(
function_response=glm.FunctionResponse(
name=message.name,
response=message.content,
)
)
]
else: else:
raise ValueError( raise ValueError(
f"Unexpected message with type {type(message)} at the position {i}." f"Unexpected message with type {type(message)} at the position {i}."
@@ -353,100 +364,44 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
return messages return messages
def _retrieve_function_call_response( def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
parts: List[genai.types.PartType], first_part = response_candidate.content.parts[0]
) -> Optional[Dict]: if first_part.function_call:
for idx, part in enumerate(parts): function_call = proto.Message.to_dict(first_part.function_call)
if part.function_call and part.function_call.name: function_call["arguments"] = json.dumps(function_call.pop("args", {}))
fc: FunctionCall = part.function_call return AIMessage(content="", additional_kwargs={"function_call": function_call})
return { else:
"function_call": { parts = response_candidate.content.parts
"name": fc.name,
"arguments": json.dumps(
dict(fc.args.items())
), # dump to match other function calling llms for now
}
}
return None
if len(parts) == 1 and parts[0].text:
def _parts_to_content( content: Union[str, List[Union[str, Dict]]] = parts[0].text
parts: List[genai.types.PartType], else:
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]: content = [proto.Message.to_dict(part) for part in parts]
"""Converts a list of Gemini API Part objects into a list of LangChain messages.""" return AIMessage(content=content, additional_kwargs={})
function_call_resp = _retrieve_function_call_response(parts)
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
# Simple text response. The typical response
return parts[0].text, function_call_resp
elif not parts:
logger.warning("Gemini produced an empty response.")
return "", function_call_resp
messages: List[Union[Dict[Any, Any], str]] = []
for part in parts:
if part.text is not None:
messages.append(
{
"type": "text",
"text": part.text,
}
)
else:
# TODO: Handle inline_data if that's a thing?
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
return messages, function_call_resp
def _response_to_result( def _response_to_result(
response: genai.types.GenerateContentResponse, response: glm.GenerateContentResponse,
ai_msg_t: Type[BaseMessage] = AIMessage,
human_msg_t: Type[BaseMessage] = HumanMessage,
chat_msg_t: Type[BaseMessage] = ChatMessage,
generation_t: Type[ChatGeneration] = ChatGeneration,
) -> ChatResult: ) -> ChatResult:
"""Converts a PaLM API response into a LangChain ChatResult.""" """Converts a PaLM API response into a LangChain ChatResult."""
llm_output = {} llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
if response.prompt_feedback:
try:
prompt_feedback = type(response.prompt_feedback).to_dict(
response.prompt_feedback, use_integers_for_enums=False
)
llm_output["prompt_feedback"] = prompt_feedback
except Exception as e:
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
generations: List[ChatGeneration] = [] generations: List[ChatGeneration] = []
role_map = {
"model": ai_msg_t,
"user": human_msg_t,
}
for candidate in response.candidates: for candidate in response.candidates:
content = candidate.content
parts_content, additional_kwargs = _parts_to_content(content.parts)
if content.role not in role_map:
logger.warning(
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
)
msg = chat_msg_t(
content=parts_content,
role=content.role,
additional_kwargs=additional_kwargs or {},
)
else:
msg = role_map[content.role](
content=parts_content,
additional_kwargs=additional_kwargs or {},
)
generation_info = {} generation_info = {}
if candidate.finish_reason: if candidate.finish_reason:
generation_info["finish_reason"] = candidate.finish_reason.name generation_info["finish_reason"] = candidate.finish_reason.name
if candidate.safety_ratings: generation_info["safety_ratings"] = [
generation_info["safety_ratings"] = [ proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
type(rating).to_dict(rating) for rating in candidate.safety_ratings for safety_rating in candidate.safety_ratings
] ]
generations.append(generation_t(message=msg, generation_info=generation_info)) generations.append(
ChatGeneration(
message=_parse_response_candidate(candidate),
generation_info=generation_info,
)
)
if not response.candidates: if not response.candidates:
# Likely a "prompt feedback" violation (e.g., toxic input) # Likely a "prompt feedback" violation (e.g., toxic input)
# Raising an error would be different than how OpenAI handles it, # Raising an error would be different than how OpenAI handles it,
@@ -455,7 +410,9 @@ def _response_to_result(
"Gemini produced an empty response. Continuing with empty message\n" "Gemini produced an empty response. Continuing with empty message\n"
f"Feedback: {response.prompt_feedback}" f"Feedback: {response.prompt_feedback}"
) )
generations = [generation_t(message=ai_msg_t(content=""), generation_info={})] generations = [
ChatGeneration(message=AIMessage(content=""), generation_info={})
]
return ChatResult(generations=generations, llm_output=llm_output) return ChatResult(generations=generations, llm_output=llm_output)
@@ -616,13 +573,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
stream=True, stream=True,
) )
for chunk in response: for chunk in response:
_chat_result = _response_to_result( _chat_result = _response_to_result(chunk)
chunk,
ai_msg_t=AIMessageChunk,
human_msg_t=HumanMessageChunk,
chat_msg_t=ChatMessageChunk,
generation_t=ChatGenerationChunk,
)
gen = cast(ChatGenerationChunk, _chat_result.generations[0]) gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager: if run_manager:
run_manager.on_llm_new_token(gen.text) run_manager.on_llm_new_token(gen.text)
@@ -646,13 +597,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
generation_method=chat.send_message_async, generation_method=chat.send_message_async,
stream=True, stream=True,
): ):
_chat_result = _response_to_result( _chat_result = _response_to_result(chunk)
chunk,
ai_msg_t=AIMessageChunk,
human_msg_t=HumanMessageChunk,
chat_msg_t=ChatMessageChunk,
generation_t=ChatGenerationChunk,
)
gen = cast(ChatGenerationChunk, _chat_result.generations[0]) gen = cast(ChatGenerationChunk, _chat_result.generations[0])
if run_manager: if run_manager:
await run_manager.on_llm_new_token(gen.text) await run_manager.on_llm_new_token(gen.text)