Skip to content

[BUG] Each iteration produces the same sampling result #7

@gpzlx1

Description

@gpzlx1

The _SampleSubIndicesKernelFusedWithReplace uses fixed seed, making the results almost the same for each call. This can lead to poor training accuracy

template <typename IdType>
__global__ void _SampleSubIndicesKernelFusedWithReplace(IdType* sub_indices,
                                             IdType* indptr, IdType* indices,
                                             IdType* sub_indptr,
                                             IdType* column_ids, int64_t size) {
  int64_t row = blockIdx.x * blockDim.y + threadIdx.y;
  const uint64_t random_seed = 7777777;  // There's a problem here
  curandState rng;
  curand_init(random_seed * gridDim.x + blockIdx.x, threadIdx.x, 0, &rng);
  while (row < size) {
    int64_t col = column_ids[row];
    int64_t in_start = indptr[col];
    int64_t out_start = sub_indptr[row];
    int64_t degree = indptr[col + 1] - indptr[col];
    int64_t fanout = sub_indptr[row + 1] - sub_indptr[row];
    int64_t tid = threadIdx.x;
    while (tid < fanout) {
      // Sequential Sampling
      const int64_t edge = tid % degree;
      // Random Sampling
      // const int64_t edge = curand(&rng) % degree;
      sub_indices[out_start + tid] = indices[in_start + edge];
      tid += blockDim.x;
    }
    row += gridDim.x * blockDim.y;
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions