mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-18 18:53:10 +00:00
RWKV: do not propagate model_state between calls (#2565)
RWKV is an RNN with a hidden state that is part of its inference. However, the model state should not be carried across uses and it's a bug to do so. This resets the state for multiple invocations
This commit is contained in:
parent
7a4e1b72a8
commit
7bf5b0ccd3
@ -67,8 +67,6 @@ class RWKV(LLM, BaseModel):
|
||||
|
||||
pipeline: Any = None #: :meta private:
|
||||
|
||||
model_state: Any = None #: :meta private:
|
||||
|
||||
model_tokens: Any = None #: :meta private:
|
||||
|
||||
class Config:
|
||||
@ -145,7 +143,7 @@ class RWKV(LLM, BaseModel):
|
||||
tokens = self.tokenizer.encode(prompt).ids
|
||||
|
||||
logits = None
|
||||
state = self.model_state
|
||||
state = None
|
||||
|
||||
occurrence = {}
|
||||
|
||||
@ -178,8 +176,6 @@ class RWKV(LLM, BaseModel):
|
||||
+ occurrence[n] * self.penalty_alpha_frequency
|
||||
)
|
||||
|
||||
# Update state for future invocations
|
||||
self.model_state = state
|
||||
return decoded
|
||||
|
||||
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
|
Loading…
Reference in New Issue
Block a user