Source code for langchain_experimental.llm_bash.base

"""Chain that interprets a prompt and executes bash operations."""
from __future__ import annotations

import logging
import warnings
from typing import Any, Dict, List, Optional

from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.schema import BasePromptTemplate, OutputParserException
from langchain.schema.language_model import BaseLanguageModel

from langchain_experimental.llm_bash.bash import BashProcess
from langchain_experimental.llm_bash.prompt import PROMPT
from langchain_experimental.pydantic_v1 import Extra, Field, root_validator

logger = logging.getLogger(__name__)


[docs]class LLMBashChain(Chain): """Chain that interprets a prompt and executes bash operations. Example: .. code-block:: python from langchain.chains import LLMBashChain from langchain.llms import OpenAI llm_bash = LLMBashChain.from_llm(OpenAI()) """ llm_chain: LLMChain llm: Optional[BaseLanguageModel] = None """[Deprecated] LLM wrapper to use.""" input_key: str = "question" #: :meta private: output_key: str = "answer" #: :meta private: prompt: BasePromptTemplate = PROMPT """[Deprecated]""" bash_process: BashProcess = Field(default_factory=BashProcess) #: :meta private: class Config: """Configuration for this pydantic object.""" extra = Extra.forbid arbitrary_types_allowed = True @root_validator(pre=True) def raise_deprecation(cls, values: Dict) -> Dict: if "llm" in values: warnings.warn( "Directly instantiating an LLMBashChain with an llm is deprecated. " "Please instantiate with llm_chain or using the from_llm class method." ) if "llm_chain" not in values and values["llm"] is not None: prompt = values.get("prompt", PROMPT) values["llm_chain"] = LLMChain(llm=values["llm"], prompt=prompt) return values # TODO: move away from `root_validator` since it is deprecated in pydantic v2 # and causes mypy type-checking failures (hence the `type: ignore`) @root_validator # type: ignore[call-overload] def validate_prompt(cls, values: Dict) -> Dict: if values["llm_chain"].prompt.output_parser is None: raise ValueError( "The prompt used by llm_chain is expected to have an output_parser." ) return values @property def input_keys(self) -> List[str]: """Expect input key. :meta private: """ return [self.input_key] @property def output_keys(self) -> List[str]: """Expect output key. :meta private: """ return [self.output_key] def _call( self, inputs: Dict[str, Any], run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, str]: _run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager() _run_manager.on_text(inputs[self.input_key], verbose=self.verbose) t = self.llm_chain.predict( question=inputs[self.input_key], callbacks=_run_manager.get_child() ) _run_manager.on_text(t, color="green", verbose=self.verbose) t = t.strip() try: parser = self.llm_chain.prompt.output_parser command_list = parser.parse(t) # type: ignore[union-attr] except OutputParserException as e: _run_manager.on_chain_error(e, verbose=self.verbose) raise e if self.verbose: _run_manager.on_text("\nCode: ", verbose=self.verbose) _run_manager.on_text( str(command_list), color="yellow", verbose=self.verbose ) output = self.bash_process.run(command_list) _run_manager.on_text("\nAnswer: ", verbose=self.verbose) _run_manager.on_text(output, color="yellow", verbose=self.verbose) return {self.output_key: output} @property def _chain_type(self) -> str: return "llm_bash_chain"
[docs] @classmethod def from_llm( cls, llm: BaseLanguageModel, prompt: BasePromptTemplate = PROMPT, **kwargs: Any, ) -> LLMBashChain: llm_chain = LLMChain(llm=llm, prompt=prompt) return cls(llm_chain=llm_chain, **kwargs)