66#
77
88import argparse
9+ import re
10+ import shutil
911
12+ from fnmatch import fnmatch
1013from pathlib import Path
14+ from typing import Iterable , Match , Optional , Set
1115
1216def get_parsearg_base () -> argparse .ArgumentParser :
1317 """ Get base arguments for scripts building a TF-PSA-Crypto test driver """
@@ -23,10 +27,111 @@ def get_parsearg_base() -> argparse.ArgumentParser:
2327 help = "Test driver name (default: %(default)s)." )
2428 return parser
2529
30+ def iter_code_files (root : Path ) -> Iterable [Path ]:
31+ """
32+ Iterate over all "*.c" and "*.h" files found recursively under the `include`
33+ and `src` subdirectories of `root`.
34+ """
35+ for directory in ("include" , "src" ):
36+ directory_path = root / directory
37+ for ext in (".c" , ".h" ):
38+ yield from directory_path .rglob (f"*{ ext } " )
39+
2640class TestDriverGenerator :
2741 """A TF-PSA-Crypto test driver generator"""
2842 def __init__ (self , dst_dir : Path , driver : str ):
2943 self .dst_dir = dst_dir
3044 self .driver = driver
3145 # Path of 'dst_dir'/include/'driver'
3246 self .test_driver_include_dir = None #type: Path | None
47+
48+ def build_tree (self , src_dir : Path , exclude_files : Optional [Set [str ]] = None ) -> None :
49+ """
50+ Build a test driver tree from `src_dir`.
51+
52+ The source directory `src_dir` is expected to have the following structure:
53+ - an `include` directory
54+ - an `src` directory
55+ - the `include` directory contains exactly one subdirectory
56+
57+ Only the `include` and `src` directories from `src_dir` are used to build
58+ the test driver tree, and their directory structure is preserved.
59+
60+ Only "*.h" and "*.c" files are copied. Files whose names match any of the
61+ patterns in `exclude_files` are excluded.
62+
63+ The subdirectory inside `include` is renamed to `driver` in the test driver
64+ tree, and header inclusions are adjusted accordingly.
65+ """
66+ include = src_dir / "include"
67+ if not include .is_dir ():
68+ raise RuntimeError (f'Do not find "include" directory in { src_dir } ' )
69+
70+ src = src_dir / "src"
71+ if not src .is_dir ():
72+ raise RuntimeError (f'Do not find "src" directory in { src_dir } ' )
73+
74+ entries = list (include .iterdir ())
75+ if len (entries ) != 1 or not entries [0 ].is_dir ():
76+ raise RuntimeError (f"Found more than one directory in { include } " )
77+
78+ src_include_dir_name = entries [0 ].name
79+
80+ if (self .dst_dir / "include" ).exists ():
81+ shutil .rmtree (self .dst_dir / "include" )
82+
83+ if (self .dst_dir / "src" ).exists ():
84+ shutil .rmtree (self .dst_dir / "src" )
85+
86+ if exclude_files is None :
87+ exclude_files = set ()
88+
89+ for file in iter_code_files (src_dir ):
90+ if any (fnmatch (file .name , pattern ) for pattern in exclude_files ):
91+ continue
92+ dst = self .dst_dir / file .relative_to (src_dir )
93+ dst .parent .mkdir (parents = True , exist_ok = True )
94+ shutil .copy2 (file , dst )
95+
96+ self .test_driver_include_dir = self .dst_dir / "include" / self .driver
97+ (self .dst_dir / "include" / src_include_dir_name ).rename ( \
98+ self .test_driver_include_dir )
99+
100+ headers = {
101+ f .relative_to (self .test_driver_include_dir ).as_posix () \
102+ for f in self .test_driver_include_dir .rglob ("*.h" )
103+ }
104+ for f in iter_code_files (self .dst_dir ):
105+ self .__rewrite_inclusions_in_file (f , headers , \
106+ src_include_dir_name , self .driver )
107+
108+ @staticmethod
109+ def __rewrite_inclusions_in_file (file : Path , headers : Set [str ],
110+ src_include_dir : str , driver : str ,) -> None :
111+ """
112+ Rewrite `#include` directives in `file` that refer to `src_include_dir/...`
113+ so that they instead refer to `driver/...`.
114+
115+ For example:
116+ #include "mbedtls/private/aes.h"
117+ becomes:
118+ #include "libtestdriver1/private/aes.h"
119+ """
120+ include_line_re = re .compile (
121+ fr'^\s*#\s*include\s*([<"])\s*{ src_include_dir } /([^>"]+)\s*([>"])' ,
122+ re .MULTILINE
123+ )
124+ text = file .read_text (encoding = "utf-8" )
125+ changed = False
126+
127+ def repl (m : Match ) -> str :
128+ nonlocal changed
129+ header = m .group (2 )
130+ if header in headers :
131+ changed = True
132+ return f'#include { m .group (1 )} { driver } /{ header } { m .group (3 )} '
133+ return m .group (0 )
134+
135+ new_text = include_line_re .sub (repl , text )
136+ if changed :
137+ file .write_text (new_text , encoding = "utf-8" )
0 commit comments