mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
See format here: https://platform.openai.com/docs/guides/function-calling/parallel-function-calling It expects a "name" argument, which we aren't providing by default.  Alternative is to add the 'name' field directly to the message if people prefer.
407 lines
12 KiB
Python
407 lines
12 KiB
Python
from __future__ import annotations
|
|
|
|
import importlib
|
|
from typing import (
|
|
Any,
|
|
AsyncIterator,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Mapping,
|
|
Sequence,
|
|
Union,
|
|
overload,
|
|
)
|
|
|
|
from langchain_core.chat_sessions import ChatSession
|
|
from langchain_core.messages import (
|
|
AIMessage,
|
|
AIMessageChunk,
|
|
BaseMessage,
|
|
BaseMessageChunk,
|
|
ChatMessage,
|
|
FunctionMessage,
|
|
HumanMessage,
|
|
SystemMessage,
|
|
ToolMessage,
|
|
)
|
|
from langchain_core.pydantic_v1 import BaseModel
|
|
from typing_extensions import Literal
|
|
|
|
|
|
async def aenumerate(
|
|
iterable: AsyncIterator[Any], start: int = 0
|
|
) -> AsyncIterator[tuple[int, Any]]:
|
|
"""Async version of enumerate function."""
|
|
i = start
|
|
async for x in iterable:
|
|
yield i, x
|
|
i += 1
|
|
|
|
|
|
class IndexableBaseModel(BaseModel):
|
|
"""Allows a BaseModel to return its fields by string variable indexing"""
|
|
|
|
def __getitem__(self, item: str) -> Any:
|
|
return getattr(self, item)
|
|
|
|
|
|
class Choice(IndexableBaseModel):
|
|
message: dict
|
|
|
|
|
|
class ChatCompletions(IndexableBaseModel):
|
|
choices: List[Choice]
|
|
|
|
|
|
class ChoiceChunk(IndexableBaseModel):
|
|
delta: dict
|
|
|
|
|
|
class ChatCompletionChunk(IndexableBaseModel):
|
|
choices: List[ChoiceChunk]
|
|
|
|
|
|
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
|
"""Convert a dictionary to a LangChain message.
|
|
|
|
Args:
|
|
_dict: The dictionary.
|
|
|
|
Returns:
|
|
The LangChain message.
|
|
"""
|
|
role = _dict["role"]
|
|
if role == "user":
|
|
return HumanMessage(content=_dict["content"])
|
|
elif role == "assistant":
|
|
# Fix for azure
|
|
# Also OpenAI returns None for tool invocations
|
|
content = _dict.get("content", "") or ""
|
|
additional_kwargs: Dict = {}
|
|
if _dict.get("function_call"):
|
|
additional_kwargs["function_call"] = dict(_dict["function_call"])
|
|
if _dict.get("tool_calls"):
|
|
additional_kwargs["tool_calls"] = _dict["tool_calls"]
|
|
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
|
elif role == "system":
|
|
return SystemMessage(content=_dict["content"])
|
|
elif role == "function":
|
|
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
|
elif role == "tool":
|
|
additional_kwargs = {}
|
|
if "name" in _dict:
|
|
additional_kwargs["name"] = _dict["name"]
|
|
return ToolMessage(
|
|
content=_dict["content"],
|
|
tool_call_id=_dict["tool_call_id"],
|
|
additional_kwargs=additional_kwargs,
|
|
)
|
|
else:
|
|
return ChatMessage(content=_dict["content"], role=role)
|
|
|
|
|
|
def convert_message_to_dict(message: BaseMessage) -> dict:
|
|
"""Convert a LangChain message to a dictionary.
|
|
|
|
Args:
|
|
message: The LangChain message.
|
|
|
|
Returns:
|
|
The dictionary.
|
|
"""
|
|
message_dict: Dict[str, Any]
|
|
if isinstance(message, ChatMessage):
|
|
message_dict = {"role": message.role, "content": message.content}
|
|
elif isinstance(message, HumanMessage):
|
|
message_dict = {"role": "user", "content": message.content}
|
|
elif isinstance(message, AIMessage):
|
|
message_dict = {"role": "assistant", "content": message.content}
|
|
if "function_call" in message.additional_kwargs:
|
|
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
|
# If function call only, content is None not empty string
|
|
if message_dict["content"] == "":
|
|
message_dict["content"] = None
|
|
if "tool_calls" in message.additional_kwargs:
|
|
message_dict["tool_calls"] = message.additional_kwargs["tool_calls"]
|
|
# If tool calls only, content is None not empty string
|
|
if message_dict["content"] == "":
|
|
message_dict["content"] = None
|
|
elif isinstance(message, SystemMessage):
|
|
message_dict = {"role": "system", "content": message.content}
|
|
elif isinstance(message, FunctionMessage):
|
|
message_dict = {
|
|
"role": "function",
|
|
"content": message.content,
|
|
"name": message.name,
|
|
}
|
|
elif isinstance(message, ToolMessage):
|
|
message_dict = {
|
|
"role": "tool",
|
|
"content": message.content,
|
|
"tool_call_id": message.tool_call_id,
|
|
}
|
|
else:
|
|
raise TypeError(f"Got unknown type {message}")
|
|
if "name" in message.additional_kwargs:
|
|
message_dict["name"] = message.additional_kwargs["name"]
|
|
return message_dict
|
|
|
|
|
|
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
|
|
"""Convert dictionaries representing OpenAI messages to LangChain format.
|
|
|
|
Args:
|
|
messages: List of dictionaries representing OpenAI messages
|
|
|
|
Returns:
|
|
List of LangChain BaseMessage objects.
|
|
"""
|
|
return [convert_dict_to_message(m) for m in messages]
|
|
|
|
|
|
def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict:
|
|
_dict: Dict[str, Any] = {}
|
|
if isinstance(chunk, AIMessageChunk):
|
|
if i == 0:
|
|
# Only shows up in the first chunk
|
|
_dict["role"] = "assistant"
|
|
if "function_call" in chunk.additional_kwargs:
|
|
_dict["function_call"] = chunk.additional_kwargs["function_call"]
|
|
# If the first chunk is a function call, the content is not empty string,
|
|
# not missing, but None.
|
|
if i == 0:
|
|
_dict["content"] = None
|
|
else:
|
|
_dict["content"] = chunk.content
|
|
else:
|
|
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
|
|
# This only happens at the end of streams, and OpenAI returns as empty dict
|
|
if _dict == {"content": ""}:
|
|
_dict = {}
|
|
return _dict
|
|
|
|
|
|
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
|
|
_dict = _convert_message_chunk(chunk, i)
|
|
return {"choices": [{"delta": _dict}]}
|
|
|
|
|
|
class ChatCompletion:
|
|
"""Chat completion."""
|
|
|
|
@overload
|
|
@staticmethod
|
|
def create(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[False] = False,
|
|
**kwargs: Any,
|
|
) -> dict:
|
|
...
|
|
|
|
@overload
|
|
@staticmethod
|
|
def create(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[True],
|
|
**kwargs: Any,
|
|
) -> Iterable:
|
|
...
|
|
|
|
@staticmethod
|
|
def create(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: bool = False,
|
|
**kwargs: Any,
|
|
) -> Union[dict, Iterable]:
|
|
models = importlib.import_module("langchain.chat_models")
|
|
model_cls = getattr(models, provider)
|
|
model_config = model_cls(**kwargs)
|
|
converted_messages = convert_openai_messages(messages)
|
|
if not stream:
|
|
result = model_config.invoke(converted_messages)
|
|
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
|
else:
|
|
return (
|
|
_convert_message_chunk_to_delta(c, i)
|
|
for i, c in enumerate(model_config.stream(converted_messages))
|
|
)
|
|
|
|
@overload
|
|
@staticmethod
|
|
async def acreate(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[False] = False,
|
|
**kwargs: Any,
|
|
) -> dict:
|
|
...
|
|
|
|
@overload
|
|
@staticmethod
|
|
async def acreate(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[True],
|
|
**kwargs: Any,
|
|
) -> AsyncIterator:
|
|
...
|
|
|
|
@staticmethod
|
|
async def acreate(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: bool = False,
|
|
**kwargs: Any,
|
|
) -> Union[dict, AsyncIterator]:
|
|
models = importlib.import_module("langchain.chat_models")
|
|
model_cls = getattr(models, provider)
|
|
model_config = model_cls(**kwargs)
|
|
converted_messages = convert_openai_messages(messages)
|
|
if not stream:
|
|
result = await model_config.ainvoke(converted_messages)
|
|
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
|
else:
|
|
return (
|
|
_convert_message_chunk_to_delta(c, i)
|
|
async for i, c in aenumerate(model_config.astream(converted_messages))
|
|
)
|
|
|
|
|
|
def _has_assistant_message(session: ChatSession) -> bool:
|
|
"""Check if chat session has an assistant message."""
|
|
return any([isinstance(m, AIMessage) for m in session["messages"]])
|
|
|
|
|
|
def convert_messages_for_finetuning(
|
|
sessions: Iterable[ChatSession],
|
|
) -> List[List[dict]]:
|
|
"""Convert messages to a list of lists of dictionaries for fine-tuning.
|
|
|
|
Args:
|
|
sessions: The chat sessions.
|
|
|
|
Returns:
|
|
The list of lists of dictionaries.
|
|
"""
|
|
return [
|
|
[convert_message_to_dict(s) for s in session["messages"]]
|
|
for session in sessions
|
|
if _has_assistant_message(session)
|
|
]
|
|
|
|
|
|
class Completions:
|
|
"""Completion."""
|
|
|
|
@overload
|
|
@staticmethod
|
|
def create(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[False] = False,
|
|
**kwargs: Any,
|
|
) -> ChatCompletions:
|
|
...
|
|
|
|
@overload
|
|
@staticmethod
|
|
def create(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[True],
|
|
**kwargs: Any,
|
|
) -> Iterable:
|
|
...
|
|
|
|
@staticmethod
|
|
def create(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: bool = False,
|
|
**kwargs: Any,
|
|
) -> Union[ChatCompletions, Iterable]:
|
|
models = importlib.import_module("langchain.chat_models")
|
|
model_cls = getattr(models, provider)
|
|
model_config = model_cls(**kwargs)
|
|
converted_messages = convert_openai_messages(messages)
|
|
if not stream:
|
|
result = model_config.invoke(converted_messages)
|
|
return ChatCompletions(
|
|
choices=[Choice(message=convert_message_to_dict(result))]
|
|
)
|
|
else:
|
|
return (
|
|
ChatCompletionChunk(
|
|
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
|
|
)
|
|
for i, c in enumerate(model_config.stream(converted_messages))
|
|
)
|
|
|
|
@overload
|
|
@staticmethod
|
|
async def acreate(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[False] = False,
|
|
**kwargs: Any,
|
|
) -> ChatCompletions:
|
|
...
|
|
|
|
@overload
|
|
@staticmethod
|
|
async def acreate(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: Literal[True],
|
|
**kwargs: Any,
|
|
) -> AsyncIterator:
|
|
...
|
|
|
|
@staticmethod
|
|
async def acreate(
|
|
messages: Sequence[Dict[str, Any]],
|
|
*,
|
|
provider: str = "ChatOpenAI",
|
|
stream: bool = False,
|
|
**kwargs: Any,
|
|
) -> Union[ChatCompletions, AsyncIterator]:
|
|
models = importlib.import_module("langchain.chat_models")
|
|
model_cls = getattr(models, provider)
|
|
model_config = model_cls(**kwargs)
|
|
converted_messages = convert_openai_messages(messages)
|
|
if not stream:
|
|
result = await model_config.ainvoke(converted_messages)
|
|
return ChatCompletions(
|
|
choices=[Choice(message=convert_message_to_dict(result))]
|
|
)
|
|
else:
|
|
return (
|
|
ChatCompletionChunk(
|
|
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
|
|
)
|
|
async for i, c in aenumerate(model_config.astream(converted_messages))
|
|
)
|
|
|
|
|
|
class Chat:
|
|
def __init__(self) -> None:
|
|
self.completions = Completions()
|
|
|
|
|
|
chat = Chat()
|