Skip to content

Commit 4eeffe8

Browse files
test_driver.py: Add method to build the test driver tree
Signed-off-by: Ronald Cron <ronald.cron@arm.com>
1 parent 75540ee commit 4eeffe8

File tree

1 file changed

+105
-0
lines changed

1 file changed

+105
-0
lines changed

scripts/mbedtls_framework/test_driver.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
#
77

88
import argparse
9+
import re
10+
import shutil
911

12+
from fnmatch import fnmatch
1013
from pathlib import Path
14+
from typing import Iterable, Match, Optional, Set
1115

1216
def get_parsearg_base() -> argparse.ArgumentParser:
1317
""" Get base arguments for scripts building a TF-PSA-Crypto test driver """
@@ -23,10 +27,111 @@ def get_parsearg_base() -> argparse.ArgumentParser:
2327
help="Test driver name (default: %(default)s).")
2428
return parser
2529

30+
def iter_code_files(root: Path) -> Iterable[Path]:
31+
"""
32+
Iterate over all "*.c" and "*.h" files found recursively under the `include`
33+
and `src` subdirectories of `root`.
34+
"""
35+
for directory in ("include", "src"):
36+
directory_path = root / directory
37+
for ext in (".c", ".h"):
38+
yield from directory_path.rglob(f"*{ext}")
39+
2640
class TestDriverGenerator:
2741
"""A TF-PSA-Crypto test driver generator"""
2842
def __init__(self, dst_dir: Path, driver: str):
2943
self.dst_dir = dst_dir
3044
self.driver = driver
3145
# Path of 'dst_dir'/include/'driver'
3246
self.test_driver_include_dir = None #type: Path | None
47+
48+
def build_tree(self, src_dir: Path, exclude_files: Optional[Set[str]] = None) -> None:
49+
"""
50+
Build a test driver tree from `src_dir`.
51+
52+
The source directory `src_dir` is expected to have the following structure:
53+
- an `include` directory
54+
- an `src` directory
55+
- the `include` directory contains exactly one subdirectory
56+
57+
Only the `include` and `src` directories from `src_dir` are used to build
58+
the test driver tree, and their directory structure is preserved.
59+
60+
Only "*.h" and "*.c" files are copied. Files whose names match any of the
61+
patterns in `exclude_files` are excluded.
62+
63+
The subdirectory inside `include` is renamed to `driver` in the test driver
64+
tree, and header inclusions are adjusted accordingly.
65+
"""
66+
include = src_dir / "include"
67+
if not include.is_dir():
68+
raise RuntimeError(f'Do not find "include" directory in {src_dir}')
69+
70+
src = src_dir / "src"
71+
if not src.is_dir():
72+
raise RuntimeError(f'Do not find "src" directory in {src_dir}')
73+
74+
entries = list(include.iterdir())
75+
if len(entries) != 1 or not entries[0].is_dir():
76+
raise RuntimeError(f"Found more than one directory in {include}")
77+
78+
src_include_dir_name = entries[0].name
79+
80+
if (self.dst_dir / "include").exists():
81+
shutil.rmtree(self.dst_dir / "include")
82+
83+
if (self.dst_dir / "src").exists():
84+
shutil.rmtree(self.dst_dir / "src")
85+
86+
if exclude_files is None:
87+
exclude_files = set()
88+
89+
for file in iter_code_files(src_dir):
90+
if any(fnmatch(file.name, pattern) for pattern in exclude_files):
91+
continue
92+
dst = self.dst_dir / file.relative_to(src_dir)
93+
dst.parent.mkdir(parents=True, exist_ok=True)
94+
shutil.copy2(file, dst)
95+
96+
self.test_driver_include_dir = self.dst_dir / "include" / self.driver
97+
(self.dst_dir / "include" / src_include_dir_name).rename( \
98+
self.test_driver_include_dir)
99+
100+
headers = {
101+
f.relative_to(self.test_driver_include_dir).as_posix() \
102+
for f in self.test_driver_include_dir.rglob("*.h")
103+
}
104+
for f in iter_code_files(self.dst_dir):
105+
self.__rewrite_inclusions_in_file(f, headers, \
106+
src_include_dir_name, self.driver)
107+
108+
@staticmethod
109+
def __rewrite_inclusions_in_file(file: Path, headers: Set[str],
110+
src_include_dir: str, driver: str,) -> None:
111+
"""
112+
Rewrite `#include` directives in `file` that refer to `src_include_dir/...`
113+
so that they instead refer to `driver/...`.
114+
115+
For example:
116+
#include "mbedtls/private/aes.h"
117+
becomes:
118+
#include "libtestdriver1/private/aes.h"
119+
"""
120+
include_line_re = re.compile(
121+
fr'^\s*#\s*include\s*([<"])\s*{src_include_dir}/([^>"]+)\s*([>"])',
122+
re.MULTILINE
123+
)
124+
text = file.read_text(encoding="utf-8")
125+
changed = False
126+
127+
def repl(m: Match) -> str:
128+
nonlocal changed
129+
header = m.group(2)
130+
if header in headers:
131+
changed = True
132+
return f'#include {m.group(1)}{driver}/{header}{m.group(3)}'
133+
return m.group(0)
134+
135+
new_text = include_line_re.sub(repl, text)
136+
if changed:
137+
file.write_text(new_text, encoding="utf-8")

0 commit comments

Comments
 (0)