From 278ef0bdcf56e4a6d4009cd415a6a0d64ee0cf79 Mon Sep 17 00:00:00 2001 From: Jacob Lee Date: Wed, 23 Aug 2023 13:02:26 -0700 Subject: [PATCH] Adds ChatOllama (#9628) @rlancemartin --------- Co-authored-by: Adilkhan Sarsen <54854336+adolkhan@users.noreply.github.com> Co-authored-by: Kim Minjong Co-authored-by: Harrison Chase Co-authored-by: Lance Martin Co-authored-by: Bagatur --- docs/extras/integrations/chat/ollama.ipynb | 382 ++++++++++++++++++ .../langchain/callbacks/streaming_stdout.py | 9 + .../langchain/chat_models/__init__.py | 2 + .../langchain/chat_models/anthropic.py | 2 +- .../langchain/langchain/chat_models/ollama.py | 122 ++++++ libs/langchain/langchain/llms/ollama.py | 49 ++- 6 files changed, 550 insertions(+), 16 deletions(-) create mode 100644 docs/extras/integrations/chat/ollama.ipynb create mode 100644 libs/langchain/langchain/chat_models/ollama.py diff --git a/docs/extras/integrations/chat/ollama.ipynb b/docs/extras/integrations/chat/ollama.ipynb new file mode 100644 index 00000000000..41a90405b78 --- /dev/null +++ b/docs/extras/integrations/chat/ollama.ipynb @@ -0,0 +1,382 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Ollama\n", + "\n", + "[Ollama](https://ollama.ai/) allows you to run open-source large language models, such as LLaMA2, locally.\n", + "\n", + "Ollama bundles model weights, configuration, and data into a single package, defined by a Modelfile. \n", + "\n", + "It optimizes setup and configuration details, including GPU usage.\n", + "\n", + "For a complete list of supported models and model variants, see the [Ollama model library](https://ollama.ai/library).\n", + "\n", + "## Setup\n", + "\n", + "First, follow [these instructions](https://github.com/jmorganca/ollama) to set up and run a local Ollama instance:\n", + "\n", + "* [Download](https://ollama.ai/download)\n", + "* Fetch a model via `ollama pull `\n", + "* e.g., for `Llama-7b`: `ollama pull llama2`\n", + "* This will download the most basic version of the model (e.g., minimum # parameters and 4-bit quantization)\n", + "* On Mac, it will download to:\n", + "\n", + "`~/.ollama/models/manifests/registry.ollama.ai/library//latest`\n", + "\n", + "* And we can specify a particular version, e.g., for `ollama pull vicuna:13b-v1.5-16k-q4_0`\n", + "* The file is here with the model version in place of `latest`\n", + "\n", + "`~/.ollama/models/manifests/registry.ollama.ai/library/vicuna/13b-v1.5-16k-q4_0`\n", + "\n", + "You can easily access models in a few ways:\n", + "\n", + "1/ if the app is running:\n", + "* All of your local models are automatically served on `localhost:11434`\n", + "* Select your model when setting `llm = Ollama(..., model=\":\")`\n", + "* If you set `llm = Ollama(..., model=\"> Use the following pieces of context to answer the question at the end. \n", + "If you don't know the answer, just say that you don't know, don't try to make up an answer. \n", + "Use three sentences maximum and keep the answer as concise as possible. <>\n", + "{context}\n", + "Question: {question}\n", + "Helpful Answer:[/INST]\"\"\"\n", + "QA_CHAIN_PROMPT = PromptTemplate(\n", + " input_variables=[\"context\", \"question\"],\n", + " template=template,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [], + "source": [ + "# Chat model\n", + "from langchain.chat_models import ChatOllama\n", + "from langchain.callbacks.manager import CallbackManager\n", + "from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler\n", + "chat_model = ChatOllama(model=\"llama2:13b-chat\",\n", + " verbose=True,\n", + " callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "# QA chain\n", + "from langchain.chains import RetrievalQA\n", + "qa_chain = RetrievalQA.from_chain_type(\n", + " chat_model,\n", + " retriever=vectorstore.as_retriever(),\n", + " chain_type_kwargs={\"prompt\": QA_CHAIN_PROMPT},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Based on the provided context, there are three approaches to task decomposition for AI agents:\n", + "\n", + "1. LLM with simple prompting, such as \"Steps for XYZ.\" or \"What are the subgoals for achieving XYZ?\"\n", + "2. Task-specific instructions, such as \"Write a story outline\" for writing a novel.\n", + "3. Human inputs." + ] + } + ], + "source": [ + "question = \"What are the various approaches to Task Decomposition for AI Agents?\"\n", + "result = qa_chain({\"query\": question})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "You can also get logging for tokens." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + " Based on the given context, here is the answer to the question \"What are the approaches to Task Decomposition?\"\n", + "\n", + "There are three approaches to task decomposition:\n", + "\n", + "1. LLM with simple prompting, such as \"Steps for XYZ.\" or \"What are the subgoals for achieving XYZ?\"\n", + "2. Using task-specific instructions, like \"Write a story outline\" for writing a novel.\n", + "3. With human inputs.{'model': 'llama2:13b-chat', 'created_at': '2023-08-23T15:37:51.469127Z', 'done': True, 'context': [1, 29871, 1, 29961, 25580, 29962, 518, 25580, 29962, 518, 25580, 29962, 3532, 14816, 29903, 6778, 4803, 278, 1494, 12785, 310, 3030, 304, 1234, 278, 1139, 472, 278, 1095, 29889, 29871, 13, 3644, 366, 1016, 29915, 29873, 1073, 278, 1234, 29892, 925, 1827, 393, 366, 1016, 29915, 29873, 1073, 29892, 1016, 29915, 29873, 1018, 304, 1207, 701, 385, 1234, 29889, 29871, 13, 11403, 2211, 25260, 7472, 322, 3013, 278, 1234, 408, 3022, 895, 408, 1950, 29889, 529, 829, 14816, 29903, 6778, 13, 5398, 26227, 508, 367, 2309, 313, 29896, 29897, 491, 365, 26369, 411, 2560, 9508, 292, 763, 376, 7789, 567, 363, 1060, 29979, 29999, 7790, 29876, 29896, 19602, 376, 5618, 526, 278, 1014, 1484, 1338, 363, 3657, 15387, 1060, 29979, 29999, 29973, 613, 313, 29906, 29897, 491, 773, 3414, 29899, 14940, 11994, 29936, 321, 29889, 29887, 29889, 376, 6113, 263, 5828, 27887, 1213, 363, 5007, 263, 9554, 29892, 470, 313, 29941, 29897, 411, 5199, 10970, 29889, 13, 13, 5398, 26227, 508, 367, 2309, 313, 29896, 29897, 491, 365, 26369, 411, 2560, 9508, 292, 763, 376, 7789, 567, 363, 1060, 29979, 29999, 7790, 29876, 29896, 19602, 376, 5618, 526, 278, 1014, 1484, 1338, 363, 3657, 15387, 1060, 29979, 29999, 29973, 613, 313, 29906, 29897, 491, 773, 3414, 29899, 14940, 11994, 29936, 321, 29889, 29887, 29889, 376, 6113, 263, 5828, 27887, 1213, 363, 5007, 263, 9554, 29892, 470, 313, 29941, 29897, 411, 5199, 10970, 29889, 13, 13, 1451, 16047, 267, 297, 1472, 29899, 8489, 18987, 322, 3414, 26227, 29901, 1858, 9450, 975, 263, 3309, 29891, 4955, 322, 17583, 3902, 8253, 278, 1650, 2913, 3933, 18066, 292, 29889, 365, 26369, 29879, 21117, 304, 10365, 13900, 746, 20050, 411, 15668, 4436, 29892, 3907, 963, 3109, 16424, 9401, 304, 25618, 1058, 5110, 515, 14260, 322, 1059, 29889, 13, 13, 1451, 16047, 267, 297, 1472, 29899, 8489, 18987, 322, 3414, 26227, 29901, 1858, 9450, 975, 263, 3309, 29891, 4955, 322, 17583, 3902, 8253, 278, 1650, 2913, 3933, 18066, 292, 29889, 365, 26369, 29879, 21117, 304, 10365, 13900, 746, 20050, 411, 15668, 4436, 29892, 3907, 963, 3109, 16424, 9401, 304, 25618, 1058, 5110, 515, 14260, 322, 1059, 29889, 13, 16492, 29901, 1724, 526, 278, 13501, 304, 9330, 897, 510, 3283, 29973, 13, 29648, 1319, 673, 10834, 29914, 25580, 29962, 518, 29914, 25580, 29962, 518, 29914, 25580, 29962, 29871, 16564, 373, 278, 2183, 3030, 29892, 1244, 338, 278, 1234, 304, 278, 1139, 376, 5618, 526, 278, 13501, 304, 9330, 897, 510, 3283, 3026, 13, 13, 8439, 526, 2211, 13501, 304, 3414, 26227, 29901, 13, 13, 29896, 29889, 365, 26369, 411, 2560, 9508, 292, 29892, 1316, 408, 376, 7789, 567, 363, 1060, 29979, 29999, 1213, 470, 376, 5618, 526, 278, 1014, 1484, 1338, 363, 3657, 15387, 1060, 29979, 29999, 3026, 13, 29906, 29889, 5293, 3414, 29899, 14940, 11994, 29892, 763, 376, 6113, 263, 5828, 27887, 29908, 363, 5007, 263, 9554, 29889, 13, 29941, 29889, 2973, 5199, 10970, 29889, 2], 'total_duration': 9514823750, 'load_duration': 795542, 'sample_count': 99, 'sample_duration': 68732000, 'prompt_eval_count': 146, 'prompt_eval_duration': 6206275000, 'eval_count': 98, 'eval_duration': 3229641000}\n" + ] + } + ], + "source": [ + "from langchain.schema import LLMResult\n", + "from langchain.callbacks.base import BaseCallbackHandler\n", + "\n", + "class GenerationStatisticsCallback(BaseCallbackHandler):\n", + " def on_llm_end(self, response: LLMResult, **kwargs) -> None:\n", + " print(response.generations[0][0].generation_info)\n", + " \n", + "callback_manager = CallbackManager([StreamingStdOutCallbackHandler(), GenerationStatisticsCallback()])\n", + "\n", + "chat_model = ChatOllama(model=\"llama2:13b-chat\",\n", + " verbose=True,\n", + " callback_manager=callback_manager)\n", + "\n", + "qa_chain = RetrievalQA.from_chain_type(\n", + " chat_model,\n", + " retriever=vectorstore.as_retriever(),\n", + " chain_type_kwargs={\"prompt\": QA_CHAIN_PROMPT},\n", + ")\n", + "\n", + "question = \"What are the approaches to Task Decomposition?\"\n", + "result = qa_chain({\"query\": question})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`eval_count` / (`eval_duration`/10e9) gets `tok / s`" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "30.343929867127645" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "98 / (3229641000/1000/1000/1000)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/callbacks/streaming_stdout.py b/libs/langchain/langchain/callbacks/streaming_stdout.py index 4acde4cebf0..2c71bc769c9 100644 --- a/libs/langchain/langchain/callbacks/streaming_stdout.py +++ b/libs/langchain/langchain/callbacks/streaming_stdout.py @@ -4,6 +4,7 @@ from typing import Any, Dict, List, Union from langchain.callbacks.base import BaseCallbackHandler from langchain.schema import AgentAction, AgentFinish, LLMResult +from langchain.schema.messages import BaseMessage class StreamingStdOutCallbackHandler(BaseCallbackHandler): @@ -14,6 +15,14 @@ class StreamingStdOutCallbackHandler(BaseCallbackHandler): ) -> None: """Run when LLM starts running.""" + def on_chat_model_start( + self, + serialized: Dict[str, Any], + messages: List[List[BaseMessage]], + **kwargs: Any + ) -> None: + """Run when LLM starts running.""" + def on_llm_new_token(self, token: str, **kwargs: Any) -> None: """Run on new LLM token. Only available when streaming is enabled.""" sys.stdout.write(token) diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index 0f26a852b8e..ee21a2377eb 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -27,6 +27,7 @@ from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.jinachat import JinaChat from langchain.chat_models.litellm import ChatLiteLLM from langchain.chat_models.mlflow_ai_gateway import ChatMLflowAIGateway +from langchain.chat_models.ollama import ChatOllama from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI from langchain.chat_models.vertexai import ChatVertexAI @@ -39,6 +40,7 @@ __all__ = [ "ChatAnthropic", "ChatGooglePalm", "ChatMLflowAIGateway", + "ChatOllama", "ChatVertexAI", "JinaChat", "HumanInputChatModel", diff --git a/libs/langchain/langchain/chat_models/anthropic.py b/libs/langchain/langchain/chat_models/anthropic.py index e201e382f98..ef1da63196e 100644 --- a/libs/langchain/langchain/chat_models/anthropic.py +++ b/libs/langchain/langchain/chat_models/anthropic.py @@ -32,7 +32,7 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): .. code-block:: python import anthropic - from langchain.llms import Anthropic + from langchain.chat_models import ChatAnthropic model = ChatAnthropic(model="", anthropic_api_key="my-api-key") """ diff --git a/libs/langchain/langchain/chat_models/ollama.py b/libs/langchain/langchain/chat_models/ollama.py new file mode 100644 index 00000000000..a1fdcd0accd --- /dev/null +++ b/libs/langchain/langchain/chat_models/ollama.py @@ -0,0 +1,122 @@ +import json +from typing import Any, Iterator, List, Optional + +from langchain.callbacks.manager import ( + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import BaseChatModel +from langchain.llms.ollama import _OllamaCommon +from langchain.schema import ChatResult +from langchain.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain.schema.output import ChatGeneration, ChatGenerationChunk + + +def _stream_response_to_chat_generation_chunk( + stream_response: str, +) -> ChatGenerationChunk: + """Convert a stream response to a generation chunk.""" + parsed_response = json.loads(stream_response) + generation_info = parsed_response if parsed_response.get("done") is True else None + return ChatGenerationChunk( + message=AIMessageChunk(content=parsed_response.get("response", "")), + generation_info=generation_info, + ) + + +class ChatOllama(BaseChatModel, _OllamaCommon): + """Ollama locally runs large language models. + + To use, follow the instructions at https://ollama.ai/. + + Example: + .. code-block:: python + + from langchain.chat_models import ChatOllama + ollama = ChatOllama(model="llama2") + """ + + @property + def _llm_type(self) -> str: + """Return type of chat model.""" + return "ollama-chat" + + @property + def lc_serializable(self) -> bool: + return True + + def _format_message_as_text(self, message: BaseMessage) -> str: + if isinstance(message, ChatMessage): + message_text = f"\n\n{message.role.capitalize()}: {message.content}" + elif isinstance(message, HumanMessage): + message_text = f"[INST] {message.content} [/INST]" + elif isinstance(message, AIMessage): + message_text = f"{message.content}" + elif isinstance(message, SystemMessage): + message_text = f"<> {message.content} <>" + else: + raise ValueError(f"Got unknown type {message}") + return message_text + + def _format_messages_as_text(self, messages: List[BaseMessage]) -> str: + return "\n".join( + [self._format_message_as_text(message) for message in messages] + ) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """Call out to Ollama's generate endpoint. + + Args: + messages: The list of base messages to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + Chat generations from the model + + Example: + .. code-block:: python + + response = ollama([ + HumanMessage(content="Tell me about the history of AI") + ]) + """ + + prompt = self._format_messages_as_text(messages) + final_chunk = super()._stream_with_aggregation( + prompt, stop=stop, run_manager=run_manager, verbose=self.verbose, **kwargs + ) + chat_generation = ChatGeneration( + message=AIMessage(content=final_chunk.text), + generation_info=final_chunk.generation_info, + ) + return ChatResult(generations=[chat_generation]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + prompt = self._format_messages_as_text(messages) + for stream_resp in self._create_stream(prompt, stop, **kwargs): + if stream_resp: + chunk = _stream_response_to_chat_generation_chunk(stream_resp) + yield chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + verbose=self.verbose, + ) diff --git a/libs/langchain/langchain/llms/ollama.py b/libs/langchain/langchain/llms/ollama.py index ec132c08b88..527bbcbc44b 100644 --- a/libs/langchain/langchain/llms/ollama.py +++ b/libs/langchain/langchain/llms/ollama.py @@ -144,9 +144,35 @@ class _OllamaCommon(BaseLanguageModel): ) return response.iter_lines(decode_unicode=True) + def _stream_with_aggregation( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + verbose: bool = False, + **kwargs: Any, + ) -> GenerationChunk: + final_chunk: Optional[GenerationChunk] = None + for stream_resp in self._create_stream(prompt, stop, **kwargs): + if stream_resp: + chunk = _stream_response_to_generation_chunk(stream_resp) + if final_chunk is None: + final_chunk = chunk + else: + final_chunk += chunk + if run_manager: + run_manager.on_llm_new_token( + chunk.text, + verbose=verbose, + ) + if final_chunk is None: + raise ValueError("No data received from Ollama stream.") + + return final_chunk + class Ollama(BaseLLM, _OllamaCommon): - """Ollama locally run large language models. + """Ollama locally runs large language models. To use, follow the instructions at https://ollama.ai/. @@ -191,20 +217,13 @@ class Ollama(BaseLLM, _OllamaCommon): # TODO: add caching here. generations = [] for prompt in prompts: - final_chunk: Optional[GenerationChunk] = None - for stream_resp in self._create_stream(prompt, stop, **kwargs): - if stream_resp: - chunk = _stream_response_to_generation_chunk(stream_resp) - if final_chunk is None: - final_chunk = chunk - else: - final_chunk += chunk - if run_manager: - run_manager.on_llm_new_token( - chunk.text, - verbose=self.verbose, - ) - + final_chunk = super()._stream_with_aggregation( + prompt, + stop=stop, + run_manager=run_manager, + verbose=self.verbose, + **kwargs, + ) generations.append([final_chunk]) return LLMResult(generations=generations)