4040import numpy as np
4141import pytensor .tensor as pt
4242
43- from pytensor import scan
43+ from pytensor import graph_replace , scan
4444from pytensor .gradient import jacobian
4545from pytensor .graph .basic import Apply , Variable
4646from pytensor .graph .fg import FunctionGraph
@@ -163,6 +163,8 @@ def __str__(self):
163163class MeasurableTransform (MeasurableElemwise ):
164164 """A placeholder used to specify a log-likelihood for a transformed measurable variable."""
165165
166+ __props__ = ("scalar_op" , "inplace_pattern" , "is_discrete" )
167+
166168 valid_scalar_types = (
167169 Exp ,
168170 Log ,
@@ -187,16 +189,55 @@ class MeasurableTransform(MeasurableElemwise):
187189 transform_elemwise : Transform
188190 measurable_input_idx : int
189191
190- def __init__ (self , * args , transform : Transform , measurable_input_idx : int , ** kwargs ):
192+ def __init__ (
193+ self , * args , transform : Transform , measurable_input_idx : int , is_discrete : bool , ** kwargs
194+ ):
191195 self .transform_elemwise = transform
192196 self .measurable_input_idx = measurable_input_idx
197+ self .is_discrete = is_discrete
193198 super ().__init__ (* args , ** kwargs )
194199
195200
201+ def abs_logprob (op , value , x , ** kwargs ):
202+ """Compute the log-CDF graph for an absolute value transformation.
203+
204+ For `Y = |X|`, we have `PDF_Y(y) = PDF_Y(-y) + PDF_Y(y)`.
205+ Except for discrete distributions where there's a special case `P(Y=0) = P(X=0)`.
206+ """
207+ logprob_pos = _logprob_helper (x , value )
208+ logprob_neg = graph_replace (logprob_pos , {value : - value })
209+ if op .is_discrete :
210+ logprob = pt .switch (
211+ pt .eq (value , 0 ),
212+ logprob_pos ,
213+ pt .logaddexp (logprob_pos , logprob_neg ),
214+ )
215+ else :
216+ logprob = pt .logaddexp (logprob_pos , logprob_neg )
217+ logprob = pt .where (value < 0 , - np .inf , logprob )
218+ return logprob
219+
220+
221+ def abs_logcdf (op , value , x , ** kwargs ):
222+ """Compute the log-CDF graph for an absolute value transformation.
223+
224+ For `Y = |X|`, we have `CDF_Y(y) = P(|X| <= y) = P(-y <= X <= y) = CDF_X(y) - CDF_X(-y)`.
225+ """
226+ logcdf_pos = _logcdf_helper (x , value )
227+ neg_value = - value - 1 if op .is_discrete else - value
228+ logcdf_neg = graph_replace (logcdf_pos , {value : neg_value })
229+ logcdf = logdiffexp (logcdf_pos , logcdf_neg )
230+ logcdf = pt .where (value < 0 , - np .inf , logcdf )
231+ return logcdf
232+
233+
196234@_logprob .register (MeasurableTransform )
197235def measurable_transform_logprob (op : MeasurableTransform , values , * inputs , ** kwargs ):
198236 """Compute the log-probability graph for a `MeasurabeTransform`."""
199237 # TODO: Could other rewrites affect the order of inputs?
238+ if isinstance (op .scalar_op , Abs ):
239+ return abs_logprob (op , values [0 ], * inputs , ** kwargs )
240+
200241 (value ,) = values
201242 other_inputs = list (inputs )
202243 measurable_input = other_inputs .pop (op .measurable_input_idx )
@@ -207,6 +248,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
207248
208249 # Some transformations, like squaring may produce multiple backward values
209250 if isinstance (backward_value , tuple ):
251+ if op .is_discrete :
252+ # Discrete variables tend to have the tricky x=0 case, get out if we don't have a custom implementation
253+ raise NotImplementedError (
254+ "Logprob of transformed discrete variables with non-injective transforms not implemented"
255+ )
210256 input_logprob = pt .logaddexp (
211257 * (
212258 _logprob_helper (measurable_input , backward_val , ** kwargs )
@@ -225,8 +271,11 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
225271 ndim_supp = value .ndim - input_logprob .ndim
226272 jacobian = jacobian .sum (axis = tuple (range (- ndim_supp , 0 )))
227273
274+ # Discrete transformations do not need the jacobian adjustment
275+ logprob = input_logprob if op .is_discrete else input_logprob + jacobian
276+
228277 # The jacobian is used to ensure a value in the supported domain was provided
229- return pt .switch (pt .isnan (jacobian ), - np .inf , input_logprob + jacobian )
278+ return pt .switch (pt .isnan (jacobian ), - np .inf , logprob )
230279
231280
232281MONOTONICALLY_INCREASING_OPS = (Exp , Log , Add , Sinh , Tanh , ArcSinh , ArcCosh , ArcTanh , Erf , Sigmoid )
@@ -236,6 +285,10 @@ def measurable_transform_logprob(op: MeasurableTransform, values, *inputs, **kwa
236285@_logcdf .register (MeasurableTransform )
237286def measurable_transform_logcdf (op : MeasurableTransform , value , * inputs , ** kwargs ):
238287 """Compute the log-CDF graph for a `MeasurabeTransform`."""
288+ if isinstance (op .scalar_op , Abs ):
289+ # Special case for absolute value transformation
290+ return abs_logcdf (op , value , * inputs , ** kwargs )
291+
239292 other_inputs = list (inputs )
240293 measurable_input = other_inputs .pop (op .measurable_input_idx )
241294 backward_value = op .transform_elemwise .backward (value , * other_inputs )
@@ -245,10 +298,8 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
245298 if isinstance (backward_value , tuple ):
246299 raise NotImplementedError
247300
248- is_discrete = measurable_input .type .dtype .startswith ("int" )
249-
250301 logcdf = _logcdf_helper (measurable_input , backward_value )
251- if is_discrete :
302+ if op . is_discrete :
252303 logccdf = pt .log1mexp (_logcdf_helper (measurable_input , backward_value - 1 ))
253304 else :
254305 logccdf = pt .log1mexp (logcdf )
@@ -275,9 +326,6 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
275326 # We don't know if this Op is monotonically increasing/decreasing
276327 raise NotImplementedError
277328
278- if is_discrete :
279- return logcdf
280-
281329 # The jacobian is used to ensure a value in the supported domain was provided
282330 jacobian = op .transform_elemwise .log_jac_det (value , * other_inputs )
283331 return pt .switch (pt .isnan (jacobian ), - np .inf , logcdf )
@@ -286,13 +334,12 @@ def measurable_transform_logcdf(op: MeasurableTransform, value, *inputs, **kwarg
286334@_icdf .register (MeasurableTransform )
287335def measurable_transform_icdf (op : MeasurableTransform , value , * inputs , ** kwargs ):
288336 """Compute the inverse CDF graph for a `MeasurabeTransform`."""
337+ if op .is_discrete :
338+ raise NotImplementedError ("icdf of transformed discrete variables not implemented" )
339+
289340 other_inputs = list (inputs )
290341 measurable_input = other_inputs .pop (op .measurable_input_idx )
291342
292- # Do not apply rewrite to discrete variables
293- if measurable_input .type .dtype .startswith ("int" ):
294- raise NotImplementedError ("icdf of transformed discrete variables not implemented" )
295-
296343 if isinstance (op .scalar_op , MONOTONICALLY_INCREASING_OPS ):
297344 pass
298345 elif isinstance (op .scalar_op , MONOTONICALLY_DECREASING_OPS ):
@@ -323,7 +370,7 @@ def measurable_transform_icdf(op: MeasurableTransform, value, *inputs, **kwargs)
323370 # Fail if transformation is not injective
324371 # A TensorVariable is returned in 1-to-1 inversions, and a tuple in 1-to-many
325372 if isinstance (op .transform_elemwise .backward (icdf , * other_inputs ), tuple ):
326- raise NotImplementedError
373+ raise NotImplementedError ( "icdf of non-injective transformations not implemented" )
327374
328375 return icdf
329376
@@ -481,15 +528,22 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
481528 [measurable_input ] = measurable_inputs
482529 [measurable_output ] = node .outputs
483530
484- # Do not apply rewrite to discrete variables except for their addition and negation
485- if measurable_input .type .dtype .startswith ("int" ):
531+ # Do not apply rewrite to discrete variables except if:
532+ # 1. Operation retains a discrete output
533+ # 2. Operation doesn't create holes in the support
534+ # Reason:
535+ # 1. Due to a limitation in our IR we don't know the type of the MeasurableVariable
536+ # We don't want to make other rewrites think they are dealing with continuous variables when they are not
537+ # 2. We don't want to add cumbersome within-domain checks
538+ is_discrete = measurable_input .type .dtype .startswith ("int" )
539+ if is_discrete :
540+ if not measurable_output .type .dtype .startswith ("int" ):
541+ return None
486542 if not (
487- find_negated_var (measurable_output ) is not None or isinstance (node .op .scalar_op , Add )
543+ isinstance (node .op .scalar_op , Add | Abs )
544+ or find_negated_var (measurable_output ) is not None
488545 ):
489546 return None
490- # Do not allow rewrite if output is cast to a float, because we don't have meta-info on the type of the MeasurableVariable
491- if not measurable_output .type .dtype .startswith ("int" ):
492- return None
493547
494548 # Check that other inputs are not potentially measurable, in which case this rewrite
495549 # would be invalid
@@ -545,6 +599,7 @@ def find_measurable_transforms(fgraph: FunctionGraph, node: Apply) -> list[Varia
545599 scalar_op = scalar_op ,
546600 transform = transform ,
547601 measurable_input_idx = measurable_input_idx ,
602+ is_discrete = is_discrete ,
548603 )
549604 transform_out = transform_op .make_node (* transform_inputs ).default_output ()
550605 return [transform_out ]
0 commit comments