diff --git a/abcfold/boltz/check_install.py b/abcfold/boltz/check_install.py index 660b3ff..ff66de4 100644 --- a/abcfold/boltz/check_install.py +++ b/abcfold/boltz/check_install.py @@ -19,6 +19,11 @@ def check_boltz(): ) as proc: stdout, stderr = proc.communicate() if proc.returncode != 0: + if "Package(s) not found:" in stderr.decode(): + + raise ModuleNotFoundError( + "Boltz package not found." + ) raise subprocess.CalledProcessError(proc.returncode, cmd, stderr) version = None @@ -41,7 +46,10 @@ def check_boltz(): "pip", "install", f"boltz=={BOLTZ_VERSION}", + "cuequivariance_torch", + "cuequivariance_ops_torch-cu12", "--no-cache-dir", + ] logger.info("Running %s", " ".join(cmd)) diff --git a/abcfold/chai1/check_install.py b/abcfold/chai1/check_install.py index 883143d..13fa2b7 100644 --- a/abcfold/chai1/check_install.py +++ b/abcfold/chai1/check_install.py @@ -5,6 +5,7 @@ logger = logging.getLogger("logger") +CHAI_VERSION = "0.6.1" CHAI_VERSION = "0.6.1" @@ -20,6 +21,11 @@ def check_chai1(): ) as proc: stdout, stderr = proc.communicate() if proc.returncode != 0: + if "Package(s) not found:" in stderr.decode(): + + raise ModuleNotFoundError( + "Chai_lab package not found." + ) raise subprocess.CalledProcessError(proc.returncode, cmd, stderr) version = None @@ -51,15 +57,42 @@ def check_chai1(): ] cmd.append("--no-deps") if no_deps else None logger.info("Running %s", " ".join(cmd)) - with subprocess.Popen( - cmd, - stdout=sys.stdout, - stderr=subprocess.PIPE, - ) as proc: - proc.wait() - if proc.returncode != 0: - if proc.stderr: - logger.error(proc.stderr.read().decode()) - raise subprocess.CalledProcessError(proc.returncode, proc.args) + run_command_using_sys(cmd) + if no_deps: + cmd = [ + sys.executable, + "-m", + "pip", + "install", + "antipickle", + "typer", + "jaxtyping", + "beartype", + "pandera", + "matplotlib", + ] + logger.info("Installing dependencies: %s", " ".join(cmd)) + run_command_using_sys(cmd) + except Exception as e: + logger.error("Error while checking or installing chai_lab: %s", e) + raise ImportError( + "chai_lab package is not installed. " + "Please install it using `pip install chai_lab`." + ) from e logger.info(f"Running Chai version: {CHAI_VERSION}") + + +def run_command_using_sys(command: list[str]) -> None: + """Run a command using sys.executable.""" + logger.info("Running command: %s", " ".join(command)) + with subprocess.Popen( + command, + stdout=sys.stdout, + stderr=subprocess.PIPE, + ) as proc: + proc.wait() + if proc.returncode != 0: + if proc.stderr: + logger.error(proc.stderr.read().decode()) + raise subprocess.CalledProcessError(proc.returncode, proc.args)