Skip to content

Commit a1c6ee9

Browse files
authored
Add xla random generator. (#9539)
This is the very first PR for #9159. It purely add the generator without any utilization of it. #9159 (comment) comment outlines the steps for entire change.
1 parent 4199865 commit a1c6ee9

File tree

8 files changed

+266
-3
lines changed

8 files changed

+266
-3
lines changed

.github/scripts/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ function run_torch_xla_cpp_tests() {
5555
"test_tensor"
5656
# disable test_xla_backend_intf since it is flaky on upstream
5757
#"test_xla_backend_intf"
58+
"test_xla_generator"
5859
"test_xla_sharding"
5960
"test_runtime"
6061
"test_status_dont_show_cpp_stacktraces"

BUILD

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,15 +72,16 @@ test_suite(
7272
"//test/cpp:test_aten_xla_tensor_4",
7373
"//test/cpp:test_aten_xla_tensor_5",
7474
"//test/cpp:test_aten_xla_tensor_6",
75+
"//test/cpp:test_debug_macros",
7576
"//test/cpp:test_ir",
7677
"//test/cpp:test_lazy",
7778
"//test/cpp:test_replication",
78-
"//test/cpp:test_tensor",
79-
"//test/cpp:test_xla_sharding",
8079
"//test/cpp:test_runtime",
8180
"//test/cpp:test_status_dont_show_cpp_stacktraces",
8281
"//test/cpp:test_status_show_cpp_stacktraces",
83-
"//test/cpp:test_debug_macros",
82+
"//test/cpp:test_tensor",
83+
"//test/cpp:test_xla_generator",
84+
"//test/cpp:test_xla_sharding",
8485
"//torch_xla/csrc/runtime:pjrt_computation_client_test",
8586
# "//torch_xla/csrc/runtime:ifrt_computation_client_test",
8687
],

test/cpp/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,3 +204,15 @@ ptxla_cc_test(
204204
"@com_google_googletest//:gtest_main",
205205
],
206206
)
207+
208+
ptxla_cc_test(
209+
name = "test_xla_generator",
210+
srcs = ["test_xla_generator.cpp"],
211+
deps = [
212+
":cpp_test_util",
213+
":torch_xla_test",
214+
"//torch_xla/csrc:tensor",
215+
"//torch_xla/csrc:aten_cuda_functions",
216+
"@com_google_googletest//:gtest_main",
217+
],
218+
)

test/cpp/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ if [[ "$RUN_CPP_TESTS" == "cpp_tests" ]]; then
100100
# disable test_xla_backend_intf since it is flaky on upstream
101101
#"test_xla_backend_intf"
102102
"test_xla_sharding"
103+
"test_xla_generator"
103104
"test_runtime"
104105
"test_status_dont_show_cpp_stacktraces"
105106
"test_status_show_cpp_stacktraces"

test/cpp/test_xla_generator.cpp

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
#include <gtest/gtest.h>
2+
#include <torch/torch.h>
3+
4+
#include "test/cpp/torch_xla_test.h"
5+
#include "torch_xla/csrc/xla_generator.h"
6+
7+
namespace torch_xla {
8+
namespace cpp_test {
9+
10+
// Test fixture for XLAGenerator tests
11+
class XLAGeneratorTest : public ::torch_xla::cpp_test::TorchXlaTest {
12+
protected:
13+
void SetUp() {
14+
// Create a generator for XLA device 0
15+
gen_ = at::make_generator<at::XLAGeneratorImpl>(0);
16+
}
17+
18+
at::Generator gen_;
19+
};
20+
21+
TEST_F(XLAGeneratorTest, Constructor) {
22+
// Check that the generator was created for the correct device
23+
ASSERT_EQ(gen_.device().type(), at::DeviceType::XLA);
24+
ASSERT_EQ(gen_.device().index(), 0);
25+
26+
// Check that the initial seed is 0
27+
ASSERT_EQ(gen_.current_seed(), 0);
28+
}
29+
30+
TEST_F(XLAGeneratorTest, Seed) {
31+
// Test setting and getting the current seed
32+
uint64_t seed_val = 12345;
33+
gen_.set_current_seed(seed_val);
34+
ASSERT_EQ(gen_.current_seed(), seed_val);
35+
36+
// Test the seed() method, which should set a non-deterministic seed
37+
uint64_t old_seed = gen_.current_seed();
38+
uint64_t new_seed = gen_.seed();
39+
// The new seed should be different from the old one and set as the current
40+
// seed
41+
ASSERT_NE(new_seed, old_seed);
42+
ASSERT_EQ(gen_.current_seed(), new_seed);
43+
}
44+
45+
TEST_F(XLAGeneratorTest, GetAndSetState) {
46+
uint64_t seed_val = 98765;
47+
uint64_t offset_val = 0;
48+
49+
// Set seed and offset on the original generator
50+
gen_.set_current_seed(seed_val);
51+
gen_.set_offset(offset_val);
52+
53+
// Get the state from the original generator
54+
at::Tensor state_tensor = gen_.get_state();
55+
56+
// Create a new generator
57+
auto new_gen = at::make_generator<at::XLAGeneratorImpl>(1);
58+
ASSERT_NE(new_gen.current_seed(), seed_val);
59+
60+
// Set the state of the new generator
61+
new_gen.set_state(state_tensor);
62+
63+
// Verify the state of the new generator
64+
ASSERT_EQ(new_gen.current_seed(), seed_val);
65+
ASSERT_EQ(new_gen.get_offset(), offset_val);
66+
}
67+
68+
TEST_F(XLAGeneratorTest, SetStateValidation) {
69+
// Test that set_state throws with incorrect tensor properties
70+
auto new_gen = at::make_generator<at::XLAGeneratorImpl>(0);
71+
72+
// Incorrect size
73+
auto wrong_size_tensor = at::empty({10}, at::kByte);
74+
EXPECT_THROW(new_gen.set_state(wrong_size_tensor), c10::Error);
75+
76+
// Incorrect dtype
77+
auto wrong_dtype_tensor = at::empty({16}, at::kInt);
78+
EXPECT_THROW(new_gen.set_state(wrong_dtype_tensor), c10::Error);
79+
}
80+
81+
TEST_F(XLAGeneratorTest, Clone) {
82+
uint64_t seed_val = 1;
83+
uint64_t offset_val = 0;
84+
85+
// Set state on the original generator
86+
gen_.set_current_seed(seed_val);
87+
gen_.set_offset(offset_val);
88+
89+
// Clone the generator
90+
auto cloned_gen = gen_.clone();
91+
92+
// Verify that the cloned generator has the same state but is a different
93+
// object
94+
ASSERT_NE(std::addressof(cloned_gen), std::addressof(gen_));
95+
ASSERT_EQ(cloned_gen.device(), gen_.device());
96+
ASSERT_EQ(cloned_gen.current_seed(), gen_.current_seed());
97+
ASSERT_EQ(cloned_gen.get_offset(), offset_val);
98+
99+
// Modify the original generator's seed and check that the clone is unaffected
100+
gen_.set_current_seed(9999);
101+
ASSERT_EQ(cloned_gen.current_seed(), seed_val);
102+
ASSERT_NE(cloned_gen.current_seed(), gen_.current_seed());
103+
}
104+
105+
} // namespace cpp_test
106+
} // namespace torch_xla

torch_xla/csrc/BUILD

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ ptxla_cc_library(
6464
"torch_util.cpp",
6565
"view.cpp",
6666
"xla_backend_impl.cpp",
67+
"xla_generator.cpp",
6768
"xla_graph_executor.cpp",
6869
"xla_lower_util.cpp",
6970
"xla_op_builder.cpp",
@@ -107,6 +108,7 @@ ptxla_cc_library(
107108
"torch_util.h",
108109
"view.h",
109110
"xla_backend_impl.h",
111+
"xla_generator.h",
110112
"xla_graph_executor.h",
111113
"xla_lower_util.h",
112114
"xla_op_builder.h",

torch_xla/csrc/xla_generator.cpp

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
#include "xla_generator.h"
2+
3+
#include <ATen/Functions.h>
4+
#include <ATen/core/ScalarType.h>
5+
#include <ATen/core/Tensor.h>
6+
#include <c10/core/Device.h>
7+
#include <c10/core/DeviceType.h>
8+
#include <c10/core/TensorImpl.h>
9+
#include <c10/util/intrusive_ptr.h>
10+
11+
#include <cstring>
12+
13+
namespace at {
14+
15+
XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index)
16+
: c10::GeneratorImpl{Device(DeviceType::XLA, device_index),
17+
DispatchKeySet(c10::DispatchKey::XLA)} {
18+
state_ = c10::make_intrusive<XLAGeneratorState>();
19+
}
20+
21+
XLAGeneratorImpl::XLAGeneratorImpl(DeviceIndex device_index,
22+
c10::intrusive_ptr<XLAGeneratorState> state)
23+
: c10::GeneratorImpl{Device(DeviceType::XLA, device_index),
24+
DispatchKeySet(c10::DispatchKey::XLA)},
25+
state_(std::move(state)) {}
26+
27+
DeviceType XLAGeneratorImpl::device_type() { return DeviceType::XLA; }
28+
29+
std::shared_ptr<XLAGeneratorImpl> XLAGeneratorImpl::clone() const {
30+
return std::shared_ptr<XLAGeneratorImpl>(clone_impl());
31+
}
32+
33+
XLAGeneratorImpl* XLAGeneratorImpl::clone_impl() const {
34+
return new XLAGeneratorImpl(device_.index(), state_->clone());
35+
}
36+
37+
void XLAGeneratorImpl::set_current_seed(uint64_t seed) { state_->seed_ = seed; }
38+
39+
uint64_t XLAGeneratorImpl::current_seed() const { return state_->seed_; }
40+
41+
uint64_t XLAGeneratorImpl::seed() {
42+
uint64_t random = c10::detail::getNonDeterministicRandom(true);
43+
set_current_seed(random);
44+
return random;
45+
}
46+
47+
void XLAGeneratorImpl::set_offset(uint64_t offset) { state_->offset_ = offset; }
48+
49+
uint64_t XLAGeneratorImpl::get_offset() const { return state_->offset_; }
50+
51+
/* Serialize the generator state into a CPU tensor. */
52+
c10::intrusive_ptr<c10::TensorImpl> XLAGeneratorImpl::get_state() const {
53+
static const size_t seed_size = sizeof(uint64_t);
54+
static const size_t offset_size = sizeof(uint64_t);
55+
static const size_t total_size = seed_size + offset_size;
56+
57+
auto state_tensor =
58+
at::empty({(int64_t)total_size},
59+
at::TensorOptions().dtype(at::kByte).device(at::kCPU));
60+
uint8_t* data_ptr = state_tensor.data_ptr<uint8_t>();
61+
memcpy(data_ptr, &state_->seed_, seed_size);
62+
memcpy(data_ptr + seed_size, &state_->offset_, offset_size);
63+
return state_tensor.getIntrusivePtr();
64+
}
65+
66+
void XLAGeneratorImpl::set_state(const c10::TensorImpl& new_state) {
67+
static const size_t seed_size = sizeof(uint64_t);
68+
static const size_t offset_size = sizeof(uint64_t);
69+
static const size_t total_size = seed_size + offset_size;
70+
71+
TORCH_CHECK(new_state.numel() == total_size,
72+
"The given state must be a byte tensor of size ", total_size,
73+
", but was size ", new_state.numel());
74+
TORCH_CHECK(new_state.dtype() == at::kByte,
75+
"The given state must be a byte tensor, but was ",
76+
new_state.dtype());
77+
TORCH_CHECK(new_state.is_cpu(), "The given state must be a CPU tensor");
78+
79+
auto new_rng_state = new_state.data_dtype_initialized<uint8_t>();
80+
memcpy(&state_->seed_, new_rng_state, seed_size);
81+
memcpy(&state_->offset_, new_rng_state + seed_size, offset_size);
82+
}
83+
84+
} // namespace at

torch_xla/csrc/xla_generator.h

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
#pragma once
2+
3+
#include <ATen/core/Generator.h>
4+
#include <ATen/core/Tensor.h>
5+
#include <c10/util/intrusive_ptr.h>
6+
7+
#include <cstdint>
8+
9+
namespace at {
10+
11+
// Holds the actual state variables for the XLA generator.
12+
struct XLAGeneratorState : c10::intrusive_ptr_target {
13+
uint64_t seed_ = 0;
14+
uint64_t offset_ = 0;
15+
16+
// Constructor
17+
XLAGeneratorState(uint64_t seed = 0, uint64_t offset = 0)
18+
: seed_(seed), offset_(offset) {}
19+
20+
// Cloning method
21+
c10::intrusive_ptr<XLAGeneratorState> clone() {
22+
return c10::make_intrusive<XLAGeneratorState>(seed_, offset_);
23+
}
24+
};
25+
26+
struct TORCH_API XLAGeneratorImpl : public c10::GeneratorImpl {
27+
// Constructors
28+
XLAGeneratorImpl(DeviceIndex device_index = -1);
29+
XLAGeneratorImpl(DeviceIndex device_index,
30+
c10::intrusive_ptr<XLAGeneratorState> state);
31+
~XLAGeneratorImpl() override = default;
32+
33+
// Cloning support
34+
std::shared_ptr<XLAGeneratorImpl> clone() const;
35+
36+
// --- Core Virtual Methods to Override ---
37+
void set_current_seed(uint64_t seed) override;
38+
uint64_t current_seed() const override;
39+
uint64_t seed() override;
40+
void set_offset(uint64_t offset) override;
41+
uint64_t get_offset() const override;
42+
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
43+
void set_state(const c10::TensorImpl& new_state) override;
44+
45+
// --- Additional Methods ---
46+
static c10::DeviceType device_type();
47+
48+
private:
49+
// Private clone implementation
50+
XLAGeneratorImpl* clone_impl() const override;
51+
52+
// The actual state is held in a separate, cloneable object.
53+
c10::intrusive_ptr<XLAGeneratorState> state_;
54+
};
55+
56+
} // namespace at

0 commit comments

Comments
 (0)