diff --git a/lib/gpt/core/object_type/base.py b/lib/gpt/core/object_type/base.py index 6eeee02ae..82c3411e0 100644 --- a/lib/gpt/core/object_type/base.py +++ b/lib/gpt/core/object_type/base.py @@ -30,6 +30,7 @@ class ot_base: data_alias = None # ot can be cast as fundamental type data_alias (such as SU(3) -> 3x3 matrix) mtab = {} # x's multiplication table for x * y rmtab = {} # y's multiplication table for x * y + atab = {} # addition lookup table # only vectors shall define otab/itab otab = None # x's outer product multiplication table for x * adj(y) @@ -46,3 +47,13 @@ def data_otype(self): def is_self_dual(self): return False + + def automatic_embedding(self, other): + if isinstance(other, complex): + return ot_singlet() + return None + + def explicit_cast(self, other): + if isinstance(other, ot_base): + return other + return None diff --git a/tests/core/expr.py b/tests/core/expr.py new file mode 100644 index 000000000..7d192b7dc --- /dev/null +++ b/tests/core/expr.py @@ -0,0 +1,34 @@ +import gpt as g +import numpy as np + +def test_factor_unary(): + grid = g.grid([8, 8, 8, 16], g.double) + v = g.vspincolor(grid) + adj_v = g.adj(v) + result = g(adj_v + v) + assert result.otype.__name__ == "ot_vector_spin_color(4,3)" + +def test_addition_with_complex(): + grid = g.grid([8, 8, 8, 16], g.double) + singlet = g.singlet(grid) + result = g(singlet + 2.0j) + assert result.otype.__name__ == "ot_singlet" + +def test_automatic_embedding(): + grid = g.grid([8, 8, 8, 16], g.double) + singlet = g.singlet(grid) + result = g(singlet + 2.0j) + assert result.otype.__name__ == "ot_singlet" + +def test_explicit_casting(): + grid = g.grid([8, 8, 8, 16], g.double) + singlet = g.singlet(grid) + casted = g.convert(singlet, g.ot_singlet()) + assert casted.otype.__name__ == "ot_singlet" + +if __name__ == "__main__": + test_factor_unary() + test_addition_with_complex() + test_automatic_embedding() + test_explicit_casting() + print("All tests passed.")