@@ -137,13 +137,41 @@ def test_forward(self):
137
137
self .assertEqual (x .shape , u .shape )
138
138
self .assertEqual (log_jac .shape , (u .shape [0 ],))
139
139
140
+ def test_forward_out_of_bounds (self ):
141
+ # Test forward transformation with out-of-bounds u values
142
+ u = torch .tensor (
143
+ [[1.5 , 0.5 ], [- 0.1 , 0.5 ]], dtype = torch .float64
144
+ ) # Out-of-bounds values
145
+ x , log_jac = self .vegas .forward (u )
146
+
147
+ # Check that out-of-bounds x values are clamped to grid boundaries
148
+ self .assertTrue (torch .all (x >= 0.0 ))
149
+ self .assertTrue (torch .all (x <= 1.0 ))
150
+
151
+ # Check log determinant adjustment for out-of-bounds cases
152
+ self .assertEqual (log_jac .shape , (u .shape [0 ],))
153
+
140
154
def test_inverse (self ):
141
155
# Test inverse transformation
142
156
x = torch .tensor ([[0.1 , 0.2 ], [0.3 , 0.4 ]], dtype = torch .float64 )
143
157
u , log_jac = self .vegas .inverse (x )
144
158
self .assertEqual (u .shape , x .shape )
145
159
self .assertEqual (log_jac .shape , (x .shape [0 ],))
146
160
161
+ def test_inverse_out_of_bounds (self ):
162
+ # Test inverse transformation with out-of-bounds x values
163
+ x = torch .tensor (
164
+ [[1.5 , 0.5 ], [- 0.1 , 0.5 ]], dtype = torch .float64
165
+ ) # Out-of-bounds values
166
+ u , log_jac = self .vegas .inverse (x )
167
+
168
+ # Check that out-of-bounds u values are clamped to [0, 1]
169
+ self .assertTrue (torch .all (u >= 0.0 ))
170
+ self .assertTrue (torch .all (u <= 1.0 ))
171
+
172
+ # Check log determinant adjustment for out-of-bounds cases
173
+ self .assertEqual (log_jac .shape , (x .shape [0 ],))
174
+
147
175
def test_train (self ):
148
176
# Test training the Vegas class
149
177
def f (x , fx ):
0 commit comments