diff --git a/lyncs_quda/clover_field.py b/lyncs_quda/clover_field.py index 4dfe3fa..958a969 100644 --- a/lyncs_quda/clover_field.py +++ b/lyncs_quda/clover_field.py @@ -8,12 +8,13 @@ import numpy from cppyy.gbl.std import vector +from functools import cache from lyncs_cppyy import make_shared, to_pointer from .lib import lib, cupy from .lattice_field import LatticeField from .gauge_field import GaugeField -from .enums import QudaParity +from .enums import QudaParity, QudaTwistFlavorType, QudaCloverFieldOrder, QudaFieldCreate # TODO list # We want dimension of (cu/num)py array to reflect parity and order @@ -32,12 +33,12 @@ class CloverField(LatticeField): * Only rho is mutable. To change other params, a new instance should be created * QUDA convention for clover field := 1+i ( kappa csw )/4 sigma_mu,nu F_mu,nu (<-sigma_mu,nu: spinor tensor) * so that sigma_mu,nu = i[g_mu, g_nu], F_mu,nu = (Q_mu,nu - Q_nu,mu)/8 (1/2 is missing from sigma_mu,nu) - * Apparently, an input to QUDA clover object, coeff = kappa*csw - * wihout a normalization factor of 1/4 or 1/32 (suggested in interface_quda.cpp) + * Apparently, an input to QUDA clover object, coeff = kappa*csw + * wihout a normalization factor of 1/4 or 1/32 (suggested in interface_quda.cpp) """ def __new__(cls, fmunu, **kwargs): - #TODO: get dofs and local dims from kwargs, instead of getting them + # TODO: get dofs and local dims from kwargs, instead of getting them # from self.shape assuming that it has the form (dofs, local_dims) if isinstance(fmunu, CloverField): return fmunu @@ -52,14 +53,14 @@ def __new__(cls, fmunu, **kwargs): field = fmunu else: fmunu = GaugeField(fmunu) - - if not is_clover: # not copying from a clover-field array + + if not is_clover: # not copying from a clover-field array idof = int((fmunu.ncol * fmunu.ndims) ** 2 / 2) prec = fmunu.dtype field = fmunu.backend.empty((idof,) + fmunu.dims, dtype=prec) - + return super().__new__(cls, field, **kwargs) - + def __init__( self, obj, @@ -70,7 +71,7 @@ def __init__( eps2=0, rho=0, computeTrLog=False, - **kwargs + **kwargs, ): # WARNING: ndarray object is not supposed to be view-casted to CloverField object # except in __new__, for which __init__ will be called subsequently, @@ -88,12 +89,8 @@ def __init__( empty=True, ) self._fmunu = obj.compute_fmunu() - self._direct = ( - False # Here, it is a flag to indicate whether the field has been computed - ) - self._inverse = ( - False # Here, it is a flag to indicate whether the field has been computed - ) + self._direct = False # Here, it is a flag to indicate whether the field has been computed + self._inverse = False # Here, it is a flag to indicate whether the field has been computed self.coeff = coeff self._twisted = twisted self._twist_flavor = tf @@ -107,12 +104,14 @@ def __init__( elif isinstance(obj, self.backend.ndarray): pass else: - raise ValueError("The input is expected to be ndarray or LatticeField object") + raise ValueError( + "The input is expected to be ndarray or LatticeField object" + ) def _prepare(self, field, copy=False, check=False, **kwargs): # When CloverField object prepares its input, the input is assumed to be of CloverField return super()._prepare(field, copy=copy, check=check, is_clover=True, **kwargs) - + # naming suggestion: native_view? default_* saved for dofs+lattice? def default_view(self): N = 1 if self.order == "FLOAT2" else 4 @@ -126,6 +125,7 @@ def twisted(self): return self._twisted @property + @QudaTwistFlavorType def twist_flavor(self): return self._twist_flavor @@ -153,16 +153,22 @@ def rho(self, val): self._rho = val @property + @QudaCloverFieldOrder def order(self): "Data order of the field" if self.precision == "double": return "FLOAT2" return "FLOAT4" - @property - def quda_order(self): - "Quda enum for data order of the field" - return getattr(lib, f"QUDA_{self.order}_CLOVER_ORDER") + @staticmethod + @cache + def _clv_params(param, **kwargs): + "Call wrapper to cache param structures" + params = lib.CloverFieldParam() + lib.copy_struct(params, param) + for key, val in kwargs.items(): + setattr(params, key, val) + return params @property def quda_params(self): @@ -178,20 +184,21 @@ def quda_params(self): an alias to inverse. not really sure what this is, but does not work properly when reconstruct==True """ - params = lib.CloverFieldParam() - lib.copy_struct(params, super().quda_params) - params.inverse = True - params.clover = to_pointer(self.ptr) - params.cloverInv = to_pointer(self._cloverInv.ptr) - params.coeff = self.coeff - params.twisted = self.twisted - params.twist_flavor = getattr(lib, f"QUDA_TWIST_{self.twist_flavor}") - params.mu2 = self.mu2 - params.epsilon2 = self.eps2 - params.rho = self.rho - params.order = self.quda_order - params.create = lib.QUDA_REFERENCE_FIELD_CREATE - params.location = self.quda_location + params = self._clv_params( + super().quda_params, + inverse=True, + clover=to_pointer(self.ptr), + cloverInv=to_pointer(self._cloverInv.ptr), + coeff=self.coeff, + twisted=self.twisted, + twist_flavor=int(self.twist_flavor), + mu2=self.mu2, + epsilon2=self.eps2, + rho=self.rho, + order=int(self.order), + create=int(QudaFieldCreate["reference"]), + location=int(self.location), + ) return params @property @@ -230,7 +237,7 @@ def trLog(self): def is_native(self): "Whether the field is native for Quda" - return lib.clover.isNative(self.quda_order, self.quda_precision) + return lib.clover.isNative(int(self.order), int(self.precision)) @property def ncol(self): @@ -354,11 +361,7 @@ def computeCloverForce(self, gauge, force, D, vxs, vps, mult=2, coeffs=None): u = gauge.extended_field(sites=R) if gauge.precision == "double": u = gauge.prepare_in(gauge, reconstruct="NO").extended_field(sites=R) - lib.cloverDerivative( - force.quda_field, u, oprodEx, 1.0, getattr(lib, "QUDA_ODD_PARITY") - ) - lib.cloverDerivative( - force.quda_field, u, oprodEx, 1.0, getattr(lib, "QUDA_EVEN_PARITY") - ) + lib.cloverDerivative(force.quda_field, u, oprodEx, 1.0, int(QudaParity["ODD"])) + lib.cloverDerivative(force.quda_field, u, oprodEx, 1.0, int(QudaParity["EVEN"])) return force diff --git a/lyncs_quda/dirac.py b/lyncs_quda/dirac.py index 5de9d98..b978990 100644 --- a/lyncs_quda/dirac.py +++ b/lyncs_quda/dirac.py @@ -15,7 +15,12 @@ from .clover_field import CloverField from .spinor_field import spinor from .lib import lib -from .enums import QudaPrecision +from .enums import ( + QudaDiracType, + QudaMatPCType, + QudaDagType, + QudaParity, + ) @dataclass(frozen=True) @@ -48,6 +53,7 @@ def __post_init__(self): # TODO: Support more Dirac types # Unsupported: DomainWall(4D/PC), Mobius(PC/Eofa), (Improved)Staggered(KD/PC), GaugeLaplace(PC), GaugeCovDev @property + @QudaDiracType def type(self): "Type of the operator" PC = "PC" if not self.full else "" @@ -62,11 +68,7 @@ def type(self): return "TWISTED_CLOVER" + PC @property - def quda_type(self): - "Quda enum for quda dslash type" - return getattr(lib, f"QUDA_{self.type}_DIRAC") - - @property + @QudaMatPCType def matPCtype(self): if self.full: return "INVALID" @@ -74,10 +76,6 @@ def matPCtype(self): symm = "_ASYMMETRIC" if not self.symm else "" return f"{parity}_{parity}{symm}" - @property - def quda_matPCtype(self): - return getattr(lib, f"QUDA_MATPC_{self.matPCtype}") - @property def is_coarse(self): "Whether is a coarse operator" @@ -88,26 +86,22 @@ def precision(self): return self.gauge.precision @property + @QudaDagType def dagger(self): "If the operator is daggered" return "NO" - @property - def quda_dagger(self): - "Quda enum for if the operator is dagger" - return getattr(lib, f"QUDA_DAG_{self.dagger}") - @property def quda_params(self): params = lib.DiracParam() - params.type = self.quda_type + params.type = int(self.type) params.kappa = self.kappa params.m5 = self.m5 params.Ls = self.Ls params.mu = self.mu params.epsilon = self.epsilon - params.dagger = self.quda_dagger - params.matpcType = self.quda_matPCtype + params.dagger = int(self.dagger) + params.matpcType = int(self.matPCtype) # Needs to prevent the gauge field to get destroyed # now we store QUDA gauge object in _quda, but it @@ -299,16 +293,16 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params): D.quda_dirac.Dslash( xs[-1].quda_field.Odd(), xs[-1].quda_field.Even(), - getattr(lib, "QUDA_ODD_PARITY"), + int(QudaParity["ODD"]), ) D.quda_dirac.M(ps[i].quda_field.Even(), xs[-1].quda_field.Even()) D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_YES")) D.quda_dirac.Dslash( ps[i].quda_field.Odd(), ps[i].quda_field.Even(), - getattr(lib, "QUDA_ODD_PARITY"), + int(QudaParity["ODD"]), ) - D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_NO")) + D.quda_dirac.Dagger(int(QudaDagType["NO"])) else: # Even-odd preconditioned case (i.e., PC in Dirac.type): # use only odd part of phi @@ -317,16 +311,16 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params): D.quda_dirac.Dslash( xs[-1].quda_field.Even(), xs[-1].quda_field.Odd(), - getattr(lib, "QUDA_EVEN_PARITY"), + int(QudaParity["EVEN"]), ) D.quda_dirac.M(ps[i].quda_field.Odd(), xs[-1].quda_field.Odd()) D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_YES")) D.quda_dirac.Dslash( ps[i].quda_field.Even(), ps[i].quda_field.Odd(), - getattr(lib, "QUDA_EVEN_PARITY"), + int(QudaParity["EVEN"]), ) - D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_NO")) + D.quda_dirac.Dagger((int(QudaDagType["NO"]))) for i in range(n): xs[i].apply_gamma5() @@ -404,11 +398,12 @@ def shift(self, value): @property def precision(self): "The precision of the operator (same as the gauge field)" - return QudaPrecision[self._prec] + return self._prec @property + @QudaMatPCType def mat_PC(self): - return QudaMatPCType[self.quda.getMatPCType()] + return self.quda.getMatPCType() @property def flops(self): diff --git a/lyncs_quda/enum.py b/lyncs_quda/enum.py index 419fddd..4b4d4e3 100644 --- a/lyncs_quda/enum.py +++ b/lyncs_quda/enum.py @@ -26,6 +26,18 @@ def __eq__(self, other): return self.cls is other.cls and int(self) == int(other) return False + def __ne__(self, other): + return not ( + self == other + ) # TODO: perhaps better to insert if-cond for NotImpelented + + def __contains__(self, other): + if isinstance(other, str): + return self.cls.clean(other) in str(self) + if isinstance(other, int): + return self.to_string(other) in str(self) + return False + class EnumMeta(type): "Metaclass for enum types" @@ -42,18 +54,19 @@ def items(cls): "List of enum items" return cls._values.items() - def clean(cls, key): + def clean(cls, rep): + # should turn everything into upper for consistency "Strips away prefix and suffix from key" "See enums.py to find what is prefix and suffix for a given enum value" - if isinstance(key, EnumValue): - key = str(key) - if isinstance(key, str): - key = key.lower() - if cls._prefix and key.startswith(cls._prefix): - key = key[len(cls._prefix) :] - if cls._suffix and key.endswith(cls._suffix): - key = key[: -len(cls._suffix)] - return key + if isinstance(rep, EnumValue): + rep = str(rep) + if isinstance(rep, str): + rep = rep.lower() + if cls._prefix and rep.startswith(cls._prefix): + rep = rep[len(cls._prefix) :] + if cls._suffix and rep.endswith(cls._suffix): + rep = rep[: -len(cls._suffix)] + return rep def to_string(cls, rep): "Returns the key representative of the given enum value" @@ -103,17 +116,26 @@ class Enum(metaclass=EnumMeta): _suffix = "" _values = {} - def __init__(self, key, default=None, callback=None): - self.key = key + def __init__(self, fnc, lpath=None, default=None, callback=None): + # fnc is supposed to return either a stripped key name or value of + # the corresponding QUDA enum type + self.fnc = fnc + self.lpath = lpath self.default = default self.callback = callback + def __call__(self, instance): + # intended for property.fget, which then invokes + # property.__get__(self, obj, objtype=None) + return EnumValue(type(self), self.fnc(instance)) + + # not meant to be a stnadard descriptor, c.f., solver.py def __get__(self, instance, owner): if instance is None: raise AttributeError out = instance - for key in self.key.split("."): + for key in self.lpath.split("."): out = getattr(out, key) return type(self)[out] @@ -126,9 +148,9 @@ def __set__(self, instance, new): new = int(type(self)[new]) out = instance - for key in self.key.split(".")[:-1]: + for key in self.lpath.split(".")[:-1]: out = getattr(out, key) - key = self.key.split(".")[-1] + key = self.lpath.split(".")[-1] old = int(getattr(out, key)) setattr(out, key, new) diff --git a/lyncs_quda/gauge_field.py b/lyncs_quda/gauge_field.py index 20a4934..2494e13 100644 --- a/lyncs_quda/gauge_field.py +++ b/lyncs_quda/gauge_field.py @@ -82,9 +82,13 @@ def _check_field(self, field=None): super()._check_field(field) if field is None: field = self - dofs = field.shape[:-self.ndims] - pdofs = prod(dofs) if dofs[0] not in self._geometry_values[1] or dofs[0]==1 else prod(dofs[1:]) - pdofs *= 2**(self.iscomplex) + dofs = field.shape[: -self.ndims] + pdofs = ( + prod(dofs) + if dofs[0] not in self._geometry_values[1] or dofs[0] == 1 + else prod(dofs[1:]) + ) + pdofs *= 2 ** (self.iscomplex) if not (pdofs in (12, 8, 10) or sqrt(pdofs / 2).is_integer()): raise TypeError(f"Unrecognized field dofs {dofs}") @@ -164,6 +168,7 @@ def dofs_per_link(self): return dofs @property + @QudaReconstructType def reconstruct(self): "Reconstruct type of the field" dofs = self.dofs_per_link @@ -177,11 +182,6 @@ def reconstruct(self): return "NO" return "INVALID" - @property - def quda_reconstruct(self): - "Quda enum for reconstruct type of the field" - return int(QudaReconstructType[self.reconstruct]) - @property def ncol(self): "Number of colors" @@ -193,6 +193,7 @@ def ncol(self): return 3 @property + @QudaGaugeFieldOrder def order(self): "Data order of the field" dofs = self.dofs_per_link @@ -202,11 +203,6 @@ def order(self): return "FLOAT4" return "FLOAT2" - @property - def quda_order(self): - "Quda enum for data order of the field" - return int(QudaGaugeFieldOrder[self.order]) - @property def _geometry_values(self): return ( @@ -215,6 +211,7 @@ def _geometry_values(self): ) @property + @QudaFieldGeometry def geometry(self): """ Geometry of the field @@ -236,11 +233,6 @@ def nlinks(self): return self.dofs[0] return 1 - @property - def quda_geometry(self): - "Quda enum for geometry of the field" - return int(QudaFieldGeometry[self.geometry]) - @property def is_coarse(self): "Whether is a coarse gauge field" @@ -260,16 +252,13 @@ def is_momentum(self, value): self._quda = None @property + @QudaTboundary def t_boundary(self): "Boundary conditions in time" return "PERIODIC_T" @property - def quda_t_boundary(self): - "Quda enum for boundary conditions in time" - return int(QudaTboundary[self.t_boundary]) - - @property + @QudaLinkType def link_type(self): "Type of the links" if self.is_coarse: @@ -278,11 +267,6 @@ def link_type(self): return "MOMENTUM" return "SU3" - @property - def quda_link_type(self): - "Quda enum for link type" - return int(QudaLinkType[self.link_type]) - @staticmethod @cache def _quda_params(*args, **kwargs): @@ -300,16 +284,16 @@ def quda_params(self): # TODO: Allow control on QudaGaugeFixed, i_mu, nFace, anisotropy, tadpole, compute_fat_link_max, params = self._quda_params( self.quda_dims, - self.quda_precision, - self.quda_reconstruct, + int(self.precision), + int(self.reconstruct), self.pad, - self.quda_geometry, - self.quda_ghost_exchange, - location=self.quda_location, - link_type=self.quda_link_type, + int(self.geometry), + int(self.ghost_exchange), + location=int(self.location), + link_type=int(self.link_type), create=int(QudaFieldCreate["reference"]), - t_boundary=self.quda_t_boundary, - order=self.quda_order, + t_boundary=int(self.t_boundary), + order=int(self.order), nColor=self.ncol, ) params.gauge = to_pointer(self.ptr) @@ -326,7 +310,7 @@ def quda_field(self): def is_native(self): "Whether the field is native for Quda" return lib.gauge.isNative( - self.quda_order, self.quda_precision, self.quda_reconstruct + int(self.order), int(self.precision), int(self.reconstruct) ) def extended_field(self, sites=1): @@ -567,7 +551,6 @@ def plaquette(self): raise NotImplementedError( "The underlying QUDA function will not work without GPU" ) - if self.geometry != "VECTOR": raise TypeError("This gauge object needs to have VECTOR geometry") plaq = lib.plaquette(self.extended_field(1)) @@ -613,7 +596,7 @@ def topological_charge_density(self, density=None): self = self.compute_fmunu() charge = numpy.zeros(4, dtype="double") if density is None: - density = self.new(dofs=(1,), dtype=self.precision) + density = self.new(dofs=(1,), dtype=str(self.precision)) lib.computeQChargeDensity(charge[:3], charge[3:], density.ptr, self.quda_field) return charge[3], tuple(charge[:3]), density diff --git a/lyncs_quda/lattice_field.py b/lyncs_quda/lattice_field.py index 11e205d..cdd57fd 100644 --- a/lyncs_quda/lattice_field.py +++ b/lyncs_quda/lattice_field.py @@ -184,7 +184,7 @@ def copy(self, other=None, out=None, **kwargs): # * check if this dose not cause any bugs if it overwrites ndarray.copy # - For now, we cast self to ndarray before performing ndarray methods like flatten # - the second arg should be "order='C'" to match the signiture? - + # check=False => here any output is accepted out = self.prepare_out(out, check=False, **kwargs) if other is None: @@ -256,9 +256,9 @@ def prepare_in(self, fields, **kwargs): return self.prepare(fields, **kwargs) _children = {} - + def __new__(cls, field, **kwargs): - #TODO: get dofs and local dims from kwargs, instead of getting them + # TODO: get dofs and local dims from kwargs, instead of getting them # from self.shape assuming that it has the form (dofs, local_dims) if isinstance(field, cls): @@ -269,14 +269,14 @@ def __new__(cls, field, **kwargs): ) parent = numpy.ndarray if isinstance(field, numpy.ndarray) else cupy.ndarray if (cls, parent) not in cls._children: - cls._children[(cls,parent)] = type(cls.__name__+"ext",(cls, parent), {}) - obj = field.view(type=cls._children[(cls,parent)]) - #self._dims = kwargs.get("dims", self.shape[-self.ndims :]) - #self._dofs = kwargs.get("dofs", field.shape[: -self.ndims]) - + cls._children[(cls, parent)] = type(cls.__name__ + "ext", (cls, parent), {}) + obj = field.view(type=cls._children[(cls, parent)]) + # self._dims = kwargs.get("dims", self.shape[-self.ndims :]) + # self._dofs = kwargs.get("dofs", field.shape[: -self.ndims]) + return obj - #field check should be performed + # field check should be performed def __array_finalize__(self, obj): "Support for __array_finalize__ standard" # Note: this is called when creating a temporary, possibly causing @@ -285,10 +285,11 @@ def __array_finalize__(self, obj): # self: newly created instance # obj: input instance self._check_field(obj) - if obj is None: return # can be removed; we are not bypassing ndarray.__new__ + if obj is None: + return # can be removed; we are not bypassing ndarray.__new__ # This will require some care when we use attr for dims and dofs in _check_field - #self._dims = kwargs.get("dims", self.shape[-self.ndims :]) - #self._dofs = kwargs.get("dofs", field.shape[: -self.ndims]) + # self._dims = kwargs.get("dims", self.shape[-self.ndims :]) + # self._dofs = kwargs.get("dofs", field.shape[: -self.ndims]) self.__init__(obj, comm=getattr(obj, "comm", None)) def __init__(self, field, comm=None, **kwargs): @@ -308,7 +309,7 @@ def _check_field(self, field=None): raise TypeError("Field is stored on a different device than the quda lib") if len(field.shape) < 4: raise ValueError("A lattice field should not have shape smaller than 4") - + def activate(self): "Activates the current field. To be called before using the object in quda" "to make sure the communicator is set for MPI" @@ -332,7 +333,7 @@ def complex_view(self): def float_view(self): "Returns a complex view of the field" - #don't need to upcast if we keep dofs and dims as attributes + # don't need to upcast if we keep dofs and dims as attributes if not self.iscomplex: return self.view(type=self.backend.ndarray) return self.view(get_float_dtype(self.dtype), self.backend.ndarray) @@ -358,17 +359,13 @@ def device(self): @property def device_id(self): return getattr(self.device, "id", None) - + @property + @QudaFieldLocation def location(self): "Memory location of the field (CPU or CUDA)" return "CPU" if isinstance(self, numpy.ndarray) else "CUDA" - @property - def quda_location(self): - "Quda enum for memory location of the field (CPU or CUDA)" - return int(QudaFieldLocation[self.location]) - @property def ndims(self): "Number of lattice dimensions" @@ -412,25 +409,17 @@ def isreal(self): return self.backend.isrealobj(self) @property + @QudaPrecision def precision(self): "Field data type precision" return get_precision(self.dtype) @property - def quda_precision(self): - "Quda enum for field data type precision" - return int(QudaPrecision[self.precision]) - - @property + @QudaGhostExchange def ghost_exchange(self): "Ghost exchange" return "NO" - @property - def quda_ghost_exchange(self): - "Quda enum for ghost exchange" - return int(QudaGhostExchange[self.ghost_exchange]) - @property def pad(self): "Memory padding" @@ -456,9 +445,9 @@ def quda_params(self): self.ndims, self.quda_dims, self.pad, - self.quda_location, - self.quda_precision, - self.quda_ghost_exchange, + int(self.location), + int(self.precision), + int(self.ghost_exchange), ) # ? this assumes: mem_type(QUDA_MEMORY_DEVICE), diff --git a/lyncs_quda/solver.py b/lyncs_quda/solver.py index 43e6adc..dbb4fe9 100644 --- a/lyncs_quda/solver.py +++ b/lyncs_quda/solver.py @@ -7,7 +7,7 @@ "Solver", ] -from functools import wraps +from functools import wraps, cache from warnings import warn from lyncs_cppyy import nullptr, make_shared from .dirac import Dirac, DiracMatrix @@ -107,7 +107,7 @@ def mat(self, mat): if not isinstance(mat, DiracMatrix): raise TypeError("mat should be an instance of Dirac or DiracMatrix") self._mat = mat - self._params.precision = int(QudaPrecision[mat.precision]) + self._params.precision = int(mat.precision) # we should not call this method after setting the below fields self._mat_sloppy = None self._mat_precon = None @@ -131,7 +131,7 @@ def mat_sloppy(self): return self._get_mat("_mat_sloppy", self.precision_sloppy) precision_sloppy = QudaPrecision( - "_params.precision_sloppy", default=lambda self: self.precision + None, lpath="_params.precision_sloppy", default=lambda self: self.precision ) @property @@ -139,7 +139,9 @@ def mat_precon(self): return self._get_mat("_mat_precon", self.precision_precondition) precision_precondition = QudaPrecision( - "_params.precision_precondition", default=lambda self: self.precision + None, + lpath="_params.precision_precondition", + default=lambda self: self.precision, ) @property @@ -147,7 +149,7 @@ def mat_eig(self): return self._get_mat("_mat_eig", self.precision_eigensolver) precision_eigensolver = QudaPrecision( - "_params.precision_eigensolver", default=lambda self: self.precision + None, lpath="_params.precision_eigensolver", default=lambda self: self.precision ) @property @@ -162,8 +164,10 @@ def profiler(self, value): raise TypeError self._profiler = value - inv_type = QudaInverterType("_params.inv_type") - inv_type_precondition = QudaInverterType("_params.inv_type_precondition") + inv_type = QudaInverterType(None, lpath="_params.inv_type") + inv_type_precondition = QudaInverterType( + None, lpath="_params.inv_type_precondition" + ) @property def preconditioner(self): @@ -173,7 +177,7 @@ def preconditioner(self): def preconditioner(self, value): if value is None: self._precon = None - self._params.inv_type_precondition = lib.QUDA_INVALID_INVERTER + self._params.inv_type_precondition = int(QudaInverterType["INVALID"]) self._params.preconditioner = nullptr else: raise NotImplementedError @@ -184,10 +188,12 @@ def _update_return_residual(self, old, new): self._params.preserve_source = not new return_residual = QudaBoolean( - "_params.return_residual", callback=_update_return_residual + None, lpath="_params.return_residual", callback=_update_return_residual ) - residual_type = QudaResidualType("_params.residual_type", default="L2_RELATIVE") + residual_type = QudaResidualType( + None, lpath="_params.residual_type", default="L2_RELATIVE" + ) @property def quda(self): diff --git a/lyncs_quda/spinor_field.py b/lyncs_quda/spinor_field.py index 0187053..9889b43 100644 --- a/lyncs_quda/spinor_field.py +++ b/lyncs_quda/spinor_field.py @@ -8,12 +8,21 @@ "SpinorField", ] -from functools import reduce +from functools import reduce, cache from time import time from lyncs_cppyy import make_shared from lyncs_cppyy.ll import to_pointer from .lib import lib from .lattice_field import LatticeField +from .enums import ( + QudaGammaBasis, + QudaFieldOrder, + QudaTwistFlavorType, + QudaSiteOrder, + QudaPCType, + QudaFieldCreate, + QudaNoiseType, + ) """ NOTE: @@ -74,6 +83,7 @@ def nvec(self): return 1 @property + @QudaGammaBasis def gamma_basis(self): "Gamma basis in use" return self._gamma_basis @@ -90,11 +100,7 @@ def gamma_basis(self, value): self._gamma_basis = value.upper() @property - def quda_gamma_basis(self): - "Quda enum for gamma basis in use" - return getattr(lib, f"QUDA_{self.gamma_basis}_GAMMA_BASIS") - - @property + @QudaFieldOrder def order(self): "Data order of the field" if self.precision in ["single", "half"] and self.nspin == 4: @@ -103,21 +109,13 @@ def order(self): return "FLOAT2" @property - def quda_order(self): - "Quda enum for data order of the field" - return getattr(lib, f"QUDA_{self.order}_FIELD_ORDER") - - @property + @QudaTwistFlavorType def twist_flavor(self): "Twist flavor of the field" return "SINGLET" @property - def quda_twist_flavor(self): - "Quda enum for twist flavor of the field" - return getattr(lib, f"QUDA_TWIST_{self.twist_flavor}") - - @property + @QudaSiteOrder def site_order(self): "Site order in use" return self._site_order @@ -141,32 +139,38 @@ def site_order(self, value): self._site_order = value @property - def quda_site_order(self): - "Quda enum for site order in use" - return getattr(lib, f"QUDA_{self.site_order}_SITE_ORDER") - - @property - def quda_pc_type(self): + @QudaPCType + def pc_type(self): "Select checkerboard preconditioning method" - return getattr(lib, f"QUDA_{self.ndims}D_PC") + return f"{self.ndims}D_PC" + + @staticmethod + @cache + def _spc_params(param, **kwargs): + "Call wrapper to cache param structures" + params = lib.ColorSpinorParam() + lib.copy_struct(params, param) + for key, val in kwargs.items(): + setattr(params, key, val) + return params @property def quda_params(self): "Returns and instance of quda::ColorSpinorParams" - params = lib.ColorSpinorParam() - lib.copy_struct(params, super().quda_params) - params.nColor = self.ncolor - params.nSpin = self.nspin - params.nVec = self.nvec - params.gammaBasis = self.quda_gamma_basis - params.pc_type = self.quda_pc_type - params.twistFlavor = self.quda_twist_flavor - - params.v = to_pointer(self.ptr) - params.create = lib.QUDA_REFERENCE_FIELD_CREATE - params.location = self.quda_location - params.fieldOrder = self.quda_order - params.siteOrder = self.quda_site_order + params = self._spc_params( + super().quda_params, + nColor=self.ncolor, + nSpin=self.nspin, + nVec=self.nvec, + gammaBasis=int(self.gamma_basis), + pc_type=int(self.pc_type), + twistFlavor=int(self.twist_flavor), + v=to_pointer(self.ptr), + create=int(QudaFieldCreate["reference"]), + location=int(self.location), + fieldOrder=int(self.order), + siteOrder=int(self.site_order), + ) return params @property @@ -180,7 +184,7 @@ def quda_field(self): def is_native(self): "Whether the field is native for Quda" return lib.colorspinor.isNative( - self.quda_order, self.quda_precision, self.nspin, self.ncolor + int(self.order), int(self.precision), self.nspin, self.ncolor ) def zero(self): @@ -190,12 +194,12 @@ def zero(self): def gaussian(self, seed=None): "Generates a random gaussian noise spinor" seed = seed or int(time() * 1e9) - lib.spinorNoise(self.quda_field, seed, lib.QUDA_NOISE_GAUSS) + lib.spinorNoise(self.quda_field, seed, int(QudaNoiseType["GAUSS"])) def uniform(self, seed=None): "Generates a random uniform noise spinor" seed = seed or int(time() * 1e9) - lib.spinorNoise(self.quda_field, seed, lib.QUDA_NOISE_UNIFORM) + lib.spinorNoise(self.quda_field, seed, int(QudaNoiseType["UNIFORM"])) def gamma5(self, out=None): "Returns the vector transformed by gamma5" diff --git a/test/test_clover_field.py b/test/test_clover_field.py index 6019237..fa5d6e8 100644 --- a/test/test_clover_field.py +++ b/test/test_clover_field.py @@ -44,14 +44,14 @@ def test_params(lib, lattice, device, dtype): assert params.mu2 == clv.mu2 assert params.epsilon2 == clv.eps2 assert params.rho == clv.rho - assert params.order == clv.quda_order + assert params.order == clv.order assert params.create == lib.QUDA_REFERENCE_FIELD_CREATE - assert params.location == clv.quda_location - assert params.Precision() == clv.quda_precision + assert params.location == clv.location + assert params.Precision() == clv.precision assert params.nDim == clv.ndims assert tuple(params.x)[: clv.ndims] == clv.dims assert params.pad == clv.pad - assert params.ghostExchange == clv.quda_ghost_exchange + assert params.ghostExchange == clv.ghost_exchange @dtype_loop # enables dtype diff --git a/test/test_dirac.py b/test/test_dirac.py index 6917b01..7bc560c 100644 --- a/test/test_dirac.py +++ b/test/test_dirac.py @@ -29,7 +29,7 @@ def test_params(lib, lattice, device, dtype): dirac = gf.Dirac() params = dirac.quda_params assert dirac.precision == gf.precision - assert params.type == dirac.quda_type + assert params.type == dirac.type assert params.kappa == dirac.kappa assert params.m5 == dirac.m5 assert params.Ls == dirac.Ls diff --git a/test/test_field.py b/test/test_field.py index 868eea9..a600455 100644 --- a/test/test_field.py +++ b/test/test_field.py @@ -2,6 +2,7 @@ from pytest import raises import numpy as np import cupy as cp + shape = (4, 3, 4, 4, 4, 4) @@ -46,7 +47,7 @@ def test_numpy(): field /= 1 assert field2 == field - + def test_cupy(): field = LatticeField(cp.zeros(shape)) assert field.location == "CUDA" diff --git a/test/test_gauge_field.py b/test/test_gauge_field.py index 88ebd3f..949f75e 100644 --- a/test/test_gauge_field.py +++ b/test/test_gauge_field.py @@ -10,7 +10,7 @@ epsilon_loop, ) from lyncs_cppyy.ll import addressof -from lyncs_utils import isclose#, allclose +from lyncs_utils import isclose # , allclose @lattice_loop @@ -30,18 +30,18 @@ def test_params(lib, lattice, device, dtype): assert gf.is_native() assert params.nColor == 3 assert params.nFace == 0 - assert params.reconstruct == gf.quda_reconstruct - assert params.location == gf.quda_location - assert params.order == gf.quda_order - assert params.t_boundary == gf.quda_t_boundary - assert params.link_type == gf.quda_link_type - assert params.geometry == gf.quda_geometry + assert params.reconstruct == gf.reconstruct + assert params.location == gf.location + assert params.order == gf.order + assert params.t_boundary == gf.t_boundary + assert params.link_type == gf.link_type + assert params.geometry == gf.geometry assert addressof(params.gauge) == gf.ptr - assert params.Precision() == gf.quda_precision + assert params.Precision() == gf.precision assert params.nDim == gf.ndims assert tuple(params.x)[: gf.ndims] == gf.dims assert params.pad == gf.pad - assert params.ghostExchange == gf.quda_ghost_exchange + assert params.ghostExchange == gf.ghost_exchange @dtype_loop # enables dtype @@ -234,17 +234,22 @@ def test_force(lib, lattice, device, epsilon): zeros = getattr(gf, path + "_field")(coeffs=0, force=True) assert zeros == 0 + from lyncs_utils import isiterable from collections.abc import Mapping + + def values(dct): "Calls values, if available, or dict.values" try: return dct.values() except AttributeError: return dict.values(dct) + + def allclose(left, right, **kwargs): if isinstance(left, cp.ndarray) and isinstance(right, cp.ndarray): - return np.allclose(left,right) + return np.allclose(left, right) if isinstance(left, cp.ndarray) and not isinstance(right, cp.ndarray): left = [left] * len(right) if not isinstance(left, cp.ndarray) and isinstance(right, cp.ndarray): @@ -264,6 +269,7 @@ def allclose(left, right, **kwargs): pairs = zip(left, right) return all((allclose(*pair, **kwargs) for pair in pairs)) + # @dtype_loop # enables dtype @device_loop # enables device @lattice_loop # enables lattice diff --git a/test/test_spinor.py b/test/test_spinor.py index 9ad4a21..7b26505 100644 --- a/test/test_spinor.py +++ b/test/test_spinor.py @@ -31,18 +31,18 @@ def test_params(lib, lattice, device, dtype): assert params.nColor == sf.ncolor assert params.nSpin == sf.nspin assert params.nVec == sf.nvec - assert params.gammaBasis == sf.quda_gamma_basis - assert params.pc_type == sf.quda_pc_type + assert params.gammaBasis == sf.gamma_basis + assert params.pc_type == sf.pc_type - assert params.location == sf.quda_location - assert params.fieldOrder == sf.quda_order - assert params.siteOrder == sf.quda_site_order + assert params.location == sf.location + assert params.fieldOrder == sf.order + assert params.siteOrder == sf.site_order assert addressof(params.v) == sf.ptr - assert params.Precision() == sf.quda_precision + assert params.Precision() == sf.precision assert params.nDim == sf.ndims assert tuple(params.x)[: sf.ndims] == sf.dims assert params.pad == sf.pad - assert params.ghostExchange == sf.quda_ghost_exchange + assert params.ghostExchange == sf.ghost_exchange @dtype_loop # enables dtype