Skip to content

Commit 3000126

Browse files
committed
added DirichletDistribution and test
1 parent 41ddbad commit 3000126

File tree

3 files changed

+69
-0
lines changed

3 files changed

+69
-0
lines changed

doc/source/apiref/distributions.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ Specific Distributions
5656

5757
.. autoclass:: BetaBinomialDistribution
5858
:members:
59+
60+
.. autoclass:: DirichletDistribution
61+
:members:
5962

6063
.. autoclass:: GammaDistribution
6164
:members:

src/qinfer/distributions.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@
6666
'SlantedNormalDistribution',
6767
'LogNormalDistribution',
6868
'BetaDistribution',
69+
'DirichletDistribution',
6970
'BetaBinomialDistribution',
7071
'GammaDistribution',
7172
'GinibreUniform',
@@ -996,6 +997,31 @@ def n_rvs(self):
996997

997998
def sample(self, n=1):
998999
return self.dist.rvs(size=n)[:, np.newaxis]
1000+
1001+
class DirichletDistribution(Distribution):
1002+
r"""
1003+
The dirichlet distribution, whose pdf at :math:`x` is proportional to
1004+
:math:`\prod_i x_i^{\alpha_i-1}`.
1005+
1006+
:param alpha: The list of concentration parameters.
1007+
"""
1008+
def __init__(self, alpha):
1009+
self._alpha = np.array(alpha)
1010+
if self.alpha.ndim != 1:
1011+
raise ValueError('The input alpha must be a 1D list of concentration parameters.')
1012+
1013+
self._dist = st.dirichlet(alpha=self.alpha)
1014+
1015+
@property
1016+
def alpha(self):
1017+
return self._alpha
1018+
1019+
@property
1020+
def n_rvs(self):
1021+
return self._alpha.size
1022+
1023+
def sample(self, n=1):
1024+
return self._dist.rvs(size=n)
9991025

10001026
class BetaBinomialDistribution(Distribution):
10011027
r"""

src/qinfer/tests/test_distributions.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -337,6 +337,46 @@ def test_betabinomial_n_rvs(self):
337337
"""
338338
dist = BetaBinomialDistribution(10, alpha=10,beta=42)
339339
assert(dist.n_rvs == 1)
340+
341+
class TestDirichletDistribution(DerandomizedTestCase):
342+
"""
343+
Tests ``DirichletDistribution``
344+
"""
345+
346+
## TEST METHODS ##
347+
348+
def test_dirichlet_moments(self):
349+
"""
350+
Distributions: Checks that the beta distribution has the right
351+
moments, with either of the two input formats
352+
"""
353+
alpha = [1,2,3,4]
354+
alpha_np = np.array(alpha)
355+
alpha_0 = alpha_np.sum()
356+
mean = alpha_np / alpha_0
357+
var = alpha_np * (alpha_0 - alpha_np) / (alpha_0 **2 * (alpha_0+1))
358+
359+
dist = DirichletDistribution(alpha)
360+
samples = dist.sample(100000)
361+
362+
assert samples.shape == (100000, alpha_np.size)
363+
assert_almost_equal(samples.mean(axis=0), mean, 2)
364+
assert_almost_equal(samples.var(axis=0), var, 2)
365+
366+
alpha = np.array([8,7,5,2,2])
367+
alpha_np = np.array(alpha)
368+
alpha_0 = alpha_np.sum()
369+
mean = alpha_np / alpha_0
370+
var = alpha_np * (alpha_0 - alpha_np) / (alpha_0 **2 * (alpha_0+1))
371+
372+
dist = DirichletDistribution(alpha)
373+
samples = dist.sample(100000)
374+
375+
assert samples.shape == (100000, alpha_np.size)
376+
assert_almost_equal(samples.mean(axis=0), mean, 2)
377+
assert_almost_equal(samples.var(axis=0), var, 2)
378+
379+
340380

341381
class TestGammaDistribution(DerandomizedTestCase):
342382
"""

0 commit comments

Comments
 (0)