mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-29 18:08:36 +00:00
google-genai[minor]: support functions call (#15146)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
302989a2b1
commit
f87b38a559
@ -26,6 +26,7 @@ import google.api_core
|
|||||||
# TODO: remove ignore once the google package is published with types
|
# TODO: remove ignore once the google package is published with types
|
||||||
import google.generativeai as genai # type: ignore[import]
|
import google.generativeai as genai # type: ignore[import]
|
||||||
import requests
|
import requests
|
||||||
|
from google.ai.generativelanguage_v1beta import FunctionCall
|
||||||
from langchain_core.callbacks.manager import (
|
from langchain_core.callbacks.manager import (
|
||||||
AsyncCallbackManagerForLLMRun,
|
AsyncCallbackManagerForLLMRun,
|
||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
@ -341,16 +342,90 @@ llm = ChatGoogleGenerativeAI(model="gemini-pro", convert_system_message_to_human
|
|||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
def _retrieve_function_call_response(
|
||||||
|
parts: List[genai.types.PartType],
|
||||||
|
) -> Optional[Dict]:
|
||||||
|
for idx, part in enumerate(parts):
|
||||||
|
if part.function_call and part.function_call.name:
|
||||||
|
fc: FunctionCall = part.function_call
|
||||||
|
return {
|
||||||
|
"function_call": {
|
||||||
|
"name": fc.name,
|
||||||
|
"arguments": dict(fc.args.items()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_function_call_req(function_calls: Union[Dict, List[Dict]]) -> Dict:
|
||||||
|
function_declarations = []
|
||||||
|
if isinstance(function_calls, dict):
|
||||||
|
function_declarations.append(_convert_fc_type(function_calls))
|
||||||
|
else:
|
||||||
|
for fc in function_calls:
|
||||||
|
function_declarations.append(_convert_fc_type(fc))
|
||||||
|
return {
|
||||||
|
"function_declarations": function_declarations,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _convert_fc_type(fc: Dict) -> Dict:
|
||||||
|
# type_: "Type"
|
||||||
|
# format_: str
|
||||||
|
# description: str
|
||||||
|
# nullable: bool
|
||||||
|
# enum: MutableSequence[str]
|
||||||
|
# items: "Schema"
|
||||||
|
# properties: MutableMapping[str, "Schema"]
|
||||||
|
# required: MutableSequence[str]
|
||||||
|
if "parameters" in fc:
|
||||||
|
fc["parameters"] = _convert_fc_type(fc["parameters"])
|
||||||
|
if "properties" in fc:
|
||||||
|
for k, v in fc["properties"].items():
|
||||||
|
fc["properties"][k] = _convert_fc_type(v)
|
||||||
|
if "type" in fc:
|
||||||
|
# STRING = 1
|
||||||
|
# NUMBER = 2
|
||||||
|
# INTEGER = 3
|
||||||
|
# BOOLEAN = 4
|
||||||
|
# ARRAY = 5
|
||||||
|
# OBJECT = 6
|
||||||
|
if fc["type"] == "string":
|
||||||
|
fc["type_"] = 1
|
||||||
|
elif fc["type"] == "number":
|
||||||
|
fc["type_"] = 2
|
||||||
|
elif fc["type"] == "integer":
|
||||||
|
fc["type_"] = 3
|
||||||
|
elif fc["type"] == "boolean":
|
||||||
|
fc["type_"] = 4
|
||||||
|
elif fc["type"] == "array":
|
||||||
|
fc["type_"] = 5
|
||||||
|
elif fc["type"] == "object":
|
||||||
|
fc["type_"] = 6
|
||||||
|
del fc["type"]
|
||||||
|
if "format" in fc:
|
||||||
|
fc["format_"] = fc["format"]
|
||||||
|
del fc["format"]
|
||||||
|
|
||||||
|
for k, v in fc.items():
|
||||||
|
if isinstance(v, dict):
|
||||||
|
fc[k] = _convert_fc_type(v)
|
||||||
|
|
||||||
|
return fc
|
||||||
|
|
||||||
|
|
||||||
def _parts_to_content(
|
def _parts_to_content(
|
||||||
parts: List[genai.types.PartType],
|
parts: List[genai.types.PartType],
|
||||||
) -> Union[str, List[Union[Dict[Any, Any], str]]]:
|
) -> Tuple[Union[str, List[Union[Dict[Any, Any], str]]], Optional[Dict]]:
|
||||||
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
"""Converts a list of Gemini API Part objects into a list of LangChain messages."""
|
||||||
|
function_call_resp = _retrieve_function_call_response(parts)
|
||||||
|
|
||||||
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
if len(parts) == 1 and parts[0].text is not None and not parts[0].inline_data:
|
||||||
# Simple text response. The typical response
|
# Simple text response. The typical response
|
||||||
return parts[0].text
|
return parts[0].text, function_call_resp
|
||||||
elif not parts:
|
elif not parts:
|
||||||
logger.warning("Gemini produced an empty response.")
|
logger.warning("Gemini produced an empty response.")
|
||||||
return ""
|
return "", function_call_resp
|
||||||
messages: List[Union[Dict[Any, Any], str]] = []
|
messages: List[Union[Dict[Any, Any], str]] = []
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if part.text is not None:
|
if part.text is not None:
|
||||||
@ -363,7 +438,7 @@ def _parts_to_content(
|
|||||||
else:
|
else:
|
||||||
# TODO: Handle inline_data if that's a thing?
|
# TODO: Handle inline_data if that's a thing?
|
||||||
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
raise ChatGoogleGenerativeAIError(f"Unexpected part type. {part}")
|
||||||
return messages
|
return messages, function_call_resp
|
||||||
|
|
||||||
|
|
||||||
def _response_to_result(
|
def _response_to_result(
|
||||||
@ -390,16 +465,24 @@ def _response_to_result(
|
|||||||
"model": ai_msg_t,
|
"model": ai_msg_t,
|
||||||
"user": human_msg_t,
|
"user": human_msg_t,
|
||||||
}
|
}
|
||||||
|
|
||||||
for candidate in response.candidates:
|
for candidate in response.candidates:
|
||||||
content = candidate.content
|
content = candidate.content
|
||||||
parts_content = _parts_to_content(content.parts)
|
parts_content, additional_kwargs = _parts_to_content(content.parts)
|
||||||
if content.role not in role_map:
|
if content.role not in role_map:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
f"Unrecognized role: {content.role}. Treating as a ChatMessage."
|
||||||
)
|
)
|
||||||
msg = chat_msg_t(content=parts_content, role=content.role)
|
msg = chat_msg_t(
|
||||||
|
content=parts_content,
|
||||||
|
role=content.role,
|
||||||
|
additional_kwargs=additional_kwargs or {},
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
msg = role_map[content.role](content=parts_content)
|
msg = role_map[content.role](
|
||||||
|
content=parts_content,
|
||||||
|
additional_kwargs=additional_kwargs or {},
|
||||||
|
)
|
||||||
generation_info = {}
|
generation_info = {}
|
||||||
if candidate.finish_reason:
|
if candidate.finish_reason:
|
||||||
generation_info["finish_reason"] = candidate.finish_reason.name
|
generation_info["finish_reason"] = candidate.finish_reason.name
|
||||||
@ -527,7 +610,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
params, chat, message = self._prepare_chat(
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
functions=kwargs.get("functions"),
|
||||||
|
)
|
||||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||||
content=message,
|
content=message,
|
||||||
**params,
|
**params,
|
||||||
@ -542,7 +629,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
params, chat, message = self._prepare_chat(
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
functions=kwargs.get("functions"),
|
||||||
|
)
|
||||||
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
response: genai.types.GenerateContentResponse = await _achat_with_retry(
|
||||||
content=message,
|
content=message,
|
||||||
**params,
|
**params,
|
||||||
@ -557,7 +648,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Iterator[ChatGenerationChunk]:
|
) -> Iterator[ChatGenerationChunk]:
|
||||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
params, chat, message = self._prepare_chat(
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
functions=kwargs.get("functions"),
|
||||||
|
)
|
||||||
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
response: genai.types.GenerateContentResponse = _chat_with_retry(
|
||||||
content=message,
|
content=message,
|
||||||
**params,
|
**params,
|
||||||
@ -584,7 +679,11 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> AsyncIterator[ChatGenerationChunk]:
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||||||
params, chat, message = self._prepare_chat(messages, stop=stop)
|
params, chat, message = self._prepare_chat(
|
||||||
|
messages,
|
||||||
|
stop=stop,
|
||||||
|
functions=kwargs.get("functions"),
|
||||||
|
)
|
||||||
async for chunk in await _achat_with_retry(
|
async for chunk in await _achat_with_retry(
|
||||||
content=message,
|
content=message,
|
||||||
**params,
|
**params,
|
||||||
@ -609,13 +708,19 @@ class ChatGoogleGenerativeAI(_BaseGoogleGenerativeAI, BaseChatModel):
|
|||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
) -> Tuple[Dict[str, Any], genai.ChatSession, genai.types.ContentDict]:
|
||||||
|
cli = self.client
|
||||||
|
functions = kwargs.pop("functions", None)
|
||||||
|
if functions:
|
||||||
|
tools = _convert_function_call_req(functions)
|
||||||
|
cli = genai.GenerativeModel(model_name=self.model, tools=tools)
|
||||||
|
|
||||||
params = self._prepare_params(stop, **kwargs)
|
params = self._prepare_params(stop, **kwargs)
|
||||||
history = _parse_chat_history(
|
history = _parse_chat_history(
|
||||||
messages,
|
messages,
|
||||||
convert_system_message_to_human=self.convert_system_message_to_human,
|
convert_system_message_to_human=self.convert_system_message_to_human,
|
||||||
)
|
)
|
||||||
message = history.pop()
|
message = history.pop()
|
||||||
chat = self.client.start_chat(history=history)
|
chat = cli.start_chat(history=history)
|
||||||
return params, chat, message
|
return params, chat, message
|
||||||
|
|
||||||
def get_num_tokens(self, text: str) -> int:
|
def get_num_tokens(self, text: str) -> int:
|
||||||
|
741
libs/partners/google-genai/poetry.lock
generated
741
libs/partners/google-genai/poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,34 @@
|
|||||||
|
"""Test ChatGoogleGenerativeAI function call."""
|
||||||
|
|
||||||
|
from langchain_google_genai.chat_models import (
|
||||||
|
ChatGoogleGenerativeAI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_function_call() -> None:
|
||||||
|
functions = [
|
||||||
|
{
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Determine weather in my location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["c", "f"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
llm = ChatGoogleGenerativeAI(model="gemini-pro").bind(functions=functions)
|
||||||
|
res = llm.invoke("what weather is today in san francisco?")
|
||||||
|
assert res
|
||||||
|
assert res.additional_kwargs
|
||||||
|
assert "function_call" in res.additional_kwargs
|
||||||
|
assert "get_weather" == res.additional_kwargs["function_call"]["name"]
|
||||||
|
arguments = res.additional_kwargs["function_call"]["arguments"]
|
||||||
|
assert isinstance(arguments, dict)
|
||||||
|
assert "location" in arguments
|
Loading…
Reference in New Issue
Block a user