docs: Update documentation for custom LLMs (#19972)

Update documentation for customizing LLMs
This commit is contained in:
Eugene Yurtsev 2024-04-11 12:21:27 -04:00 committed by GitHub
parent 799714c629
commit 653489a1a9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 429 additions and 58 deletions

View File

@ -9,44 +9,73 @@
"\n",
"This notebook goes over how to create a custom LLM wrapper, in case you want to use your own LLM or a different wrapper than one that is supported in LangChain.\n",
"\n",
"Wrapping your LLM with the standard `LLM` interface allow you to use your LLM in existing LangChain programs with minimal code modifications!\n",
"\n",
"As an bonus, your LLM will automatically become a LangChain `Runnable` and will benefit from some optimizations out of the box, async support, the `astream_events` API, etc.\n",
"\n",
"## Implementation\n",
"\n",
"There are only two required things that a custom LLM needs to implement:\n",
"\n",
"- A `_call` method that takes in a string, some optional stop words, and returns a string.\n",
"- A `_llm_type` property that returns a string. Used for logging purposes only.\n",
"\n",
"There is a second optional thing it can implement:\n",
"| Method | Description |\n",
"|---------------|---------------------------------------------------------------------------|\n",
"| `_call` | Takes in a string and some optional stop words, and returns a string. Used by `invoke`. |\n",
"| `_llm_type` | A property that returns a string, used for logging purposes only. \n",
"\n",
"- An `_identifying_params` property that is used to help with printing of this class. Should return a dictionary.\n",
"\n",
"Let's implement a very simple custom LLM that just returns the first n characters of the input."
"\n",
"Optional implementations: \n",
"\n",
"\n",
"| Method | Description |\n",
"|----------------------|-----------------------------------------------------------------------------------------------------------|\n",
"| `_identifying_params` | Used to help with identifying the model and printing the LLM; should return a dictionary. This is a **@property**. |\n",
"| `_acall` | Provides an async native implementation of `_call`, used by `ainvoke`. |\n",
"| `_stream` | Method to stream the output token by token. |\n",
"| `_astream` | Provides an async native implementation of `_stream`; in newer LangChain versions, defaults to `_stream`. |\n",
"\n",
"\n",
"\n",
"Let's implement a simple custom LLM that just returns the first n characters of the input."
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a65696a0",
"metadata": {},
"execution_count": 1,
"id": "2e9bb32f-6fd1-46ac-b32f-d175663710c0",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"from typing import Any, List, Mapping, Optional\n",
"from typing import Any, Dict, Iterator, List, Mapping, Optional\n",
"\n",
"from langchain_core.callbacks.manager import CallbackManagerForLLMRun\n",
"from langchain_core.language_models.llms import LLM"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "d5ceff02",
"metadata": {},
"outputs": [],
"source": [
"class CustomLLM(LLM):\n",
" n: int\n",
"from langchain_core.language_models.llms import LLM\n",
"from langchain_core.outputs import GenerationChunk\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" return \"custom\"\n",
"\n",
"class CustomLLM(LLM):\n",
" \"\"\"A custom chat model that echoes the first `n` characters of the input.\n",
"\n",
" When contributing an implementation to LangChain, carefully document\n",
" the model including the initialization parameters, include\n",
" an example of how to initialize the model and include any relevant\n",
" links to the underlying models documentation or API.\n",
"\n",
" Example:\n",
"\n",
" .. code-block:: python\n",
"\n",
" model = CustomChatModel(n=2)\n",
" result = model.invoke([HumanMessage(content=\"hello\")])\n",
" result = model.batch([[HumanMessage(content=\"hello\")],\n",
" [HumanMessage(content=\"world\")]])\n",
" \"\"\"\n",
"\n",
" n: int\n",
" \"\"\"The number of characters from the last message of the prompt to be echoed.\"\"\"\n",
"\n",
" def _call(\n",
" self,\n",
@ -55,47 +84,133 @@
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ) -> str:\n",
" \"\"\"Run the LLM on the given input.\n",
"\n",
" Override this method to implement the LLM logic.\n",
"\n",
" Args:\n",
" prompt: The prompt to generate from.\n",
" stop: Stop words to use when generating. Model output is cut off at the\n",
" first occurrence of any of the stop substrings.\n",
" If stop tokens are not supported consider raising NotImplementedError.\n",
" run_manager: Callback manager for the run.\n",
" **kwargs: Arbitrary additional keyword arguments. These are usually passed\n",
" to the model provider API call.\n",
"\n",
" Returns:\n",
" The model output as a string. Actual completions SHOULD NOT include the prompt.\n",
" \"\"\"\n",
" if stop is not None:\n",
" raise ValueError(\"stop kwargs are not permitted.\")\n",
" return prompt[: self.n]\n",
"\n",
" def _stream(\n",
" self,\n",
" prompt: str,\n",
" stop: Optional[List[str]] = None,\n",
" run_manager: Optional[CallbackManagerForLLMRun] = None,\n",
" **kwargs: Any,\n",
" ) -> Iterator[GenerationChunk]:\n",
" \"\"\"Stream the LLM on the given prompt.\n",
"\n",
" This method should be overridden by subclasses that support streaming.\n",
"\n",
" If not implemented, the default behavior of calls to stream will be to\n",
" fallback to the non-streaming version of the model and return\n",
" the output as a single chunk.\n",
"\n",
" Args:\n",
" prompt: The prompt to generate from.\n",
" stop: Stop words to use when generating. Model output is cut off at the\n",
" first occurrence of any of these substrings.\n",
" run_manager: Callback manager for the run.\n",
" **kwargs: Arbitrary additional keyword arguments. These are usually passed\n",
" to the model provider API call.\n",
"\n",
" Returns:\n",
" An iterator of GenerationChunks.\n",
" \"\"\"\n",
" for char in prompt[: self.n]:\n",
" chunk = GenerationChunk(text=char)\n",
" if run_manager:\n",
" run_manager.on_llm_new_token(chunk.text, chunk=chunk)\n",
"\n",
" yield chunk\n",
"\n",
" @property\n",
" def _identifying_params(self) -> Mapping[str, Any]:\n",
" \"\"\"Get the identifying parameters.\"\"\"\n",
" return {\"n\": self.n}"
" def _identifying_params(self) -> Dict[str, Any]:\n",
" \"\"\"Return a dictionary of identifying parameters.\"\"\"\n",
" return {\n",
" # The model name allows users to specify custom token counting\n",
" # rules in LLM monitoring applications (e.g., in LangSmith users\n",
" # can provide per token pricing for their model and monitor\n",
" # costs for the given LLM.)\n",
" \"model_name\": \"CustomChatModel\",\n",
" }\n",
"\n",
" @property\n",
" def _llm_type(self) -> str:\n",
" \"\"\"Get the type of language model used by this chat model. Used for logging purposes only.\"\"\"\n",
" return \"custom\""
]
},
{
"cell_type": "markdown",
"id": "714dede0",
"metadata": {},
"id": "f614fb7b-e476-4d81-821b-57a2ebebe21c",
"metadata": {
"tags": []
},
"source": [
"We can now use this as an any other LLM."
"### Let's test it 🧪"
]
},
{
"cell_type": "markdown",
"id": "e3feae15-4afc-49f4-8542-93867d4ea769",
"metadata": {
"tags": []
},
"source": [
"This LLM will implement the standard `Runnable` interface of LangChain which many of the LangChain abstractions support!"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "10e5ece6",
"metadata": {},
"outputs": [],
"execution_count": 2,
"id": "dfff4a95-99b2-4dba-b80d-9c3855046ef1",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mCustomLLM\u001b[0m\n",
"Params: {'model_name': 'CustomChatModel'}\n"
]
}
],
"source": [
"llm = CustomLLM(n=10)"
"llm = CustomLLM(n=5)\n",
"print(llm)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 3,
"id": "8cd49199",
"metadata": {},
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'This is a '"
"'This '"
]
},
"execution_count": 11,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
@ -105,39 +220,209 @@
]
},
{
"cell_type": "markdown",
"id": "bbfebea1",
"metadata": {},
"cell_type": "code",
"execution_count": 4,
"id": "511b3cb1-9c6f-49b6-9002-a2ec490632b0",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"'world'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"We can also print the LLM and see its custom print."
"await llm.ainvoke(\"world\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "d9d5bec2-d60a-4ebd-a97d-ac32c98ab02f",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['woof ', 'meow ']"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.batch([\"woof woof woof\", \"meow meow meow\"])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "9c33fa19",
"metadata": {},
"id": "fe246b29-7a93-4bef-8861-389445598c25",
"metadata": {
"tags": []
},
"outputs": [
{
"data": {
"text/plain": [
"['woof ', 'meow ']"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"await llm.abatch([\"woof woof woof\", \"meow meow meow\"])"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "3a67c38f-b83b-4eb9-a231-441c55ee8c82",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1mCustomLLM\u001b[0m\n",
"Params: {'n': 10}\n"
"h|e|l|l|o|"
]
}
],
"source": [
"print(llm)"
"async for token in llm.astream(\"hello\"):\n",
" print(token, end=\"|\", flush=True)"
]
},
{
"cell_type": "markdown",
"id": "b62c282b-3a35-4529-aac4-2c2f0916790e",
"metadata": {},
"source": [
"Let's confirm that in integrates nicely with other `LangChain` APIs."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "6dac3f47",
"metadata": {},
"execution_count": 15,
"id": "d5578e74-7fa8-4673-afee-7a59d442aaff",
"metadata": {
"tags": []
},
"outputs": [],
"source": []
"source": [
"from langchain_core.prompts import ChatPromptTemplate"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "672ff664-8673-4832-9f4f-335253880141",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"prompt = ChatPromptTemplate.from_messages(\n",
" [(\"system\", \"you are a bot\"), (\"human\", \"{input}\")]\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "c400538a-9146-4c93-9fac-293d8f9ca6bf",
"metadata": {
"tags": []
},
"outputs": [],
"source": [
"llm = CustomLLM(n=7)\n",
"chain = prompt | llm"
]
},
{
"cell_type": "code",
"execution_count": 18,
"id": "080964af-3e2d-4573-85cb-0d7cc58a6f42",
"metadata": {
"tags": []
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'event': 'on_chain_start', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'name': 'RunnableSequence', 'tags': [], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n",
"{'event': 'on_prompt_start', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}}}\n",
"{'event': 'on_prompt_end', 'name': 'ChatPromptTemplate', 'run_id': '7e996251-a926-4344-809e-c425a9846d21', 'tags': ['seq:step:1'], 'metadata': {}, 'data': {'input': {'input': 'hello there!'}, 'output': ChatPromptValue(messages=[SystemMessage(content='you are a bot'), HumanMessage(content='hello there!')])}}\n",
"{'event': 'on_llm_start', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'input': {'prompts': ['System: you are a bot\\nHuman: hello there!']}}}\n",
"{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'S'}}\n",
"{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'S'}}\n",
"{'event': 'on_llm_stream', 'name': 'CustomLLM', 'run_id': 'a8766beb-10f4-41de-8750-3ea7cf0ca7e2', 'tags': ['seq:step:2'], 'metadata': {}, 'data': {'chunk': 'y'}}\n",
"{'event': 'on_chain_stream', 'run_id': '05f24b4f-7ea3-4fb6-8417-3aa21633462f', 'tags': [], 'metadata': {}, 'name': 'RunnableSequence', 'data': {'chunk': 'y'}}\n"
]
}
],
"source": [
"idx = 0\n",
"async for event in chain.astream_events({\"input\": \"hello there!\"}, version=\"v1\"):\n",
" print(event)\n",
" idx += 1\n",
" if idx > 7:\n",
" # Truncate\n",
" break"
]
},
{
"cell_type": "markdown",
"id": "a85e848a-5316-4318-b770-3f8fd34f4231",
"metadata": {},
"source": [
"## Contributing\n",
"\n",
"We appreciate all chat model integration contributions. \n",
"\n",
"Here's a checklist to help make sure your contribution gets added to LangChain:\n",
"\n",
"Documentation:\n",
"\n",
"* The model contains doc-strings for all initialization arguments, as these will be surfaced in the [APIReference](https://api.python.langchain.com/en/stable/langchain_api_reference.html).\n",
"* The class doc-string for the model contains a link to the model API if the model is powered by a service.\n",
"\n",
"Tests:\n",
"\n",
"* [ ] Add unit or integration tests to the overridden methods. Verify that `invoke`, `ainvoke`, `batch`, `stream` work if you've over-ridden the corresponding code.\n",
"\n",
"Streaming (if you're implementing it):\n",
"\n",
"* [ ] Make sure to invoke the `on_llm_new_token` callback\n",
"* [ ] `on_llm_new_token` is invoked BEFORE yielding the chunk\n",
"\n",
"Stop Token Behavior:\n",
"\n",
"* [ ] Stop token should be respected\n",
"* [ ] Stop token should be INCLUDED as part of the response\n",
"\n",
"Secret API Keys:\n",
"\n",
"* [ ] If your model connects to an API it will likely accept API keys as part of its initialization. Use Pydantic's `SecretStr` type for secrets, so they don't get accidentally printed out when folks print the model."
]
}
],
"metadata": {
@ -156,7 +441,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.1"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@ -557,6 +557,25 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
"""Stream the LLM on the given prompt.
This method should be overridden by subclasses that support streaming.
If not implemented, the default behavior of calls to stream will be to
fallback to the non-streaming version of the model and return
the output as a single chunk.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An iterator of GenerationChunks.
"""
raise NotImplementedError()
async def _astream(
@ -566,6 +585,23 @@ class BaseLLM(BaseLanguageModel[str], ABC):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
"""An async version of the _stream method.
The default implementation uses the synchronous _stream method and wraps it in
an async iterator. Subclasses that need to provide a true async implementation
should override this method.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of these substrings.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
An async iterator of GenerationChunks.
"""
iterator = await run_in_executor(
None,
self._stream,
@ -1182,10 +1218,28 @@ class BaseLLM(BaseLanguageModel[str], ABC):
class LLM(BaseLLM):
"""Base LLM abstract class.
"""This class exposes a simple interface for implementing a custom LLM.
The purpose of this class is to expose a simpler interface for working
with LLMs, rather than expect the user to implement the full _generate method.
You should subclass this class and implement the following:
- `_call` method: Run the LLM on the given prompt and input (used by `invoke`).
- `_identifying_params` property: Return a dictionary of the identifying parameters
This is critical for caching and tracing purposes. Identifying parameters
is a dict that identifies the LLM.
It should mostly include a `model_name`.
Optional: Override the following methods to provide more optimizations:
- `_acall`: Provide a native async version of the `_call` method.
If not provided, will delegate to the synchronous version using
`run_in_executor`. (Used by `ainvoke`).
- `_stream`: Stream the LLM on the given prompt and input.
`stream` will use `_stream` if provided, otherwise it
use `_call` and output will arrive in one chunk.
- `_astream`: Override to provide a native async version of the `_stream` method.
`astream` will use `_astream` if provided, otherwise it will implement
a fallback behavior that will use `_stream` if `_stream` is implemented,
and use `_acall` if `_stream` is not implemented.
"""
@abstractmethod
@ -1196,7 +1250,22 @@ class LLM(BaseLLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
"""Run the LLM on the given input.
Override this method to implement the LLM logic.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The model output as a string. SHOULD NOT include the prompt.
"""
async def _acall(
self,
@ -1205,7 +1274,24 @@ class LLM(BaseLLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
"""Async version of the _call method.
The default implementation delegates to the synchronous _call method using
`run_in_executor`. Subclasses that need to provide a true async implementation
should override this method to reduce the overhead of using `run_in_executor`.
Args:
prompt: The prompt to generate from.
stop: Stop words to use when generating. Model output is cut off at the
first occurrence of any of the stop substrings.
If stop tokens are not supported consider raising NotImplementedError.
run_manager: Callback manager for the run.
**kwargs: Arbitrary additional keyword arguments. These are usually passed
to the model provider API call.
Returns:
The model output as a string. SHOULD NOT include the prompt.
"""
return await run_in_executor(
None,
self._call,