From 8241faf1383b35aff1a3476f90911cd042b4eae7 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Mon, 2 Feb 2026 16:04:34 -0300 Subject: [PATCH 1/2] Fix bfloat16 not recognized as floating point dtype Add "BFloat16" to the is_floating_point check in torch_dtype class. This fixes nn_module$to() failing when called with torch_bfloat16(). Also adds test coverage for is_floating_point across all dtype variants. Co-Authored-By: Claude Opus 4.5 --- R/dtype.R | 2 +- tests/testthat/test-dtype.R | 28 ++++++++++++++++++++++++++-- 2 files changed, 27 insertions(+), 3 deletions(-) diff --git a/R/dtype.R b/R/dtype.R index b2b4427906..e5d9ca7ad2 100644 --- a/R/dtype.R +++ b/R/dtype.R @@ -14,7 +14,7 @@ torch_dtype <- R7Class( ), active = list( is_floating_point = function() { - if (cpp_dtype_to_string(self$ptr) %in% c("Float", "Double", "Half")) { + if (cpp_dtype_to_string(self$ptr) %in% c("Float", "Double", "Half", "BFloat16")) { TRUE } else { FALSE diff --git a/tests/testthat/test-dtype.R b/tests/testthat/test-dtype.R index f4ed8d35fa..32abd8a8fa 100644 --- a/tests/testthat/test-dtype.R +++ b/tests/testthat/test-dtype.R @@ -70,10 +70,34 @@ test_that("can set select devices using strings", { }) test_that("error when comparing dtypes", { - + expect_error( NULL == torch_float64(), "not a dtype" ) - + +}) + +test_that("is_floating_point works for all floating point types", { + # All standard floating point types should return TRUE + + expect_true(torch_float()$is_floating_point) + expect_true(torch_float32()$is_floating_point) + expect_true(torch_float64()$is_floating_point) + expect_true(torch_double()$is_floating_point) + expect_true(torch_float16()$is_floating_point) + expect_true(torch_half()$is_floating_point) + + expect_true(torch_bfloat16()$is_floating_point) + + # Non-floating point types should return FALSE + expect_false(torch_int()$is_floating_point) + expect_false(torch_int32()$is_floating_point) + expect_false(torch_int64()$is_floating_point) + expect_false(torch_long()$is_floating_point) + expect_false(torch_int16()$is_floating_point) + expect_false(torch_short()$is_floating_point) + expect_false(torch_int8()$is_floating_point) + expect_false(torch_uint8()$is_floating_point) + expect_false(torch_bool()$is_floating_point) }) \ No newline at end of file From 2f54fbbf1ceb5718a28bb118e337b9c6f69c98f0 Mon Sep 17 00:00:00 2001 From: Daniel Falbel Date: Mon, 2 Feb 2026 17:18:43 -0300 Subject: [PATCH 2/2] Add test for module$to with bfloat16 dtype Co-Authored-By: Claude Opus 4.5 --- tests/testthat/test-nn.R | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/testthat/test-nn.R b/tests/testthat/test-nn.R index 50592b69ef..0dec01b1aa 100644 --- a/tests/testthat/test-nn.R +++ b/tests/testthat/test-nn.R @@ -176,6 +176,14 @@ test_that("to", { expect_equal(net$linear$bias$device$type, "cpu") }) +test_that("module$to works with bfloat16", { + net <- nn_linear(10, 10) + net$to(dtype = torch_bfloat16()) + + expect_true(net$weight$dtype == torch_bfloat16()) + expect_true(net$bias$dtype == torch_bfloat16()) +}) + test_that("state_dict for modules", { Net <- nn_module( initialize = function() {