1
+ # base.py
2
+ # This file contains the base distribution classes for Monte Carlo integration methods
3
+ # It defines foundational classes for sampling distributions and transformations
4
+
1
5
import torch
2
6
from torch import nn
3
7
import numpy as np
4
8
import sys
5
9
from MCintegration .utils import get_device
6
10
11
+ # Constants for numerical stability
12
+ # Small but safe non-zero value
7
13
MINVAL = 10 ** (sys .float_info .min_10_exp + 50 )
8
- MAXVAL = 10 ** (sys .float_info .max_10_exp - 50 )
14
+ MAXVAL = 10 ** (sys .float_info .max_10_exp - 50 ) # Large but safe value
9
15
EPSILON = 1e-16 # Small value to ensure numerical stability
10
16
11
17
12
18
class BaseDistribution (nn .Module ):
13
19
"""
14
- Base distribution of a flow-based model
15
- Parameters do not depend of target variable (as is the case for a VAE encoder)
20
+ Base distribution class for flow-based models.
21
+ This is an abstract base class that provides structure for probability distributions
22
+ used in Monte Carlo integration. Parameters do not depend on target variables
23
+ (unlike a VAE encoder).
16
24
"""
17
25
18
26
def __init__ (self , dim , device = "cpu" , dtype = torch .float32 ):
27
+ """
28
+ Initialize BaseDistribution.
29
+
30
+ Args:
31
+ dim (int): Dimensionality of the distribution
32
+ device (str or torch.device): Device to use for computation
33
+ dtype (torch.dtype): Data type for computations
34
+ """
19
35
super ().__init__ ()
20
36
self .dtype = dtype
21
37
self .dim = dim
22
38
self .device = device
23
39
24
40
def sample (self , batch_size = 1 , ** kwargs ):
25
- """Samples from base distribution
41
+ """
42
+ Sample from the base distribution.
26
43
27
44
Args:
28
- num_samples: Number of samples to draw from the distriubtion
45
+ batch_size (int): Number of samples to draw
46
+ **kwargs: Additional arguments
29
47
30
48
Returns:
31
- Samples drawn from the distribution
49
+ tuple: (samples, log_det_jacobian)
50
+
51
+ Raises:
52
+ NotImplementedError: This is an abstract method
32
53
"""
33
54
raise NotImplementedError
34
55
35
56
def sample_with_detJ (self , batch_size = 1 , ** kwargs ):
57
+ """
58
+ Sample from base distribution with Jacobian determinant (not log).
59
+
60
+ Args:
61
+ batch_size (int): Number of samples to draw
62
+ **kwargs: Additional arguments
63
+
64
+ Returns:
65
+ tuple: (samples, det_jacobian)
66
+ """
36
67
u , detJ = self .sample (batch_size , ** kwargs )
37
- detJ .exp_ ()
68
+ detJ .exp_ () # Convert log_det to det
38
69
return u , detJ
39
70
40
71
41
72
class Uniform (BaseDistribution ):
42
73
"""
43
- Multivariate uniform distribution
74
+ Multivariate uniform distribution over [0,1]^dim.
75
+ Samples from a uniform distribution in the hypercube [0,1]^dim.
44
76
"""
45
77
46
78
def __init__ (self , dim , device = "cpu" , dtype = torch .float32 ):
79
+ """
80
+ Initialize Uniform distribution.
81
+
82
+ Args:
83
+ dim (int): Dimensionality of the distribution
84
+ device (str or torch.device): Device to use for computation
85
+ dtype (torch.dtype): Data type for computations
86
+ """
47
87
super ().__init__ (dim , device , dtype )
48
88
49
89
def sample (self , batch_size = 1 , ** kwargs ):
90
+ """
91
+ Sample from uniform distribution over [0,1]^dim.
92
+
93
+ Args:
94
+ batch_size (int): Number of samples to draw
95
+ **kwargs: Additional arguments
96
+
97
+ Returns:
98
+ tuple: (uniform samples, log_det_jacobian=0)
99
+ """
50
100
# torch.manual_seed(0) # test seed
51
- u = torch .rand ((batch_size , self .dim ), device = self .device , dtype = self .dtype )
52
- log_detJ = torch .zeros (batch_size , device = self .device , dtype = self .dtype )
101
+ u = torch .rand ((batch_size , self .dim ),
102
+ device = self .device , dtype = self .dtype )
103
+ log_detJ = torch .zeros (
104
+ batch_size , device = self .device , dtype = self .dtype )
53
105
return u , log_detJ
54
106
55
107
56
108
class LinearMap (nn .Module ):
109
+ """
110
+ Linear transformation map of the form x = u * A + b.
111
+ Maps points from one space to another using a linear transformation.
112
+ """
113
+
57
114
def __init__ (self , A , b , device = None , dtype = torch .float32 ):
115
+ """
116
+ Initialize LinearMap with scaling A and offset b.
117
+
118
+ Args:
119
+ A (list, numpy.ndarray, torch.Tensor): Scaling factors
120
+ b (list, numpy.ndarray, torch.Tensor): Offset values
121
+ device (str or torch.device): Device to use for computation
122
+ dtype (torch.dtype): Data type for computations
123
+ """
58
124
if device is None :
59
125
self .device = get_device ()
60
126
else :
@@ -67,24 +133,54 @@ def __init__(self, A, b, device=None, dtype=torch.float32):
67
133
elif isinstance (A , torch .Tensor ):
68
134
self .A = A .to (dtype = self .dtype , device = self .device )
69
135
else :
70
- raise ValueError ("'A' must be a list, numpy array, or torch tensor." )
136
+ raise ValueError (
137
+ "'A' must be a list, numpy array, or torch tensor." )
71
138
72
139
if isinstance (b , (list , np .ndarray )):
73
140
self .b = torch .tensor (b , dtype = self .dtype , device = self .device )
74
141
elif isinstance (b , torch .Tensor ):
75
142
self .b = b .to (dtype = self .dtype , device = self .device )
76
143
else :
77
- raise ValueError ("'b' must be a list, numpy array, or torch tensor." )
144
+ raise ValueError (
145
+ "'b' must be a list, numpy array, or torch tensor." )
78
146
147
+ # Pre-compute determinant of Jacobian for efficiency
79
148
self ._detJ = torch .prod (self .A )
80
149
81
150
def forward (self , u ):
151
+ """
152
+ Apply forward transformation: x = u * A + b.
153
+
154
+ Args:
155
+ u (torch.Tensor): Input points
156
+
157
+ Returns:
158
+ tuple: (transformed points, log_det_jacobian)
159
+ """
82
160
return u * self .A + self .b , torch .log (self ._detJ .repeat (u .shape [0 ]))
83
161
84
162
def forward_with_detJ (self , u ):
163
+ """
164
+ Apply forward transformation with Jacobian determinant (not log).
165
+
166
+ Args:
167
+ u (torch.Tensor): Input points
168
+
169
+ Returns:
170
+ tuple: (transformed points, det_jacobian)
171
+ """
85
172
u , detJ = self .forward (u )
86
- detJ .exp_ ()
173
+ detJ .exp_ () # Convert log_det to det
87
174
return u , detJ
88
175
89
176
def inverse (self , x ):
177
+ """
178
+ Apply inverse transformation: u = (x - b) / A.
179
+
180
+ Args:
181
+ x (torch.Tensor): Input points
182
+
183
+ Returns:
184
+ tuple: (transformed points, log_det_jacobian)
185
+ """
90
186
return (x - self .b ) / self .A , torch .log (self ._detJ .repeat (x .shape [0 ]))
0 commit comments