|
1 | 1 | import os |
2 | 2 | import shutil |
3 | 3 | import unittest |
4 | | -from pathlib import Path |
5 | | - |
6 | 4 | import numpy as np |
7 | 5 | import torch |
8 | 6 |
|
@@ -224,6 +222,118 @@ def test_vision_reference(self): |
224 | 222 | measure_steps=1, |
225 | 223 | ) |
226 | 224 |
|
| 225 | + def test_parameter_manager_onehot_generic(self): |
| 226 | + test_configs = [ |
| 227 | + { |
| 228 | + 'supernet': 'ofa_mbv3_d234_e346_k357_w1.2', |
| 229 | + 'pymoo_vector': [ |
| 230 | + 1, 2, 2, 2, 2, 2, 1, 2, 0, 2, 1, 1, 0, 0, 1, |
| 231 | + 1, 2, 2, 1, 0, 1, 1, 2, 1, 0, 1, 0, 2, 2, 2, |
| 232 | + 0, 0, 2, 2, 2, 2, 1, 1, 2, 1, 2, 0, 2, 0, 0, |
| 233 | + ], |
| 234 | + 'onehot_vector_expected': [ |
| 235 | + 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, |
| 236 | + 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, |
| 237 | + 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 0, 1, |
| 238 | + 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, |
| 239 | + 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, |
| 240 | + 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, |
| 241 | + 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 1, 0, 0, |
| 242 | + 0, 0, 1, 1, 0, 0, 1, 0, 0, |
| 243 | + ], |
| 244 | + }, |
| 245 | + { |
| 246 | + 'supernet': 'transformer_lt_wmt_en_de', |
| 247 | + 'pymoo_vector': [ |
| 248 | + 0, 0, 2, 1, 0, 0, 0, 2, 0, 2, 2, 2, 0, 2, 3, 0, 0, 0, 0, 0, |
| 249 | + 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1], |
| 250 | + 'onehot_vector_expected': [ |
| 251 | + 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 0, |
| 252 | + 0, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, |
| 253 | + 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, |
| 254 | + 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, |
| 255 | + 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, |
| 256 | + ], |
| 257 | + }, |
| 258 | + ] |
| 259 | + for test_config in test_configs: |
| 260 | + nas_agent = NAS('dynas_fake.yaml') |
| 261 | + search_algorithm, supernet = 'nsga2', test_config['supernet'] |
| 262 | + config = NASConfig(approach='dynas', search_algorithm=search_algorithm) |
| 263 | + config.dynas.supernet = supernet |
| 264 | + nas_agent = NAS(config) |
| 265 | + nas_agent.init_for_search() |
| 266 | + |
| 267 | + onehot_vector = nas_agent.supernet_manager.onehot_generic(in_array=test_config['pymoo_vector']) |
| 268 | + self.assertListEqual(list(onehot_vector), test_config['onehot_vector_expected']) |
| 269 | + |
| 270 | + def test_parameter_manager_translate2param(self): |
| 271 | + test_configs = [ |
| 272 | + { |
| 273 | + 'supernet': 'ofa_mbv3_d234_e346_k357_w1.2', |
| 274 | + 'pymoo_vector': [ |
| 275 | + 1, 2, 2, 2, 2, 2, 1, 2, 0, 2, 1, 1, 0, 0, 1, |
| 276 | + 1, 2, 2, 1, 0, 1, 1, 2, 1, 0, 1, 0, 2, 2, 2, |
| 277 | + 0, 0, 2, 2, 2, 2, 1, 1, 2, 1, 2, 0, 2, 0, 0, |
| 278 | + ], |
| 279 | + 'param_dict_expected': { |
| 280 | + 'd': [4, 2, 4, 2, 2], |
| 281 | + 'e': [4, 4, 6, 4, 3, 4, 3, 6, 6, 6, 3, 3, 6, 6, 6, 6, 4, 4, 6, 4], |
| 282 | + 'ks': [5, 7, 7, 7, 7, 7, 5, 7, 3, 7, 5, 5, 3, 3, 5, 5, 7, 7, 5, 3], |
| 283 | + }, |
| 284 | + }, |
| 285 | + { |
| 286 | + 'supernet': 'transformer_lt_wmt_en_de', |
| 287 | + 'pymoo_vector': [0, 0, 2, 1, 0, 0, 0, 2, 0, 2, 2, 2, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1], |
| 288 | + 'param_dict_expected': {'encoder_embed_dim': [640], 'decoder_embed_dim': [640], 'encoder_ffn_embed_dim': [1024, 2048, 3072, 3072, 3072, 1024], 'decoder_ffn_embed_dim': [3072, 1024, 1024, 1024, 3072, 1024], 'decoder_layer_num': [3], 'encoder_self_attention_heads': [8, 8, 8, 8, 8, 8], 'decoder_self_attention_heads': [8, 4, 8, 4, 4, 8], 'decoder_ende_attention_heads': [8, 4, 4, 4, 8, 8], 'decoder_arbitrary_ende_attn': [1, -1, -1, 1, -1, 1]}, |
| 289 | + } |
| 290 | + ] |
| 291 | + |
| 292 | + for test_config in test_configs: |
| 293 | + nas_agent = NAS('dynas_fake.yaml') |
| 294 | + search_algorithm, supernet = 'nsga2', test_config['supernet'] |
| 295 | + config = NASConfig(approach='dynas', search_algorithm=search_algorithm) |
| 296 | + config.dynas.supernet = supernet |
| 297 | + nas_agent = NAS(config) |
| 298 | + nas_agent.init_for_search() |
| 299 | + |
| 300 | + param_dict = nas_agent.supernet_manager.translate2param(test_config['pymoo_vector']) |
| 301 | + |
| 302 | + self.assertDictEqual(param_dict, test_config['param_dict_expected']) |
| 303 | + |
| 304 | + |
| 305 | + def test_parameter_manager_translate2pymoo(self): |
| 306 | + test_configs = [ |
| 307 | + { |
| 308 | + 'supernet': 'ofa_mbv3_d234_e346_k357_w1.2', |
| 309 | + 'param_dict': { |
| 310 | + 'd': [4, 2, 4, 2, 2], |
| 311 | + 'e': [4, 4, 6, 4, 3, 4, 3, 6, 6, 6, 3, 3, 6, 6, 6, 6, 4, 4, 6, 4], |
| 312 | + 'ks': [5, 7, 7, 7, 7, 7, 5, 7, 3, 7, 5, 5, 3, 3, 5, 5, 7, 7, 5, 3], |
| 313 | + }, |
| 314 | + 'pymoo_vector_expected': [ |
| 315 | + 1, 2, 2, 2, 2, 2, 1, 2, 0, 2, 1, 1, 0, 0, 1, |
| 316 | + 1, 2, 2, 1, 0, 1, 1, 2, 1, 0, 1, 0, 2, 2, 2, |
| 317 | + 0, 0, 2, 2, 2, 2, 1, 1, 2, 1, 2, 0, 2, 0, 0, |
| 318 | + ], |
| 319 | + }, |
| 320 | + { |
| 321 | + 'supernet': 'transformer_lt_wmt_en_de', |
| 322 | + 'param_dict': {'encoder_embed_dim': [640], 'decoder_embed_dim': [640], 'encoder_ffn_embed_dim': [1024, 2048, 3072, 3072, 3072, 1024], 'decoder_ffn_embed_dim': [3072, 1024, 1024, 1024, 3072, 1024], 'decoder_layer_num': [3], 'encoder_self_attention_heads': [8, 8, 8, 8, 8, 8], 'decoder_self_attention_heads': [8, 4, 8, 4, 4, 8], 'decoder_ende_attention_heads': [8, 4, 4, 4, 8, 8], 'decoder_arbitrary_ende_attn': [1, -1, -1, 1, -1, 1]}, |
| 323 | + 'pymoo_vector_expected': [0, 0, 2, 1, 0, 0, 0, 2, 0, 2, 2, 2, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 0, 0, 1, 0, 1], |
| 324 | + } |
| 325 | + ] |
| 326 | + for test_config in test_configs: |
| 327 | + nas_agent = NAS('dynas_fake.yaml') |
| 328 | + search_algorithm, supernet = 'nsga2', test_config['supernet'] |
| 329 | + config = NASConfig(approach='dynas', search_algorithm=search_algorithm) |
| 330 | + config.dynas.supernet = supernet |
| 331 | + nas_agent = NAS(config) |
| 332 | + nas_agent.init_for_search() |
| 333 | + |
| 334 | + pymoo_vector = nas_agent.supernet_manager.translate2pymoo(test_config['param_dict']) |
| 335 | + self.assertListEqual(pymoo_vector, test_config['pymoo_vector_expected']) |
| 336 | + |
227 | 337 |
|
228 | 338 | if __name__ == "__main__": |
229 | 339 | unittest.main() |
0 commit comments