mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 04:38:26 +00:00
Simplify AzureChatOpenAI implementation. (#1902)
Change AzureChatOpenAI class implementation as Azure just added support for chat completion API. See: https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions. This should make the code much simpler.
This commit is contained in:
parent
f155d9d3ec
commit
273e9bf296
@ -2,62 +2,50 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import root_validator
|
||||
|
||||
from langchain.chat_models.openai import (
|
||||
ChatOpenAI,
|
||||
acompletion_with_retry,
|
||||
)
|
||||
from langchain.schema import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatGeneration,
|
||||
ChatResult,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__file__)
|
||||
|
||||
|
||||
def _create_chat_prompt(messages: List[BaseMessage]) -> str:
|
||||
"""Create a prompt for Azure OpenAI using ChatML."""
|
||||
prompt = "\n".join([message.format_chatml() for message in messages])
|
||||
return prompt + "\n<|im_start|>assistant\n"
|
||||
|
||||
|
||||
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||
generations = []
|
||||
for res in response["choices"]:
|
||||
message = AIMessage(content=res["text"])
|
||||
gen = ChatGeneration(message=message)
|
||||
generations.append(gen)
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
|
||||
class AzureChatOpenAI(ChatOpenAI):
|
||||
"""Wrapper around Azure OpenAI Chat large language models.
|
||||
"""Wrapper around Azure OpenAI Chat Completion API. To use this class you
|
||||
must have a deployed model on Azure OpenAI. Use `deployment_name` in the
|
||||
constructor to refer to the "Model deployment name" in the Azure portal.
|
||||
|
||||
To use, you should have the ``openai`` python package installed, and the
|
||||
following environment variables set:
|
||||
- ``OPENAI_API_TYPE``
|
||||
In addition, you should have the ``openai`` python package installed, and the
|
||||
following environment variables set or passed in constructor in lower case:
|
||||
- ``OPENAI_API_TYPE`` (default: ``azure``)
|
||||
- ``OPENAI_API_KEY``
|
||||
- ``OPENAI_API_BASE``
|
||||
- ``OPENAI_API_VERSION``
|
||||
|
||||
For exmaple, if you have `gpt-35-turbo` deployed, with the deployment name
|
||||
`35-turbo-dev`, the constructor should look like:
|
||||
|
||||
.. code-block:: python
|
||||
AzureChatOpenAI(
|
||||
deployment_name="35-turbo-dev",
|
||||
openai_api_version="2023-03-15-preview",
|
||||
)
|
||||
|
||||
Be aware the API version may change.
|
||||
|
||||
Any parameters that are valid to be passed to the openai.create call can be passed
|
||||
in, even if not explicitly saved on this class.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.chat_models import AzureChatOpenAI
|
||||
openai = AzureChatOpenAI(deployment_name="<your deployment name>")
|
||||
"""
|
||||
|
||||
deployment_name: str = ""
|
||||
stop: List[str] = ["<|im_end|>"]
|
||||
openai_api_type: str = "azure"
|
||||
openai_api_base: str = ""
|
||||
openai_api_version: str = ""
|
||||
openai_api_key: str = ""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -95,10 +83,10 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
"Please it install it with `pip install openai`."
|
||||
)
|
||||
try:
|
||||
values["client"] = openai.Completion
|
||||
values["client"] = openai.ChatCompletion
|
||||
except AttributeError:
|
||||
raise ValueError(
|
||||
"`openai` has no `Completion` attribute, this is likely "
|
||||
"`openai` has no `ChatCompletion` attribute, this is likely "
|
||||
"due to an old version of the openai package. Try upgrading it "
|
||||
"with `pip install --upgrade openai`."
|
||||
)
|
||||
@ -113,66 +101,5 @@ class AzureChatOpenAI(ChatOpenAI):
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
**super()._default_params,
|
||||
"stop": self.stop,
|
||||
"engine": self.deployment_name,
|
||||
}
|
||||
|
||||
def _generate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
prompt, params = self._create_prompt(messages, stop)
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
params["stream"] = True
|
||||
for stream_resp in self.completion_with_retry(prompt=prompt, **params):
|
||||
token = stream_resp["choices"][0]["delta"].get("text", "")
|
||||
inner_completion += token
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
message = AIMessage(content=inner_completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
response = self.completion_with_retry(prompt=prompt, **params)
|
||||
return _create_chat_result(response)
|
||||
|
||||
def _create_prompt(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[str, Dict[str, Any]]:
|
||||
params: Dict[str, Any] = {
|
||||
**{"model": self.model_name, "engine": self.deployment_name},
|
||||
**self._default_params,
|
||||
}
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
prompt = _create_chat_prompt(messages)
|
||||
return prompt, params
|
||||
|
||||
async def _agenerate(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||
) -> ChatResult:
|
||||
prompt, params = self._create_prompt(messages, stop)
|
||||
if self.streaming:
|
||||
inner_completion = ""
|
||||
params["stream"] = True
|
||||
async for stream_resp in await acompletion_with_retry(
|
||||
self, prompt=prompt, **params
|
||||
):
|
||||
token = stream_resp["choices"][0]["delta"].get("text", "")
|
||||
inner_completion += token
|
||||
if self.callback_manager.is_async:
|
||||
await self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
else:
|
||||
self.callback_manager.on_llm_new_token(
|
||||
token,
|
||||
verbose=self.verbose,
|
||||
)
|
||||
message = AIMessage(content=inner_completion)
|
||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||
else:
|
||||
response = await acompletion_with_retry(self, prompt=prompt, **params)
|
||||
return _create_chat_result(response)
|
||||
|
@ -60,9 +60,6 @@ class BaseMessage(BaseModel):
|
||||
content: str
|
||||
additional_kwargs: dict = Field(default_factory=dict)
|
||||
|
||||
def format_chatml(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type(self) -> str:
|
||||
@ -72,9 +69,6 @@ class BaseMessage(BaseModel):
|
||||
class HumanMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the human."""
|
||||
|
||||
def format_chatml(self) -> str:
|
||||
return f"<|im_start|>user\n{self.content}\n<|im_end|>"
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
@ -84,9 +78,6 @@ class HumanMessage(BaseMessage):
|
||||
class AIMessage(BaseMessage):
|
||||
"""Type of message that is spoken by the AI."""
|
||||
|
||||
def format_chatml(self) -> str:
|
||||
return f"<|im_start|>assistant\n{self.content}\n<|im_end|>"
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
@ -96,9 +87,6 @@ class AIMessage(BaseMessage):
|
||||
class SystemMessage(BaseMessage):
|
||||
"""Type of message that is a system message."""
|
||||
|
||||
def format_chatml(self) -> str:
|
||||
return f"<|im_start|>system\n{self.content}\n<|im_end|>"
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
@ -110,9 +98,6 @@ class ChatMessage(BaseMessage):
|
||||
|
||||
role: str
|
||||
|
||||
def format_chatml(self) -> str:
|
||||
return f"<|im_start|>{self.role}\n{self.content}\n<|im_end|>"
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
"""Type of the message, used for serialization."""
|
||||
|
Loading…
Reference in New Issue
Block a user