diff --git a/docs/docs/integrations/chat/ai21.ipynb b/docs/docs/integrations/chat/ai21.ipynb index efdb43cab42..38429aa9153 100644 --- a/docs/docs/integrations/chat/ai21.ipynb +++ b/docs/docs/integrations/chat/ai21.ipynb @@ -50,18 +50,18 @@ }, { "cell_type": "code", - "execution_count": null, "id": "62e0dbc3", "metadata": { "tags": [] }, - "outputs": [], "source": [ "import os\n", "from getpass import getpass\n", "\n", "os.environ[\"AI21_API_KEY\"] = getpass()" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -73,14 +73,14 @@ }, { "cell_type": "code", - "execution_count": null, "id": "7c2e19d3-7c58-4470-9e1a-718b27a32056", "metadata": {}, - "outputs": [], "source": [ "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -115,15 +115,15 @@ }, { "cell_type": "code", - "execution_count": 2, "id": "c40756fb-cbf8-4d44-a293-3989d707237e", "metadata": {}, - "outputs": [], "source": [ "from langchain_ai21 import ChatAI21\n", "\n", "llm = ChatAI21(model=\"jamba-instruct\", temperature=0)" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -135,21 +135,8 @@ }, { "cell_type": "code", - "execution_count": 3, "id": "46b982dc-5d8a-46da-a711-81c03ccd6adc", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content=\"J'adore programmer.\", id='run-2e8d16d6-a06e-45cb-8d0c-1c8208645033-0')" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "messages = [\n", " (\n", @@ -160,7 +147,9 @@ "]\n", "ai_msg = llm.invoke(messages)\n", "ai_msg" - ] + ], + "outputs": [], + "execution_count": null }, { "cell_type": "markdown", @@ -174,7 +163,6 @@ }, { "cell_type": "code", - "execution_count": 4, "id": "39353473fce5dd2e", "metadata": { "collapsed": false, @@ -182,18 +170,6 @@ "outputs_hidden": false } }, - "outputs": [ - { - "data": { - "text/plain": [ - "AIMessage(content='Ich liebe das Programmieren.', id='run-e1bd82dc-1a7e-4b2e-bde9-ac995929ac0f-0')" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "from langchain_core.prompts import ChatPromptTemplate\n", "\n", @@ -215,7 +191,95 @@ " \"input\": \"I love programming.\",\n", " }\n", ")" - ] + ], + "outputs": [], + "execution_count": null + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "# Tool Calls / Function Calling", + "id": "39c0ccd229927eab" + }, + { + "metadata": {}, + "cell_type": "markdown", + "source": "This example shows how to use tool calling with AI21 models:", + "id": "2bf6b40be07fe2d4" + }, + { + "metadata": {}, + "cell_type": "code", + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "from langchain_ai21.chat_models import ChatAI21\n", + "from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage\n", + "from langchain_core.tools import tool\n", + "from langchain_core.utils.function_calling import convert_to_openai_tool\n", + "\n", + "os.environ[\"AI21_API_KEY\"] = getpass()\n", + "\n", + "\n", + "@tool\n", + "def get_weather(location: str, date: str) -> str:\n", + " \"\"\"“Provide the weather for the specified location on the given date.”\"\"\"\n", + " if location == \"New York\" and date == \"2024-12-05\":\n", + " return \"25 celsius\"\n", + " elif location == \"New York\" and date == \"2024-12-06\":\n", + " return \"27 celsius\"\n", + " elif location == \"London\" and date == \"2024-12-05\":\n", + " return \"22 celsius\"\n", + " return \"32 celsius\"\n", + "\n", + "\n", + "llm = ChatAI21(model=\"jamba-1.5-mini\")\n", + "\n", + "llm_with_tools = llm.bind_tools([convert_to_openai_tool(get_weather)])\n", + "\n", + "chat_messages = [\n", + " SystemMessage(\n", + " content=\"You are a helpful assistant. You can use the provided tools \"\n", + " \"to assist with various tasks and provide accurate information\"\n", + " )\n", + "]\n", + "\n", + "human_messages = [\n", + " HumanMessage(\n", + " content=\"What is the forecast for the weather in New York on December 5, 2024?\"\n", + " ),\n", + " HumanMessage(content=\"And what about the 2024-12-06?\"),\n", + " HumanMessage(content=\"OK, thank you.\"),\n", + " HumanMessage(content=\"What is the expected weather in London on December 5, 2024?\"),\n", + "]\n", + "\n", + "\n", + "for human_message in human_messages:\n", + " print(f\"User: {human_message.content}\")\n", + " chat_messages.append(human_message)\n", + " response = llm_with_tools.invoke(chat_messages)\n", + " chat_messages.append(response)\n", + " if response.tool_calls:\n", + " tool_call = response.tool_calls[0]\n", + " if tool_call[\"name\"] == \"get_weather\":\n", + " weather = get_weather.invoke(\n", + " {\n", + " \"location\": tool_call[\"args\"][\"location\"],\n", + " \"date\": tool_call[\"args\"][\"date\"],\n", + " }\n", + " )\n", + " chat_messages.append(\n", + " ToolMessage(content=weather, tool_call_id=tool_call[\"id\"])\n", + " )\n", + " llm_answer = llm_with_tools.invoke(chat_messages)\n", + " print(f\"Assistant: {llm_answer.content}\")\n", + " else:\n", + " print(f\"Assistant: {response.content}\")" + ], + "id": "a181a28df77120fb", + "outputs": [], + "execution_count": null }, { "cell_type": "markdown",