mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-04 04:07:54 +00:00
Allow user to modify the GPU and language settings when using NLP Cloud (#7985)
--------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
483f6c2fe3
commit
73d5cba308
@ -20,12 +20,16 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
|
||||
model_name: str # Define model_name as a class attribute
|
||||
gpu: bool # Define gpu as a class attribute
|
||||
client: Any #: :meta private:
|
||||
|
||||
def __init__(
|
||||
self, model_name: str = "paraphrase-multilingual-mpnet-base-v2", **kwargs: Any
|
||||
self,
|
||||
model_name: str = "paraphrase-multilingual-mpnet-base-v2",
|
||||
gpu: bool = False,
|
||||
**kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(model_name=model_name, **kwargs)
|
||||
super().__init__(model_name=model_name, gpu=gpu, **kwargs)
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
@ -37,7 +41,7 @@ class NLPCloudEmbeddings(BaseModel, Embeddings):
|
||||
import nlpcloud
|
||||
|
||||
values["client"] = nlpcloud.Client(
|
||||
values["model_name"], nlpcloud_api_key, gpu=False, lang="en"
|
||||
values["model_name"], nlpcloud_api_key, gpu=values["gpu"], lang="en"
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
|
@ -17,12 +17,16 @@ class NLPCloud(LLM):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import NLPCloud
|
||||
nlpcloud = NLPCloud(model="gpt-neox-20b")
|
||||
nlpcloud = NLPCloud(model="finetuned-gpt-neox-20b")
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
model_name: str = "finetuned-gpt-neox-20b"
|
||||
"""Model name to use."""
|
||||
gpu: bool = True
|
||||
"""Whether to use a GPU or not"""
|
||||
lang: str = "en"
|
||||
"""Language to use (multilingual addon)"""
|
||||
temperature: float = 0.7
|
||||
"""What sampling temperature to use."""
|
||||
min_length: int = 1
|
||||
@ -71,7 +75,10 @@ class NLPCloud(LLM):
|
||||
import nlpcloud
|
||||
|
||||
values["client"] = nlpcloud.Client(
|
||||
values["model_name"], nlpcloud_api_key, gpu=True, lang="en"
|
||||
values["model_name"],
|
||||
nlpcloud_api_key,
|
||||
gpu=values["gpu"],
|
||||
lang=values["lang"],
|
||||
)
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@ -104,7 +111,12 @@ class NLPCloud(LLM):
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {**{"model_name": self.model_name}, **self._default_params}
|
||||
return {
|
||||
**{"model_name": self.model_name},
|
||||
**{"gpu": self.gpu},
|
||||
**{"lang": self.lang},
|
||||
**self._default_params,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user