mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
openai adapters (#8988)
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com> Co-authored-by: Nuno Campos <nuno@boringbits.io>
This commit is contained in:
0
libs/langchain/langchain/adapters/__init__.py
Normal file
0
libs/langchain/langchain/adapters/__init__.py
Normal file
208
libs/langchain/langchain/adapters/openai.py
Normal file
208
libs/langchain/langchain/adapters/openai.py
Normal file
@@ -0,0 +1,208 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterable,
|
||||
List,
|
||||
Mapping,
|
||||
Sequence,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
|
||||
from typing_extensions import Literal
|
||||
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
|
||||
|
||||
async def aenumerate(
|
||||
iterable: AsyncIterator[Any], start: int = 0
|
||||
) -> AsyncIterator[tuple[int, Any]]:
|
||||
"""Async version of enumerate."""
|
||||
i = start
|
||||
async for x in iterable:
|
||||
yield i, x
|
||||
i += 1
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
message_dict: Dict[str, Any]
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
# If function call only, content is None not empty string
|
||||
if message_dict["content"] == "":
|
||||
message_dict["content"] = None
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMessage]:
|
||||
"""Convert dictionaries representing OpenAI messages to LangChain format.
|
||||
|
||||
Args:
|
||||
messages: List of dictionaries representing OpenAI messages
|
||||
|
||||
Returns:
|
||||
List of LangChain BaseMessage objects.
|
||||
"""
|
||||
return [convert_dict_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
|
||||
_dict: Dict[str, Any] = {}
|
||||
if isinstance(chunk, AIMessageChunk):
|
||||
if i == 0:
|
||||
# Only shows up in the first chunk
|
||||
_dict["role"] = "assistant"
|
||||
if "function_call" in chunk.additional_kwargs:
|
||||
_dict["function_call"] = chunk.additional_kwargs["function_call"]
|
||||
# If the first chunk is a function call, the content is not empty string,
|
||||
# not missing, but None.
|
||||
if i == 0:
|
||||
_dict["content"] = None
|
||||
else:
|
||||
_dict["content"] = chunk.content
|
||||
else:
|
||||
raise ValueError(f"Got unexpected streaming chunk type: {type(chunk)}")
|
||||
# This only happens at the end of streams, and OpenAI returns as empty dict
|
||||
if _dict == {"content": ""}:
|
||||
_dict = {}
|
||||
return {"choices": [{"delta": _dict}]}
|
||||
|
||||
|
||||
class ChatCompletion:
|
||||
@overload
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Iterable:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[dict, Iterable]:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = model_config.invoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
for i, c in enumerate(model_config.stream(converted_messages))
|
||||
)
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> dict:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[dict, AsyncIterator]:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = await model_config.ainvoke(converted_messages)
|
||||
return {"choices": [{"message": convert_message_to_dict(result)}]}
|
||||
else:
|
||||
return (
|
||||
_convert_message_chunk_to_delta(c, i)
|
||||
async for i, c in aenumerate(model_config.astream(converted_messages))
|
||||
)
|
@@ -9,9 +9,9 @@ from typing import TYPE_CHECKING, Optional, Set
|
||||
import requests
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.adapters.openai import convert_message_to_dict
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_convert_message_to_dict,
|
||||
_import_tiktoken,
|
||||
)
|
||||
from langchain.schema.messages import BaseMessage
|
||||
@@ -178,7 +178,7 @@ class ChatAnyscale(ChatOpenAI):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
num_tokens = 0
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
@@ -19,6 +19,7 @@ from typing import (
|
||||
|
||||
from pydantic import Field, root_validator
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message, convert_message_to_dict
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
@@ -27,17 +28,12 @@ from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.llms.base import create_base_retry_decorator
|
||||
from langchain.schema import ChatGeneration, ChatResult
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain.schema.output import ChatGenerationChunk
|
||||
@@ -121,63 +117,6 @@ def _convert_delta_to_message_chunk(
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
# Fix for azure
|
||||
# Also OpenAI returns None for tool invocations
|
||||
content = _dict.get("content", "") or ""
|
||||
if _dict.get("function_call"):
|
||||
additional_kwargs = {"function_call": dict(_dict["function_call"])}
|
||||
else:
|
||||
additional_kwargs = {}
|
||||
return AIMessage(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
elif role == "function":
|
||||
return FunctionMessage(content=_dict["content"], name=_dict["name"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
|
||||
def convert_openai_messages(messages: List[dict]) -> List[BaseMessage]:
|
||||
"""Convert dictionaries representing OpenAI messages to LangChain format.
|
||||
|
||||
Args:
|
||||
messages: List of dictionaries representing OpenAI messages
|
||||
|
||||
Returns:
|
||||
List of LangChain BaseMessage objects.
|
||||
"""
|
||||
return [_convert_dict_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
|
||||
class ChatOpenAI(BaseChatModel):
|
||||
"""Wrapper around OpenAI Chat large language models.
|
||||
|
||||
@@ -411,13 +350,13 @@ class ChatOpenAI(BaseChatModel):
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
message_dicts = [convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def _create_chat_result(self, response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = _convert_dict_to_message(res["message"])
|
||||
message = convert_dict_to_message(res["message"])
|
||||
gen = ChatGeneration(
|
||||
message=message,
|
||||
generation_info=dict(finish_reason=res.get("finish_reason")),
|
||||
@@ -568,7 +507,7 @@ class ChatOpenAI(BaseChatModel):
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [_convert_message_to_dict(m) for m in messages]
|
||||
messages_dict = [convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
|
107
libs/langchain/tests/integration_tests/adapters/test_openai.py
Normal file
107
libs/langchain/tests/integration_tests/adapters/test_openai.py
Normal file
@@ -0,0 +1,107 @@
|
||||
from typing import Any
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
|
||||
from langchain.adapters import openai as lcopenai
|
||||
|
||||
|
||||
def _test_no_stream(**kwargs: Any) -> None:
|
||||
result = openai.ChatCompletion.create(**kwargs)
|
||||
lc_result = lcopenai.ChatCompletion.create(**kwargs)
|
||||
if isinstance(lc_result, dict):
|
||||
if isinstance(result, dict):
|
||||
result_dict = result["choices"][0]["message"].to_dict_recursive()
|
||||
lc_result_dict = lc_result["choices"][0]["message"]
|
||||
assert result_dict == lc_result_dict
|
||||
return
|
||||
|
||||
|
||||
def _test_stream(**kwargs: Any) -> None:
|
||||
result = []
|
||||
for c in openai.ChatCompletion.create(**kwargs):
|
||||
result.append(c["choices"][0]["delta"].to_dict_recursive())
|
||||
|
||||
lc_result = []
|
||||
for c in lcopenai.ChatCompletion.create(**kwargs):
|
||||
lc_result.append(c["choices"][0]["delta"])
|
||||
assert result == lc_result
|
||||
|
||||
|
||||
async def _test_async(**kwargs: Any) -> None:
|
||||
result = await openai.ChatCompletion.acreate(**kwargs)
|
||||
lc_result = await lcopenai.ChatCompletion.acreate(**kwargs)
|
||||
if isinstance(lc_result, dict):
|
||||
if isinstance(result, dict):
|
||||
result_dict = result["choices"][0]["message"].to_dict_recursive()
|
||||
lc_result_dict = lc_result["choices"][0]["message"]
|
||||
assert result_dict == lc_result_dict
|
||||
return
|
||||
|
||||
|
||||
async def _test_astream(**kwargs: Any) -> None:
|
||||
result = []
|
||||
async for c in await openai.ChatCompletion.acreate(**kwargs):
|
||||
result.append(c["choices"][0]["delta"].to_dict_recursive())
|
||||
|
||||
lc_result = []
|
||||
async for c in await lcopenai.ChatCompletion.acreate(**kwargs):
|
||||
lc_result.append(c["choices"][0]["delta"])
|
||||
assert result == lc_result
|
||||
|
||||
|
||||
FUNCTIONS = [
|
||||
{
|
||||
"name": "get_current_weather",
|
||||
"description": "Get the current weather in a given location",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"location": {
|
||||
"type": "string",
|
||||
"description": "The city and state, e.g. San Francisco, CA",
|
||||
},
|
||||
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||
},
|
||||
"required": ["location"],
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def _test_module(**kwargs: Any) -> None:
|
||||
_test_no_stream(**kwargs)
|
||||
await _test_async(**kwargs)
|
||||
_test_stream(stream=True, **kwargs)
|
||||
await _test_astream(stream=True, **kwargs)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_normal_call() -> None:
|
||||
await _test_module(
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
model="gpt-3.5-turbo",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_function_calling() -> None:
|
||||
await _test_module(
|
||||
messages=[{"role": "user", "content": "whats the weather in boston"}],
|
||||
model="gpt-3.5-turbo",
|
||||
functions=FUNCTIONS,
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_answer_with_function_calling() -> None:
|
||||
await _test_module(
|
||||
messages=[
|
||||
{"role": "user", "content": "say hi, then whats the weather in boston"}
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
functions=FUNCTIONS,
|
||||
temperature=0,
|
||||
)
|
@@ -5,9 +5,9 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain.adapters.openai import convert_dict_to_message
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
_convert_dict_to_message,
|
||||
)
|
||||
from langchain.schema.messages import (
|
||||
AIMessage,
|
||||
@@ -20,7 +20,7 @@ from langchain.schema.messages import (
|
||||
def test_function_message_dict_to_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = _convert_dict_to_message(
|
||||
result = convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
@@ -34,21 +34,21 @@ def test_function_message_dict_to_function_message() -> None:
|
||||
|
||||
def test__convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
||||
def test__convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
result = convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
|
||||
|
Reference in New Issue
Block a user