Compare commits

...

20 Commits

Author SHA1 Message Date
Bagatur
f8052401dd fmt 2024-04-01 18:14:19 -07:00
Bagatur
02b36f7722 fmt 2024-04-01 18:10:58 -07:00
Aaron
3d94cfdaf0 fix: lint issues
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-12-12 01:48:39 -05:00
Aaron
7a4a259988 merge: remote-tracking branch 'upstream/master' into chore/migrate-to-new-api 2023-12-12 01:46:17 -05:00
Aaron
882aea7574 fix: lint
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-26 22:07:49 -05:00
Aaron
450f358557 fix: correct getitem on sync item
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-26 18:48:29 -05:00
Aaron
341449261b fix: conform with styles and lint
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-26 05:10:00 -05:00
Aaron
ad4f4cce41 fix: correct setup for configuration
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-26 05:04:06 -05:00
Aaron
f67b996173 fix: types
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-26 04:59:16 -05:00
Aaron Pham
13855f37bc Merge branch 'master' into chore/migrate-to-new-api 2023-11-26 04:55:06 -05:00
Aaron Pham
23c2b30511 Merge branch 'master' into chore/migrate-to-new-api 2023-11-14 18:51:57 -05:00
Aaron
0f281fbde3 chore: update example notebook
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-12 17:56:38 -05:00
Aaron
78681daf6f chore(client): add async_client object
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-12 17:50:41 -05:00
Aaron
2b2dba6efd chore: update client correctness
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-12 04:20:45 -05:00
Aaron
feef8bddcf revert: Revert to auto check for available models eagerly
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-08 13:42:37 -05:00
Aaron
32f3e04537 revert: move back to explicit api
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-08 13:32:55 -05:00
Aaron Pham
ffd3f4e4c5 Merge branch 'langchain-ai:master' into chore/migrate-to-new-api 2023-11-08 13:17:53 -05:00
Aaron
3e4c7e5b35 chore: update API for always calling pretrained check to model store
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-08 13:16:42 -05:00
Aaron
36d24d1cae chore: fix linter issue
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-06 19:22:37 -05:00
Aaron
a2fb1d7608 chore(openllm): simplified interface with new API
Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com>
2023-11-06 19:19:28 -05:00
3 changed files with 69 additions and 105 deletions

View File

@@ -40,7 +40,7 @@
"To start an LLM server, use `openllm start` command. For example, to start a dolly-v2 server, run the following command from a terminal:\n",
"\n",
"```bash\n",
"openllm start dolly-v2\n",
"openllm start facebook/opt-1.3b\n",
"```\n",
"\n",
"\n",
@@ -84,8 +84,7 @@
"from langchain_community.llms import OpenLLM\n",
"\n",
"llm = OpenLLM(\n",
" model_name=\"dolly-v2\",\n",
" model_id=\"databricks/dolly-v2-3b\",\n",
" model_id=\"facebook/opt-250m\",\n",
" temperature=0.94,\n",
" repetition_penalty=1.2,\n",
")"
@@ -152,7 +151,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.9.16"
}
},
"nbformat": 4,

View File

@@ -38,7 +38,7 @@ OpenLLM server can run either locally or on the cloud.
To try it out locally, start an OpenLLM server:
```bash
openllm start flan-t5
openllm start facebook/opt-1.3b
```
Wrapper usage:
@@ -59,7 +59,7 @@ running inference.
```python
from langchain_community.llms import OpenLLM
llm = OpenLLM(model_name="dolly-v2", model_id='databricks/dolly-v2-7b')
llm = OpenLLM(model_id='HuggingFaceH4/zephyr-7b-alpha')
llm("What is the difference between a duck and a goose? And why there are so many Goose in Canada?")
```

View File

@@ -3,6 +3,7 @@ from __future__ import annotations
import copy
import json
import logging
import warnings
from typing import (
TYPE_CHECKING,
Any,
@@ -11,7 +12,6 @@ from typing import (
Literal,
Optional,
TypedDict,
Union,
overload,
)
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
import openllm
ServerType = Literal["http", "grpc"]
ServerType = Literal["http"]
class IdentifyingParams(TypedDict):
@@ -92,10 +92,9 @@ class OpenLLM(LLM):
llm_kwargs: Dict[str, Any]
"""Keyword arguments to be passed to openllm.LLM"""
_runner: Optional[openllm.LLMRunner] = PrivateAttr(default=None)
_client: Union[
openllm.client.HTTPClient, openllm.client.GrpcClient, None
] = PrivateAttr(default=None)
_llm: Optional[openllm.LLM[Any, Any]] = PrivateAttr(default=None)
_client: Optional[openllm.HTTPClient] = PrivateAttr(default=None)
_async_client: Optional[openllm.AsyncHTTPClient] = PrivateAttr(default=None)
class Config:
extra = "forbid"
@@ -116,7 +115,7 @@ class OpenLLM(LLM):
self,
*,
server_url: str = ...,
server_type: Literal["grpc", "http"] = ...,
server_type: Literal["http"] = ...,
**llm_kwargs: Any,
) -> None:
...
@@ -128,7 +127,7 @@ class OpenLLM(LLM):
model_id: Optional[str] = None,
server_url: Optional[str] = None,
timeout: int = 30,
server_type: Literal["grpc", "http"] = "http",
server_type: Literal["http"] = "http",
embedded: bool = True,
**llm_kwargs: Any,
):
@@ -147,38 +146,31 @@ class OpenLLM(LLM):
assert (
model_id is None and model_name is None
), "'server_url' and {'model_id', 'model_name'} are mutually exclusive"
client_cls = (
openllm.client.HTTPClient
if server_type == "http"
else openllm.client.GrpcClient
)
client = client_cls(server_url, timeout)
super().__init__(
**{
"server_url": server_url,
"timeout": timeout,
"server_type": server_type,
"server_type": "http",
"llm_kwargs": llm_kwargs,
}
)
self._runner = None # type: ignore
self._client = client
self._llm = None # type: ignore
self._client = openllm.HTTPClient(server_url, timeout=timeout)
self._async_client = openllm.AsyncHTTPClient(server_url, timeout=timeout)
else:
assert model_name is not None, "Must provide 'model_name' or 'server_url'"
# since the LLM are relatively huge, we don't actually want to convert the
# Runner with embedded when running the server. Instead, we will only set
# the init_local here so that LangChain users can still use the LLM
# in-process. Wrt to BentoML users, setting embedded=False is the expected
# behaviour to invoke the runners remotely.
# We need to also enable ensure_available to download and setup the model.
runner = openllm.Runner(
model_name=model_name,
model_id=model_id,
init_local=embedded,
ensure_available=True,
**llm_kwargs,
)
if model_name is None: # supports not passing model_name
assert model_id is not None, "Must provide 'model_id' or 'server_url'"
llm = openllm.LLM[Any, Any](model_id, embedded=embedded)
else:
assert (
model_name is not None
), "Must provide 'model_name' or 'server_url'"
config = openllm.AutoConfig.for_model(model_name, **llm_kwargs)
model_id = model_id or config["default_id"]
llm = openllm.LLM[Any, Any](
model_id, llm_config=config, embedded=embedded
)
super().__init__(
**{
"model_name": model_name,
@@ -188,10 +180,11 @@ class OpenLLM(LLM):
}
)
self._client = None # type: ignore
self._runner = runner
self._async_client = None # type: ignore
self._llm = llm
@property
def runner(self) -> openllm.LLMRunner:
def runner(self) -> openllm.LLMRunner[Any, Any]:
"""
Get the underlying openllm.LLMRunner instance for integration with BentoML.
@@ -213,31 +206,42 @@ class OpenLLM(LLM):
def chat(input_text: str):
return agent.run(input_text)
"""
if self._runner is None:
warnings.warn(
"'OpenLLM.runner' is deprecated, use 'OpenLLM.llm' instead",
DeprecationWarning,
)
if self._llm is None:
raise ValueError("OpenLLM must be initialized locally with 'model_name'")
return self._runner
return self._llm.runner
@property
def llm(self) -> openllm.LLM[Any, Any]:
"""Get the underlying openllm.LLM instance."""
if self._llm is None:
raise ValueError("OpenLLM must be initialized locally with 'model_name'")
return self._llm
@property
def _identifying_params(self) -> IdentifyingParams:
"""Get the identifying parameters."""
if self._client is not None:
self.llm_kwargs.update(self._client._config)
model_name = self._client._metadata.model_dump()["model_name"]
model_id = self._client._metadata.model_dump()["model_id"]
model_name = self._client["model_name"]
model_id = self._client["model_id"]
else:
if self._runner is None:
raise ValueError("Runner must be initialized.")
model_name = self.model_name
if self._llm is None:
raise ValueError("LLM must be initialized.")
model_name = self.model_name or ""
model_id = self.model_id
try:
self.llm_kwargs.update(
json.loads(self._runner.identifying_params["configuration"])
json.loads(self._llm.identifying_params["configuration"])
)
except (TypeError, json.JSONDecodeError):
pass
return IdentifyingParams(
server_url=self.server_url,
server_type=self.server_type,
server_type="http",
embedded=self.embedded,
llm_kwargs=self.llm_kwargs,
model_name=model_name,
@@ -255,36 +259,21 @@ class OpenLLM(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:
import openllm
except ImportError as e:
raise ImportError(
"Could not import openllm. Make sure to install it with "
"'pip install openllm'."
) from e
import asyncio
copied = copy.deepcopy(self.llm_kwargs)
copied.update(kwargs)
config = openllm.AutoConfig.for_model(
self._identifying_params["model_name"], **copied
)
if self._client:
res = (
self._client.generate(prompt, **config.model_dump(flatten=True))
.outputs[0]
.text
)
res = self._client.generate(prompt, llm_config=copied, stop=stop)
else:
assert self._runner is not None
res = self._runner(prompt, **config.model_dump(flatten=True))
if isinstance(res, dict) and "text" in res:
return res["text"]
elif isinstance(res, str):
return res
assert self._llm is not None
res = asyncio.run(self._llm.generate(prompt, stop=stop, **copied))
if hasattr(res, "outputs"):
return res.outputs[0].text
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
"Expected result to be either a 'openllm.GenerationOutput' or "
f"'openllm_client.Response' output. Received '{res}' instead"
)
async def _acall(
@@ -294,44 +283,20 @@ class OpenLLM(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
try:
import openllm
except ImportError as e:
raise ImportError(
"Could not import openllm. Make sure to install it with "
"'pip install openllm'."
) from e
copied = copy.deepcopy(self.llm_kwargs)
copied.update(kwargs)
config = openllm.AutoConfig.for_model(
self._identifying_params["model_name"], **copied
)
if self._client:
async_client = openllm.client.AsyncHTTPClient(self.server_url)
res = (
await async_client.generate(prompt, **config.model_dump(flatten=True))
).responses[0]
if self._async_client:
res = await self._async_client.generate(
prompt, llm_config=copied, stop=stop
)
else:
assert self._runner is not None
(
prompt,
generate_kwargs,
postprocess_kwargs,
) = self._runner.llm.sanitize_parameters(prompt, **kwargs)
generated_result = await self._runner.generate.async_run(
prompt, **generate_kwargs
)
res = self._runner.llm.postprocess_generate(
prompt, generated_result, **postprocess_kwargs
)
assert self._llm is not None
res = await self._llm.generate(prompt, stop=stop, **copied)
if isinstance(res, dict) and "text" in res:
return res["text"]
elif isinstance(res, str):
return res
if hasattr(res, "outputs"):
return res.outputs[0].text
else:
raise ValueError(
"Expected result to be a dict with key 'text' or a string. "
f"Received {res}"
"Expected result to be either a 'openllm.GenerationOutput' or "
f"'openllm_client.Response' output. Received '{res}' instead"
)