community[patch]: OpenLLM Client Fixes + Added Timeout Parameter (#17478)

- OpenLLM was using outdated method to get the final text output from
openllm client invocation which was raising the error. Therefore
corrected that.
- OpenLLM `_identifying_params` was getting the openllm's client
configuration using outdated attributes which was raising error.
- Updated the docstring for OpenLLM.
- Added timeout parameter to be passed to underlying openllm client.
This commit is contained in:
Mohammad Mohtashim 2024-02-19 23:09:11 +05:00 committed by GitHub
parent 1d2aa19aee
commit 43dc5d3416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -72,7 +72,7 @@ class OpenLLM(LLM):
from langchain_community.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000')
llm("What is the difference between a duck and a goose?")
llm.invoke("What is the difference between a duck and a goose?")
"""
model_name: Optional[str] = None
@ -82,6 +82,8 @@ class OpenLLM(LLM):
See 'openllm models' for all available model variants."""
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
timeout: int = 30
""""Time out for the openllm client"""
server_type: ServerType = "http"
"""Optional server type. Either 'http' or 'grpc'."""
embedded: bool = True
@ -125,6 +127,7 @@ class OpenLLM(LLM):
*,
model_id: Optional[str] = None,
server_url: Optional[str] = None,
timeout: int = 30,
server_type: Literal["grpc", "http"] = "http",
embedded: bool = True,
**llm_kwargs: Any,
@ -149,11 +152,12 @@ class OpenLLM(LLM):
if server_type == "http"
else openllm.client.GrpcClient
)
client = client_cls(server_url)
client = client_cls(server_url, timeout)
super().__init__(
**{
"server_url": server_url,
"timeout": timeout,
"server_type": server_type,
"llm_kwargs": llm_kwargs,
}
@ -217,9 +221,9 @@ class OpenLLM(LLM):
def _identifying_params(self) -> IdentifyingParams:
"""Get the identifying parameters."""
if self._client is not None:
self.llm_kwargs.update(self._client._config())
model_name = self._client._metadata()["model_name"]
model_id = self._client._metadata()["model_id"]
self.llm_kwargs.update(self._client._config)
model_name = self._client._metadata.model_dump()["model_name"]
model_id = self._client._metadata.model_dump()["model_id"]
else:
if self._runner is None:
raise ValueError("Runner must be initialized.")
@ -265,9 +269,11 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied
)
if self._client:
res = self._client.generate(
prompt, **config.model_dump(flatten=True)
).responses[0]
res = (
self._client.generate(prompt, **config.model_dump(flatten=True))
.outputs[0]
.text
)
else:
assert self._runner is not None
res = self._runner(prompt, **config.model_dump(flatten=True))