langchain/libs/experimental/langchain_experimental/llm_bash/base.py
Eugene Yurtsev b9f65e5038
experimental[patch]: Migrate pydantic extra to literals (#25194)
Migrate pydantic extra to literals

Upgrade to using a literal for specifying the extra which is the
recommended approach in pydantic 2.

This works correctly also in pydantic v1.

```python
from pydantic.v1 import BaseModel

class Foo(BaseModel, extra="forbid"):
    x: int

Foo(x=5, y=1)
```

And 


```python
from pydantic.v1 import BaseModel

class Foo(BaseModel):
    x: int

    class Config:
      extra = "forbid"

Foo(x=5, y=1)
```


## Enum -> literal using grit pattern:

```
engine marzano(0.1)
language python
or {
    `extra=Extra.allow` => `extra="allow"`,
    `extra=Extra.forbid` => `extra="forbid"`,
    `extra=Extra.ignore` => `extra="ignore"`
}
```

Resorted attributes in config and removed doc-string in case we will
need to deal with going back and forth between pydantic v1 and v2 during
the 0.3 release. (This will reduce merge conflicts.)


## Sort attributes in Config:

```
engine marzano(0.1)
language python


function sort($values) js {
    return $values.text.split(',').sort().join("\n");
}


class_definition($name, $body) as $C where {
    $name <: `Config`,
    $body <: block($statements),
    $values = [],
    $statements <: some bubble($values) assignment() as $A where {
        $values += $A
    },
    $body => sort($values),
}

```
2024-08-08 19:05:54 +00:00

127 lines
4.3 KiB
Python

"""Chain that interprets a prompt and executes bash operations."""
from __future__ import annotations
import logging
import warnings
from typing import Any, Dict, List, Optional
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain_core.callbacks.manager import CallbackManagerForChainRun
from langchain_core.language_models import BaseLanguageModel
from langchain_experimental.llm_bash.bash import BashProcess
from langchain_experimental.llm_bash.prompt import PROMPT
from langchain_experimental.pydantic_v1 import Field, root_validator
logger = logging.getLogger(__name__)
class LLMBashChain(Chain):
"""Chain that interprets a prompt and executes bash operations.
Example:
.. code-block:: python
from langchain.chains import LLMBashChain
from langchain_community.llms import OpenAI
llm_bash = LLMBashChain.from_llm(OpenAI())
"""
llm_chain: LLMChain
llm: Optional[BaseLanguageModel] = None
"""[Deprecated] LLM wrapper to use."""
input_key: str = "question" #: :meta private:
output_key: str = "answer" #: :meta private:
prompt: BasePromptTemplate = PROMPT
"""[Deprecated]"""
bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private:
class Config:
arbitrary_types_allowed = True
extra = "forbid"
@root_validator(pre=True)
def raise_deprecation(cls, values: Dict) -> Dict:
if "llm" in values:
warnings.warn(
"Directly instantiating an LLMBashChain with an llm is deprecated. "
"Please instantiate with llm_chain or using the from_llm class method."
)
if "llm_chain" not in values and values["llm"] is not None:
prompt = values.get("prompt", PROMPT)
values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt)
return values
# TODO: move away from `root_validator` since it is deprecated in pydantic v2
# and causes mypy type-checking failures (hence the `type: ignore`)
@root_validator # type: ignore[call-overload]
def validate_prompt(cls, values: Dict) -> Dict:
if values["llm_chain"].prompt.output_parser is None:
raise ValueError(
"The prompt used by llm_chain is expected to have an output_parser."
)
return values
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Expect output key.
:meta private:
"""
return [self.output_key]
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, str]:
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
_run_manager.on_text(inputs[self.input_key], verbose=self.verbose)
t = self.llm_chain.predict(
question=inputs[self.input_key], callbacks=_run_manager.get_child()
)
_run_manager.on_text(t, color="green", verbose=self.verbose)
t = t.strip()
try:
parser = self.llm_chain.prompt.output_parser
command_list = parser.parse(t) # type: ignore[union-attr]
except OutputParserException as e:
_run_manager.on_chain_error(e, verbose=self.verbose)
raise e
if self.verbose:
_run_manager.on_text("\nCode: ", verbose=self.verbose)
_run_manager.on_text(
str(command_list), color="yellow", verbose=self.verbose
)
output = self.bash_process.run(command_list)
_run_manager.on_text("\nAnswer: ", verbose=self.verbose)
_run_manager.on_text(output, color="yellow", verbose=self.verbose)
return {self.output_key: output}
@property
def _chain_type(self) -> str:
return "llm_bash_chain"
@classmethod
def from_llm(
cls,
llm: BaseLanguageModel,
prompt: BasePromptTemplate = PROMPT,
**kwargs: Any,
) -> LLMBashChain:
llm_chain = LLMChain(llm=llm, prompt=prompt)
return cls(llm_chain=llm_chain, **kwargs)