langchain/libs/community/langchain_community/chat_models/reka.py
2024-11-15 13:37:14 -05:00

436 lines
16 KiB
Python

import json
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import (
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import get_from_dict_or_env
from langchain_core.utils.function_calling import convert_to_openai_tool
from pydantic import BaseModel, ConfigDict, Field, model_validator
DEFAULT_REKA_MODEL = "reka-flash"
ContentType = Union[str, List[Union[str, Dict[str, Any]]]]
def process_content_item(item: Dict[str, Any]) -> Dict[str, Any]:
"""Process a single content item."""
if item["type"] == "image_url":
image_url = item["image_url"]
if isinstance(image_url, dict) and "url" in image_url:
# If it's in LangChain format, extract the URL value
item["image_url"] = image_url["url"]
return item
def process_content(content: ContentType) -> List[Dict[str, Any]]:
"""Process content to handle both text and media inputs,
returning a list of content items."""
if isinstance(content, str):
return [{"type": "text", "text": content}]
elif isinstance(content, list):
result = []
for item in content:
if isinstance(item, str):
result.append({"type": "text", "text": item})
elif isinstance(item, dict):
result.append(process_content_item(item))
else:
raise ValueError(f"Invalid content item format: {item}")
return result
else:
raise ValueError("Invalid content format")
def convert_to_reka_messages(messages: List[BaseMessage]) -> List[Dict[str, Any]]:
"""Convert LangChain messages to Reka message format."""
reka_messages: List[Dict[str, Any]] = []
system_message: Optional[str] = None
for message in messages:
if isinstance(message, SystemMessage):
if system_message is None:
if isinstance(message.content, str):
system_message = message.content
else:
raise TypeError("SystemMessage content must be a string.")
else:
raise ValueError("Multiple system messages are not supported.")
elif isinstance(message, HumanMessage):
processed_content = process_content(message.content)
if system_message:
if (
processed_content
and isinstance(processed_content[0], dict)
and processed_content[0].get("type") == "text"
and "text" in processed_content[0]
):
processed_content[0]["text"] = (
f"{system_message}\n{processed_content[0]['text']}"
)
else:
processed_content.insert(
0, {"type": "text", "text": system_message}
)
system_message = None
reka_messages.append({"role": "user", "content": processed_content})
elif isinstance(message, AIMessage):
reka_message: Dict[str, Any] = {"role": "assistant"}
if message.content:
processed_content = process_content(message.content)
reka_message["content"] = processed_content
if "tool_calls" in message.additional_kwargs:
tool_calls = message.additional_kwargs["tool_calls"]
formatted_tool_calls = []
for tool_call in tool_calls:
formatted_tool_call = {
"id": tool_call["id"],
"name": tool_call["function"]["name"],
"parameters": json.loads(tool_call["function"]["arguments"]),
}
formatted_tool_calls.append(formatted_tool_call)
reka_message["tool_calls"] = formatted_tool_calls
reka_messages.append(reka_message)
elif isinstance(message, ToolMessage):
content_list: List[Dict[str, Any]] = []
content_list.append(
{
"tool_call_id": message.tool_call_id,
"output": json.dumps({"status": message.content}),
}
)
reka_messages.append(
{
"role": "tool_output",
"content": content_list,
}
)
else:
raise ValueError(f"Unsupported message type: {type(message)}")
return reka_messages
class ChatReka(BaseChatModel):
"""Reka chat large language models."""
client: Any = None #: :meta private:
async_client: Any = None #: :meta private:
model: str = Field(default=DEFAULT_REKA_MODEL)
max_tokens: int = Field(default=256)
temperature: Optional[float] = None
streaming: bool = False
default_request_timeout: Optional[float] = None
max_retries: int = 2
reka_api_key: Optional[str] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
model_config = ConfigDict(extra="forbid")
token_counter: Optional[
Callable[[Union[str, BaseMessage, List[BaseMessage]]], int]
] = None
@model_validator(mode="before")
@classmethod
def validate_environment(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Validate that API key and Python package exist in the environment."""
reka_api_key = values.get("reka_api_key")
reka_api_key = get_from_dict_or_env(
{"reka_api_key": reka_api_key}, "reka_api_key", "REKA_API_KEY"
)
values["reka_api_key"] = reka_api_key
try:
# Import reka libraries here
from reka.client import AsyncReka, Reka
values["client"] = Reka(
api_key=reka_api_key,
)
values["async_client"] = AsyncReka(
api_key=reka_api_key,
)
except ImportError:
raise ImportError(
"Could not import Reka Python package. "
"Please install it with `pip install reka-api`."
)
return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Reka API."""
params = {
"model": self.model,
"max_tokens": self.max_tokens,
}
if self.temperature is not None:
params["temperature"] = self.temperature
return {**params, **self.model_kwargs}
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "reka-chat"
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
stream = self.client.chat.create_stream(messages=reka_messages, **params)
for chunk in stream:
content = chunk.responses[0].chunk.content
chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
if run_manager:
run_manager.on_llm_new_token(content, chunk=chat_chunk)
yield chat_chunk
async def _astream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[ChatGenerationChunk]:
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
stream = self.async_client.chat.create_stream(messages=reka_messages, **params)
async for chunk in stream:
content = chunk.responses[0].chunk.content
chat_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
if run_manager:
await run_manager.on_llm_new_token(content, chunk=chat_chunk)
yield chat_chunk
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
return generate_from_stream(
self._stream(messages, stop=stop, run_manager=run_manager, **kwargs)
)
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
response = self.client.chat.create(messages=reka_messages, **params)
if response.responses[0].message.tool_calls:
tool_calls = response.responses[0].message.tool_calls
message = AIMessage(
content="", # Empty string instead of None
additional_kwargs={
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.parameters),
},
}
for tc in tool_calls
]
},
)
else:
content = response.responses[0].message.content
# Ensure content is never None
message = AIMessage(content=content if content is not None else "")
return ChatResult(generations=[ChatGeneration(message=message)])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
return await agenerate_from_stream(
self._astream(messages, stop=stop, run_manager=run_manager, **kwargs)
)
reka_messages = convert_to_reka_messages(messages)
params = {**self._default_params, **kwargs}
if stop:
params["stop"] = stop
response = await self.async_client.chat.create(messages=reka_messages, **params)
if response.responses[0].message.tool_calls:
tool_calls = response.responses[0].message.tool_calls
message = AIMessage(
content="", # Empty string instead of None
additional_kwargs={
"tool_calls": [
{
"id": tc.id,
"type": "function",
"function": {
"name": tc.name,
"arguments": json.dumps(tc.parameters),
},
}
for tc in tool_calls
]
},
)
else:
content = response.responses[0].message.content
# Ensure content is never None
message = AIMessage(content=content if content is not None else "")
return ChatResult(generations=[ChatGeneration(message=message)])
def get_num_tokens(self, input: Union[str, BaseMessage, List[BaseMessage]]) -> int:
"""Calculate number of tokens.
Args:
input: Either a string, a single BaseMessage, or a list of BaseMessages.
Returns:
int: Number of tokens in the input.
Raises:
ImportError: If tiktoken is not installed.
ValueError: If message content is not a string.
"""
if self.token_counter is not None:
return self.token_counter(input)
try:
import tiktoken
except ImportError:
raise ImportError(
"Could not import tiktoken python package. "
"Please install it with `pip install tiktoken`."
)
encoding = tiktoken.get_encoding("cl100k_base")
if isinstance(input, str):
return len(encoding.encode(input))
elif isinstance(input, BaseMessage):
content = input.content
if not isinstance(content, str):
raise ValueError(
f"Message content must be a string, got {type(content)}"
)
return len(encoding.encode(content))
elif isinstance(input, list):
total = 0
for msg in input:
content = msg.content
if not isinstance(content, str):
raise ValueError(
f"Message content must be a string, got {type(content)}"
)
total += len(encoding.encode(content))
return total
else:
raise TypeError(f"Unsupported input type: {type(input)}")
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
*,
tool_choice: str = "auto",
strict: Optional[bool] = None,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model.
The `tool_choice` parameter controls how the model uses the tools you pass.
There are three available options:
- `"auto"`: Lets the model decide whether or not to invoke a tool. This is the
recommended way to do function calling with our models.
- `"none"`: Disables tool calling. In this case, even if you pass tools to
the model, the model will not invoke any tools.
- `"tool"`: Forces the model to invoke one or more of the tools it has
been passed.
Args:
tools: A list of tool definitions to bind to this chat model.
Supports any tool definition handled by
:meth:`langchain_core.utils.function_calling.convert_to_openai_tool`.
tool_choice: Controls how the model uses the tools you pass.
Options are "auto", "none", or "tool". Defaults to "auto".
strict:
If True, model output is guaranteed to exactly match the JSON Schema
provided in the tool definition.
If False, input schema will not be validated
and model output will not be validated.
If None, ``strict`` argument will not
be passed to the model.
kwargs: Any additional parameters are passed directly to the model.
Returns:
Runnable: An executable chain or component.
"""
formatted_tools = [
convert_to_openai_tool(tool, strict=strict) for tool in tools
]
# Ensure tool_choice is one of the allowed options
if tool_choice not in ("auto", "none", "tool"):
raise ValueError(
f"Invalid tool_choice '{tool_choice}' provided. "
"Tool choice must be one of: 'auto', 'none', or 'tool'."
)
# Map tool_choice to the parameter expected by the Reka API
kwargs["tool_choice"] = tool_choice
# Pass the tools and updated kwargs to the model
formatted_tools = [tool["function"] for tool in formatted_tools]
return super().bind(tools=formatted_tools, **kwargs)