diff --git a/libs/partners/google-genai/langchain_google_genai/chat_models.py b/libs/partners/google-genai/langchain_google_genai/chat_models.py index b69ff5824f0..8d56e983c87 100644 --- a/libs/partners/google-genai/langchain_google_genai/chat_models.py +++ b/libs/partners/google-genai/langchain_google_genai/chat_models.py @@ -16,18 +16,18 @@ from typing import ( Optional, Sequence, Tuple, - Type, Union, cast, ) from urllib.parse import urlparse +import google.ai.generativelanguage as glm import google.api_core # TODO: remove ignore once the google package is published with types import google.generativeai as genai # type: ignore[import] +import proto # type: ignore[import] import requests -from google.ai.generativelanguage_v1beta import FunctionCall from langchain_core.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, @@ -35,13 +35,9 @@ from langchain_core.callbacks.manager import ( from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.messages import ( AIMessage, - AIMessageChunk, BaseMessage, - ChatMessage, - ChatMessageChunk, FunctionMessage, HumanMessage, - HumanMessageChunk, SystemMessage, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult @@ -327,15 +323,30 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human continue elif isinstance(message, AIMessage): role = "model" - # TODO: Handle AImessage with function call - parts = _convert_to_parts(message.content) + raw_function_call = message.additional_kwargs.get("function_call") + if raw_function_call: + function_call = glm.FunctionCall( + { + "name": raw_function_call["name"], + "args": json.loads(raw_function_call["arguments"]), + } + ) + parts = [glm.Part(function_call=function_call)] + else: + parts = _convert_to_parts(message.content) elif isinstance(message, HumanMessage): role = "user" parts = _convert_to_parts(message.content) elif isinstance(message, FunctionMessage): role = "user" - # TODO: Handle FunctionMessage - parts = _convert_to_parts(message.content) + parts = [ + glm.Part( + function_response=glm.FunctionResponse( + name=message.name, + response=message.content, + ) + ) + ] else: raise ValueError( f"Unexpected message with type {type(message)} at the position {i}." @@ -353,100 +364,44 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human return messages -def _retrieve_function_call_response( - parts: List[genai.types.PartType], -) -> Optional[Dict]: - for idx, part in enumerate(parts): - if part.function_call and part.function_call.name: - fc: FunctionCall = part.function_call - return { - "function_call": { - "name": fc.name, - "arguments": json.dumps( - dict(fc.args.items()) - ), # dump to match other function calling llms for now - } - } - return None +def _parse_response_candidate(response_candidate: glm.Candidate) -> AIMessage: + first_part = response_candidate.content.parts[0] + if first_part.function_call: + function_call = proto.Message.to_dict(first_part.function_call) + function_call["arguments"] = json.dumps(function_call.pop("args", {})) + return AIMessage(content="", additional_kwargs={"function_call": function_call}) + else: + parts = response_candidate.content.parts - -def _parts_to_content( - parts: List[genai.types.PartType], -) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]: - """Converts a list of Gemini API Part objects into a list of LangChain messages.""" - function_call_resp = _retrieve_function_call_response(parts) - - if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data: - # Simple text response. The typical response - return parts[0].text, function_call_resp - elif not parts: - logger.warning("Gemini produced an empty response.") - return "", function_call_resp - messages: List[Union[Dict[Any, Any], str]] = [] - for part in parts: - if part.text is not None: - messages.append( - { - "type": "text", - "text": part.text, - } - ) - else: - # TODO: Handle inline_data if that's a thing? - raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}") - return messages, function_call_resp + if len(parts) == 1 and parts[0].text: + content: Union[str, List[Union[str, Dict]]] = parts[0].text + else: + content = [proto.Message.to_dict(part) for part in parts] + return AIMessage(content=content, additional_kwargs={}) def _response_to_result( - response: genai.types.GenerateContentResponse, - ai_msg_t: Type[BaseMessage] = AIMessage, - human_msg_t: Type[BaseMessage] = HumanMessage, - chat_msg_t: Type[BaseMessage] = ChatMessage, - generation_t: Type[ChatGeneration] = ChatGeneration, + response: glm.GenerateContentResponse, ) -> ChatResult: """Converts a PaLM API response into a LangChain ChatResult.""" - llm_output = {} - if response.prompt_feedback: - try: - prompt_feedback = type(response.prompt_feedback).to_dict( - response.prompt_feedback, use_integers_for_enums=False - ) - llm_output["prompt_feedback"] = prompt_feedback - except Exception as e: - logger.debug(f"Unable to convert prompt_feedback to dict: {e}") + llm_output = {"prompt_feedback": proto.Message.to_dict(response.prompt_feedback)} generations: List[ChatGeneration] = [] - role_map = { - "model": ai_msg_t, - "user": human_msg_t, - } - for candidate in response.candidates: - content = candidate.content - parts_content, additional_kwargs = _parts_to_content(content.parts) - if content.role not in role_map: - logger.warning( - f"Unrecognized role: {content.role}. Treating as a ChatMessage." - ) - msg = chat_msg_t( - content=parts_content, - role=content.role, - additional_kwargs=additional_kwargs or {}, - ) - else: - msg = role_map[content.role]( - content=parts_content, - additional_kwargs=additional_kwargs or {}, - ) generation_info = {} if candidate.finish_reason: generation_info["finish_reason"] = candidate.finish_reason.name - if candidate.safety_ratings: - generation_info["safety_ratings"] = [ - type(rating).to_dict(rating) for rating in candidate.safety_ratings - ] - generations.append(generation_t(message=msg, generation_info=generation_info)) + generation_info["safety_ratings"] = [ + proto.Message.to_dict(safety_rating, use_integers_for_enums=False) + for safety_rating in candidate.safety_ratings + ] + generations.append( + ChatGeneration( + message=_parse_response_candidate(candidate), + generation_info=generation_info, + ) + ) if not response.candidates: # Likely a "prompt feedback" violation (e.g., toxic input) # Raising an error would be different than how OpenAI handles it, @@ -455,7 +410,9 @@ def _response_to_result( "Gemini produced an empty response. Continuing with empty message\n" f"Feedback: {response.prompt_feedback}" ) - generations = [generation_t(message=ai_msg_t(content=""), generation_info={})] + generations = [ + ChatGeneration(message=AIMessage(content=""), generation_info={}) + ] return ChatResult(generations=generations, llm_output=llm_output) @@ -616,13 +573,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): stream=True, ) for chunk in response: - _chat_result = _response_to_result( - chunk, - ai_msg_t=AIMessageChunk, - human_msg_t=HumanMessageChunk, - chat_msg_t=ChatMessageChunk, - generation_t=ChatGenerationChunk, - ) + _chat_result = _response_to_result(chunk) gen = cast(ChatGenerationChunk, _chat_result.generations[0]) if run_manager: run_manager.on_llm_new_token(gen.text) @@ -646,13 +597,7 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel): generation_method=chat.send_message_async, stream=True, ): - _chat_result = _response_to_result( - chunk, - ai_msg_t=AIMessageChunk, - human_msg_t=HumanMessageChunk, - chat_msg_t=ChatMessageChunk, - generation_t=ChatGenerationChunk, - ) + _chat_result = _response_to_result(chunk) gen = cast(ChatGenerationChunk, _chat_result.generations[0]) if run_manager: await run_manager.on_llm_new_token(gen.text)