mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-15 22:44:36 +00:00
Add openai v2 adapter (#14063)
### Description
Starting from [openai version
1.0.0](17ac677995 (module-level-client)
),
the camel case form of `openai.ChatCompletion` is no longer supported
and has been changed to lowercase `openai.chat.completions`. In
addition, the returned object only accepts attribute access instead of
index access:
```python
import openai
# optional; defaults to `os.environ['OPENAI_API_KEY']`
openai.api_key = '...'
# all client options can be configured just like the `OpenAI` instantiation counterpart
openai.base_url = "https://..."
openai.default_headers = {"x-foo": "true"}
completion = openai.chat.completions.create(
model="gpt-4",
messages=[
{
"role": "user",
"content": "How do I output all files in a directory using Python?",
},
],
)
print(completion.choices[0].message.content)
```
So I implemented a compatible adapter that supports both attribute
access and index access:
```python
In [1]: from langchain.adapters import openai as lc_openai
...: messages = [{"role": "user", "content": "hi"}]
In [2]: result = lc_openai.chat.completions.create(
...: messages=messages, model="gpt-3.5-turbo", temperature=0
...: )
In [3]: result.choices[0].message
Out[3]: {'role': 'assistant', 'content': 'Hello! How can I assist you today?'}
In [4]: result["choices"][0]["message"]
Out[4]: {'role': 'assistant', 'content': 'Hello! How can I assist you today?'}
In [5]: result = await lc_openai.chat.completions.acreate(
...: messages=messages, model="gpt-3.5-turbo", temperature=0
...: )
In [6]: result.choices[0].message
Out[6]: {'role': 'assistant', 'content': 'Hello! How can I assist you today?'}
In [7]: result["choices"][0]["message"]
Out[7]: {'role': 'assistant', 'content': 'Hello! How can I assist you today?'}
In [8]: for rs in lc_openai.chat.completions.create(
...: messages=messages, model="gpt-3.5-turbo", temperature=0, stream=True
...: ):
...: print(rs.choices[0].delta)
...: print(rs["choices"][0]["delta"])
...:
{'role': 'assistant', 'content': ''}
{'role': 'assistant', 'content': ''}
{'content': 'Hello'}
{'content': 'Hello'}
{'content': '!'}
{'content': '!'}
In [20]: async for rs in await lc_openai.chat.completions.acreate(
...: messages=messages, model="gpt-3.5-turbo", temperature=0, stream=True
...: ):
...: print(rs.choices[0].delta)
...: print(rs["choices"][0]["delta"])
...:
{'role': 'assistant', 'content': ''}
{'role': 'assistant', 'content': ''}
{'content': 'Hello'}
{'content': 'Hello'}
{'content': '!'}
{'content': '!'}
...
```
### Twitter handle
[lin_bob57617](https://twitter.com/lin_bob57617)
This commit is contained in:
@@ -25,6 +25,7 @@ from langchain_core.messages import (
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import BaseModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
|
||||
@@ -38,6 +39,29 @@ async def aenumerate(
|
||||
i += 1
|
||||
|
||||
|
||||
class IndexableBaseModel(BaseModel):
|
||||
"""Allows a BaseModel to return its fields by string variable indexing"""
|
||||
|
||||
def __getitem__(self, item: str) -> Any:
|
||||
return getattr(self, item)
|
||||
|
||||
|
||||
class Choice(IndexableBaseModel):
|
||||
message: dict
|
||||
|
||||
|
||||
class ChatCompletions(IndexableBaseModel):
|
||||
choices: List[Choice]
|
||||
|
||||
|
||||
class ChoiceChunk(IndexableBaseModel):
|
||||
delta: dict
|
||||
|
||||
|
||||
class ChatCompletionChunk(IndexableBaseModel):
|
||||
choices: List[ChoiceChunk]
|
||||
|
||||
|
||||
def convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage:
|
||||
"""Convert a dictionary to a LangChain message.
|
||||
|
||||
@@ -129,7 +153,7 @@ def convert_openai_messages(messages: Sequence[Dict[str, Any]]) -> List[BaseMess
|
||||
return [convert_dict_to_message(m) for m in messages]
|
||||
|
||||
|
||||
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
|
||||
def _convert_message_chunk(chunk: BaseMessageChunk, i: int) -> dict:
|
||||
_dict: Dict[str, Any] = {}
|
||||
if isinstance(chunk, AIMessageChunk):
|
||||
if i == 0:
|
||||
@@ -148,6 +172,11 @@ def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str
|
||||
# This only happens at the end of streams, and OpenAI returns as empty dict
|
||||
if _dict == {"content": ""}:
|
||||
_dict = {}
|
||||
return _dict
|
||||
|
||||
|
||||
def _convert_message_chunk_to_delta(chunk: BaseMessageChunk, i: int) -> Dict[str, Any]:
|
||||
_dict = _convert_message_chunk(chunk, i)
|
||||
return {"choices": [{"delta": _dict}]}
|
||||
|
||||
|
||||
@@ -262,3 +291,109 @@ def convert_messages_for_finetuning(
|
||||
for session in sessions
|
||||
if _has_assistant_message(session)
|
||||
]
|
||||
|
||||
|
||||
class Completions:
|
||||
"""Completion."""
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> ChatCompletions:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> Iterable:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatCompletions, Iterable]:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = model_config.invoke(converted_messages)
|
||||
return ChatCompletions(
|
||||
choices=[Choice(message=convert_message_to_dict(result))]
|
||||
)
|
||||
else:
|
||||
return (
|
||||
ChatCompletionChunk(
|
||||
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
|
||||
)
|
||||
for i, c in enumerate(model_config.stream(converted_messages))
|
||||
)
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[False] = False,
|
||||
**kwargs: Any,
|
||||
) -> ChatCompletions:
|
||||
...
|
||||
|
||||
@overload
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: Literal[True],
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator:
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
async def acreate(
|
||||
messages: Sequence[Dict[str, Any]],
|
||||
*,
|
||||
provider: str = "ChatOpenAI",
|
||||
stream: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> Union[ChatCompletions, AsyncIterator]:
|
||||
models = importlib.import_module("langchain.chat_models")
|
||||
model_cls = getattr(models, provider)
|
||||
model_config = model_cls(**kwargs)
|
||||
converted_messages = convert_openai_messages(messages)
|
||||
if not stream:
|
||||
result = await model_config.ainvoke(converted_messages)
|
||||
return ChatCompletions(
|
||||
choices=[Choice(message=convert_message_to_dict(result))]
|
||||
)
|
||||
else:
|
||||
return (
|
||||
ChatCompletionChunk(
|
||||
choices=[ChoiceChunk(delta=_convert_message_chunk(c, i))]
|
||||
)
|
||||
async for i, c in aenumerate(model_config.astream(converted_messages))
|
||||
)
|
||||
|
||||
|
||||
class Chat:
|
||||
def __init__(self) -> None:
|
||||
self.completions = Completions()
|
||||
|
||||
|
||||
chat = Chat()
|
||||
|
Reference in New Issue
Block a user