This commit is contained in:
Erick Friis 2024-11-21 19:37:17 -08:00
parent 46ea6722f4
commit bffca0d5c2
5 changed files with 307 additions and 54 deletions

View File

@ -90,12 +90,19 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import AIMessageChunk, BaseMessage, AIMessage
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
class CustomChatModelAdvanced(BaseChatModel):
"""A custom chat model that echoes the first `n` characters of the input.
class ChatParrotLink(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
@ -106,16 +113,21 @@ class CustomChatModelAdvanced(BaseChatModel):
.. code-block:: python
model = CustomChatModel(n=2)
model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str
model_name: str = Field(alias="model")
"""The name of the model"""
n: int
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2
def _generate(
self,
@ -142,13 +154,20 @@ class CustomChatModelAdvanced(BaseChatModel):
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.n]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload (e.g., function calling request)
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##
@ -180,10 +199,21 @@ class CustomChatModelAdvanced(BaseChatModel):
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = last_message.content[: self.n]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)
for token in tokens:
chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)
if run_manager:
# This is optional in newer versions of LangChain

View File

@ -48,7 +48,7 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 4,
"id": "c5046e6a-8b09-4a99-b6e6-7a605aac5738",
"metadata": {
"tags": []
@ -175,12 +175,19 @@
" CallbackManagerForLLMRun,\n",
")\n",
"from langchain_core.language_models import BaseChatModel\n",
"from langchain_core.messages import AIMessageChunk, BaseMessage, HumanMessage\n",
"from langchain_core.messages import (\n",
" AIMessage,\n",
" AIMessageChunk,\n",
" BaseMessage,\n",
")\n",
"from langchain_core.messages.ai import UsageMetadata\n",
"from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult\n",
"from pydantic import Field\n",
"\n",
"\n",
"class CustomChatModelAdvanced(BaseChatModel):\n",
" \"\"\"A custom chat model that echoes the first `n` characters of the input.\n",
"class ChatParrotLink(BaseChatModel):\n",
" \"\"\"A custom chat model that echoes the first `parrot_buffer_length` characters\n",
" of the input.\n",
"\n",
" When contributing an implementation to LangChain, carefully document\n",
" the model including the initialization parameters, include\n",
@ -191,16 +198,21 @@
"\n",
" .. code-block:: python\n",
"\n",
" model = CustomChatModel(n=2)\n",
" model = ChatParrotLink(parrot_buffer_length=2, model=\"bird-brain-001\")\n",
" result = model.invoke([HumanMessage(content=\"hello\")])\n",
" result = model.batch([[HumanMessage(content=\"hello\")],\n",
" [HumanMessage(content=\"world\")]])\n",
" \"\"\"\n",
"\n",
" model_name: str\n",
" model_name: str = Field(alias=\"model\")\n",
" \"\"\"The name of the model\"\"\"\n",
" n: int\n",
" parrot_buffer_length: int\n",
" \"\"\"The number of characters from the last message of the prompt to be echoed.\"\"\"\n",
" temperature: Optional[float] = None\n",
" max_tokens: Optional[int] = None\n",
" timeout: Optional[int] = None\n",
" stop: Optional[List[str]] = None\n",
" max_retries: int = 2\n",
"\n",
" def _generate(\n",
" self,\n",
@ -227,13 +239,20 @@
" # Replace this with actual logic to generate a response from a list\n",
" # of messages.\n",
" last_message = messages[-1]\n",
" tokens = last_message.content[: self.n]\n",
" tokens = last_message.content[: self.parrot_buffer_length]\n",
" ct_input_tokens = sum(len(message.content) for message in messages)\n",
" ct_output_tokens = len(tokens)\n",
" message = AIMessage(\n",
" content=tokens,\n",
" additional_kwargs={}, # Used to add additional payload (e.g., function calling request)\n",
" additional_kwargs={}, # Used to add additional payload to the message\n",
" response_metadata={ # Use for response metadata\n",
" \"time_in_seconds\": 3,\n",
" },\n",
" usage_metadata={\n",
" \"input_tokens\": ct_input_tokens,\n",
" \"output_tokens\": ct_output_tokens,\n",
" \"total_tokens\": ct_input_tokens + ct_output_tokens,\n",
" },\n",
" )\n",
" ##\n",
"\n",
@ -265,10 +284,21 @@
" run_manager: A run manager with callbacks for the LLM.\n",
" \"\"\"\n",
" last_message = messages[-1]\n",
" tokens = last_message.content[: self.n]\n",
" tokens = str(last_message.content[: self.parrot_buffer_length])\n",
" ct_input_tokens = sum(len(message.content) for message in messages)\n",
"\n",
" for token in tokens:\n",
" chunk = ChatGenerationChunk(message=AIMessageChunk(content=token))\n",
" usage_metadata = UsageMetadata(\n",
" {\n",
" \"input_tokens\": ct_input_tokens,\n",
" \"output_tokens\": 1,\n",
" \"total_tokens\": ct_input_tokens + 1,\n",
" }\n",
" )\n",
" ct_input_tokens = 0\n",
" chunk = ChatGenerationChunk(\n",
" message=AIMessageChunk(content=token, usage_metadata=usage_metadata)\n",
" )\n",
"\n",
" if run_manager:\n",
" # This is optional in newer versions of LangChain\n",
@ -320,7 +350,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 5,
"id": "27689f30-dcd2-466b-ba9d-f60b7d434110",
"metadata": {
"tags": []
@ -329,16 +359,16 @@
{
"data": {
"text/plain": [
"AIMessage(content='Meo', response_metadata={'time_in_seconds': 3}, id='run-ddb42bd6-4fdd-4bd2-8be5-e11b67d3ac29-0')"
"AIMessage(content='Meo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-cf11aeb6-8ab6-43d7-8c68-c1ef89b6d78e-0', usage_metadata={'input_tokens': 26, 'output_tokens': 3, 'total_tokens': 29})"
]
},
"execution_count": 6,
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = CustomChatModelAdvanced(n=3, model_name=\"my_custom_model\")\n",
"model = ChatParrotLink(parrot_buffer_length=3, model=\"my_custom_model\")\n",
"\n",
"model.invoke(\n",
" [\n",
@ -351,7 +381,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 6,
"id": "406436df-31bf-466b-9c3d-39db9d6b6407",
"metadata": {
"tags": []
@ -360,10 +390,10 @@
{
"data": {
"text/plain": [
"AIMessage(content='hel', response_metadata={'time_in_seconds': 3}, id='run-4d3cc912-44aa-454b-977b-ca02be06c12e-0')"
"AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-618e5ed4-d611-4083-8cf1-c270726be8d9-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8})"
]
},
"execution_count": 7,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
@ -374,7 +404,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 7,
"id": "a72ffa46-6004-41ef-bbe4-56fa17a029e2",
"metadata": {
"tags": []
@ -383,11 +413,11 @@
{
"data": {
"text/plain": [
"[AIMessage(content='hel', response_metadata={'time_in_seconds': 3}, id='run-9620e228-1912-4582-8aa1-176813afec49-0'),\n",
" AIMessage(content='goo', response_metadata={'time_in_seconds': 3}, id='run-1ce8cdf8-6f75-448e-82f7-1bb4a121df93-0')]"
"[AIMessage(content='hel', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-eea4ed7d-d750-48dc-90c0-7acca1ff388f-0', usage_metadata={'input_tokens': 5, 'output_tokens': 3, 'total_tokens': 8}),\n",
" AIMessage(content='goo', additional_kwargs={}, response_metadata={'time_in_seconds': 3}, id='run-07cfc5c1-3c62-485f-b1e0-3d46e1547287-0', usage_metadata={'input_tokens': 7, 'output_tokens': 3, 'total_tokens': 10})]"
]
},
"execution_count": 8,
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
@ -398,7 +428,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 8,
"id": "3633be2c-2ea0-42f9-a72f-3b5240690b55",
"metadata": {
"tags": []
@ -427,7 +457,7 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 9,
"id": "b7d73995-eeab-48c6-a7d8-32c98ba29fc2",
"metadata": {
"tags": []
@ -456,7 +486,7 @@
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 10,
"id": "17840eba-8ff4-4e73-8e4f-85f16eb1c9d0",
"metadata": {
"tags": []
@ -466,20 +496,12 @@
"name": "stdout",
"output_type": "stream",
"text": [
"{'event': 'on_chat_model_start', 'run_id': '125a2a16-b9cd-40de-aa08-8aa9180b07d0', 'name': 'CustomChatModelAdvanced', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}}\n",
"{'event': 'on_chat_model_stream', 'run_id': '125a2a16-b9cd-40de-aa08-8aa9180b07d0', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='c', id='run-125a2a16-b9cd-40de-aa08-8aa9180b07d0')}}\n",
"{'event': 'on_chat_model_stream', 'run_id': '125a2a16-b9cd-40de-aa08-8aa9180b07d0', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='a', id='run-125a2a16-b9cd-40de-aa08-8aa9180b07d0')}}\n",
"{'event': 'on_chat_model_stream', 'run_id': '125a2a16-b9cd-40de-aa08-8aa9180b07d0', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='t', id='run-125a2a16-b9cd-40de-aa08-8aa9180b07d0')}}\n",
"{'event': 'on_chat_model_stream', 'run_id': '125a2a16-b9cd-40de-aa08-8aa9180b07d0', 'tags': [], 'metadata': {}, 'name': 'CustomChatModelAdvanced', 'data': {'chunk': AIMessageChunk(content='', response_metadata={'time_in_sec': 3}, id='run-125a2a16-b9cd-40de-aa08-8aa9180b07d0')}}\n",
"{'event': 'on_chat_model_end', 'name': 'CustomChatModelAdvanced', 'run_id': '125a2a16-b9cd-40de-aa08-8aa9180b07d0', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat', response_metadata={'time_in_sec': 3}, id='run-125a2a16-b9cd-40de-aa08-8aa9180b07d0')}}\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/eugene/src/langchain/libs/core/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: This API is in beta and may change in the future.\n",
" warn_beta(\n"
"{'event': 'on_chat_model_start', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'name': 'ChatParrotLink', 'tags': [], 'metadata': {}, 'data': {'input': 'cat'}, 'parent_ids': []}\n",
"{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='c', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 1, 'total_tokens': 4})}, 'parent_ids': []}\n",
"{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='a', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}\n",
"{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='t', additional_kwargs={}, response_metadata={}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 0, 'output_tokens': 1, 'total_tokens': 1})}, 'parent_ids': []}\n",
"{'event': 'on_chat_model_stream', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'name': 'ChatParrotLink', 'data': {'chunk': AIMessageChunk(content='', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a')}, 'parent_ids': []}\n",
"{'event': 'on_chat_model_end', 'name': 'ChatParrotLink', 'run_id': '3f0b5501-5c78-45b3-92fc-8322a6a5024a', 'tags': [], 'metadata': {}, 'data': {'output': AIMessageChunk(content='cat', additional_kwargs={}, response_metadata={'time_in_sec': 3}, id='run-3f0b5501-5c78-45b3-92fc-8322a6a5024a', usage_metadata={'input_tokens': 3, 'output_tokens': 3, 'total_tokens': 6})}, 'parent_ids': []}\n"
]
}
],
@ -545,7 +567,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"display_name": ".venv",
"language": "python",
"name": "python3"
},
@ -559,7 +581,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
"version": "3.11.4"
}
},
"nbformat": 4,

View File

@ -493,9 +493,13 @@ class ChatModelIntegrationTests(ChatModelTests):
message=AIMessage(
content="Output text",
usage_metadata={
"input_tokens": 0,
"output_tokens": 240,
"total_tokens": 590,
"input_tokens": (
num_input_tokens if is_first_chunk else 0
),
"output_tokens": 11,
"total_tokens": (
11+num_input_tokens if is_first_chunk else 11
),
"input_token_details": {
"audio": 10,
"cache_creation": 200,

View File

@ -0,0 +1,167 @@
from typing import Any, Dict, Iterator, List, Optional
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
class ChatParrotLink(BaseChatModel):
"""A custom chat model that echoes the first `parrot_buffer_length` characters
of the input.
When contributing an implementation to LangChain, carefully document
the model including the initialization parameters, include
an example of how to initialize the model and include any relevant
links to the underlying models documentation or API.
Example:
.. code-block:: python
model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
result = model.invoke([HumanMessage(content="hello")])
result = model.batch([[HumanMessage(content="hello")],
[HumanMessage(content="world")]])
"""
model_name: str = Field(alias="model")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[List[str]] = None
max_retries: int = 2
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Override the _generate method to implement the chat model logic.
This can be a call to an API, a call to a local model, or any other
implementation that generates a response to the input prompt.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
# Replace this with actual logic to generate a response from a list
# of messages.
last_message = messages[-1]
tokens = last_message.content[: self.parrot_buffer_length]
ct_input_tokens = sum(len(message.content) for message in messages)
ct_output_tokens = len(tokens)
message = AIMessage(
content=tokens,
additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata
"time_in_seconds": 3,
},
usage_metadata={
"input_tokens": ct_input_tokens,
"output_tokens": ct_output_tokens,
"total_tokens": ct_input_tokens + ct_output_tokens,
},
)
##
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model.
This method should be implemented if the model can generate output
in a streaming fashion. If the model does not support streaming,
do not implement it. In that case streaming requests will be automatically
handled by the _generate method.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
last_message = messages[-1]
tokens = str(last_message.content[: self.parrot_buffer_length])
ct_input_tokens = sum(len(message.content) for message in messages)
for token in tokens:
usage_metadata = UsageMetadata(
{
"input_tokens": ct_input_tokens,
"output_tokens": 1,
"total_tokens": ct_input_tokens + 1,
}
)
ct_input_tokens = 0
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
# Let's add some other information (e.g., response metadata)
chunk = ChatGenerationChunk(
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
)
if run_manager:
# This is optional in newer versions of LangChain
# The on_llm_new_token will be called automatically
run_manager.on_llm_new_token(token, chunk=chunk)
yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "echoing-chat-model-advanced"
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
# The model name allows users to specify custom token counting
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
}

View File

@ -0,0 +1,30 @@
"""
Test the standard tests on the custom chat model in the docs
"""
from typing import Type
from langchain_tests.integration_tests import ChatModelIntegrationTests
from langchain_tests.unit_tests import ChatModelUnitTests
from .custom_chat_model import ChatParrotLink
class TestChatParrotLinkUnit(ChatModelUnitTests):
@property
def chat_model_class(self) -> Type[ChatParrotLink]:
return ChatParrotLink
@property
def chat_model_params(self) -> dict:
return {"model": "bird-brain-001", "temperature": 0, "parrot_buffer_length": 50}
class TestChatParrotLinkIntegration(ChatModelIntegrationTests):
@property
def chat_model_class(self) -> Type[ChatParrotLink]:
return ChatParrotLink
@property
def chat_model_params(self) -> dict:
return {"model": "bird-brain-001", "temperature": 0, "parrot_buffer_length": 50}