community[patch]: OpenLLM Client Fixes + Added Timeout Parameter (#17478)

- OpenLLM was using outdated method to get the final text output from
openllm client invocation which was raising the error. Therefore
corrected that.
- OpenLLM `_identifying_params` was getting the openllm's client
configuration using outdated attributes which was raising error.
- Updated the docstring for OpenLLM.
- Added timeout parameter to be passed to underlying openllm client.
This commit is contained in:
Mohammad Mohtashim 2024-02-19 23:09:11 +05:00 committed by GitHub
parent 1d2aa19aee
commit 43dc5d3416
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -72,7 +72,7 @@ class OpenLLM(LLM):
from langchain_community.llms import OpenLLM from langchain_community.llms import OpenLLM
llm = OpenLLM(server_url='http://localhost:3000') llm = OpenLLM(server_url='http://localhost:3000')
llm("What is the difference between a duck and a goose?") llm.invoke("What is the difference between a duck and a goose?")
""" """
model_name: Optional[str] = None model_name: Optional[str] = None
@ -82,6 +82,8 @@ class OpenLLM(LLM):
See 'openllm models' for all available model variants.""" See 'openllm models' for all available model variants."""
server_url: Optional[str] = None server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'.""" """Optional server URL that currently runs a LLMServer with 'openllm start'."""
timeout: int = 30
""""Time out for the openllm client"""
server_type: ServerType = "http" server_type: ServerType = "http"
"""Optional server type. Either 'http' or 'grpc'.""" """Optional server type. Either 'http' or 'grpc'."""
embedded: bool = True embedded: bool = True
@ -125,6 +127,7 @@ class OpenLLM(LLM):
*, *,
model_id: Optional[str] = None, model_id: Optional[str] = None,
server_url: Optional[str] = None, server_url: Optional[str] = None,
timeout: int = 30,
server_type: Literal["grpc", "http"] = "http", server_type: Literal["grpc", "http"] = "http",
embedded: bool = True, embedded: bool = True,
**llm_kwargs: Any, **llm_kwargs: Any,
@ -149,11 +152,12 @@ class OpenLLM(LLM):
if server_type == "http" if server_type == "http"
else openllm.client.GrpcClient else openllm.client.GrpcClient
) )
client = client_cls(server_url) client = client_cls(server_url, timeout)
super().__init__( super().__init__(
**{ **{
"server_url": server_url, "server_url": server_url,
"timeout": timeout,
"server_type": server_type, "server_type": server_type,
"llm_kwargs": llm_kwargs, "llm_kwargs": llm_kwargs,
} }
@ -217,9 +221,9 @@ class OpenLLM(LLM):
def _identifying_params(self) -> IdentifyingParams: def _identifying_params(self) -> IdentifyingParams:
"""Get the identifying parameters.""" """Get the identifying parameters."""
if self._client is not None: if self._client is not None:
self.llm_kwargs.update(self._client._config()) self.llm_kwargs.update(self._client._config)
model_name = self._client._metadata()["model_name"] model_name = self._client._metadata.model_dump()["model_name"]
model_id = self._client._metadata()["model_id"] model_id = self._client._metadata.model_dump()["model_id"]
else: else:
if self._runner is None: if self._runner is None:
raise ValueError("Runner must be initialized.") raise ValueError("Runner must be initialized.")
@ -265,9 +269,11 @@ class OpenLLM(LLM):
self._identifying_params["model_name"], **copied self._identifying_params["model_name"], **copied
) )
if self._client: if self._client:
res = self._client.generate( res = (
prompt, **config.model_dump(flatten=True) self._client.generate(prompt, **config.model_dump(flatten=True))
).responses[0] .outputs[0]
.text
)
else: else:
assert self._runner is not None assert self._runner is not None
res = self._runner(prompt, **config.model_dump(flatten=True)) res = self._runner(prompt, **config.model_dump(flatten=True))