Skip to content

Commit 57fb92a

Browse files
bojakeOceania2018
authored andcommitted
Fixed the choice method so that it respects the seed of the random state it is running in.
Fixed the Seed attribute from NumPyRandom so that it returns the seed value that was specified in the ctor. Added unit tests for testing the seed state's consistency.
1 parent 49c4c1a commit 57fb92a

File tree

3 files changed

+115
-4
lines changed

3 files changed

+115
-4
lines changed

src/NumSharp.Core/RandomSampling/np.random.choice.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ public partial class NumPyRandom
1313
public NDArray choice(NDArray arr, Shape shape = default, bool replace = true, double[] probabilities = null)
1414
{
1515
int arrSize = arr.size;
16-
NDArray mask = np.random.choice(arrSize, shape, probabilities: probabilities);
16+
NDArray mask = choice(arrSize, shape, probabilities: probabilities);
1717
return arr[mask];
1818
}
1919

@@ -35,14 +35,14 @@ public NDArray choice(int a, Shape shape = default, bool replace = true, double[
3535

3636
if (probabilities == null)
3737
{
38-
idx = np.random.randint(0, arr.size, shape);
38+
idx = randint(0, arr.size, shape);
3939
}
4040
else
4141
{
4242

4343
NDArray cdf = np.cumsum(probabilities);
4444
cdf /= cdf[cdf.size - 1];
45-
NDArray uniformSamples = np.random.uniform(0, 1, (int[]) shape);
45+
NDArray uniformSamples = uniform(0, 1, (int[]) shape);
4646
idx = np.searchsorted(cdf, uniformSamples);
4747
}
4848

src/NumSharp.Core/RandomSampling/np.random.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@ protected internal NumPyRandom(NativeRandomState nativeRandomState)
2222
set_state(nativeRandomState);
2323
}
2424

25-
protected internal NumPyRandom(int seed) : this(new Randomizer(seed)) { }
25+
protected internal NumPyRandom(int seed) : this(new Randomizer(seed)) {
26+
Seed = seed;
27+
}
2628

2729
protected internal NumPyRandom() : this(new Randomizer()) { }
2830

@@ -62,6 +64,7 @@ public NumPyRandom RandomState(NativeRandomState state)
6264
/// </summary>
6365
public void seed(int seed)
6466
{
67+
Seed = seed;
6568
randomizer = new Randomizer(seed);
6669
}
6770

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Diagnostics;
5+
using System.Text;
6+
7+
namespace NumSharp.UnitTest.RandomSampling
8+
{
9+
/// <summary>
10+
/// The random seed tests are only supposed to test the consistent output from the random state when the
11+
/// same seed is applied. No testing of the actual output from the random state is expected here. Just test
12+
/// the consistent output after repeatedly setting the same seed value.
13+
/// </summary>
14+
[TestClass]
15+
public class NpRandomSeedTests : TestClass
16+
{
17+
[TestMethod]
18+
public void SeedTest()
19+
{
20+
NumPyRandom rando = np.random.RandomState(1000);
21+
Assert.AreEqual(1000, rando.Seed, "The seed value given in the ctor does not match the seed value attribute.");
22+
}
23+
24+
[TestMethod]
25+
public void UniformOneSample()
26+
{
27+
NumPyRandom rando = np.random.RandomState(1000);
28+
// Generate a uniform random sample from np.arange(5) of size 1:
29+
// This is equivalent to np.random.randint(0,5,1)
30+
int low = 0;
31+
int high = 5;
32+
33+
// Start with the known, which is what we expect to see every time
34+
NDArray actual = rando.choice(high);
35+
for (int i = 0; i < 10; i++) {
36+
rando.seed(1000);
37+
NDArray test = rando.choice(high); // Not specifying size means 1 single value is wanted
38+
Assert.AreEqual(actual, test, "Inconsistent random result with same seed. Expected the value to be equal every time.");
39+
}
40+
41+
}
42+
43+
[TestMethod]
44+
public void UniformMultipleSample()
45+
{
46+
NumPyRandom rando = np.random.RandomState(1000);
47+
// Generate a uniform random sample from np.arange(5) of size 3:
48+
// This is equivalent to np.random.randint(0,5,3)
49+
int low = 0;
50+
int high = 5;
51+
int nrSamples = 3;
52+
53+
NDArray actual = rando.choice(high, (Shape)nrSamples);
54+
55+
for (int i = 0; i < 10; i++) {
56+
rando.seed(1000);
57+
NDArray test = rando.choice(high, (Shape)nrSamples);
58+
for (int j = 0; j < actual.size; j++) {
59+
Assert.AreEqual(actual.GetAtIndex<int>(j), test.GetAtIndex<int>(j), "Inconsistent choice sampling with the same seed. Expected the results to always be the same.");
60+
}
61+
}
62+
}
63+
64+
[TestMethod]
65+
public void NonUniformSample()
66+
{
67+
NumPyRandom rando = np.random.RandomState(1000);
68+
// Generate a non-uniform random sample from np.arange(5) of size 3:
69+
int low = 0;
70+
int high = 5;
71+
int nrSamples = 3;
72+
double[] probabilities = new double[] { 0.1, 0, 0.3, 0.6, 0 };
73+
74+
NDArray actual = rando.choice(5, (Shape)nrSamples, probabilities: probabilities);
75+
76+
for (int i = 0; i < 10; i++) {
77+
rando.seed(1000);
78+
NDArray test = rando.choice(5, (Shape)nrSamples, probabilities: probabilities);
79+
for (int j = 0; j < actual.size; j++) {
80+
Assert.AreEqual(actual.GetAtIndex<int>(j), test.GetAtIndex<int>(j), "Inconsistent choice sampling with the same seed. Expected the results to always be the same.");
81+
}
82+
}
83+
}
84+
85+
86+
[TestMethod]
87+
public void IntegerArraySample()
88+
{
89+
NumPyRandom rando = np.random.RandomState(1000);
90+
int nrSamples = 5;
91+
92+
NDArray int_arr = new int[] { 42, 96, 3, 101 };
93+
double[] probabilities = new double[] { 0.5, 0.1, 0.0, 0.3 };
94+
95+
NDArray actual = rando.choice(int_arr, (Shape)nrSamples, probabilities: probabilities);
96+
97+
for (int i = 0; i < 10; i++)
98+
{
99+
rando.seed(1000);
100+
NDArray test = rando.choice(int_arr, (Shape)nrSamples, probabilities: probabilities);
101+
for (int j = 0; j < actual.size; j++)
102+
{
103+
Assert.AreEqual(actual.GetAtIndex<int>(j), test.GetAtIndex<int>(j), "Inconsistent choice sampling with the same seed. Expected the results to always be the same.");
104+
}
105+
}
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)