Skip to content

Commit 55db846

Browse files
committed
update maps test
1 parent c21b49b commit 55db846

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

MCintegration/maps_test.py

+28
Original file line numberDiff line numberDiff line change
@@ -137,13 +137,41 @@ def test_forward(self):
137137
self.assertEqual(x.shape, u.shape)
138138
self.assertEqual(log_jac.shape, (u.shape[0],))
139139

140+
def test_forward_out_of_bounds(self):
141+
# Test forward transformation with out-of-bounds u values
142+
u = torch.tensor(
143+
[[1.5, 0.5], [-0.1, 0.5]], dtype=torch.float64
144+
) # Out-of-bounds values
145+
x, log_jac = self.vegas.forward(u)
146+
147+
# Check that out-of-bounds x values are clamped to grid boundaries
148+
self.assertTrue(torch.all(x >= 0.0))
149+
self.assertTrue(torch.all(x <= 1.0))
150+
151+
# Check log determinant adjustment for out-of-bounds cases
152+
self.assertEqual(log_jac.shape, (u.shape[0],))
153+
140154
def test_inverse(self):
141155
# Test inverse transformation
142156
x = torch.tensor([[0.1, 0.2], [0.3, 0.4]], dtype=torch.float64)
143157
u, log_jac = self.vegas.inverse(x)
144158
self.assertEqual(u.shape, x.shape)
145159
self.assertEqual(log_jac.shape, (x.shape[0],))
146160

161+
def test_inverse_out_of_bounds(self):
162+
# Test inverse transformation with out-of-bounds x values
163+
x = torch.tensor(
164+
[[1.5, 0.5], [-0.1, 0.5]], dtype=torch.float64
165+
) # Out-of-bounds values
166+
u, log_jac = self.vegas.inverse(x)
167+
168+
# Check that out-of-bounds u values are clamped to [0, 1]
169+
self.assertTrue(torch.all(u >= 0.0))
170+
self.assertTrue(torch.all(u <= 1.0))
171+
172+
# Check log determinant adjustment for out-of-bounds cases
173+
self.assertEqual(log_jac.shape, (x.shape[0],))
174+
147175
def test_train(self):
148176
# Test training the Vegas class
149177
def f(x, fx):

0 commit comments

Comments
 (0)