Skip to content

Commit d7be303

Browse files
POI-WXyangbofun
authored andcommitted
[Ascend] Wx/fix dtype cast bug of adamw op (#1260)
* fix dtype cast bug of adamw op
1 parent 2c4e94a commit d7be303

File tree

3 files changed

+7
-16
lines changed

3 files changed

+7
-16
lines changed

adaptor/codegen/gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565
"IndexPut": ["out"],
6666
"Adadelta": ["input", "grad", "square_avg", "acc_delta"],
6767
"IndexBackward": ["zeros_like_input"],
68+
"AdamW": ["param", "exp_avg", "exp_avg_sq", "max_exp_avg_sq"],
6869
}
6970

7071

impl/ascend/convert_config.yaml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
- common_config:
22
layout: NCHW
33

4+
- diopiAdamW:
5+
dtype: (float64)->float32
6+
layout: ND
7+
48
- diopiSoftmax:
59
layout: ND
610

@@ -447,8 +451,8 @@
447451

448452
- diopiMaxPool2dWithIndices:
449453
tensor_dtype:
450-
indices: (int64)->int32
454+
indices: (int64)->int32
451455

452456
- diopiMaxPool2dBackward:
453457
tensor_dtype:
454-
indices: (int64)->int32
458+
indices: (int64)->int32

impl/ascend/device_configs.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1435,20 +1435,6 @@
14351435
skip_all = True
14361436
),
14371437

1438-
'adam': dict(
1439-
name=['adamw'],
1440-
tensor_para=dict(
1441-
args=[
1442-
{
1443-
"ins": ['param'],
1444-
# float64 not supported yet on ascend
1445-
# temporarily skip all test cases due to software stack version
1446-
"dtype": [Skip(np.float16), Skip(np.float32), Skip(np.float64)],
1447-
},
1448-
]
1449-
),
1450-
),
1451-
14521438
# temporarily skip all test cases for flash_attention_varlen due to the version of software stack on ascend
14531439
'flash_attention_varlen': dict(
14541440
name=['flash_attention_varlen'],

0 commit comments

Comments
 (0)