diff --git a/extension/training/module/test/training_module_test.cpp b/extension/training/module/test/training_module_test.cpp index 16ff87bc022..29d9bcf5842 100644 --- a/extension/training/module/test/training_module_test.cpp +++ b/extension/training/module/test/training_module_test.cpp @@ -199,3 +199,58 @@ TEST_F(TrainingModuleTest, DataExternalConstantsTest) { ASSERT_EQ(attributes.find("b")->second.sizes()[0], 2); ASSERT_EQ(attributes.find("b")->second.dim(), 2); } + +TEST_F(TrainingModuleTest, UnloadMethodTest) { + const char* ptd_path = std::getenv("ET_MODULE_TRAIN_DATA_PATH"); + Result data_map_loader_res = FileDataLoader::from(ptd_path); + ASSERT_EQ(data_map_loader_res.error(), Error::Ok); + + auto data_map_loader = + std::make_unique( + std::move(data_map_loader_res.get())); + + const char* pte_path = std::getenv("ET_MODULE_TRAIN_PROGRAM_PATH"); + Result pte_loader_res = FileDataLoader::from(pte_path); + ASSERT_EQ(pte_loader_res.error(), Error::Ok); + + auto pte_loader = std::make_unique( + std::move(pte_loader_res.get())); + + auto mod = executorch::extension::training::TrainingModule( + std::move(pte_loader), + nullptr, + nullptr, + nullptr, + std::move(data_map_loader)); + + auto parameters_res = mod.named_parameters("forward"); + ASSERT_EQ(parameters_res.error(), Error::Ok); + auto& parameters = parameters_res.get(); + + ASSERT_NEAR( + parameters_res.get() + .find("linear.bias") + ->second.const_data_ptr()[0], + 0.1528, + 0.0001); + + // mock training + auto linear_bias_ptr = + parameters.find("linear.bias")->second.mutable_data_ptr(); + linear_bias_ptr[0] += 0.5; + ASSERT_NEAR( + parameters.find("linear.bias")->second.const_data_ptr()[0], + 0.6528, + 0.0001); + + mod.unload_method("forward"); + + auto new_parameters_res = mod.named_parameters("forward"); + ASSERT_EQ(new_parameters_res.error(), Error::Ok); + ASSERT_NEAR( + new_parameters_res.get() + .find("linear.bias") + ->second.const_data_ptr()[0], + 0.1528, + 0.0001); +} diff --git a/extension/training/module/training_module.h b/extension/training/module/training_module.h index 7dd380d2709..146eb61bcb7 100644 --- a/extension/training/module/training_module.h +++ b/extension/training/module/training_module.h @@ -49,6 +49,15 @@ class ET_EXPERIMENTAL TrainingModule final explicit TrainingModule(Module&&) = delete; TrainingModule& operator=(Module&&) = delete; + // Redefine to erase the tensors pointing to the released memory. + inline bool unload_method(const std::string& method_name) { + method_named_gradients_.erase(method_name); + method_named_parameters_.erase(method_name); + method_named_attributes_.erase(method_name); + + return methods_.erase(method_name); + } + /** * Execute a specific method with the given input and retrieve output. Only * valid if the specified method is a joint graph. Loads the program and