From 2b2dba6efdcaca7591b177724efceeadfbc9878d Mon Sep 17 00:00:00 2001 From: Aaron <29749331+aarnphm@users.noreply.github.com> Date: Sun, 12 Nov 2023 04:20:45 -0500 Subject: [PATCH] chore: update client correctness Signed-off-by: Aaron <29749331+aarnphm@users.noreply.github.com> --- docs/docs/integrations/llms/openllm.ipynb | 11 ++++--- docs/docs/integrations/providers/openllm.mdx | 4 +-- libs/langchain/langchain/llms/openllm.py | 34 +++++++++----------- 3 files changed, 23 insertions(+), 26 deletions(-) diff --git a/docs/docs/integrations/llms/openllm.ipynb b/docs/docs/integrations/llms/openllm.ipynb index 6d85ea29b0d..bbeb6cfa79b 100644 --- a/docs/docs/integrations/llms/openllm.ipynb +++ b/docs/docs/integrations/llms/openllm.ipynb @@ -40,7 +40,7 @@ "To start an LLM server, use `openllm start` command. For example, to start a dolly-v2 server, run the following command from a terminal:\n", "\n", "```bash\n", - "openllm start dolly-v2\n", + "openllm start facebook/opt-1.3b\n", "```\n", "\n", "\n", @@ -84,8 +84,8 @@ "from langchain.llms import OpenLLM\n", "\n", "llm = OpenLLM(\n", - " model_name=\"dolly-v2\",\n", - " model_id=\"databricks/dolly-v2-3b\",\n", + " model_name=\"opt\",\n", + " model_id=\"facebook/opt-250m\",\n", " temperature=0.94,\n", " repetition_penalty=1.2,\n", ")" @@ -114,7 +114,8 @@ } ], "source": [ - "from langchain.prompts import PromptTemplate\nfrom langchain.chains import LLMChain\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.chains import LLMChain\n", "\n", "template = \"What is a good name for a company that makes {product}?\"\n", "\n", @@ -151,7 +152,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.10" + "version": "3.9.16" } }, "nbformat": 4, diff --git a/docs/docs/integrations/providers/openllm.mdx b/docs/docs/integrations/providers/openllm.mdx index 1f24af8ed22..0ce956a83c2 100644 --- a/docs/docs/integrations/providers/openllm.mdx +++ b/docs/docs/integrations/providers/openllm.mdx @@ -38,7 +38,7 @@ OpenLLM server can run either locally or on the cloud. To try it out locally, start an OpenLLM server: ```bash -openllm start flan-t5 +openllm start facebook/opt-1.3b ``` Wrapper usage: @@ -59,7 +59,7 @@ running inference. ```python from langchain.llms import OpenLLM -llm = OpenLLM(model_name="dolly-v2", model_id='databricks/dolly-v2-7b') +llm = OpenLLM(model_id='HuggingFaceH4/zephyr-7b-alpha') llm("What is the difference between a duck and a goose? And why there are so many Goose in Canada?") ``` diff --git a/libs/langchain/langchain/llms/openllm.py b/libs/langchain/langchain/llms/openllm.py index 932ab967e65..366211a1faf 100644 --- a/libs/langchain/langchain/llms/openllm.py +++ b/libs/langchain/langchain/llms/openllm.py @@ -3,6 +3,7 @@ from __future__ import annotations import copy import json import logging +import warnings from typing import ( TYPE_CHECKING, Any, @@ -11,7 +12,6 @@ from typing import ( Literal, Optional, TypedDict, - Union, overload, ) @@ -26,7 +26,7 @@ if TYPE_CHECKING: import openllm -ServerType = Literal["http", "grpc"] +ServerType = Literal["http"] class IdentifyingParams(TypedDict): @@ -91,9 +91,7 @@ class OpenLLM(LLM): """Keyword arguments to be passed to openllm.LLM""" _llm: Optional[openllm.LLM[Any, Any]] = PrivateAttr(default=None) - _client: Union[ - openllm.client.HTTPClient, openllm.client.GrpcClient, None - ] = PrivateAttr(default=None) + _client: Optional[openllm.client.HTTPClient] = PrivateAttr(default=None) class Config: extra = "forbid" @@ -114,7 +112,7 @@ class OpenLLM(LLM): self, *, server_url: str = ..., - server_type: Literal["grpc", "http"] = ..., + server_type: Literal["http"] = ..., **llm_kwargs: Any, ) -> None: ... @@ -125,7 +123,7 @@ class OpenLLM(LLM): *, model_id: Optional[str] = None, server_url: Optional[str] = None, - server_type: Literal["grpc", "http"] = "http", + server_type: Literal["http"] = "http", embedded: bool = True, **llm_kwargs: Any, ): @@ -144,11 +142,7 @@ class OpenLLM(LLM): assert ( model_id is None and model_name is None ), "'server_url' and {'model_id', 'model_name'} are mutually exclusive" - client_cls = ( - openllm.client.HTTPClient - if server_type == "http" - else openllm.client.GrpcClient - ) + client_cls = openllm.client.HTTPClient client = client_cls(server_url) super().__init__( @@ -170,7 +164,7 @@ class OpenLLM(LLM): # in-process. Wrt to BentoML users, setting embedded=False is the expected # behaviour to invoke the runners remotely. # We need to also enable ensure_available to download and setup the model. - llm = openllm.LLM[Any, Any](model_id, llm_config=config) # ensure_available will now always call + llm = openllm.LLM[Any, Any](model_id, llm_config=config) if embedded: llm.runner.init_local(quiet=True) super().__init__( @@ -207,6 +201,10 @@ class OpenLLM(LLM): def chat(input_text: str): return agent.run(input_text) """ + warnings.warn( + "'OpenLLM.runner' is deprecated, use 'OpenLLM.llm' instead", + DeprecationWarning, + ) if self._llm is None: raise ValueError("OpenLLM must be initialized locally with 'model_name'") return self._llm.runner @@ -222,9 +220,9 @@ class OpenLLM(LLM): def _identifying_params(self) -> IdentifyingParams: """Get the identifying parameters.""" if self._client is not None: - self.llm_kwargs.update(self._client._config()) - model_name = self._client._metadata()["model_name"] - model_id = self._client._metadata()["model_id"] + self.llm_kwargs.update(self._client._config) + model_name = self._client._metadata["model_name"] + model_id = self._client._metadata["model_id"] else: if self._llm is None: raise ValueError("LLM must be initialized.") @@ -308,9 +306,7 @@ class OpenLLM(LLM): ) if self._client: async_client = openllm.client.AsyncHTTPClient(self.server_url) - res = await async_client.generate( - prompt, stop=stop, **config.model_dump(flatten=True) - ) + res = await async_client.generate(prompt, llm_config=config, stop=stop) else: assert self._llm is not None res = await self._llm.generate(