From e8a84b05a4f311b01b471d427ed0899f79d40865 Mon Sep 17 00:00:00 2001 From: Jacob Mansdorfer <90076431+jmansdorfer@users.noreply.github.com> Date: Mon, 21 Apr 2025 17:01:57 -0400 Subject: [PATCH] Community: Adding tool calling and some new parameters to the langchain-predictionguard docs. (#30953) - [x] **PR message**: - **Description:** Updates the documentation for the langchain-predictionguard package, adding tool calling functionality and some new parameters. --- .../integrations/chat/predictionguard.ipynb | 163 ++++++++++++++++-- 1 file changed, 151 insertions(+), 12 deletions(-) diff --git a/docs/docs/integrations/chat/predictionguard.ipynb b/docs/docs/integrations/chat/predictionguard.ipynb index 101339cdb63..6d06a7dd5ad 100644 --- a/docs/docs/integrations/chat/predictionguard.ipynb +++ b/docs/docs/integrations/chat/predictionguard.ipynb @@ -57,8 +57,8 @@ { "metadata": { "ExecuteTime": { - "end_time": "2024-11-08T19:44:51.390231Z", - "start_time": "2024-11-08T19:44:51.387945Z" + "end_time": "2025-04-21T18:23:30.746350Z", + "start_time": "2025-04-21T18:23:30.744744Z" } }, "cell_type": "code", @@ -70,7 +70,7 @@ ], "id": "fa57fba89276da13", "outputs": [], - "execution_count": 1 + "execution_count": 2 }, { "metadata": {}, @@ -82,12 +82,25 @@ "id": "87dc1742af7b053" }, { - "metadata": {}, + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-21T18:23:33.359278Z", + "start_time": "2025-04-21T18:23:32.853207Z" + } + }, "cell_type": "code", "source": "%pip install -qU langchain-predictionguard", "id": "b816ae8553cba021", - "outputs": [], - "execution_count": null + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Note: you may need to restart the kernel to use updated packages.\n" + ] + } + ], + "execution_count": 3 }, { "cell_type": "markdown", @@ -103,13 +116,13 @@ "metadata": { "id": "2xe8JEUwA7_y", "ExecuteTime": { - "end_time": "2024-11-08T19:44:53.950653Z", - "start_time": "2024-11-08T19:44:53.488694Z" + "end_time": "2025-04-21T18:23:39.812675Z", + "start_time": "2025-04-21T18:23:39.666881Z" } }, "source": "from langchain_predictionguard import ChatPredictionGuard", "outputs": [], - "execution_count": 2 + "execution_count": 4 }, { "cell_type": "code", @@ -117,8 +130,8 @@ "metadata": { "id": "Ua7Mw1N4HcER", "ExecuteTime": { - "end_time": "2024-11-08T19:44:54.890695Z", - "start_time": "2024-11-08T19:44:54.502846Z" + "end_time": "2025-04-21T18:23:41.590296Z", + "start_time": "2025-04-21T18:23:41.253237Z" } }, "source": [ @@ -126,7 +139,7 @@ "chat = ChatPredictionGuard(model=\"Hermes-3-Llama-3.1-8B\")" ], "outputs": [], - "execution_count": 3 + "execution_count": 5 }, { "metadata": {}, @@ -221,6 +234,132 @@ ], "execution_count": 6 }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "## Tool Calling\n", + "\n", + "Prediction Guard has a tool calling API that lets you describe tools and their arguments, which enables the model return a JSON object with a tool to call and the inputs to that tool. Tool-calling is very useful for building tool-using chains and agents, and for getting structured outputs from models more generally.\n" + ], + "id": "1227780d6e6728ba" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### ChatPredictionGuard.bind_tools()\n", + "\n", + "Using `ChatPredictionGuard.bind_tools()`, you can pass in Pydantic classes, dict schemas, and Langchain tools as tools to the model, which are then reformatted to allow for use by the model." + ], + "id": "23446aa52e01d1ba" + }, + { + "metadata": {}, + "cell_type": "code", + "outputs": [], + "execution_count": null, + "source": [ + "from pydantic import BaseModel, Field\n", + "\n", + "\n", + "class GetWeather(BaseModel):\n", + " \"\"\"Get the current weather in a given location\"\"\"\n", + "\n", + " location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n", + "\n", + "\n", + "class GetPopulation(BaseModel):\n", + " \"\"\"Get the current population in a given location\"\"\"\n", + "\n", + " location: str = Field(..., description=\"The city and state, e.g. San Francisco, CA\")\n", + "\n", + "\n", + "llm_with_tools = chat.bind_tools(\n", + " [GetWeather, GetPopulation]\n", + " # strict = True # enforce tool args schema is respected\n", + ")" + ], + "id": "135efb0bfc5916c1" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-21T18:42:41.834079Z", + "start_time": "2025-04-21T18:42:40.289095Z" + } + }, + "cell_type": "code", + "source": [ + "ai_msg = llm_with_tools.invoke(\n", + " \"Which city is hotter today and which is bigger: LA or NY?\"\n", + ")\n", + "ai_msg" + ], + "id": "8136f19a8836cd58", + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'chatcmpl-tool-b1204a3c70b44cd8802579df48df0c8c', 'type': 'function', 'index': 0, 'function': {'name': 'GetWeather', 'arguments': '{\"location\": \"Los Angeles, CA\"}'}}, {'id': 'chatcmpl-tool-e299116c05bf4ce498cd6042928ae080', 'type': 'function', 'index': 0, 'function': {'name': 'GetWeather', 'arguments': '{\"location\": \"New York, NY\"}'}}, {'id': 'chatcmpl-tool-19502a60f30348669ffbac00ff503388', 'type': 'function', 'index': 0, 'function': {'name': 'GetPopulation', 'arguments': '{\"location\": \"Los Angeles, CA\"}'}}, {'id': 'chatcmpl-tool-4b8d56ef067f447795d9146a56e43510', 'type': 'function', 'index': 0, 'function': {'name': 'GetPopulation', 'arguments': '{\"location\": \"New York, NY\"}'}}]}, response_metadata={}, id='run-4630cfa9-4e95-42dd-8e4a-45db78180a10-0', tool_calls=[{'name': 'GetWeather', 'args': {'location': 'Los Angeles, CA'}, 'id': 'chatcmpl-tool-b1204a3c70b44cd8802579df48df0c8c', 'type': 'tool_call'}, {'name': 'GetWeather', 'args': {'location': 'New York, NY'}, 'id': 'chatcmpl-tool-e299116c05bf4ce498cd6042928ae080', 'type': 'tool_call'}, {'name': 'GetPopulation', 'args': {'location': 'Los Angeles, CA'}, 'id': 'chatcmpl-tool-19502a60f30348669ffbac00ff503388', 'type': 'tool_call'}, {'name': 'GetPopulation', 'args': {'location': 'New York, NY'}, 'id': 'chatcmpl-tool-4b8d56ef067f447795d9146a56e43510', 'type': 'tool_call'}])" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 7 + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": [ + "### AIMessage.tool_calls\n", + "\n", + "Notice that the AIMessage has a tool_calls attribute. This contains in a standardized ToolCall format that is model-provider agnostic." + ], + "id": "84f405c45a35abe5" + }, + { + "metadata": { + "ExecuteTime": { + "end_time": "2025-04-21T18:43:00.429453Z", + "start_time": "2025-04-21T18:43:00.426399Z" + } + }, + "cell_type": "code", + "source": "ai_msg.tool_calls", + "id": "bdcee85475019719", + "outputs": [ + { + "data": { + "text/plain": [ + "[{'name': 'GetWeather',\n", + " 'args': {'location': 'Los Angeles, CA'},\n", + " 'id': 'chatcmpl-tool-b1204a3c70b44cd8802579df48df0c8c',\n", + " 'type': 'tool_call'},\n", + " {'name': 'GetWeather',\n", + " 'args': {'location': 'New York, NY'},\n", + " 'id': 'chatcmpl-tool-e299116c05bf4ce498cd6042928ae080',\n", + " 'type': 'tool_call'},\n", + " {'name': 'GetPopulation',\n", + " 'args': {'location': 'Los Angeles, CA'},\n", + " 'id': 'chatcmpl-tool-19502a60f30348669ffbac00ff503388',\n", + " 'type': 'tool_call'},\n", + " {'name': 'GetPopulation',\n", + " 'args': {'location': 'New York, NY'},\n", + " 'id': 'chatcmpl-tool-4b8d56ef067f447795d9146a56e43510',\n", + " 'type': 'tool_call'}]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "execution_count": 8 + }, { "cell_type": "markdown", "id": "ff1b51a8",