Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ggml/src/ggml-cuda/cpy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
char * src1_ddc = (char *) src1->data;

const bool contiguous_srcs = ggml_is_contiguous(src0) && ggml_is_contiguous(src1);
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) && src0->ne[3] == 1;
const bool can_be_transposed = nb01 == (int64_t)ggml_element_size(src0) &&
src0->ne[3] == 1 && nb02 == ne00 * ne01 * (int64_t)ggml_element_size(src0);

if (src0->type == src1->type && contiguous_srcs) {
GGML_ASSERT(ggml_nbytes(src0) == ggml_nbytes(src1));
Expand Down
32 changes: 32 additions & 0 deletions tests/test-backend-ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2800,6 +2800,33 @@ struct test_cont : public test_case {
}
};

struct test_irregular_cont : public test_case {
const ggml_type type;
const std::array<int64_t, 4> ne;

std::string vars() override {
return VARS_TO_STR2(type, ne);
}

test_irregular_cont(ggml_type type = GGML_TYPE_F32,
std::array<int64_t, 4> ne = {1, 4, 2, 1})
: type(type), ne(ne) {}

ggml_tensor * build_graph(ggml_context * ctx) override {
ggml_tensor * src = ggml_new_tensor(ctx, type, 4, ne.data());
ggml_set_param(src);
ggml_set_name(src, "src");

ggml_tensor * view = ggml_view_4d(ctx, src, src->ne[0], 1, src->ne[2], src->ne[3],
src->nb[1], src->nb[2], src->nb[3], src->nb[0] * (src->ne[1] - 1));

ggml_tensor * out = ggml_cont(ctx, view);
ggml_set_name(out, "out");

return out;
}
};

// GGML_OP_ADD
// GGML_OP_SUB
// GGML_OP_MUL
Expand Down Expand Up @@ -6956,6 +6983,11 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 1, 3 ,5}));
test_cases.emplace_back(new test_cont(GGML_TYPE_BF16, {2, 3, 5 ,7}));

test_cases.emplace_back(new test_irregular_cont());
test_cases.emplace_back(new test_irregular_cont(GGML_TYPE_F32, {1, 8, 17, 1}));
test_cases.emplace_back(new test_irregular_cont(GGML_TYPE_BF16, {1, 4, 2, 1}));
test_cases.emplace_back(new test_irregular_cont(GGML_TYPE_BF16, {1, 8, 17, 1}));

auto add_test_bin_bcast = [&](ggml_type type, std::array<int64_t, 4> ne, std::array<int, 4> nr) {
for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) {
test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr));
Expand Down
Loading