mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-28 17:38:36 +00:00
google-genai[minor]: support functions call (#15146)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
302989a2b1
commit
f87b38a559
@ -26,6 +26,7 @@ import google.api_core
|
||||
# TODO: remove ignore once the google package is published with types
|
||||
import google.generativeai as genai # type: ignore[import]
|
||||
import requests
|
||||
from google.ai.generativelanguage_v1beta import FunctionCall
|
||||
from langchain_core.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@ -341,16 +342,90 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
||||
return messages
|
||||
|
||||
|
||||
def _retrieve_function_call_response(
|
||||
parts: List[genai.types.PartType],
|
||||
) -> Optional[Dict]:
|
||||
for idx, part in enumerate(parts):
|
||||
if part.function_call and part.function_call.name:
|
||||
fc: FunctionCall = part.function_call
|
||||
return {
|
||||
"function_call": {
|
||||
"name": fc.name,
|
||||
"arguments": dict(fc.args.items()),
|
||||
}
|
||||
}
|
||||
return None
|
||||
|
||||
|
||||
def _convert_function_call_req(function_calls: Union[Dict, List[Dict]]) -> Dict:
|
||||
function_declarations = []
|
||||
if isinstance(function_calls, dict):
|
||||
function_declarations.append(_convert_fc_type(function_calls))
|
||||
else:
|
||||
for fc in function_calls:
|
||||
function_declarations.append(_convert_fc_type(fc))
|
||||
return {
|
||||
"function_declarations": function_declarations,
|
||||
}
|
||||
|
||||
|
||||
def _convert_fc_type(fc: Dict) -> Dict:
|
||||
# type_: "Type"
|
||||
# format_: str
|
||||
# description: str
|
||||
# nullable: bool
|
||||
# enum: MutableSequence[str]
|
||||
# items: "Schema"
|
||||
# properties: MutableMapping[str, "Schema"]
|
||||
# required: MutableSequence[str]
|
||||
if "parameters" in fc:
|
||||
fc["parameters"] = _convert_fc_type(fc["parameters"])
|
||||
if "properties" in fc:
|
||||
for k, v in fc["properties"].items():
|
||||
fc["properties"][k] = _convert_fc_type(v)
|
||||
if "type" in fc:
|
||||
# STRING = 1
|
||||
# NUMBER = 2
|
||||
# INTEGER = 3
|
||||
# BOOLEAN = 4
|
||||
# ARRAY = 5
|
||||
# OBJECT = 6
|
||||
if fc["type"] == "string":
|
||||
fc["type_"] = 1
|
||||
elif fc["type"] == "number":
|
||||
fc["type_"] = 2
|
||||
elif fc["type"] == "integer":
|
||||
fc["type_"] = 3
|
||||
elif fc["type"] == "boolean":
|
||||
fc["type_"] = 4
|
||||
elif fc["type"] == "array":
|
||||
fc["type_"] = 5
|
||||
elif fc["type"] == "object":
|
||||
fc["type_"] = 6
|
||||
del fc["type"]
|
||||
if "format" in fc:
|
||||
fc["format_"] = fc["format"]
|
||||
del fc["format"]
|
||||
|
||||
for k, v in fc.items():
|
||||
if isinstance(v, dict):
|
||||
fc[k] = _convert_fc_type(v)
|
||||
|
||||
return fc
|
||||
|
||||
|
||||
def _parts_to_content(
|
||||
parts: List[genai.types.PartType],
|
||||
) -> Union[str, List[Union[Dict[Any, Any], str]]]:
|
||||
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
|
||||
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
||||
function_call_resp = _retrieve_function_call_response(parts)
|
||||
|
||||
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
||||
# Simple text response. The typical response
|
||||
return parts[0].text
|
||||
return parts[0].text, function_call_resp
|
||||
elif not parts:
|
||||
logger.warning("Gemini produced an empty response.")
|
||||
return ""
|
||||
return "", function_call_resp
|
||||
messages: List[Union[Dict[Any, Any], str]] = []
|
||||
for part in parts:
|
||||
if part.text is not None:
|
||||
@ -363,7 +438,7 @@ def _parts_to_content(
|
||||
else:
|
||||
# TODO: Handle inline_data if that's a thing?
|
||||
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
||||
return messages
|
||||
return messages, function_call_resp
|
||||
|
||||
|
||||
def _response_to_result(
|
||||
@ -390,16 +465,24 @@ def _response_to_result(
|
||||
"model": ai_msg_t,
|
||||
"user": human_msg_t,
|
||||
}
|
||||
|
||||
for candidate in response.candidates:
|
||||
content = candidate.content
|
||||
parts_content = _parts_to_content(content.parts)
|
||||
parts_content, additional_kwargs = _parts_to_content(content.parts)
|
||||
if content.role not in role_map:
|
||||
logger.warning(
|
||||
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
||||
)
|
||||
msg = chat_msg_t(content=parts_content, role=content.role)
|
||||
msg = chat_msg_t(
|
||||
content=parts_content,
|
||||
role=content.role,
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
else:
|
||||
msg = role_map[content.role](content=parts_content)
|
||||
msg = role_map[content.role](
|
||||
content=parts_content,
|
||||
additional_kwargs=additional_kwargs or {},
|
||||
)
|
||||
generation_info = {}
|
||||
if candidate.finish_reason:
|
||||
generation_info["finish_reason"] = candidate.finish_reason.name
|
||||
@ -527,7 +610,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
)
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
@ -542,7 +629,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
)
|
||||
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
@ -557,7 +648,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
)
|
||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
@ -584,7 +679,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
||||
params, chat, message = self._prepare_chat(
|
||||
messages,
|
||||
stop=stop,
|
||||
functions=kwargs.get("functions"),
|
||||
)
|
||||
async for chunk in await _achat_with_retry(
|
||||
content=message,
|
||||
**params,
|
||||
@ -609,13 +708,19 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
||||
stop: Optional[List[str]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
||||
cli = self.client
|
||||
functions = kwargs.pop("functions", None)
|
||||
if functions:
|
||||
tools = _convert_function_call_req(functions)
|
||||
cli = genai.GenerativeModel(model_name=self.model, tools=tools)
|
||||
|
||||
params = self._prepare_params(stop, **kwargs)
|
||||
history = _parse_chat_history(
|
||||
messages,
|
||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||
)
|
||||
message = history.pop()
|
||||
chat = self.client.start_chat(history=history)
|
||||
chat = cli.start_chat(history=history)
|
||||
return params, chat, message
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
|
741
libs/partners/google-genai/poetry.lock
generated
741
libs/partners/google-genai/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,34 @@
|
||||
"""Test ChatGoogleGenerativeAI function call."""
|
||||
|
||||
from langchain_google_genai.chat_models import (
|
||||
ChatGoogleGenerativeAI,
|
||||
)
|
||||
|
||||
|
||||
def test_function_call() -> None:
|
||||
functions = [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Determine weather in my location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=functions)
|
||||
res = llm.invoke("what weather is today in san francisco?")
|
||||
assert res
|
||||
assert res.additional_kwargs
|
||||
assert "function_call" in res.additional_kwargs
|
||||
assert "get_weather" == res.additional_kwargs["function_call"]["name"]
|
||||
arguments = res.additional_kwargs["function_call"]["arguments"]
|
||||
assert isinstance(arguments, dict)
|
||||
assert "location" in arguments
|
Loading…
Reference in New Issue
Block a user