Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion src/MeasureTheory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,6 @@ import MeasureBase:
smf,
invsmf,
massof
import MeasureBase: ≪
using MeasureBase: BoundedInts, BoundedReals, CountingBase, IntegerDomain, IntegerNumbers
using MeasureBase: weightedmeasure, restrict
using MeasureBase: AbstractTransitionKernel
Expand Down
16 changes: 8 additions & 8 deletions src/combinators/transforms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ insupport(d::Pushforward, x) = insupport(d.μ, inverse(d.f)(x))
Pushforward(f, μ) = Pushforward(f, μ, True())

function Pretty.tile(pf::Pushforward{<:TV.CallableTransform})
Pretty.list_layout(Pretty.tile.([pf.f.t, pf.μ, pf.logjac]); prefix = :Pushforward)
Pretty.list_layout(Pretty.tile.([pf.f.x, pf.μ, pf.logjac]); prefix = :Pushforward)
end

function Pretty.tile(pf::Pushforward)
Expand All @@ -31,7 +31,7 @@ Pullback(f, ν) = Pullback(f, ν, True())
insupport(d::Pullback, x) = insupport(d.ν, d.f(x))

function Pretty.tile(pf::Pullback{<:TV.CallableTransform})
Pretty.list_layout(Pretty.tile.([pf.f.t, pf.ν, pf.logjac]); prefix = :Pullback)
Pretty.list_layout(Pretty.tile.([pf.f.x, pf.ν, pf.logjac]); prefix = :Pullback)
end

function Pretty.tile(pf::Pullback)
Expand All @@ -41,7 +41,7 @@ end
@inline function logdensity_def(pb::Pullback{F,M,True}, x) where {F<:CallableTransform,M}
f = pb.f
ν = pb.ν
y, logJ = TV.transform_and_logjac(f.t, x)
y, logJ = TV.transform_and_logjac(f.x, x)
return logdensity_def(ν, y) + logJ
end

Expand All @@ -55,8 +55,8 @@ end
@inline function logdensity_def(pf::Pushforward{F,M,True}, y) where {F<:CallableTransform,M}
f = pf.f
μ = pf.μ
x = TV.inverse(f.t)(y)
_, logJ = TV.transform_and_logjac(f.t, x)
x = TV.inverse(f.x)(y)
_, logJ = TV.transform_and_logjac(f.x, x)
return logdensity_def(μ, x) - logJ
end

Expand All @@ -66,7 +66,7 @@ end
) where {F<:CallableTransform,M}
f = pf.f
μ = pf.μ
x = TV.inverse(f.t)(y)
x = TV.inverse(f.x)(y)
return logdensity_def(μ, x)
end

Expand All @@ -76,9 +76,9 @@ function Pushforward(f::AbstractTransform, ν, logjac = True())
Pushforward(TV.transform(f), ν, logjac)
end

Pullback(f::CallableInverse, ν, logjac = True()) = Pushforward(TV.transform(f.t), ν, logjac)
Pullback(f::CallableInverse, ν, logjac = True()) = Pushforward(TV.transform(f.x), ν, logjac)

Pushforward(f::CallableInverse, ν, logjac = True()) = Pullback(TV.transform(f.t), ν, logjac)
Pushforward(f::CallableInverse, ν, logjac = True()) = Pullback(TV.transform(f.x), ν, logjac)

Base.rand(rng::AbstractRNG, T::Type, ν::Pushforward) = ν.f(rand(rng, T, ν.μ))

Expand Down
9 changes: 5 additions & 4 deletions src/distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,16 @@ const Dists = Distributions
error("Not implemented:\nas($d)")
end

using TransformVariables: ShiftedExp, ScaledShiftedLogistic
using TransformVariables: TVShift, TVExp, TVNeg, TVScale, TVLogistic

function asTransform(supp::Dists.RealInterval)
(lb, ub) = (supp.lb, supp.ub)

(lb, ub) == (-Inf, Inf) && (return asℝ)
isinf(ub) && return ShiftedExp(true, lb)
isinf(lb) && return ShiftedExp(false, lb)
return ScaledShiftedLogistic(ub - lb, lb)
isinf(ub) && return TVShift(lb) ∘ TVExp()
isinf(lb) && return TVShift(lb) ∘ TVNeg() ∘ TVExp()
shift, scale = promote(lb, ub - lb)
return TVShift(shift) ∘ TVScale(scale) ∘ TVLogistic()
end

as(μ::AbstractMeasure, _data::NamedTuple) = as(μ)
Expand Down
4 changes: 2 additions & 2 deletions src/parameterized.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ Return a transformation for a given parameterized measure subject to the named t

```
julia> asparams(Binomial{(:p,)}, (n=10,))
TransformVariables.TransformTuple{NamedTuple{(:p,), Tuple{TransformVariables.ScaledShiftedLogistic{Float64}}}}((p = as𝕀,), 1)
TransformVariables.TransformTuple{NamedTuple{(:p,), Tuple{TransformVariables.CompositeScalarTransform{Tuple{TVShift{Float64}, TVScale{Float64}, TVLogistic}}}}}((p = as𝕀,), 1)
```

------------
Expand All @@ -38,7 +38,7 @@ Return a transformation with no constraints. For example,

```
julia> asparams(Normal{(:μ,:σ)})
TransformVariables.TransformTuple{NamedTuple{(:μ, :σ), Tuple{TransformVariables.Identity, TransformVariables.ShiftedExp{true, Float64}}}}((μ = asℝ, σ = asℝ₊), 2)
TransformVariables.TransformTuple{NamedTuple{(:μ, :σ), Tuple{TransformVariables.Identity, TransformVariables.CompositeScalarTransform{Tuple{TVShift{Float64}, TVExp}}}}}((μ = asℝ, σ = asℝ₊), 2)
```
"""
function asparams end
Expand Down
6 changes: 3 additions & 3 deletions src/transforms/ordered.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ addlogjac(::TV.NoLogJac, Δℓ) = TV.NoLogJac()

using MappedArrays

bounds(t::TV.ShiftedExp{true}) = (t.shift, TV.∞)
bounds(t::TV.ShiftedExp{false}) = (-TV.∞, t.shift)
bounds(t::TV.ScaledShiftedLogistic) = (t.shift, t.scale + t.shift)
bounds(t::TV.CompositeScalarTransform{Tuple{TVShift{T}, TVExp}}) where {T} = (t.transforms[1].shift, TV.∞)
bounds(t::TV.CompositeScalarTransform{Tuple{TVShift{T}, TVNeg, TVExp}}) where {T} = (-TV.∞, t.transforms[1].shift)
bounds(t::TV.TransformVariables.CompositeScalarTransform{Tuple{TVShift{T}, TVScale{T}, TVLogistic}}) where {T} = (t.shift, t.scale + t.shift)
bounds(::TV.Identity) = (-TV.∞, TV.∞)

const OrderedΔx = -8.0
Expand Down
Loading