langchain/libs/community/langchain_community/chat_models/everlyai.py
Bagatur a0c2281540
infra: update mypy 1.10, ruff 0.5 (#23721)
```python
"""python scripts/update_mypy_ruff.py"""
import glob
import tomllib
from pathlib import Path

import toml
import subprocess
import re

ROOT_DIR = Path(__file__).parents[1]


def main():
    for path in glob.glob(str(ROOT_DIR / "libs/**/pyproject.toml"), recursive=True):
        print(path)
        with open(path, "rb") as f:
            pyproject = tomllib.load(f)
        try:
            pyproject["tool"]["poetry"]["group"]["typing"]["dependencies"]["mypy"] = (
                "^1.10"
            )
            pyproject["tool"]["poetry"]["group"]["lint"]["dependencies"]["ruff"] = (
                "^0.5"
            )
        except KeyError:
            continue
        with open(path, "w") as f:
            toml.dump(pyproject, f)
        cwd = "/".join(path.split("/")[:-1])
        completed = subprocess.run(
            "poetry lock --no-update; poetry install --with typing; poetry run mypy . --no-color",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )
        logs = completed.stdout.split("\n")

        to_ignore = {}
        for l in logs:
            if re.match("^(.*)\:(\d+)\: error:.*\[(.*)\]", l):
                path, line_no, error_type = re.match(
                    "^(.*)\:(\d+)\: error:.*\[(.*)\]", l
                ).groups()
                if (path, line_no) in to_ignore:
                    to_ignore[(path, line_no)].append(error_type)
                else:
                    to_ignore[(path, line_no)] = [error_type]
        print(len(to_ignore))
        for (error_path, line_no), error_types in to_ignore.items():
            all_errors = ", ".join(error_types)
            full_path = f"{cwd}/{error_path}"
            try:
                with open(full_path, "r") as f:
                    file_lines = f.readlines()
            except FileNotFoundError:
                continue
            file_lines[int(line_no) - 1] = (
                file_lines[int(line_no) - 1][:-1] + f"  # type: ignore[{all_errors}]\n"
            )
            with open(full_path, "w") as f:
                f.write("".join(file_lines))

        subprocess.run(
            "poetry run ruff format .; poetry run ruff --select I --fix .",
            cwd=cwd,
            shell=True,
            capture_output=True,
            text=True,
        )


if __name__ == "__main__":
    main()

```
2024-07-03 10:33:27 -07:00

161 lines
5.5 KiB
Python

"""EverlyAI Endpoints chat wrapper. Relies heavily on ChatOpenAI."""
from __future__ import annotations
import logging
import sys
from typing import TYPE_CHECKING, Dict, Optional, Set
from langchain_core.messages import BaseMessage
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_community.adapters.openai import convert_message_to_dict
from langchain_community.chat_models.openai import (
ChatOpenAI,
_import_tiktoken,
)
if TYPE_CHECKING:
import tiktoken
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://everlyai.xyz/hosted"
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
class ChatEverlyAI(ChatOpenAI):
"""`EverlyAI` Chat large language models.
To use, you should have the ``openai`` python package installed, and the
environment variable ``EVERLYAI_API_KEY`` set with your API key.
Alternatively, you can use the everlyai_api_key keyword argument.
Any parameters that are valid to be passed to the `openai.create` call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain_community.chat_models import ChatEverlyAI
chat = ChatEverlyAI(model_name="meta-llama/Llama-2-7b-chat-hf")
"""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
return "everlyai-chat"
@property
def lc_secrets(self) -> Dict[str, str]:
return {"everlyai_api_key": "EVERLYAI_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return False
everlyai_api_key: Optional[str] = None
"""EverlyAI Endpoints API keys."""
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
everlyai_api_base: str = DEFAULT_API_BASE
"""Base URL path for API requests."""
available_models: Optional[Set[str]] = None
"""Available models from EverlyAI API."""
@staticmethod
def get_available_models() -> Set[str]:
"""Get available models from EverlyAI API."""
# EverlyAI doesn't yet support dynamically query for available models.
return set(
[
"meta-llama/Llama-2-7b-chat-hf",
"meta-llama/Llama-2-13b-chat-hf-quantized",
]
)
@root_validator(pre=True)
def validate_environment_override(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values,
"everlyai_api_key",
"EVERLYAI_API_KEY",
)
values["openai_api_base"] = DEFAULT_API_BASE
try:
import openai
except ImportError as e:
raise ImportError(
"Could not import openai python package. "
"Please install it with `pip install openai`.",
) from e
try:
values["client"] = openai.ChatCompletion
except AttributeError as exc:
raise ValueError(
"`openai` has no `ChatCompletion` attribute, this is likely "
"due to an old version of the openai package. Try upgrading it "
"with `pip install --upgrade openai`.",
) from exc
if "model_name" not in values.keys():
values["model_name"] = DEFAULT_MODEL
model_name = values["model_name"]
available_models = cls.get_available_models()
if model_name not in available_models:
raise ValueError(
f"Model name {model_name} not found in available models: "
f"{available_models}.",
)
values["available_models"] = available_models
return values
def _get_encoding_model(self) -> tuple[str, tiktoken.Encoding]:
tiktoken_ = _import_tiktoken()
if self.tiktoken_model_name is not None:
model = self.tiktoken_model_name
else:
model = self.model_name
# Returns the number of tokens used by a list of messages.
try:
encoding = tiktoken_.encoding_for_model("gpt-3.5-turbo-0301")
except KeyError:
logger.warning("Warning: model not found. Using cl100k_base encoding.")
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
return model, encoding
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens