Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 45 additions & 42 deletions lyncs_quda/clover_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@

import numpy
from cppyy.gbl.std import vector
from functools import cache

from lyncs_cppyy import make_shared, to_pointer
from .lib import lib, cupy
from .lattice_field import LatticeField
from .gauge_field import GaugeField
from .enums import QudaParity
from .enums import QudaParity, QudaTwistFlavorType, QudaCloverFieldOrder, QudaFieldCreate

# TODO list
# We want dimension of (cu/num)py array to reflect parity and order
Expand All @@ -32,12 +33,12 @@ class CloverField(LatticeField):
* Only rho is mutable. To change other params, a new instance should be created
* QUDA convention for clover field := 1+i ( kappa csw )/4 sigma_mu,nu F_mu,nu (<-sigma_mu,nu: spinor tensor)
* so that sigma_mu,nu = i[g_mu, g_nu], F_mu,nu = (Q_mu,nu - Q_nu,mu)/8 (1/2 is missing from sigma_mu,nu)
* Apparently, an input to QUDA clover object, coeff = kappa*csw
* wihout a normalization factor of 1/4 or 1/32 (suggested in interface_quda.cpp)
* Apparently, an input to QUDA clover object, coeff = kappa*csw
* wihout a normalization factor of 1/4 or 1/32 (suggested in interface_quda.cpp)
"""

def __new__(cls, fmunu, **kwargs):
#TODO: get dofs and local dims from kwargs, instead of getting them
# TODO: get dofs and local dims from kwargs, instead of getting them
# from self.shape assuming that it has the form (dofs, local_dims)
if isinstance(fmunu, CloverField):
return fmunu
Expand All @@ -52,14 +53,14 @@ def __new__(cls, fmunu, **kwargs):
field = fmunu
else:
fmunu = GaugeField(fmunu)
if not is_clover: # not copying from a clover-field array

if not is_clover: # not copying from a clover-field array
idof = int((fmunu.ncol * fmunu.ndims) ** 2 / 2)
prec = fmunu.dtype
field = fmunu.backend.empty((idof,) + fmunu.dims, dtype=prec)

return super().__new__(cls, field, **kwargs)

def __init__(
self,
obj,
Expand All @@ -70,7 +71,7 @@ def __init__(
eps2=0,
rho=0,
computeTrLog=False,
**kwargs
**kwargs,
):
# WARNING: ndarray object is not supposed to be view-casted to CloverField object
# except in __new__, for which __init__ will be called subsequently,
Expand All @@ -88,12 +89,8 @@ def __init__(
empty=True,
)
self._fmunu = obj.compute_fmunu()
self._direct = (
False # Here, it is a flag to indicate whether the field has been computed
)
self._inverse = (
False # Here, it is a flag to indicate whether the field has been computed
)
self._direct = False # Here, it is a flag to indicate whether the field has been computed
self._inverse = False # Here, it is a flag to indicate whether the field has been computed
self.coeff = coeff
self._twisted = twisted
self._twist_flavor = tf
Expand All @@ -107,12 +104,14 @@ def __init__(
elif isinstance(obj, self.backend.ndarray):
pass
else:
raise ValueError("The input is expected to be ndarray or LatticeField object")
raise ValueError(
"The input is expected to be ndarray or LatticeField object"
)

def _prepare(self, field, copy=False, check=False, **kwargs):
# When CloverField object prepares its input, the input is assumed to be of CloverField
return super()._prepare(field, copy=copy, check=check, is_clover=True, **kwargs)

# naming suggestion: native_view? default_* saved for dofs+lattice?
def default_view(self):
N = 1 if self.order == "FLOAT2" else 4
Expand All @@ -126,6 +125,7 @@ def twisted(self):
return self._twisted

@property
@QudaTwistFlavorType
def twist_flavor(self):
return self._twist_flavor

Expand Down Expand Up @@ -153,16 +153,22 @@ def rho(self, val):
self._rho = val

@property
@QudaCloverFieldOrder
def order(self):
"Data order of the field"
if self.precision == "double":
return "FLOAT2"
return "FLOAT4"

@property
def quda_order(self):
"Quda enum for data order of the field"
return getattr(lib, f"QUDA_{self.order}_CLOVER_ORDER")
@staticmethod
@cache
def _clv_params(param, **kwargs):
"Call wrapper to cache param structures"
params = lib.CloverFieldParam()
lib.copy_struct(params, param)
for key, val in kwargs.items():
setattr(params, key, val)
return params

@property
def quda_params(self):
Expand All @@ -178,20 +184,21 @@ def quda_params(self):
an alias to inverse. not really sure what this is, but does
not work properly when reconstruct==True
"""
params = lib.CloverFieldParam()
lib.copy_struct(params, super().quda_params)
params.inverse = True
params.clover = to_pointer(self.ptr)
params.cloverInv = to_pointer(self._cloverInv.ptr)
params.coeff = self.coeff
params.twisted = self.twisted
params.twist_flavor = getattr(lib, f"QUDA_TWIST_{self.twist_flavor}")
params.mu2 = self.mu2
params.epsilon2 = self.eps2
params.rho = self.rho
params.order = self.quda_order
params.create = lib.QUDA_REFERENCE_FIELD_CREATE
params.location = self.quda_location
params = self._clv_params(
super().quda_params,
inverse=True,
clover=to_pointer(self.ptr),
cloverInv=to_pointer(self._cloverInv.ptr),
coeff=self.coeff,
twisted=self.twisted,
twist_flavor=int(self.twist_flavor),
mu2=self.mu2,
epsilon2=self.eps2,
rho=self.rho,
order=int(self.order),
create=int(QudaFieldCreate["reference"]),
location=int(self.location),
)
return params

@property
Expand Down Expand Up @@ -230,7 +237,7 @@ def trLog(self):

def is_native(self):
"Whether the field is native for Quda"
return lib.clover.isNative(self.quda_order, self.quda_precision)
return lib.clover.isNative(int(self.order), int(self.precision))

@property
def ncol(self):
Expand Down Expand Up @@ -354,11 +361,7 @@ def computeCloverForce(self, gauge, force, D, vxs, vps, mult=2, coeffs=None):
u = gauge.extended_field(sites=R)
if gauge.precision == "double":
u = gauge.prepare_in(gauge, reconstruct="NO").extended_field(sites=R)
lib.cloverDerivative(
force.quda_field, u, oprodEx, 1.0, getattr(lib, "QUDA_ODD_PARITY")
)
lib.cloverDerivative(
force.quda_field, u, oprodEx, 1.0, getattr(lib, "QUDA_EVEN_PARITY")
)
lib.cloverDerivative(force.quda_field, u, oprodEx, 1.0, int(QudaParity["ODD"]))
lib.cloverDerivative(force.quda_field, u, oprodEx, 1.0, int(QudaParity["EVEN"]))

return force
47 changes: 21 additions & 26 deletions lyncs_quda/dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@
from .clover_field import CloverField
from .spinor_field import spinor
from .lib import lib
from .enums import QudaPrecision
from .enums import (
QudaDiracType,
QudaMatPCType,
QudaDagType,
QudaParity,
)


@dataclass(frozen=True)
Expand Down Expand Up @@ -48,6 +53,7 @@ def __post_init__(self):
# TODO: Support more Dirac types
# Unsupported: DomainWall(4D/PC), Mobius(PC/Eofa), (Improved)Staggered(KD/PC), GaugeLaplace(PC), GaugeCovDev
@property
@QudaDiracType
def type(self):
"Type of the operator"
PC = "PC" if not self.full else ""
Expand All @@ -62,22 +68,14 @@ def type(self):
return "TWISTED_CLOVER" + PC

@property
def quda_type(self):
"Quda enum for quda dslash type"
return getattr(lib, f"QUDA_{self.type}_DIRAC")

@property
@QudaMatPCType
def matPCtype(self):
if self.full:
return "INVALID"
parity = "EVEN" if self.even else "ODD"
symm = "_ASYMMETRIC" if not self.symm else ""
return f"{parity}_{parity}{symm}"

@property
def quda_matPCtype(self):
return getattr(lib, f"QUDA_MATPC_{self.matPCtype}")

@property
def is_coarse(self):
"Whether is a coarse operator"
Expand All @@ -88,26 +86,22 @@ def precision(self):
return self.gauge.precision

@property
@QudaDagType
def dagger(self):
"If the operator is daggered"
return "NO"

@property
def quda_dagger(self):
"Quda enum for if the operator is dagger"
return getattr(lib, f"QUDA_DAG_{self.dagger}")

@property
def quda_params(self):
params = lib.DiracParam()
params.type = self.quda_type
params.type = int(self.type)
params.kappa = self.kappa
params.m5 = self.m5
params.Ls = self.Ls
params.mu = self.mu
params.epsilon = self.epsilon
params.dagger = self.quda_dagger
params.matpcType = self.quda_matPCtype
params.dagger = int(self.dagger)
params.matpcType = int(self.matPCtype)

# Needs to prevent the gauge field to get destroyed
# now we store QUDA gauge object in _quda, but it
Expand Down Expand Up @@ -299,16 +293,16 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params):
D.quda_dirac.Dslash(
xs[-1].quda_field.Odd(),
xs[-1].quda_field.Even(),
getattr(lib, "QUDA_ODD_PARITY"),
int(QudaParity["ODD"]),
)
D.quda_dirac.M(ps[i].quda_field.Even(), xs[-1].quda_field.Even())
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_YES"))
D.quda_dirac.Dslash(
ps[i].quda_field.Odd(),
ps[i].quda_field.Even(),
getattr(lib, "QUDA_ODD_PARITY"),
int(QudaParity["ODD"]),
)
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_NO"))
D.quda_dirac.Dagger(int(QudaDagType["NO"]))
else:
# Even-odd preconditioned case (i.e., PC in Dirac.type):
# use only odd part of phi
Expand All @@ -317,16 +311,16 @@ def force(self, *phis, out=None, mult=2, coeffs=None, **params):
D.quda_dirac.Dslash(
xs[-1].quda_field.Even(),
xs[-1].quda_field.Odd(),
getattr(lib, "QUDA_EVEN_PARITY"),
int(QudaParity["EVEN"]),
)
D.quda_dirac.M(ps[i].quda_field.Odd(), xs[-1].quda_field.Odd())
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_YES"))
D.quda_dirac.Dslash(
ps[i].quda_field.Even(),
ps[i].quda_field.Odd(),
getattr(lib, "QUDA_EVEN_PARITY"),
int(QudaParity["EVEN"]),
)
D.quda_dirac.Dagger(getattr(lib, "QUDA_DAG_NO"))
D.quda_dirac.Dagger((int(QudaDagType["NO"])))

for i in range(n):
xs[i].apply_gamma5()
Expand Down Expand Up @@ -404,11 +398,12 @@ def shift(self, value):
@property
def precision(self):
"The precision of the operator (same as the gauge field)"
return QudaPrecision[self._prec]
return self._prec

@property
@QudaMatPCType
def mat_PC(self):
return QudaMatPCType[self.quda.getMatPCType()]
return self.quda.getMatPCType()

@property
def flops(self):
Expand Down
Loading