From 273e9bf2964035c0c4838315dbe84c0532f5ab53 Mon Sep 17 00:00:00 2001 From: Eric Zhu Date: Wed, 22 Mar 2023 19:36:51 -0700 Subject: [PATCH] Simplify AzureChatOpenAI implementation. (#1902) Change AzureChatOpenAI class implementation as Azure just added support for chat completion API. See: https://learn.microsoft.com/en-us/azure/cognitive-services/openai/how-to/chatgpt?pivots=programming-language-chat-completions. This should make the code much simpler. --- langchain/chat_models/azure_openai.py | 123 ++++++-------------------- langchain/schema.py | 15 ---- 2 files changed, 25 insertions(+), 113 deletions(-) diff --git a/langchain/chat_models/azure_openai.py b/langchain/chat_models/azure_openai.py index a91ed21c576..37f00d5017e 100644 --- a/langchain/chat_models/azure_openai.py +++ b/langchain/chat_models/azure_openai.py @@ -2,62 +2,50 @@ from __future__ import annotations import logging -from typing import Any, Dict, List, Mapping, Optional, Tuple +from typing import Any, Dict from pydantic import root_validator from langchain.chat_models.openai import ( ChatOpenAI, - acompletion_with_retry, -) -from langchain.schema import ( - AIMessage, - BaseMessage, - ChatGeneration, - ChatResult, ) from langchain.utils import get_from_dict_or_env logger = logging.getLogger(__file__) -def _create_chat_prompt(messages: List[BaseMessage]) -> str: - """Create a prompt for Azure OpenAI using ChatML.""" - prompt = "\n".join([message.format_chatml() for message in messages]) - return prompt + "\n<|im_start|>assistant\n" - - -def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: - generations = [] - for res in response["choices"]: - message = AIMessage(content=res["text"]) - gen = ChatGeneration(message=message) - generations.append(gen) - return ChatResult(generations=generations) - - class AzureChatOpenAI(ChatOpenAI): - """Wrapper around Azure OpenAI Chat large language models. + """Wrapper around Azure OpenAI Chat Completion API. To use this class you + must have a deployed model on Azure OpenAI. Use `deployment_name` in the + constructor to refer to the "Model deployment name" in the Azure portal. - To use, you should have the ``openai`` python package installed, and the - following environment variables set: - - ``OPENAI_API_TYPE`` + In addition, you should have the ``openai`` python package installed, and the + following environment variables set or passed in constructor in lower case: + - ``OPENAI_API_TYPE`` (default: ``azure``) - ``OPENAI_API_KEY`` - ``OPENAI_API_BASE`` - ``OPENAI_API_VERSION`` + For exmaple, if you have `gpt-35-turbo` deployed, with the deployment name + `35-turbo-dev`, the constructor should look like: + + .. code-block:: python + AzureChatOpenAI( + deployment_name="35-turbo-dev", + openai_api_version="2023-03-15-preview", + ) + + Be aware the API version may change. + Any parameters that are valid to be passed to the openai.create call can be passed in, even if not explicitly saved on this class. - - Example: - .. code-block:: python - - from langchain.chat_models import AzureChatOpenAI - openai = AzureChatOpenAI(deployment_name="") """ deployment_name: str = "" - stop: List[str] = ["<|im_end|>"] + openai_api_type: str = "azure" + openai_api_base: str = "" + openai_api_version: str = "" + openai_api_key: str = "" @root_validator() def validate_environment(cls, values: Dict) -> Dict: @@ -95,10 +83,10 @@ class AzureChatOpenAI(ChatOpenAI): "Please it install it with `pip install openai`." ) try: - values["client"] = openai.Completion + values["client"] = openai.ChatCompletion except AttributeError: raise ValueError( - "`openai` has no `Completion` attribute, this is likely " + "`openai` has no `ChatCompletion` attribute, this is likely " "due to an old version of the openai package. Try upgrading it " "with `pip install --upgrade openai`." ) @@ -113,66 +101,5 @@ class AzureChatOpenAI(ChatOpenAI): """Get the default parameters for calling OpenAI API.""" return { **super()._default_params, - "stop": self.stop, + "engine": self.deployment_name, } - - def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None - ) -> ChatResult: - prompt, params = self._create_prompt(messages, stop) - if self.streaming: - inner_completion = "" - params["stream"] = True - for stream_resp in self.completion_with_retry(prompt=prompt, **params): - token = stream_resp["choices"][0]["delta"].get("text", "") - inner_completion += token - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) - message = AIMessage(content=inner_completion) - return ChatResult(generations=[ChatGeneration(message=message)]) - response = self.completion_with_retry(prompt=prompt, **params) - return _create_chat_result(response) - - def _create_prompt( - self, messages: List[BaseMessage], stop: Optional[List[str]] - ) -> Tuple[str, Dict[str, Any]]: - params: Dict[str, Any] = { - **{"model": self.model_name, "engine": self.deployment_name}, - **self._default_params, - } - if stop is not None: - if "stop" in params: - raise ValueError("`stop` found in both the input and default params.") - params["stop"] = stop - prompt = _create_chat_prompt(messages) - return prompt, params - - async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None - ) -> ChatResult: - prompt, params = self._create_prompt(messages, stop) - if self.streaming: - inner_completion = "" - params["stream"] = True - async for stream_resp in await acompletion_with_retry( - self, prompt=prompt, **params - ): - token = stream_resp["choices"][0]["delta"].get("text", "") - inner_completion += token - if self.callback_manager.is_async: - await self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) - else: - self.callback_manager.on_llm_new_token( - token, - verbose=self.verbose, - ) - message = AIMessage(content=inner_completion) - return ChatResult(generations=[ChatGeneration(message=message)]) - else: - response = await acompletion_with_retry(self, prompt=prompt, **params) - return _create_chat_result(response) diff --git a/langchain/schema.py b/langchain/schema.py index 84840bcdbca..b42a640b8b1 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -60,9 +60,6 @@ class BaseMessage(BaseModel): content: str additional_kwargs: dict = Field(default_factory=dict) - def format_chatml(self) -> str: - raise NotImplementedError() - @property @abstractmethod def type(self) -> str: @@ -72,9 +69,6 @@ class BaseMessage(BaseModel): class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" - def format_chatml(self) -> str: - return f"<|im_start|>user\n{self.content}\n<|im_end|>" - @property def type(self) -> str: """Type of the message, used for serialization.""" @@ -84,9 +78,6 @@ class HumanMessage(BaseMessage): class AIMessage(BaseMessage): """Type of message that is spoken by the AI.""" - def format_chatml(self) -> str: - return f"<|im_start|>assistant\n{self.content}\n<|im_end|>" - @property def type(self) -> str: """Type of the message, used for serialization.""" @@ -96,9 +87,6 @@ class AIMessage(BaseMessage): class SystemMessage(BaseMessage): """Type of message that is a system message.""" - def format_chatml(self) -> str: - return f"<|im_start|>system\n{self.content}\n<|im_end|>" - @property def type(self) -> str: """Type of the message, used for serialization.""" @@ -110,9 +98,6 @@ class ChatMessage(BaseMessage): role: str - def format_chatml(self) -> str: - return f"<|im_start|>{self.role}\n{self.content}\n<|im_end|>" - @property def type(self) -> str: """Type of the message, used for serialization."""