diff --git a/R/dtype.R b/R/dtype.R index b2b4427906..e5d9ca7ad2 100644 --- a/R/dtype.R +++ b/R/dtype.R @@ -14,7 +14,7 @@ torch_dtype <- R7Class( ), active = list( is_floating_point = function() { - if (cpp_dtype_to_string(self$ptr) %in% c("Float", "Double", "Half")) { + if (cpp_dtype_to_string(self$ptr) %in% c("Float", "Double", "Half", "BFloat16")) { TRUE } else { FALSE diff --git a/tests/testthat/test-dtype.R b/tests/testthat/test-dtype.R index f4ed8d35fa..32abd8a8fa 100644 --- a/tests/testthat/test-dtype.R +++ b/tests/testthat/test-dtype.R @@ -70,10 +70,34 @@ test_that("can set select devices using strings", { }) test_that("error when comparing dtypes", { - + expect_error( NULL == torch_float64(), "not a dtype" ) - + +}) + +test_that("is_floating_point works for all floating point types", { + # All standard floating point types should return TRUE + + expect_true(torch_float()$is_floating_point) + expect_true(torch_float32()$is_floating_point) + expect_true(torch_float64()$is_floating_point) + expect_true(torch_double()$is_floating_point) + expect_true(torch_float16()$is_floating_point) + expect_true(torch_half()$is_floating_point) + + expect_true(torch_bfloat16()$is_floating_point) + + # Non-floating point types should return FALSE + expect_false(torch_int()$is_floating_point) + expect_false(torch_int32()$is_floating_point) + expect_false(torch_int64()$is_floating_point) + expect_false(torch_long()$is_floating_point) + expect_false(torch_int16()$is_floating_point) + expect_false(torch_short()$is_floating_point) + expect_false(torch_int8()$is_floating_point) + expect_false(torch_uint8()$is_floating_point) + expect_false(torch_bool()$is_floating_point) }) \ No newline at end of file diff --git a/tests/testthat/test-nn.R b/tests/testthat/test-nn.R index 50592b69ef..0dec01b1aa 100644 --- a/tests/testthat/test-nn.R +++ b/tests/testthat/test-nn.R @@ -176,6 +176,14 @@ test_that("to", { expect_equal(net$linear$bias$device$type, "cpu") }) +test_that("module$to works with bfloat16", { + net <- nn_linear(10, 10) + net$to(dtype = torch_bfloat16()) + + expect_true(net$weight$dtype == torch_bfloat16()) + expect_true(net$bias$dtype == torch_bfloat16()) +}) + test_that("state_dict for modules", { Net <- nn_module( initialize = function() {