From 43dc5d34163829720ba46d2268db61909e7e5c6c Mon Sep 17 00:00:00 2001 From: Mohammad Mohtashim <45242107+keenborder786@users.noreply.github.com> Date: Mon, 19 Feb 2024 23:09:11 +0500 Subject: [PATCH] community[patch]: OpenLLM Client Fixes + Added Timeout Parameter (#17478) - OpenLLM was using outdated method to get the final text output from openllm client invocation which was raising the error. Therefore corrected that. - OpenLLM `_identifying_params` was getting the openllm's client configuration using outdated attributes which was raising error. - Updated the docstring for OpenLLM. - Added timeout parameter to be passed to underlying openllm client. --- .../langchain_community/llms/openllm.py | 22 ++++++++++++------- 1 file changed, 14 insertions(+), 8 deletions(-) diff --git a/libs/community/langchain_community/llms/openllm.py b/libs/community/langchain_community/llms/openllm.py index afb5a18f9ba..fa3b03e1f98 100644 --- a/libs/community/langchain_community/llms/openllm.py +++ b/libs/community/langchain_community/llms/openllm.py @@ -72,7 +72,7 @@ class OpenLLM(LLM): from langchain_community.llms import OpenLLM llm = OpenLLM(server_url='http://localhost:3000') - llm("What is the difference between a duck and a goose?") + llm.invoke("What is the difference between a duck and a goose?") """ model_name: Optional[str] = None @@ -82,6 +82,8 @@ class OpenLLM(LLM): See 'openllm models' for all available model variants.""" server_url: Optional[str] = None """Optional server URL that currently runs a LLMServer with 'openllm start'.""" + timeout: int = 30 + """"Time out for the openllm client""" server_type: ServerType = "http" """Optional server type. Either 'http' or 'grpc'.""" embedded: bool = True @@ -125,6 +127,7 @@ class OpenLLM(LLM): *, model_id: Optional[str] = None, server_url: Optional[str] = None, + timeout: int = 30, server_type: Literal["grpc", "http"] = "http", embedded: bool = True, **llm_kwargs: Any, @@ -149,11 +152,12 @@ class OpenLLM(LLM): if server_type == "http" else openllm.client.GrpcClient ) - client = client_cls(server_url) + client = client_cls(server_url, timeout) super().__init__( **{ "server_url": server_url, + "timeout": timeout, "server_type": server_type, "llm_kwargs": llm_kwargs, } @@ -217,9 +221,9 @@ class OpenLLM(LLM): def _identifying_params(self) -> IdentifyingParams: """Get the identifying parameters.""" if self._client is not None: - self.llm_kwargs.update(self._client._config()) - model_name = self._client._metadata()["model_name"] - model_id = self._client._metadata()["model_id"] + self.llm_kwargs.update(self._client._config) + model_name = self._client._metadata.model_dump()["model_name"] + model_id = self._client._metadata.model_dump()["model_id"] else: if self._runner is None: raise ValueError("Runner must be initialized.") @@ -265,9 +269,11 @@ class OpenLLM(LLM): self._identifying_params["model_name"], **copied ) if self._client: - res = self._client.generate( - prompt, **config.model_dump(flatten=True) - ).responses[0] + res = ( + self._client.generate(prompt, **config.model_dump(flatten=True)) + .outputs[0] + .text + ) else: assert self._runner is not None res = self._runner(prompt, **config.model_dump(flatten=True))