1
+ #include < gtest/gtest.h>
2
+ #include < torch/torch.h>
3
+
4
+ #include " test/cpp/torch_xla_test.h"
5
+ #include " torch_xla/csrc/xla_generator.h"
6
+
7
+ namespace torch_xla {
8
+ namespace cpp_test {
9
+
10
+ // Test fixture for XLAGenerator tests
11
+ class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest {
12
+ protected:
13
+ void SetUp () {
14
+ // Create a generator for XLA device 0
15
+ gen_ = at::make_generator<at::XLAGeneratorImpl>(0 );
16
+ }
17
+
18
+ at::Generator gen_;
19
+ };
20
+
21
+ TEST_F (XLAGeneratorTest, Constructor) {
22
+ // Check that the generator was created for the correct device
23
+ ASSERT_EQ (gen_.device ().type (), at::DeviceType::XLA);
24
+ ASSERT_EQ (gen_.device ().index (), 0 );
25
+
26
+ // Check that the initial seed is 0
27
+ ASSERT_EQ (gen_.current_seed (), 0 );
28
+ }
29
+
30
+ TEST_F (XLAGeneratorTest, Seed) {
31
+ // Test setting and getting the current seed
32
+ uint64_t seed_val = 12345 ;
33
+ gen_.set_current_seed (seed_val);
34
+ ASSERT_EQ (gen_.current_seed (), seed_val);
35
+
36
+ // Test the seed() method, which should set a non-deterministic seed
37
+ uint64_t old_seed = gen_.current_seed ();
38
+ uint64_t new_seed = gen_.seed ();
39
+ // The new seed should be different from the old one and set as the current
40
+ // seed
41
+ ASSERT_NE (new_seed, old_seed);
42
+ ASSERT_EQ (gen_.current_seed (), new_seed);
43
+ }
44
+
45
+ TEST_F (XLAGeneratorTest, GetAndSetState) {
46
+ uint64_t seed_val = 98765 ;
47
+ uint64_t offset_val = 0 ;
48
+
49
+ // Set seed and offset on the original generator
50
+ gen_.set_current_seed (seed_val);
51
+ gen_.set_offset (offset_val);
52
+
53
+ // Get the state from the original generator
54
+ at::Tensor state_tensor = gen_.get_state ();
55
+
56
+ // Create a new generator
57
+ auto new_gen = at::make_generator<at::XLAGeneratorImpl>(1 );
58
+ ASSERT_NE (new_gen.current_seed (), seed_val);
59
+
60
+ // Set the state of the new generator
61
+ new_gen.set_state (state_tensor);
62
+
63
+ // Verify the state of the new generator
64
+ ASSERT_EQ (new_gen.current_seed (), seed_val);
65
+ ASSERT_EQ (new_gen.get_offset (), offset_val);
66
+ }
67
+
68
+ TEST_F (XLAGeneratorTest, SetStateValidation) {
69
+ // Test that set_state throws with incorrect tensor properties
70
+ auto new_gen = at::make_generator<at::XLAGeneratorImpl>(0 );
71
+
72
+ // Incorrect size
73
+ auto wrong_size_tensor = at::empty ({10 }, at::kByte );
74
+ EXPECT_THROW (new_gen.set_state (wrong_size_tensor), c10::Error);
75
+
76
+ // Incorrect dtype
77
+ auto wrong_dtype_tensor = at::empty ({16 }, at::kInt );
78
+ EXPECT_THROW (new_gen.set_state (wrong_dtype_tensor), c10::Error);
79
+ }
80
+
81
+ TEST_F (XLAGeneratorTest, Clone) {
82
+ uint64_t seed_val = 1 ;
83
+ uint64_t offset_val = 0 ;
84
+
85
+ // Set state on the original generator
86
+ gen_.set_current_seed (seed_val);
87
+ gen_.set_offset (offset_val);
88
+
89
+ // Clone the generator
90
+ auto cloned_gen = gen_.clone ();
91
+
92
+ // Verify that the cloned generator has the same state but is a different
93
+ // object
94
+ ASSERT_NE (std::addressof (cloned_gen), std::addressof (gen_));
95
+ ASSERT_EQ (cloned_gen.device (), gen_.device ());
96
+ ASSERT_EQ (cloned_gen.current_seed (), gen_.current_seed ());
97
+ ASSERT_EQ (cloned_gen.get_offset (), offset_val);
98
+
99
+ // Modify the original generator's seed and check that the clone is unaffected
100
+ gen_.set_current_seed (9999 );
101
+ ASSERT_EQ (cloned_gen.current_seed (), seed_val);
102
+ ASSERT_NE (cloned_gen.current_seed (), gen_.current_seed ());
103
+ }
104
+
105
+ } // namespace cpp_test
106
+ } // namespace torch_xla
0 commit comments