mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-01 04:29:09 +00:00
Harrison/lintai21 (#114)
This commit is contained in:
parent
d8734ce5ad
commit
9f878e43d8
@ -9,6 +9,8 @@ from langchain.llms.base import LLM
|
|||||||
|
|
||||||
|
|
||||||
class AI21PenaltyData(BaseModel):
|
class AI21PenaltyData(BaseModel):
|
||||||
|
"""Parameters for AI21 penalty data."""
|
||||||
|
|
||||||
scale: int = 0
|
scale: int = 0
|
||||||
applyToWhitespaces: bool = True
|
applyToWhitespaces: bool = True
|
||||||
applyToPunctuations: bool = True
|
applyToPunctuations: bool = True
|
||||||
@ -20,7 +22,8 @@ class AI21PenaltyData(BaseModel):
|
|||||||
class AI21(BaseModel, LLM):
|
class AI21(BaseModel, LLM):
|
||||||
"""Wrapper around AI21 large language models.
|
"""Wrapper around AI21 large language models.
|
||||||
|
|
||||||
To use, you should have the environment variable ``AI21_API_KEY`` set with your API key.
|
To use, you should have the environment variable ``AI21_API_KEY``
|
||||||
|
set with your API key.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
@ -56,7 +59,7 @@ class AI21(BaseModel, LLM):
|
|||||||
numResults: int = 1
|
numResults: int = 1
|
||||||
"""How many completions to generate for each prompt."""
|
"""How many completions to generate for each prompt."""
|
||||||
|
|
||||||
logitBias: Dict[str, float] = None
|
logitBias: Optional[Dict[str, float]] = None
|
||||||
"""Adjust the probability of specific tokens being generated."""
|
"""Adjust the probability of specific tokens being generated."""
|
||||||
|
|
||||||
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY")
|
ai21_api_key: Optional[str] = os.environ.get("AI21_API_KEY")
|
||||||
@ -123,10 +126,13 @@ class AI21(BaseModel, LLM):
|
|||||||
"prompt": prompt,
|
"prompt": prompt,
|
||||||
"stopSequences": stop,
|
"stopSequences": stop,
|
||||||
**self._default_params,
|
**self._default_params,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
optional_detail = response.json().get('error')
|
optional_detail = response.json().get("error")
|
||||||
raise ValueError(f'AI21 /complete call failed with status code {response.status_code}. Details: {optional_detail}')
|
raise ValueError(
|
||||||
response = response.json()
|
f"AI21 /complete call failed with status code {response.status_code}."
|
||||||
return response["completions"][0]["data"]["text"]
|
f" Details: {optional_detail}"
|
||||||
|
)
|
||||||
|
response_json = response.json()
|
||||||
|
return response_json["completions"][0]["data"]["text"]
|
||||||
|
2
setup.py
2
setup.py
@ -14,7 +14,7 @@ setup(
|
|||||||
version=__version__,
|
version=__version__,
|
||||||
packages=find_packages(),
|
packages=find_packages(),
|
||||||
description="Building applications with LLMs through composability",
|
description="Building applications with LLMs through composability",
|
||||||
install_requires=["pydantic", "sqlalchemy", "numpy"],
|
install_requires=["pydantic", "sqlalchemy", "numpy", "requests"],
|
||||||
long_description=long_description,
|
long_description=long_description,
|
||||||
license="MIT",
|
license="MIT",
|
||||||
url="https://github.com/hwchase17/langchain",
|
url="https://github.com/hwchase17/langchain",
|
||||||
|
@ -8,3 +8,4 @@ isort
|
|||||||
mypy
|
mypy
|
||||||
flake8
|
flake8
|
||||||
flake8-docstrings
|
flake8-docstrings
|
||||||
|
types-requests
|
||||||
|
Loading…
Reference in New Issue
Block a user