community: add reka chat model integration (#27379)

This commit is contained in:
alex shengzhi li
2024-11-15 10:37:14 -08:00
committed by GitHub
parent d3252b7417
commit 39fcb476fd
7 changed files with 1629 additions and 0 deletions

View File

@@ -0,0 +1,593 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"---\n",
"sidebar_label: Reka\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChatReka\n",
"\n",
"This notebook provides a quick overview for getting started with Reka [chat models](../../concepts/chat_models.mdx). \n",
"\n",
"Reka has several chat models. You can find information about their latest models and their costs, context windows, and supported input types in the [Reka docs](https://docs.reka.ai/available-models).\n",
"\n",
"\n",
"\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatReka] | [langchain_community](https://python.langchain.com/api_reference/community/index.html) | ✅ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ✅ | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
"To access Reka models you'll need to create an Reka developer account, get an API key, and install the `langchain_community` integration package and the reka python package via 'pip install reka-api'.\n",
"\n",
"### Credentials\n",
"\n",
"Head to https://platform.reka.ai/ to sign up for Reka and generate an API key. Once you've done this set the REKA_API_KEY environment variable:"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain __ModuleName__ integration lives in the `langchain_community` package:"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"%pip install -qU langchain_community reka-api"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiation"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"REKA_API_KEY\"] = getpass.getpass(\"Enter your Reka API key: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Optional: use Langsmith to trace the execution of the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your Langsmith API key: \")"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models import ChatReka\n",
"\n",
"model = ChatReka()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' Hello! How can I help you today? If you have a question, need assistance, or just want to chat, feel free to let me know. Have a great day!\\n\\n', additional_kwargs={}, response_metadata={}, id='run-61522ec2-0587-4fd5-a492-5b205fd8860c-0')"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model.invoke(\"hi\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Images input "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" The image shows an indoor setting with no visible windows or natural light, and there are no indicators of weather conditions. The focus is on a cat sitting on a computer keyboard, and the background includes a computer monitor and various office supplies.\n"
]
}
],
"source": [
"from langchain_core.messages import HumanMessage\n",
"\n",
"image_url = \"https://v0.docs.reka.ai/_images/000000245576.jpg\"\n",
"\n",
"message = HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": \"describe the weather in this image\"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\"url\": image_url},\n",
" },\n",
" ],\n",
")\n",
"response = model.invoke([message])\n",
"print(response.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Multiple images as input"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" The first image features two German Shepherds, one adult and one puppy, in a vibrant, lush green setting. The adult dog is carrying a large stick in its mouth, running through what appears to be a grassy field, with the puppy following close behind. Both dogs exhibit striking physical characteristics typical of the breed, such as pointed ears and dense fur.\n",
"\n",
"The second image shows a close-up of a single cat with striking blue eyes, likely a breed like the Siberian or Maine Coon, in a natural outdoor setting. The cat's fur is lighter, possibly a mix of white and gray, and it has a more subdued expression compared to the dogs. The background is blurred, suggesting a focus on the cat's face.\n",
"\n",
"Overall, the differences lie in the subjects (two dogs vs. one cat), the setting (lush, vibrant grassy field vs. a more muted outdoor background), and the overall mood and activity depicted (playful and active vs. serene and focused).\n"
]
}
],
"source": [
"message = HumanMessage(\n",
" content=[\n",
" {\"type\": \"text\", \"text\": \"What are the difference between the two images? \"},\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"https://cdn.pixabay.com/photo/2019/07/23/13/51/shepherd-dog-4357790_1280.jpg\"\n",
" },\n",
" },\n",
" {\n",
" \"type\": \"image_url\",\n",
" \"image_url\": {\n",
" \"url\": \"https://cdn.pixabay.com/photo/2024/02/17/00/18/cat-8578562_1280.jpg\"\n",
" },\n",
" },\n",
" ],\n",
")\n",
"response = model.invoke([message])\n",
"print(response.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chaining"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=' Ich liebe Programmieren.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-ffc4ace1-b73a-4fb3-ad0f-57e60a0f9b8d-0')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | model\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Use use with tavtly api search"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Tool use and agent creation\n",
"\n",
"## Define the tools\n",
"\n",
"We first need to create the tools we want to use. Our main tool of choice will be Tavily - a search engine. We have a built-in tool in LangChain to easily use Tavily search engine as tool.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"os.environ[\"TAVILY_API_KEY\"] = getpass.getpass(\"Enter your Tavily API key: \")"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[{'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730484342, 'localtime': '2024-11-01 11:05'}, 'current': {'last_updated_epoch': 1730484000, 'last_updated': '2024-11-01 11:00', 'temp_c': 11.1, 'temp_f': 52.0, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.9, 'wind_kph': 4.7, 'wind_degree': 247, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.1, 'feelslike_f': 52.0, 'windchill_c': 10.3, 'windchill_f': 50.5, 'heatindex_c': 10.8, 'heatindex_f': 51.5, 'dewpoint_c': 10.4, 'dewpoint_f': 50.6, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.8, 'gust_kph': 6.1}}\"}, {'url': 'https://weatherspark.com/h/m/557/2024/1/Historical-Weather-in-January-2024-in-San-Francisco-California-United-States', 'content': 'San Francisco Temperature History January 2024\\nHourly Temperature in January 2024 in San Francisco\\nCompare San Francisco to another city:\\nCloud Cover in January 2024 in San Francisco\\nDaily Precipitation in January 2024 in San Francisco\\nObserved Weather in January 2024 in San Francisco\\nHours of Daylight and Twilight in January 2024 in San Francisco\\nSunrise & Sunset with Twilight in January 2024 in San Francisco\\nSolar Elevation and Azimuth in January 2024 in San Francisco\\nMoon Rise, Set & Phases in January 2024 in San Francisco\\nHumidity Comfort Levels in January 2024 in San Francisco\\nWind Speed in January 2024 in San Francisco\\nHourly Wind Speed in January 2024 in San Francisco\\nHourly Wind Direction in 2024 in San Francisco\\nAtmospheric Pressure in January 2024 in San Francisco\\nData Sources\\n See all nearby weather stations\\nLatest Report — 1:56 PM\\nFri, Jan 12, 2024\\xa0\\xa0\\xa0\\xa04 min ago\\xa0\\xa0\\xa0\\xa0UTC 21:56\\nCall Sign KSFO\\nTemp.\\n54.0°F\\nPrecipitation\\nNo Report\\nWind\\n8.1 mph\\nCloud Cover\\nMostly Cloudy\\n14,000 ft\\nRaw: KSFO 122156Z 08007KT 10SM FEW030 SCT050 BKN140 12/07 A3022 While having the tremendous advantages of temporal and spatial completeness, these reconstructions: (1) are based on computer models that may have model-based errors, (2) are coarsely sampled on a 50 km grid and are therefore unable to reconstruct the local variations of many microclimates, and (3) have particular difficulty with the weather in some coastal areas, especially small islands.\\n We further caution that our travel scores are only as good as the data that underpin them, that weather conditions at any given location and time are unpredictable and variable, and that the definition of the scores reflects a particular set of preferences that may not agree with those of any particular reader.\\n January 2024 Weather History in San Francisco California, United States\\nThe data for this report comes from the San Francisco International Airport.'}]\n"
]
}
],
"source": [
"from langchain_community.tools.tavily_search import TavilySearchResults\n",
"\n",
"search = TavilySearchResults(max_results=2)\n",
"search_results = search.invoke(\"what is the weather in SF\")\n",
"print(search_results)\n",
"# If we want, we can create other tools.\n",
"# Once we have all the tools we want, we can put them in a list that we will reference later.\n",
"tools = [search]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now see what it is like to enable this model to do tool calling. In order to enable that we use .bind_tools to give the language model knowledge of these tools\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"model_with_tools = model.bind_tools(tools)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now call the model. Let's first call it with a normal message, and see how it responds. We can look at both the content field as well as the tool_calls field.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ContentString: Hello! How can I help you today? If you have a question or need information on a specific topic, feel free to ask. Just type your search query and I'll do my best to assist using the available function.\n",
"\n",
"\n",
"ToolCalls: []\n"
]
}
],
"source": [
"from langchain_core.messages import HumanMessage\n",
"\n",
"response = model_with_tools.invoke([HumanMessage(content=\"Hi!\")])\n",
"\n",
"print(f\"ContentString: {response.content}\")\n",
"print(f\"ToolCalls: {response.tool_calls}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now, let's try calling it with some input that would expect a tool to be called.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ContentString: \n",
"ToolCalls: [{'name': 'tavily_search_results_json', 'args': {'query': 'weather in SF'}, 'id': '2548c622-3553-42df-8220-39fde0632bdb', 'type': 'tool_call'}]\n"
]
}
],
"source": [
"response = model_with_tools.invoke([HumanMessage(content=\"What's the weather in SF?\")])\n",
"\n",
"print(f\"ContentString: {response.content}\")\n",
"print(f\"ToolCalls: {response.tool_calls}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can see that there's now no text content, but there is a tool call! It wants us to call the Tavily Search tool.\n",
"\n",
"This isn't calling that tool yet - it's just telling us to. In order to actually call it, we'll want to create our agent."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Create the agent"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Now that we have defined the tools and the LLM, we can create the agent. We will be using LangGraph to construct the agent. Currently, we are using a high level interface to construct the agent, but the nice thing about LangGraph is that this high-level interface is backed by a low-level, highly controllable API in case you want to modify the agent logic.\n",
"\n",
"Now, we can initialize the agent with the LLM and the tools.\n",
"\n",
"Note that we are passing in the model, not model_with_tools. That is because `create_react_agent` will call `.bind_tools` for us under the hood."
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from langgraph.prebuilt import create_react_agent\n",
"\n",
"agent_executor = create_react_agent(model, tools)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's now try it out on an example where it should be invoking the tool"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HumanMessage(content='hi!', additional_kwargs={}, response_metadata={}, id='0ab1f3c7-9079-42d4-8a8a-13af5f6c226b'),\n",
" AIMessage(content=' Hello! How can I help you today? If you have a question or need information on a specific topic, feel free to ask. For example, you can start with a search query like \"latest news on climate change\" or \"biography of Albert Einstein\".\\n\\n', additional_kwargs={}, response_metadata={}, id='run-276d9dcd-13f3-481d-b562-8fe3962d9ba1-0')]"
]
},
"execution_count": 15,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = agent_executor.invoke({\"messages\": [HumanMessage(content=\"hi!\")]})\n",
"\n",
"response[\"messages\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In order to see exactly what is happening under the hood (and to make sure it's not calling a tool) we can take a look at the LangSmith trace: https://smith.langchain.com/public/2372d9c5-855a-45ee-80f2-94b63493563d/r"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[HumanMessage(content='whats the weather in sf?', additional_kwargs={}, response_metadata={}, id='af276c61-3df7-4241-8cb0-81d1f1477bb3'),\n",
" AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '86da84b8-0d44-444f-8448-7f134f9afa41', 'type': 'function', 'function': {'name': 'tavily_search_results_json', 'arguments': '{\"query\": \"weather in SF\"}'}}]}, response_metadata={}, id='run-abe1b8e2-98a6-4f69-8f95-278ac8c141ff-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'weather in SF'}, 'id': '86da84b8-0d44-444f-8448-7f134f9afa41', 'type': 'tool_call'}]),\n",
" ToolMessage(content='[{\"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1730483436, \\'localtime\\': \\'2024-11-01 10:50\\'}, \\'current\\': {\\'last_updated_epoch\\': 1730483100, \\'last_updated\\': \\'2024-11-01 10:45\\', \\'temp_c\\': 11.4, \\'temp_f\\': 52.5, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Mist\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/143.png\\', \\'code\\': 1030}, \\'wind_mph\\': 2.2, \\'wind_kph\\': 3.6, \\'wind_degree\\': 237, \\'wind_dir\\': \\'WSW\\', \\'pressure_mb\\': 1019.0, \\'pressure_in\\': 30.08, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 100, \\'cloud\\': 100, \\'feelslike_c\\': 11.8, \\'feelslike_f\\': 53.2, \\'windchill_c\\': 11.2, \\'windchill_f\\': 52.1, \\'heatindex_c\\': 11.7, \\'heatindex_f\\': 53.0, \\'dewpoint_c\\': 10.1, \\'dewpoint_f\\': 50.1, \\'vis_km\\': 2.8, \\'vis_miles\\': 1.0, \\'uv\\': 3.0, \\'gust_mph\\': 3.0, \\'gust_kph\\': 4.9}}\"}, {\"url\": \"https://www.timeanddate.com/weather/@z-us-94134/ext\", \"content\": \"Forecasted weather conditions the coming 2 weeks for San Francisco. Sign in. News. News Home; Astronomy News; Time Zone News ... 01 pm: Mon Nov 11: 60 / 53 °F: Tstorms early. Broken clouds. 54 °F: 19 mph: ↑: 70%: 58%: 0.20\\\\\" 0 (Low) 6:46 am: 5:00 pm * Updated Monday, October 28, 2024 2:24:10 pm San Francisco time - Weather by CustomWeather\"}]', name='tavily_search_results_json', id='de8c8d78-ae24-4a8a-9c73-795c1e4fdd41', tool_call_id='86da84b8-0d44-444f-8448-7f134f9afa41', artifact={'query': 'weather in SF', 'follow_up_questions': None, 'answer': None, 'images': [], 'results': [{'title': 'Weather in San Francisco', 'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730483436, 'localtime': '2024-11-01 10:50'}, 'current': {'last_updated_epoch': 1730483100, 'last_updated': '2024-11-01 10:45', 'temp_c': 11.4, 'temp_f': 52.5, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.2, 'wind_kph': 3.6, 'wind_degree': 237, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.8, 'feelslike_f': 53.2, 'windchill_c': 11.2, 'windchill_f': 52.1, 'heatindex_c': 11.7, 'heatindex_f': 53.0, 'dewpoint_c': 10.1, 'dewpoint_f': 50.1, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.0, 'gust_kph': 4.9}}\", 'score': 0.9989501, 'raw_content': None}, {'title': 'San Francisco, USA 14 day weather forecast - timeanddate.com', 'url': 'https://www.timeanddate.com/weather/@z-us-94134/ext', 'content': 'Forecasted weather conditions the coming 2 weeks for San Francisco. Sign in. News. News Home; Astronomy News; Time Zone News ... 01 pm: Mon Nov 11: 60 / 53 °F: Tstorms early. Broken clouds. 54 °F: 19 mph: ↑: 70%: 58%: 0.20\" 0 (Low) 6:46 am: 5:00 pm * Updated Monday, October 28, 2024 2:24:10 pm San Francisco time - Weather by CustomWeather', 'score': 0.9938309, 'raw_content': None}], 'response_time': 3.56}),\n",
" AIMessage(content=' The current weather in San Francisco is mist with a temperature of 11.4°C (52.5°F). There is a 100% humidity and the wind is blowing at 2.2 mph from the WSW direction. The forecast for the coming weeks shows a mix of cloudy and partly cloudy days with some chances of thunderstorms. Temperatures are expected to range between 53°F and 60°F.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-de4207d6-e8e8-4382-ad16-4de0dcf0812a-0')]"
]
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"response = agent_executor.invoke(\n",
" {\"messages\": [HumanMessage(content=\"whats the weather in sf?\")]}\n",
")\n",
"response[\"messages\"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can check out the LangSmith trace to make sure it's calling the search tool effectively.\n",
"\n",
"https://smith.langchain.com/public/013ef704-654b-4447-8428-637b343d646e/r"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We've seen how the agent can be called with `.invoke` to get a final response. If the agent executes multiple steps, this may take a while. To show intermediate progress, we can stream back messages as they occur.\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'agent': {'messages': [AIMessage(content='', additional_kwargs={'tool_calls': [{'id': '2457d3ea-f001-4b8c-a1ed-3dc3d1381639', 'type': 'function', 'function': {'name': 'tavily_search_results_json', 'arguments': '{\"query\": \"weather in San Francisco\"}'}}]}, response_metadata={}, id='run-0363deab-84d2-4319-bb1e-b55b47fe2274-0', tool_calls=[{'name': 'tavily_search_results_json', 'args': {'query': 'weather in San Francisco'}, 'id': '2457d3ea-f001-4b8c-a1ed-3dc3d1381639', 'type': 'tool_call'}])]}}\n",
"----\n",
"{'tools': {'messages': [ToolMessage(content='[{\"url\": \"https://www.weatherapi.com/\", \"content\": \"{\\'location\\': {\\'name\\': \\'San Francisco\\', \\'region\\': \\'California\\', \\'country\\': \\'United States of America\\', \\'lat\\': 37.775, \\'lon\\': -122.4183, \\'tz_id\\': \\'America/Los_Angeles\\', \\'localtime_epoch\\': 1730483636, \\'localtime\\': \\'2024-11-01 10:53\\'}, \\'current\\': {\\'last_updated_epoch\\': 1730483100, \\'last_updated\\': \\'2024-11-01 10:45\\', \\'temp_c\\': 11.4, \\'temp_f\\': 52.5, \\'is_day\\': 1, \\'condition\\': {\\'text\\': \\'Mist\\', \\'icon\\': \\'//cdn.weatherapi.com/weather/64x64/day/143.png\\', \\'code\\': 1030}, \\'wind_mph\\': 2.2, \\'wind_kph\\': 3.6, \\'wind_degree\\': 237, \\'wind_dir\\': \\'WSW\\', \\'pressure_mb\\': 1019.0, \\'pressure_in\\': 30.08, \\'precip_mm\\': 0.0, \\'precip_in\\': 0.0, \\'humidity\\': 100, \\'cloud\\': 100, \\'feelslike_c\\': 11.8, \\'feelslike_f\\': 53.2, \\'windchill_c\\': 11.2, \\'windchill_f\\': 52.1, \\'heatindex_c\\': 11.7, \\'heatindex_f\\': 53.0, \\'dewpoint_c\\': 10.1, \\'dewpoint_f\\': 50.1, \\'vis_km\\': 2.8, \\'vis_miles\\': 1.0, \\'uv\\': 3.0, \\'gust_mph\\': 3.0, \\'gust_kph\\': 4.9}}\"}, {\"url\": \"https://weather.com/weather/monthly/l/69bedc6a5b6e977993fb3e5344e3c06d8bc36a1fb6754c3ddfb5310a3c6d6c87\", \"content\": \"Weather.com brings you the most accurate monthly weather forecast for San Francisco, CA with average/record and high/low temperatures, precipitation and more. ... 11. 66 ° 55 ° 12. 69 ° 60\"}]', name='tavily_search_results_json', id='e675f99b-130f-4e98-8477-badd45938d9d', tool_call_id='2457d3ea-f001-4b8c-a1ed-3dc3d1381639', artifact={'query': 'weather in San Francisco', 'follow_up_questions': None, 'answer': None, 'images': [], 'results': [{'title': 'Weather in San Francisco', 'url': 'https://www.weatherapi.com/', 'content': \"{'location': {'name': 'San Francisco', 'region': 'California', 'country': 'United States of America', 'lat': 37.775, 'lon': -122.4183, 'tz_id': 'America/Los_Angeles', 'localtime_epoch': 1730483636, 'localtime': '2024-11-01 10:53'}, 'current': {'last_updated_epoch': 1730483100, 'last_updated': '2024-11-01 10:45', 'temp_c': 11.4, 'temp_f': 52.5, 'is_day': 1, 'condition': {'text': 'Mist', 'icon': '//cdn.weatherapi.com/weather/64x64/day/143.png', 'code': 1030}, 'wind_mph': 2.2, 'wind_kph': 3.6, 'wind_degree': 237, 'wind_dir': 'WSW', 'pressure_mb': 1019.0, 'pressure_in': 30.08, 'precip_mm': 0.0, 'precip_in': 0.0, 'humidity': 100, 'cloud': 100, 'feelslike_c': 11.8, 'feelslike_f': 53.2, 'windchill_c': 11.2, 'windchill_f': 52.1, 'heatindex_c': 11.7, 'heatindex_f': 53.0, 'dewpoint_c': 10.1, 'dewpoint_f': 50.1, 'vis_km': 2.8, 'vis_miles': 1.0, 'uv': 3.0, 'gust_mph': 3.0, 'gust_kph': 4.9}}\", 'score': 0.9968992, 'raw_content': None}, {'title': 'Monthly Weather Forecast for San Francisco, CA - weather.com', 'url': 'https://weather.com/weather/monthly/l/69bedc6a5b6e977993fb3e5344e3c06d8bc36a1fb6754c3ddfb5310a3c6d6c87', 'content': 'Weather.com brings you the most accurate monthly weather forecast for San Francisco, CA with average/record and high/low temperatures, precipitation and more. ... 11. 66 ° 55 ° 12. 69 ° 60', 'score': 0.97644573, 'raw_content': None}], 'response_time': 3.16})]}}\n",
"----\n",
"{'agent': {'messages': [AIMessage(content=' The current weather in San Francisco is misty with a temperature of 11.4°C (52.5°F). The wind is blowing at 2.2 mph (3.6 kph) from the WSW direction. The humidity is at 100%, and the visibility is 2.8 km (1.0 miles). The monthly forecast shows average temperatures ranging from 55°F to 66°F (13°C to 19°C) with some precipitation expected.\\n\\n', additional_kwargs={}, response_metadata={}, id='run-99ccf444-d286-4244-a5a5-7b1b511153a6-0')]}}\n",
"----\n"
]
}
],
"source": [
"for chunk in agent_executor.stream(\n",
" {\"messages\": [HumanMessage(content=\"whats the weather in sf?\")]}\n",
"):\n",
" print(chunk)\n",
" print(\"----\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API reference"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"https://docs.reka.ai/quick-start"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain_reka",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@@ -86,6 +86,7 @@ telethon>=1.28.5,<2
tidb-vector>=0.0.3,<1.0.0 tidb-vector>=0.0.3,<1.0.0
timescale-vector==0.0.1 timescale-vector==0.0.1
tqdm>=4.48.0 tqdm>=4.48.0
tiktoken>=0.8.0
tree-sitter>=0.20.2,<0.21 tree-sitter>=0.20.2,<0.21
tree-sitter-languages>=1.8.0,<2 tree-sitter-languages>=1.8.0,<2
upstash-redis>=1.1.0,<2 upstash-redis>=1.1.0,<2

View File

@@ -155,6 +155,9 @@ if TYPE_CHECKING:
from langchain_community.chat_models.promptlayer_openai import ( from langchain_community.chat_models.promptlayer_openai import (
PromptLayerChatOpenAI, PromptLayerChatOpenAI,
) )
from langchain_community.chat_models.reka import (
ChatReka,
)
from langchain_community.chat_models.sambanova import ( from langchain_community.chat_models.sambanova import (
ChatSambaNovaCloud, ChatSambaNovaCloud,
ChatSambaStudio, ChatSambaStudio,
@@ -226,6 +229,7 @@ __all__ = [
"ChatOllama", "ChatOllama",
"ChatOpenAI", "ChatOpenAI",
"ChatPerplexity", "ChatPerplexity",
"ChatReka",
"ChatPremAI", "ChatPremAI",
"ChatSambaNovaCloud", "ChatSambaNovaCloud",
"ChatSambaStudio", "ChatSambaStudio",
@@ -290,6 +294,7 @@ _module_lookup = {
"ChatOCIModelDeploymentTGI": "langchain_community.chat_models.oci_data_science", "ChatOCIModelDeploymentTGI": "langchain_community.chat_models.oci_data_science",
"ChatOllama": "langchain_community.chat_models.ollama", "ChatOllama": "langchain_community.chat_models.ollama",
"ChatOpenAI": "langchain_community.chat_models.openai", "ChatOpenAI": "langchain_community.chat_models.openai",
"ChatReka": "langchain_community.chat_models.reka",
"ChatPerplexity": "langchain_community.chat_models.perplexity", "ChatPerplexity": "langchain_community.chat_models.perplexity",
"ChatSambaNovaCloud": "langchain_community.chat_models.sambanova", "ChatSambaNovaCloud": "langchain_community.chat_models.sambanova",
"ChatSambaStudio": "langchain_community.chat_models.sambanova", "ChatSambaStudio": "langchain_community.chat_models.sambanova",

View File

@@ -0,0 +1,435 @@
import json
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, ConfigDict, Field, model_validator
DEFAULT_REKA_MODEL = "reka-flash"
ContentType = Union[str, List[Union[str, Dict[str, Any]]]]
def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single content item."""
if item["type"] == "image_url":
image_url = item["image_url"]
if isinstance(image_url, dict) and "url" in image_url:
# If it's in LangChain format, extract the URL value
item["image_url"] = image_url["url"]
return item
def process_content(content: ContentType) -> List[Dict[str, Any]]:
"""Process content to handle both text and media inputs,
returning a list of content items."""
if isinstance(content, str):
return [{"type": "text", "text": content}]
elif isinstance(content, list):
result = []
for item in content:
if isinstance(item, str):
result.append({"type": "text", "text": item})
elif isinstance(item, dict):
result.append(process_content_item(item))
else:
raise ValueError(f"Invalid content item format: {item}")
return result
else:
raise ValueError("Invalid content format")
def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any]]:
"""Convert LangChain messages to Reka message format."""
reka_messages: List[Dict[str, Any]] = []
system_message: Optional[str] = None
for message in messages:
if isinstance(message, SystemMessage):
if system_message is None:
if isinstance(message.content, str):
system_message = message.content
else:
raise TypeError("SystemMessage content must be a string.")
else:
raise ValueError("Multiple system messages are not supported.")
elif isinstance(message, HumanMessage):
processed_content = process_content(message.content)
if system_message:
if (
processed_content
and isinstance(processed_content[0], dict)
and processed_content[0].get("type") == "text"
and "text" in processed_content[0]
):
processed_content[0]["text"] = (
f"{system_message}\n{processed_content[0]['text']}"
)
else:
processed_content.insert(
0, {"type": "text", "text": system_message}
)
system_message = None
reka_messages.append({"role": "user", "content": processed_content})
elif isinstance(message, AIMessage):
reka_message: Dict[str, Any] = {"role": "assistant"}
if message.content:
processed_content = process_content(message.content)
reka_message["content"] = processed_content
if "tool_calls" in message.additional_kwargs:
tool_calls = message.additional_kwargs["tool_calls"]
formatted_tool_calls = []
for tool_call in tool_calls:
formatted_tool_call = {
"id": tool_call["id"],
"name": tool_call["function"]["name"],
"parameters": json.loads(tool_call["function"]["arguments"]),
}
formatted_tool_calls.append(formatted_tool_call)
reka_message["tool_calls"] = formatted_tool_calls
reka_messages.append(reka_message)
elif isinstance(message, ToolMessage):
content_list: List[Dict[str, Any]] = []
content_list.append(
{
"tool_call_id": message.tool_call_id,
"output": json.dumps({"status": message.content}),
}
)
reka_messages.append(
{
"role": "tool_output",
"content": content_list,
}
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return reka_messages
class ChatReka(BaseChatModel):
"""Reka chat large language models."""
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = Field(default=DEFAULT_REKA_MODEL)
max_tokens: int = Field(default=256)
temperature: Optional[float] = None
streaming: bool = False
default_request_timeout: Optional[float] = None
max_retries: int = 2
reka_api_key: Optional[str] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_config = ConfigDict(extra="forbid")
token_counter: Optional[
Callable[[Union[str, BaseMessage, List[BaseMessage]]], int]
] = None
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that API key and Python package exist in the environment."""
reka_api_key = values.get("reka_api_key")
reka_api_key = get_from_dict_or_env(
{"reka_api_key": reka_api_key}, "reka_api_key", "REKA_API_KEY"
)
values["reka_api_key"] = reka_api_key
try:
# Import reka libraries here
from reka.client import AsyncReka, Reka
values["client"] = Reka(
api_key=reka_api_key,
)
values["async_client"] = AsyncReka(
api_key=reka_api_key,
)
except ImportError:
raise ImportError(
"Could not import Reka Python package. "
"Please install it with `pip install reka-api`."
)
return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Reka API."""
params = {
"model": self.model,
"max_tokens": self.max_tokens,
}
if self.temperature is not None:
params["temperature"] = self.temperature
return {**params, **self.model_kwargs}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "reka-chat"
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
stream = self.client.chat.create_stream(messages=reka_messages, **params)
for chunk in stream:
content = chunk.responses[0].chunk.content
chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
if run_manager:
run_manager.on_llm_new_token(content, chunk=chat_chunk)
yield chat_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
stream = self.async_client.chat.create_stream(messages=reka_messages, **params)
async for chunk in stream:
content = chunk.responses[0].chunk.content
chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
if run_manager:
await run_manager.on_llm_new_token(content, chunk=chat_chunk)
yield chat_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
return generate_from_stream(
self._stream(messages, stop=stop, run_manager=run_manager, **kwargs)
)
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
response = self.client.chat.create(messages=reka_messages, **params)
if response.responses[0].message.tool_calls:
tool_calls = response.responses[0].message.tool_calls
message = AIMessage(
content="", # Empty string instead of None
additional_kwargs={
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.parameters),
},
}
for tc in tool_calls
]
},
)
else:
content = response.responses[0].message.content
# Ensure content is never None
message = AIMessage(content=content if content is not None else "")
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
return await agenerate_from_stream(
self._astream(messages, stop=stop, run_manager=run_manager, **kwargs)
)
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
response = await self.async_client.chat.create(messages=reka_messages, **params)
if response.responses[0].message.tool_calls:
tool_calls = response.responses[0].message.tool_calls
message = AIMessage(
content="", # Empty string instead of None
additional_kwargs={
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.parameters),
},
}
for tc in tool_calls
]
},
)
else:
content = response.responses[0].message.content
# Ensure content is never None
message = AIMessage(content=content if content is not None else "")
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, input: Union[str, BaseMessage, List[BaseMessage]]) -> int:
"""Calculate number of tokens.
Args:
input: Either a string, a single BaseMessage, or a list of BaseMessages.
Returns:
int: Number of tokens in the input.
Raises:
ImportError: If tiktoken is not installed.
ValueError: If message content is not a string.
"""
if self.token_counter is not None:
return self.token_counter(input)
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"Please install it with `pip install tiktoken`."
)
encoding = tiktoken.get_encoding("cl100k_base")
if isinstance(input, str):
return len(encoding.encode(input))
elif isinstance(input, BaseMessage):
content = input.content
if not isinstance(content, str):
raise ValueError(
f"Message content must be a string, got {type(content)}"
)
return len(encoding.encode(content))
elif isinstance(input, list):
total = 0
for msg in input:
content = msg.content
if not isinstance(content, str):
raise ValueError(
f"Message content must be a string, got {type(content)}"
)
total += len(encoding.encode(content))
return total
else:
raise TypeError(f"Unsupported input type: {type(input)}")
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: str = "auto",
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
The `tool_choice` parameter controls how the model uses the tools you pass.
There are three available options:
- `"auto"`: Lets the model decide whether or not to invoke a tool. This is the
recommended way to do function calling with our models.
- `"none"`: Disables tool calling. In this case, even if you pass tools to
the model, the model will not invoke any tools.
- `"tool"`: Forces the model to invoke one or more of the tools it has
been passed.
Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
tool_choice: Controls how the model uses the tools you pass.
Options are "auto", "none", or "tool". Defaults to "auto".
strict:
If True, model output is guaranteed to exactly match the JSON Schema
provided in the tool definition.
If False, input schema will not be validated
and model output will not be validated.
If None, ``strict`` argument will not
be passed to the model.
kwargs: Any additional parameters are passed directly to the model.
Returns:
Runnable: An executable chain or component.
"""
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
# Ensure tool_choice is one of the allowed options
if tool_choice not in ("auto", "none", "tool"):
raise ValueError(
f"Invalid tool_choice '{tool_choice}' provided. "
"Tool choice must be one of: 'auto', 'none', or 'tool'."
)
# Map tool_choice to the parameter expected by the Reka API
kwargs["tool_choice"] = tool_choice
# Pass the tools and updated kwargs to the model
formatted_tools = [tool["function"] for tool in formatted_tools]
return super().bind(tools=formatted_tools, **kwargs)

View File

@@ -0,0 +1,222 @@
"""Test Reka API wrapper."""
import logging
from typing import List
import pytest
from langchain_core.messages import (
AIMessage,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, LLMResult
from langchain_community.chat_models.reka import ChatReka
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_call() -> None:
"""Test a simple call to Reka."""
chat = ChatReka(model="reka-flash", verbose=True)
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
logger.debug(f"Response content: {response.content}")
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_generate() -> None:
"""Test the generate method of Reka."""
chat = ChatReka(model="reka-flash", verbose=True)
chat_messages: List[List[BaseMessage]] = [
[HumanMessage(content="How many toes do dogs have?")]
]
messages_copy = [messages.copy() for messages in chat_messages]
result: LLMResult = chat.generate(chat_messages)
assert isinstance(result, LLMResult)
for response in result.generations[0]:
assert isinstance(response, ChatGeneration)
assert isinstance(response.text, str)
assert response.text == response.message.content
logger.debug(f"Generated response: {response.text}")
assert chat_messages == messages_copy
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_streaming() -> None:
"""Test streaming tokens from Reka."""
chat = ChatReka(model="reka-flash", streaming=True, verbose=True)
message = HumanMessage(content="Tell me a story.")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
logger.debug(f"Streaming response content: {response.content}")
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_streaming_callback() -> None:
"""Test that streaming correctly invokes callbacks."""
callback_handler = FakeCallbackHandler()
chat = ChatReka(
model="reka-flash",
streaming=True,
callbacks=[callback_handler],
verbose=True,
)
message = HumanMessage(content="Write me a sentence with 10 words.")
chat.invoke([message])
assert callback_handler.llm_streams > 1
logger.debug(f"Number of LLM streams: {callback_handler.llm_streams}")
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
async def test_reka_async_streaming_callback() -> None:
"""Test asynchronous streaming with callbacks."""
callback_handler = FakeCallbackHandler()
chat = ChatReka(
model="reka-flash",
streaming=True,
callbacks=[callback_handler],
verbose=True,
)
chat_messages: List[BaseMessage] = [
HumanMessage(content="How many toes do dogs have?")
]
result: LLMResult = await chat.agenerate([chat_messages])
assert callback_handler.llm_streams > 1
assert isinstance(result, LLMResult)
for response in result.generations[0]:
assert isinstance(response, ChatGeneration)
assert isinstance(response.text, str)
assert response.text == response.message.content
logger.debug(f"Async generated response: {response.text}")
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_tool_usage_integration() -> None:
"""Test tool usage with Reka API integration."""
# Initialize the ChatReka model with tools and verbose logging
chat_reka = ChatReka(model="reka-flash", verbose=True)
tools = [
{
"type": "function",
"function": {
"name": "get_product_availability",
"description": (
"Determine whether a product is currently in stock given "
"a product ID."
),
"parameters": {
"type": "object",
"properties": {
"product_id": {
"type": "string",
"description": (
"The unique product ID to check availability for"
),
},
},
"required": ["product_id"],
},
},
},
]
chat_reka_with_tools = chat_reka.bind_tools(tools)
# Start a conversation
messages: List[BaseMessage] = [
HumanMessage(content="Is product A12345 in stock right now?")
]
# Get the initial response
response = chat_reka_with_tools.invoke(messages)
assert isinstance(response, AIMessage)
logger.debug(f"Initial AI message: {response.content}")
# Check if the model wants to use a tool
if "tool_calls" in response.additional_kwargs:
tool_calls = response.additional_kwargs["tool_calls"]
for tool_call in tool_calls:
function_name = tool_call["function"]["name"]
arguments = tool_call["function"]["arguments"]
logger.debug(
f"Tool call requested: {function_name} with arguments {arguments}"
)
# Simulate executing the tool
tool_output = "AVAILABLE"
tool_message = ToolMessage(
content=tool_output, tool_call_id=tool_call["id"]
)
messages.append(response)
messages.append(tool_message)
final_response = chat_reka_with_tools.invoke(messages)
assert isinstance(final_response, AIMessage)
logger.debug(f"Final AI message: {final_response.content}")
# Assert that the response message is non-empty
assert final_response.content, "The final response content is empty."
else:
pytest.fail("The model did not request a tool.")
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_system_message() -> None:
"""Test Reka with system message."""
chat = ChatReka(model="reka-flash", verbose=True)
messages = [
SystemMessage(content="You are a helpful AI that speaks like Shakespeare."),
HumanMessage(content="Tell me about the weather today."),
]
response = chat.invoke(messages)
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)
logger.debug(f"Response with system message: {response.content}")
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_system_message_multi_turn() -> None:
"""Test multi-turn conversation with system message."""
chat = ChatReka(model="reka-flash", verbose=True)
messages = [
SystemMessage(content="You are a math tutor who explains concepts simply."),
HumanMessage(content="What is a prime number?"),
]
# First turn
response1 = chat.invoke(messages)
assert isinstance(response1, AIMessage)
messages.append(response1)
# Second turn
messages.append(HumanMessage(content="Can you give me an example?"))
response2 = chat.invoke(messages)
assert isinstance(response2, AIMessage)
logger.debug(f"First response: {response1.content}")
logger.debug(f"Second response: {response2.content}")

View File

@@ -45,6 +45,7 @@ EXPECTED_ALL = [
"ChatVertexAI", "ChatVertexAI",
"ChatYandexGPT", "ChatYandexGPT",
"ChatYuan2", "ChatYuan2",
"ChatReka",
"ChatZhipuAI", "ChatZhipuAI",
"ErnieBotChat", "ErnieBotChat",
"FakeListChatModel", "FakeListChatModel",

View File

@@ -0,0 +1,372 @@
import json
import os
from typing import Any, Dict, List
from unittest.mock import MagicMock, patch
import pytest
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from pydantic import ValidationError
from langchain_community.chat_models import ChatReka
from langchain_community.chat_models.reka import (
convert_to_reka_messages,
process_content,
)
os.environ["REKA_API_KEY"] = "dummy_key"
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_model_param() -> None:
llm = ChatReka(model="reka-flash")
assert llm.model == "reka-flash"
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_model_kwargs() -> None:
llm = ChatReka(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_incorrect_field() -> None:
"""Test that providing an incorrect field raises ValidationError."""
with pytest.raises(ValidationError):
ChatReka(unknown_field="bar") # type: ignore
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_initialization() -> None:
"""Test Reka initialization."""
# Verify that ChatReka can be initialized using a secret key provided
# as a parameter rather than an environment variable.
ChatReka(model="reka-flash", reka_api_key="test_key")
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
@pytest.mark.parametrize(
("content", "expected"),
[
("Hello", [{"type": "text", "text": "Hello"}]),
(
[
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": "https://example.com/image.jpg",
},
],
[
{"type": "text", "text": "Describe this image"},
{"type": "image_url", "image_url": "https://example.com/image.jpg"},
],
),
(
[
{"type": "text", "text": "Hello"},
{
"type": "image_url",
"image_url": {"url": "https://example.com/image.jpg"},
},
],
[
{"type": "text", "text": "Hello"},
{"type": "image_url", "image_url": "https://example.com/image.jpg"},
],
),
],
)
def test_process_content(content: Any, expected: List[Dict[str, Any]]) -> None:
result = process_content(content)
assert result == expected
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
@pytest.mark.parametrize(
("messages", "expected"),
[
(
[HumanMessage(content="Hello")],
[{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
),
(
[
HumanMessage(
content=[
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": "https://example.com/image.jpg",
},
]
),
AIMessage(content="It's a beautiful landscape."),
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "Describe this image"},
{
"type": "image_url",
"image_url": "https://example.com/image.jpg",
},
],
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "It's a beautiful landscape."}
],
},
],
),
],
)
def test_convert_to_reka_messages(
messages: List[BaseMessage], expected: List[Dict[str, Any]]
) -> None:
result = convert_to_reka_messages(messages)
assert result == expected
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_streaming() -> None:
llm = ChatReka(streaming=True)
assert llm.streaming is True
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_temperature() -> None:
llm = ChatReka(temperature=0.5)
assert llm.temperature == 0.5
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_max_tokens() -> None:
llm = ChatReka(max_tokens=100)
assert llm.max_tokens == 100
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_default_params() -> None:
llm = ChatReka()
assert llm._default_params == {
"max_tokens": 256,
"model": "reka-flash",
}
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_identifying_params() -> None:
"""Test that ChatReka identifies its default parameters correctly."""
chat = ChatReka(model="reka-flash", temperature=0.7, max_tokens=256)
expected_params = {
"model": "reka-flash",
"temperature": 0.7,
"max_tokens": 256,
}
assert chat._default_params == expected_params
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_llm_type() -> None:
llm = ChatReka()
assert llm._llm_type == "reka-chat"
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_reka_tool_use_with_mocked_response() -> None:
with patch("reka.client.Reka") as MockReka:
# Mock the Reka client
mock_client = MockReka.return_value
mock_chat = MagicMock()
mock_client.chat = mock_chat
mock_response = MagicMock()
mock_message = MagicMock()
mock_tool_call = MagicMock()
mock_tool_call.id = "tool_call_1"
mock_tool_call.name = "search_tool"
mock_tool_call.parameters = {"query": "LangChain"}
mock_message.tool_calls = [mock_tool_call]
mock_message.content = None
mock_response.responses = [MagicMock(message=mock_message)]
mock_chat.create.return_value = mock_response
llm = ChatReka()
messages: List[BaseMessage] = [HumanMessage(content="Tell me about LangChain")]
result = llm._generate(messages)
assert len(result.generations) == 1
ai_message = result.generations[0].message
assert ai_message.content == ""
assert "tool_calls" in ai_message.additional_kwargs
tool_calls = ai_message.additional_kwargs["tool_calls"]
assert len(tool_calls) == 1
assert tool_calls[0]["id"] == "tool_call_1"
assert tool_calls[0]["function"]["name"] == "search_tool"
assert tool_calls[0]["function"]["arguments"] == json.dumps(
{"query": "LangChain"}
)
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
@pytest.mark.parametrize(
("messages", "expected"),
[
# Test single system message
(
[
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hello"),
],
[
{
"role": "user",
"content": [
{"type": "text", "text": "You are a helpful assistant.\nHello"}
],
}
],
),
# Test system message with multiple messages
(
[
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="What is 2+2?"),
AIMessage(content="4"),
HumanMessage(content="Thanks!"),
],
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "You are a helpful assistant.\nWhat is 2+2?",
}
],
},
{"role": "assistant", "content": [{"type": "text", "text": "4"}]},
{"role": "user", "content": [{"type": "text", "text": "Thanks!"}]},
],
),
# Test system message with media content
(
[
SystemMessage(content="Hi."),
HumanMessage(
content=[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": "https://example.com/image.jpg",
},
]
),
],
[
{
"role": "user",
"content": [
{
"type": "text",
"text": "Hi.\nWhat's in this image?",
},
{
"type": "image_url",
"image_url": "https://example.com/image.jpg",
},
],
},
],
),
],
)
def test_system_message_handling(
messages: List[BaseMessage], expected: List[Dict[str, Any]]
) -> None:
"""Test that system messages are handled correctly."""
result = convert_to_reka_messages(messages)
assert result == expected
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_multiple_system_messages_error() -> None:
"""Test that multiple system messages raise an error."""
messages = [
SystemMessage(content="System message 1"),
SystemMessage(content="System message 2"),
HumanMessage(content="Hello"),
]
with pytest.raises(ValueError, match="Multiple system messages are not supported."):
convert_to_reka_messages(messages)
@pytest.mark.skip(
reason="Dependency conflict w/ other dependencies for urllib3 versions."
)
def test_get_num_tokens() -> None:
"""Test that token counting works correctly for different input types."""
llm = ChatReka()
import tiktoken
encoding = tiktoken.get_encoding("cl100k_base")
# Test string input
text = "What is the weather like today?"
expected_tokens = len(encoding.encode(text))
assert llm.get_num_tokens(text) == expected_tokens
# Test BaseMessage input
message = HumanMessage(content="What is the weather like today?")
assert isinstance(message.content, str)
expected_tokens = len(encoding.encode(message.content))
assert llm.get_num_tokens(message) == expected_tokens
# Test List[BaseMessage] input
messages = [
SystemMessage(content="You are a helpful assistant."),
HumanMessage(content="Hi!"),
AIMessage(content="Hello! How can I help you today?"),
]
expected_tokens = sum(
len(encoding.encode(msg.content))
for msg in messages
if isinstance(msg.content, str)
)
assert llm.get_num_tokens(messages) == expected_tokens
# Test empty message list
assert llm.get_num_tokens([]) == 0