mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
google-genai[patch]: added parsing of function call / response (#17245)
This commit is contained in:
@@ -16,18 +16,18 @@ from typing import (
|
||||
Optional,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import google.ai.generativelanguage as glm
|
||||
import google.api_core
|
||||
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
import proto # type: ignore[import]
|
||||
import requests
|
||||
from google.ai.generativelanguage_v1beta import FunctionCall
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -35,13 +35,9 @@ from langchain_core.callbacks.manager import (
|
||||
from langchain_core.language_models.chat_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
@@ -327,15 +323,30 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
||||
continue
|
||||
elif isinstance(message, AIMessage):
|
||||
role = "model"
|
||||
# TODO: Handle AImessage with function call
|
||||
parts = _convert_to_parts(message.content)
|
||||
raw_function_call = message.additional_kwargs.get("function_call")
|
||||
if raw_function_call:
|
||||
function_call = glm.FunctionCall(
|
||||
{
|
||||
"name": raw_function_call["name"],
|
||||
"args": json.loads(raw_function_call["arguments"]),
|
||||
}
|
||||
)
|
||||
parts = [glm.Part(function_call=function_call)]
|
||||
else:
|
||||
parts = _convert_to_parts(message.content)
|
||||
elif isinstance(message, HumanMessage):
|
||||
role = "user"
|
||||
parts = _convert_to_parts(message.content)
|
||||
elif isinstance(message, FunctionMessage):
|
||||
role = "user"
|
||||
# TODO: Handle FunctionMessage
|
||||
parts = _convert_to_parts(message.content)
|
||||
parts = [
|
||||
glm.Part(
|
||||
function_response=glm.FunctionResponse(
|
||||
name=message.name,
|
||||
response=message.content,
|
||||
)
|
||||
)
|
||||
]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unexpected message with type {type(message)} at the position {i}."
|
||||
@@ -353,100 +364,44 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
||||
return messages
|
||||
|
||||
|
||||
def _retrieve_function_call_response(
|
||||
parts: List[genai.types.PartType],
|
||||
) -> Optional[Dict]:
|
||||
for idx, part in enumerate(parts):
|
||||
if part.function_call and part.function_call.name:
|
||||
fc: FunctionCall = part.function_call
|
||||
return {
|
||||
"function_call": {
|
||||
"name": fc.name,
|
||||
"arguments": json.dumps(
|
||||
dict(fc.args.items())
|
||||
), # dump to match other function calling llms for now
|
||||
}
|
||||
}
|
||||
return None
|
||||
def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
||||
first_part = response_candidate.content.parts[0]
|
||||
if first_part.function_call:
|
||||
function_call = proto.Message.to_dict(first_part.function_call)
|
||||
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
||||
return AIMessage(content="", additional_kwargs={"function_call": function_call})
|
||||
else:
|
||||
parts = response_candidate.content.parts
|
||||
|
||||
|
||||
def _parts_to_content(
|
||||
parts: List[genai.types.PartType],
|
||||
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
|
||||
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
||||
function_call_resp = _retrieve_function_call_response(parts)
|
||||
|
||||
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
||||
# Simple text response. The typical response
|
||||
return parts[0].text, function_call_resp
|
||||
elif not parts:
|
||||
logger.warning("Gemini produced an empty response.")
|
||||
return "", function_call_resp
|
||||
messages: List[Union[Dict[Any, Any], str]] = []
|
||||
for part in parts:
|
||||
if part.text is not None:
|
||||
messages.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": part.text,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# TODO: Handle inline_data if that's a thing?
|
||||
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
||||
return messages, function_call_resp
|
||||
if len(parts) == 1 and parts[0].text:
|
||||
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
||||
else:
|
||||
content = [proto.Message.to_dict(part) for part in parts]
|
||||
return AIMessage(content=content, additional_kwargs={})
|
||||
|
||||
|
||||
def _response_to_result(
|
||||
response: genai.types.GenerateContentResponse,
|
||||
ai_msg_t: Type[BaseMessage] = AIMessage,
|
||||
human_msg_t: Type[BaseMessage] = HumanMessage,
|
||||
chat_msg_t: Type[BaseMessage] = ChatMessage,
|
||||
generation_t: Type[ChatGeneration] = ChatGeneration,
|
||||
response: glm.GenerateContentResponse,
|
||||
) -> ChatResult:
|
||||
"""Converts a PaLM API response into a LangChain ChatResult."""
|
||||
llm_output = {}
|
||||
if response.prompt_feedback:
|
||||
try:
|
||||
prompt_feedback = type(response.prompt_feedback).to_dict(
|
||||
response.prompt_feedback, use_integers_for_enums=False
|
||||
)
|
||||
llm_output["prompt_feedback"] = prompt_feedback
|
||||
except Exception as e:
|
||||
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
|
||||
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
||||
|
||||
generations: List[ChatGeneration] = []
|
||||
|
||||
role_map = {
|
||||
"model": ai_msg_t,
|
||||
"user": human_msg_t,
|
||||
}
|
||||
|
||||
for candidate in response.candidates:
|
||||
content = candidate.content
|
||||
parts_content, additional_kwargs = _parts_to_content(content.parts)
|
||||
if content.role not in role_map:
|
||||
logger.warning(
|
||||
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
||||
)
|
||||
msg = chat_msg_t(
|
||||
content=parts_content,
|
||||
role=content.role,
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
else:
|
||||
msg = role_map[content.role](
|
||||
content=parts_content,
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
generation_info = {}
|
||||
if candidate.finish_reason:
|
||||
generation_info["finish_reason"] = candidate.finish_reason.name
|
||||
if candidate.safety_ratings:
|
||||
generation_info["safety_ratings"] = [
|
||||
type(rating).to_dict(rating) for rating in candidate.safety_ratings
|
||||
]
|
||||
generations.append(generation_t(message=msg, generation_info=generation_info))
|
||||
generation_info["safety_ratings"] = [
|
||||
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
|
||||
for safety_rating in candidate.safety_ratings
|
||||
]
|
||||
generations.append(
|
||||
ChatGeneration(
|
||||
message=_parse_response_candidate(candidate),
|
||||
generation_info=generation_info,
|
||||
)
|
||||
)
|
||||
if not response.candidates:
|
||||
# Likely a "prompt feedback" violation (e.g., toxic input)
|
||||
# Raising an error would be different than how OpenAI handles it,
|
||||
@@ -455,7 +410,9 @@ def _response_to_result(
|
||||
"Gemini produced an empty response. Continuing with empty message\n"
|
||||
f"Feedback: {response.prompt_feedback}"
|
||||
)
|
||||
generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
|
||||
generations = [
|
||||
ChatGeneration(message=AIMessage(content=""), generation_info={})
|
||||
]
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
|
||||
@@ -616,13 +573,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
stream=True,
|
||||
)
|
||||
for chunk in response:
|
||||
_chat_result = _response_to_result(
|
||||
chunk,
|
||||
ai_msg_t=AIMessageChunk,
|
||||
human_msg_t=HumanMessageChunk,
|
||||
chat_msg_t=ChatMessageChunk,
|
||||
generation_t=ChatGenerationChunk,
|
||||
)
|
||||
_chat_result = _response_to_result(chunk)
|
||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(gen.text)
|
||||
@@ -646,13 +597,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
generation_method=chat.send_message_async,
|
||||
stream=True,
|
||||
):
|
||||
_chat_result = _response_to_result(
|
||||
chunk,
|
||||
ai_msg_t=AIMessageChunk,
|
||||
human_msg_t=HumanMessageChunk,
|
||||
chat_msg_t=ChatMessageChunk,
|
||||
generation_t=ChatGenerationChunk,
|
||||
)
|
||||
_chat_result = _response_to_result(chunk)
|
||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(gen.text)
|
||||
|
Reference in New Issue
Block a user