Source code for langchain_community.tools.e2b_data_analysis.tool

from __future__ import annotations

import ast
import json
import os
from io import StringIO
from sys import version_info
from typing import IO, TYPE_CHECKING, Any, Callable, List, Optional, Type

from langchain_core.callbacks import (
    AsyncCallbackManagerForToolRun,
    CallbackManager,
    CallbackManagerForToolRun,
)
from langchain_core.pydantic_v1 import BaseModel, Field, PrivateAttr

from langchain_community.tools import BaseTool, Tool
from langchain_community.tools.e2b_data_analysis.unparse import Unparser

if TYPE_CHECKING:
    from e2b import EnvVars
    from e2b.templates.data_analysis import Artifact

base_description = """Evaluates python code in a sandbox environment. \
The environment is long running and exists across multiple executions. \
You must send the whole script every time and print your outputs. \
Script should be pure python code that can be evaluated. \
It should be in python format NOT markdown. \
The code should NOT be wrapped in backticks. \
All python packages including requests, matplotlib, scipy, numpy, pandas, \
etc are available. Create and display chart using `plt.show()`."""


def _unparse(tree: ast.AST) -> str:
    """Unparse the AST."""
    if version_info.minor < 9:
        s = StringIO()
        Unparser(tree, file=s)
        source_code = s.getvalue()
        s.close()
    else:
        source_code = ast.unparse(tree)  # type: ignore[attr-defined]
    return source_code


[docs]def add_last_line_print(code: str) -> str: """Add print statement to the last line if it's missing. Sometimes, the LLM-generated code doesn't have `print(variable_name)`, instead the LLM tries to print the variable only by writing `variable_name` (as you would in REPL, for example). This methods checks the AST of the generated Python code and adds the print statement to the last line if it's missing. """ tree = ast.parse(code) node = tree.body[-1] if isinstance(node, ast.Expr) and isinstance(node.value, ast.Call): if isinstance(node.value.func, ast.Name) and node.value.func.id == "print": return _unparse(tree) if isinstance(node, ast.Expr): tree.body[-1] = ast.Expr( value=ast.Call( func=ast.Name(id="print", ctx=ast.Load()), args=[node.value], keywords=[], ) ) return _unparse(tree)
[docs]class UploadedFile(BaseModel): """Description of the uploaded path with its remote path.""" name: str remote_path: str description: str
[docs]class E2BDataAnalysisToolArguments(BaseModel): """Arguments for the E2BDataAnalysisTool.""" python_code: str = Field( ..., example="print('Hello World')", description=( "The python script to be evaluated. " "The contents will be in main.py. " "It should not be in markdown format." ), )
[docs]class E2BDataAnalysisTool(BaseTool): """Tool for running python code in a sandboxed environment for data analysis.""" name = "e2b_data_analysis" args_schema: Type[BaseModel] = E2BDataAnalysisToolArguments session: Any description: str _uploaded_files: List[UploadedFile] = PrivateAttr(default_factory=list) def __init__( self, api_key: Optional[str] = None, cwd: Optional[str] = None, env_vars: Optional[EnvVars] = None, on_stdout: Optional[Callable[[str], Any]] = None, on_stderr: Optional[Callable[[str], Any]] = None, on_artifact: Optional[Callable[[Artifact], Any]] = None, on_exit: Optional[Callable[[int], Any]] = None, **kwargs: Any, ): try: from e2b import DataAnalysis except ImportError as e: raise ImportError( "Unable to import e2b, please install with `pip install e2b`." ) from e # If no API key is provided, E2B will try to read it from the environment # variable E2B_API_KEY super().__init__(description=base_description, **kwargs) self.session = DataAnalysis( api_key=api_key, cwd=cwd, env_vars=env_vars, on_stdout=on_stdout, on_stderr=on_stderr, on_exit=on_exit, on_artifact=on_artifact, )
[docs] def close(self) -> None: """Close the cloud sandbox.""" self._uploaded_files = [] self.session.close()
@property def uploaded_files_description(self) -> str: if len(self._uploaded_files) == 0: return "" lines = ["The following files available in the sandbox:"] for f in self._uploaded_files: if f.description == "": lines.append(f"- path: `{f.remote_path}`") else: lines.append( f"- path: `{f.remote_path}` \n description: `{f.description}`" ) return "\n".join(lines) def _run( self, python_code: str, run_manager: Optional[CallbackManagerForToolRun] = None, callbacks: Optional[CallbackManager] = None, ) -> str: python_code = add_last_line_print(python_code) if callbacks is not None: on_artifact = getattr(callbacks.metadata, "on_artifact", None) else: on_artifact = None stdout, stderr, artifacts = self.session.run_python( python_code, on_artifact=on_artifact ) out = { "stdout": stdout, "stderr": stderr, "artifacts": list(map(lambda artifact: artifact.name, artifacts)), } return json.dumps(out) async def _arun( self, python_code: str, run_manager: Optional[AsyncCallbackManagerForToolRun] = None, ) -> str: raise NotImplementedError("e2b_data_analysis does not support async")
[docs] def run_command( self, cmd: str, ) -> dict: """Run shell command in the sandbox.""" proc = self.session.process.start(cmd) output = proc.wait() return { "stdout": output.stdout, "stderr": output.stderr, "exit_code": output.exit_code, }
[docs] def install_python_packages(self, package_names: str | List[str]) -> None: """Install python packages in the sandbox.""" self.session.install_python_packages(package_names)
[docs] def install_system_packages(self, package_names: str | List[str]) -> None: """Install system packages (via apt) in the sandbox.""" self.session.install_system_packages(package_names)
[docs] def download_file(self, remote_path: str) -> bytes: """Download file from the sandbox.""" return self.session.download_file(remote_path)
[docs] def upload_file(self, file: IO, description: str) -> UploadedFile: """Upload file to the sandbox. The file is uploaded to the '/home/user/<filename>' path.""" remote_path = self.session.upload_file(file) f = UploadedFile( name=os.path.basename(file.name), remote_path=remote_path, description=description, ) self._uploaded_files.append(f) self.description = self.description + "\n" + self.uploaded_files_description return f
[docs] def remove_uploaded_file(self, uploaded_file: UploadedFile) -> None: """Remove uploaded file from the sandbox.""" self.session.filesystem.remove(uploaded_file.remote_path) self._uploaded_files = [ f for f in self._uploaded_files if f.remote_path != uploaded_file.remote_path ] self.description = self.description + "\n" + self.uploaded_files_description
[docs] def as_tool(self) -> Tool: return Tool.from_function( func=self._run, name=self.name, description=self.description, args_schema=self.args_schema, )