|
9 | 9 | import typing |
10 | 10 | from abc import ABC, abstractmethod |
11 | 11 | from enum import Enum |
12 | | -from typing import Any, Dict, List, Optional, Set |
| 12 | +from typing import Any, Dict, List, Set |
13 | 13 |
|
14 | 14 | import torch |
15 | 15 | from executorch.backends.aoti.passes.replace_view_copy_with_view import ( |
@@ -91,39 +91,24 @@ def collect_unsupported_fallback_kernels(cls, missing_fallback_kernels: Set[str] |
91 | 91 | ) |
92 | 92 |
|
93 | 93 | def generate_c_shim_extern_kernel_call_and_collect_unsupported_kernels( |
94 | | - self, |
95 | | - kernel: str, |
96 | | - args: list[str], |
97 | | - device: str, |
98 | | - *, |
99 | | - debug_args: Optional[list[str]] = None, |
100 | | - debug_handle: Optional[int] = None, |
101 | | - ): |
| 94 | + self, kernel: str, *args: Any, **kwargs: Any |
| 95 | + ) -> None: |
102 | 96 | if kernel not in supported_kernels: |
103 | 97 | missing_fallback_kernels.add(kernel) |
104 | 98 |
|
105 | | - original_generate_c_shim_extern_kernel_call( |
106 | | - self, |
107 | | - kernel, |
108 | | - args, |
109 | | - device, |
110 | | - debug_args=debug_args, |
111 | | - debug_handle=debug_handle, |
| 99 | + return original_generate_c_shim_extern_kernel_call( |
| 100 | + self, kernel, *args, **kwargs |
112 | 101 | ) |
113 | 102 |
|
114 | 103 | def generate_fallback_kernel_with_runtime_lookup_aot_and_collect_unsupported_kernels( |
115 | | - self, |
116 | | - op_overload, |
117 | | - raw_args, |
118 | | - output_args, |
119 | | - raw_outputs, |
120 | | - ): |
| 104 | + self, op_overload: Any, *args: Any, **kwargs: Any |
| 105 | + ) -> None: |
121 | 106 | kernel_name = getattr(op_overload, "_name", str(op_overload)) |
122 | 107 | if kernel_name not in supported_kernels: |
123 | 108 | missing_fallback_kernels.add(kernel_name) |
124 | 109 |
|
125 | | - original_generate_fallback_kernel_with_runtime_lookup_aot( |
126 | | - self, op_overload, raw_args, output_args, raw_outputs |
| 110 | + return original_generate_fallback_kernel_with_runtime_lookup_aot( |
| 111 | + self, op_overload, *args, **kwargs |
127 | 112 | ) |
128 | 113 |
|
129 | 114 | CppWrapperCpu.generate_c_shim_extern_kernel_call = ( |
|
0 commit comments