community[patch]: Modify LLMs/Anyscale work with OpenAI API v1 (#14206)

<!-- Thank you for contributing to LangChain!

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes (if applicable),
  - **Dependencies:** any dependencies required for this change,
- **Tag maintainer:** for a quicker response, tag the relevant
maintainer (see below),
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` to check this
locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc:

https://github.com/langchain-ai/langchain/blob/master/.github/CONTRIBUTING.md

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in `docs/extras`
directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
- **Description:** 
1. Modify LLMs/Anyscale to work with OAI v1
2. Get rid of openai_ prefixed variables in Chat_model/ChatAnyscale
3. Modify `anyscale_api_base` to `anyscale_base_url` to follow OAI name
convention (reverted)

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
kYLe 2024-02-09 15:11:18 -08:00 committed by GitHub
parent 24c0bab57b
commit c9999557bf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 147 additions and 114 deletions

View File

@ -23,7 +23,6 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1"
DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf"
@ -60,7 +59,7 @@ class ChatAnyscale(ChatOpenAI):
def is_lc_serializable(cls) -> bool:
return False
anyscale_api_key: SecretStr
anyscale_api_key: SecretStr = Field(default=None)
"""AnyScale Endpoints API keys."""
model_name: str = Field(default=DEFAULT_MODEL, alias="model")
"""Model name to use."""
@ -102,14 +101,9 @@ class ChatAnyscale(ChatOpenAI):
return {model["id"] for model in models_response.json()["data"]}
@root_validator(pre=True)
def validate_environment_override(cls, values: dict) -> dict:
@root_validator()
def validate_environment(cls, values: dict) -> dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
values,
"anyscale_api_key",
"ANYSCALE_API_KEY",
)
values["anyscale_api_key"] = convert_to_secret_str(
get_from_dict_or_env(
values,
@ -117,7 +111,7 @@ class ChatAnyscale(ChatOpenAI):
"ANYSCALE_API_KEY",
)
)
values["openai_api_base"] = get_from_dict_or_env(
values["anyscale_api_base"] = get_from_dict_or_env(
values,
"anyscale_api_base",
"ANYSCALE_API_BASE",
@ -140,8 +134,8 @@ class ChatAnyscale(ChatOpenAI):
try:
if is_openai_v1():
client_params = {
"api_key": values["openai_api_key"],
"base_url": values["openai_api_base"],
"api_key": values["anyscale_api_key"].get_secret_value(),
"base_url": values["anyscale_api_base"],
# To do: future support
# "organization": values["openai_organization"],
# "timeout": values["request_timeout"],
@ -152,6 +146,8 @@ class ChatAnyscale(ChatOpenAI):
}
values["client"] = openai.OpenAI(**client_params).chat.completions
else:
values["openai_api_base"] = values["anyscale_api_base"]
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
values["client"] = openai.ChatCompletion
except AttributeError as exc:
raise ValueError(
@ -164,10 +160,9 @@ class ChatAnyscale(ChatOpenAI):
values["model_name"] = DEFAULT_MODEL
model_name = values["model_name"]
available_models = cls.get_available_models(
values["openai_api_key"],
values["openai_api_base"],
values["anyscale_api_key"].get_secret_value(),
values["anyscale_api_base"],
)
if model_name not in available_models:
@ -197,9 +192,8 @@ class ChatAnyscale(ChatOpenAI):
def get_num_tokens_from_messages(self, messages: list[BaseMessage]) -> int:
"""Calculate num tokens with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
Official documentation: https://github.com/openai/openai-cookbook/blob/main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb
"""
if sys.version_info[1] <= 7:
return super().get_num_tokens_from_messages(messages)
model, encoding = self._get_encoding_model()

View File

@ -1,15 +1,11 @@
"""Wrapper around Anyscale Endpoint"""
from typing import (
Any,
AsyncIterator,
Dict,
Iterator,
List,
Mapping,
Optional,
Set,
Tuple,
cast,
)
from langchain_core.callbacks import (
@ -25,6 +21,13 @@ from langchain_community.llms.openai import (
acompletion_with_retry,
completion_with_retry,
)
from langchain_community.utils.openai import is_openai_v1
DEFAULT_BASE_URL = "https://api.endpoints.anyscale.com/v1"
DEFAULT_MODEL = "Meta-Llama/Llama-Guard-7b"
# Completion models support by Anyscale Endpoints
COMPLETION_MODELS = ["Meta-Llama/Llama-Guard-7b"]
def update_token_usage(
@ -64,16 +67,14 @@ def create_llm_result(
class Anyscale(BaseOpenAI):
"""Anyscale large language models.
To use, you should have the environment variable ``ANYSCALE_API_BASE`` and
``ANYSCALE_API_KEY``set with your Anyscale Endpoint, or pass it as a named
parameter to the constructor.
To use, you should have the environment variable ``ANYSCALE_API_KEY``set with your
Anyscale Endpoint, or pass it as a named parameter to the constructor.
To use with Anyscale Private Endpoint, please also set ``ANYSCALE_BASE_URL``.
Example:
.. code-block:: python
from langchain_community.llms import Anyscale
anyscalellm = Anyscale(anyscale_api_base="ANYSCALE_API_BASE",
anyscale_api_key="ANYSCALE_API_KEY",
model_name="meta-llama/Llama-2-7b-chat-hf")
from langchain.llms import Anyscale
anyscalellm = Anyscale(anyscale_api_key="ANYSCALE_API_KEY")
# To leverage Ray for parallel processing
@ray.remote(num_cpus=1)
def send_query(llm, text):
@ -84,8 +85,9 @@ class Anyscale(BaseOpenAI):
"""
"""Key word arguments to pass to the model."""
anyscale_api_base: Optional[str] = None
anyscale_api_key: Optional[SecretStr] = None
anyscale_api_base: str = Field(default=DEFAULT_BASE_URL)
anyscale_api_key: SecretStr = Field(default=None)
model_name: str = Field(default=DEFAULT_MODEL)
prefix_messages: List = Field(default_factory=list)
@ -97,17 +99,47 @@ class Anyscale(BaseOpenAI):
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["anyscale_api_base"] = get_from_dict_or_env(
values, "anyscale_api_base", "ANYSCALE_API_BASE"
values,
"anyscale_api_base",
"ANYSCALE_API_BASE",
default=DEFAULT_BASE_URL,
)
values["anyscale_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "anyscale_api_key", "ANYSCALE_API_KEY")
)
values["model_name"] = get_from_dict_or_env(
values,
"model_name",
"MODEL_NAME",
default=DEFAULT_MODEL,
)
if values["model_name"] not in COMPLETION_MODELS:
raise ValueError(
"langchain_community.llm.Anyscale ONLY works \
with completions models.For Chat models, please use \
langchain_community.chat_model.ChatAnyscale"
)
try:
import openai
## Always create ChatComplete client, replacing the legacy Complete client
values["client"] = openai.ChatCompletion
if is_openai_v1():
client_params = {
"api_key": values["anyscale_api_key"].get_secret_value(),
"base_url": values["anyscale_api_base"],
# To do: future support
# "organization": values["openai_organization"],
# "timeout": values["request_timeout"],
# "max_retries": values["max_retries"],
# "default_headers": values["default_headers"],
# "default_query": values["default_query"],
# "http_client": values["http_client"],
}
values["client"] = openai.OpenAI(**client_params).completions
else:
values["openai_api_base"] = values["anyscale_api_base"]
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
values["client"] = openai.Completion
except ImportError:
raise ImportError(
"Could not import openai python package. "
@ -132,70 +164,22 @@ class Anyscale(BaseOpenAI):
def _invocation_params(self) -> Dict[str, Any]:
"""Get the parameters used to invoke the model."""
openai_creds: Dict[str, Any] = {
"api_key": cast(SecretStr, self.anyscale_api_key).get_secret_value(),
"model": self.model_name,
}
if not is_openai_v1():
openai_creds.update(
{
"api_key": self.anyscale_api_key.get_secret_value(),
"api_base": self.anyscale_api_base,
}
return {**openai_creds, **{"model": self.model_name}, **super()._default_params}
)
return {**openai_creds, **super()._invocation_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "Anyscale LLM"
def _get_chat_messages(
self, prompts: List[str], stop: Optional[List[str]] = None
) -> Tuple:
if len(prompts) > 1:
raise ValueError(
f"Anyscale currently only supports single prompt, got {prompts}"
)
messages = self.prefix_messages + [{"role": "user", "content": prompts[0]}]
params: Dict[str, Any] = self._invocation_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
if params.get("max_tokens") == -1:
# for Chat api, omitting max_tokens is equivalent to having no limit
del params["max_tokens"]
return messages, params
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs, "stream": True}
for stream_resp in completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token)
yield chunk
if run_manager:
run_manager.on_llm_new_token(token, chunk=chunk)
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs, "stream": True}
async for stream_resp in await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
):
token = stream_resp["choices"][0]["delta"].get("content", "")
chunk = GenerationChunk(text=token)
yield chunk
if run_manager:
await run_manager.on_llm_new_token(token, chunk=chunk)
def _generate(
self,
prompts: List[str],
@ -203,13 +187,37 @@ class Anyscale(BaseOpenAI):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
Example:
.. code-block:: python
response = openai.generate(["Tell me a joke."])
"""
# TODO: write a unit test for this
params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for prompt in prompts:
system_fingerprint: Optional[str] = None
for _prompts in sub_prompts:
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
generation: Optional[GenerationChunk] = None
for chunk in self._stream(prompt, stop, run_manager, **kwargs):
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
if generation is None:
generation = chunk
else:
@ -217,7 +225,7 @@ class Anyscale(BaseOpenAI):
assert generation is not None
choices.append(
{
"message": {"content": generation.text},
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
@ -226,16 +234,30 @@ class Anyscale(BaseOpenAI):
else None,
}
)
else:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs}
response = completion_with_retry(
self, messages=messages, run_manager=run_manager, **params
## THis is the ONLY change from BaseOpenAI()._generate()
self,
prompt=_prompts[0],
run_manager=run_manager,
**params,
)
if not isinstance(response, dict):
# V1 client returns the response in an PyDantic object instead of
# dict. For the transition period, we deep convert it to dict.
response = response.dict()
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return create_llm_result(choices, prompts, token_usage, self.model_name)
if not system_fingerprint:
system_fingerprint = response.get("system_fingerprint")
return self.create_llm_result(
choices,
prompts,
params,
token_usage,
system_fingerprint=system_fingerprint,
)
async def _agenerate(
self,
@ -244,14 +266,25 @@ class Anyscale(BaseOpenAI):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
# Get the token usage from the response.
# Includes prompt, completion, and total tokens used.
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
for prompt in prompts:
messages = self.prefix_messages + [{"role": "user", "content": prompt}]
system_fingerprint: Optional[str] = None
for _prompts in sub_prompts:
if self.streaming:
if len(_prompts) > 1:
raise ValueError("Cannot stream results with multiple prompts.")
generation: Optional[GenerationChunk] = None
async for chunk in self._astream(prompt, stop, run_manager, **kwargs):
async for chunk in self._astream(
_prompts[0], stop, run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
@ -259,7 +292,7 @@ class Anyscale(BaseOpenAI):
assert generation is not None
choices.append(
{
"message": {"content": generation.text},
"text": generation.text,
"finish_reason": generation.generation_info.get("finish_reason")
if generation.generation_info
else None,
@ -269,11 +302,21 @@ class Anyscale(BaseOpenAI):
}
)
else:
messages, params = self._get_chat_messages([prompt], stop)
params = {**params, **kwargs}
response = await acompletion_with_retry(
self, messages=messages, run_manager=run_manager, **params
## THis is the ONLY change from BaseOpenAI()._agenerate()
self,
prompt=_prompts[0],
run_manager=run_manager,
**params,
)
if not isinstance(response, dict):
response = response.dict()
choices.extend(response["choices"])
update_token_usage(_keys, response, token_usage)
return create_llm_result(choices, prompts, token_usage, self.model_name)
return self.create_llm_result(
choices,
prompts,
params,
token_usage,
system_fingerprint=system_fingerprint,
)

View File

@ -8,9 +8,7 @@ from langchain_community.llms.anyscale import Anyscale
@pytest.mark.requires("openai")
def test_api_key_is_secret_string() -> None:
llm = Anyscale(
anyscale_api_key="secret-api-key", anyscale_api_base="test", model_name="test"
)
llm = Anyscale(anyscale_api_key="secret-api-key", anyscale_api_base="test")
assert isinstance(llm.anyscale_api_key, SecretStr)
@ -20,7 +18,7 @@ def test_api_key_masked_when_passed_from_env(
) -> None:
"""Test initialization with an API key provided via an env variable"""
monkeypatch.setenv("ANYSCALE_API_KEY", "secret-api-key")
llm = Anyscale(anyscale_api_base="test", model_name="test")
llm = Anyscale(anyscale_api_base="test")
print(llm.anyscale_api_key, end="")
captured = capsys.readouterr()
@ -32,9 +30,7 @@ def test_api_key_masked_when_passed_via_constructor(
capsys: CaptureFixture,
) -> None:
"""Test initialization with an API key provided via the initializer"""
llm = Anyscale(
anyscale_api_key="secret-api-key", anyscale_api_base="test", model_name="test"
)
llm = Anyscale(anyscale_api_key="secret-api-key", anyscale_api_base="test")
print(llm.anyscale_api_key, end="")
captured = capsys.readouterr()