diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index b195da2f..f0db255a 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,7 +6,7 @@ on: jobs: build: - runs-on: ubuntu-latest + runs-on: ubuntu-22.04 strategy: matrix: python-version: ["3.7", "3.8", "3.12"] diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index a5d41652..65f3f166 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -21,7 +21,7 @@ repos: # Python - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. - rev: v0.8.1 + rev: v0.8.2 hooks: - id: ruff args: ["--fix"] diff --git a/dpdata/stat.py b/dpdata/stat.py index 5ec39570..ed74c258 100644 --- a/dpdata/stat.py +++ b/dpdata/stat.py @@ -2,13 +2,14 @@ from abc import ABCMeta, abstractmethod from functools import lru_cache +from typing import Any import numpy as np from dpdata.system import LabeledSystem, MultiSystems -def mae(errors: np.ndarray) -> np.float64: +def mae(errors: np.ndarray) -> np.floating[Any]: """Compute the mean absolute error (MAE). Parameters @@ -18,13 +19,13 @@ def mae(errors: np.ndarray) -> np.float64: Returns ------- - np.float64 + floating[Any] mean absolute error (MAE) """ return np.mean(np.abs(errors)) -def rmse(errors: np.ndarray) -> np.float64: +def rmse(errors: np.ndarray) -> np.floating[Any]: """Compute the root mean squared error (RMSE). Parameters @@ -34,7 +35,7 @@ def rmse(errors: np.ndarray) -> np.float64: Returns ------- - np.float64 + floating[Any] root mean squared error (RMSE) """ return np.sqrt(np.mean(np.square(errors))) @@ -74,22 +75,22 @@ def f_errors(self) -> np.ndarray: """Force errors.""" @property - def e_mae(self) -> np.float64: + def e_mae(self) -> np.floating[Any]: """Energy MAE.""" return mae(self.e_errors) @property - def e_rmse(self) -> np.float64: + def e_rmse(self) -> np.floating[Any]: """Energy RMSE.""" return rmse(self.e_errors) @property - def f_mae(self) -> np.float64: + def f_mae(self) -> np.floating[Any]: """Force MAE.""" return mae(self.f_errors) @property - def f_rmse(self) -> np.float64: + def f_rmse(self) -> np.floating[Any]: """Force RMSE.""" return rmse(self.f_errors) diff --git a/dpdata/system.py b/dpdata/system.py index abe0a755..00172602 100644 --- a/dpdata/system.py +++ b/dpdata/system.py @@ -1049,6 +1049,7 @@ def remove_atom_names(self, atom_names: str | list[str]): atom_idx = self.data["atom_types"] == idx removed_atom_idx.append(atom_idx) picked_atom_idx = ~np.any(removed_atom_idx, axis=0) + assert not isinstance(picked_atom_idx, np.bool_) new_sys = self.pick_atom_idx(picked_atom_idx) # let's remove atom_names # firstly, rearrange atom_names and put these atom_names in the end diff --git a/tests/test_abacus_pw_scf.py b/tests/test_abacus_pw_scf.py index 20751f81..0d2bdef5 100644 --- a/tests/test_abacus_pw_scf.py +++ b/tests/test_abacus_pw_scf.py @@ -158,7 +158,7 @@ def test_noforcestress_job(self): # check below will not throw error system_ch4 = dpdata.LabeledSystem("abacus.scf", fmt="abacus/scf") # check the returned force is empty - self.assertFalse(system_ch4.data["forces"]) + self.assertFalse(system_ch4.data["forces"].size) self.assertTrue("virials" not in system_ch4.data) # test append self system_ch4.append(system_ch4)