mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-19 11:08:55 +00:00
Add additional VertexAI Params (#5837)
## Changes - Added the `stop` param to the `_VertexAICommon` class so it can be set at llm initialization ## Example Usage ```python VertexAI( # ... temperature=0.15, max_output_tokens=128, top_p=1, top_k=40, stop=["\n```"], ) ``` ## Possible Reviewers - @hwchase17 - @agola11
This commit is contained in:
parent
76fcd96dae
commit
9f4b720a63
@ -29,6 +29,8 @@ class _VertexAICommon(BaseModel):
|
|||||||
top_k: int = 40
|
top_k: int = 40
|
||||||
"How the model selects tokens for output, the next token is selected from "
|
"How the model selects tokens for output, the next token is selected from "
|
||||||
"among the top-k most probable tokens."
|
"among the top-k most probable tokens."
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
"Optional list of stop words to use when generating."
|
||||||
project: Optional[str] = None
|
project: Optional[str] = None
|
||||||
"The default GCP project to use when making Vertex API calls."
|
"The default GCP project to use when making Vertex API calls."
|
||||||
location: str = "us-central1"
|
location: str = "us-central1"
|
||||||
@ -48,11 +50,13 @@ class _VertexAICommon(BaseModel):
|
|||||||
}
|
}
|
||||||
return {**base_params}
|
return {**base_params}
|
||||||
|
|
||||||
def _predict(self, prompt: str, stop: Optional[List[str]]) -> str:
|
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||||
res = self.client.predict(prompt, **self._default_params)
|
res = self.client.predict(prompt, **self._default_params)
|
||||||
return self._enforce_stop_words(res.text, stop)
|
return self._enforce_stop_words(res.text, stop)
|
||||||
|
|
||||||
def _enforce_stop_words(self, text: str, stop: Optional[List[str]]) -> str:
|
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
|
||||||
|
if stop is None and self.stop is not None:
|
||||||
|
stop = self.stop
|
||||||
if stop:
|
if stop:
|
||||||
return enforce_stop_tokens(text, stop)
|
return enforce_stop_tokens(text, stop)
|
||||||
return text
|
return text
|
||||||
|
Loading…
Reference in New Issue
Block a user