langchain/libs/standard-tests/tests/unit_tests/custom_chat_model_v1.py
Mason Daugherty c1b86cc929
feat: minor core work, v1 standard tests & (most of) v1 ollama (#32315)
Resolves #32215

---------

Co-authored-by: Chester Curme <chester.curme@gmail.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
Co-authored-by: Nuno Campos <nuno@langchain.dev>
2025-08-06 18:22:02 -04:00

260 lines
8.2 KiB
Python

"""``ChatParrotLinkV1`` implementation for standard-tests with v1 messages.
This module provides a test implementation of ``BaseChatModel`` that supports the new
v1 message format with content blocks.
"""
from collections.abc import AsyncIterator, Iterator
from typing import Any, Optional, cast
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.callbacks.manager import AsyncCallbackManagerForLLMRun
from langchain_core.messages.ai import UsageMetadata
from langchain_core.v1.chat_models import BaseChatModel
from langchain_core.v1.messages import AIMessage, AIMessageChunk, MessageV1
from pydantic import Field
class ChatParrotLinkV1(BaseChatModel):
"""A custom v1 chat model that echoes input with content blocks support.
This model is designed for testing the v1 message format and content blocks. Echoes
the first ``parrot_buffer_length`` characters of the input and returns them as
proper v1 content blocks.
Example:
.. code-block:: python
model = ChatParrotLinkV1(parrot_buffer_length=10, model="parrot-v1")
result = model.invoke([HumanMessage(content="hello world")])
# Returns AIMessage with content blocks format
"""
model_name: str = Field(alias="model")
"""The name of the model."""
temperature: Optional[float] = None
max_tokens: Optional[int] = None
timeout: Optional[int] = None
stop: Optional[list[str]] = None
max_retries: int = 2
parrot_buffer_length: int = Field(default=50)
"""The number of characters from the last message to echo."""
def _invoke(
self,
messages: list[MessageV1],
**kwargs: Any,
) -> AIMessage:
"""Generate a response by echoing the input as content blocks.
Args:
messages: List of v1 messages to process.
**kwargs: Additional generation parameters.
Returns:
AIMessage with content blocks format.
"""
_ = kwargs # Mark as used
if not messages:
return AIMessage("No input provided")
last_message = messages[-1]
# Extract text content from the message
text_content = ""
for block in last_message.content:
if isinstance(block, dict) and block.get("type") == "text":
text_content += str(block.get("text", ""))
# Echo the first parrot_buffer_length characters
echoed_text = text_content[: self.parrot_buffer_length]
# Calculate usage metadata
total_input_chars = sum(
len(str(msg.content))
if isinstance(msg.content, str)
else (
sum(len(str(block)) for block in msg.content)
if isinstance(msg.content, list)
else 0
)
for msg in messages
)
usage_metadata = UsageMetadata(
input_tokens=total_input_chars,
output_tokens=len(echoed_text),
total_tokens=total_input_chars + len(echoed_text),
)
return AIMessage(
content=echoed_text,
response_metadata=cast(
Any,
{
"model_name": self.model_name,
"time_in_seconds": 0.1,
},
),
usage_metadata=usage_metadata,
)
def _stream(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[AIMessageChunk]:
"""Stream the response by yielding character chunks.
Args:
messages: List of v1 messages to process.
stop: Stop sequences (unused in this implementation).
run_manager: Callback manager for the LLM run.
**kwargs: Additional generation parameters.
Yields:
AIMessageChunk objects with individual characters.
"""
_ = stop # Mark as used
_ = kwargs # Mark as used
if not messages:
yield AIMessageChunk("No input provided")
return
last_message = messages[-1]
# Extract text content from the message
text_content = ""
# Extract text from content blocks
for block in last_message.content:
if isinstance(block, dict) and block.get("type") == "text":
text_content += str(block.get("text", ""))
# Echo the first parrot_buffer_length characters
echoed_text = text_content[: self.parrot_buffer_length]
# Calculate total input for usage metadata
total_input_chars = sum(
len(str(msg.content))
if isinstance(msg.content, str)
else (
sum(len(str(block)) for block in msg.content)
if isinstance(msg.content, list)
else 0
)
for msg in messages
)
# Stream each character as a chunk
for i, char in enumerate(echoed_text):
usage_metadata = UsageMetadata(
input_tokens=total_input_chars if i == 0 else 0,
output_tokens=1,
total_tokens=total_input_chars + 1 if i == 0 else 1,
)
chunk = AIMessageChunk(
content=char,
usage_metadata=usage_metadata,
)
if run_manager:
run_manager.on_llm_new_token(char, chunk=chunk)
yield chunk
# Final chunk with response metadata
final_chunk = AIMessageChunk(
content=[],
response_metadata=cast(
Any,
{
"model_name": self.model_name,
"time_in_seconds": 0.1,
},
),
)
yield final_chunk
async def _ainvoke(
self,
messages: list[MessageV1],
**kwargs: Any,
) -> AIMessage:
"""Async generate a response (delegates to sync implementation).
Args:
messages: List of v1 messages to process.
**kwargs: Additional generation parameters.
Returns:
AIMessage with content blocks format.
"""
# For simplicity, delegate to sync implementation
return self._invoke(messages, **kwargs)
async def _astream(
self,
messages: list[MessageV1],
stop: Optional[list[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[AIMessageChunk]:
"""Async stream the response (delegates to sync implementation).
Args:
messages: List of v1 messages to process.
stop: Stop sequences (unused in this implementation).
run_manager: Async callback manager for the LLM run.
**kwargs: Additional generation parameters.
Yields:
AIMessageChunk objects with individual characters.
"""
# For simplicity, delegate to sync implementation
for chunk in self._stream(messages, stop, None, **kwargs):
yield chunk
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "parrot-chat-model-v1"
@property
def _identifying_params(self) -> dict[str, Any]:
"""Return a dictionary of identifying parameters."""
return {
"model_name": self.model_name,
"parrot_buffer_length": self.parrot_buffer_length,
}
def get_token_ids(self, text: str) -> list[int]:
"""Convert text to token IDs using simple character-based tokenization.
For testing purposes, we use a simple approach where each character
maps to its ASCII/Unicode code point.
Args:
text: The text to tokenize.
Returns:
List of token IDs (character code points).
"""
return [ord(char) for char in text]
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens in the text.
Args:
text: The text to count tokens for.
Returns:
Number of tokens (characters in this simple implementation).
"""
return len(text)