mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-04 12:39:32 +00:00
google-genai[patch]: added parsing of function call / response (#17245)
This commit is contained in:
@@ -16,18 +16,18 @@ from typing import (
|
|||||||
Optional,
|
Optional,
|
||||||
Sequence,
|
Sequence,
|
||||||
Tuple,
|
Tuple,
|
||||||
Type,
|
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
)
|
)
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
|
import google.ai.generativelanguage as glm
|
||||||
import google.api_core
|
import google.api_core
|
||||||
|
|
||||||
# TODO: remove ignore once the google package is published with types
|
# TODO: remove ignore once the google package is published with types
|
||||||
import google.generativeai as genai # type: ignore[import]
|
import google.generativeai as genai # type: ignore[import]
|
||||||
|
import proto # type: ignore[import]
|
||||||
import requests
|
import requests
|
||||||
from google.ai.generativelanguage_v1beta import FunctionCall
|
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@@ -35,13 +35,9 @@ from langchain_core.callbacks.manager import (
|
|||||||
from langchain_core.language_models.chat_models import BaseChatModel
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
|
||||||
BaseMessage,
|
BaseMessage,
|
||||||
ChatMessage,
|
|
||||||
ChatMessageChunk,
|
|
||||||
FunctionMessage,
|
FunctionMessage,
|
||||||
HumanMessage,
|
HumanMessage,
|
||||||
HumanMessageChunk,
|
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
)
|
)
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
@@ -327,15 +323,30 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|||||||
continue
|
continue
|
||||||
elif isinstance(message, AIMessage):
|
elif isinstance(message, AIMessage):
|
||||||
role = "model"
|
role = "model"
|
||||||
# TODO: Handle AImessage with function call
|
raw_function_call = message.additional_kwargs.get("function_call")
|
||||||
parts = _convert_to_parts(message.content)
|
if raw_function_call:
|
||||||
|
function_call = glm.FunctionCall(
|
||||||
|
{
|
||||||
|
"name": raw_function_call["name"],
|
||||||
|
"args": json.loads(raw_function_call["arguments"]),
|
||||||
|
}
|
||||||
|
)
|
||||||
|
parts = [glm.Part(function_call=function_call)]
|
||||||
|
else:
|
||||||
|
parts = _convert_to_parts(message.content)
|
||||||
elif isinstance(message, HumanMessage):
|
elif isinstance(message, HumanMessage):
|
||||||
role = "user"
|
role = "user"
|
||||||
parts = _convert_to_parts(message.content)
|
parts = _convert_to_parts(message.content)
|
||||||
elif isinstance(message, FunctionMessage):
|
elif isinstance(message, FunctionMessage):
|
||||||
role = "user"
|
role = "user"
|
||||||
# TODO: Handle FunctionMessage
|
parts = [
|
||||||
parts = _convert_to_parts(message.content)
|
glm.Part(
|
||||||
|
function_response=glm.FunctionResponse(
|
||||||
|
name=message.name,
|
||||||
|
response=message.content,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Unexpected message with type {type(message)} at the position {i}."
|
f"Unexpected message with type {type(message)} at the position {i}."
|
||||||
@@ -353,100 +364,44 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def _retrieve_function_call_response(
|
def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage:
|
||||||
parts: List[genai.types.PartType],
|
first_part = response_candidate.content.parts[0]
|
||||||
) -> Optional[Dict]:
|
if first_part.function_call:
|
||||||
for idx, part in enumerate(parts):
|
function_call = proto.Message.to_dict(first_part.function_call)
|
||||||
if part.function_call and part.function_call.name:
|
function_call["arguments"] = json.dumps(function_call.pop("args", {}))
|
||||||
fc: FunctionCall = part.function_call
|
return AIMessage(content="", additional_kwargs={"function_call": function_call})
|
||||||
return {
|
else:
|
||||||
"function_call": {
|
parts = response_candidate.content.parts
|
||||||
"name": fc.name,
|
|
||||||
"arguments": json.dumps(
|
|
||||||
dict(fc.args.items())
|
|
||||||
), # dump to match other function calling llms for now
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
if len(parts) == 1 and parts[0].text:
|
||||||
def _parts_to_content(
|
content: Union[str, List[Union[str, Dict]]] = parts[0].text
|
||||||
parts: List[genai.types.PartType],
|
else:
|
||||||
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
|
content = [proto.Message.to_dict(part) for part in parts]
|
||||||
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
return AIMessage(content=content, additional_kwargs={})
|
||||||
function_call_resp = _retrieve_function_call_response(parts)
|
|
||||||
|
|
||||||
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
|
||||||
# Simple text response. The typical response
|
|
||||||
return parts[0].text, function_call_resp
|
|
||||||
elif not parts:
|
|
||||||
logger.warning("Gemini produced an empty response.")
|
|
||||||
return "", function_call_resp
|
|
||||||
messages: List[Union[Dict[Any, Any], str]] = []
|
|
||||||
for part in parts:
|
|
||||||
if part.text is not None:
|
|
||||||
messages.append(
|
|
||||||
{
|
|
||||||
"type": "text",
|
|
||||||
"text": part.text,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# TODO: Handle inline_data if that's a thing?
|
|
||||||
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
|
||||||
return messages, function_call_resp
|
|
||||||
|
|
||||||
|
|
||||||
def _response_to_result(
|
def _response_to_result(
|
||||||
response: genai.types.GenerateContentResponse,
|
response: glm.GenerateContentResponse,
|
||||||
ai_msg_t: Type[BaseMessage] = AIMessage,
|
|
||||||
human_msg_t: Type[BaseMessage] = HumanMessage,
|
|
||||||
chat_msg_t: Type[BaseMessage] = ChatMessage,
|
|
||||||
generation_t: Type[ChatGeneration] = ChatGeneration,
|
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
"""Converts a PaLM API response into a LangChain ChatResult."""
|
"""Converts a PaLM API response into a LangChain ChatResult."""
|
||||||
llm_output = {}
|
llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)}
|
||||||
if response.prompt_feedback:
|
|
||||||
try:
|
|
||||||
prompt_feedback = type(response.prompt_feedback).to_dict(
|
|
||||||
response.prompt_feedback, use_integers_for_enums=False
|
|
||||||
)
|
|
||||||
llm_output["prompt_feedback"] = prompt_feedback
|
|
||||||
except Exception as e:
|
|
||||||
logger.debug(f"Unable to convert prompt_feedback to dict: {e}")
|
|
||||||
|
|
||||||
generations: List[ChatGeneration] = []
|
generations: List[ChatGeneration] = []
|
||||||
|
|
||||||
role_map = {
|
|
||||||
"model": ai_msg_t,
|
|
||||||
"user": human_msg_t,
|
|
||||||
}
|
|
||||||
|
|
||||||
for candidate in response.candidates:
|
for candidate in response.candidates:
|
||||||
content = candidate.content
|
|
||||||
parts_content, additional_kwargs = _parts_to_content(content.parts)
|
|
||||||
if content.role not in role_map:
|
|
||||||
logger.warning(
|
|
||||||
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
|
||||||
)
|
|
||||||
msg = chat_msg_t(
|
|
||||||
content=parts_content,
|
|
||||||
role=content.role,
|
|
||||||
additional_kwargs=additional_kwargs or {},
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
msg = role_map[content.role](
|
|
||||||
content=parts_content,
|
|
||||||
additional_kwargs=additional_kwargs or {},
|
|
||||||
)
|
|
||||||
generation_info = {}
|
generation_info = {}
|
||||||
if candidate.finish_reason:
|
if candidate.finish_reason:
|
||||||
generation_info["finish_reason"] = candidate.finish_reason.name
|
generation_info["finish_reason"] = candidate.finish_reason.name
|
||||||
if candidate.safety_ratings:
|
generation_info["safety_ratings"] = [
|
||||||
generation_info["safety_ratings"] = [
|
proto.Message.to_dict(safety_rating, use_integers_for_enums=False)
|
||||||
type(rating).to_dict(rating) for rating in candidate.safety_ratings
|
for safety_rating in candidate.safety_ratings
|
||||||
]
|
]
|
||||||
generations.append(generation_t(message=msg, generation_info=generation_info))
|
generations.append(
|
||||||
|
ChatGeneration(
|
||||||
|
message=_parse_response_candidate(candidate),
|
||||||
|
generation_info=generation_info,
|
||||||
|
)
|
||||||
|
)
|
||||||
if not response.candidates:
|
if not response.candidates:
|
||||||
# Likely a "prompt feedback" violation (e.g., toxic input)
|
# Likely a "prompt feedback" violation (e.g., toxic input)
|
||||||
# Raising an error would be different than how OpenAI handles it,
|
# Raising an error would be different than how OpenAI handles it,
|
||||||
@@ -455,7 +410,9 @@ def _response_to_result(
|
|||||||
"Gemini produced an empty response. Continuing with empty message\n"
|
"Gemini produced an empty response. Continuing with empty message\n"
|
||||||
f"Feedback: {response.prompt_feedback}"
|
f"Feedback: {response.prompt_feedback}"
|
||||||
)
|
)
|
||||||
generations = [generation_t(message=ai_msg_t(content=""), generation_info={})]
|
generations = [
|
||||||
|
ChatGeneration(message=AIMessage(content=""), generation_info={})
|
||||||
|
]
|
||||||
return ChatResult(generations=generations, llm_output=llm_output)
|
return ChatResult(generations=generations, llm_output=llm_output)
|
||||||
|
|
||||||
|
|
||||||
@@ -616,13 +573,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
_chat_result = _response_to_result(
|
_chat_result = _response_to_result(chunk)
|
||||||
chunk,
|
|
||||||
ai_msg_t=AIMessageChunk,
|
|
||||||
human_msg_t=HumanMessageChunk,
|
|
||||||
chat_msg_t=ChatMessageChunk,
|
|
||||||
generation_t=ChatGenerationChunk,
|
|
||||||
)
|
|
||||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||||
if run_manager:
|
if run_manager:
|
||||||
run_manager.on_llm_new_token(gen.text)
|
run_manager.on_llm_new_token(gen.text)
|
||||||
@@ -646,13 +597,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
generation_method=chat.send_message_async,
|
generation_method=chat.send_message_async,
|
||||||
stream=True,
|
stream=True,
|
||||||
):
|
):
|
||||||
_chat_result = _response_to_result(
|
_chat_result = _response_to_result(chunk)
|
||||||
chunk,
|
|
||||||
ai_msg_t=AIMessageChunk,
|
|
||||||
human_msg_t=HumanMessageChunk,
|
|
||||||
chat_msg_t=ChatMessageChunk,
|
|
||||||
generation_t=ChatGenerationChunk,
|
|
||||||
)
|
|
||||||
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
gen = cast(ChatGenerationChunk, _chat_result.generations[0])
|
||||||
if run_manager:
|
if run_manager:
|
||||||
await run_manager.on_llm_new_token(gen.text)
|
await run_manager.on_llm_new_token(gen.text)
|
||||||
|
Reference in New Issue
Block a user