mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-31 10:23:18 +00:00
chore(community): update to OpenLLM 0.6 (#24609)
Update to OpenLLM 0.6, which we decides to make use of OpenLLM's OpenAI-compatible endpoint. Thus, OpenLLM will now just become a thin wrapper around OpenAI wrapper. Signed-off-by: Aaron Pham <contact@aarnphm.xyz> --------- Signed-off-by: Aaron Pham <contact@aarnphm.xyz> Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
@@ -1084,7 +1084,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
|
||||
"vertexai": _import_vertex,
|
||||
"vertexai_model_garden": _import_vertex_model_garden,
|
||||
"openllm": _import_openllm,
|
||||
"openllm_client": _import_openllm,
|
||||
"outlines": _import_outlines,
|
||||
"vllm": _import_vllm,
|
||||
"vllm_openai": _import_vllm_openai,
|
||||
|
@@ -1,338 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import json
|
||||
import logging
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
TypedDict,
|
||||
Union,
|
||||
overload,
|
||||
)
|
||||
from typing import Any, Dict
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from pydantic import ConfigDict, PrivateAttr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import openllm
|
||||
from langchain_community.llms.openai import BaseOpenAI
|
||||
from langchain_community.utils.openai import is_openai_v1
|
||||
|
||||
|
||||
ServerType = Literal["http", "grpc"]
|
||||
class OpenLLM(BaseOpenAI):
|
||||
"""OpenAI's compatible API client for OpenLLM server
|
||||
|
||||
.. versionchanged:: 0.2.11
|
||||
|
||||
class IdentifyingParams(TypedDict):
|
||||
"""Parameters for identifying a model as a typed dict."""
|
||||
|
||||
model_name: str
|
||||
model_id: Optional[str]
|
||||
server_url: Optional[str]
|
||||
server_type: Optional[ServerType]
|
||||
embedded: bool
|
||||
llm_kwargs: Dict[str, Any]
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class OpenLLM(LLM):
|
||||
"""OpenLLM, supporting both in-process model
|
||||
instance and remote OpenLLM servers.
|
||||
|
||||
To use, you should have the openllm library installed:
|
||||
|
||||
.. code-block:: bash
|
||||
|
||||
pip install openllm
|
||||
|
||||
Learn more at: https://github.com/bentoml/openllm
|
||||
|
||||
Example running an LLM model locally managed by OpenLLM:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import OpenLLM
|
||||
llm = OpenLLM(
|
||||
model_name='flan-t5',
|
||||
model_id='google/flan-t5-large',
|
||||
)
|
||||
llm.invoke("What is the difference between a duck and a goose?")
|
||||
|
||||
For all available supported models, you can run 'openllm models'.
|
||||
|
||||
If you have a OpenLLM server running, you can also use it remotely:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import OpenLLM
|
||||
llm = OpenLLM(server_url='http://localhost:3000')
|
||||
llm.invoke("What is the difference between a duck and a goose?")
|
||||
Changed in 0.2.11 to support OpenLLM 0.6. Now behaves similar to OpenAI wrapper.
|
||||
"""
|
||||
|
||||
model_name: Optional[str] = None
|
||||
"""Model name to use. See 'openllm models' for all available models."""
|
||||
model_id: Optional[str] = None
|
||||
"""Model Id to use. If not provided, will use the default model for the model name.
|
||||
See 'openllm models' for all available model variants."""
|
||||
server_url: Optional[str] = None
|
||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||
timeout: int = 30
|
||||
""""Time out for the openllm client"""
|
||||
server_type: ServerType = "http"
|
||||
"""Optional server type. Either 'http' or 'grpc'."""
|
||||
embedded: bool = True
|
||||
"""Initialize this LLM instance in current process by default. Should
|
||||
only set to False when using in conjunction with BentoML Service."""
|
||||
llm_kwargs: Dict[str, Any]
|
||||
"""Keyword arguments to be passed to openllm.LLM"""
|
||||
|
||||
_runner: Optional[openllm.LLMRunner] = PrivateAttr(default=None)
|
||||
_client: Union[openllm.client.HTTPClient, openllm.client.GrpcClient, None] = (
|
||||
PrivateAttr(default=None)
|
||||
)
|
||||
|
||||
model_config = ConfigDict(
|
||||
extra="forbid",
|
||||
)
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = ...,
|
||||
*,
|
||||
model_id: Optional[str] = ...,
|
||||
embedded: Literal[True, False] = ...,
|
||||
**llm_kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
server_url: str = ...,
|
||||
server_type: Literal["grpc", "http"] = ...,
|
||||
**llm_kwargs: Any,
|
||||
) -> None: ...
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: Optional[str] = None,
|
||||
*,
|
||||
model_id: Optional[str] = None,
|
||||
server_url: Optional[str] = None,
|
||||
timeout: int = 30,
|
||||
server_type: Literal["grpc", "http"] = "http",
|
||||
embedded: bool = True,
|
||||
**llm_kwargs: Any,
|
||||
):
|
||||
try:
|
||||
import openllm
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import openllm. Make sure to install it with "
|
||||
"'pip install openllm.'"
|
||||
) from e
|
||||
|
||||
llm_kwargs = llm_kwargs or {}
|
||||
|
||||
if server_url is not None:
|
||||
logger.debug("'server_url' is provided, returning a openllm.Client")
|
||||
assert (
|
||||
model_id is None and model_name is None
|
||||
), "'server_url' and {'model_id', 'model_name'} are mutually exclusive"
|
||||
client_cls = (
|
||||
openllm.client.HTTPClient
|
||||
if server_type == "http"
|
||||
else openllm.client.GrpcClient
|
||||
)
|
||||
client = client_cls(server_url, timeout)
|
||||
|
||||
super().__init__(
|
||||
**{ # type: ignore[arg-type]
|
||||
"server_url": server_url,
|
||||
"timeout": timeout,
|
||||
"server_type": server_type,
|
||||
"llm_kwargs": llm_kwargs,
|
||||
}
|
||||
)
|
||||
self._runner = None # type: ignore
|
||||
self._client = client
|
||||
else:
|
||||
assert model_name is not None, "Must provide 'model_name' or 'server_url'"
|
||||
# since the LLM are relatively huge, we don't actually want to convert the
|
||||
# Runner with embedded when running the server. Instead, we will only set
|
||||
# the init_local here so that LangChain users can still use the LLM
|
||||
# in-process. Wrt to BentoML users, setting embedded=False is the expected
|
||||
# behaviour to invoke the runners remotely.
|
||||
# We need to also enable ensure_available to download and setup the model.
|
||||
runner = openllm.Runner(
|
||||
model_name=model_name,
|
||||
model_id=model_id,
|
||||
init_local=embedded,
|
||||
ensure_available=True,
|
||||
**llm_kwargs,
|
||||
)
|
||||
super().__init__(
|
||||
**{ # type: ignore[arg-type]
|
||||
"model_name": model_name,
|
||||
"model_id": model_id,
|
||||
"embedded": embedded,
|
||||
"llm_kwargs": llm_kwargs,
|
||||
}
|
||||
)
|
||||
self._client = None # type: ignore
|
||||
self._runner = runner
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def runner(self) -> openllm.LLMRunner:
|
||||
"""
|
||||
Get the underlying openllm.LLMRunner instance for integration with BentoML.
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
llm = OpenLLM(
|
||||
model_name='flan-t5',
|
||||
model_id='google/flan-t5-large',
|
||||
embedded=False,
|
||||
params: Dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
**self._default_params,
|
||||
"logit_bias": None,
|
||||
}
|
||||
if not is_openai_v1():
|
||||
params.update(
|
||||
{
|
||||
"api_key": self.openai_api_key,
|
||||
"api_base": self.openai_api_base,
|
||||
}
|
||||
)
|
||||
tools = load_tools(["serpapi", "llm-math"], llm=llm)
|
||||
agent = initialize_agent(
|
||||
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION
|
||||
)
|
||||
svc = bentoml.Service("langchain-openllm", runners=[llm.runner])
|
||||
|
||||
@svc.api(input=Text(), output=Text())
|
||||
def chat(input_text: str):
|
||||
return agent.run(input_text)
|
||||
"""
|
||||
if self._runner is None:
|
||||
raise ValueError("OpenLLM must be initialized locally with 'model_name'")
|
||||
return self._runner
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> IdentifyingParams:
|
||||
"""Get the identifying parameters."""
|
||||
if self._client is not None:
|
||||
self.llm_kwargs.update(self._client._config)
|
||||
model_name = self._client._metadata.model_dump()["model_name"]
|
||||
model_id = self._client._metadata.model_dump()["model_id"]
|
||||
else:
|
||||
if self._runner is None:
|
||||
raise ValueError("Runner must be initialized.")
|
||||
model_name = self.model_name
|
||||
model_id = self.model_id
|
||||
try:
|
||||
self.llm_kwargs.update(
|
||||
json.loads(self._runner.identifying_params["configuration"])
|
||||
)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
pass
|
||||
return IdentifyingParams(
|
||||
server_url=self.server_url,
|
||||
server_type=self.server_type,
|
||||
embedded=self.embedded,
|
||||
llm_kwargs=self.llm_kwargs,
|
||||
model_name=model_name,
|
||||
model_id=model_id,
|
||||
)
|
||||
return params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "openllm_client" if self._client else "openllm"
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
import openllm
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import openllm. Make sure to install it with "
|
||||
"'pip install openllm'."
|
||||
) from e
|
||||
|
||||
copied = copy.deepcopy(self.llm_kwargs)
|
||||
copied.update(kwargs)
|
||||
config = openllm.AutoConfig.for_model(
|
||||
self._identifying_params["model_name"], **copied
|
||||
)
|
||||
if self._client:
|
||||
res = (
|
||||
self._client.generate(prompt, **config.model_dump(flatten=True))
|
||||
.outputs[0]
|
||||
.text
|
||||
)
|
||||
else:
|
||||
assert self._runner is not None
|
||||
res = self._runner(prompt, **config.model_dump(flatten=True))
|
||||
if isinstance(res, dict) and "text" in res:
|
||||
return res["text"]
|
||||
elif isinstance(res, str):
|
||||
return res
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected result to be a dict with key 'text' or a string. "
|
||||
f"Received {res}"
|
||||
)
|
||||
|
||||
async def _acall(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
try:
|
||||
import openllm
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Could not import openllm. Make sure to install it with "
|
||||
"'pip install openllm'."
|
||||
) from e
|
||||
|
||||
copied = copy.deepcopy(self.llm_kwargs)
|
||||
copied.update(kwargs)
|
||||
config = openllm.AutoConfig.for_model(
|
||||
self._identifying_params["model_name"], **copied
|
||||
)
|
||||
if self._client:
|
||||
async_client = openllm.client.AsyncHTTPClient(self.server_url, self.timeout)
|
||||
res = (
|
||||
(await async_client.generate(prompt, **config.model_dump(flatten=True)))
|
||||
.outputs[0]
|
||||
.text
|
||||
)
|
||||
else:
|
||||
assert self._runner is not None
|
||||
(
|
||||
prompt,
|
||||
generate_kwargs,
|
||||
postprocess_kwargs,
|
||||
) = self._runner.llm.sanitize_parameters(prompt, **kwargs)
|
||||
generated_result = await self._runner.generate.async_run(
|
||||
prompt, **generate_kwargs
|
||||
)
|
||||
res = self._runner.llm.postprocess_generate(
|
||||
prompt, generated_result, **postprocess_kwargs
|
||||
)
|
||||
|
||||
if isinstance(res, dict) and "text" in res:
|
||||
return res["text"]
|
||||
elif isinstance(res, str):
|
||||
return res
|
||||
else:
|
||||
raise ValueError(
|
||||
"Expected result to be a dict with key 'text' or a string. "
|
||||
f"Received {res}"
|
||||
)
|
||||
return "openllm"
|
||||
|
@@ -1,17 +1,13 @@
|
||||
"""Test OpenLLM wrapper."""
|
||||
"""Test OpenLLM API wrapper."""
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_community.llms.openllm import OpenLLM
|
||||
|
||||
|
||||
def test_openllm_llm_local() -> None:
|
||||
llm = OpenLLM(model_name="flan-t5", model_id="google/flan-t5-small")
|
||||
output = llm.invoke("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_openllm_with_kwargs() -> None:
|
||||
llm = OpenLLM(
|
||||
model_name="flan-t5", model_id="google/flan-t5-small", temperature=0.84
|
||||
)
|
||||
output = llm.invoke("Say bar:")
|
||||
@pytest.mark.scheduled
|
||||
def test_openai_call() -> None:
|
||||
"""Test valid call to openai."""
|
||||
llm = OpenLLM()
|
||||
output = llm.invoke("Say something nice:")
|
||||
assert isinstance(output, str)
|
||||
|
Reference in New Issue
Block a user