Skip to content

Commit fd7c3f5

Browse files
add function.WithTransformsBasis
1 parent 4335892 commit fd7c3f5

File tree

2 files changed

+48
-0
lines changed

2 files changed

+48
-0
lines changed

nutils/function.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4379,6 +4379,41 @@ def get_support(self, dof):
43794379
supp2 = self._basis2.get_support(dof2)
43804380
return (supp1[:,_] * len(self._basis2.transforms) + supp2[_,:]).ravel()
43814381

4382+
class WithTransformsBasis(Basis):
4383+
'''Replace the transforms sequence of a basis.
4384+
4385+
Parameters
4386+
----------
4387+
parent : :class:`Basis`
4388+
The basis to wrap.
4389+
transforms : :class:`nutils.transformseq.Transforms`
4390+
The new transforms sequence.
4391+
'''
4392+
4393+
@types.apply_annotations
4394+
def __init__(self, parent:strictbasis, transforms:transformseq.stricttransforms, trans:types.strict[TransformChain]):
4395+
self._parent = parent
4396+
assert len(self._parent.transforms) == len(transforms)
4397+
super().__init__(ndofs=parent.ndofs, transforms=transforms, ndims=parent.ndimsdomain, trans=trans)
4398+
4399+
def get_support(self, dof):
4400+
return self._parent.get_support(dof)
4401+
4402+
def get_dofs(self, ielem):
4403+
return self._parent.get_dofs(ielem)
4404+
4405+
def get_coefficients(self, ielem):
4406+
return self._parent.get_coefficients(ielem)
4407+
4408+
def f_ndofs(self, index):
4409+
return self._parent.f_ndofs(index)
4410+
4411+
def f_dofs(self, index):
4412+
return self._parent.f_dofs(index)
4413+
4414+
def f_coefficients(self, index):
4415+
return self._parent.f_coefficients(index)
4416+
43824417
class DisjointUnionBasis(Basis):
43834418

43844419
__slots__ = '_bases', '_dofsplits', '_elemsplits'

tests/test_function.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1229,6 +1229,19 @@ def setUp(self):
12291229
self.checkndofs = 3
12301230
super().setUp()
12311231

1232+
class WithTransformsBasis(CommonBasis, TestCase):
1233+
def setUp(self):
1234+
root = function.Root('X', 0)
1235+
self.roots = root,
1236+
self.checkcoeffs = [[1],[2,3],[4,5],[6]]
1237+
self.checkdofs = [[0],[2,3],[1,3],[2]]
1238+
self.checkndofs = 4
1239+
parent_transforms = transformseq.PlainTransforms([(transform.Identifier(0,k),) for k in 'abcd'], 0, 0)
1240+
parent = function.PlainBasis(self.checkcoeffs, self.checkdofs, 4, parent_transforms, 0, function.SelectChain(self.roots))
1241+
transforms = transformseq.PlainTransforms([(transform.Identifier(0,k),) for k in 'efgh'], 0, 0)
1242+
self.basis = function.WithTransformsBasis(parent, transforms, function.SelectChain(self.roots))
1243+
super().setUp()
1244+
12321245
class DisjointUnionBasis(CommonBasis, TestCase):
12331246
def setUp(self):
12341247
root = function.Root('X', 0)

0 commit comments

Comments
 (0)