mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-19 21:33:51 +00:00
- **docs: poetry publish** - **x** - **x** - **x** - **x** - **x** - **x** - **x** - **x** - **x**
168 lines
6.6 KiB
Python
168 lines
6.6 KiB
Python
from typing import Any, Dict, Iterator, List, Optional
|
|
|
|
from langchain_core.callbacks import (
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models import BaseChatModel
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
)
|
|
from langchain_core.messages.ai import UsageMetadata
|
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
|
from pydantic import Field
|
|
|
|
|
|
class ChatParrotLink(BaseChatModel):
|
|
"""A custom chat model that echoes the first `parrot_buffer_length` characters
|
|
of the input.
|
|
|
|
When contributing an implementation to LangChain, carefully document
|
|
the model including the initialization parameters, include
|
|
an example of how to initialize the model and include any relevant
|
|
links to the underlying models documentation or API.
|
|
|
|
Example:
|
|
|
|
.. code-block:: python
|
|
|
|
model = ChatParrotLink(parrot_buffer_length=2, model="bird-brain-001")
|
|
result = model.invoke([HumanMessage(content="hello")])
|
|
result = model.batch([[HumanMessage(content="hello")],
|
|
[HumanMessage(content="world")]])
|
|
"""
|
|
|
|
model_name: str = Field(alias="model")
|
|
"""The name of the model"""
|
|
parrot_buffer_length: int
|
|
"""The number of characters from the last message of the prompt to be echoed."""
|
|
temperature: Optional[float] = None
|
|
max_tokens: Optional[int] = None
|
|
timeout: Optional[int] = None
|
|
stop: Optional[List[str]] = None
|
|
max_retries: int = 2
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""Override the _generate method to implement the chat model logic.
|
|
|
|
This can be a call to an API, a call to a local model, or any other
|
|
implementation that generates a response to the input prompt.
|
|
|
|
Args:
|
|
messages: the prompt composed of a list of messages.
|
|
stop: a list of strings on which the model should stop generating.
|
|
If generation stops due to a stop token, the stop token itself
|
|
SHOULD BE INCLUDED as part of the output. This is not enforced
|
|
across models right now, but it's a good practice to follow since
|
|
it makes it much easier to parse the output of the model
|
|
downstream and understand why generation stopped.
|
|
run_manager: A run manager with callbacks for the LLM.
|
|
"""
|
|
# Replace this with actual logic to generate a response from a list
|
|
# of messages.
|
|
last_message = messages[-1]
|
|
tokens = last_message.content[: self.parrot_buffer_length]
|
|
ct_input_tokens = sum(len(message.content) for message in messages)
|
|
ct_output_tokens = len(tokens)
|
|
message = AIMessage(
|
|
content=tokens,
|
|
additional_kwargs={}, # Used to add additional payload to the message
|
|
response_metadata={ # Use for response metadata
|
|
"time_in_seconds": 3,
|
|
},
|
|
usage_metadata={
|
|
"input_tokens": ct_input_tokens,
|
|
"output_tokens": ct_output_tokens,
|
|
"total_tokens": ct_input_tokens + ct_output_tokens,
|
|
},
|
|
)
|
|
##
|
|
|
|
generation = ChatGeneration(message=message)
|
|
return ChatResult(generations=[generation])
|
|
|
|
def _stream(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[ChatGenerationChunk]:
|
|
"""Stream the output of the model.
|
|
|
|
This method should be implemented if the model can generate output
|
|
in a streaming fashion. If the model does not support streaming,
|
|
do not implement it. In that case streaming requests will be automatically
|
|
handled by the _generate method.
|
|
|
|
Args:
|
|
messages: the prompt composed of a list of messages.
|
|
stop: a list of strings on which the model should stop generating.
|
|
If generation stops due to a stop token, the stop token itself
|
|
SHOULD BE INCLUDED as part of the output. This is not enforced
|
|
across models right now, but it's a good practice to follow since
|
|
it makes it much easier to parse the output of the model
|
|
downstream and understand why generation stopped.
|
|
run_manager: A run manager with callbacks for the LLM.
|
|
"""
|
|
last_message = messages[-1]
|
|
tokens = str(last_message.content[: self.parrot_buffer_length])
|
|
ct_input_tokens = sum(len(message.content) for message in messages)
|
|
|
|
for token in tokens:
|
|
usage_metadata = UsageMetadata(
|
|
{
|
|
"input_tokens": ct_input_tokens,
|
|
"output_tokens": 1,
|
|
"total_tokens": ct_input_tokens + 1,
|
|
}
|
|
)
|
|
ct_input_tokens = 0
|
|
chunk = ChatGenerationChunk(
|
|
message=AIMessageChunk(content=token, usage_metadata=usage_metadata)
|
|
)
|
|
|
|
if run_manager:
|
|
# This is optional in newer versions of LangChain
|
|
# The on_llm_new_token will be called automatically
|
|
run_manager.on_llm_new_token(token, chunk=chunk)
|
|
|
|
yield chunk
|
|
|
|
# Let's add some other information (e.g., response metadata)
|
|
chunk = ChatGenerationChunk(
|
|
message=AIMessageChunk(content="", response_metadata={"time_in_sec": 3})
|
|
)
|
|
if run_manager:
|
|
# This is optional in newer versions of LangChain
|
|
# The on_llm_new_token will be called automatically
|
|
run_manager.on_llm_new_token(token, chunk=chunk)
|
|
yield chunk
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Get the type of language model used by this chat model."""
|
|
return "echoing-chat-model-advanced"
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
"""Return a dictionary of identifying parameters.
|
|
|
|
This information is used by the LangChain callback system, which
|
|
is used for tracing purposes make it possible to monitor LLMs.
|
|
"""
|
|
return {
|
|
# The model name allows users to specify custom token counting
|
|
# rules in LLM monitoring applications (e.g., in LangSmith users
|
|
# can provide per token pricing for their model and monitor
|
|
# costs for the given LLM.)
|
|
"model_name": self.model_name,
|
|
}
|