diff --git a/langchain/chat_models/__init__.py b/langchain/chat_models/__init__.py index 2b1b9b3dfab..88bafc7e7a2 100644 --- a/langchain/chat_models/__init__.py +++ b/langchain/chat_models/__init__.py @@ -1,4 +1,5 @@ +from langchain.chat_models.azure_openai import AzureChatOpenAI from langchain.chat_models.openai import ChatOpenAI from langchain.chat_models.promptlayer_openai import PromptLayerChatOpenAI -__all__ = ["ChatOpenAI", "PromptLayerChatOpenAI"] +__all__ = ["ChatOpenAI", "AzureChatOpenAI", "PromptLayerChatOpenAI"] diff --git a/langchain/chat_models/azure_openai.py b/langchain/chat_models/azure_openai.py new file mode 100644 index 00000000000..a91ed21c576 --- /dev/null +++ b/langchain/chat_models/azure_openai.py @@ -0,0 +1,178 @@ +"""Azure OpenAI chat wrapper.""" +from __future__ import annotations + +import logging +from typing import Any, Dict, List, Mapping, Optional, Tuple + +from pydantic import root_validator + +from langchain.chat_models.openai import ( + ChatOpenAI, + acompletion_with_retry, +) +from langchain.schema import ( + AIMessage, + BaseMessage, + ChatGeneration, + ChatResult, +) +from langchain.utils import get_from_dict_or_env + +logger = logging.getLogger(__file__) + + +def _create_chat_prompt(messages: List[BaseMessage]) -> str: + """Create a prompt for Azure OpenAI using ChatML.""" + prompt = "\n".join([message.format_chatml() for message in messages]) + return prompt + "\n<|im_start|>assistant\n" + + +def _create_chat_result(response: Mapping[str, Any]) -> ChatResult: + generations = [] + for res in response["choices"]: + message = AIMessage(content=res["text"]) + gen = ChatGeneration(message=message) + generations.append(gen) + return ChatResult(generations=generations) + + +class AzureChatOpenAI(ChatOpenAI): + """Wrapper around Azure OpenAI Chat large language models. + + To use, you should have the ``openai`` python package installed, and the + following environment variables set: + - ``OPENAI_API_TYPE`` + - ``OPENAI_API_KEY`` + - ``OPENAI_API_BASE`` + - ``OPENAI_API_VERSION`` + + Any parameters that are valid to be passed to the openai.create call can be passed + in, even if not explicitly saved on this class. + + Example: + .. code-block:: python + + from langchain.chat_models import AzureChatOpenAI + openai = AzureChatOpenAI(deployment_name="") + """ + + deployment_name: str = "" + stop: List[str] = ["<|im_end|>"] + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + openai_api_key = get_from_dict_or_env( + values, + "openai_api_key", + "OPENAI_API_KEY", + ) + openai_api_base = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + ) + openai_api_version = get_from_dict_or_env( + values, + "openai_api_version", + "OPENAI_API_VERSION", + ) + openai_api_type = get_from_dict_or_env( + values, + "openai_api_type", + "OPENAI_API_TYPE", + ) + try: + import openai + + openai.api_type = openai_api_type + openai.api_base = openai_api_base + openai.api_version = openai_api_version + openai.api_key = openai_api_key + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please it install it with `pip install openai`." + ) + try: + values["client"] = openai.Completion + except AttributeError: + raise ValueError( + "`openai` has no `Completion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling OpenAI API.""" + return { + **super()._default_params, + "stop": self.stop, + } + + def _generate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + prompt, params = self._create_prompt(messages, stop) + if self.streaming: + inner_completion = "" + params["stream"] = True + for stream_resp in self.completion_with_retry(prompt=prompt, **params): + token = stream_resp["choices"][0]["delta"].get("text", "") + inner_completion += token + self.callback_manager.on_llm_new_token( + token, + verbose=self.verbose, + ) + message = AIMessage(content=inner_completion) + return ChatResult(generations=[ChatGeneration(message=message)]) + response = self.completion_with_retry(prompt=prompt, **params) + return _create_chat_result(response) + + def _create_prompt( + self, messages: List[BaseMessage], stop: Optional[List[str]] + ) -> Tuple[str, Dict[str, Any]]: + params: Dict[str, Any] = { + **{"model": self.model_name, "engine": self.deployment_name}, + **self._default_params, + } + if stop is not None: + if "stop" in params: + raise ValueError("`stop` found in both the input and default params.") + params["stop"] = stop + prompt = _create_chat_prompt(messages) + return prompt, params + + async def _agenerate( + self, messages: List[BaseMessage], stop: Optional[List[str]] = None + ) -> ChatResult: + prompt, params = self._create_prompt(messages, stop) + if self.streaming: + inner_completion = "" + params["stream"] = True + async for stream_resp in await acompletion_with_retry( + self, prompt=prompt, **params + ): + token = stream_resp["choices"][0]["delta"].get("text", "") + inner_completion += token + if self.callback_manager.is_async: + await self.callback_manager.on_llm_new_token( + token, + verbose=self.verbose, + ) + else: + self.callback_manager.on_llm_new_token( + token, + verbose=self.verbose, + ) + message = AIMessage(content=inner_completion) + return ChatResult(generations=[ChatGeneration(message=message)]) + else: + response = await acompletion_with_retry(self, prompt=prompt, **params) + return _create_chat_result(response) diff --git a/langchain/schema.py b/langchain/schema.py index 0a4e5ef5918..cb2d740342b 100644 --- a/langchain/schema.py +++ b/langchain/schema.py @@ -60,6 +60,9 @@ class BaseMessage(BaseModel): content: str additional_kwargs: dict = Field(default_factory=dict) + def format_chatml(self) -> str: + raise NotImplementedError() + @property @abstractmethod def type(self) -> str: @@ -69,6 +72,9 @@ class BaseMessage(BaseModel): class HumanMessage(BaseMessage): """Type of message that is spoken by the human.""" + def format_chatml(self) -> str: + return f"<|im_start|>user\n{self.content}\n<|im_end|>" + @property def type(self) -> str: """Type of the message, used for serialization.""" @@ -78,6 +84,9 @@ class HumanMessage(BaseMessage): class AIMessage(BaseMessage): """Type of message that is spoken by the AI.""" + def format_chatml(self) -> str: + return f"<|im_start|>assistant\n{self.content}\n<|im_end|>" + @property def type(self) -> str: """Type of the message, used for serialization.""" @@ -87,6 +96,9 @@ class AIMessage(BaseMessage): class SystemMessage(BaseMessage): """Type of message that is a system message.""" + def format_chatml(self) -> str: + return f"<|im_start|>system\n{self.content}\n<|im_end|>" + @property def type(self) -> str: """Type of the message, used for serialization.""" @@ -98,6 +110,9 @@ class ChatMessage(BaseMessage): role: str + def format_chatml(self) -> str: + return f"<|im_start|>{self.role}\n{self.content}\n<|im_end|>" + @property def type(self) -> str: """Type of the message, used for serialization."""