Skip to content

Commit 7af1472

Browse files
Merge pull request #661 from jedwin3210/edwin/issue_660
feat(util): add seed parameter for reproducible sampling
2 parents 0e9a6f0 + b465473 commit 7af1472

File tree

4 files changed

+39
-6
lines changed

4 files changed

+39
-6
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@ All notable changes to this project will be documented in this file.
44

55
This project adheres to [Semantic Versioning](http://semver.org/).
66

7+
### v0.18.1
8+
* Added optional `seed` parameter to `sample_proportions()` for reproducible results
9+
* Added optional `seed` parameter to `proportions_from_distribution()` for reproducible results
10+
* Fixed bug where `proportions_from_distribution()` ignored the `column_name` parameter
11+
712
### v0.18.0
813
* Improved color contrast for charts (gold, blue)
914
* Fixed make_array() so it doesn't auto-convert booleans to integers

datascience/util.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def plot_normal_cdf(rbound=None, lbound=None, mean=0, sd=1):
126126
plot_cdf_area = plot_normal_cdf
127127

128128

129-
def sample_proportions(sample_size: int, probabilities):
129+
def sample_proportions(sample_size: int, probabilities, seed=None):
130130
"""Return the proportion of random draws for each outcome in a distribution.
131131
132132
This function is similar to np.random.Generator.multinomial, but returns proportions
@@ -137,15 +137,17 @@ def sample_proportions(sample_size: int, probabilities):
137137
138138
``probabilities``: An array of probabilities that forms a distribution.
139139
140+
``seed``: Optional seed for reproducibility. If None, results will be random.
141+
140142
Returns:
141143
An array with the same length as ``probability`` that sums to 1.
142144
"""
143-
rng = np.random.default_rng()
145+
rng = np.random.default_rng(seed)
144146
return rng.multinomial(sample_size, probabilities) / sample_size
145147

146148

147149
def proportions_from_distribution(table, label, sample_size,
148-
column_name='Random Sample'):
150+
column_name='Random Sample', seed=None):
149151
"""
150152
Adds a column named ``column_name`` containing the proportions of a random
151153
draw using the distribution in ``label``.
@@ -165,6 +167,8 @@ def proportions_from_distribution(table, label, sample_size,
165167
``column_name``: The name of the new column that contains the sampled
166168
proportions. Defaults to ``'Random Sample'``.
167169
170+
``seed``: Optional seed for reproducibility. If None, results will be random.
171+
168172
Returns:
169173
A copy of ``table`` with a column ``column_name`` containing the
170174
sampled proportions. The proportions will sum to 1.
@@ -173,8 +177,8 @@ def proportions_from_distribution(table, label, sample_size,
173177
``ValueError``: If the ``label`` is not in the table, or if
174178
``table.column(label)`` does not sum to 1.
175179
"""
176-
proportions = sample_proportions(sample_size, table.column(label))
177-
return table.with_column('Random Sample', proportions)
180+
proportions = sample_proportions(sample_size, table.column(label), seed)
181+
return table.with_column(column_name, proportions)
178182

179183

180184
def table_apply(table, func, subset=None):

datascience/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.18.0'
1+
__version__ = '0.18.1'

tests/test_util.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,30 @@ def test_proportions_from_distribution():
103103
assert [x in (0, 0.5, 1) for x in ds.sample_proportions(2, ds.make_array(.2, .3, .5))]
104104

105105

106+
def test_sample_proportions_seed():
107+
"""Test seed parameter and backward compatibility"""
108+
result1 = ds.sample_proportions(1000, [0.5, 0.5], seed=42)
109+
result2 = ds.sample_proportions(1000, [0.5, 0.5], seed=42)
110+
assert np.array_equal(result1, result2)
111+
112+
result3 = ds.sample_proportions(1000, [0.5, 0.5], seed=99)
113+
assert not np.array_equal(result1, result3)
114+
115+
116+
def test_proportions_from_distribution_seed_and_column_name():
117+
"""Test seed parameter and column_name bug fix"""
118+
t = ds.Table().with_column('probs', [0.6, 0.4])
119+
120+
result1 = ds.proportions_from_distribution(t, 'probs', 1000, seed=42)
121+
result2 = ds.proportions_from_distribution(t, 'probs', 1000, seed=42)
122+
assert np.array_equal(result1.column(1), result2.column(1))
123+
assert _round_eq(1, sum(result1.column(1)))
124+
125+
result3 = ds.proportions_from_distribution(t, 'probs', 1000, column_name='My Sample')
126+
assert 'My Sample' in result3.labels
127+
assert result3.num_columns == 2
128+
129+
106130
def test_is_non_string_iterable():
107131
is_string = 'hello'
108132
assert ds.is_non_string_iterable(is_string) == False

0 commit comments

Comments
 (0)