mirror of
https://github.com/hwchase17/langchain.git
synced 2025-04-28 11:55:21 +00:00
Harrison/lintai21 (#114)
This commit is contained in:
parent
d8734ce5ad
commit
9f878e43d8
@ -9,6 +9,8 @@ from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class AI21PenaltyData(BaseModel):
|
||||
"""Parameters for AI21 penalty data."""
|
||||
|
||||
scale: int = 0
|
||||
applyToWhitespaces: bool = True
|
||||
applyToPunctuations: bool = True
|
||||
@ -20,7 +22,8 @@ class AI21PenaltyData(BaseModel):
|
||||
class AI21(BaseModel, LLM):
|
||||
"""Wrapper around AI21 large language models.
|
||||
|
||||
To use, you should have the environment variable ``AI21_API_KEY`` set with your API key.
|
||||
To use, you should have the environment variable ``AI21_API_KEY``
|
||||
set with your API key.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
@ -56,7 +59,7 @@ class AI21(BaseModel, LLM):
|
||||
numResults: int = 1
|
||||
"""How many completions to generate for each prompt."""
|
||||
|
||||
logitBias: Dict[str, float] = None
|
||||
logitBias: Optional[Dict[str, float]] = None
|
||||
"""Adjust the probability of specific tokens being generated."""
|
||||
|
||||
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY")
|
||||
@ -123,10 +126,13 @@ class AI21(BaseModel, LLM):
|
||||
"prompt": prompt,
|
||||
"stopSequences": stop,
|
||||
**self._default_params,
|
||||
}
|
||||
},
|
||||
)
|
||||
if response.status_code != 200:
|
||||
optional_detail = response.json().get('error')
|
||||
raise ValueError(f'AI21 /complete call failed with status code {response.status_code}. Details: {optional_detail}')
|
||||
response = response.json()
|
||||
return response["completions"][0]["data"]["text"]
|
||||
optional_detail = response.json().get("error")
|
||||
raise ValueError(
|
||||
f"AI21 /complete call failed with status code {response.status_code}."
|
||||
f" Details: {optional_detail}"
|
||||
)
|
||||
response_json = response.json()
|
||||
return response_json["completions"][0]["data"]["text"]
|
||||
|
2
setup.py
2
setup.py
@ -14,7 +14,7 @@ setup(
|
||||
version=__version__,
|
||||
packages=find_packages(),
|
||||
description="Building applications with LLMs through composability",
|
||||
install_requires=["pydantic", "sqlalchemy", "numpy"],
|
||||
install_requires=["pydantic", "sqlalchemy", "numpy", "requests"],
|
||||
long_description=long_description,
|
||||
license="MIT",
|
||||
url="https://github.com/hwchase17/langchain",
|
||||
|
@ -8,3 +8,4 @@ isort
|
||||
mypy
|
||||
flake8
|
||||
flake8-docstrings
|
||||
types-requests
|
||||
|
Loading…
Reference in New Issue
Block a user