mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 05:52:15 +00:00
community[major], core[patch], langchain[patch], experimental[patch]: Create langchain-community (#14463)
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
This commit is contained in:
234
libs/community/langchain_community/llms/rwkv.py
Normal file
234
libs/community/langchain_community/llms/rwkv.py
Normal file
@@ -0,0 +1,234 @@
|
||||
"""RWKV models.
|
||||
|
||||
Based on https://github.com/saharNooby/rwkv.cpp/blob/master/rwkv/chat_with_bot.py
|
||||
https://github.com/BlinkDL/ChatRWKV/blob/main/v2/chat.py
|
||||
"""
|
||||
from typing import Any, Dict, List, Mapping, Optional, Set
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.pydantic_v1 import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
|
||||
class RWKV(LLM, BaseModel):
|
||||
"""RWKV language models.
|
||||
|
||||
To use, you should have the ``rwkv`` python package installed, the
|
||||
pre-trained model file, and the model's config information.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.llms import RWKV
|
||||
model = RWKV(model="./models/rwkv-3b-fp16.bin", strategy="cpu fp32")
|
||||
|
||||
# Simplest invocation
|
||||
response = model("Once upon a time, ")
|
||||
"""
|
||||
|
||||
model: str
|
||||
"""Path to the pre-trained RWKV model file."""
|
||||
|
||||
tokens_path: str
|
||||
"""Path to the RWKV tokens file."""
|
||||
|
||||
strategy: str = "cpu fp32"
|
||||
"""Token context window."""
|
||||
|
||||
rwkv_verbose: bool = True
|
||||
"""Print debug information."""
|
||||
|
||||
temperature: float = 1.0
|
||||
"""The temperature to use for sampling."""
|
||||
|
||||
top_p: float = 0.5
|
||||
"""The top-p value to use for sampling."""
|
||||
|
||||
penalty_alpha_frequency: float = 0.4
|
||||
"""Positive values penalize new tokens based on their existing frequency
|
||||
in the text so far, decreasing the model's likelihood to repeat the same
|
||||
line verbatim.."""
|
||||
|
||||
penalty_alpha_presence: float = 0.4
|
||||
"""Positive values penalize new tokens based on whether they appear
|
||||
in the text so far, increasing the model's likelihood to talk about
|
||||
new topics.."""
|
||||
|
||||
CHUNK_LEN: int = 256
|
||||
"""Batch size for prompt processing."""
|
||||
|
||||
max_tokens_per_generation: int = 256
|
||||
"""Maximum number of tokens to generate."""
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
|
||||
tokenizer: Any = None #: :meta private:
|
||||
|
||||
pipeline: Any = None #: :meta private:
|
||||
|
||||
model_tokens: Any = None #: :meta private:
|
||||
|
||||
model_state: Any = None #: :meta private:
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"verbose": self.verbose,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"penalty_alpha_frequency": self.penalty_alpha_frequency,
|
||||
"penalty_alpha_presence": self.penalty_alpha_presence,
|
||||
"CHUNK_LEN": self.CHUNK_LEN,
|
||||
"max_tokens_per_generation": self.max_tokens_per_generation,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _rwkv_param_names() -> Set[str]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"verbose",
|
||||
}
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in the environment."""
|
||||
try:
|
||||
import tokenizers
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import tokenizers python package. "
|
||||
"Please install it with `pip install tokenizers`."
|
||||
)
|
||||
try:
|
||||
from rwkv.model import RWKV as RWKVMODEL
|
||||
from rwkv.utils import PIPELINE
|
||||
|
||||
values["tokenizer"] = tokenizers.Tokenizer.from_file(values["tokens_path"])
|
||||
|
||||
rwkv_keys = cls._rwkv_param_names()
|
||||
model_kwargs = {k: v for k, v in values.items() if k in rwkv_keys}
|
||||
model_kwargs["verbose"] = values["rwkv_verbose"]
|
||||
values["client"] = RWKVMODEL(
|
||||
values["model"], strategy=values["strategy"], **model_kwargs
|
||||
)
|
||||
values["pipeline"] = PIPELINE(values["client"], values["tokens_path"])
|
||||
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import rwkv python package. "
|
||||
"Please install it with `pip install rwkv`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {
|
||||
"model": self.model,
|
||||
**self._default_params,
|
||||
**{k: v for k, v in self.__dict__.items() if k in RWKV._rwkv_param_names()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return the type of llm."""
|
||||
return "rwkv"
|
||||
|
||||
def run_rnn(self, _tokens: List[str], newline_adj: int = 0) -> Any:
|
||||
AVOID_REPEAT_TOKENS = []
|
||||
AVOID_REPEAT = ",:?!"
|
||||
for i in AVOID_REPEAT:
|
||||
dd = self.pipeline.encode(i)
|
||||
assert len(dd) == 1
|
||||
AVOID_REPEAT_TOKENS += dd
|
||||
|
||||
tokens = [int(x) for x in _tokens]
|
||||
self.model_tokens += tokens
|
||||
|
||||
out: Any = None
|
||||
|
||||
while len(tokens) > 0:
|
||||
out, self.model_state = self.client.forward(
|
||||
tokens[: self.CHUNK_LEN], self.model_state
|
||||
)
|
||||
tokens = tokens[self.CHUNK_LEN :]
|
||||
END_OF_LINE = 187
|
||||
out[END_OF_LINE] += newline_adj # adjust \n probability
|
||||
|
||||
if self.model_tokens[-1] in AVOID_REPEAT_TOKENS:
|
||||
out[self.model_tokens[-1]] = -999999999
|
||||
return out
|
||||
|
||||
def rwkv_generate(self, prompt: str) -> str:
|
||||
self.model_state = None
|
||||
self.model_tokens = []
|
||||
logits = self.run_rnn(self.tokenizer.encode(prompt).ids)
|
||||
begin = len(self.model_tokens)
|
||||
out_last = begin
|
||||
|
||||
occurrence: Dict = {}
|
||||
|
||||
decoded = ""
|
||||
for i in range(self.max_tokens_per_generation):
|
||||
for n in occurrence:
|
||||
logits[n] -= (
|
||||
self.penalty_alpha_presence
|
||||
+ occurrence[n] * self.penalty_alpha_frequency
|
||||
)
|
||||
token = self.pipeline.sample_logits(
|
||||
logits, temperature=self.temperature, top_p=self.top_p
|
||||
)
|
||||
|
||||
END_OF_TEXT = 0
|
||||
if token == END_OF_TEXT:
|
||||
break
|
||||
if token not in occurrence:
|
||||
occurrence[token] = 1
|
||||
else:
|
||||
occurrence[token] += 1
|
||||
|
||||
logits = self.run_rnn([token])
|
||||
xxx = self.tokenizer.decode(self.model_tokens[out_last:])
|
||||
if "\ufffd" not in xxx: # avoid utf-8 display issues
|
||||
decoded += xxx
|
||||
out_last = begin + i + 1
|
||||
if i >= self.max_tokens_per_generation - 100:
|
||||
break
|
||||
|
||||
return decoded
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
r"""RWKV generation
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: A list of strings to stop generation when encountered.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
prompt = "Once upon a time, "
|
||||
response = model(prompt, n_predict=55)
|
||||
"""
|
||||
text = self.rwkv_generate(prompt)
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
Reference in New Issue
Block a user