mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-06 05:08:20 +00:00
Harrison/handle stop tokens ai21 (#1077)
Co-authored-by: Andrew Huang <jhuang16888@gmail.com>
This commit is contained in:
parent
d8ed286200
commit
52753066ef
@ -1,5 +1,5 @@
|
|||||||
"""Wrapper around AI21 APIs."""
|
"""Wrapper around AI21 APIs."""
|
||||||
from typing import Any, Dict, List, Mapping, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
import requests
|
import requests
|
||||||
from pydantic import BaseModel, Extra, root_validator
|
from pydantic import BaseModel, Extra, root_validator
|
||||||
@ -64,6 +64,8 @@ class AI21(LLM, BaseModel):
|
|||||||
|
|
||||||
ai21_api_key: Optional[str] = None
|
ai21_api_key: Optional[str] = None
|
||||||
|
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
base_url: Optional[str] = None
|
base_url: Optional[str] = None
|
||||||
"""Base url to use, if None decides based on model name."""
|
"""Base url to use, if None decides based on model name."""
|
||||||
|
|
||||||
@ -80,7 +82,7 @@ class AI21(LLM, BaseModel):
|
|||||||
return values
|
return values
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _default_params(self) -> Mapping[str, Any]:
|
def _default_params(self) -> Dict[str, Any]:
|
||||||
"""Get the default parameters for calling AI21 API."""
|
"""Get the default parameters for calling AI21 API."""
|
||||||
return {
|
return {
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
@ -95,7 +97,7 @@ class AI21(LLM, BaseModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
"""Get the identifying parameters."""
|
"""Get the identifying parameters."""
|
||||||
return {**{"model": self.model}, **self._default_params}
|
return {**{"model": self.model}, **self._default_params}
|
||||||
|
|
||||||
@ -119,7 +121,11 @@ class AI21(LLM, BaseModel):
|
|||||||
|
|
||||||
response = ai21("Tell me a joke.")
|
response = ai21("Tell me a joke.")
|
||||||
"""
|
"""
|
||||||
if stop is None:
|
if self.stop is not None and stop is not None:
|
||||||
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
|
elif self.stop is not None:
|
||||||
|
stop = self.stop
|
||||||
|
elif stop is None:
|
||||||
stop = []
|
stop = []
|
||||||
if self.base_url is not None:
|
if self.base_url is not None:
|
||||||
base_url = self.base_url
|
base_url = self.base_url
|
||||||
|
Loading…
Reference in New Issue
Block a user