@@ -14,6 +14,7 @@ class DecoratedFunction:
1414 source : str
1515 lint_mode : bool
1616 lint_user_whitelist : list [str ]
17+ subdirectories : list [str ] | None = None
1718 filepath : Path | None = None
1819 parameters : list [tuple [str , str | None ]] = dataclasses .field (default_factory = list )
1920 arguments_type_schema : dict | None = None
@@ -83,6 +84,20 @@ class CodegenFunctionVisitor(ast.NodeVisitor):
8384 def __init__ (self ):
8485 self .functions : list [DecoratedFunction ] = []
8586
87+ def get_function_name (self , node : ast .Call ) -> str :
88+ keywords = {k .arg : k .value for k in node .keywords }
89+ if "name" in keywords :
90+ return ast .literal_eval (keywords ["name" ])
91+ return ast .literal_eval (node .args [0 ])
92+
93+ def get_subdirectories (self , node : ast .Call ) -> list [str ] | None :
94+ keywords = {k .arg : k .value for k in node .keywords }
95+ if "subdirectories" in keywords :
96+ return ast .literal_eval (keywords ["subdirectories" ])
97+ if len (node .args ) > 1 :
98+ return ast .literal_eval (node .args [1 ])
99+ return None
100+
86101 def get_function_body (self , node : ast .FunctionDef ) -> str :
87102 """Extract and unindent the function body."""
88103 # Get the start and end positions of the function body
@@ -178,7 +193,7 @@ def visit_FunctionDef(self, node):
178193 for decorator in node .decorator_list :
179194 if (
180195 isinstance (decorator , ast .Call )
181- and len (decorator .args ) >= 1
196+ and ( len (decorator .args ) > 0 or len ( decorator . keywords ) > 0 )
182197 and (
183198 # Check if it's a direct codegen.X call
184199 (isinstance (decorator .func , ast .Attribute ) and isinstance (decorator .func .value , ast .Name ) and decorator .func .value .id == "codegen" )
@@ -188,7 +203,8 @@ def visit_FunctionDef(self, node):
188203 )
189204 ):
190205 # Get the function name from the decorator argument
191- func_name = ast .literal_eval (decorator .args [0 ])
206+ func_name = self .get_function_name (decorator )
207+ subdirectories = self .get_subdirectories (decorator )
192208
193209 # Get additional metadata for webhook
194210 lint_mode = decorator .func .attr == "webhook"
@@ -201,7 +217,16 @@ def visit_FunctionDef(self, node):
201217 # Get just the function body, unindented
202218 body_source = self .get_function_body (node )
203219 parameters = self .get_function_parameters (node )
204- self .functions .append (DecoratedFunction (name = func_name , source = body_source , lint_mode = lint_mode , lint_user_whitelist = lint_user_whitelist , parameters = parameters ))
220+ self .functions .append (
221+ DecoratedFunction (
222+ name = func_name ,
223+ subdirectories = subdirectories ,
224+ source = body_source ,
225+ lint_mode = lint_mode ,
226+ lint_user_whitelist = lint_user_whitelist ,
227+ parameters = parameters ,
228+ )
229+ )
205230
206231 def _has_codegen_root (self , node ):
207232 """Recursively check if an AST node chain starts with codegen."""
0 commit comments