diff --git a/docs/docs/modules/model_io/llms/custom_llm.ipynb b/docs/docs/modules/model_io/llms/custom_llm.ipynb index 1c9a60c000d..da8735ffd1c 100644 --- a/docs/docs/modules/model_io/llms/custom_llm.ipynb +++ b/docs/docs/modules/model_io/llms/custom_llm.ipynb @@ -9,44 +9,73 @@ "\n", "This notebook goes over how to create a custom LLM wrapper, in case you want to use your own LLM or a different wrapper than one that is supported in LangChain.\n", "\n", + "Wrapping your LLM with the standard `LLM` interface allow you to use your LLM in existing LangChain programs with minimal code modifications!\n", + "\n", + "As an bonus, your LLM will automatically become a LangChain `Runnable` and will benefit from some optimizations out of the box, async support, the `astream_events` API, etc.\n", + "\n", + "## Implementation\n", + "\n", "There are only two required things that a custom LLM needs to implement:\n", "\n", - "- A `_call` method that takes in a string, some optional stop words, and returns a string.\n", - "- A `_llm_type` property that returns a string. Used for logging purposes only.\n", "\n", - "There is a second optional thing it can implement:\n", + "| Method | Description |\n", + "|---------------|---------------------------------------------------------------------------|\n", + "| `_call` | Takes in a string and some optional stop words, and returns a string. Used by `invoke`. |\n", + "| `_llm_type` | A property that returns a string, used for logging purposes only. \n", "\n", - "- An `_identifying_params` property that is used to help with printing of this class. Should return a dictionary.\n", "\n", - "Let's implement a very simple custom LLM that just returns the first n characters of the input." + "\n", + "Optional implementations: \n", + "\n", + "\n", + "| Method | Description |\n", + "|----------------------|-----------------------------------------------------------------------------------------------------------|\n", + "| `_identifying_params` | Used to help with identifying the model and printing the LLM; should return a dictionary. This is a **@property**. |\n", + "| `_acall` | Provides an async native implementation of `_call`, used by `ainvoke`. |\n", + "| `_stream` | Method to stream the output token by token. |\n", + "| `_astream` | Provides an async native implementation of `_stream`; in newer LangChain versions, defaults to `_stream`. |\n", + "\n", + "\n", + "\n", + "Let's implement a simple custom LLM that just returns the first n characters of the input." ] }, { "cell_type": "code", - "execution_count": 2, - "id": "a65696a0", - "metadata": {}, + "execution_count": 1, + "id": "2e9bb32f-6fd1-46ac-b32f-d175663710c0", + "metadata": { + "tags": [] + }, "outputs": [], "source": [ - "from typing import Any, List, Mapping, Optional\n", + "from typing import Any, Dict, Iterator, List, Mapping, Optional\n", "\n", "from langchain_core.callbacks.manager import CallbackManagerForLLMRun\n", - "from langchain_core.language_models.llms import LLM" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "d5ceff02", - "metadata": {}, - "outputs": [], - "source": [ - "class CustomLLM(LLM):\n", - " n: int\n", + "from langchain_core.language_models.llms import LLM\n", + "from langchain_core.outputs import GenerationChunk\n", "\n", - " @property\n", - " def _llm_type(self) -> str:\n", - " return \"custom\"\n", + "\n", + "class CustomLLM(LLM):\n", + " \"\"\"A custom chat model that echoes the first `n` characters of the input.\n", + "\n", + " When contributing an implementation to LangChain, carefully document\n", + " the model including the initialization parameters, include\n", + " an example of how to initialize the model and include any relevant\n", + " links to the underlying models documentation or API.\n", + "\n", + " Example:\n", + "\n", + " .. code-block:: python\n", + "\n", + " model = CustomChatModel(n=2)\n", + " result = model.invoke([HumanMessage(content=\"hello\")])\n", + " result = model.batch([[HumanMessage(content=\"hello\")],\n", + " [HumanMessage(content=\"world\")]])\n", + " \"\"\"\n", + "\n", + " n: int\n", + " \"\"\"The number of characters from the last message of the prompt to be echoed.\"\"\"\n", "\n", " def _call(\n", " self,\n", @@ -55,47 +84,133 @@ " run_manager: Optional[CallbackManagerForLLMRun] = None,\n", " **kwargs: Any,\n", " ) -> str:\n", + " \"\"\"Run the LLM on the given input.\n", + "\n", + " Override this method to implement the LLM logic.\n", + "\n", + " Args:\n", + " prompt: The prompt to generate from.\n", + " stop: Stop words to use when generating. Model output is cut off at the\n", + " first occurrence of any of the stop substrings.\n", + " If stop tokens are not supported consider raising NotImplementedError.\n", + " run_manager: Callback manager for the run.\n", + " **kwargs: Arbitrary additional keyword arguments. These are usually passed\n", + " to the model provider API call.\n", + "\n", + " Returns:\n", + " The model output as a string. Actual completions SHOULD NOT include the prompt.\n", + " \"\"\"\n", " if stop is not None:\n", " raise ValueError(\"stop kwargs are not permitted.\")\n", " return prompt[: self.n]\n", "\n", + " def _stream(\n", + " self,\n", + " prompt: str,\n", + " stop: Optional[List[str]] = None,\n", + " run_manager: Optional[CallbackManagerForLLMRun] = None,\n", + " **kwargs: Any,\n", + " ) -> Iterator[GenerationChunk]:\n", + " \"\"\"Stream the LLM on the given prompt.\n", + "\n", + " This method should be overridden by subclasses that support streaming.\n", + "\n", + " If not implemented, the default behavior of calls to stream will be to\n", + " fallback to the non-streaming version of the model and return\n", + " the output as a single chunk.\n", + "\n", + " Args:\n", + " prompt: The prompt to generate from.\n", + " stop: Stop words to use when generating. Model output is cut off at the\n", + " first occurrence of any of these substrings.\n", + " run_manager: Callback manager for the run.\n", + " **kwargs: Arbitrary additional keyword arguments. These are usually passed\n", + " to the model provider API call.\n", + "\n", + " Returns:\n", + " An iterator of GenerationChunks.\n", + " \"\"\"\n", + " for char in prompt[: self.n]:\n", + " chunk = GenerationChunk(text=char)\n", + " if run_manager:\n", + " run_manager.on_llm_new_token(chunk.text, chunk=chunk)\n", + "\n", + " yield chunk\n", + "\n", " @property\n", - " def _identifying_params(self) -> Mapping[str, Any]:\n", - " \"\"\"Get the identifying parameters.\"\"\"\n", - " return {\"n\": self.n}" + " def _identifying_params(self) -> Dict[str, Any]:\n", + " \"\"\"Return a dictionary of identifying parameters.\"\"\"\n", + " return {\n", + " # The model name allows users to specify custom token counting\n", + " # rules in LLM monitoring applications (e.g., in LangSmith users\n", + " # can provide per token pricing for their model and monitor\n", + " # costs for the given LLM.)\n", + " \"model_name\": \"CustomChatModel\",\n", + " }\n", + "\n", + " @property\n", + " def _llm_type(self) -> str:\n", + " \"\"\"Get the type of language model used by this chat model. Used for logging purposes only.\"\"\"\n", + " return \"custom\"" ] }, { "cell_type": "markdown", - "id": "714dede0", - "metadata": {}, + "id": "f614fb7b-e476-4d81-821b-57a2ebebe21c", + "metadata": { + "tags": [] + }, "source": [ - "We can now use this as an any other LLM." + "### Let's test it 🧪" + ] + }, + { + "cell_type": "markdown", + "id": "e3feae15-4afc-49f4-8542-93867d4ea769", + "metadata": { + "tags": [] + }, + "source": [ + "This LLM will implement the standard `Runnable` interface of LangChain which many of the LangChain abstractions support!" ] }, { "cell_type": "code", - "execution_count": 10, - "id": "10e5ece6", - "metadata": {}, - "outputs": [], + "execution_count": 2, + "id": "dfff4a95-99b2-4dba-b80d-9c3855046ef1", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1mCustomLLM\u001b[0m\n", + "Params: {'model_name': 'CustomChatModel'}\n" + ] + } + ], "source": [ - "llm = CustomLLM(n=10)" + "llm = CustomLLM(n=5)\n", + "print(llm)" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 3, "id": "8cd49199", - "metadata": {}, + "metadata": { + "tags": [] + }, "outputs": [ { "data": { "text/plain": [ - "'This is a '" + "'This '" ] }, - "execution_count": 11, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -105,39 +220,209 @@ ] }, { - "cell_type": "markdown", - "id": "bbfebea1", - "metadata": {}, + "cell_type": "code", + "execution_count": 4, + "id": "511b3cb1-9c6f-49b6-9002-a2ec490632b0", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "'world'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], "source": [ - "We can also print the LLM and see its custom print." + "await llm.ainvoke(\"world\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "d9d5bec2-d60a-4ebd-a97d-ac32c98ab02f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['woof ', 'meow ']" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm.batch([\"woof woof woof\", \"meow meow meow\"])" ] }, { "cell_type": "code", "execution_count": 6, - "id": "9c33fa19", - "metadata": {}, + "id": "fe246b29-7a93-4bef-8861-389445598c25", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "data": { + "text/plain": [ + "['woof ', 'meow ']" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "await llm.abatch([\"woof woof woof\", \"meow meow meow\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "3a67c38f-b83b-4eb9-a231-441c55ee8c82", + "metadata": { + "tags": [] + }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1mCustomLLM\u001b[0m\n", - "Params: {'n': 10}\n" + "h|e|l|l|o|" ] } ], "source": [ - "print(llm)" + "async for token in llm.astream(\"hello\"):\n", + " print(token, end=\"|\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "id": "b62c282b-3a35-4529-aac4-2c2f0916790e", + "metadata": {}, + "source": [ + "Let's confirm that in integrates nicely with other `LangChain` APIs." ] }, { "cell_type": "code", - "execution_count": null, - "id": "6dac3f47", - "metadata": {}, + "execution_count": 15, + "id": "d5578e74-7fa8-4673-afee-7a59d442aaff", + "metadata": { + "tags": [] + }, "outputs": [], - "source": [] + "source": [ + "from langchain_core.prompts import ChatPromptTemplate" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "672ff664-8673-4832-9f4f-335253880141", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "prompt = ChatPromptTemplate.from_messages(\n", + " [(\"system\", \"you are a bot\"), (\"human\", \"{input}\")]\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "c400538a-9146-4c93-9fac-293d8f9ca6bf", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "llm = CustomLLM(n=7)\n", + "chain = prompt | llm" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "080964af-3e2d-4573-85cb-0d7cc58a6f42", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n", + "{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n", + "{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}\n", + "{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\\nHuman: hello there!']}}}\n", + "{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}\n", + "{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}\n", + "{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}\n", + "{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}\n" + ] + } + ], + "source": [ + "idx = 0\n", + "async for event in chain.astream_events({\"input\": \"hello there!\"}, version=\"v1\"):\n", + " print(event)\n", + " idx += 1\n", + " if idx > 7:\n", + " # Truncate\n", + " break" + ] + }, + { + "cell_type": "markdown", + "id": "a85e848a-5316-4318-b770-3f8fd34f4231", + "metadata": {}, + "source": [ + "## Contributing\n", + "\n", + "We appreciate all chat model integration contributions. \n", + "\n", + "Here's a checklist to help make sure your contribution gets added to LangChain:\n", + "\n", + "Documentation:\n", + "\n", + "* The model contains doc-strings for all initialization arguments, as these will be surfaced in the [APIReference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).\n", + "* The class doc-string for the model contains a link to the model API if the model is powered by a service.\n", + "\n", + "Tests:\n", + "\n", + "* [ ] Add unit or integration tests to the overridden methods. Verify that `invoke`, `ainvoke`, `batch`, `stream` work if you've over-ridden the corresponding code.\n", + "\n", + "Streaming (if you're implementing it):\n", + "\n", + "* [ ] Make sure to invoke the `on_llm_new_token` callback\n", + "* [ ] `on_llm_new_token` is invoked BEFORE yielding the chunk\n", + "\n", + "Stop Token Behavior:\n", + "\n", + "* [ ] Stop token should be respected\n", + "* [ ] Stop token should be INCLUDED as part of the response\n", + "\n", + "Secret API Keys:\n", + "\n", + "* [ ] If your model connects to an API it will likely accept API keys as part of its initialization. Use Pydantic's `SecretStr` type for secrets, so they don't get accidentally printed out when folks print the model." + ] } ], "metadata": { @@ -156,7 +441,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.1" + "version": "3.11.4" } }, "nbformat": 4, diff --git a/libs/core/langchain_core/language_models/llms.py b/libs/core/langchain_core/language_models/llms.py index fc741dbf448..e307e8035c3 100644 --- a/libs/core/langchain_core/language_models/llms.py +++ b/libs/core/langchain_core/language_models/llms.py @@ -557,6 +557,25 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: + """Stream the LLM on the given prompt. + + This method should be overridden by subclasses that support streaming. + + If not implemented, the default behavior of calls to stream will be to + fallback to the non-streaming version of the model and return + the output as a single chunk. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + An iterator of GenerationChunks. + """ raise NotImplementedError() async def _astream( @@ -566,6 +585,23 @@ class BaseLLM(BaseLanguageModel[str], ABC): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: + """An async version of the _stream method. + + The default implementation uses the synchronous _stream method and wraps it in + an async iterator. Subclasses that need to provide a true async implementation + should override this method. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of these substrings. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + An async iterator of GenerationChunks. + """ iterator = await run_in_executor( None, self._stream, @@ -1182,10 +1218,28 @@ class BaseLLM(BaseLanguageModel[str], ABC): class LLM(BaseLLM): - """Base LLM abstract class. + """This class exposes a simple interface for implementing a custom LLM. - The purpose of this class is to expose a simpler interface for working - with LLMs, rather than expect the user to implement the full _generate method. + You should subclass this class and implement the following: + + - `_call` method: Run the LLM on the given prompt and input (used by `invoke`). + - `_identifying_params` property: Return a dictionary of the identifying parameters + This is critical for caching and tracing purposes. Identifying parameters + is a dict that identifies the LLM. + It should mostly include a `model_name`. + + Optional: Override the following methods to provide more optimizations: + + - `_acall`: Provide a native async version of the `_call` method. + If not provided, will delegate to the synchronous version using + `run_in_executor`. (Used by `ainvoke`). + - `_stream`: Stream the LLM on the given prompt and input. + `stream` will use `_stream` if provided, otherwise it + use `_call` and output will arrive in one chunk. + - `_astream`: Override to provide a native async version of the `_stream` method. + `astream` will use `_astream` if provided, otherwise it will implement + a fallback behavior that will use `_stream` if `_stream` is implemented, + and use `_acall` if `_stream` is not implemented. """ @abstractmethod @@ -1196,7 +1250,22 @@ class LLM(BaseLLM): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - """Run the LLM on the given prompt and input.""" + """Run the LLM on the given input. + + Override this method to implement the LLM logic. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + If stop tokens are not supported consider raising NotImplementedError. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output as a string. SHOULD NOT include the prompt. + """ async def _acall( self, @@ -1205,7 +1274,24 @@ class LLM(BaseLLM): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: - """Run the LLM on the given prompt and input.""" + """Async version of the _call method. + + The default implementation delegates to the synchronous _call method using + `run_in_executor`. Subclasses that need to provide a true async implementation + should override this method to reduce the overhead of using `run_in_executor`. + + Args: + prompt: The prompt to generate from. + stop: Stop words to use when generating. Model output is cut off at the + first occurrence of any of the stop substrings. + If stop tokens are not supported consider raising NotImplementedError. + run_manager: Callback manager for the run. + **kwargs: Arbitrary additional keyword arguments. These are usually passed + to the model provider API call. + + Returns: + The model output as a string. SHOULD NOT include the prompt. + """ return await run_in_executor( None, self._call,