chore: update client correctness

Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
This commit is contained in:
Aaron
2023-11-12 04:20:45 -05:00
parent feef8bddcf
commit 2b2dba6efd
3 changed files with 23 additions and 26 deletions

View File

@@ -40,7 +40,7 @@
"To start an LLM server, use `openllm start` command. For example, to start a dolly-v2 server, run the following command from a terminal:\n",
"\n",
"```bash\n",
"openllm start dolly-v2\n",
"openllm start facebook/opt-1.3b\n",
"```\n",
"\n",
"\n",
@@ -84,8 +84,8 @@
"from langchain.llms import OpenLLM\n",
"\n",
"llm = OpenLLM(\n",
" model_name=\"dolly-v2\",\n",
" model_id=\"databricks/dolly-v2-3b\",\n",
" model_name=\"opt\",\n",
" model_id=\"facebook/opt-250m\",\n",
" temperature=0.94,\n",
" repetition_penalty=1.2,\n",
")"
@@ -114,7 +114,8 @@
}
],
"source": [
"from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.chains import LLMChain\n",
"\n",
"template = \"What is a good name for a company that makes {product}?\"\n",
"\n",
@@ -151,7 +152,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.9.16"
}
},
"nbformat": 4,

View File

@@ -38,7 +38,7 @@ OpenLLM server can run either locally or on the cloud.
To try it out locally, start an OpenLLM server:
```bash
openllm start flan-t5
openllm start facebook/opt-1.3b
```
Wrapper usage:
@@ -59,7 +59,7 @@ running inference.
```python
from langchain.llms import OpenLLM
llm = OpenLLM(model_name="dolly-v2", model_id='databricks/dolly-v2-7b')
llm = OpenLLM(model_id='HuggingFaceH4/zephyr-7b-alpha')
llm("What is the difference between a duck and a goose? And why there are so many Goose in Canada?")
```

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import copy
import json
import logging
import warnings
from typing import (
TYPE_CHECKING,
Any,
@@ -11,7 +12,6 @@ from typing import (
Literal,
Optional,
TypedDict,
Union,
overload,
)
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
import openllm
ServerType = Literal["http", "grpc"]
ServerType = Literal["http"]
class IdentifyingParams(TypedDict):
@@ -91,9 +91,7 @@ class OpenLLM(LLM):
"""Keyword arguments to be passed to openllm.LLM"""
_llm: Optional[openllm.LLM[Any, Any]] = PrivateAttr(default=None)
_client: Union[
openllm.client.HTTPClient, openllm.client.GrpcClient, None
] = PrivateAttr(default=None)
_client: Optional[openllm.client.HTTPClient] = PrivateAttr(default=None)
class Config:
extra = "forbid"
@@ -114,7 +112,7 @@ class OpenLLM(LLM):
self,
*,
server_url: str = ...,
server_type: Literal["grpc", "http"] = ...,
server_type: Literal["http"] = ...,
**llm_kwargs: Any,
) -> None:
...
@@ -125,7 +123,7 @@ class OpenLLM(LLM):
*,
model_id: Optional[str] = None,
server_url: Optional[str] = None,
server_type: Literal["grpc", "http"] = "http",
server_type: Literal["http"] = "http",
embedded: bool = True,
**llm_kwargs: Any,
):
@@ -144,11 +142,7 @@ class OpenLLM(LLM):
assert (
model_id is None and model_name is None
), "'server_url' and {'model_id', 'model_name'} are mutually exclusive"
client_cls = (
openllm.client.HTTPClient
if server_type == "http"
else openllm.client.GrpcClient
)
client_cls = openllm.client.HTTPClient
client = client_cls(server_url)
super().__init__(
@@ -170,7 +164,7 @@ class OpenLLM(LLM):
# in-process. Wrt to BentoML users, setting embedded=False is the expected
# behaviour to invoke the runners remotely.
# We need to also enable ensure_available to download and setup the model.
llm = openllm.LLM[Any, Any](model_id, llm_config=config) # ensure_available will now always call
llm = openllm.LLM[Any, Any](model_id, llm_config=config)
if embedded:
llm.runner.init_local(quiet=True)
super().__init__(
@@ -207,6 +201,10 @@ class OpenLLM(LLM):
def chat(input_text: str):
return agent.run(input_text)
"""
warnings.warn(
"'OpenLLM.runner' is deprecated, use 'OpenLLM.llm' instead",
DeprecationWarning,
)
if self._llm is None:
raise ValueError("OpenLLM must be initialized locally with 'model_name'")
return self._llm.runner
@@ -222,9 +220,9 @@ class OpenLLM(LLM):
def _identifying_params(self) -> IdentifyingParams:
"""Get the identifying parameters."""
if self._client is not None:
self.llm_kwargs.update(self._client._config())
model_name = self._client._metadata()["model_name"]
model_id = self._client._metadata()["model_id"]
self.llm_kwargs.update(self._client._config)
model_name = self._client._metadata["model_name"]
model_id = self._client._metadata["model_id"]
else:
if self._llm is None:
raise ValueError("LLM must be initialized.")
@@ -308,9 +306,7 @@ class OpenLLM(LLM):
)
if self._client:
async_client = openllm.client.AsyncHTTPClient(self.server_url)
res = await async_client.generate(
prompt, stop=stop, **config.model_dump(flatten=True)
)
res = await async_client.generate(prompt, llm_config=config, stop=stop)
else:
assert self._llm is not None
res = await self._llm.generate(