From 617a4e617b1c9a71a89effa996e651e3fc0084e1 Mon Sep 17 00:00:00 2001 From: Alex Sherstinsky Date: Fri, 30 Aug 2024 12:41:42 -0700 Subject: [PATCH] community: Fix a bug in handling kwargs overwrites in Predibase integration, and update the documentation. (#25893) Thank you for contributing to LangChain! - [x] **PR title**: "package: description" - Where "package" is whichever of langchain, community, core, experimental, etc. is being modified. Use "docs: ..." for purely docs changes, "templates: ..." for template changes, "infra: ..." for CI changes. - Example: "community: add foobar LLM" - [x] **PR message**: ***Delete this entire checklist*** and replace with - **Description:** a description of the change - **Issue:** the issue # it fixes, if applicable - **Dependencies:** any dependencies required for this change - **Twitter handle:** if your PR gets announced, and you'd like a mention, we'll gladly shout you out! - [x] **Add tests and docs**: If you're adding a new integration, please include 1. a test for the integration, preferably unit tests that do not rely on network access, 2. an example notebook showing its use. It lives in `docs/docs/integrations` directory. - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/ Additional guidelines: - Make sure optional dependencies are imported within a function. - Please do not add dependencies to pyproject.toml files (even optional ones) unless they are required for unit tests. - Most PRs should not touch more than one package. - Changes should be backwards compatible. - If you are adding something to community, do not re-import it in langchain. If no one reviews your PR within a few days, please @-mention one of baskaryan, efriis, eyurtsev, ccurme, vbarda, hwchase17. --- docs/docs/integrations/llms/predibase.ipynb | 33 +++++++++++- docs/docs/integrations/providers/predibase.md | 51 +++++++++++++++++-- .../langchain_community/llms/predibase.py | 2 +- 3 files changed, 80 insertions(+), 6 deletions(-) diff --git a/docs/docs/integrations/llms/predibase.ipynb b/docs/docs/integrations/llms/predibase.ipynb index 82ffeddda12..5c352792be5 100644 --- a/docs/docs/integrations/llms/predibase.ipynb +++ b/docs/docs/integrations/llms/predibase.ipynb @@ -70,6 +70,10 @@ " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\",\n", " adapter_version=1,\n", + " **{\n", + " \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n", + " \"max_new_tokens\": 5, # default is 256\n", + " },\n", ")" ] }, @@ -87,6 +91,10 @@ " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"predibase/e2e_nlg\",\n", + " **{\n", + " \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n", + " \"max_new_tokens\": 5, # default is 256\n", + " },\n", ")" ] }, @@ -96,7 +104,11 @@ "metadata": {}, "outputs": [], "source": [ - "response = model.invoke(\"Can you recommend me a nice dry wine?\")\n", + "# Optionally use `kwargs` to dynamically overwrite \"generate()\" settings.\n", + "response = model.invoke(\n", + " \"Can you recommend me a nice dry wine?\",\n", + " **{\"temperature\": 0.5, \"max_new_tokens\": 1024},\n", + ")\n", "print(response)" ] }, @@ -127,6 +139,10 @@ " model=\"mistral-7b\",\n", " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", + " **{\n", + " \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n", + " \"max_new_tokens\": 5, # default is 256\n", + " },\n", ")" ] }, @@ -147,6 +163,10 @@ " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"e2e_nlg\",\n", " adapter_version=1,\n", + " **{\n", + " \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n", + " \"max_new_tokens\": 5, # default is 256\n", + " },\n", ")" ] }, @@ -162,6 +182,10 @@ " predibase_api_key=os.environ.get(\"PREDIBASE_API_TOKEN\"),\n", " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"predibase/e2e_nlg\",\n", + " **{\n", + " \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n", + " \"max_new_tokens\": 5, # default is 256\n", + " },\n", ")" ] }, @@ -259,6 +283,10 @@ " predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted)\n", " adapter_id=\"my-finetuned-adapter-id\", # Supports both, Predibase-hosted and HuggingFace-hosted adapter repositories.\n", " adapter_version=1, # required for Predibase-hosted adapters (ignored for HuggingFace-hosted adapters)\n", + " **{\n", + " \"api_token\": os.environ.get(\"HUGGING_FACE_HUB_TOKEN\"),\n", + " \"max_new_tokens\": 5, # default is 256\n", + " },\n", ")\n", "# replace my-base-LLM with the name of your choice of a serverless base model in Predibase" ] @@ -269,7 +297,8 @@ "metadata": {}, "outputs": [], "source": [ - "# response = model.invoke(\"Can you help categorize the following emails into positive, negative, and neutral?\")" + "# Optionally use `kwargs` to dynamically overwrite \"generate()\" settings.\n", + "# response = model.invoke(\"Can you help categorize the following emails into positive, negative, and neutral?\", **{\"temperature\": 0.5, \"max_new_tokens\": 1024})" ] } ], diff --git a/docs/docs/integrations/providers/predibase.md b/docs/docs/integrations/providers/predibase.md index 7ba380d1331..dba020c8832 100644 --- a/docs/docs/integrations/providers/predibase.md +++ b/docs/docs/integrations/providers/predibase.md @@ -21,9 +21,24 @@ model = Predibase( model="mistral-7b", predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) + """ + Optionally use `model_kwargs` to set new default "generate()" settings. For example: + { + "api_token": os.environ.get("HUGGING_FACE_HUB_TOKEN"), + "max_new_tokens": 5, # default is 256 + } + """ + **model_kwargs, ) -response = model.invoke("Can you recommend me a nice dry wine?") +""" +Optionally use `kwargs` to dynamically overwrite "generate()" settings. For example: +{ + "temperature": 0.5, # default is the value in model_kwargs or 0.1 (initialization default) + "max_new_tokens": 1024, # default is the value in model_kwargs or 256 (initialization default) +} +""" +response = model.invoke("Can you recommend me a nice dry wine?", **kwargs) print(response) ``` @@ -42,9 +57,24 @@ model = Predibase( predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) adapter_id="e2e_nlg", adapter_version=1, + """ + Optionally use `model_kwargs` to set new default "generate()" settings. For example: + { + "api_token": os.environ.get("HUGGING_FACE_HUB_TOKEN"), + "max_new_tokens": 5, # default is 256 + } + """ + **model_kwargs, ) -response = model.invoke("Can you recommend me a nice dry wine?") +""" +Optionally use `kwargs` to dynamically overwrite "generate()" settings. For example: +{ + "temperature": 0.5, # default is the value in model_kwargs or 0.1 (initialization default) + "max_new_tokens": 1024, # default is the value in model_kwargs or 256 (initialization default) +} +""" +response = model.invoke("Can you recommend me a nice dry wine?", **kwargs) print(response) ``` @@ -62,8 +92,23 @@ model = Predibase( predibase_api_key=os.environ.get("PREDIBASE_API_TOKEN"), predibase_sdk_version=None, # optional parameter (defaults to the latest Predibase SDK version if omitted) adapter_id="predibase/e2e_nlg", + """ + Optionally use `model_kwargs` to set new default "generate()" settings. For example: + { + "api_token": os.environ.get("HUGGING_FACE_HUB_TOKEN"), + "max_new_tokens": 5, # default is 256 + } + """ + **model_kwargs, ) -response = model.invoke("Can you recommend me a nice dry wine?") +""" +Optionally use `kwargs` to dynamically overwrite "generate()" settings. For example: +{ + "temperature": 0.5, # default is the value in model_kwargs or 0.1 (initialization default) + "max_new_tokens": 1024, # default is the value in model_kwargs or 256 (initialization default) +} +""" +response = model.invoke("Can you recommend me a nice dry wine?", **kwargs) print(response) ``` diff --git a/libs/community/langchain_community/llms/predibase.py b/libs/community/langchain_community/llms/predibase.py index 65aca86bcff..398c169b4a2 100644 --- a/libs/community/langchain_community/llms/predibase.py +++ b/libs/community/langchain_community/llms/predibase.py @@ -50,8 +50,8 @@ class Predibase(LLM): **kwargs: Any, ) -> str: options: Dict[str, Union[str, float]] = { - **(self.model_kwargs or {}), **self.default_options_for_generation, + **(self.model_kwargs or {}), **(kwargs or {}), } if self._is_deprecated_sdk_version():