docs: updated CHATNVIDIA notebooks (#24584)

Updated notebook for tool calling support in chat models
This commit is contained in:
Daniel Glogowski 2024-07-25 06:22:53 -07:00 committed by GitHub
parent d6631919f4
commit 221486687a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -302,9 +302,6 @@
"\n",
"NVIDIA also supports multimodal inputs, meaning you can provide both images and text for the model to reason over. An example model supporting multimodal inputs is `nvidia/neva-22b`.\n",
"\n",
"\n",
"These models accept LangChain's standard image formats, and accept `labels`, similar to the Steering LLMs above. In addition to `creativity`, `complexity`, and `verbosity`, these models support a `quality` toggle.\n",
"\n",
"Below is an example use:"
]
},
@ -447,92 +444,6 @@
"llm.invoke(f'What\\'s in this image?\\n<img src=\"{base64_with_mime_type}\" />')"
]
},
{
"cell_type": "markdown",
"id": "3e61d868",
"metadata": {},
"source": [
"#### **Advanced Use Case:** Forcing Payload \n",
"\n",
"You may notice that some newer models may have strong parameter expectations that the LangChain connector may not support by default. For example, we cannot invoke the [Kosmos](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/ai-foundation/models/kosmos-2) model at the time of this notebook's latest release due to the lack of a streaming argument on the server side: "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "d143e0d6",
"metadata": {},
"outputs": [],
"source": [
"from langchain_nvidia_ai_endpoints import ChatNVIDIA\n",
"\n",
"kosmos = ChatNVIDIA(model=\"microsoft/kosmos-2\")\n",
"\n",
"from langchain_core.messages import HumanMessage\n",
"\n",
"# kosmos.invoke(\n",
"# [\n",
"# HumanMessage(\n",
"# content=[\n",
"# {\"type\": \"text\", \"text\": \"Describe this image:\"},\n",
"# {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n",
"# ]\n",
"# )\n",
"# ]\n",
"# )\n",
"\n",
"# Exception: [422] Unprocessable Entity\n",
"# body -> stream\n",
"# Extra inputs are not permitted (type=extra_forbidden)\n",
"# RequestID: 35538c9a-4b45-4616-8b75-7ef816fccf38"
]
},
{
"cell_type": "markdown",
"id": "1e230b70",
"metadata": {},
"source": [
"For a simple use case like this, we can actually try to force the payload argument of our underlying client by specifying the `payload_fn` function as follows: "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "0925b2b1",
"metadata": {},
"outputs": [],
"source": [
"def drop_streaming_key(d):\n",
" \"\"\"Takes in payload dictionary, outputs new payload dictionary\"\"\"\n",
" if \"stream\" in d:\n",
" d.pop(\"stream\")\n",
" return d\n",
"\n",
"\n",
"## Override the payload passthrough. Default is to pass through the payload as is.\n",
"kosmos = ChatNVIDIA(model=\"microsoft/kosmos-2\")\n",
"kosmos.client.payload_fn = drop_streaming_key\n",
"\n",
"kosmos.invoke(\n",
" [\n",
" HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": \"Describe this image:\"},\n",
" {\"type\": \"image_url\", \"image_url\": {\"url\": image_url}},\n",
" ]\n",
" )\n",
" ]\n",
")"
]
},
{
"cell_type": "markdown",
"id": "fe6e1758",
"metadata": {},
"source": [
"For more advanced or custom use-cases (i.e. supporting the diffusion models), you may be interested in leveraging the `NVEModel` client as a requests backbone. The `NVIDIAEmbeddings` class is a good source of inspiration for this. "
]
},
{
"cell_type": "markdown",
"id": "137662a6",
@ -540,7 +451,7 @@
"id": "137662a6"
},
"source": [
"## Example usage within RunnableWithMessageHistory "
"## Example usage within a RunnableWithMessageHistory"
]
},
{
@ -630,14 +541,14 @@
{
"cell_type": "code",
"execution_count": null,
"id": "uHIMZxVSVNBC",
"id": "LyD1xVKmVSs4",
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 284
"height": 350
},
"id": "uHIMZxVSVNBC",
"outputId": "79acc89d-a820-4f2c-bac2-afe99da95580"
"id": "LyD1xVKmVSs4",
"outputId": "a1714513-a8fd-4d14-f974-233e39d5c4f5"
},
"outputs": [],
"source": [
@ -646,6 +557,79 @@
" config=config,\n",
")"
]
},
{
"cell_type": "markdown",
"id": "f3cbbba0",
"metadata": {},
"source": [
"## Tool calling\n",
"\n",
"Starting in v0.2, `ChatNVIDIA` supports [bind_tools](https://api.python.langchain.com/en/latest/language_models/langchain_core.language_models.chat_models.BaseChatModel.html#langchain_core.language_models.chat_models.BaseChatModel.bind_tools).\n",
"\n",
"`ChatNVIDIA` provides integration with the variety of models on [build.nvidia.com](https://build.nvidia.com) as well as local NIMs. Not all these models are trained for tool calling. Be sure to select a model that does have tool calling for your experimention and applications."
]
},
{
"cell_type": "markdown",
"id": "6f7b535e",
"metadata": {},
"source": [
"You can get a list of models that are known to support tool calling with,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e36c8911",
"metadata": {},
"outputs": [],
"source": [
"tool_models = [\n",
" model for model in ChatNVIDIA.get_available_models() if model.supports_tools\n",
"]\n",
"tool_models"
]
},
{
"cell_type": "markdown",
"id": "b01d75a7",
"metadata": {},
"source": [
"With a tool capable model,"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "bd54f174",
"metadata": {},
"outputs": [],
"source": [
"from langchain_core.pydantic_v1 import Field\n",
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool\n",
"def get_current_weather(\n",
" location: str = Field(..., description=\"The location to get the weather for.\"),\n",
"):\n",
" \"\"\"Get the current weather for a location.\"\"\"\n",
" ...\n",
"\n",
"\n",
"llm = ChatNVIDIA(model=tool_models[0].id).bind_tools(tools=[get_current_weather])\n",
"response = llm.invoke(\"What is the weather in Boston?\")\n",
"response.tool_calls"
]
},
{
"cell_type": "markdown",
"id": "e08df68c",
"metadata": {},
"source": [
"See [How to use chat models to call tools](https://python.langchain.com/v0.2/docs/how_to/tool_calling/) for additional examples."
]
}
],
"metadata": {
@ -667,7 +651,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.2"
"version": "3.10.13"
}
},
"nbformat": 4,