diff --git a/cookbook/selecting_llms_based_on_context_length.ipynb b/cookbook/selecting_llms_based_on_context_length.ipynb new file mode 100644 index 00000000000..a166acffc1d --- /dev/null +++ b/cookbook/selecting_llms_based_on_context_length.ipynb @@ -0,0 +1,175 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "e93283d1", + "metadata": {}, + "source": [ + "# Selecting LLMs based on Context Length\n", + "\n", + "Different LLMs have different context lengths. As a very immediate an practical example, OpenAI has two versions of GPT-3.5-Turbo: one with 4k context, another with 16k context. This notebook shows how to route between them based on input." + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "cc453450", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "from langchain.schema.prompt import PromptValue\n", + "from langchain.schema.messages import BaseMessage\n", + "from langchain.chat_models import ChatOpenAI\n", + "from langchain.schema.output_parser import StrOutputParser\n", + "from typing import Union, Sequence" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1cec6a10", + "metadata": {}, + "outputs": [], + "source": [ + "short_context_model = ChatOpenAI(model=\"gpt-3.5-turbo\")\n", + "long_context_model = ChatOpenAI(model=\"gpt-3.5-turbo-16k\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "772da153", + "metadata": {}, + "outputs": [], + "source": [ + "def get_context_length(prompt: PromptValue):\n", + " messages = prompt.to_messages()\n", + " tokens = short_context_model.get_num_tokens_from_messages(messages)\n", + " return tokens" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "db771e20", + "metadata": {}, + "outputs": [], + "source": [ + "prompt = PromptTemplate.from_template(\"Summarize this passage: {context}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "af057e2f", + "metadata": {}, + "outputs": [], + "source": [ + "def choose_model(prompt: PromptValue):\n", + " context_len = get_context_length(prompt)\n", + " if context_len < 30:\n", + " print(\"short model\")\n", + " return short_context_model\n", + " else:\n", + " print(\"long model\")\n", + " return long_context_model" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "id": "84f3e07d", + "metadata": {}, + "outputs": [], + "source": [ + "chain = prompt | choose_model | StrOutputParser()" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "id": "d8b14f8f", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "short model\n" + ] + }, + { + "data": { + "text/plain": [ + "'The passage mentions that a frog visited a pond.'" + ] + }, + "execution_count": 26, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.invoke({\"context\": \"a frog went to a pond\"})" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "id": "70ebd3dd", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "long model\n" + ] + }, + { + "data": { + "text/plain": [ + "'The passage describes a frog that moved from one pond to another and perched on a log.'" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chain.invoke({\"context\": \"a frog went to a pond and sat on a log and went to a different pond\"})" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a7e29fef", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}