mirror of
https://github.com/hwchase17/langchain.git
synced 2025-05-31 20:19:43 +00:00
google-vertexai[patch]: Harrison/vertex function calling (#16223)
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
parent
6bc6d64a12
commit
f60f59d69f
@ -6,9 +6,9 @@ all: help
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
test_integration: TEST_FILE = tests/integration_tests/
|
||||
integration_tests: TEST_FILE = tests/integration_tests/
|
||||
|
||||
test test_integration:
|
||||
test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
tests:
|
||||
|
@ -1,6 +1,8 @@
|
||||
from langchain_google_vertexai._enums import HarmBlockThreshold, HarmCategory
|
||||
from langchain_google_vertexai.chains import create_structured_runnable
|
||||
from langchain_google_vertexai.chat_models import ChatVertexAI
|
||||
from langchain_google_vertexai.embeddings import VertexAIEmbeddings
|
||||
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
|
||||
from langchain_google_vertexai.llms import VertexAI, VertexAIModelGarden
|
||||
|
||||
__all__ = [
|
||||
@ -10,4 +12,6 @@ __all__ = [
|
||||
"VertexAIModelGarden",
|
||||
"HarmBlockThreshold",
|
||||
"HarmCategory",
|
||||
"PydanticFunctionsOutputParser",
|
||||
"create_structured_runnable",
|
||||
]
|
||||
|
@ -0,0 +1,111 @@
|
||||
from typing import (
|
||||
Dict,
|
||||
Optional,
|
||||
Sequence,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
from langchain_core.output_parsers import (
|
||||
BaseGenerationOutputParser,
|
||||
BaseOutputParser,
|
||||
)
|
||||
from langchain_core.prompts import BasePromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.runnables import Runnable
|
||||
|
||||
from langchain_google_vertexai.functions_utils import PydanticFunctionsOutputParser
|
||||
|
||||
|
||||
def get_output_parser(
|
||||
functions: Sequence[Type[BaseModel]],
|
||||
) -> Union[BaseOutputParser, BaseGenerationOutputParser]:
|
||||
"""Get the appropriate function output parser given the user functions.
|
||||
|
||||
Args:
|
||||
functions: Sequence where element is a dictionary, a pydantic.BaseModel class,
|
||||
or a Python function. If a dictionary is passed in, it is assumed to
|
||||
already be a valid OpenAI function.
|
||||
|
||||
Returns:
|
||||
A PydanticFunctionsOutputParser
|
||||
"""
|
||||
function_names = [f.__name__ for f in functions]
|
||||
if len(functions) > 1:
|
||||
pydantic_schema: Union[Dict, Type[BaseModel]] = {
|
||||
name: fn for name, fn in zip(function_names, functions)
|
||||
}
|
||||
else:
|
||||
pydantic_schema = functions[0]
|
||||
output_parser: Union[
|
||||
BaseOutputParser, BaseGenerationOutputParser
|
||||
] = PydanticFunctionsOutputParser(pydantic_schema=pydantic_schema)
|
||||
return output_parser
|
||||
|
||||
|
||||
def create_structured_runnable(
|
||||
function: Union[Type[BaseModel], Sequence[Type[BaseModel]]],
|
||||
llm: Runnable,
|
||||
*,
|
||||
prompt: Optional[BasePromptTemplate] = None,
|
||||
) -> Runnable:
|
||||
"""Create a runnable sequence that uses OpenAI functions.
|
||||
|
||||
Args:
|
||||
function: Either a single pydantic.BaseModel class or a sequence
|
||||
of pydantic.BaseModels classes.
|
||||
For best results, pydantic.BaseModels
|
||||
should have descriptions of the parameters.
|
||||
llm: Language model to use,
|
||||
assumed to support the Google Vertex function-calling API.
|
||||
prompt: BasePromptTemplate to pass to the model.
|
||||
|
||||
Returns:
|
||||
A runnable sequence that will pass in the given functions to the model when run.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from langchain_google_vertexai import ChatVertexAI, create_structured_runnable
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field
|
||||
|
||||
|
||||
class RecordPerson(BaseModel):
|
||||
\"\"\"Record some identifying information about a person.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The person's name")
|
||||
age: int = Field(..., description="The person's age")
|
||||
fav_food: Optional[str] = Field(None, description="The person's favorite food")
|
||||
|
||||
|
||||
class RecordDog(BaseModel):
|
||||
\"\"\"Record some identifying information about a dog.\"\"\"
|
||||
|
||||
name: str = Field(..., description="The dog's name")
|
||||
color: str = Field(..., description="The dog's color")
|
||||
fav_food: Optional[str] = Field(None, description="The dog's favorite food")
|
||||
|
||||
|
||||
llm = ChatVertexAI(model_name="gemini-pro")
|
||||
prompt = ChatPromptTemplate.from_template(\"\"\"
|
||||
You are a world class algorithm for recording entities.
|
||||
Make calls to the relevant function to record the entities in the following input: {input}
|
||||
Tip: Make sure to answer in the correct format\"\"\"
|
||||
)
|
||||
chain = create_structured_runnable([RecordPerson, RecordDog], llm, prompt=prompt)
|
||||
chain.invoke({"input": "Harry was a chubby brown beagle who loved chicken"})
|
||||
# -> RecordDog(name="Harry", color="brown", fav_food="chicken")
|
||||
""" # noqa: E501
|
||||
if not function:
|
||||
raise ValueError("Need to pass in at least one function. Received zero.")
|
||||
functions = function if isinstance(function, Sequence) else [function]
|
||||
output_parser = get_output_parser(functions)
|
||||
llm_with_functions = llm.bind(functions=functions)
|
||||
if prompt is None:
|
||||
initial_chain = llm_with_functions
|
||||
else:
|
||||
initial_chain = prompt | llm_with_functions
|
||||
return initial_chain | output_parser
|
@ -1,5 +1,10 @@
|
||||
from typing import List
|
||||
import json
|
||||
from typing import Dict, List, Type, Union
|
||||
|
||||
from langchain_core.exceptions import OutputParserException
|
||||
from langchain_core.output_parsers import BaseOutputParser
|
||||
from langchain_core.outputs import ChatGeneration, Generation
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from langchain_core.tools import Tool
|
||||
from langchain_core.utils.function_calling import FunctionDescription
|
||||
from langchain_core.utils.json_schema import dereference_refs
|
||||
@ -11,6 +16,29 @@ from vertexai.preview.generative_models import (
|
||||
)
|
||||
|
||||
|
||||
def _format_pydantic_to_vertex_function(
|
||||
pydantic_model: Type[BaseModel],
|
||||
) -> FunctionDescription:
|
||||
schema = dereference_refs(pydantic_model.schema())
|
||||
schema.pop("definitions", None)
|
||||
|
||||
return {
|
||||
"name": schema["title"],
|
||||
"description": schema["description"],
|
||||
"parameters": {
|
||||
"properties": {
|
||||
k: {
|
||||
"type": v["type"],
|
||||
"description": v.get("description"),
|
||||
}
|
||||
for k, v in schema["properties"].items()
|
||||
},
|
||||
"required": schema["required"],
|
||||
"type": schema["type"],
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
|
||||
"Format tool into the Vertex function API."
|
||||
if tool.args_schema:
|
||||
@ -46,11 +74,81 @@ def _format_tool_to_vertex_function(tool: Tool) -> FunctionDescription:
|
||||
}
|
||||
|
||||
|
||||
def _format_tools_to_vertex_tool(tools: List[Tool]) -> List[VertexTool]:
|
||||
def _format_tools_to_vertex_tool(
|
||||
tools: List[Union[Tool, Type[BaseModel]]],
|
||||
) -> List[VertexTool]:
|
||||
"Format tool into the Vertex Tool instance."
|
||||
function_declarations = []
|
||||
for tool in tools:
|
||||
func = _format_tool_to_vertex_function(tool)
|
||||
if isinstance(tool, Tool):
|
||||
func = _format_tool_to_vertex_function(tool)
|
||||
else:
|
||||
func = _format_pydantic_to_vertex_function(tool)
|
||||
function_declarations.append(FunctionDeclaration(**func))
|
||||
|
||||
return [VertexTool(function_declarations=function_declarations)]
|
||||
|
||||
|
||||
class PydanticFunctionsOutputParser(BaseOutputParser):
|
||||
"""Parse an output as a pydantic object.
|
||||
|
||||
This parser is used to parse the output of a ChatModel that uses
|
||||
Google Vertex function format to invoke functions.
|
||||
|
||||
The parser extracts the function call invocation and matches
|
||||
them to the pydantic schema provided.
|
||||
|
||||
An exception will be raised if the function call does not match
|
||||
the provided schema.
|
||||
|
||||
Example:
|
||||
|
||||
... code-block:: python
|
||||
|
||||
message = AIMessage(
|
||||
content="This is a test message",
|
||||
additional_kwargs={
|
||||
"function_call": {
|
||||
"name": "cookie",
|
||||
"arguments": json.dumps({"name": "value", "age": 10}),
|
||||
}
|
||||
},
|
||||
)
|
||||
chat_generation = ChatGeneration(message=message)
|
||||
|
||||
class Cookie(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
|
||||
class Dog(BaseModel):
|
||||
species: str
|
||||
|
||||
# Full output
|
||||
parser = PydanticOutputFunctionsParser(
|
||||
pydantic_schema={"cookie": Cookie, "dog": Dog}
|
||||
)
|
||||
result = parser.parse_result([chat_generation])
|
||||
"""
|
||||
|
||||
pydantic_schema: Union[Type[BaseModel], Dict[str, Type[BaseModel]]]
|
||||
|
||||
def parse_result(
|
||||
self, result: List[Generation], *, partial: bool = False
|
||||
) -> BaseModel:
|
||||
if not isinstance(result[0], ChatGeneration):
|
||||
raise ValueError("This output parser only works on ChatGeneration output")
|
||||
message = result[0].message
|
||||
function_call = message.additional_kwargs.get("function_call", {})
|
||||
if function_call:
|
||||
function_name = function_call["name"]
|
||||
tool_input = function_call.get("arguments", {})
|
||||
if isinstance(self.pydantic_schema, dict):
|
||||
schema = self.pydantic_schema[function_name]
|
||||
else:
|
||||
schema = self.pydantic_schema
|
||||
return schema(**json.loads(tool_input))
|
||||
else:
|
||||
raise OutputParserException(f"Could not parse function call: {message}")
|
||||
|
||||
def parse(self, text: str) -> BaseModel:
|
||||
raise ValueError("Can only parse messages")
|
||||
|
@ -7,6 +7,8 @@ EXPECTED_ALL = [
|
||||
"VertexAIModelGarden",
|
||||
"HarmBlockThreshold",
|
||||
"HarmCategory",
|
||||
"PydanticFunctionsOutputParser",
|
||||
"create_structured_runnable",
|
||||
]
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user