google-genai[minor]: support functions call (#15146)

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
chyroc 2024-02-08 04:09:30 +08:00 committed by GitHub
parent 302989a2b1
commit f87b38a559
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 520 additions and 384 deletions

View File

@ -26,6 +26,7 @@ import google.api_core
# TODO: remove ignore once the google package is published with types
import google.generativeai as genai # type: ignore[import]
import requests
from google.ai.generativelanguage_v1beta import FunctionCall
from langchain_core.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
@ -341,16 +342,90 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
return messages
def _retrieve_function_call_response(
parts: List[genai.types.PartType],
) -> Optional[Dict]:
for idx, part in enumerate(parts):
if part.function_call and part.function_call.name:
fc: FunctionCall = part.function_call
return {
"function_call": {
"name": fc.name,
"arguments": dict(fc.args.items()),
}
}
return None
def _convert_function_call_req(function_calls: Union[Dict, List[Dict]]) -> Dict:
function_declarations = []
if isinstance(function_calls, dict):
function_declarations.append(_convert_fc_type(function_calls))
else:
for fc in function_calls:
function_declarations.append(_convert_fc_type(fc))
return {
"function_declarations": function_declarations,
}
def _convert_fc_type(fc: Dict) -> Dict:
# type_: "Type"
# format_: str
# description: str
# nullable: bool
# enum: MutableSequence[str]
# items: "Schema"
# properties: MutableMapping[str, "Schema"]
# required: MutableSequence[str]
if "parameters" in fc:
fc["parameters"] = _convert_fc_type(fc["parameters"])
if "properties" in fc:
for k, v in fc["properties"].items():
fc["properties"][k] = _convert_fc_type(v)
if "type" in fc:
# STRING = 1
# NUMBER = 2
# INTEGER = 3
# BOOLEAN = 4
# ARRAY = 5
# OBJECT = 6
if fc["type"] == "string":
fc["type_"] = 1
elif fc["type"] == "number":
fc["type_"] = 2
elif fc["type"] == "integer":
fc["type_"] = 3
elif fc["type"] == "boolean":
fc["type_"] = 4
elif fc["type"] == "array":
fc["type_"] = 5
elif fc["type"] == "object":
fc["type_"] = 6
del fc["type"]
if "format" in fc:
fc["format_"] = fc["format"]
del fc["format"]
for k, v in fc.items():
if isinstance(v, dict):
fc[k] = _convert_fc_type(v)
return fc
def _parts_to_content(
parts: List[genai.types.PartType],
) -> Union[str, List[Union[Dict[Any, Any], str]]]:
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
function_call_resp = _retrieve_function_call_response(parts)
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
# Simple text response. The typical response
return parts[0].text
return parts[0].text, function_call_resp
elif not parts:
logger.warning("Gemini produced an empty response.")
return ""
return "", function_call_resp
messages: List[Union[Dict[Any, Any], str]] = []
for part in parts:
if part.text is not None:
@ -363,7 +438,7 @@ def _parts_to_content(
else:
# TODO: Handle inline_data if that's a thing?
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
return messages
return messages, function_call_resp
def _response_to_result(
@ -390,16 +465,24 @@ def _response_to_result(
"model": ai_msg_t,
"user": human_msg_t,
}
for candidate in response.candidates:
content = candidate.content
parts_content = _parts_to_content(content.parts)
parts_content, additional_kwargs = _parts_to_content(content.parts)
if content.role not in role_map:
logger.warning(
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
)
msg = chat_msg_t(content=parts_content, role=content.role)
msg = chat_msg_t(
content=parts_content,
role=content.role,
additional_kwargs=additional_kwargs or {},
)
else:
msg = role_map[content.role](content=parts_content)
msg = role_map[content.role](
content=parts_content,
additional_kwargs=additional_kwargs or {},
)
generation_info = {}
if candidate.finish_reason:
generation_info["finish_reason"] = candidate.finish_reason.name
@ -527,7 +610,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params, chat, message = self._prepare_chat(messages, stop=stop)
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
@ -542,7 +629,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
params, chat, message = self._prepare_chat(messages, stop=stop)
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
)
response: genai.types.GenerateContentResponse = await _achat_with_retry(
content=message,
**params,
@ -557,7 +648,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
params, chat, message = self._prepare_chat(messages, stop=stop)
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
)
response: genai.types.GenerateContentResponse = _chat_with_retry(
content=message,
**params,
@ -584,7 +679,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
params, chat, message = self._prepare_chat(messages, stop=stop)
params, chat, message = self._prepare_chat(
messages,
stop=stop,
functions=kwargs.get("functions"),
)
async for chunk in await _achat_with_retry(
content=message,
**params,
@ -609,13 +708,19 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
stop: Optional[List[str]] = None,
**kwargs: Any,
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
cli = self.client
functions = kwargs.pop("functions", None)
if functions:
tools = _convert_function_call_req(functions)
cli = genai.GenerativeModel(model_name=self.model, tools=tools)
params = self._prepare_params(stop, **kwargs)
history = _parse_chat_history(
messages,
convert_system_message_to_human=self.convert_system_message_to_human,
)
message = history.pop()
chat = self.client.start_chat(history=history)
chat = cli.start_chat(history=history)
return params, chat, message
def get_num_tokens(self, text: str) -> int:

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,34 @@
"""Test ChatGoogleGenerativeAI function call."""
from langchain_google_genai.chat_models import (
ChatGoogleGenerativeAI,
)
def test_function_call() -> None:
functions = [
{
"name": "get_weather",
"description": "Determine weather in my location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["c", "f"]},
},
"required": ["location"],
},
}
]
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=functions)
res = llm.invoke("what weather is today in san francisco?")
assert res
assert res.additional_kwargs
assert "function_call" in res.additional_kwargs
assert "get_weather" == res.additional_kwargs["function_call"]["name"]
arguments = res.additional_kwargs["function_call"]["arguments"]
assert isinstance(arguments, dict)
assert "location" in arguments