diff --git a/.github/workflows/mkdocs.yml b/.github/workflows/mkdocs.yml new file mode 100644 index 0000000..ce91ff7 --- /dev/null +++ b/.github/workflows/mkdocs.yml @@ -0,0 +1,15 @@ +name: build-mkdocs +on: + push: + branches: + - main +jobs: + deploy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.11 + - run: pip install mkdocs mkdocs-material mkdocstrings[python] mkdocs-git-revision-date-localized-plugin mkdocs-git-committers-plugin-2 + - run: mkdocs gh-deploy --force --clean --verbose diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..3316597 --- /dev/null +++ b/.gitignore @@ -0,0 +1,163 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Project specific +dataset/ diff --git a/README.md b/README.md index 22297e4..ecf0f72 100644 --- a/README.md +++ b/README.md @@ -1,19 +1,31 @@ # mmushell -MMUShell OS-Agnostic Memory Forensics Tool -Proof of concept for techniques developed by Andrea Oliveri and Davide Balzarotti in +## Description -["In the Land of MMUs: Multiarchitecture OS-Agnostic Virtual Memory Forensics"](https://doi.org/10.1145/3528102) +MMUShell is an OS-Agnostic memory morensics tool, a proof of concept for techniques developed by Andrea Oliveri and Davide Balzarotti in ["In the Land of MMUs: Multiarchitecture OS-Agnostic Virtual Memory Forensics"](https://doi.org/10.1145/3528102). -Installation: +The first step required to perform any analysis of a physical memory image is the reconstruction of the virtual address spaces, which allows translating virtual addresses to their corresponding physical offsets. However, this phase is often overlooked, and the challenges related to it are rarely discussed in the literature. Practical tools solve the problem by using a set of custom heuristics tailored on a very small number of well-known operating systems (OSs) running on few architectures. + +In the whitepaper, we look for the first time at all the different ways the virtual to physical translation can be operated in 10 different CPU architectures. In each case, we study the inviolable constraints imposed by the memory management unit that can be used to build signatures to recover the required data structures from memory without any knowledge about the running OS. + +This tool allows to experiment with the extraction of virtual address spaces, showing the challenges of performing an OS-agnostic virtual to physical address translation in real-world scenarios. +It was tested on a large set of 26 different OSs, 6 architectures and a use case on a real hardware device. + +## Quick installation + +On a standard Linux distribution : +```shell +$ python -m venv --system-site-packages --symlinks venv +$ venv/bin/pip install -r requirements.txt ``` -pip install -r requirements.txt + +On Nix/NixOS : +```shell +$ nix develop +# or with direnv +$ direnv allow . ``` -Usage: -- Dump all the RAM areas of the machine that you want to analyze in raw format, one file per physical memory area. -- Create a YAML file describing the hardware configuration of the machine (see the examples available in the dataset). -- ```mmushell machine.yaml``` -- Use the interactive shell to find MMU registers, Radix-Trees, Hash tables etc. and explore them. The ```help``` command lists all the possible actions available for the selected CPU architecture. -- [Here](https://www.s3.eurecom.fr/datasets/datasets_old_www/mmushell_dataset.tar) part of the dataset containing the memory dumps of the OSs used in the paper (only the open-source ones, due to license restrictions). -- ```/qemu/``` contains the patch for QEMU 5.0.0 in order to collect the ground truth values of the MMU registers during OSs execution. +## Documentation + +You can find the updated documentation [here](https://memoscopy.github.io/mmushell), where you will find tutorials, how-to guides, references and explanations on this project. diff --git a/converter.py b/converter.py index 563fa31..afdc8bc 100755 --- a/converter.py +++ b/converter.py @@ -1,13 +1,19 @@ #!/usr/bin/env python3 import yaml +import json import argparse + from sortedcontainers import SortedDict -import json + def main(): parser = argparse.ArgumentParser() - parser.add_argument("MACHINE_CONFIG", help="YAML file describing the machine", type=argparse.FileType("r")) + parser.add_argument( + "MACHINE_CONFIG", + help="YAML file describing the machine", + type=argparse.FileType("r"), + ) args = parser.parse_args() # Load machine config @@ -21,153 +27,166 @@ def main(): # Call the exporter exporter(machine_config) + def exporter(machine_config): - """Convert dump set into an ELF file containg the physical address space""" - - architecture = machine_config["cpu"]["architecture"] - bits = machine_config["cpu"]["bits"] - endianness = machine_config["cpu"]["endianness"] - prefix = machine_config["memspace"]["ram"][0]["dumpfile"].split(".")[0] - - with open(prefix + ".elf", "wb") as elf_fd: - # Create the ELF header and write it on the file - machine_data = { - "QEMUArchitecture": architecture, - "Uptime": -1, - "CPURegisters": [], - "MemoryMappedDevices": [(x["start"], x["end"] + 1) for x in machine_config["memspace"]["not_ram"]], - "MMUMode": machine_config["mmu"]["mode"] - } - - # Create ELF main header - if architecture == "aarch64": - e_machine = 0xB7 - elif architecture == "arm": - e_machine = 0x28 - elif architecture == "riscv": - e_machine = 0xF3 - elif architecture == "intel": - if bits == 64: - e_machine = 0x3E - else: - e_machine = 0x03 - machine_data["CPUSpecifics"] = {"MAXPHYADDR": machine_config["cpu"]["processor_features"]["m_phy"]} + """Convert dump set into an ELF file containg the physical address space""" + + architecture = machine_config["cpu"]["architecture"] + bits = machine_config["cpu"]["bits"] + endianness = machine_config["cpu"]["endianness"] + prefix = machine_config["memspace"]["ram"][0]["dumpfile"].split(".")[0] + + with open(prefix + ".elf", "wb") as elf_fd: + # Create the ELF header and write it on the file + machine_data = { + "QEMUArchitecture": architecture, + "Uptime": -1, + "CPURegisters": [], + "MemoryMappedDevices": [ + (x["start"], x["end"] + 1) + for x in machine_config["memspace"]["not_ram"] + ], + "MMUMode": machine_config["mmu"]["mode"], + } + + # Create ELF main header + if architecture == "aarch64": + e_machine = 0xB7 + elif architecture == "arm": + e_machine = 0x28 + elif architecture == "riscv": + e_machine = 0xF3 + elif architecture == "intel": + if bits == 64: + e_machine = 0x3E else: - raise Exception("Unsupported architecture") - - e_ehsize = 0x40 - e_phentsize = 0x38 - elf_h = bytearray(e_ehsize) - elf_h[0x00:0x04] = b'\x7fELF' # Magic - elf_h[0x04] = 2 # Elf type - elf_h[0x05] = 1 if endianness == "little" else 2 # Endianness - elf_h[0x06] = 1 # Version - elf_h[0x10:0x12] = 0x4.to_bytes(2, endianness) # e_type - elf_h[0x12:0x14] = e_machine.to_bytes(2, endianness) # e_machine - elf_h[0x14:0x18] = 0x1.to_bytes(4, endianness) # e_version - elf_h[0x34:0x36] = e_ehsize.to_bytes(2, endianness) # e_ehsize - elf_h[0x36:0x38] = e_phentsize.to_bytes(2, endianness) # e_phentsize - elf_fd.write(elf_h) - - regions = SortedDict() - for region in machine_config["memspace"]["ram"]: - regions[(region["start"], region["end"] + 1)] = region["dumpfile"] - for region in machine_config["memspace"]["not_ram"]: - regions[(region["start"], region["end"] + 1)] = None - - # Write segments in the new file and fill the program header - p_offset = len(elf_h) - offset2p_offset = {} - - for (begin, end), dump_file in regions.items(): - # Not write not RAM regions - if dump_file is None: - offset2p_offset[(begin, end)] = -1 - continue - - # Write physical RAM regions - offset2p_offset[(begin, end)] = p_offset - with open(dump_file, "rb") as region_fd: - elf_fd.write(region_fd.read()) - p_offset += (end - begin) - - # Create FOSSIL NOTE segment style - pad = 4 - name = "FOSSIL" - n_type = 0xDEADC0DE - name_b = name.encode() - name_b += b"\x00" - namesz = len(name_b).to_bytes(pad, endianness) - name_b += bytes(pad - (len(name_b) % pad)) - - descr_b = json.dumps(machine_data).encode() - descr_b += b"\x00" - descr_b += bytes(pad - (len(descr_b) % pad)) - descrsz = len(descr_b).to_bytes(pad, endianness) - - machine_note = namesz + descrsz + n_type.to_bytes(pad, endianness) + name_b + descr_b - len_machine_note = len(machine_note) - elf_fd.write(machine_note) - - # Create the program header - # Add FOSSIL NOTE entry style - p_header = bytes() - note_entry = bytearray(e_phentsize) - note_entry[0x00:0x04] = 0x4.to_bytes(4, endianness) # p_type - note_entry[0x08:0x10] = p_offset.to_bytes(8, endianness) # p_offset - note_entry[0x20:0x28] = len_machine_note.to_bytes(8, endianness) # p_filesz - - p_offset += len_machine_note - e_phoff = p_offset - p_header += note_entry - - # Add all the segments (ignoring not in RAM pages) - for (begin, end), offset in offset2p_offset.items(): - if offset == -1: - p_filesz = 0 - pmask = 6 - offset = 0 - else: - p_filesz = end - begin - pmask = 7 - - segment_entry = bytearray(e_phentsize) - segment_entry[0x00:0x04] = 0x1.to_bytes(4, endianness) # p_type - segment_entry[0x04:0x08] = pmask.to_bytes(4, endianness) # p_flags - segment_entry[0x10:0x18] = begin.to_bytes(8, endianness) # p_vaddr - segment_entry[0x18:0x20] = begin.to_bytes(8, endianness) # p_paddr Original offset - segment_entry[0x28:0x30] = (end - begin).to_bytes(8, endianness) # p_memsz - segment_entry[0x08:0x10] = offset.to_bytes(8, endianness) # p_offset - segment_entry[0x20:0x28] = p_filesz.to_bytes(8, endianness) # p_filesz - - p_header += segment_entry - - # Write the segment header - elf_fd.write(p_header) - s_header_pos = elf_fd.tell() # Last position written (used if we need to write segment header) - e_phnum = len(regions) + 1 - - # Modify the ELF header to point to program header - elf_fd.seek(0x20) - elf_fd.write(e_phoff.to_bytes(8, endianness)) # e_phoff - - # If we have more than 65535 segments we have create a special Section entry contains the - # number of program entry (as specified in ELF64 specifications) - if e_phnum < 65536: - elf_fd.seek(0x38) - elf_fd.write(e_phnum.to_bytes(2, endianness)) # e_phnum + e_machine = 0x03 + machine_data["CPUSpecifics"] = { + "MAXPHYADDR": machine_config["cpu"]["processor_features"]["m_phy"] + } + else: + raise Exception("Unsupported architecture") + + e_ehsize = 0x40 + e_phentsize = 0x38 + elf_h = bytearray(e_ehsize) + elf_h[0x00:0x04] = b"\x7fELF" # Magic + elf_h[0x04] = 2 # Elf type + elf_h[0x05] = 1 if endianness == "little" else 2 # Endianness + elf_h[0x06] = 1 # Version + elf_h[0x10:0x12] = 0x4.to_bytes(2, endianness) # e_type + elf_h[0x12:0x14] = e_machine.to_bytes(2, endianness) # e_machine + elf_h[0x14:0x18] = 0x1.to_bytes(4, endianness) # e_version + elf_h[0x34:0x36] = e_ehsize.to_bytes(2, endianness) # e_ehsize + elf_h[0x36:0x38] = e_phentsize.to_bytes(2, endianness) # e_phentsize + elf_fd.write(elf_h) + + regions = SortedDict() + for region in machine_config["memspace"]["ram"]: + regions[(region["start"], region["end"] + 1)] = region["dumpfile"] + for region in machine_config["memspace"]["not_ram"]: + regions[(region["start"], region["end"] + 1)] = None + + # Write segments in the new file and fill the program header + p_offset = len(elf_h) + offset2p_offset = {} + + for (begin, end), dump_file in regions.items(): + # Not write not RAM regions + if dump_file is None: + offset2p_offset[(begin, end)] = -1 + continue + + # Write physical RAM regions + offset2p_offset[(begin, end)] = p_offset + with open(dump_file, "rb") as region_fd: + elf_fd.write(region_fd.read()) + p_offset += end - begin + + # Create FOSSIL NOTE segment style + pad = 4 + name = "FOSSIL" + n_type = 0xDEADC0DE + name_b = name.encode() + name_b += b"\x00" + namesz = len(name_b).to_bytes(pad, endianness) + name_b += bytes(pad - (len(name_b) % pad)) + + descr_b = json.dumps(machine_data).encode() + descr_b += b"\x00" + descr_b += bytes(pad - (len(descr_b) % pad)) + descrsz = len(descr_b).to_bytes(pad, endianness) + + machine_note = ( + namesz + descrsz + n_type.to_bytes(pad, endianness) + name_b + descr_b + ) + len_machine_note = len(machine_note) + elf_fd.write(machine_note) + + # Create the program header + # Add FOSSIL NOTE entry style + p_header = bytes() + note_entry = bytearray(e_phentsize) + note_entry[0x00:0x04] = 0x4.to_bytes(4, endianness) # p_type + note_entry[0x08:0x10] = p_offset.to_bytes(8, endianness) # p_offset + note_entry[0x20:0x28] = len_machine_note.to_bytes(8, endianness) # p_filesz + + p_offset += len_machine_note + e_phoff = p_offset + p_header += note_entry + + # Add all the segments (ignoring not in RAM pages) + for (begin, end), offset in offset2p_offset.items(): + if offset == -1: + p_filesz = 0 + pmask = 6 + offset = 0 else: - elf_fd.seek(0x28) - elf_fd.write(s_header_pos.to_bytes(8, endianness)) # e_shoff - elf_fd.seek(0x38) - elf_fd.write(0xFFFF.to_bytes(2, endianness)) # e_phnum - elf_fd.write(0x40.to_bytes(2, endianness)) # e_shentsize - elf_fd.write(0x1.to_bytes(2, endianness)) # e_shnum - - section_entry = bytearray(0x40) - section_entry[0x2C:0x30] = e_phnum.to_bytes(4, endianness) # sh_info - elf_fd.seek(s_header_pos) - elf_fd.write(section_entry) - -if __name__ == '__main__': + p_filesz = end - begin + pmask = 7 + + segment_entry = bytearray(e_phentsize) + segment_entry[0x00:0x04] = 0x1.to_bytes(4, endianness) # p_type + segment_entry[0x04:0x08] = pmask.to_bytes(4, endianness) # p_flags + segment_entry[0x10:0x18] = begin.to_bytes(8, endianness) # p_vaddr + segment_entry[0x18:0x20] = begin.to_bytes( + 8, endianness + ) # p_paddr Original offset + segment_entry[0x28:0x30] = (end - begin).to_bytes(8, endianness) # p_memsz + segment_entry[0x08:0x10] = offset.to_bytes(8, endianness) # p_offset + segment_entry[0x20:0x28] = p_filesz.to_bytes(8, endianness) # p_filesz + + p_header += segment_entry + + # Write the segment header + elf_fd.write(p_header) + s_header_pos = ( + elf_fd.tell() + ) # Last position written (used if we need to write segment header) + e_phnum = len(regions) + 1 + + # Modify the ELF header to point to program header + elf_fd.seek(0x20) + elf_fd.write(e_phoff.to_bytes(8, endianness)) # e_phoff + + # If we have more than 65535 segments we have create a special Section entry contains the + # number of program entry (as specified in ELF64 specifications) + if e_phnum < 65536: + elf_fd.seek(0x38) + elf_fd.write(e_phnum.to_bytes(2, endianness)) # e_phnum + else: + elf_fd.seek(0x28) + elf_fd.write(s_header_pos.to_bytes(8, endianness)) # e_shoff + elf_fd.seek(0x38) + elf_fd.write(0xFFFF.to_bytes(2, endianness)) # e_phnum + elf_fd.write(0x40.to_bytes(2, endianness)) # e_shentsize + elf_fd.write(0x1.to_bytes(2, endianness)) # e_shnum + + section_entry = bytearray(0x40) + section_entry[0x2C:0x30] = e_phnum.to_bytes(4, endianness) # sh_info + elf_fd.seek(s_header_pos) + elf_fd.write(section_entry) + + +if __name__ == "__main__": main() diff --git a/docs/explanation.md b/docs/explanation.md new file mode 100644 index 0000000..2a9e4fd --- /dev/null +++ b/docs/explanation.md @@ -0,0 +1,22 @@ + + +The first step required to perform any analysis of a physical memory image is the reconstruction of the virtual address spaces, which allows translating virtual addresses to their corresponding physical offsets. However, this phase is often overlooked, and the challenges related to it are rarely discussed in the literature. Practical tools solve the problem by using a set of custom heuristics tailored on a very small number of well-known operating systems (OSs) running on few architectures. + +In the whitepaper, we look for the first time at all the different ways the virtual to physical translation can be operated in 10 different CPU architectures. In each case, we study the inviolable constraints imposed by the memory management unit that can be used to build signatures to recover the required data structures from memory without any knowledge about the running OS. + +This tool allows to experiment with the extraction of virtual address spaces, showing the challenges of performing an OS-agnostic virtual to physical address translation in real-world scenarios. +It was tested on a large set of 26 different OSs, 6 architectures and a use case on a real hardware device. diff --git a/docs/how-to-guides.md b/docs/how-to-guides.md new file mode 100644 index 0000000..abcda5f --- /dev/null +++ b/docs/how-to-guides.md @@ -0,0 +1,40 @@ + diff --git a/docs/index.md b/docs/index.md new file mode 100644 index 0000000..16ca97c --- /dev/null +++ b/docs/index.md @@ -0,0 +1,17 @@ +This site contains the project documentation for the +`mmushell` project, an OS-Agnostic memory morensics tool, a proof of concept for techniques developed by Andrea Oliveri and Davide Balzarotti in ["In the Land of MMUs: Multiarchitecture OS-Agnostic Virtual Memory Forensics"](https://doi.org/10.1145/3528102). + +## Table of contents + +The documentation follows the best practice for +project documentation as described by Daniele Procida +in the [Diátaxis documentation framework](https://diataxis.fr/) +and consists of four separate parts: + +1. [Tutorials](tutorials.md) +2. [How-To Guides](how-to-guides.md) +3. [Reference](reference/reference.md) +4. [Explanation](explanation.md) + +Quickly find what you're looking for depending on +your use case by looking at the different pages. diff --git a/docs/reference/architectures/aarch64.md b/docs/reference/architectures/aarch64.md new file mode 100644 index 0000000..fa74dcf --- /dev/null +++ b/docs/reference/architectures/aarch64.md @@ -0,0 +1,3 @@ +## AArch64 + +::: mmushell.architectures.aarch64 diff --git a/docs/reference/architectures/arm.md b/docs/reference/architectures/arm.md new file mode 100644 index 0000000..657b406 --- /dev/null +++ b/docs/reference/architectures/arm.md @@ -0,0 +1,3 @@ +## ARM + +::: mmushell.architectures.arm diff --git a/docs/reference/architectures/generic.md b/docs/reference/architectures/generic.md new file mode 100644 index 0000000..3136b19 --- /dev/null +++ b/docs/reference/architectures/generic.md @@ -0,0 +1,3 @@ +## Generic + +::: mmushell.architectures.generic diff --git a/docs/reference/architectures/intel.md b/docs/reference/architectures/intel.md new file mode 100644 index 0000000..b3a650e --- /dev/null +++ b/docs/reference/architectures/intel.md @@ -0,0 +1,3 @@ +## intel + +::: mmushell.architectures.intel diff --git a/docs/reference/architectures/mips.md b/docs/reference/architectures/mips.md new file mode 100644 index 0000000..19d8b9a --- /dev/null +++ b/docs/reference/architectures/mips.md @@ -0,0 +1,3 @@ +## MIPS + +::: mmushell.architectures.mips diff --git a/docs/reference/architectures/ppc.md b/docs/reference/architectures/ppc.md new file mode 100644 index 0000000..096aac0 --- /dev/null +++ b/docs/reference/architectures/ppc.md @@ -0,0 +1,3 @@ +## PowerPC + +::: mmushell.architectures.ppc diff --git a/docs/reference/architectures/riscv.md b/docs/reference/architectures/riscv.md new file mode 100644 index 0000000..5f50552 --- /dev/null +++ b/docs/reference/architectures/riscv.md @@ -0,0 +1,3 @@ +## RISC-V + +::: mmushell.architectures.riscv diff --git a/docs/reference/exporter.md b/docs/reference/exporter.md new file mode 100644 index 0000000..34fd3d3 --- /dev/null +++ b/docs/reference/exporter.md @@ -0,0 +1,3 @@ +## Exporter + +::: mmushell.exporter diff --git a/docs/reference/mmushell.md b/docs/reference/mmushell.md new file mode 100644 index 0000000..b60f66c --- /dev/null +++ b/docs/reference/mmushell.md @@ -0,0 +1,3 @@ +## MMUShell + +::: mmushell.mmushell diff --git a/docs/reference/reference.md b/docs/reference/reference.md new file mode 100644 index 0000000..75c8be0 --- /dev/null +++ b/docs/reference/reference.md @@ -0,0 +1,5 @@ + + diff --git a/docs/tutorials.md b/docs/tutorials.md new file mode 100644 index 0000000..6e348f9 --- /dev/null +++ b/docs/tutorials.md @@ -0,0 +1,180 @@ +## Organisation + +- `mmushell/architectures/` : various architectures parsers and a generic one +- `mmushell/mmushell.py` : main script allowing to reconstruct virtual address spaces from a memory dump, more instructions below +- `mmushell/exporter.py` : this is a POC showing the possible use of techniques to perform a preliminary analysis of a dump by exporting each virtual address space as a self-contained ELF Core dump file. See section [TOWARDS OS AGNOSTIC MEMORY FORENSICS](https://www.s3.eurecom.fr/docs/tops22_oliveri.pdf). +- `converter.py` : export dump to be used in [Fossil](https://github.com/eurecom-s3/fossil). It adds CPU registers and convert the kernel physical address space in virtual address space one. **Note**: you can ignore this script, is not part of mmushell +- `qemu/` : contains scripts and patch necessary to get ground truth registers values from an emulated system + +## Quick installation + +On a standard Linux distribution : +```shell +$ python -m venv --system-site-packages --symlinks venv +$ venv/bin/pip install -r requirements.txt +``` + +On Nix/NixOS : +```shell +$ nix develop +# or with direnv +$ direnv allow . +``` + +## Usage + +### Dataset + +[Here](https://www.s3.eurecom.fr/datasets/datasets_old_www/mmushell_dataset.tar) part of the dataset containing the memory dumps of the OSs used in the paper (only the open-source ones, due to license restrictions). + +In each archive there are a minimum of 4 files and require at least 4GB of free space (decompressed): + +- `XXX.regs` : contains the values of the registers collected by QEMU during the execution (the ground truth), pickle format, to be used (optionally) with --gtruth option of mmushell + +- `XXX.yaml` : contains the hardware configuration of the machine which has run the OS, YAML file, to be used as argument of mmushell + +- `XXX.dump.Y` : chunk of the RAM dump of the machine + +- `XXX.lzma` : an mmushell session file, it contains the output of mmushell, pickle lzma format, to be used (optionally) with `--session` option of mmushell + +The use of `XXX.lzma` allows to avoid reexecuting the parsing and data structure reconstructing phase, gaining time! + +### CLI + +MMUShell must be run in the folder containing dump/configuration files as all the paths are relatives. + +!!! warning + Some OSs require a minimum of 32GB of RAM to be parsed (the Intel 32bit ones, in particular HaikuOS) or a minimum of 1 hour of execution (independently by the number of the CPU cores, Intel 32/PAE/IA64 OSs) + + Consider using the session file for them to gain time. + +Help : + +```shell +$ mmushell.py +usage: mmushell.py [-h] [--gtruth GTRUTH] [--session SESSION] [--debug] MACHINE_CONFIG +mmushell.py: error: the following arguments are required: MACHINE_CONFIG +``` + +1. Dump all the RAM areas of the machine that you want to analyze in raw format, one file per physical memory area. +2. Create a YAML file describing the hardware configuration of the machine (see the examples available in the dataset) + The format is the following : + ```yaml + cpu: + # Architecture type + architecture: (aarch64|arm|intel|mips|ppc|riscv) + + # Endianness level + endianness: (big|little) + + # Bits used by architecture + bits: (32|64) + + mmu: + # MMU mode varying from architectures + mode: (ia64|ia32|pae|sv32|sv39|sv48|ppc32|mips32|Short|Long) # any class that inherits from MMU + # ^^^^^^^^^ ^^^^^^^^^^ ^^ ^^^ ^^^^^^ + # intel riscv ppc mips arm + + memspace: + # Physical RAM space region + ram: + - start: 0x0000000080000000 # ram start address + end: 0x000000017fffffff # ram end address + dumpfile: linux.dump + + # Physical memory regions that are not RAM + # Example: reserved regions for MMIO, ROM, ... See https://en.wikipedia.org/wiki/Memory-mapped_I/O_and_port-mapped_I/O#Examples + # Those portions are needed because page tables also maps these special physical addresses, so the CPU can use these associated + # virtual addresses to write or read from them. We need to distinguish them otherwise we can misinterpret some page tables as data pages. + not_ram: + - start: 0x0000000000000000 + end: 0x0000000000011fff + + - start: 0x0000000000100000 + end: 0x0000000000101023 + # ... + ``` + +3. Launch mmushell with the configuration file. Example with the provided RISC-V SV39 memory dump : + ```shell + $ mmushell.py dataset/riscv/sv39/linux/linux.yaml + MMUShell. Type help or ? to list commands. + + [MMUShell riscv]# ? + + Documented commands (type help ): + ======================================== + exit help parse_memory show_radix_trees + find_radix_trees ipython save_data show_table + ``` + Use the interactive shell to find MMU registers, Radix-Trees, Hash tables etc. and explore them. The `help` command lists all the possible actions available for the selected CPU architecture. + +### Ground truth + +`XXXX_gtruth` commands are available only if you load a `XXX.regs` file as they compare found results with the ground truth. +These commands have an equivalent command which show only the results found by MMUShell without comparing them with the ground truth. + +!!! note + The folder `qemu/` contains the patch for QEMU 5.0.0 in order to collect the ground truth values of the MMU registers during OSs execution. Please read the concerned README in qemu/README.md. + +### Notes and procedures + +As mmushell available commands differs from one architecture to another, here are different steps needed to be performed in order. + +!!! note + Steps prefixed with "*" are necessary only if you don't use session files (e.g: `XXX.lzma`). + +**RISC-V** + +1. *`parse_memory` : find MMU tables +2. *`find_radix_trees` : reconstruct radix trees +3. `show_radix_trees_gtruth` : compare radix trees found with the ground truth + +**MIPS** + +1. *`parse_memory` -> find MMU opcodes +2. *`find_registers_values` -> perform dataflow analysis and retrieve registers values +3. `show_registers_gtruth` -> compare registers found with the ground truth + +**PowerPC** + +1. *`parse_memory` -> find MMU opcodes and hash tables +2. *`find_registers_values` -> perform dataflow analisys and retrieve registers values +3. `show_hashtables_gtruth` -> compare the hash table found with the ground truth +4. `show_registers_gtruth` -> compare the registers found with the ground truth + +???+ note + For **Linux Debian**, **Mac OS X**, **Morphos** : `show_hashtables_gtruth` shows another Hash Table not retrieved by MMUShell, but it is a table used during startup (as shown by the timestamp) and we ignore it because it does not used during normal OS operation. + + For **Mac OS X** : ignore BAT registers values in `show_registers_gtruth` as it uses different values for each process (as shown by the ground truth), the falses and positives results are purely a coincidence. + +**Intel** + +1. *`parse_memory` -> look for tables and IDT +2. *`find_radix_trees` -> reconstruct radix trees (could be slow) +3. *`show_idtrs_gtruth` -> show a comparization between true and found IDTs (note: some OSs define multiple IDTs, one per core). We deliberately ignore IDT table used during the boot phase (see PowerPC notes) +4. `show_radix_trees_gtruth XXXX` -> where XXX must be the PHYSICAL address of the last IDT used by the system shown by show_idtrs_gtruth". Shows a comparization between true and found radix trees which resolve the IDT PHYSICAL ADDRESS XXXX (obtained by the previous command, for our statistics we always use a true positive one) + +???+ note + For **BarrellfishOS**, **OmniOS** : those allocate a different IDT for every single CPU core. Some processes are core specific and are able to resolve only the IDT of the same core. For each IDT found, MMUShell shows different proccesses as FN. They are not real false negatives because are the per-core processes which are not able to resolve that specific IDT but are found by MMUShell using one among the other IDT. + + For **RCore** : if you enter the physical address of the real IDT used by the system in `show_radix_trees_gtruth`, MMUShell does not show any entry because it has found a different IDT (a FP) and has no valid radix trees for the real IDT. Please use `show_idtrs` to show the IDT found (YYY) and `show_radix_trees XXXX` to show the radix trees associated (all FP). + +**ARM** + +1. *`find_registers_values` -> perform dataflow analysis to recover TTBCR value +2. *`show_registers_gtruth` -> compare the values retrieved with the ground truth +3. *`set_ttbcr XXX` -> use REAL value of the TTBCR shown by the previous command +4. *`find_tables` -> find MMU tables +5. *`find_radix_trees` -> reconstruct radix trees +6. `show_radix_trees_gtruth` -> compare radix trees found with the ground truth + +**AArch64** + +1. *`find_registers_values` -> perform dataflow analysis to recover TCR value +2. *`show_registers_gtruth` -> compare the values retrieved with the ground truth +3. *`set_tcr XXX` -> use REAL value of the TCR shown by the previous command +4. *`find_tables` -> find MMU tables +5. *`find_radix_trees` -> reconstruct radix trees +6. `show_radix_trees_gtruth` -> compare radix trees found with the ground truth diff --git a/exporter.py b/exporter.py deleted file mode 100755 index c0a2c86..0000000 --- a/exporter.py +++ /dev/null @@ -1,1002 +0,0 @@ -#!/usr/bin/env python3 - -import numpy as np -from elftools.elf.elffile import ELFFile -from elftools.elf.segments import NoteSegment -import json -from collections import defaultdict -from struct import iter_unpack -from tqdm import tqdm -from bisect import bisect -import argparse -import json -from compress_pickle import load as load_c -from pickle import load -import traceback - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("PHY_ELF", help="Dump file in ELF format", type=str) - parser.add_argument("MMU_DATA", help="List of DTBs and MMU configuration registers", type=argparse.FileType("rb")) - args = parser.parse_args() - - # Load session file - try: - mmu_data = load(args.MMU_DATA) - except Exception as e: - print(f"Error: {e}") - exit(1) - - # Load ELF file - elf_dump = ELFDump(args.PHY_ELF) - - # Dump processes - for idx, process_mmu_data in enumerate(tqdm(mmu_data)): - try: - virtspace = get_virtspace(elf_dump, process_mmu_data) - virtspace.export_virtual_memory_elf(f"process.{idx}.elf") - except Exception as e: - print(f"Error during process exporting: {e}") - # print(traceback.format_exc()) - -class IMSimple: - """Fast search in intervals (begin) (end)""" - def __init__(self, keys, values): - self.keys = keys - self.values = values - - def __getitem__(self, x): - idx = bisect(self.keys, x) - 1 - begin = self.keys[idx] - if begin <= x < self.values[idx]: - return x - begin - else: - return -1 - - def contains(self, x, size): - idx = bisect(self.keys, x) - 1 - begin = self.keys[idx] - end = self.values[idx] - if not(begin <= x < end) or x + size >= end: - return -1 - else: - return x - begin - - def get_values(self): - return zip(self.keys, self.values) - - def get_extremes(self): - return self.keys[0], self.values[-1] - -class IMData: - """Fast search in intervals (begin), (end, associated data)""" - def __init__(self, keys, values): - self.keys = keys - self.values = values - - def __getitem__(self, x): - idx = bisect(self.keys, x) - 1 - begin = self.keys[idx] - end, data = self.values[idx] - if begin <= x < end: - return data - else: - return -1 - - def contains(self, x, size): - idx = bisect(self.keys, x) - 1 - begin = self.keys[idx] - end, data = self.values[idx] - if not(begin <= x < end) or x + size >= end: - return -1 - else: - return data - - def get_values(self): - return zip(self.keys, self.values) - - def get_extremes(self): - return self.keys[0], self.values[-1][0] - -class IMOffsets: - """Fast search in intervals (begin), (end, associated offset)""" - def __init__(self, keys, values): - self.keys = keys - self.values = values - - def __getitem__(self, x): - idx = bisect(self.keys, x) - 1 - begin = self.keys[idx] - end, data = self.values[idx] - if begin <= x < end: - return x - begin + data - else: - return -1 - - def contains(self, x, size): - """Return the maximum size and the list of intervals""" - idx = bisect(self.keys, x) - 1 - begin = self.keys[idx] - end, data = self.values[idx] - if not(begin <= x < end): - return 0, [] - - intervals = [(x, min(end - x, size), x - begin + data)] - if end - x >= size: - return size, intervals - - # The address space requested is bigger than a single interval - start = end - remaining = size - (end - x) - idx += 1 - print(start, remaining, idx) - while idx < len(self.values): - begin = self.keys[idx] - end, data = self.values[idx] - - # Virtual addresses must be contigous - if begin != start: - return size - remaining, intervals - - interval_size = min(end - begin, remaining) - intervals.append((start, interval_size, data)) - remaining -= interval_size - if not remaining: - return size, intervals - start += interval_size - idx += 1 - - def get_values(self): - return zip(self.keys, self.values) - - def get_extremes(self): - return self.keys[0], self.values[-1][0] - - -class IMOverlapping: - """Fast search in overlapping intervals (begin), (end, [associated - offsets])""" - - def __init__(self, intervals): - limit2changes = defaultdict(lambda: ([], [])) - for idx, (l, r, v) in enumerate(intervals): - assert l < r - limit2changes[l][0].append(v) - limit2changes[r][1].append(v) - self.limits, changes = zip(*sorted(limit2changes.items())) - - self.results = [[]] - s = set() - offsets = {} - res = [] - for idx, (arrivals, departures) in enumerate(changes): - - s.difference_update(departures) - for i in departures: - offsets.pop(i) - - for i in s: - offsets[i] += (self.limits[idx] - self.limits[idx - 1]) - - s.update(arrivals) - for i in arrivals: - offsets[i] = 0 - - res.clear() - for k,v in offsets.items(): - res.extend([i + v for i in k]) - self.results.append(res.copy()) - - def __getitem__(self, x): - idx = bisect(self.limits, x) - k = x - self.limits[idx - 1] - return [k + p for p in self.results[idx]] - - def get_values(self): - return zip(self.limits, self.results) - - -class ELFDump: - def __init__(self, elf_filename): - self.filename = elf_filename - self.machine_data = {} - self.p2o = None # Physical to RAM (ELF offset) - self.o2p = None # RAM (ELF offset) to Physical - self.p2mmd = None # Physical to Memory Mapped Devices (ELF offset) - self.elf_buf = np.zeros(0, dtype=np.byte) - self.elf_filename = elf_filename - - with open(self.elf_filename, "rb") as elf_fd: - - # Load the ELF in memory - self.elf_buf = np.fromfile(elf_fd, dtype=np.byte) - elf_fd.seek(0) - - # Parse the ELF file - self.__read_elf_file(elf_fd) - - def __read_elf_file(self, elf_fd): - """Parse the dump in ELF format""" - o2p_list = [] - p2o_list = [] - p2mmd_list = [] - elf_file = ELFFile(elf_fd) - - for segm in elf_file.iter_segments(): - - # NOTES - if isinstance(segm, NoteSegment): - for note in segm.iter_notes(): - - # Ignore NOTE genrated by other softwares - if note["n_name"] != "FOSSIL": - continue - - # At moment only one type of note - if note["n_type"] != 0xdeadc0de: - continue - - # Suppose only one deadcode note - self.machine_data = json.loads(note["n_desc"].rstrip("\x00")) - self.machine_data["Endianness"] = "little" if elf_file.header["e_ident"].EI_DATA == "ELFDATA2LSB" else "big" - self.machine_data["Architecture"] = "_".join(elf_file.header["e_machine"].split("_")[1:]) - else: - # Fill arrays needed to translate physical addresses to file offsets - r_start = segm["p_vaddr"] - r_end = r_start + segm["p_memsz"] - - if segm["p_filesz"]: - p_offset = segm["p_offset"] - p2o_list.append((r_start, (r_end, p_offset))) - o2p_list.append((p_offset, (p_offset + (r_end - r_start), r_start))) - else: - # device_name = "" # UNUSED - for device in self.machine_data["MemoryMappedDevices"]: # Possible because NOTES always the first segment - if device[0] == r_start: - # device_name = device[1] # UNUSED - break - p2mmd_list.append((r_start, r_end)) - - # Debug - # self.p2o_list = p2o_list - # self.o2p_list = o2p_list - # self.p2mmd_list = p2mmd_list - - # Compact intervals - p2o_list = self._compact_intervals(p2o_list) - o2p_list = self._compact_intervals(o2p_list) - p2mmd_list = self._compact_intervals_simple(p2mmd_list) - - self.p2o = IMOffsets(*list(zip(*sorted(p2o_list)))) - self.o2p = IMOffsets(*list(zip(*sorted(o2p_list)))) - self.p2mmd = IMSimple(*list(zip(*sorted(p2mmd_list)))) - - def _compact_intervals_simple(self, intervals): - """Compact intervals if pointer values are contiguos""" - fused_intervals = [] - prev_begin = prev_end = -1 - for interval in intervals: - begin, end = interval - if prev_end == begin: - prev_end = end - else: - fused_intervals.append((prev_begin, prev_end)) - prev_begin = begin - prev_end = end - - if prev_begin != begin: - fused_intervals.append((prev_begin, prev_end)) - else: - fused_intervals.append((begin, end)) - - return fused_intervals[1:] - - def _compact_intervals(self, intervals): - """Compact intervals if pointer and pointed values are contigous""" - fused_intervals = [] - prev_begin = prev_end = prev_phy = -1 - for interval in intervals: - begin, (end, phy) = interval - if prev_end == begin and prev_phy + (prev_end - prev_begin) == phy: - prev_end = end - else: - fused_intervals.append((prev_begin, (prev_end, prev_phy))) - prev_begin = begin - prev_end = end - prev_phy = phy - - if prev_begin != begin: - fused_intervals.append((prev_begin, (prev_end, prev_phy))) - else: - fused_intervals.append((begin, (end, phy))) - - return fused_intervals[1:] - - def in_ram(self, paddr, size=1): - """Return True if the interval is completely in RAM""" - return self.p2o.contains(paddr, size)[0] == size - - def in_mmd(self, paddr, size=1): - """Return True if the interval is completely in Memory mapped devices space""" - return True if self.p2mmd.contains(paddr, size) != -1 else False - - def get_data(self, paddr, size): - """Return the data at physical address (interval)""" - size_available, intervals = self.p2o.contains(paddr, size) - if size_available != size: - return bytes() - - ret = bytearray() - for interval in intervals: - _, interval_size, offset = interval - ret.extend(self.elf_buf[offset:offset+interval_size].tobytes()) - - return ret - - def get_data_raw(self, offset, size=1): - """Return the data at the offset in the ELF (interval)""" - return self.elf_buf[offset:offset+size].tobytes() - - def get_machine_data(self): - """Return a dict containing machine configuration""" - return self.machine_data - - def get_ram_regions(self): - """Return all the RAM regions of the machine and the associated offset""" - return self.p2o.get_values() - - def get_mmd_regions(self): - """Return all the Memory mapped devices intervals of the machine and the associated offset""" - return self.p2mmd.get_values() - -def get_virtspace(phy, mmu_values): - """Return a virtspace from a physical one""" - architecture = phy.get_machine_data()["Architecture"].lower() - if "riscv" in architecture: - return RISCVTranslator.factory(phy, mmu_values) - elif "x86" in architecture or "386" in architecture: - return IntelTranslator.factory(phy, mmu_values) - else: - raise Exception("Unknown architecture") - -class AddressTranslator: - def __init__(self, dtb, phy): - self.dtb = dtb - self.phy = phy - - # Set machine specifics - if self.wordsize == 4: - self.word_type = np.uint32 - if self.phy.machine_data["Endianness"] == "big": - self.word_fmt = ">u4" - else: - self.word_fmt = " physical mappings""" - - table = self.phy.get_data(table_addr, self.table_sizes[lvl]) - if not table: - print(f"Table {hex(table_addr)} size:{self.table_sizes[lvl]} at level {lvl} not in RAM") - return - - for index, entry in enumerate(iter_unpack(self.unpack_fmt, table)): - is_valid, pmask, phy_addr, page_size = self._read_entry(index, entry[0], lvl) - - if not is_valid: - continue - - virt_addr = prefix | (index << self.shifts[lvl]) - pmask = upmask + pmask - - if (lvl == self.total_levels - 1) or page_size: # Last radix level or Leaf - - # Ignore pages not in RAM (some OSs map more RAM than available) and not memory mapped devices - in_ram = self.phy.in_ram(phy_addr, page_size) - in_mmd = self.phy.in_mmd(phy_addr, page_size) - if not in_ram and not in_mmd: - continue - - permissions = self._reconstruct_permissions(pmask) - virt_addr = self._finalize_virt_addr(virt_addr, permissions) - mapping[permissions].append((virt_addr, page_size, phy_addr, in_mmd)) - - # Add only RAM address to the reverse translation P2V - if in_ram and not in_mmd: - if permissions not in reverse_mapping: - reverse_mapping[permissions] = defaultdict(list) - reverse_mapping[permissions][(phy_addr, page_size)].append(virt_addr) - else: - # Lower level entry - self._explore_radixtree(phy_addr, mapping, reverse_mapping, lvl=lvl+1, prefix=virt_addr, upmask=pmask) - - def _compact_intervals_virt_offset(self, intervals): - """Compact intervals if virtual addresses and offsets values are - contigous (virt -> offset)""" - fused_intervals = [] - prev_begin = prev_end = prev_offset = -1 - for interval in intervals: - begin, end, phy, _ = interval - - offset = self.phy.p2o[phy] - if offset == -1: - continue - - if prev_end == begin and prev_offset + (prev_end - prev_begin) == offset: - prev_end = end - else: - fused_intervals.append((prev_begin, (prev_end, prev_offset))) - prev_begin = begin - prev_end = end - prev_offset = offset - - if prev_begin != begin: - fused_intervals.append((prev_begin, (prev_end, prev_offset))) - else: - offset = self.phy.p2o[phy] - if offset == -1: - print(f"ERROR!! {phy}") - else: - fused_intervals.append((begin, (end, offset))) - return fused_intervals[1:] - - def _compact_intervals_permissions(self, intervals): - """Compact intervals if virtual addresses are contigous and permissions are equals""" - fused_intervals = [] - prev_begin = prev_end = -1 - prev_pmask = (0, 0) - for interval in intervals: - begin, end, _, pmask = interval - if prev_end == begin and prev_pmask == pmask: - prev_end = end - else: - fused_intervals.append((prev_begin, (prev_end, prev_pmask))) - prev_begin = begin - prev_end = end - prev_pmask = pmask - - if prev_begin != begin: - fused_intervals.append((prev_begin, (prev_end, prev_pmask))) - else: - fused_intervals.append((begin, (end, pmask))) - - return fused_intervals[1:] - - def _reconstruct_mappings(self, table_addr, upmask): - # Explore the radix tree - mapping = defaultdict(list) - reverse_mapping = {} - self._explore_radixtree(table_addr, mapping, reverse_mapping, upmask=upmask) - - # Needed for ELF virtual mapping reconstruction - self.reverse_mapping = reverse_mapping - self.mapping = mapping - - # Collect all intervals (start, end+1, phy_page, pmask) - intervals = [] - for pmask, mapping_p in mapping.items(): - if pmask[1] == 0: # Ignore user not accessible pages - print(pmask) - continue - intervals.extend([(x[0], x[0]+x[1], x[2], pmask) for x in mapping_p if not x[3]]) # Ignore MMD - intervals.sort() - - if not intervals: - raise Exception - # Fuse intervals in order to reduce the number of elements to speed up - fused_intervals_v2o = self._compact_intervals_virt_offset(intervals) - fused_intervals_permissions = self._compact_intervals_permissions(intervals) - - # Offset to virtual is impossible to compact in a easy way due to the - # multiple-to-one mapping. We order the array and use bisection to find - # the possible results and a partial - intervals_o2v = [] - for pmasks, d in reverse_mapping.items(): - if pmasks[1] != 0: # Ignore user accessible pages - continue - for k, v in d.items(): - # We have to translate phy -> offset - offset = self.phy.p2o[k[0]] - if offset == -1: # Ignore unresolvable pages - continue - intervals_o2v.append((offset, k[1]+offset, tuple(v))) - intervals_o2v.sort() - - # Fill resolution objects - self.v2o = IMOffsets(*list(zip(*fused_intervals_v2o))) - self.o2v = IMOverlapping(intervals_o2v) - self.pmasks = IMData(*list(zip(*fused_intervals_permissions))) - - def export_virtual_memory_elf(self, elf_filename): - """Create an ELF file containg the virtual address space of the process""" - with open(elf_filename, "wb") as elf_fd: - # Create the ELF header and write it on the file - machine_data = self.phy.get_machine_data() - endianness = machine_data["Endianness"] - machine = machine_data["Architecture"].lower() - - # Create ELF main header - if "aarch64" in machine: - e_machine = 0xB7 - elif "arm" in machine: - e_machine = 0x28 - elif "riscv" in machine: - e_machine = 0xF3 - elif "x86_64" in machine: - e_machine = 0x3E - elif "386" in machine: - e_machine = 0x03 - else: - raise Exception("Unknown architecture") - - e_ehsize = 0x40 - e_phentsize = 0x38 - elf_h = bytearray(e_ehsize) - elf_h[0x00:0x04] = b'\x7fELF' # Magic - elf_h[0x04] = 2 # Elf type - elf_h[0x05] = 1 if endianness == "little" else 2 # Endianness - elf_h[0x06] = 1 # Version - elf_h[0x10:0x12] = 0x4.to_bytes(2, endianness) # e_type - elf_h[0x12:0x14] = e_machine.to_bytes(2, endianness) # e_machine - elf_h[0x14:0x18] = 0x1.to_bytes(4, endianness) # e_version - elf_h[0x34:0x36] = e_ehsize.to_bytes(2, endianness) # e_ehsize - elf_h[0x36:0x38] = e_phentsize.to_bytes(2, endianness) # e_phentsize - elf_fd.write(elf_h) - - # For each pmask try to compact intervals in order to reduce the number of segments - intervals = defaultdict(list) - for (kpmask, pmask), intervals_list in self.mapping.items(): - print(kpmask, pmask) - - if pmask == 0: # Ignore pages not accessible by the process - continue - - intervals[pmask].extend([(x[0], x[0]+x[1], x[2]) for x in intervals_list if not x[3]]) # Ignore MMD - intervals[pmask].sort() - - if len(intervals[pmask]) == 0: - intervals.pop(pmask) - continue - - # Compact them - fused_intervals = [] - prev_begin = prev_end = prev_offset = -1 - for interval in intervals[pmask]: - begin, end, phy = interval - - offset = self.phy.p2o[phy] - if offset == -1: - continue - - if prev_end == begin and prev_offset + (prev_end - prev_begin) == offset: - prev_end = end - else: - fused_intervals.append([prev_begin, prev_end, prev_offset]) - prev_begin = begin - prev_end = end - prev_offset = offset - - if prev_begin != begin: - fused_intervals.append([prev_begin, prev_end, prev_offset]) - else: - offset = self.phy.p2o[phy] - if offset == -1: - print(f"ERROR!! {phy}") - else: - fused_intervals.append([begin, end, offset]) - intervals[pmask] = sorted(fused_intervals[1:], key=lambda x: x[1] - x[0], reverse=True) - - # Write segments in the new file and fill the program header - p_offset = len(elf_h) - offset2p_offset = {} # Slow but more easy to implement (best way: a tree sort structure able to be updated) - e_phnum = 0 - - for pmask, interval_list in intervals.items(): - e_phnum += len(interval_list) - for idx, interval in enumerate(interval_list): - begin, end, offset = interval - size = end - begin - if offset not in offset2p_offset: - elf_fd.write(self.phy.get_data_raw(offset, size)) - if not self.phy.get_data_raw(offset, size): - print(hex(offset), hex(size)) - new_offset = p_offset - p_offset += size - for page_idx in range(0, size, self.minimum_page): - offset2p_offset[offset + page_idx] = new_offset + page_idx - else: - new_offset = offset2p_offset[offset] - interval_list[idx].append(new_offset) # Assign the new offset in the dest file - - # Create the program header containing all the segments (ignoring not in RAM pages) - e_phoff = elf_fd.tell() - p_header = bytes() - for pmask, interval_list in intervals.items(): - for begin, end, offset, p_offset in interval_list: - p_filesz = end - begin - - # Back convert offset to physical page - p_addr = self.phy.o2p[offset] - assert p_addr != -1 - - segment_entry = bytearray(e_phentsize) - segment_entry[0x00:0x04] = 0x1.to_bytes(4, endianness) # p_type - segment_entry[0x04:0x08] = pmask.to_bytes(4, endianness) # p_flags - segment_entry[0x10:0x18] = begin.to_bytes(8, endianness) # p_vaddr - segment_entry[0x18:0x20] = p_addr.to_bytes(8, endianness) # p_paddr Original physical address - segment_entry[0x28:0x30] = p_filesz.to_bytes(8, endianness) # p_memsz - segment_entry[0x08:0x10] = p_offset.to_bytes(8, endianness) # p_offset - segment_entry[0x20:0x28] = p_filesz.to_bytes(8, endianness) # p_filesz - - p_header += segment_entry - - # Write the segment header - elf_fd.write(p_header) - s_header_pos = elf_fd.tell() # Last position written (used if we need to write segment header) - - # Modify the ELF header to point to program header - elf_fd.seek(0x20) - elf_fd.write(e_phoff.to_bytes(8, endianness)) # e_phoff - - # If we have more than 65535 segments we have create a special Section entry contains the - # number of program entry (as specified in ELF64 specifications) - if e_phnum < 65536: - elf_fd.seek(0x38) - elf_fd.write(e_phnum.to_bytes(2, endianness)) # e_phnum - else: - elf_fd.seek(0x28) - elf_fd.write(s_header_pos.to_bytes(8, endianness)) # e_shoff - elf_fd.seek(0x38) - elf_fd.write(0xFFFF.to_bytes(2, endianness)) # e_phnum - elf_fd.write(0x40.to_bytes(2, endianness)) # e_shentsize - elf_fd.write(0x1.to_bytes(2, endianness)) # e_shnum - - section_entry = bytearray(0x40) - section_entry[0x2C:0x30] = e_phnum.to_bytes(4, endianness) # sh_info - elf_fd.seek(s_header_pos) - elf_fd.write(section_entry) - - -class IntelTranslator(AddressTranslator): - @staticmethod - def derive_mmu_settings(mmu_class, regs_dict, mphy): - if mmu_class is IntelAMD64: - dtb = ((regs_dict["cr3"] >> 12) & ((1 << (mphy - 12)) - 1)) << 12 - - elif mmu_class is IntelIA32: - dtb = ((regs_dict["cr3"] >> 12) & (1 << 20) - 1) << 12 - mphy = min(mphy, 40) - - else: - raise NotImplementedError - - return {"dtb": dtb, - "wp": True, - "ac": False, - "nxe": True, - "smep": False, - "smap": False, - "mphy": mphy - } - - @staticmethod - def derive_translator_class(mmu_mode): - if mmu_mode == "ia64": - return IntelAMD64 - elif mmu_mode == "pae": - return NotImplementedError - elif mmu_mode == "ia32": - return IntelIA32 - else: - raise NotImplementedError - - @staticmethod - def factory(phy, mmu_values): - machine_data = phy.get_machine_data() - mmu_mode = machine_data["MMUMode"] - mphy = machine_data["CPUSpecifics"]["MAXPHYADDR"] - - translator_c = IntelTranslator.derive_translator_class(mmu_mode) - mmu_settings = IntelTranslator.derive_mmu_settings(translator_c, mmu_values, mphy) - return translator_c(phy=phy, **mmu_settings) - - - def __init__(self, dtb, phy, mphy, wp=False, ac=False, nxe=False, smap=False, smep=False): - super(IntelTranslator, self).__init__(dtb, phy) - self.mphy = mphy - self.wp = wp - self.ac = ac # UNUSED by Fossil - self.smap = smap - self.nxe = nxe - self.smep = smep - self.minimum_page = 0x1000 - - print("Creating resolution trees...") - self._reconstruct_mappings(self.dtb, upmask=[[False, True, True]]) - - def _finalize_virt_addr(self, virt_addr, permissions): - return virt_addr - - -class IntelIA32(IntelTranslator): - def __init__(self, dtb, phy, mphy, wp=True, ac=False, nxe=False, smap=False, smep=False): - self.unpack_fmt = "> 12) & ((1 << 20) - 1)) << 12 - return True, perms_flags, addr, 0 - - # Leaf - else: - if lvl == 0: - addr = (((entry >> 13) & ((1 << (self.mphy - 32)) - 1)) << 32) | (((entry >> 22) & ((1 << 10) - 1)) << 22) - else: - addr = ((entry >> 12) & ((1 << 20) - 1)) << 12 - return True, perms_flags, addr, 1 << self.shifts[lvl] - - def _reconstruct_permissions(self, pmask): - k_flags, w_flags, _ = zip(*pmask) - - # Kernel page in user mode - if any(k_flags): - r = True - w = all(w_flags) if self.wp else True - return r << 2 | w << 1 | 1, 0 - - # User page in user mode - else: - r = True - w = all(w_flags) - return 0, r << 2 | w << 1 | 1 - -class IntelAMD64(IntelTranslator): - def __init__(self, dtb, phy, mphy, wp=True, ac=False, nxe=True, smap=False, smep=False): - self.unpack_fmt = "> 12) & ((1 << (self.mphy - 12)) - 1)) << 12 - return True, perms_flags, addr, 0 - - # Leaf - else: - addr = ((entry >> self.shifts[lvl]) & ((1 << (self.mphy - self.shifts[lvl])) - 1)) << self.shifts[lvl] - return True, perms_flags, addr, 1 << self.shifts[lvl] - - def _reconstruct_permissions(self, pmask): - k_flags, w_flags, x_flags = zip(*pmask) - - # Kernel page in user mode - if any(k_flags): - r = True - w = all(w_flags) if self.wp else True - x = all(x_flags) if self.nxe else True - - return r << 2 | w << 1 | int(x), 0 - - # User page in user mode - else: - r = True - w = all(w_flags) - x = all(x_flags) if self.nxe else True - - return 0, r << 2 | w << 1 | int(x) - - def _finalize_virt_addr(self, virt_addr, permissions): - # Canonical address form - if virt_addr & 0x800000000000: - return self.prefix | virt_addr - else: - return virt_addr - - -class RISCVTranslator(AddressTranslator): - @staticmethod - def derive_mmu_settings(mmu_class, regs_dict): - - dtb = regs_dict["satp"] - return {"dtb": dtb, - "Sum": False, - "mxr": False - } - - @staticmethod - def derive_translator_class(mmu_mode): - if mmu_mode == "sv39": - return RISCVSV39 - else: - return RISCVSV32 - @staticmethod - def factory(phy, mmu_values): - machine_data = phy.get_machine_data() - mmu_mode = machine_data["MMUMode"] - translator_c = RISCVTranslator.derive_translator_class(mmu_mode) - mmu_settings = RISCVTranslator.derive_mmu_settings(translator_c, mmu_values) - return translator_c(phy=phy, **mmu_settings) - - - def __init__(self, dtb, phy, Sum=True, mxr=True): - super(RISCVTranslator, self).__init__(dtb, phy) - self.Sum = Sum - self.mxr = mxr - self.minimum_page = 0x1000 - - print("Creating resolution trees...") - self._reconstruct_mappings(self.dtb, upmask=[[False, True, True, True]]) - - def _finalize_virt_addr(self, virt_addr, permissions): - return virt_addr - - def _reconstruct_permissions(self, pmask): - k_flag, r_flag, w_flag, x_flag = pmask[-1] # No hierarchy - - r = r_flag - if self.mxr: - r |= x_flag - - w = w_flag - x = x_flag - - # Kernel page in user mode - if k_flag: - return r << 2 | w << 1 | int(x), 0 - - # User page in user mode - else: - return 0, r << 2 | w << 1 | int(x) - - -class RISCVSV32(RISCVTranslator): - def __init__(self, dtb, phy, Sum, mxr): - self.unpack_fmt = "> 10) & ((1 << 22) - 1)) << 12 - # Leaf - if r or w or x or lvl == 1: - return True, perms_flags, addr, 1 << self.shifts[lvl] - else: - # Upper tables pointers - return True, perms_flags, addr, 0 - - -class RISCVSV39(RISCVTranslator): - def __init__(self, dtb, phy, Sum, mxr): - self.unpack_fmt = "> 10) & ((1 << 44) - 1)) << 12 - # Leaf - if r or w or x or lvl == 2: - return True, perms_flags, addr, 1 << self.shifts[lvl] - else: - # Upper tables pointers - return True, perms_flags, addr, 0 - -if __name__ == '__main__': - main() diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..ef92065 --- /dev/null +++ b/flake.lock @@ -0,0 +1,647 @@ +{ + "nodes": { + "cachix": { + "inputs": { + "devenv": "devenv_2", + "flake-compat": "flake-compat_2", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "pre-commit-hooks": "pre-commit-hooks" + }, + "locked": { + "lastModified": 1712055811, + "narHash": "sha256-7FcfMm5A/f02yyzuavJe06zLa9hcMHsagE28ADcmQvk=", + "owner": "cachix", + "repo": "cachix", + "rev": "02e38da89851ec7fec3356a5c04bc8349cae0e30", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "cachix", + "type": "github" + } + }, + "devenv": { + "inputs": { + "cachix": "cachix", + "flake-compat": "flake-compat_4", + "nix": "nix_2", + "nixpkgs": "nixpkgs_2", + "pre-commit-hooks": "pre-commit-hooks_2" + }, + "locked": { + "lastModified": 1713005873, + "narHash": "sha256-3DFCO/hK8h0Rs14t7uPr95gVHtKKADkjBh+JjSon8Aw=", + "owner": "cachix", + "repo": "devenv", + "rev": "8e882058b4602b70093d1fbff57755db09e89f11", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "devenv", + "type": "github" + } + }, + "devenv_2": { + "inputs": { + "flake-compat": [ + "devenv", + "cachix", + "flake-compat" + ], + "nix": "nix", + "nixpkgs": "nixpkgs", + "poetry2nix": "poetry2nix", + "pre-commit-hooks": [ + "devenv", + "cachix", + "pre-commit-hooks" + ] + }, + "locked": { + "lastModified": 1708704632, + "narHash": "sha256-w+dOIW60FKMaHI1q5714CSibk99JfYxm0CzTinYWr+Q=", + "owner": "cachix", + "repo": "devenv", + "rev": "2ee4450b0f4b95a1b90f2eb5ffea98b90e48c196", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "python-rewrite", + "repo": "devenv", + "type": "github" + } + }, + "fenix": { + "inputs": { + "nixpkgs": "nixpkgs_3", + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1713081423, + "narHash": "sha256-ZIWbIbbNsJoOdd/8a0uvT6Mn7s9x7noF9Bz8VEesc4Q=", + "owner": "nix-community", + "repo": "fenix", + "rev": "3a25d7927aa54299ef734b0d010c785857231183", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "type": "github" + } + }, + "flake-compat": { + "flake": false, + "locked": { + "lastModified": 1673956053, + "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_2": { + "flake": false, + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_3": { + "flake": false, + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_4": { + "flake": false, + "locked": { + "lastModified": 1696426674, + "narHash": "sha256-kvjfFW7WAETZlt09AgDn1MrtKzP7t90Vf7vypd3OL1U=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "0f9255e01c2351cc7d116c072cb317785dd33b33", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-compat_5": { + "flake": false, + "locked": { + "lastModified": 1673956053, + "narHash": "sha256-4gtG9iQuiKITOjNQQeQIpoIB6b16fm+504Ch3sNKLd8=", + "owner": "edolstra", + "repo": "flake-compat", + "rev": "35bb57c0c8d8b62bbfd284272c928ceb64ddbde9", + "type": "github" + }, + "original": { + "owner": "edolstra", + "repo": "flake-compat", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1689068808, + "narHash": "sha256-6ixXo3wt24N/melDWjq70UuHQLxGV8jZvooRanIHXw0=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "919d646de7be200f3bf08cb76ae1f09402b6f9b4", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_2": { + "inputs": { + "systems": "systems_2" + }, + "locked": { + "lastModified": 1701680307, + "narHash": "sha256-kAuep2h5ajznlPMD9rnQyffWG8EM/C73lejGofXvdM8=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "4022d587cbbfd70fe950c1e2083a02621806a725", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "flake-utils_3": { + "inputs": { + "systems": "systems_3" + }, + "locked": { + "lastModified": 1710146030, + "narHash": "sha256-SZ5L6eA7HJ/nmkzGG7/ISclqe6oZdOZTNoesiInkXPQ=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "b1d9ab70662946ef0850d488da1c9019f3a9752a", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "gitignore": { + "inputs": { + "nixpkgs": [ + "devenv", + "cachix", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1703887061, + "narHash": "sha256-gGPa9qWNc6eCXT/+Z5/zMkyYOuRZqeFZBDbopNZQkuY=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "43e1aa1308018f37118e34d3a9cb4f5e75dc11d5", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "gitignore_2": { + "inputs": { + "nixpkgs": [ + "devenv", + "pre-commit-hooks", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1709087332, + "narHash": "sha256-HG2cCnktfHsKV0s4XW83gU3F57gaTljL9KNSuG6bnQs=", + "owner": "hercules-ci", + "repo": "gitignore.nix", + "rev": "637db329424fd7e46cf4185293b9cc8c88c95394", + "type": "github" + }, + "original": { + "owner": "hercules-ci", + "repo": "gitignore.nix", + "type": "github" + } + }, + "nix": { + "inputs": { + "flake-compat": "flake-compat", + "nixpkgs": [ + "devenv", + "cachix", + "devenv", + "nixpkgs" + ], + "nixpkgs-regression": "nixpkgs-regression" + }, + "locked": { + "lastModified": 1708577783, + "narHash": "sha256-92xq7eXlxIT5zFNccLpjiP7sdQqQI30Gyui2p/PfKZM=", + "owner": "domenkozar", + "repo": "nix", + "rev": "ecd0af0c1f56de32cbad14daa1d82a132bf298f8", + "type": "github" + }, + "original": { + "owner": "domenkozar", + "ref": "devenv-2.21", + "repo": "nix", + "type": "github" + } + }, + "nix-github-actions": { + "inputs": { + "nixpkgs": [ + "devenv", + "cachix", + "devenv", + "poetry2nix", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1688870561, + "narHash": "sha256-4UYkifnPEw1nAzqqPOTL2MvWtm3sNGw1UTYTalkTcGY=", + "owner": "nix-community", + "repo": "nix-github-actions", + "rev": "165b1650b753316aa7f1787f3005a8d2da0f5301", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "nix-github-actions", + "type": "github" + } + }, + "nix_2": { + "inputs": { + "flake-compat": "flake-compat_5", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-regression": "nixpkgs-regression_2" + }, + "locked": { + "lastModified": 1712911606, + "narHash": "sha256-BGvBhepCufsjcUkXnEEXhEVjwdJAwPglCC2+bInc794=", + "owner": "domenkozar", + "repo": "nix", + "rev": "b24a9318ea3f3600c1e24b4a00691ee912d4de12", + "type": "github" + }, + "original": { + "owner": "domenkozar", + "ref": "devenv-2.21", + "repo": "nix", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1692808169, + "narHash": "sha256-x9Opq06rIiwdwGeK2Ykj69dNc2IvUH1fY55Wm7atwrE=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "9201b5ff357e781bf014d0330d18555695df7ba8", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixpkgs-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-regression": { + "locked": { + "lastModified": 1643052045, + "narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + } + }, + "nixpkgs-regression_2": { + "locked": { + "lastModified": 1643052045, + "narHash": "sha256-uGJ0VXIhWKGXxkeNnq4TvV3CIOkUJ3PAoLZ3HMzNVMw=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + }, + "original": { + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "215d4d0fd80ca5163643b03a33fde804a29cc1e2", + "type": "github" + } + }, + "nixpkgs-stable": { + "locked": { + "lastModified": 1704874635, + "narHash": "sha256-YWuCrtsty5vVZvu+7BchAxmcYzTMfolSPP5io8+WYCg=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "3dc440faeee9e889fe2d1b4d25ad0f430d449356", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-23.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs-stable_2": { + "locked": { + "lastModified": 1710695816, + "narHash": "sha256-3Eh7fhEID17pv9ZxrPwCLfqXnYP006RKzSs0JptsN84=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "614b4613980a522ba49f0d194531beddbb7220d3", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-23.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_2": { + "locked": { + "lastModified": 1710796454, + "narHash": "sha256-lQlICw60RhH8sHTDD/tJiiJrlAfNn8FDI9c+7G2F0SE=", + "owner": "cachix", + "repo": "devenv-nixpkgs", + "rev": "06fb0f1c643aee3ae6838dda3b37ef0abc3c763b", + "type": "github" + }, + "original": { + "owner": "cachix", + "ref": "rolling", + "repo": "devenv-nixpkgs", + "type": "github" + } + }, + "nixpkgs_3": { + "locked": { + "lastModified": 1712791164, + "narHash": "sha256-3sbWO1mbpWsLepZGbWaMovSO7ndZeFqDSdX0hZ9nVyw=", + "owner": "nixos", + "repo": "nixpkgs", + "rev": "1042fd8b148a9105f3c0aca3a6177fd1d9360ba5", + "type": "github" + }, + "original": { + "owner": "nixos", + "ref": "nixos-unstable", + "repo": "nixpkgs", + "type": "github" + } + }, + "nixpkgs_4": { + "locked": { + "lastModified": 1712867921, + "narHash": "sha256-edTFV4KldkCMdViC/rmpJa7oLIU8SE/S35lh/ukC7bg=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "51651a540816273b67bc4dedea2d37d116c5f7fe", + "type": "github" + }, + "original": { + "owner": "NixOS", + "ref": "nixos-23.11", + "repo": "nixpkgs", + "type": "github" + } + }, + "poetry2nix": { + "inputs": { + "flake-utils": "flake-utils", + "nix-github-actions": "nix-github-actions", + "nixpkgs": [ + "devenv", + "cachix", + "devenv", + "nixpkgs" + ] + }, + "locked": { + "lastModified": 1692876271, + "narHash": "sha256-IXfZEkI0Mal5y1jr6IRWMqK8GW2/f28xJenZIPQqkY0=", + "owner": "nix-community", + "repo": "poetry2nix", + "rev": "d5006be9c2c2417dafb2e2e5034d83fabd207ee3", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "poetry2nix", + "type": "github" + } + }, + "pre-commit-hooks": { + "inputs": { + "flake-compat": "flake-compat_3", + "flake-utils": "flake-utils_2", + "gitignore": "gitignore", + "nixpkgs": [ + "devenv", + "cachix", + "nixpkgs" + ], + "nixpkgs-stable": "nixpkgs-stable" + }, + "locked": { + "lastModified": 1708018599, + "narHash": "sha256-M+Ng6+SePmA8g06CmUZWi1AjG2tFBX9WCXElBHEKnyM=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "5df5a70ad7575f6601d91f0efec95dd9bc619431", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "pre-commit-hooks_2": { + "inputs": { + "flake-compat": [ + "devenv", + "flake-compat" + ], + "flake-utils": "flake-utils_3", + "gitignore": "gitignore_2", + "nixpkgs": [ + "devenv", + "nixpkgs" + ], + "nixpkgs-stable": "nixpkgs-stable_2" + }, + "locked": { + "lastModified": 1712897695, + "narHash": "sha256-nMirxrGteNAl9sWiOhoN5tIHyjBbVi5e2tgZUgZlK3Y=", + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "rev": "40e6053ecb65fcbf12863338a6dcefb3f55f1bf8", + "type": "github" + }, + "original": { + "owner": "cachix", + "repo": "pre-commit-hooks.nix", + "type": "github" + } + }, + "root": { + "inputs": { + "devenv": "devenv", + "fenix": "fenix", + "nixpkgs": "nixpkgs_4", + "systems": "systems_4" + } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1713048903, + "narHash": "sha256-Yw/V2S+yRkCgDOBgKbdD00m2voAc+UIrTx2Xqcry3Ns=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "beb205f347d676ec3dc6e6d13793a7c814f0a417", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_2": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_3": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + }, + "systems_4": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..4bbf1d4 --- /dev/null +++ b/flake.nix @@ -0,0 +1,54 @@ +{ + inputs = { + nixpkgs.url = "github:NixOS/nixpkgs/nixos-23.11"; + systems.url = "github:nix-systems/default"; + devenv.url = "github:cachix/devenv"; + fenix.url = "github:nix-community/fenix"; + }; + + nixConfig = { + extra-trusted-public-keys = "devenv.cachix.org-1:w1cLUi8dv3hnoSPGAuibQv+f9TZLr6cv/Hm9XgU50cw="; + extra-substituters = "https://devenv.cachix.org"; + }; + + outputs = + { self + , nixpkgs + , devenv + , systems + , ... + } @ inputs: + let + forEachSystem = nixpkgs.lib.genAttrs (import systems); + in + { + packages = forEachSystem (system: { + devenv-up = self.devShells.${system}.default.config.procfileScript; + }); + + devShells = + forEachSystem + (system: + let + pkgs = nixpkgs.legacyPackages.${system}; + in + { + default = devenv.lib.mkShell { + inherit inputs pkgs; + modules = [ + { + languages.python = { + enable = true; + venv = { + enable = true; + requirements = builtins.readFile ./requirements.txt; + }; + }; + + packages = with pkgs; [ python311Packages.numpy ]; + } + ]; + }; + }); + }; +} diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..d905c29 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,56 @@ +site_name: MMUShell +site_url: https://memoscopy.github.io/mmushell +site_description: Documentation for MMUShell + +repo_name: Memoscopy/mmushell +repo_url: https://github.com/Memoscopy/mmushell + +edit_uri: edit/main/docs/ + +theme: + name: material + features: + - search.suggest + - search.highlight + - search.share + - content.action.edit + - navigation.instant + - navigation.tabs + - navigation.tabs.sticky + - navigation.sections + - navigation.path + - navigation.top + icon: + repo: fontawesome/brands/github + +plugins: + - mkdocstrings + - search + - git-revision-date-localized: + enable_creation_date: true + - git-committers: + repository: Memoscopy/mmushell + branch: main + +markdown_extensions: + - admonition + - pymdownx.details + - pymdownx.superfences + +nav: + - index.md + - tutorials.md + - how-to-guides.md + - Reference: + - Index: reference/reference.md + - MMUShell: reference/mmushell.md + - Exporter: reference/exporter.md + - Architectures: + - Generic: reference/architectures/generic.md + - AArch64: reference/architectures/aarch64.md + - ARM: reference/architectures/arm.md + - RISC-V: reference/architectures/riscv.md + - Intel: reference/architectures/intel.md + - MIPS: reference/architectures/mips.md + - PowerPC: reference/architectures/ppc.md + - explanation.md diff --git a/mmushell.py b/mmushell.py deleted file mode 100755 index e0020bf..0000000 --- a/mmushell.py +++ /dev/null @@ -1,127 +0,0 @@ -#!/usr/bin/env python3 -import argparse -import logging -import importlib -import yaml -from cerberus import Validator - -# Set logging configuration -logger = logging.getLogger(__name__) - -# Schema for YAML configuration file -machine_yaml_schema = { - 'cpu': { - 'required': True, - 'type': 'dict', - 'schema': { - 'architecture': {'required': True, 'type': 'string', 'min': 1}, - 'endianness': {'required': True, 'type': 'string', 'min': 1}, - 'bits': {'required': True, 'type': 'integer', 'allowed': [32, 64]}, - 'processor_features': {'required': False, 'type': 'dict'}, - 'registers_values': { - 'required': False, - 'type': 'dict', - 'keysrules': {'type': 'string', 'min': 1}, - 'valuesrules': {'type': 'integer'} - } - } - }, - 'mmu': { - 'required': True, - 'type': 'dict', - 'schema': { - 'mode': {'required': True, 'type': 'string', 'min': 1} - } - }, - 'memspace': { - 'required': True, - 'type': 'dict', - 'schema': { - 'ram': { - 'required': True, - 'type': 'list', - 'minlength': 1, - 'schema': { - 'type': 'dict', - 'schema': { - 'start': {'required': True, 'type': 'integer', 'min': 0, 'max': 0xFFFFFFFFFFFFFFFF}, - 'end': {'required': True, 'type': 'integer', 'min': 0, 'max': 0xFFFFFFFFFFFFFFFF}, - 'dumpfile': {'required': True, 'type': 'string', 'min': 0} - } - } - }, - 'not_ram': { - 'required': True, - 'type': 'list', - 'minlength': 1, - 'schema': { - 'type': 'dict', - 'schema': { - 'start': {'required': True, 'type': 'integer', 'min': 0, 'max': 0xFFFFFFFFFFFFFFFF}, - 'end': {'required': True, 'type': 'integer', 'min': 0, 'max': 0xFFFFFFFFFFFFFFFF}, - } - } - } - } - } - } - -def main(): - # Parse arguments - parser = argparse.ArgumentParser() - parser.add_argument("MACHINE_CONFIG", help="YAML file describing the machine", type=argparse.FileType("r")) - parser.add_argument("--gtruth", help="Ground truth from QEMU registers", type=argparse.FileType("rb", 0), default=None) - parser.add_argument("--session", help="Data file of a previous MMUShell session", type=str, default=None) - parser.add_argument("--debug", help="Enable debug output", action="store_true", default=False) - args = parser.parse_args() - - # Set logging system - fmt = "%(msg)s" - if args.debug: - logging.basicConfig(level=logging.DEBUG, format=fmt) - else: - logging.basicConfig(level=logging.INFO, format=fmt) - - # Load the machine configuration YAML file - try: - machine_config = yaml.load(args.MACHINE_CONFIG, Loader=yaml.FullLoader) - args.MACHINE_CONFIG.close() - except Exception as e: - logger.fatal("Malformed YAML file: {}".format(e)) - exit(1) - - # Validate YAML schema - yaml_validator = Validator(allow_unknown=True) - if not yaml_validator.validate(machine_config, machine_yaml_schema): - logger.fatal("Invalid YAML file. Error:" + str(yaml_validator.errors)) - exit(1) - - # Create the Machine class - try: - architecture_module = importlib.import_module("architectures." + machine_config["cpu"]["architecture"]) - except ModuleNotFoundError: - logger.fatal("Unkown architecture!") - exit(1) - - # Create a Machine starting from the parsed configuration - machine = architecture_module.Machine.from_machine_config(machine_config) - - # Launch the interactive shell - if args.gtruth: - shell = architecture_module.MMUShellGTruth(machine=machine) - else: - shell = architecture_module.MMUShell(machine=machine) - - # Load ground truth (if passed) - if args.gtruth: - shell.load_gtruth(args.gtruth) - - # Load previous data (if passed) - if args.session: - shell.reload_data_from_file(args.session) - - shell.cmdloop() - - -if __name__ == '__main__': - main() diff --git a/mmushell/__init__.py b/mmushell/__init__.py new file mode 100644 index 0000000..4e05a47 --- /dev/null +++ b/mmushell/__init__.py @@ -0,0 +1,8 @@ +""" +OS-Agnostic memory morensics tool + +Modules exported by this package: + +- `mmushell`: main script allowing to reconstruct virtual address spaces from a memory dump. +- `exporter`: this is a POC showing the possible use of techniques to perform a preliminary analysis of a dump by exporting each virtual address space as a self-contained ELF Core file. +""" diff --git a/mmushell/architectures/__init__.py b/mmushell/architectures/__init__.py new file mode 100644 index 0000000..533a8f8 --- /dev/null +++ b/mmushell/architectures/__init__.py @@ -0,0 +1,13 @@ +""" +This package contains the architecture-specific modules used by the `mmushell` script. + +Modules exported by this package: + +- `generic`: Architecture-agnostic generic implementation for CPU and memory management related functions and classes. +- `aarch64`: AArch64-specific functions and classes. +- `arm`: ARM-specific functions and classes. +- `intel`: Intel-specific functions and classes. +- `mips`: MIPS-specific functions and classes. +- `ppc`: PowerPC-specific functions and classes. +- `riscv`: RISC-V-specific functions and classes. +""" diff --git a/architectures/aarch64.py b/mmushell/architectures/aarch64.py similarity index 66% rename from architectures/aarch64.py rename to mmushell/architectures/aarch64.py index eabf989..11f2fba 100644 --- a/architectures/aarch64.py +++ b/mmushell/architectures/aarch64.py @@ -1,3 +1,7 @@ +import logging +import portion +import multiprocessing as mp + from architectures.generic import Machine as MachineDefault from architectures.generic import CPU as CPUDefault from architectures.generic import PhysicalMemory as PhysicalMemoryDefault @@ -5,29 +9,29 @@ from architectures.generic import TableEntry, PageTable, MMURadix, PAS, RadixTree from architectures.generic import CPUReg, VAS from architectures.generic import MMU as MMUDefault -import logging -from collections import defaultdict, deque + from miasm.analysis.machine import Machine as MIASMMachine from miasm.core.bin_stream import bin_stream_vm from miasm.core.locationdb import LocationDB + +from more_itertools import divide +from collections import defaultdict, deque from prettytable import PrettyTable +from dataclasses import dataclass +from IPython import embed +from random import uniform +from struct import iter_unpack, unpack from time import sleep from tqdm import tqdm from copy import deepcopy, copy -from random import uniform -from struct import iter_unpack, unpack -from dataclasses import dataclass -import multiprocessing as mp -# import cProfile -import portion -from more_itertools import divide -from IPython import embed logger = logging.getLogger(__name__) -def _dummy_f(): # Workaround pickle defaultdict + +def _dummy_f(): # Workaround pickle defaultdict return defaultdict(set) + # For AArch64 TCR.TxSZ and TCR.TGx control the size of the address space in mode x (0 user, 1 kernel) and the size of the granule (the data # page). The structure of the radix tree is very complicated and depends on both TxSZ and TGx (see get_trees_struct()). The tree can have # a variable number of levels depending on TGx and TxSZ (the TTBRx_EL1 points to the first level available in the tree) @@ -55,14 +59,16 @@ class Data: class CPURegAArch64(CPUReg): @classmethod def get_register_obj(cls, reg_name, value): - return globals()[reg_name](value) + return globals()[reg_name](value) class TCR_EL1(CPURegAArch64): def is_valid(self, value): - if CPU.extract_bits(value, 6, 1) != 0 or \ - CPU.extract_bits(value, 35, 1) != 0 or \ - CPU.extract_bits(value, 59, 5) != 0: + if ( + CPU.extract_bits(value, 6, 1) != 0 + or CPU.extract_bits(value, 35, 1) != 0 + or CPU.extract_bits(value, 59, 5) != 0 + ): return False else: return True @@ -80,7 +86,12 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return self.t0sz == other.t0sz and self.t1sz == other.t1sz and self.tg0 == other.tg0 and self.tg1 == other.tg1 + return ( + self.t0sz == other.t0sz + and self.t1sz == other.t1sz + and self.tg0 == other.tg0 + and self.tg1 == other.tg1 + ) def count_fields_equals(self, other): tot = 0 @@ -152,7 +163,7 @@ def _get_tree_struct(self, granule, size_offset): elif 28 <= size_offset <= 38: t = (2, 1 << (42 - size_offset)) else: - t= (3, 1 << (53 - size_offset)) + t = (3, 1 << (53 - size_offset)) elif granule == 65536: if 12 <= size_offset <= 21: t = (1, 1 << (25 - size_offset)) @@ -167,8 +178,12 @@ def _get_tree_struct(self, granule, size_offset): def get_trees_struct(self): ret = {"kernel": None, "user": None} - ret["kernel"] = self._get_tree_struct(self.get_kernel_granule(), self.get_kernel_size_offset()) - ret["user"] = self._get_tree_struct(self.get_user_granule(), self.get_user_size_offset()) + ret["kernel"] = self._get_tree_struct( + self.get_kernel_granule(), self.get_kernel_size_offset() + ) + ret["user"] = self._get_tree_struct( + self.get_user_granule(), self.get_user_size_offset() + ) # WORKAROUND: some OS do not set correctly the register values, using invalid one on real hw... if ret["kernel"]["top_table_size"] == -1: @@ -214,7 +229,6 @@ def _get_radix_base(self, value): x = self._calculate_x() return CPU.extract_bits(value, x, 47 - x + 1) << x - def __init__(self, value): self.value = value if self.is_valid(value): @@ -231,24 +245,41 @@ def is_mmu_equivalent_to(self, other): def __repr__(self): return f"{self.reg_name} {hex(self.value)} => ASID:{hex(self.asid)}, Address:{hex(self.address)}, CnP: {self.cnp}" + class TTBR0_EL1(TTBR): reg_name = "TTBR0_EL1" mode = "user" + class TTBR1_EL1(TTBR): reg_name = "TTBR1_EL1" mode = "kernel" + ##################################################################### # 64 bit entries and page table ##################################################################### + class TEntry64(TableEntry): entry_size = 8 entry_name = "TEntry64" size = 0 - labels = ["Address:", "Attributes:", "Secure:", "Permissions:", "Shareability:", - "Accessed:", "Global:", "Block:", "Guarded:", "Dirty:", "Continous:", "Kernel exec:", "Exec:"] + labels = [ + "Address:", + "Attributes:", + "Secure:", + "Permissions:", + "Shareability:", + "Accessed:", + "Global:", + "Block:", + "Guarded:", + "Dirty:", + "Continous:", + "Kernel exec:", + "Exec:", + ] addr_fmt = "0x{:016x}" def __init__(self, address, lower_flags, upper_flags): @@ -261,22 +292,25 @@ def __hash__(self): def __repr__(self): e_resume = self.entry_resume_stringified() - return str([self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))]) + return str( + [self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))] + ) def entry_resume(self): - return [self.address, - self.extract_attributes(), - self.is_secure_entry(), - self.extract_permissions(), - self.extract_shareability(), - self.is_accessed_entry(), - self.is_global_entry(), - self.is_block_entry(), - self.is_guarded_entry(), - self.is_dirty_entry(), - self.is_continuous_entry(), - self.is_kernel_executable_entry(), - self.is_executable_entry() + return [ + self.address, + self.extract_attributes(), + self.is_secure_entry(), + self.extract_permissions(), + self.extract_shareability(), + self.is_accessed_entry(), + self.is_global_entry(), + self.is_block_entry(), + self.is_guarded_entry(), + self.is_dirty_entry(), + self.is_continuous_entry(), + self.is_kernel_executable_entry(), + self.is_executable_entry(), ] def entry_resume_stringified(self): @@ -295,28 +329,38 @@ def is_supervisor_entry(self): # Lower attributes (Block and Pages) def extract_attributes(self): return MMU.extract_bits(self.lower_flags, 0, 3) + def is_secure_entry(self): return not bool(MMU.extract_bits(self.lower_flags, 3, 1)) + def extract_permissions(self): return MMU.extract_bits(self.lower_flags, 4, 2) + def extract_shareability(self): return MMU.extract_bits(self.lower_flags, 6, 2) + def is_accessed_entry(self): return bool(MMU.extract_bits(self.lower_flags, 8, 1)) + def is_global_entry(self): return not bool(MMU.extract_bits(self.lower_flags, 9, 1)) + def is_block_entry(self): return not bool(MMU.extract_bits(self.lower_flags, 14, 1)) # Upper attributes (Block and Pages) def is_guarded_entry(self): return bool(MMU.extract_bits(self.upper_flags, 0, 1)) + def is_dirty_entry(self): return bool(MMU.extract_bits(self.upper_flags, 1, 1)) + def is_continuous_entry(self): return bool(MMU.extract_bits(self.upper_flags, 2, 1)) + def is_kernel_executable_entry(self): return not bool(MMU.extract_bits(self.upper_flags, 3, 1)) + def is_executable_entry(self): return not bool(MMU.extract_bits(self.upper_flags, 4, 1)) @@ -332,7 +376,15 @@ def get_permissions(self): permissions = self.extract_permissions() u = bool(permissions & 0x1) w = not bool(permissions & 0x2) - return (True, w, self.is_kernel_executable_entry(), u, u and w, self.is_executable_entry()) + return ( + True, + w, + self.is_kernel_executable_entry(), + u, + u and w, + self.is_executable_entry(), + ) + class PTP(TEntry64): entry_name = "PTP" @@ -344,28 +396,38 @@ def is_supervisor_entry(self): # Lower attributes def extract_attributes(self): return 0 + def is_secure_entry(self): return not bool(MMU.extract_bits(self.upper_flags, 13, 1)) + def extract_permissions(self): return MMU.extract_bits(self.upper_flags, 11, 2) + def extract_shareability(self): return 0 + def is_accessed_entry(self): return False + def is_global_entry(self): return False + def is_block_entry(self): return False # Upper attributes def is_guarded_entry(self): return False + def is_dirty_entry(self): return False + def is_continuous_entry(self): return False + def is_kernel_executable_entry(self): return not bool(MMU.extract_bits(self.upper_flags, 9, 1)) + def is_executable_entry(self): return not bool(MMU.extract_bits(self.upper_flags, 10, 1)) @@ -374,44 +436,69 @@ def get_permissions(self): u = not bool(permissions & 0x1) w = not bool(permissions & 0x2) - return (True, w, self.is_kernel_executable_entry(), u, u and w, self.is_executable_entry()) + return ( + True, + w, + self.is_kernel_executable_entry(), + u, + u and w, + self.is_executable_entry(), + ) + # Page table pointers class PTP_4KB(PTP): entry_name = "PTP_4KB" + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 12, 36) << 12 + class PTP_4KB_L0(PTP_4KB): entry_name = "PTP_4KB_L0" + + class PTP_4KB_L1(PTP_4KB): entry_name = "PTP_4KB_L1" + + class PTP_4KB_L2(PTP_4KB): entry_name = "PTP_4KB_L2" + class PTP_16KB(PTP): entry_name = "PTP_16KB" + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 14, 34) << 14 + class PTP_16KB_L0(PTP_16KB): entry_name = "PTP_16KB_L0" + + class PTP_16KB_L1(PTP_16KB): entry_name = "PTP_16KB_L1" + + class PTP_16KB_L2(PTP_16KB): entry_name = "PTP_16KB_L2" class PTP_64KB(PTP): entry_name = "PTP_64KB" + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 16, 32) << 16 + class PTP_64KB_L0(PTP_64KB): entry_name = "PTP_64KB_L0" + + class PTP_64KB_L1(PTP_64KB): entry_name = "PTP_64KB_L1" @@ -420,61 +507,93 @@ class PTP_64KB_L1(PTP_64KB): class PTBLOCK_L1_4KB(TEntry64): entry_name = "PTBLOCK_L1_4KB" size = 1024 * 1024 * 1024 + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 30, 18) << 30 + class PTBLOCK_L2_4KB(TEntry64): entry_name = "PTBLOCK_L2_4KB" size = 2 * 1024 * 1024 + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 21, 27) << 21 + class PTBLOCK_L2_16KB(TEntry64): entry_name = "PTBLOCK_L2_16KB" size = 32 * 1024 * 1024 + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 25, 23) << 25 + class PTBLOCK_L2_64KB(TEntry64): entry_name = "PTBLOCK_L2_64KB" size = 512 * 1024 * 1024 + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 29, 19) << 29 + # Page pointers class PTPAGE_4KB(TEntry64): entry_name = "PTPAGE_4KB" size = 4 * 1024 + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 12, 36) << 12 + class PTPAGE_16KB(TEntry64): entry_name = "PTPAGE_16KB" size = 16 * 1024 + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 14, 34) << 14 + class PTPAGE_64KB(TEntry64): entry_name = "PTPAGE_64KB" size = 64 * 1024 + @staticmethod def extract_addr(entry): - return MMU.extract_bits(entry, 12, 4) << 48 | MMU.extract_bits(entry, 16, 32) << 16 + return ( + MMU.extract_bits(entry, 12, 4) << 48 | MMU.extract_bits(entry, 16, 32) << 16 + ) + class ReservedEntry(TEntry64): entry_name = "Reserved" size = 0 + class PageTableAArch64(PageTable): entry_size = 8 - table_fields = ["Entry address", "Pointed address", "Attributes", "Secure", "Permsissions", "Shareability", - "Accessed", "Global", "Block", "Guarded", "Dirty", "Continous", "Kernel exec", "Exec", "Classes"] + table_fields = [ + "Entry address", + "Pointed address", + "Attributes", + "Secure", + "Permsissions", + "Shareability", + "Accessed", + "Global", + "Block", + "Guarded", + "Dirty", + "Continous", + "Kernel exec", + "Exec", + "Classes", + ] addr_fmt = "0x{:016x}" def __repr__(self): @@ -484,9 +603,13 @@ def __repr__(self): for entry_class in self.entries: for entry_idx, entry_obj in self.entries[entry_class].items(): entry_addr = self.address + (entry_idx * self.entry_size) - table.add_row([self.addr_fmt.format(entry_addr)] + entry_obj.entry_resume_stringified() + [entry_class.entry_name]) + table.add_row( + [self.addr_fmt.format(entry_addr)] + + entry_obj.entry_resume_stringified() + + [entry_class.entry_name] + ) - table.sortby="Entry address" + table.sortby = "Entry address" return str(table) @@ -497,8 +620,7 @@ class PhysicalMemory(PhysicalMemoryDefault): class CPU(CPUDefault): @classmethod def from_cpu_config(cls, cpu_config, **kwargs): - return CPUAArch64(cpu_config) - + return CPUAArch64(cpu_config) def __init__(self, features): super(CPU, self).__init__(features) @@ -517,7 +639,7 @@ def __init__(self, features): class CPUAArch64(CPU): def __init__(self, features): super(CPUAArch64, self).__init__(features) - self.processor_features["opcode_to_mmu_regs"] = { + self.processor_features["opcode_to_mmu_regs"] = { (1, 0, 2, 0, 2): "TCR_EL1", (1, 5, 2, 0, 2): "TCR_EL1", (1, 6, 2, 0, 2): "TCR_EL3", @@ -525,22 +647,23 @@ def __init__(self, features): (1, 5, 2, 0, 0): "TTBR0_EL1", (1, 0, 2, 0, 1): "TTBR1_EL1", (1, 5, 2, 0, 1): "TTBR1_EL1", - (1, 0, 5, 2, 0): "ESR_EL1", # Read - (1, 0, 6, 0, 0): "FAR_EL1", # Read - (1, 5, 4, 0, 1): "ELR_EL1", # Read - (1, 0, 1, 0, 0): "SCTLR_EL1", # R/W + (1, 0, 5, 2, 0): "ESR_EL1", # Read + (1, 0, 6, 0, 0): "FAR_EL1", # Read + (1, 5, 4, 0, 1): "ELR_EL1", # Read + (1, 0, 1, 0, 0): "SCTLR_EL1", # R/W (1, 5, 1, 0, 0): "SCTLR_EL1", (1, 6, 1, 0, 0): "SCTLR_EL3", - (1, 4, 4, 1, 0): "SP_EL1", # Write - (1, 5, 4, 0, 0): "SPSR_EL1", # R/W + (1, 4, 4, 1, 0): "SP_EL1", # Write + (1, 5, 4, 0, 0): "SPSR_EL1", # R/W (1, 6, 4, 0, 0): "SPSR_EL3", - (1, 0, 12, 0, 0): "VBAR_EL1", # Write - (1, 6, 12, 0, 0): "VBAR_EL3", # Write - (1, 0, 13, 0, 1): "CONTEXTIDR_EL1", # R/W - (1, 0, 0, 7, 0): "ID_AA64MMFR0_EL1" # Read - + (1, 0, 12, 0, 0): "VBAR_EL1", # Write + (1, 6, 12, 0, 0): "VBAR_EL3", # Write + (1, 0, 13, 0, 1): "CONTEXTIDR_EL1", # R/W + (1, 0, 0, 7, 0): "ID_AA64MMFR0_EL1", # Read } - self.processor_features["opcode_to_gregs"] = ["X{}".format(i) for i in range(31)] + self.processor_features["opcode_to_gregs"] = [ + "X{}".format(i) for i in range(31) + ] CPU.processor_features = self.processor_features CPU.registers_values = self.registers_values @@ -549,24 +672,24 @@ def parse_opcode(self, instr, page_addr, offset): # Collect locations of opcodes # RET and ERET - if CPUAArch64.extract_bits(instr, 0, 5) == 0 and \ - CPUAArch64.extract_bits(instr, 10, 22) == 0b1101011001011111000000: - return {page_addr + offset: {"register": "", - "instruction": "RET" - }} + if ( + CPUAArch64.extract_bits(instr, 0, 5) == 0 + and CPUAArch64.extract_bits(instr, 10, 22) == 0b1101011001011111000000 + ): + return {page_addr + offset: {"register": "", "instruction": "RET"}} - elif CPUAArch64.extract_bits(instr, 0, 32) == 0b11010110100111110000001111100000: - return {page_addr + offset: {"register": "", - "instruction": "ERET" - }} + elif ( + CPUAArch64.extract_bits(instr, 0, 32) == 0b11010110100111110000001111100000 + ): + return {page_addr + offset: {"register": "", "instruction": "ERET"}} # BLR/BR - elif CPUAArch64.extract_bits(instr, 0, 5) == 0b00000 and \ - CPUAArch64.extract_bits(instr, 22, 10) == 0b1101011000 and \ - CPUAArch64.extract_bits(instr, 10, 11) == 0b11111000000: - return {page_addr + offset: {"register": "", - "instruction": "BLR" - }} + elif ( + CPUAArch64.extract_bits(instr, 0, 5) == 0b00000 + and CPUAArch64.extract_bits(instr, 22, 10) == 0b1101011000 + and CPUAArch64.extract_bits(instr, 10, 11) == 0b11111000000 + ): + return {page_addr + offset: {"register": "", "instruction": "BLR"}} # MSR opcode for MMU registers (write on MMU register) elif CPUAArch64.extract_bits(instr, 20, 12) == 0b110101010001: @@ -584,11 +707,15 @@ def parse_opcode(self, instr, page_addr, offset): if reg_idx in self.processor_features["opcode_to_mmu_regs"]: mmu_regs = self.processor_features["opcode_to_mmu_regs"][reg_idx] rt = self.processor_features["opcode_to_gregs"][rt] - return {page_addr + offset: {"register": mmu_regs, - "gpr": [rt], - "f_addr": -1, - "instruction": "MSR" - }} + return { + page_addr + + offset: { + "register": mmu_regs, + "gpr": [rt], + "f_addr": -1, + "instruction": "MSR", + } + } # MRS opcode for MMU registers (read from MMU register) elif CPUAArch64.extract_bits(instr, 20, 12) == 0b110101010011: @@ -607,11 +734,15 @@ def parse_opcode(self, instr, page_addr, offset): mmu_regs = self.processor_features["opcode_to_mmu_regs"][reg_idx] rt = self.processor_features["opcode_to_gregs"][rt] - return {page_addr + offset: {"register": mmu_regs, - "gpr": [rt], - "f_addr": -1, - "instruction": "MRS" - }} + return { + page_addr + + offset: { + "register": mmu_regs, + "gpr": [rt], + "f_addr": -1, + "instruction": "MRS", + } + } else: return {} return {} @@ -624,7 +755,7 @@ def identify_functions_start(self, addreses): mdis.dontdis_retcall = False instr_len = self.processor_features["instr_len"] - logger = logging.getLogger('asmblock') + logger = logging.getLogger("asmblock") logger.disabled = True for addr in tqdm(addreses): @@ -635,7 +766,6 @@ def identify_functions_start(self, addreses): # Maximum 10000 instructions instructions = 0 while True and instructions <= 10000: - # Stop if found an invalid instruction try: asmcode = mdis.dis_instr(cur_addr) @@ -653,9 +783,21 @@ def identify_functions_start(self, addreses): break # JMPs and special opcodes - elif asmcode.name in ["BL", "BLR", "BRK", "HLT", "HVC", "SMC", - "SVC", "DCPS1", "DCPS2", "DCPS3", - "DRPS", "WFE", "WFI"]: + elif asmcode.name in [ + "BL", + "BLR", + "BRK", + "HLT", + "HVC", + "SMC", + "SVC", + "DCPS1", + "DCPS2", + "DCPS3", + "DRPS", + "WFE", + "WFI", + ]: cur_addr += instr_len break @@ -682,6 +824,7 @@ def get_miasm_machine(self): # MMU Modes ################################################################# + class MMU(MMURadix): PAGE_SIZE = 0 @@ -708,7 +851,16 @@ class LONG(MMU): map_entries_to_shifts = {"kernel": [], "user": []} map_reserved_entries_to_levels = {"kernel": [], "user": []} - def reconstruct_table(self, mode, frame_addr, frame_size, table_levels, table_size, table_entries, empty_entries): + def reconstruct_table( + self, + mode, + frame_addr, + frame_size, + table_levels, + table_size, + table_entries, + empty_entries, + ): # Reconstruct table_levels tables, empty tables and data_pages of a given size frame_d = defaultdict(dict) page_tables = defaultdict(dict) @@ -720,25 +872,35 @@ def reconstruct_table(self, mode, frame_addr, frame_size, table_levels, table_si frame_d.clear() # Count the empty entries - entry_addresses = set(range(table_addr, table_addr + table_size, MMU.entries_size)) + entry_addresses = set( + range(table_addr, table_addr + table_size, MMU.entries_size) + ) empty_count = len(entry_addresses.intersection(empty_entries)) # Reconstruct the content of the table candidate for entry_addr in entry_addresses.intersection(table_entries.keys()): entry_idx = (entry_addr - table_addr) // MMU.entries_size for entry_type in table_entries[entry_addr]: - frame_d[entry_type][entry_idx] = table_entries[entry_addr][entry_type] + frame_d[entry_type][entry_idx] = table_entries[entry_addr][ + entry_type + ] # Classify the frame - pt_classes = set(self.classify_frame(frame_d, empty_count, int(table_size // MMU.entries_size), mode=mode)) + pt_classes = set( + self.classify_frame( + frame_d, empty_count, int(table_size // MMU.entries_size), mode=mode + ) + ) - if -1 in pt_classes: # Empty + if -1 in pt_classes: # Empty empty_tables.append(table_addr) - elif -2 in pt_classes: # Data + elif -2 in pt_classes: # Data data_pages.append(table_addr) elif table_levels.intersection(pt_classes): levels = sorted(table_levels.intersection(pt_classes)) - table_obj = self.page_table_class(table_addr, table_size, deepcopy(frame_d), levels) + table_obj = self.page_table_class( + table_addr, table_size, deepcopy(frame_d), levels + ) for level in levels: page_tables[level][table_addr] = table_obj else: @@ -755,13 +917,18 @@ def aggregate_frames(self, frames, frame_size, page_size): if frame_addr % page_size != 0: continue - if all([(frame_addr + idx * frame_size) in frames for idx in range(1, frame_per_page)]): + if all( + [ + (frame_addr + idx * frame_size) in frames + for idx in range(1, frame_per_page) + ] + ): pages.append(frame_addr) return pages def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): - sleep(uniform(pidx, pidx+1) // 1000) + sleep(uniform(pidx, pidx + 1) // 1000) mm = copy(self.machine.memory) mm.reopen() @@ -769,12 +936,12 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): # parse all the records in a frame of 64KB and reconstruct all the different tables # Prepare thread local dictionaries in which collect data - data_pages = {"user": [], - "kernel": []} - empty_tables = {"user": [], - "kernel": []} - page_tables = {"user": [{} for i in range(self.radix_levels["user"])], - "kernel": [{} for i in range(self.radix_levels["kernel"])]} + data_pages = {"user": [], "kernel": []} + empty_tables = {"user": [], "kernel": []} + page_tables = { + "user": [{} for i in range(self.radix_levels["user"])], + "kernel": [{} for i in range(self.radix_levels["kernel"])], + } tcr = kwargs["tcr"] trees_struct = tcr.get_trees_struct() @@ -783,17 +950,23 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): table_entries = defaultdict(dict) empty_entries = set() total_elems, iterator = addresses - for frame_addr in tqdm(iterator, position=-pidx, total=total_elems, leave=False): + for frame_addr in tqdm( + iterator, position=-pidx, total=total_elems, leave=False + ): frame_buf = mm.get_data(frame_addr, frame_size) table_entries.clear() empty_entries.clear() # Unpack entries - for entry_idx, entry in enumerate(iter_unpack(self.paging_unpack_format, frame_buf)): + for entry_idx, entry in enumerate( + iter_unpack(self.paging_unpack_format, frame_buf) + ): entry_addr = frame_addr + entry_idx * MMU.entries_size - entry_classes = self.classify_entry(frame_addr, entry[0]) # In this case frame_addr is not used + entry_classes = self.classify_entry( + frame_addr, entry[0] + ) # In this case frame_addr is not used # Data entry if None in entry_classes: @@ -810,15 +983,46 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): table_entries[entry_addr][entry_type] = entry_obj # Reconstruct kernel tables - self.reconstruct_all_tables("kernel", trees_struct, frame_addr, frame_size, table_entries, empty_entries, page_tables, data_pages, empty_tables) + self.reconstruct_all_tables( + "kernel", + trees_struct, + frame_addr, + frame_size, + table_entries, + empty_entries, + page_tables, + data_pages, + empty_tables, + ) # Reconstruct user tables only if the radix tree have a different shape if trees_struct["kernel"] != trees_struct["user"]: - self.reconstruct_all_tables("user", trees_struct, frame_addr, frame_size, table_entries, empty_entries, page_tables, data_pages, empty_tables) + self.reconstruct_all_tables( + "user", + trees_struct, + frame_addr, + frame_size, + table_entries, + empty_entries, + page_tables, + data_pages, + empty_tables, + ) return page_tables, data_pages, empty_tables - def reconstruct_all_tables(self, mode, tree_struct, frame_addr, frame_size, table_entries, empty_entries, page_tables, data_pages, empty_tables): + def reconstruct_all_tables( + self, + mode, + tree_struct, + frame_addr, + frame_size, + table_entries, + empty_entries, + page_tables, + data_pages, + empty_tables, + ): granule = tree_struct[mode]["granule"] total_levels = tree_struct[mode]["total_levels"] top_table_size = tree_struct[mode]["top_table_size"] @@ -826,13 +1030,29 @@ def reconstruct_all_tables(self, mode, tree_struct, frame_addr, frame_size, tabl # Top table has a different size, must be parsed separately if granule != top_table_size: candidate_levels = list(range(1, total_levels)) - t, _, _ = self.reconstruct_table(mode, frame_addr, frame_size, [0], top_table_size, table_entries, empty_entries) + t, _, _ = self.reconstruct_table( + mode, + frame_addr, + frame_size, + [0], + top_table_size, + table_entries, + empty_entries, + ) page_tables[mode][0].update(t[0]) else: candidate_levels = list(range(total_levels)) # Look for other levels - t, d, e = self.reconstruct_table(mode, frame_addr, frame_size, candidate_levels, granule, table_entries, empty_entries) + t, d, e = self.reconstruct_table( + mode, + frame_addr, + frame_size, + candidate_levels, + granule, + table_entries, + empty_entries, + ) for level in t: page_tables[mode][level].update(t[level]) data_pages[mode].extend(d) @@ -848,16 +1068,16 @@ def classify_entry(self, page_addr, entry): # Block or RESERVED for PTL2 elif class_bits == 0b01: - # For L2 tables this type of entry is RESERVED and treated as EMPTY classification.append(ReservedEntry(0, 0, 0)) # SH bits has one configuration reserved (0b01) # At least 17:20 must be 0 - if (MMU.extract_bits(entry, 8, 2) != 0b01 and - not MMU.extract_bits(entry, 12, 4) and - not MMU.extract_bits(entry, 17, 4)): - + if ( + MMU.extract_bits(entry, 8, 2) != 0b01 + and not MMU.extract_bits(entry, 12, 4) + and not MMU.extract_bits(entry, 17, 4) + ): lower_flags = TEntry64.extract_lower_flags(entry) upper_flags = TEntry64.extract_upper_flags(entry) @@ -867,15 +1087,21 @@ def classify_entry(self, page_addr, entry): if not MMU.extract_bits(entry, 17, 8): addr = PTBLOCK_L2_16KB.extract_addr(addr) - classification.append(PTBLOCK_L2_16KB(addr, lower_flags, upper_flags)) + classification.append( + PTBLOCK_L2_16KB(addr, lower_flags, upper_flags) + ) if not MMU.extract_bits(entry, 17, 12): addr = PTBLOCK_L2_64KB.extract_addr(addr) - classification.append(PTBLOCK_L2_64KB(addr, lower_flags, upper_flags)) + classification.append( + PTBLOCK_L2_64KB(addr, lower_flags, upper_flags) + ) if not MMU.extract_bits(entry, 17, 13): addr = PTBLOCK_L1_4KB.extract_addr(addr) - classification.append(PTBLOCK_L1_4KB(addr, lower_flags, upper_flags)) + classification.append( + PTBLOCK_L1_4KB(addr, lower_flags, upper_flags) + ) # Page or Pointer else: @@ -920,28 +1146,26 @@ def classify_entry(self, page_addr, entry): return [None] return classification + class MMUShell(MMUShellDefault): - def __init__(self, completekey='tab', stdin=None, stdout=None, machine={}): + def __init__(self, completekey="tab", stdin=None, stdout=None, machine={}): super(MMUShell, self).__init__(completekey, stdin, stdout, machine) if not self.data: - self.data = Data(is_tables_found=False, - is_radix_found=False, - is_registers_found=False, - opcodes={}, - regs_values={}, - page_tables={"user": defaultdict(dict), - "kernel": defaultdict(dict)}, - data_pages={"user": [], - "kernel": []}, - empty_tables={"user": [], - "kernel": []}, - reverse_map_tables = {"user": None, - "kernel": None}, - reverse_map_pages = {"user": None, - "kernel": None}, - used_tcr=None, - ttbrs=defaultdict(dict)) + self.data = Data( + is_tables_found=False, + is_radix_found=False, + is_registers_found=False, + opcodes={}, + regs_values={}, + page_tables={"user": defaultdict(dict), "kernel": defaultdict(dict)}, + data_pages={"user": [], "kernel": []}, + empty_tables={"user": [], "kernel": []}, + reverse_map_tables={"user": None, "kernel": None}, + reverse_map_pages={"user": None, "kernel": None}, + used_tcr=None, + ttbrs=defaultdict(dict), + ) def reload_data_from_file(self, data_filename): super(MMUShell, self).reload_data_from_file(data_filename) @@ -958,7 +1182,9 @@ def do_find_registers_values(self, arg): return logger.info("Look for opcodes related to MMU setup...") - parallel_results = self.machine.apply_parallel(65536, self.machine.cpu.parse_opcodes_parallel) + parallel_results = self.machine.apply_parallel( + 65536, self.machine.cpu.parse_opcodes_parallel + ) opcodes = {} logger.info("Reaggregate threads data...") @@ -968,8 +1194,12 @@ def do_find_registers_values(self, arg): self.data.opcodes = opcodes # Filter to look only for opcodes which write on MMU register only and not read from them or from other registers - filter_f = lambda it: True if it[1]["register"] == "TCR_EL1" and it[1]["instruction"] == "MSR" else False - mmu_wr_opcodes = {k: v for k,v in filter(filter_f, opcodes.items())} + filter_f = ( + lambda it: True + if it[1]["register"] == "TCR_EL1" and it[1]["instruction"] == "MSR" + else False + ) + mmu_wr_opcodes = {k: v for k, v in filter(filter_f, opcodes.items())} logging.info("Use heuristics to find function addresses...") logging.info("This analysis could be extremely slow!") @@ -978,13 +1208,20 @@ def do_find_registers_values(self, arg): logging.info("Identify register values using data flow analysis...") # We use data flow analysis and merge the results - dataflow_values = self.machine.cpu.find_registers_values_dataflow(mmu_wr_opcodes) + dataflow_values = self.machine.cpu.find_registers_values_dataflow( + mmu_wr_opcodes + ) filtered_values = defaultdict(set) for register, values in dataflow_values.items(): for value in values: reg_obj = CPURegAArch64.get_register_obj(register, value) - if reg_obj.valid and not any([val_obj.is_mmu_equivalent_to(reg_obj) for val_obj in filtered_values[register]]): + if reg_obj.valid and not any( + [ + val_obj.is_mmu_equivalent_to(reg_obj) + for val_obj in filtered_values[register] + ] + ): filtered_values[register].add(reg_obj) self.data.regs_values = filtered_values @@ -1034,8 +1271,12 @@ def do_set_tcr(self, args): top_table_size = trees_struct[mode]["top_table_size"] LONG.radix_levels[mode] = total_levels - LONG.map_level_to_table_size[mode] = [top_table_size] + ([granule] * (total_levels - 1)) - LONG.map_reserved_entries_to_levels[mode] = [[] for i in range(total_levels - 1)] + [[ReservedEntry]] + LONG.map_level_to_table_size[mode] = [top_table_size] + ( + [granule] * (total_levels - 1) + ) + LONG.map_reserved_entries_to_levels[mode] = [ + [] for i in range(total_levels - 1) + ] + [[ReservedEntry]] if granule == 4096: if total_levels == 1: @@ -1043,17 +1284,55 @@ def do_set_tcr(self, args): LONG.map_ptr_entries_to_levels[mode] = [None] LONG.map_entries_to_shifts[mode] = {PTPAGE_4KB: 12} elif total_levels == 2: - LONG.map_datapages_entries_to_levels[mode] = [[PTBLOCK_L2_4KB], [PTPAGE_4KB]] + LONG.map_datapages_entries_to_levels[mode] = [ + [PTBLOCK_L2_4KB], + [PTPAGE_4KB], + ] LONG.map_ptr_entries_to_levels[mode] = [PTP_4KB_L0, None] - LONG.map_entries_to_shifts[mode] = {PTP_4KB_L0: 21, PTPAGE_4KB: 12, PTBLOCK_L2_4KB: 21} + LONG.map_entries_to_shifts[mode] = { + PTP_4KB_L0: 21, + PTPAGE_4KB: 12, + PTBLOCK_L2_4KB: 21, + } elif total_levels == 3: - LONG.map_datapages_entries_to_levels[mode] = [[PTBLOCK_L1_4KB], [PTBLOCK_L2_4KB], [PTPAGE_4KB]] - LONG.map_ptr_entries_to_levels[mode] = [PTP_4KB_L0, PTP_4KB_L1, None] - LONG.map_entries_to_shifts[mode] = {PTP_4KB_L0: 30, PTP_4KB_L1: 21, PTPAGE_4KB: 12, PTBLOCK_L2_4KB: 21, PTBLOCK_L1_4KB: 30} + LONG.map_datapages_entries_to_levels[mode] = [ + [PTBLOCK_L1_4KB], + [PTBLOCK_L2_4KB], + [PTPAGE_4KB], + ] + LONG.map_ptr_entries_to_levels[mode] = [ + PTP_4KB_L0, + PTP_4KB_L1, + None, + ] + LONG.map_entries_to_shifts[mode] = { + PTP_4KB_L0: 30, + PTP_4KB_L1: 21, + PTPAGE_4KB: 12, + PTBLOCK_L2_4KB: 21, + PTBLOCK_L1_4KB: 30, + } else: - LONG.map_datapages_entries_to_levels[mode] = [[None], [PTBLOCK_L1_4KB], [PTBLOCK_L2_4KB], [PTPAGE_4KB]] - LONG.map_ptr_entries_to_levels[mode] = [PTP_4KB_L0, PTP_4KB_L1, PTP_4KB_L2, None] - LONG.map_entries_to_shifts[mode] = {PTP_4KB_L0: 39, PTP_4KB_L1: 30, PTP_4KB_L2: 21,PTPAGE_4KB: 12, PTBLOCK_L2_4KB: 21, PTBLOCK_L1_4KB: 30} + LONG.map_datapages_entries_to_levels[mode] = [ + [None], + [PTBLOCK_L1_4KB], + [PTBLOCK_L2_4KB], + [PTPAGE_4KB], + ] + LONG.map_ptr_entries_to_levels[mode] = [ + PTP_4KB_L0, + PTP_4KB_L1, + PTP_4KB_L2, + None, + ] + LONG.map_entries_to_shifts[mode] = { + PTP_4KB_L0: 39, + PTP_4KB_L1: 30, + PTP_4KB_L2: 21, + PTPAGE_4KB: 12, + PTBLOCK_L2_4KB: 21, + PTBLOCK_L1_4KB: 30, + } elif granule == 16384: if total_levels == 1: @@ -1061,31 +1340,89 @@ def do_set_tcr(self, args): LONG.map_ptr_entries_to_levels[mode] = [None] LONG.map_entries_to_shifts[mode] = {PTPAGE_16KB: 14} elif total_levels == 2: - LONG.map_datapages_entries_to_levels[mode] = [[PTBLOCK_L2_16KB], [PTPAGE_16KB]] + LONG.map_datapages_entries_to_levels[mode] = [ + [PTBLOCK_L2_16KB], + [PTPAGE_16KB], + ] LONG.map_ptr_entries_to_levels[mode] = [PTP_16KB_L0, None] - LONG.map_entries_to_shifts[mode] = {PTP_16KB_L0: 25, PTPAGE_16KB: 14, PTBLOCK_L2_16KB: 25} + LONG.map_entries_to_shifts[mode] = { + PTP_16KB_L0: 25, + PTPAGE_16KB: 14, + PTBLOCK_L2_16KB: 25, + } elif total_levels == 3: - LONG.map_datapages_entries_to_levels[mode] = [[None], [PTBLOCK_L2_16KB], [PTPAGE_16KB]] - LONG.map_ptr_entries_to_levels[mode] = [PTP_16KB_L0, PTP_16KB_L1, None] - LONG.map_entries_to_shifts[mode] = {PTP_16KB_L0: 36, PTP_16KB_L1: 25, PTPAGE_16KB: 14, PTBLOCK_L2_16KB: 25} + LONG.map_datapages_entries_to_levels[mode] = [ + [None], + [PTBLOCK_L2_16KB], + [PTPAGE_16KB], + ] + LONG.map_ptr_entries_to_levels[mode] = [ + PTP_16KB_L0, + PTP_16KB_L1, + None, + ] + LONG.map_entries_to_shifts[mode] = { + PTP_16KB_L0: 36, + PTP_16KB_L1: 25, + PTPAGE_16KB: 14, + PTBLOCK_L2_16KB: 25, + } else: - LONG.map_datapages_entries_to_levels[mode] = [[None], [None], [PTBLOCK_L2_16KB], [PTPAGE_16KB]] - LONG.map_ptr_entries_to_levels[mode] = [PTP_16KB_L0, PTP_16KB_L2, None] - LONG.map_entries_to_shifts[mode] = {PTP_16KB_L0: 47, PTP_16KB_L1: 36, PTP_16KB_L2: 25, PTPAGE_16KB: 14, PTBLOCK_L2_16KB: 25} + LONG.map_datapages_entries_to_levels[mode] = [ + [None], + [None], + [PTBLOCK_L2_16KB], + [PTPAGE_16KB], + ] + LONG.map_ptr_entries_to_levels[mode] = [ + PTP_16KB_L0, + PTP_16KB_L2, + None, + ] + LONG.map_entries_to_shifts[mode] = { + PTP_16KB_L0: 47, + PTP_16KB_L1: 36, + PTP_16KB_L2: 25, + PTPAGE_16KB: 14, + PTBLOCK_L2_16KB: 25, + } else: if total_levels == 1: LONG.map_datapages_entries_to_levels[mode] = [PTPAGE_64KB] LONG.map_ptr_entries_to_levels[mode] = [None] - LONG.map_entries_to_shifts[mode] = {PTPAGE_64KB: 16, PTBLOCK_L2_64KB: 29} + LONG.map_entries_to_shifts[mode] = { + PTPAGE_64KB: 16, + PTBLOCK_L2_64KB: 29, + } elif total_levels == 2: - LONG.map_datapages_entries_to_levels[mode] = [[PTBLOCK_L2_64KB], [PTPAGE_64KB]] + LONG.map_datapages_entries_to_levels[mode] = [ + [PTBLOCK_L2_64KB], + [PTPAGE_64KB], + ] LONG.map_ptr_entries_to_levels[mode] = [PTP_64KB_L0, None] - LONG.map_entries_to_shifts[mode] = {PTP_64KB_L0: 29, PTPAGE_64KB: 16, PTBLOCK_L2_16KB: 29} + LONG.map_entries_to_shifts[mode] = { + PTP_64KB_L0: 29, + PTPAGE_64KB: 16, + PTBLOCK_L2_16KB: 29, + } else: - LONG.map_datapages_entries_to_levels[mode] = [[None], [PTBLOCK_L2_64KB], [PTPAGE_64KB]] - LONG.map_ptr_entries_to_levels[mode] = [PTP_64KB_L0, PTP_64KB_L1, None] - LONG.map_entries_to_shifts[mode] = {PTP_64KB_L0:42, PTP_64KB_L1:29, PTPAGE_64KB: 16, PTBLOCK_L2_64KB: 29} + LONG.map_datapages_entries_to_levels[mode] = [ + [None], + [PTBLOCK_L2_64KB], + [PTPAGE_64KB], + ] + LONG.map_ptr_entries_to_levels[mode] = [ + PTP_64KB_L0, + PTP_64KB_L1, + None, + ] + LONG.map_entries_to_shifts[mode] = { + PTP_64KB_L0: 42, + PTP_64KB_L1: 29, + PTPAGE_64KB: 16, + PTBLOCK_L2_64KB: 29, + } def do_find_tables(self, args): """Find MMU tables in memory""" @@ -1096,22 +1433,30 @@ def do_find_tables(self, args): # Delete all the previous table data if self.data.is_tables_found: - self.data.page_tables= { "user": defaultdict(dict), - "kernel": defaultdict(dict)} - self.data.data_pages={ "user": [], - "kernel": []} - self.data.empty_tables={"user": [], - "kernel": []} + self.data.page_tables = { + "user": defaultdict(dict), + "kernel": defaultdict(dict), + } + self.data.data_pages = {"user": [], "kernel": []} + self.data.empty_tables = {"user": [], "kernel": []} self.data.reverse_map_tables = {} self.data.reverse_map_pages = {} # WORKAROUND: initialize here because unpickable! - self.data.reverse_map_pages = {"kernel": defaultdict(_dummy_f), "user": defaultdict(_dummy_f)} - self.data.reverse_map_tables = {"kernel": defaultdict(_dummy_f), "user": defaultdict(_dummy_f)} + self.data.reverse_map_pages = { + "kernel": defaultdict(_dummy_f), + "user": defaultdict(_dummy_f), + } + self.data.reverse_map_tables = { + "kernel": defaultdict(_dummy_f), + "user": defaultdict(_dummy_f), + } # Parse memory in chunk of 64KiB logger.info("Look for paging tables...") - parallel_results = self.machine.apply_parallel(65536, self.machine.mmu.parse_parallel_frame, tcr=tcr) + parallel_results = self.machine.apply_parallel( + 65536, self.machine.mmu.parse_parallel_frame, tcr=tcr + ) logger.info("Reaggregate threads data...") for result in parallel_results: page_tables, data_pages, empty_tables = result.get() @@ -1134,10 +1479,16 @@ def do_find_tables(self, args): ptr_class = self.machine.mmu.map_ptr_entries_to_levels[mode][lvl] referenced_nxt = [] for table_addr in list(self.data.page_tables[mode][lvl].keys()): - for entry_obj in self.data.page_tables[mode][lvl][table_addr].entries[ptr_class].values(): - if entry_obj.address not in self.data.page_tables[mode][lvl + 1] and \ - entry_obj.address not in self.data.empty_tables[mode]: - + for entry_obj in ( + self.data.page_tables[mode][lvl][table_addr] + .entries[ptr_class] + .values() + ): + if ( + entry_obj.address + not in self.data.page_tables[mode][lvl + 1] + and entry_obj.address not in self.data.empty_tables[mode] + ): # Remove the table self.data.page_tables[mode][lvl].pop(table_addr) break @@ -1147,33 +1498,42 @@ def do_find_tables(self, args): # Remove table not referenced by upper levels referenced_nxt = set(referenced_nxt) - for table_addr in set(self.data.page_tables[mode][lvl + 1].keys()).difference(referenced_nxt): + for table_addr in set( + self.data.page_tables[mode][lvl + 1].keys() + ).difference(referenced_nxt): self.data.page_tables[mode][lvl + 1].pop(table_addr) logger.info("Fill reverse maps...") for mode in ["user", "kernel"]: for lvl in range(0, self.machine.mmu.radix_levels[mode]): ptr_class = self.machine.mmu.map_ptr_entries_to_levels[mode][lvl] - page_classes = self.machine.mmu.map_datapages_entries_to_levels[mode][lvl] + page_classes = self.machine.mmu.map_datapages_entries_to_levels[mode][ + lvl + ] for table_addr, table_obj in self.data.page_tables[mode][lvl].items(): for entry_obj in table_obj.entries[ptr_class].values(): - self.data.reverse_map_tables[mode][lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_tables[mode][lvl][entry_obj.address].add( + table_obj.address + ) for page_class in page_classes: for entry_obj in table_obj.entries[page_class].values(): - self.data.reverse_map_pages[mode][lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_pages[mode][lvl][ + entry_obj.address + ].add(table_obj.address) # If kernel and user space use the same configuration, copy kernel data to user trees_struct = tcr.get_trees_struct() if trees_struct["kernel"] == trees_struct["user"]: self.data.page_tables["user"] = self.data.page_tables["kernel"] self.data.reverse_map_pages["user"] = self.data.reverse_map_pages["kernel"] - self.data.reverse_map_tables["user"] = self.data.reverse_map_tables["kernel"] + self.data.reverse_map_tables["user"] = self.data.reverse_map_tables[ + "kernel" + ] self.data.data_pages["user"] = self.data.data_pages["kernel"] self.data.empty_tables["user"] = self.data.empty_tables["kernel"] self.data.is_tables_found = True - def do_show_table(self, args): """Show MMU table at chosen address. Usage: show_table ADDRESS (user, kernel) [level size]""" if not self.data.used_tcr: @@ -1202,13 +1562,14 @@ def do_show_table(self, args): mode = args[1] if len(args) == 4: - try: lvl = self.parse_int(args[2]) if lvl > (self.machine.mmu.radix_levels[mode] - 1): raise ValueError except ValueError: - logger.warning(f"Level must be an integer between 0 and {self.machine.mmu.radix_levels[mode] - 1}") + logger.warning( + f"Level must be an integer between 0 and {self.machine.mmu.radix_levels[mode] - 1}" + ) return trees_struct = LONG.tcr.get_trees_struct() @@ -1223,7 +1584,9 @@ def do_show_table(self, args): try: table_size = self.parse_int(args[3]) if table_size not in valid_sizes[mode][lvl]: - logging.warning(f"Size not allowed for choosen level! Valid sizes are:{valid_sizes[mode][lvl]}") + logging.warning( + f"Size not allowed for choosen level! Valid sizes are:{valid_sizes[mode][lvl]}" + ) return except ValueError: logger.warning("Invalid size value") @@ -1233,7 +1596,9 @@ def do_show_table(self, args): lvl = -1 table_buff = self.machine.memory.get_data(addr, table_size) - invalids, pt_classes, table_obj = self.machine.mmu.parse_frame(table_buff, addr, table_size, lvl, mode=mode) + invalids, pt_classes, table_obj = self.machine.mmu.parse_frame( + table_buff, addr, table_size, lvl, mode=mode + ) print(table_obj) print(f"Invalid entries: {invalids} Table levels: {pt_classes}") @@ -1251,7 +1616,9 @@ def do_find_radix_trees(self, args): self.data.ttbrs.clear() # Some table level was not found... - if not len(self.data.page_tables["kernel"][0]) and not len(self.data.page_tables["user"][0]): + if not len(self.data.page_tables["kernel"][0]) and not len( + self.data.page_tables["user"][0] + ): logger.warning("OOPS... no tables in first level... Wrong MMU mode?") return @@ -1261,36 +1628,61 @@ def do_find_radix_trees(self, args): # Collect opcodes opcode_classes = defaultdict(list) for opcode_addr, opcode_data in self.data.opcodes.items(): - opcode_classes[(opcode_data["instruction"], opcode_data["register"])].append(opcode_addr) + opcode_classes[ + (opcode_data["instruction"], opcode_data["register"]) + ].append(opcode_addr) # Find all TTBR1_EL1 which contain interrupt related opcodes logging.info("Find TTBR1_EL1 candidates...") - int_opcode_addrs = opcode_classes[("MRS", "ESR_EL1")] + opcode_classes[("MRS", "FAR_EL1")] + opcode_classes[("MRS", "ELR_EL1")] + int_opcode_addrs = ( + opcode_classes[("MRS", "ESR_EL1")] + + opcode_classes[("MRS", "FAR_EL1")] + + opcode_classes[("MRS", "ELR_EL1")] + ) already_explored = set() for opcode_addr in int_opcode_addrs: - - derived_addresses = self.machine.mmu.derive_page_address(opcode_addr, mode="kernel") + derived_addresses = self.machine.mmu.derive_page_address( + opcode_addr, mode="kernel" + ) for derived_address in derived_addresses: - if derived_address in already_explored: continue lvl, addr = derived_address - ttbrs_candidates["kernel"].extend(self.radix_roots_from_data_page(lvl, addr, self.data.reverse_map_pages["kernel"], self.data.reverse_map_tables["kernel"])) + ttbrs_candidates["kernel"].extend( + self.radix_roots_from_data_page( + lvl, + addr, + self.data.reverse_map_pages["kernel"], + self.data.reverse_map_tables["kernel"], + ) + ) already_explored.add(derived_address) - ttbrs_candidates["kernel"] = list(set(ttbrs_candidates["kernel"]).intersection(self.data.page_tables["kernel"][0].keys())) + ttbrs_candidates["kernel"] = list( + set(ttbrs_candidates["kernel"]).intersection( + self.data.page_tables["kernel"][0].keys() + ) + ) # Filter kernel candidates for ERET and write on MMU registers logger.info("Filtering TTBR1_EL1 candidates...") - mmu_w_opcode_addrs = opcode_classes[("MSR", "TCR_EL1")] + opcode_classes[("MSR", "TTBR0_EL1")] + mmu_w_opcode_addrs = ( + opcode_classes[("MSR", "TCR_EL1")] + opcode_classes[("MSR", "TTBR0_EL1")] + ) phy_cache = defaultdict(dict) - ttbrs_filtered = {"kernel":{}, "user":{}} + ttbrs_filtered = {"kernel": {}, "user": {}} virt_cache = defaultdict(dict) for candidate in tqdm(ttbrs_candidates["kernel"]): - # Calculate physpace and discard empty ones - consistency, pas = self.physpace(candidate, self.data.page_tables["kernel"], self.data.empty_tables["kernel"], mode="kernel", hierarchical=True, cache=phy_cache) + consistency, pas = self.physpace( + candidate, + self.data.page_tables["kernel"], + self.data.empty_tables["kernel"], + mode="kernel", + hierarchical=True, + cache=phy_cache, + ) # Discard inconsistent one if not consistency: @@ -1318,8 +1710,17 @@ def do_find_radix_trees(self, args): else: continue - vas = self.virtspace(candidate, mode="kernel", hierarchical=True, cache=virt_cache) - radix_tree = RadixTree(candidate, trees_struct["kernel"]["total_levels"], pas, vas, kernel=True, user=False) + vas = self.virtspace( + candidate, mode="kernel", hierarchical=True, cache=virt_cache + ) + radix_tree = RadixTree( + candidate, + trees_struct["kernel"]["total_levels"], + pas, + vas, + kernel=True, + user=False, + ) ttbrs_filtered["kernel"][candidate] = radix_tree # Find all TTBR0_EL1 which contain at least one RET instruction @@ -1327,25 +1728,42 @@ def do_find_radix_trees(self, args): virt_cache.clear() logging.info("Find TTBR0_EL1 candidates...") for opcode_addr in opcode_classes[("RET", "")]: - - derived_addresses = self.machine.mmu.derive_page_address(opcode_addr, mode="user") + derived_addresses = self.machine.mmu.derive_page_address( + opcode_addr, mode="user" + ) for derived_address in derived_addresses: - if derived_address in already_explored: continue lvl, addr = derived_address - ttbrs_candidates["user"].extend(self.radix_roots_from_data_page(lvl, addr, self.data.reverse_map_pages["user"], self.data.reverse_map_tables["user"])) + ttbrs_candidates["user"].extend( + self.radix_roots_from_data_page( + lvl, + addr, + self.data.reverse_map_pages["user"], + self.data.reverse_map_tables["user"], + ) + ) already_explored.add(derived_address) - ttbrs_candidates["user"] = list(set(ttbrs_candidates["user"]).intersection(self.data.page_tables["user"][0].keys())) + ttbrs_candidates["user"] = list( + set(ttbrs_candidates["user"]).intersection( + self.data.page_tables["user"][0].keys() + ) + ) logger.info("Filtering TTBR0_EL1 candidates...") phy_cache = defaultdict(dict) for candidate in tqdm(ttbrs_candidates["user"]): - # Calculate physpace and discard empty ones - consistency, pas = self.physpace(candidate, self.data.page_tables["user"], self.data.empty_tables["user"], mode="user", hierarchical=True, cache=phy_cache) + consistency, pas = self.physpace( + candidate, + self.data.page_tables["user"], + self.data.empty_tables["user"], + mode="user", + hierarchical=True, + cache=phy_cache, + ) # Discard inconsistent one if not consistency: @@ -1365,8 +1783,17 @@ def do_find_radix_trees(self, args): else: continue - vas = self.virtspace(candidate, mode="user", hierarchical=True, cache=virt_cache) - radix_tree = RadixTree(candidate, trees_struct["user"]["total_levels"], pas, vas, kernel=False, user=True) + vas = self.virtspace( + candidate, mode="user", hierarchical=True, cache=virt_cache + ) + radix_tree = RadixTree( + candidate, + trees_struct["user"]["total_levels"], + pas, + vas, + kernel=False, + user=True, + ) ttbrs_filtered["user"][candidate] = radix_tree self.data.ttbrs = ttbrs_filtered @@ -1378,15 +1805,24 @@ def do_show_radix_trees(self, args): logging.info("Please, find them first!") return - labels = ["Radix address", "Total levels", "Kernel size (Bytes)", "User size (Bytes)", "Kernel"] + labels = [ + "Radix address", + "Total levels", + "Kernel size (Bytes)", + "User size (Bytes)", + "Kernel", + ] table = PrettyTable() table.field_names = labels for mode in ["kernel", "user"]: for ttbr in self.data.ttbrs[mode].values(): - table.add_row(ttbr.entry_resume_stringified() + ["X" if mode == "kernel" else ""]) - table.sortby="Radix address" + table.add_row( + ttbr.entry_resume_stringified() + ["X" if mode == "kernel" else ""] + ) + table.sortby = "Radix address" print(table) + class MMUShellGTruth(MMUShell): def do_show_registers_gtruth(self, args): """Compare TCR values found with the ground truth""" @@ -1402,25 +1838,29 @@ def do_show_registers_gtruth(self, args): if value not in all_tcrs or (value_info[1] > all_tcrs[value][1]): all_tcrs[value] = (value_info[0], value_info[1]) - - last_tcr = TCR_EL1(sorted(all_tcrs.keys(), key=lambda x: all_tcrs[x][1], reverse=True)[0]) + last_tcr = TCR_EL1( + sorted(all_tcrs.keys(), key=lambda x: all_tcrs[x][1], reverse=True)[0] + ) tcr_fields_equals = {} for value_found_obj in self.data.regs_values["TCR_EL1"]: - tcr_fields_equals[value_found_obj] = value_found_obj.count_fields_equals(last_tcr) - k_sorted = sorted(tcr_fields_equals.keys(), key=lambda x: tcr_fields_equals[x], reverse=True) + tcr_fields_equals[value_found_obj] = value_found_obj.count_fields_equals( + last_tcr + ) + k_sorted = sorted( + tcr_fields_equals.keys(), key=lambda x: tcr_fields_equals[x], reverse=True + ) if not k_sorted: print(f"Correct TCR_EL1 value: {last_tcr}") print("TCR_EL1 fields found:... 0/4") - print("FP: {}".format(str(len(self.data.regs_values["TCR_EL1"])) )) + print("FP: {}".format(str(len(self.data.regs_values["TCR_EL1"])))) return else: tcr_found = k_sorted[0] correct_fields_found = tcr_fields_equals[tcr_found] print(f"Correct TCR_EL1 value: {last_tcr}, Found: {tcr_found}") print("TCR_EL1 fields found:... {}/4".format(correct_fields_found)) - print("FP: {}".format(str(len(self.data.regs_values["TCR_EL1"]) - 1) )) - + print("FP: {}".format(str(len(self.data.regs_values["TCR_EL1"]) - 1))) def do_show_radix_trees_gtruth(self, args): """Compare radix trees found with the ground truth""" @@ -1438,14 +1878,22 @@ def do_show_radix_trees_gtruth(self, args): # Collect opcodes opcode_classes = defaultdict(list) for opcode_addr, opcode_data in self.data.opcodes.items(): - opcode_classes[(opcode_data["instruction"], opcode_data["register"])].append(opcode_addr) + opcode_classes[ + (opcode_data["instruction"], opcode_data["register"]) + ].append(opcode_addr) # Kernel radix trees # Filtering using the same criteria used by the algorithm, however we test only candidates which are possible # False Negatives beacuse the interection must always pass the check! - mmu_w_opcode_addrs = opcode_classes[("MSR", "TCR_EL1")] + opcode_classes[("MSR", "TTBR0_EL1")] + opcode_classes[("MSR", "TTBR1_EL1")] - - kernel_radix_trees = False # Some AArch64 machines do not have TTBR1_EL1 but only TTBR0_EL1 + mmu_w_opcode_addrs = ( + opcode_classes[("MSR", "TCR_EL1")] + + opcode_classes[("MSR", "TTBR0_EL1")] + + opcode_classes[("MSR", "TTBR1_EL1")] + ) + + kernel_radix_trees = ( + False # Some AArch64 machines do not have TTBR1_EL1 but only TTBR0_EL1 + ) for key in ["TTBR1_EL1", "TTBR1_EL1_S"]: for value, data in tqdm(self.gtruth.get(key, {}).items()): ttbr = TTBR1_EL1(value) @@ -1469,42 +1917,54 @@ def do_show_radix_trees_gtruth(self, args): kernel_radix_trees = True if kernel_radix_trees: - tps = sorted(set(ttbr1s.keys()).intersection(set(self.data.ttbrs["kernel"].keys()))) - fps = sorted(set(self.data.ttbrs["kernel"].keys()).difference(set(ttbr1s.keys()))) - fns_candidates = set(ttbr1s.keys()).difference(set(self.data.ttbrs["kernel"].keys())) + tps = sorted( + set(ttbr1s.keys()).intersection(set(self.data.ttbrs["kernel"].keys())) + ) + fps = sorted( + set(self.data.ttbrs["kernel"].keys()).difference(set(ttbr1s.keys())) + ) + fns_candidates = set(ttbr1s.keys()).difference( + set(self.data.ttbrs["kernel"].keys()) + ) fns = [] # Check False negatives for candidate in tqdm(fns_candidates): - # Calculate physpace and discard empty ones - consistency, pas = self.physpace(candidate, self.data.page_tables["kernel"], self.data.empty_tables["kernel"], mode="kernel", hierarchical=True, cache=ttbr1_phy_cache) + # Calculate physpace and discard empty ones + consistency, pas = self.physpace( + candidate, + self.data.page_tables["kernel"], + self.data.empty_tables["kernel"], + mode="kernel", + hierarchical=True, + cache=ttbr1_phy_cache, + ) - # Discard inconsistent one - if not consistency: - continue + # Discard inconsistent one + if not consistency: + continue - # Check if at least one ERET opcode in physical address space - for opcode_addr in opcode_classes[("ERET", "")]: - if pas.is_in_kernel_space(opcode_addr): - break - else: - continue + # Check if at least one ERET opcode in physical address space + for opcode_addr in opcode_classes[("ERET", "")]: + if pas.is_in_kernel_space(opcode_addr): + break + else: + continue - # WARNING! We cannot filter for user_size = 0 due to TCR_EL1.E0PD1 ! - # Check if at least one MMU opcode in physical address space - for opcode_addr in mmu_w_opcode_addrs: - if pas.is_in_kernel_space(opcode_addr): - break - else: - continue + # WARNING! We cannot filter for user_size = 0 due to TCR_EL1.E0PD1 ! + # Check if at least one MMU opcode in physical address space + for opcode_addr in mmu_w_opcode_addrs: + if pas.is_in_kernel_space(opcode_addr): + break + else: + continue - fns.append(candidate) + fns.append(candidate) fns.sort() # User radix trees for key in ["TTBR0_EL1", "TTBR0_EL1_S"]: for value, data in tqdm(self.gtruth.get(key, {}).items()): - ttbr = TTBR0_EL1(value) try: @@ -1536,35 +1996,42 @@ def do_show_radix_trees_gtruth(self, args): # Filter FN fnsu = [] for candidate in tqdm(fnsu_candidates): - # Calculate physpace and discard empty ones - consistency, pas = self.physpace(candidate, self.data.page_tables["user"], self.data.empty_tables["user"], mode="user", hierarchical=True, cache=ttbr0_phy_cache) + # Calculate physpace and discard empty ones + consistency, pas = self.physpace( + candidate, + self.data.page_tables["user"], + self.data.empty_tables["user"], + mode="user", + hierarchical=True, + cache=ttbr0_phy_cache, + ) - # Discard inconsistent one - if not consistency: - continue + # Discard inconsistent one + if not consistency: + continue - # At least a page must be R or W in usermode - for perms in pas.space: - if perms[3] or perms[4]: - break - else: - continue + # At least a page must be R or W in usermode + for perms in pas.space: + if perms[3] or perms[4]: + break + else: + continue - # Check if at least one BLR opcode in physical address space - for opcode_addr in opcode_classes[("BLR", "")]: - if opcode_addr in pas: - break - else: - continue + # Check if at least one BLR opcode in physical address space + for opcode_addr in opcode_classes[("BLR", "")]: + if opcode_addr in pas: + break + else: + continue - # Check if at least one RET opcode in physical address space - for opcode_addr in opcode_classes[("RET", "")]: - if opcode_addr in pas: - break - else: - continue + # Check if at least one RET opcode in physical address space + for opcode_addr in opcode_classes[("RET", "")]: + if opcode_addr in pas: + break + else: + continue - fnsu.append(candidate) + fnsu.append(candidate) fnsu.sort() # Show results @@ -1575,47 +2042,33 @@ def do_show_radix_trees_gtruth(self, args): if kernel_radix_trees: umode = "U" for tp in sorted(tps): - table.add_row([hex(tp), - "X", - "K", - kernel_regs[tp][1][0], kernel_regs[tp][1][1]]) + table.add_row( + [hex(tp), "X", "K", kernel_regs[tp][1][0], kernel_regs[tp][1][1]] + ) for fn in sorted(fns): - table.add_row([hex(fn), - "", - "K", - kernel_regs[fn][1][0], kernel_regs[fn][1][1]]) + table.add_row( + [hex(fn), "", "K", kernel_regs[fn][1][0], kernel_regs[fn][1][1]] + ) for fp in sorted(fps): - table.add_row([hex(fp), - "False positive", - "K", - "", ""]) + table.add_row([hex(fp), "False positive", "K", "", ""]) else: umode = "K" # User for tp in sorted(tpsu): - table.add_row([hex(tp), - "X", - umode, - ttbr0s[tp][1][0], ttbr0s[tp][1][1]]) + table.add_row([hex(tp), "X", umode, ttbr0s[tp][1][0], ttbr0s[tp][1][1]]) for fn in sorted(fnsu): - table.add_row([hex(fn), - "", - umode, - ttbr0s[fn][1][0], ttbr0s[fn][1][1]]) + table.add_row([hex(fn), "", umode, ttbr0s[fn][1][0], ttbr0s[fn][1][1]]) for fp in sorted(fpsu): - table.add_row([hex(fp), - "False positive", - umode, - "", ""]) + table.add_row([hex(fp), "False positive", umode, "", ""]) print(table) if kernel_radix_trees: print(f"TP:{len(tps)} FN:{len(fns)} FP:{len(fps)}") print(f"USER TP:{len(tpsu)} FN:{len(fnsu)} FP:{len(fpsu)}") else: - print(f"TP:{len(tpsu)} FN:{len(fnsu)} FP:{len(fpsu)}") \ No newline at end of file + print(f"TP:{len(tpsu)} FN:{len(fnsu)} FP:{len(fpsu)}") diff --git a/architectures/arm.py b/mmushell/architectures/arm.py similarity index 66% rename from architectures/arm.py rename to mmushell/architectures/arm.py index fa407fe..7d1594f 100644 --- a/architectures/arm.py +++ b/mmushell/architectures/arm.py @@ -1,3 +1,7 @@ +import logging +import portion +import multiprocessing as mp + from architectures.generic import Machine as MachineDefault from architectures.generic import CPU as CPUDefault from architectures.generic import PhysicalMemory as PhysicalMemoryDefault @@ -5,21 +9,20 @@ from architectures.generic import TableEntry, PageTable, MMURadix, PAS, RadixTree from architectures.generic import CPUReg, VAS from architectures.generic import MMU as MMUDefault -import logging -from collections import defaultdict, deque + from miasm.analysis.machine import Machine as MIASMMachine from miasm.core.bin_stream import bin_stream_vm from miasm.core.locationdb import LocationDB from prettytable import PrettyTable + +from collections import defaultdict, deque +from dataclasses import dataclass +from random import uniform +from struct import iter_unpack, unpack from time import sleep from tqdm import tqdm from copy import deepcopy, copy -from random import uniform -from struct import iter_unpack, unpack -from dataclasses import dataclass -import multiprocessing as mp -# import cProfile -import portion + from more_itertools import divide from IPython import embed @@ -40,6 +43,7 @@ # PXN is present also on PTP short entries and its hierarchical! ######################################################## + class VASShort(VAS): def __repr__(self): s = "" @@ -50,6 +54,7 @@ def __repr__(self): s += f"\t[{hex(interval.lower)}, {hex(interval.upper)}]\n" return s + @dataclass class Data: is_tables_found: bool @@ -69,20 +74,22 @@ class Data: class CPURegARM32(CPUReg): @classmethod def get_register_obj(cls, reg_name, value): - return globals()[reg_name](value) + return globals()[reg_name](value) class SCTLR(CPURegARM32): def is_valid(self, value): - if CPU.extract_bits(value, 31, 1) or \ - CPU.extract_bits(value, 26, 1) or \ - CPU.extract_bits(value, 15, 1) or \ - CPU.extract_bits(value, 8, 2) or \ - not CPU.extract_bits(value, 23, 1) or \ - not CPU.extract_bits(value, 18, 1) or \ - not CPU.extract_bits(value, 16, 1) or \ - not CPU.extract_bits(value, 6, 1) or \ - CPU.extract_bits(value, 3, 2) != 3: + if ( + CPU.extract_bits(value, 31, 1) + or CPU.extract_bits(value, 26, 1) + or CPU.extract_bits(value, 15, 1) + or CPU.extract_bits(value, 8, 2) + or not CPU.extract_bits(value, 23, 1) + or not CPU.extract_bits(value, 18, 1) + or not CPU.extract_bits(value, 16, 1) + or not CPU.extract_bits(value, 6, 1) + or CPU.extract_bits(value, 3, 2) != 3 + ): return False else: return True @@ -101,15 +108,20 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return self.m == other.m and self.afe == other.afe and self.tre == other.tre and self.ee == other.ee + return ( + self.m == other.m + and self.afe == other.afe + and self.tre == other.tre + and self.ee == other.ee + ) def __repr__(self): return f"SCTLR {hex(self.value)} => HA:{hex(self.ha)}, AFE:{hex(self.afe)}, TRE:{hex(self.tre)}, EE:{hex(self.ee)} M:{hex(self.m)}" + class TTBCR(CPURegARM32): def is_valid(self, value): - if CPU.extract_bits(value, 6, 25) or \ - CPU.extract_bits(value, 3, 1): + if CPU.extract_bits(value, 6, 25) or CPU.extract_bits(value, 3, 1): return False else: return True @@ -155,7 +167,9 @@ def __init__(self, value): self.value = value if self.is_valid(value): self.valid = True - self.irgn = (CPU.extract_bits(value, 1, 1) << 1) | CPU.extract_bits(value, 6, 1) + self.irgn = (CPU.extract_bits(value, 1, 1) << 1) | CPU.extract_bits( + value, 6, 1 + ) self.s = CPU.extract_bits(value, 1, 1) self.imp = CPU.extract_bits(value, 2, 1) self.rgn = CPU.extract_bits(value, 3, 2) @@ -181,7 +195,9 @@ def __init__(self, value): self.value = value if self.is_valid(value): self.valid = True - self.irgn = (CPU.extract_bits(value, 1, 1) << 1) | CPU.extract_bits(value, 6, 1) + self.irgn = (CPU.extract_bits(value, 1, 1) << 1) | CPU.extract_bits( + value, 6, 1 + ) self.s = CPU.extract_bits(value, 1, 1) self.imp = CPU.extract_bits(value, 2, 1) self.rgn = CPU.extract_bits(value, 3, 2) @@ -196,16 +212,28 @@ def is_mmu_equivalent_to(self, other): def __repr__(self): return f"TTBR1 {hex(self.value)} => Address:{hex(self.address)}, IRGN:{hex(self.irgn)}, S:{hex(self.s)}, IMP:{hex(self.imp)}, RGN:{hex(self.rgn)}, NOS:{hex(self.nos)}" + ##################################################################### # 32 bit entries and page table ##################################################################### + class TEntry32(TableEntry): entry_size = 4 entry_name = "TEntry32" size = 0 - labels = ["Address:", "TEX:", "Cacheble:", "Bufferable:", "Permsissions:" - "Exec:", "Kernel exec:", "Secure:", "Domain:", "Shared:", "Global:"] + labels = [ + "Address:", + "TEX:", + "Cacheble:", + "Bufferable:", + "Permsissions:" "Exec:", + "Kernel exec:", + "Secure:", + "Domain:", + "Shared:", + "Global:", + ] addr_fmt = "0x{:08x}" def __hash__(self): @@ -213,21 +241,24 @@ def __hash__(self): def __repr__(self): e_resume = self.entry_resume_stringified() - return str([self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))]) + return str( + [self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))] + ) def entry_resume(self): - return [self.address, - self.extract_tex(), - self.is_cachable_entry(), - self.is_bufferable_entry(), - self.extract_permissions(), - self.is_executable_entry(), - self.is_kernel_executable_entry(), - self.is_secure_entry(), - self.extract_domain(), - self.is_shared_entry(), - self.is_global_entry() - ] + return [ + self.address, + self.extract_tex(), + self.is_cachable_entry(), + self.is_bufferable_entry(), + self.extract_permissions(), + self.is_executable_entry(), + self.is_kernel_executable_entry(), + self.is_secure_entry(), + self.extract_domain(), + self.is_shared_entry(), + self.is_global_entry(), + ] def entry_resume_stringified(self): res = self.entry_resume() @@ -256,7 +287,15 @@ def permissions_mode_21(self): def get_permissions(self): kr, kw, r, w = self.permissions_mode_21() - return (kr, kw, self.is_executable_entry(), r, w, self.is_kernel_executable_entry()) + return ( + kr, + kw, + self.is_executable_entry(), + r, + w, + self.is_kernel_executable_entry(), + ) + class PTP(TEntry32): entry_name = "PTP32" @@ -264,58 +303,83 @@ class PTP(TEntry32): def extract_tex(self): return 0 + def is_cachable_entry(self): return "Ign." + def is_bufferable_entry(self): return "Ign." + def is_executable_entry(self): return True + def extract_permissions(self): return 0 + def is_kernel_executable_entry(self): return not bool(MMU.extract_bits(self.flags, 2, 1)) + def is_secure_entry(self): return not bool(MMU.extract_bits(self.flags, 3, 1)) + def extract_domain(self): return 0 + def is_shared_entry(self): return "Ign." + def is_global_entry(self): return "Ign." + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 10, 22) << 10 + @staticmethod def extract_flags(entry): return MMU.extract_bits(entry, 0, 10) + class PTSECTION(TEntry32): entry_name = "PTSECTION" size = 1024 * 1024 def extract_tex(self): return MMU.extract_bits(self.flags, 12, 2) + def is_cachable_entry(self): return bool(MMU.extract_bits(self.flags, 3, 1)) + def is_bufferable_entry(self): return bool(MMU.extract_bits(self.flags, 2, 1)) + def extract_permissions(self): - return (MMU.extract_bits(self.flags, 15, 1) << 2) | MMU.extract_bits(self.flags, 10, 2) + return (MMU.extract_bits(self.flags, 15, 1) << 2) | MMU.extract_bits( + self.flags, 10, 2 + ) + def is_executable_entry(self): return not bool(MMU.extract_bits(self.flags, 4, 1)) + def is_kernel_executable_entry(self): return not bool(MMU.extract_bits(self.flags, 0, 1)) + def is_secure_entry(self): return not bool(MMU.extract_bits(self.flags, 19, 1)) + def extract_domain(self): return MMU.extract_bits(self.flags, 5, 4) + def is_shared_entry(self): return bool(MMU.extract_bits(self.flags, 16, 1)) + def is_global_entry(self): return not bool(MMU.extract_bits(self.flags, 17, 1)) + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 20, 12) << 20 + @staticmethod def extract_flags(entry): return MMU.extract_bits(entry, 0, 20) @@ -327,100 +391,155 @@ class PTSUPERSECTION(TEntry32): def extract_tex(self): return MMU.extract_bits(self.flags, 12, 2) + def is_cachable_entry(self): return bool(MMU.extract_bits(self.flags, 3, 1)) + def is_bufferable_entry(self): return bool(MMU.extract_bits(self.flags, 2, 1)) + def extract_permissions(self): - return (MMU.extract_bits(self.flags, 15, 1) << 2) | MMU.extract_bits(self.flags, 10, 2) + return (MMU.extract_bits(self.flags, 15, 1) << 2) | MMU.extract_bits( + self.flags, 10, 2 + ) + def is_executable_entry(self): return not bool(MMU.extract_bits(self.flags, 4, 1)) + def is_kernel_executable_entry(self): return not bool(MMU.extract_bits(self.flags, 0, 1)) + def is_secure_entry(self): return not bool(MMU.extract_bits(self.flags, 19, 1)) + def extract_domain(self): return 0 + def is_shared_entry(self): return bool(MMU.extract_bits(self.flags, 16, 1)) + def is_global_entry(self): return not bool(MMU.extract_bits(self.flags, 17, 1)) + @staticmethod def extract_addr(entry): addr = MMU.extract_bits(entry, 24, 8) << 24 addr = addr | (MMU.extract_bits(entry, 20, 4) << 32) addr = addr | (MMU.extract_bits(entry, 5, 4) << 36) return addr + @staticmethod def extract_flags(entry): return MMU.extract_bits(entry, 0, 20) + class PTLARGE(TEntry32): entry_name = "PTLARGE" size = 1024 * 64 def extract_tex(self): return MMU.extract_bits(self.flags, 12, 2) + def is_cachable_entry(self): return bool(MMU.extract_bits(self.flags, 3, 1)) + def is_bufferable_entry(self): return bool(MMU.extract_bits(self.flags, 2, 1)) + def extract_permissions(self): - return (MMU.extract_bits(self.flags, 9, 1) << 2) | MMU.extract_bits(self.flags, 4, 2) + return (MMU.extract_bits(self.flags, 9, 1) << 2) | MMU.extract_bits( + self.flags, 4, 2 + ) + def is_executable_entry(self): return not bool(MMU.extract_bits(self.flags, 15, 1)) + def is_kernel_executable_entry(self): return True + def is_secure_entry(self): return False + def extract_domain(self): return 0 + def is_shared_entry(self): return bool(MMU.extract_bits(self.flags, 10, 1)) + def is_global_entry(self): return not bool(MMU.extract_bits(self.flags, 11, 1)) + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 16, 16) << 16 + @staticmethod def extract_flags(entry): return MMU.extract_bits(entry, 0, 16) + class PTSMALL(TEntry32): entry_name = "PTSMALL" size = 1024 * 4 def extract_tex(self): return MMU.extract_bits(self.flags, 6, 2) + def is_cachable_entry(self): return bool(MMU.extract_bits(self.flags, 3, 1)) + def is_bufferable_entry(self): return bool(MMU.extract_bits(self.flags, 2, 1)) + def extract_permissions(self): - return (MMU.extract_bits(self.flags, 9, 1) << 2 )| MMU.extract_bits(self.flags, 4, 2) + return (MMU.extract_bits(self.flags, 9, 1) << 2) | MMU.extract_bits( + self.flags, 4, 2 + ) + def is_executable_entry(self): return not bool(MMU.extract_bits(self.flags, 0, 1)) + def is_kernel_executable_entry(self): return True + def is_secure_entry(self): return False + def extract_domain(self): return 0 + def is_shared_entry(self): return bool(MMU.extract_bits(self.flags, 10, 1)) + def is_global_entry(self): return not bool(MMU.extract_bits(self.flags, 11, 1)) + @staticmethod def extract_addr(entry): return MMU.extract_bits(entry, 12, 20) << 12 + @staticmethod def extract_flags(entry): return MMU.extract_bits(entry, 0, 12) + class PageTableARM32(PageTable): entry_size = 4 - table_fields = ["Entry address", "Pointed address", "TEX", "Cacheble", "Bufferable", - "Permissions", "Exec", "Kernel exec", "Secure", "Domain", "Shared", "Global", "Classes"] + table_fields = [ + "Entry address", + "Pointed address", + "TEX", + "Cacheble", + "Bufferable", + "Permissions", + "Exec", + "Kernel exec", + "Secure", + "Domain", + "Shared", + "Global", + "Classes", + ] addr_fmt = "0x{:08x}" def __repr__(self): @@ -430,9 +549,13 @@ def __repr__(self): for entry_class in self.entries: for entry_idx, entry_obj in self.entries[entry_class].items(): entry_addr = self.address + (entry_idx * self.entry_size) - table.add_row([self.addr_fmt.format(entry_addr)] + entry_obj.entry_resume_stringified() + [entry_class.entry_name]) + table.add_row( + [self.addr_fmt.format(entry_addr)] + + entry_obj.entry_resume_stringified() + + [entry_class.entry_name] + ) - table.sortby="Entry address" + table.sortby = "Entry address" return str(table) @@ -443,8 +566,7 @@ class PhysicalMemory(PhysicalMemoryDefault): class CPU(CPUDefault): @classmethod def from_cpu_config(cls, cpu_config, **kwargs): - return CPUARM32(cpu_config) - + return CPUARM32(cpu_config) def __init__(self, features): super(CPU, self).__init__(features) @@ -469,54 +591,75 @@ def __init__(self, features): (2, 0, 0, 1): "TTBR1", (2, 0, 0, 2): "TTBCR", (5, 0, 0, 0): "DFSR", - (5, 0, 0, 1): "IFSR" + (5, 0, 0, 1): "IFSR", } - self.processor_features["opcode_to_gregs"] = ["R{}".format(i) for i in range(16)] + self.processor_features["opcode_to_gregs"] = [ + "R{}".format(i) for i in range(16) + ] CPU.processor_features = self.processor_features CPU.registers_values = self.registers_values def parse_opcode(self, instr, page_addr, offset): - # Collect locations of opcodes - if CPUARM32.extract_bits(instr, 24, 4) == 0b1110 and \ - CPUARM32.extract_bits(instr, 4, 1) == 1 and \ - CPUARM32.extract_bits(instr, 8, 4) == 0b1111: - - opc1 = CPUARM32.extract_bits(instr, 21, 3) - crn = CPUARM32.extract_bits(instr, 16, 4) - rt = self.processor_features["opcode_to_gregs"][CPUARM32.extract_bits(instr, 12, 4)] - opc2 = CPUARM32.extract_bits(instr, 5, 3) - crm = CPUARM32.extract_bits(instr, 0, 4) - mmu_regs = self.processor_features["opcode_to_mmu_regs"] - - # MRC XXX, YYY (Read from Coprocessor register) - if CPUARM32.extract_bits(instr, 20, 1) == 1: - if (crn, opc1, crm, opc2) in mmu_regs: - if mmu_regs[(crn, opc1, crm, opc2)] not in ["IFSR", "TTBR0", "TTBR1", "DFSR"]: - return {} - - return {page_addr + offset: {"register": mmu_regs[(crn, opc1, crm, opc2)], - "gpr": [rt], - "f_addr": -1, - "f_parents": set(), - "instruction": "MRC" - } - } - - # MCR XXX, YYY (Write to Coprocessor register) - else: - if (crn, opc1, crm, opc2) in mmu_regs: - if mmu_regs[(crn, opc1, crm, opc2)] not in ["TTBR0", "TTBR1", "TTBCR", "SCTLR"]: - return {} - - return {page_addr + offset: {"register": mmu_regs[(crn, opc1, crm, opc2)], - "gpr": [rt], - "f_addr": -1, - "f_parents": set(), - "instruction": "MCR" - } - } - return {} + # Collect locations of opcodes + if ( + CPUARM32.extract_bits(instr, 24, 4) == 0b1110 + and CPUARM32.extract_bits(instr, 4, 1) == 1 + and CPUARM32.extract_bits(instr, 8, 4) == 0b1111 + ): + opc1 = CPUARM32.extract_bits(instr, 21, 3) + crn = CPUARM32.extract_bits(instr, 16, 4) + rt = self.processor_features["opcode_to_gregs"][ + CPUARM32.extract_bits(instr, 12, 4) + ] + opc2 = CPUARM32.extract_bits(instr, 5, 3) + crm = CPUARM32.extract_bits(instr, 0, 4) + mmu_regs = self.processor_features["opcode_to_mmu_regs"] + + # MRC XXX, YYY (Read from Coprocessor register) + if CPUARM32.extract_bits(instr, 20, 1) == 1: + if (crn, opc1, crm, opc2) in mmu_regs: + if mmu_regs[(crn, opc1, crm, opc2)] not in [ + "IFSR", + "TTBR0", + "TTBR1", + "DFSR", + ]: + return {} + + return { + page_addr + + offset: { + "register": mmu_regs[(crn, opc1, crm, opc2)], + "gpr": [rt], + "f_addr": -1, + "f_parents": set(), + "instruction": "MRC", + } + } + + # MCR XXX, YYY (Write to Coprocessor register) + else: + if (crn, opc1, crm, opc2) in mmu_regs: + if mmu_regs[(crn, opc1, crm, opc2)] not in [ + "TTBR0", + "TTBR1", + "TTBCR", + "SCTLR", + ]: + return {} + + return { + page_addr + + offset: { + "register": mmu_regs[(crn, opc1, crm, opc2)], + "gpr": [rt], + "f_addr": -1, + "f_parents": set(), + "instruction": "MCR", + } + } + return {} def identify_functions_start(self, addreses): machine = self.machine.get_miasm_machine() @@ -526,7 +669,7 @@ def identify_functions_start(self, addreses): mdis.dontdis_retcall = False instr_len = self.processor_features["instr_len"] - logger = logging.getLogger('asmblock') + logger = logging.getLogger("asmblock") logger.disabled = True for addr in tqdm(addreses): @@ -537,7 +680,6 @@ def identify_functions_start(self, addreses): # Maximum 10000 instructions instructions = 0 while True and instructions <= 10000: - # Stop if found an invalid instruction try: asmcode = mdis.dis_instr(cur_addr) @@ -561,7 +703,7 @@ def identify_functions_start(self, addreses): # Branch! elif asmcode.name in ["B", "BX", "BXJ", "BL", "BLX"]: - #if asmcode.arg2str(asmcode.args[0]) in ["LR", "R14"]: + # if asmcode.arg2str(asmcode.args[0]) in ["LR", "R14"]: cur_addr += instr_len break @@ -606,6 +748,7 @@ def get_miasm_machine(self): # MMU Modes ################################################################# + class MMU(MMURadix): PAGE_SIZE = 4096 @@ -626,19 +769,25 @@ def __init__(self, mmu_config): class SHORT(MMU): map_ptr_entries_to_levels = {"global": [PTP, None]} - map_datapages_entries_to_levels = {"global": [[PTSECTION, PTSUPERSECTION], [PTLARGE, PTSMALL]]} + map_datapages_entries_to_levels = { + "global": [[PTSECTION, PTSUPERSECTION], [PTLARGE, PTSMALL]] + } ttbcr_n = 0 map_level_to_table_size = {"global": [0, 4096]} - map_entries_to_shifts = {"global": { - PTP: 10, - PTSECTION: 20, - PTSUPERSECTION: 24, - PTLARGE: 16, - PTSMALL: 12 - }} + map_entries_to_shifts = { + "global": {PTP: 10, PTSECTION: 20, PTSUPERSECTION: 24, PTLARGE: 16, PTSMALL: 12} + } map_reserved_entries_to_levels = {"global": [[], []]} - def reconstruct_table(self, frame_addr, frame_size, table_level, table_size, table_entries, empty_entries): + def reconstruct_table( + self, + frame_addr, + frame_size, + table_level, + table_size, + table_entries, + empty_entries, + ): # Reconstruct table_level tables, empty tables and data_pages of a given size frame_d = defaultdict(dict) page_tables = {} @@ -648,24 +797,32 @@ def reconstruct_table(self, frame_addr, frame_size, table_level, table_size, tab frame_d.clear() # Count the empty entries - entry_addresses = set(range(table_addr, table_addr + table_size, MMU.entries_size)) + entry_addresses = set( + range(table_addr, table_addr + table_size, MMU.entries_size) + ) empty_count = len(entry_addresses.intersection(empty_entries)) # Reconstruct the content of the table candidate for entry_addr in entry_addresses.intersection(table_entries.keys()): entry_idx = (entry_addr - table_addr) // MMU.entries_size for entry_type in table_entries[entry_addr]: - frame_d[entry_type][entry_idx] = table_entries[entry_addr][entry_type] + frame_d[entry_type][entry_idx] = table_entries[entry_addr][ + entry_type + ] # Classify the frame - pt_classes = self.classify_frame(frame_d, empty_count, int(table_size // MMU.entries_size)) + pt_classes = self.classify_frame( + frame_d, empty_count, int(table_size // MMU.entries_size) + ) - if -1 in pt_classes: # Empty + if -1 in pt_classes: # Empty empty_tables.append(table_addr) - elif -2 in pt_classes: # Data + elif -2 in pt_classes: # Data data_pages.append(table_addr) elif table_level in pt_classes: - table_obj = self.page_table_class(table_addr, table_size, deepcopy(frame_d), [table_level]) + table_obj = self.page_table_class( + table_addr, table_size, deepcopy(frame_d), [table_level] + ) page_tables[table_addr] = table_obj else: continue @@ -681,13 +838,18 @@ def aggregate_frames(self, frames, frame_size, page_size): if frame_addr % page_size != 0: continue - if all([(frame_addr + idx * frame_size) in frames for idx in range(1, frame_per_page)]): + if all( + [ + (frame_addr + idx * frame_size) in frames + for idx in range(1, frame_per_page) + ] + ): pages.append(frame_addr) return pages def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): - sleep(uniform(pidx, pidx+1) // 1000) + sleep(uniform(pidx, pidx + 1) // 1000) mm = copy(self.machine.memory) mm.reopen() @@ -706,24 +868,32 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): # Prepare thread local dictionaries in which collect data data_frames = [] empty_tables = [] - page_tables = {"user": [{} for i in range(self.radix_levels["global"])], - "kernel": [{} for i in range(self.radix_levels["global"])]} + page_tables = { + "user": [{} for i in range(self.radix_levels["global"])], + "kernel": [{} for i in range(self.radix_levels["global"])], + } # Cicle over every frame table_entries = defaultdict(dict) empty_entries = set() total_elems, iterator = addresses - for frame_addr in tqdm(iterator, position=-pidx, total=total_elems, leave=False): + for frame_addr in tqdm( + iterator, position=-pidx, total=total_elems, leave=False + ): frame_buf = mm.get_data(frame_addr, frame_size) table_entries.clear() empty_entries.clear() # Unpack entries - for entry_idx, entry in enumerate(iter_unpack(self.paging_unpack_format, frame_buf)): + for entry_idx, entry in enumerate( + iter_unpack(self.paging_unpack_format, frame_buf) + ): entry_addr = frame_addr + entry_idx * MMU.entries_size - entry_classes = self.classify_entry(frame_addr, entry[0]) # In this case frame_addr is not used + entry_classes = self.classify_entry( + frame_addr, entry[0] + ) # In this case frame_addr is not used # Data entry if None in entry_classes: @@ -741,18 +911,34 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): empty_entries = set(empty_entries) # Look for PTL1 tables and empty tables - t, d, e = self.reconstruct_table(frame_addr, frame_size, 1, PTL1_TABLE_SIZE, table_entries, empty_entries) + t, d, e = self.reconstruct_table( + frame_addr, frame_size, 1, PTL1_TABLE_SIZE, table_entries, empty_entries + ) page_tables["kernel"][1].update(t) data_frames.extend(d) empty_tables.extend(e) # Look for PTL0 Kernel tables - t, _, _ = self.reconstruct_table(frame_addr, frame_size, 0, PTL0_KERNEL_TABLE_SIZE, table_entries, empty_entries) + t, _, _ = self.reconstruct_table( + frame_addr, + frame_size, + 0, + PTL0_KERNEL_TABLE_SIZE, + table_entries, + empty_entries, + ) page_tables["kernel"][0].update(t) # Look for PTL0 User tables if PTL0_USER_TABLE_SIZE != PTL0_KERNEL_TABLE_SIZE: - t, _, _ = self.reconstruct_table(frame_addr, frame_size, 0, PTL0_USER_TABLE_SIZE, table_entries, empty_entries) + t, _, _ = self.reconstruct_table( + frame_addr, + frame_size, + 0, + PTL0_USER_TABLE_SIZE, + table_entries, + empty_entries, + ) page_tables["user"][0].update(t) # Reconstruct data_pages @@ -769,7 +955,9 @@ def classify_entry(self, page_addr, entry): elif class_bits == 0b01: # BIT 9 (which is IMPLEMENTATION DEFINED) it has to be zero - if not MMU.extract_bits(entry, 9, 1): # WARNING! Specs require bit 4 = 0 but OSs do not respect it... :/ + if not MMU.extract_bits( + entry, 9, 1 + ): # WARNING! Specs require bit 4 = 0 but OSs do not respect it... :/ addr = PTP.extract_addr(entry) if addr not in self.machine.memory.physpace["not_valid_regions"]: flags = PTP.extract_flags(entry) @@ -781,7 +969,11 @@ def classify_entry(self, page_addr, entry): flags = PTLARGE.extract_flags(entry) ptlarge_obj = PTLARGE(addr, flags) - if not (ptlarge_obj.extract_tex() == 1 and not ptlarge_obj.is_bufferable_entry() and ptlarge_obj.is_cachable_entry()): + if not ( + ptlarge_obj.extract_tex() == 1 + and not ptlarge_obj.is_bufferable_entry() + and ptlarge_obj.is_cachable_entry() + ): classification.append(ptlarge_obj) else: # If bit 9 is 0 it can be a Section/Supersection (bit 9 is IMPLEMENTATION DEFINED) @@ -790,13 +982,18 @@ def classify_entry(self, page_addr, entry): flags = PTSECTION.extract_flags(entry) section_obj = PTSECTION(addr, flags) # Some values of TEX, C and B are RESERVED - if not (section_obj.extract_tex() == 1 and not section_obj.is_bufferable_entry() and section_obj.is_cachable_entry()): - + if not ( + section_obj.extract_tex() == 1 + and not section_obj.is_bufferable_entry() + and section_obj.is_cachable_entry() + ): # BIT 18 discriminate between Section (0) and Supersection (1) if MMU.extract_bits(entry, 18, 1): super_section_addr = PTSUPERSECTION.extract_addr(entry) super_section_flags = PTSUPERSECTION.extract_flags(entry) - classification.append(PTSUPERSECTION(super_section_addr, super_section_flags)) + classification.append( + PTSUPERSECTION(super_section_addr, super_section_flags) + ) else: classification.append(section_obj) @@ -805,33 +1002,62 @@ def classify_entry(self, page_addr, entry): addr = PTSMALL.extract_addr(entry) flags = PTSMALL.extract_flags(entry) entry_obj = PTSMALL(addr, flags) - if not (entry_obj.extract_tex() == 1 and not entry_obj.is_bufferable_entry() and entry_obj.is_cachable_entry()): + if not ( + entry_obj.extract_tex() == 1 + and not entry_obj.is_bufferable_entry() + and entry_obj.is_cachable_entry() + ): classification.append(entry_obj) if not classification: # No valid class found return [None] return classification + class MMUShell(MMUShellDefault): - def __init__(self, completekey='tab', stdin=None, stdout=None, machine={}): + def __init__(self, completekey="tab", stdin=None, stdout=None, machine={}): super(MMUShell, self).__init__(completekey, stdin, stdout, machine) if not self.data: - self.data = Data(is_tables_found=False, - is_radix_found=False, - is_registers_found=False, - opcodes={}, - regs_values={}, - page_tables={"user": [{} for i in range(self.machine.mmu.radix_levels["global"])], - "kernel": [{} for i in range(self.machine.mmu.radix_levels["global"])]}, - data_pages=[], - empty_tables=[], - reverse_map_tables = {"user": [defaultdict(set) for i in range(self.machine.mmu.radix_levels["global"])], - "kernel": [defaultdict(set) for i in range(self.machine.mmu.radix_levels["global"])]}, - reverse_map_pages = {"user": [defaultdict(set) for i in range(self.machine.mmu.radix_levels["global"])], - "kernel": [defaultdict(set) for i in range(self.machine.mmu.radix_levels["global"])]}, - used_ttbcr=None, - ttbrs=defaultdict(dict)) + self.data = Data( + is_tables_found=False, + is_radix_found=False, + is_registers_found=False, + opcodes={}, + regs_values={}, + page_tables={ + "user": [ + {} for i in range(self.machine.mmu.radix_levels["global"]) + ], + "kernel": [ + {} for i in range(self.machine.mmu.radix_levels["global"]) + ], + }, + data_pages=[], + empty_tables=[], + reverse_map_tables={ + "user": [ + defaultdict(set) + for i in range(self.machine.mmu.radix_levels["global"]) + ], + "kernel": [ + defaultdict(set) + for i in range(self.machine.mmu.radix_levels["global"]) + ], + }, + reverse_map_pages={ + "user": [ + defaultdict(set) + for i in range(self.machine.mmu.radix_levels["global"]) + ], + "kernel": [ + defaultdict(set) + for i in range(self.machine.mmu.radix_levels["global"]) + ], + }, + used_ttbcr=None, + ttbrs=defaultdict(dict), + ) def reload_data_from_file(self, data_filename): super(MMUShell, self).reload_data_from_file(data_filename) @@ -845,7 +1071,9 @@ def do_find_registers_values(self, arg): return logger.info("Look for opcodes related to MMU setup...") - parallel_results = self.machine.apply_parallel(self.machine.mmu.PAGE_SIZE, self.machine.cpu.parse_opcodes_parallel) + parallel_results = self.machine.apply_parallel( + self.machine.mmu.PAGE_SIZE, self.machine.cpu.parse_opcodes_parallel + ) opcodes = {} logger.info("Reaggregate threads data...") @@ -855,8 +1083,12 @@ def do_find_registers_values(self, arg): self.data.opcodes = opcodes # Filter to look only for opcodes which write on MMU register only and not read from them or from other registers - filter_f = lambda it: True if it[1]["register"] == "TTBCR" and it[1]["instruction"] == "MCR" else False - mmu_wr_opcodes = {k: v for k,v in filter(filter_f, opcodes.items())} + filter_f = ( + lambda it: True + if it[1]["register"] == "TTBCR" and it[1]["instruction"] == "MCR" + else False + ) + mmu_wr_opcodes = {k: v for k, v in filter(filter_f, opcodes.items())} logging.info("Use heuristics to find function addresses...") logging.info("This analysis could be extremely slow!") @@ -865,18 +1097,29 @@ def do_find_registers_values(self, arg): logging.info("Identify register values using data flow analysis...") # We use data flow analysis and merge the results - dataflow_values = self.machine.cpu.find_registers_values_dataflow(mmu_wr_opcodes) + dataflow_values = self.machine.cpu.find_registers_values_dataflow( + mmu_wr_opcodes + ) filtered_values = defaultdict(set) for register, values in dataflow_values.items(): for value in values: reg_obj = CPURegARM32.get_register_obj(register, value) - if reg_obj.valid and not any([val_obj.is_mmu_equivalent_to(reg_obj) for val_obj in filtered_values[register]]): + if reg_obj.valid and not any( + [ + val_obj.is_mmu_equivalent_to(reg_obj) + for val_obj in filtered_values[register] + ] + ): filtered_values[register].add(reg_obj) # Add default values - reg_obj = CPURegARM32.get_register_obj("TTBCR", self.machine.cpu.registers_values["TTBCR"]) - if reg_obj.valid and all([not reg_obj.is_mmu_equivalent_to(x) for x in filtered_values["TTBCR"]]): + reg_obj = CPURegARM32.get_register_obj( + "TTBCR", self.machine.cpu.registers_values["TTBCR"] + ) + if reg_obj.valid and all( + [not reg_obj.is_mmu_equivalent_to(x) for x in filtered_values["TTBCR"]] + ): filtered_values["TTBCR"].add(reg_obj) self.data.regs_values = filtered_values @@ -936,14 +1179,20 @@ def do_find_tables(self, args): # Parse memory in chunk of 16KiB PTL0_USER_TABLE_SIZE = ttbcr.get_ptl0_user_table_size() logger.info("Look for paging tables...") - parallel_results = self.machine.apply_parallel(16384, self.machine.mmu.parse_parallel_frame, ptl0_u_size=PTL0_USER_TABLE_SIZE) + parallel_results = self.machine.apply_parallel( + 16384, + self.machine.mmu.parse_parallel_frame, + ptl0_u_size=PTL0_USER_TABLE_SIZE, + ) logger.info("Reaggregate threads data...") for result in parallel_results: page_tables, data_pages, empty_tables = result.get() for level in range(self.machine.mmu.radix_levels["global"]): self.data.page_tables["user"][level].update(page_tables["user"][level]) - self.data.page_tables["kernel"][level].update(page_tables["kernel"][level]) + self.data.page_tables["kernel"][level].update( + page_tables["kernel"][level] + ) self.data.data_pages.extend(data_pages) self.data.empty_tables.extend(empty_tables) @@ -957,7 +1206,9 @@ def do_find_tables(self, args): if PTL0_USER_TABLE_SIZE != 16384: modes.append("user") fps["user"] = [] - self.data.page_tables["user"][1] = deepcopy(self.data.page_tables["kernel"][1]) + self.data.page_tables["user"][1] = deepcopy( + self.data.page_tables["kernel"][1] + ) # Remove all tables which point to inexistent table of lower level logger.info("Reduce false positives...") @@ -966,10 +1217,15 @@ def do_find_tables(self, args): for mode in modes: referenced_nxt = [] for table_addr in list(self.data.page_tables[mode][0].keys()): - for entry_obj in self.data.page_tables[mode][0][table_addr].entries[ptr_class].values(): - if entry_obj.address not in self.data.page_tables[mode][1] and \ - entry_obj.address not in self.data.empty_tables: - + for entry_obj in ( + self.data.page_tables[mode][0][table_addr] + .entries[ptr_class] + .values() + ): + if ( + entry_obj.address not in self.data.page_tables[mode][1] + and entry_obj.address not in self.data.empty_tables + ): # Remove the table self.data.page_tables[mode][0].pop(table_addr) break @@ -979,24 +1235,31 @@ def do_find_tables(self, args): # Remove table not referenced by upper levels referenced_nxt = set(referenced_nxt) - for table_addr in set(self.data.page_tables[mode][1].keys()).difference(referenced_nxt): + for table_addr in set(self.data.page_tables[mode][1].keys()).difference( + referenced_nxt + ): self.data.page_tables[mode][1].pop(table_addr) logger.info("Fill reverse maps...") for mode in modes: for lvl in range(0, self.machine.mmu.radix_levels["global"]): ptr_class = self.machine.mmu.map_ptr_entries_to_levels["global"][lvl] - page_classes = self.machine.mmu.map_datapages_entries_to_levels["global"][lvl] + page_classes = self.machine.mmu.map_datapages_entries_to_levels[ + "global" + ][lvl] for table_addr, table_obj in self.data.page_tables[mode][lvl].items(): for entry_obj in table_obj.entries[ptr_class].values(): - self.data.reverse_map_tables[mode][lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_tables[mode][lvl][entry_obj.address].add( + table_obj.address + ) for page_class in page_classes: for entry_obj in table_obj.entries[page_class].values(): - self.data.reverse_map_pages[mode][lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_pages[mode][lvl][ + entry_obj.address + ].add(table_obj.address) self.data.is_tables_found = True - def do_show_table(self, args): """Show an MMU table at a chosen address. Usage show_table ADDRESS [level table size]""" if not self.data.used_ttbcr: @@ -1027,13 +1290,19 @@ def do_show_table(self, args): if lvl > self.machine.mmu.radix_levels["global"] - 1: raise ValueError except ValueError: - logger.warning("Level must be an integer between 0 and {}".format(str(self.machine.mmu.radix_levels["global"] - 1))) + logger.warning( + "Level must be an integer between 0 and {}".format( + str(self.machine.mmu.radix_levels["global"] - 1) + ) + ) return try: table_size = self.parse_int(args[2]) if table_size not in valid_sizes[lvl]: - logging.warning(f"Size not allowed for choosen level! Valid sizes are:{valid_sizes[lvl]}") + logging.warning( + f"Size not allowed for choosen level! Valid sizes are:{valid_sizes[lvl]}" + ) return except ValueError: logger.warning("Invalid size value") @@ -1043,11 +1312,15 @@ def do_show_table(self, args): lvl = -1 table_buff = self.machine.memory.get_data(addr, table_size) - invalids, pt_classes, table_obj = self.machine.mmu.parse_frame(table_buff, addr, table_size, lvl) + invalids, pt_classes, table_obj = self.machine.mmu.parse_frame( + table_buff, addr, table_size, lvl + ) print(table_obj) print(f"Invalid entries: {invalids} Table levels: {pt_classes}") - def virtspace_short(self, addr, page_tables, lvl=0, prefix=0, ukx=True, cache=defaultdict(dict)): + def virtspace_short( + self, addr, page_tables, lvl=0, prefix=0, ukx=True, cache=defaultdict(dict) + ): """Recursively reconstruct virtual address space for SHORT mode""" virtspace = VASShort() @@ -1058,21 +1331,29 @@ def virtspace_short(self, addr, page_tables, lvl=0, prefix=0, ukx=True, cache=de if lvl == self.machine.mmu.radix_levels["global"] - 1: for data_class in data_classes: shift = self.machine.mmu.map_entries_to_shifts["global"][data_class] - for entry_idx, entry in page_tables[lvl][addr].entries[data_class].items(): + for entry_idx, entry in ( + page_tables[lvl][addr].entries[data_class].items() + ): permissions = entry.extract_permissions() kx = entry.is_kernel_executable_entry() and ukx x = entry.is_executable_entry() virt_addr = prefix | (entry_idx << shift) - virtspace[(permissions, x, kx)] |= portion.closedopen(virt_addr, virt_addr + entry.size) - cache[lvl][addr][(permissions, x, kx)] |= portion.closedopen(virt_addr, virt_addr + entry.size) + virtspace[(permissions, x, kx)] |= portion.closedopen( + virt_addr, virt_addr + entry.size + ) + cache[lvl][addr][(permissions, x, kx)] |= portion.closedopen( + virt_addr, virt_addr + entry.size + ) return virtspace else: if ptr_class in page_tables[lvl][addr].entries: shift = self.machine.mmu.map_entries_to_shifts["global"][ptr_class] - for entry_idx, entry in page_tables[lvl][addr].entries[ptr_class].items(): + for entry_idx, entry in ( + page_tables[lvl][addr].entries[ptr_class].items() + ): if entry.address not in page_tables[lvl + 1]: continue else: @@ -1082,7 +1363,14 @@ def virtspace_short(self, addr, page_tables, lvl=0, prefix=0, ukx=True, cache=de x = entry.is_executable_entry() virt_addr = prefix | (entry_idx << shift) - low_virts = self.virtspace_short(entry.address, page_tables, lvl + 1, virt_addr, kx, cache=cache) + low_virts = self.virtspace_short( + entry.address, + page_tables, + lvl + 1, + virt_addr, + kx, + cache=cache, + ) else: low_virts = cache[lvl + 1][entry.address] @@ -1091,20 +1379,28 @@ def virtspace_short(self, addr, page_tables, lvl=0, prefix=0, ukx=True, cache=de cache[lvl][addr][perm] |= virts_fragment for data_class in data_classes: - if data_class in page_tables[lvl][addr].entries and data_class is not None: + if ( + data_class in page_tables[lvl][addr].entries + and data_class is not None + ): shift = self.machine.mmu.map_entries_to_shifts["global"][data_class] - for entry_idx, entry in page_tables[lvl][addr].entries[data_class].items(): + for entry_idx, entry in ( + page_tables[lvl][addr].entries[data_class].items() + ): permissions = entry.extract_permissions() kx = entry.is_kernel_executable_entry() and ukx x = entry.is_executable_entry() virt_addr = prefix | (entry_idx << shift) - virtspace[(permissions, x, kx)] |= portion.closedopen(virt_addr, virt_addr + entry.size) - cache[lvl][addr][(permissions, x, kx)] |= portion.closedopen(virt_addr, virt_addr + entry.size) + virtspace[(permissions, x, kx)] |= portion.closedopen( + virt_addr, virt_addr + entry.size + ) + cache[lvl][addr][(permissions, x, kx)] |= portion.closedopen( + virt_addr, virt_addr + entry.size + ) return virtspace - def do_find_radix_trees(self, args): """Reconstruct radix trees""" if not self.data.is_tables_found: @@ -1140,46 +1436,87 @@ def do_find_radix_trees(self, args): for mode in modes: already_explored = set() - for page_addr in tqdm(self.data.data_pages.union(self.data.empty_tables).union(not_ram_pages)): + for page_addr in tqdm( + self.data.data_pages.union(self.data.empty_tables).union(not_ram_pages) + ): derived_addresses = self.machine.mmu.derive_page_address(page_addr) for derived_address in derived_addresses: if derived_address in already_explored: continue lvl, addr = derived_address - candidates[mode].extend(self.radix_roots_from_data_page(lvl, addr, self.data.reverse_map_pages[mode], self.data.reverse_map_tables[mode])) + candidates[mode].extend( + self.radix_roots_from_data_page( + lvl, + addr, + self.data.reverse_map_pages[mode], + self.data.reverse_map_tables[mode], + ) + ) already_explored.add(derived_address) - candidates[mode] = list(set(candidates[mode]).intersection(self.data.page_tables[mode][0].keys())) + candidates[mode] = list( + set(candidates[mode]).intersection( + self.data.page_tables[mode][0].keys() + ) + ) candidates[mode].sort() # Collect interrupt/paging opcodes - filter_f_read = lambda it: True if it[1]["register"] in ["DFSR", "IFSR"] and it[1]["instruction"] == "MRC" else False - kernel_opcodes_read = [x[0] for x in filter(filter_f_read, self.data.opcodes.items())] - filter_f_write = lambda it: True if it[1]["register"] in ["TTBR0", "TTBCR"] and it[1]["instruction"] == "MCR" else False - kernel_opcodes_write = [x[0] for x in filter(filter_f_write, self.data.opcodes.items())] + filter_f_read = ( + lambda it: True + if it[1]["register"] in ["DFSR", "IFSR"] and it[1]["instruction"] == "MRC" + else False + ) + kernel_opcodes_read = [ + x[0] for x in filter(filter_f_read, self.data.opcodes.items()) + ] + filter_f_write = ( + lambda it: True + if it[1]["register"] in ["TTBR0", "TTBCR"] and it[1]["instruction"] == "MCR" + else False + ) + kernel_opcodes_write = [ + x[0] for x in filter(filter_f_write, self.data.opcodes.items()) + ] logging.info("Filtering candidates...") filtered = {"kernel": {}, "user": {}} for mode in modes: - physpace_cache = defaultdict(dict) # We need to use different caches for user and kernel modes + physpace_cache = defaultdict( + dict + ) # We need to use different caches for user and kernel modes virtspace_cache = defaultdict(dict) for candidate in tqdm(candidates[mode]): - consistency, pas = self.physpace(candidate, self.data.page_tables[mode], self.data.empty_tables,cache=physpace_cache) + consistency, pas = self.physpace( + candidate, + self.data.page_tables[mode], + self.data.empty_tables, + cache=physpace_cache, + ) # Ignore inconsistent radix-tress or which maps zero spaces - if not consistency or \ - (pas.get_kernel_size() == 0 and pas.get_user_size() == 0): + if not consistency or ( + pas.get_kernel_size() == 0 and pas.get_user_size() == 0 + ): continue # Look for kernel trees able to map at least one interrupt/paging related opcodes if mode == "kernel": # We check also in user pages (when ttbr1 is not used!) because user pages are always accessible also by the kernel! - if not any([opcode_addr in pas for opcode_addr in kernel_opcodes_read]) or \ - (SHORT.ttbcr_n !=0 and not any([opcode_addr in pas for opcode_addr in kernel_opcodes_write])): + if not any( + [opcode_addr in pas for opcode_addr in kernel_opcodes_read] + ) or ( + SHORT.ttbcr_n != 0 + and not any( + [opcode_addr in pas for opcode_addr in kernel_opcodes_write] + ) + ): continue - vas = self.virtspace_short(candidate, self.data.page_tables[mode], cache=virtspace_cache) + vas = self.virtspace_short( + candidate, self.data.page_tables[mode], cache=virtspace_cache + ) # At least a kernel executable page must be exist for _, _, kx in vas: @@ -1188,7 +1525,9 @@ def do_find_radix_trees(self, args): else: continue - radix_tree = RadixTree(candidate, 0, pas, vas, kernel=True, user=False) + radix_tree = RadixTree( + candidate, 0, pas, vas, kernel=True, user=False + ) filtered[mode][candidate] = radix_tree else: @@ -1197,7 +1536,9 @@ def do_find_radix_trees(self, args): continue # At least an executable page must exists - vas = self.virtspace_short(candidate, self.data.page_tables[mode], cache=virtspace_cache) + vas = self.virtspace_short( + candidate, self.data.page_tables[mode], cache=virtspace_cache + ) for _, x, _ in vas: if x: break @@ -1211,7 +1552,9 @@ def do_find_radix_trees(self, args): else: continue - radix_tree = RadixTree(candidate, 0, pas, vas, kernel=False, user=True) + radix_tree = RadixTree( + candidate, 0, pas, vas, kernel=False, user=True + ) filtered[mode][candidate] = radix_tree self.data.ttbrs = filtered @@ -1223,15 +1566,24 @@ def do_show_radix_trees(self, args): logging.info("Please, find them first!") return - labels = ["Radix address", "First level", "Kernel size (Bytes)", "User size (Bytes)", "Kernel"] + labels = [ + "Radix address", + "First level", + "Kernel size (Bytes)", + "User size (Bytes)", + "Kernel", + ] table = PrettyTable() table.field_names = labels for mode in ["kernel", "user"]: for ttbr in self.data.ttbrs[mode].values(): - table.add_row(ttbr.entry_resume_stringified() + ["X" if mode == "kernel" else ""]) - table.sortby="Radix address" + table.add_row( + ttbr.entry_resume_stringified() + ["X" if mode == "kernel" else ""] + ) + table.sortby = "Radix address" print(table) + class MMUShellGTruth(MMUShell): def do_show_registers_gtruth(self, args): """Compare TTBCR register values found with the ground truth""" @@ -1240,33 +1592,42 @@ def do_show_registers_gtruth(self, args): return # Check if the last value of TTBCR was found - all_ttbcrs = {} # QEMU export TTBCR inside various registers as TTBCR, TTBCR, TCR_S etc. due to it's capability to emulate different ARM/ARM64 systems + all_ttbcrs = ( + {} + ) # QEMU export TTBCR inside various registers as TTBCR, TTBCR, TCR_S etc. due to it's capability to emulate different ARM/ARM64 systems for reg_name, value_data in self.gtruth.items(): if "TCR" in reg_name or "TTBCR" in reg_name: for value, value_info in value_data.items(): - if value not in all_ttbcrs or (value_info[1] > all_ttbcrs[value][1]): + if value not in all_ttbcrs or ( + value_info[1] > all_ttbcrs[value][1] + ): all_ttbcrs[value] = (value_info[0], value_info[1]) - - last_ttbcr = TTBCR(sorted(all_ttbcrs.keys(), key=lambda x: all_ttbcrs[x][1], reverse=True)[0]) + last_ttbcr = TTBCR( + sorted(all_ttbcrs.keys(), key=lambda x: all_ttbcrs[x][1], reverse=True)[0] + ) ttbcr_fields_equals = {} for value_found_obj in self.data.regs_values["TTBCR"]: - ttbcr_fields_equals[value_found_obj] = value_found_obj.count_fields_equals(last_ttbcr) - k_sorted = sorted(ttbcr_fields_equals.keys(), key=lambda x: ttbcr_fields_equals[x], reverse=True) + ttbcr_fields_equals[value_found_obj] = value_found_obj.count_fields_equals( + last_ttbcr + ) + k_sorted = sorted( + ttbcr_fields_equals.keys(), + key=lambda x: ttbcr_fields_equals[x], + reverse=True, + ) tcr_found = k_sorted[0] correct_fields_found = ttbcr_fields_equals[tcr_found] if correct_fields_found: print(f"Correct TTBCR value: {last_ttbcr}, Found: {tcr_found}") print("TTBCR fields found:... {}/2".format(correct_fields_found)) - print("FP: {}".format(str(len(self.data.regs_values["TTBCR"]) - 1) )) + print("FP: {}".format(str(len(self.data.regs_values["TTBCR"]) - 1))) else: print(f"Correct TTBCR value: {last_ttbcr}") print("TTBCR fields found:... 0/2") - print("FP: {}".format(str(len(self.data.regs_values["TTBCR"])) )) - - + print("FP: {}".format(str(len(self.data.regs_values["TTBCR"])))) def do_show_radix_trees_gtruth(self, args): """Compare found radix trees with the ground truth""" @@ -1280,10 +1641,22 @@ def do_show_radix_trees_gtruth(self, args): ttbr0_phy_cache = defaultdict(dict) ttbr1_phy_cache = defaultdict(dict) virt_cache = defaultdict(dict) - filter_f_read = lambda it: True if it[1]["register"] in ["DFSR", "IFSR"] and it[1]["instruction"] == "MRC" else False - kernel_opcodes_read = [x[0] for x in filter(filter_f_read, self.data.opcodes.items())] - filter_f_write = lambda it: True if it[1]["register"] in ["TTBR0", "TTBCR"] and it[1]["instruction"] == "MCR" else False - kernel_opcodes_write = [x[0] for x in filter(filter_f_write, self.data.opcodes.items())] + filter_f_read = ( + lambda it: True + if it[1]["register"] in ["DFSR", "IFSR"] and it[1]["instruction"] == "MRC" + else False + ) + kernel_opcodes_read = [ + x[0] for x in filter(filter_f_read, self.data.opcodes.items()) + ] + filter_f_write = ( + lambda it: True + if it[1]["register"] in ["TTBR0", "TTBCR"] and it[1]["instruction"] == "MCR" + else False + ) + kernel_opcodes_write = [ + x[0] for x in filter(filter_f_write, self.data.opcodes.items()) + ] # User or kernel+user radix trees for key in ["TTBR0", "TTBR0_S", "TTBR0_EL1", "TTBR0_EL1_S"]: @@ -1296,21 +1669,28 @@ def do_show_radix_trees_gtruth(self, args): if ttbr.address not in self.data.page_tables["user"][0]: continue - consistency, pas = self.physpace(ttbr.address, self.data.page_tables["user"], self.data.empty_tables, cache=ttbr0_phy_cache) + consistency, pas = self.physpace( + ttbr.address, + self.data.page_tables["user"], + self.data.empty_tables, + cache=ttbr0_phy_cache, + ) if not consistency: continue if pas.get_kernel_size() != 0: continue - virtspace = self.virtspace_short(ttbr.address, self.data.page_tables["user"], cache=virt_cache) + virtspace = self.virtspace_short( + ttbr.address, self.data.page_tables["user"], cache=virt_cache + ) for _, x, _ in virtspace: if x: break else: continue - # Filter for at least a writable page + # Filter for at least a writable page for p, _, _ in virtspace: if p & 0b10 == 2: break @@ -1323,16 +1703,28 @@ def do_show_radix_trees_gtruth(self, args): if ttbr.address not in self.data.page_tables["kernel"][0]: continue - consistency, pas = self.physpace(ttbr.address, self.data.page_tables["kernel"], self.data.empty_tables, cache=ttbr0_phy_cache) - if not consistency or (pas.get_kernel_size() == 0 and pas.get_user_size() == 0): + consistency, pas = self.physpace( + ttbr.address, + self.data.page_tables["kernel"], + self.data.empty_tables, + cache=ttbr0_phy_cache, + ) + if not consistency or ( + pas.get_kernel_size() == 0 and pas.get_user_size() == 0 + ): continue - if not any([opcode_addr in pas for opcode_addr in kernel_opcodes_read]) or \ - not any([opcode_addr in pas for opcode_addr in kernel_opcodes_write]): + if not any( + [opcode_addr in pas for opcode_addr in kernel_opcodes_read] + ) or not any( + [opcode_addr in pas for opcode_addr in kernel_opcodes_write] + ): continue # At least a kernel executable page must be exist - virtspace = self.virtspace_short(ttbr.address, self.data.page_tables["kernel"], cache=virt_cache) + virtspace = self.virtspace_short( + ttbr.address, self.data.page_tables["kernel"], cache=virt_cache + ) for _, _, kx in virtspace: if kx: break @@ -1353,15 +1745,26 @@ def do_show_radix_trees_gtruth(self, args): if ttbr.address not in self.data.page_tables["kernel"][0]: continue - consistency, pas = self.physpace(ttbr.address, self.data.page_tables["kernel"], self.data.empty_tables, cache=ttbr1_phy_cache) - if not consistency or (pas.get_kernel_size() == 0 and pas.get_user_size() == 0): + consistency, pas = self.physpace( + ttbr.address, + self.data.page_tables["kernel"], + self.data.empty_tables, + cache=ttbr1_phy_cache, + ) + if not consistency or ( + pas.get_kernel_size() == 0 and pas.get_user_size() == 0 + ): continue - if not any([opcode_addr in pas for opcode_addr in kernel_opcodes_read]): + if not any( + [opcode_addr in pas for opcode_addr in kernel_opcodes_read] + ): continue # At least a kernel executable page must be exist - virtspace = self.virtspace_short(ttbr.address, self.data.page_tables["kernel"], cache=virt_cache) + virtspace = self.virtspace_short( + ttbr.address, self.data.page_tables["kernel"], cache=virt_cache + ) for _, _, kx in virtspace: if kx: break @@ -1372,16 +1775,34 @@ def do_show_radix_trees_gtruth(self, args): # True positives, false negatives, false positives if SHORT.ttbcr_n == 0: - tps = sorted(set(ttbr0s.keys()).intersection(set(self.data.ttbrs["kernel"].keys()))) - fns = sorted(set(ttbr0s.keys()).difference(set(self.data.ttbrs["kernel"].keys()))) - fps = sorted(set(self.data.ttbrs["kernel"].keys()).difference(set(ttbr0s.keys()))) + tps = sorted( + set(ttbr0s.keys()).intersection(set(self.data.ttbrs["kernel"].keys())) + ) + fns = sorted( + set(ttbr0s.keys()).difference(set(self.data.ttbrs["kernel"].keys())) + ) + fps = sorted( + set(self.data.ttbrs["kernel"].keys()).difference(set(ttbr0s.keys())) + ) else: - tps = sorted(set(ttbr1s.keys()).intersection(set(self.data.ttbrs["kernel"].keys()))) - fns = sorted(set(ttbr1s.keys()).difference(set(self.data.ttbrs["kernel"].keys()))) - fps = sorted(set(self.data.ttbrs["kernel"].keys()).difference(set(ttbr1s.keys()))) - tpsu = sorted(set(ttbr0s.keys()).intersection(set(self.data.ttbrs["user"].keys()))) - fnsu = sorted(set(ttbr0s.keys()).difference(set(self.data.ttbrs["user"].keys()))) - fpsu = sorted(set(self.data.ttbrs["user"].keys()).difference(set(ttbr0s.keys()))) + tps = sorted( + set(ttbr1s.keys()).intersection(set(self.data.ttbrs["kernel"].keys())) + ) + fns = sorted( + set(ttbr1s.keys()).difference(set(self.data.ttbrs["kernel"].keys())) + ) + fps = sorted( + set(self.data.ttbrs["kernel"].keys()).difference(set(ttbr1s.keys())) + ) + tpsu = sorted( + set(ttbr0s.keys()).intersection(set(self.data.ttbrs["user"].keys())) + ) + fnsu = sorted( + set(ttbr0s.keys()).difference(set(self.data.ttbrs["user"].keys())) + ) + fpsu = sorted( + set(self.data.ttbrs["user"].keys()).difference(set(ttbr0s.keys())) + ) # Show results table = PrettyTable() @@ -1394,42 +1815,28 @@ def do_show_radix_trees_gtruth(self, args): mode = "K" for tp in sorted(tps): - table.add_row([hex(tp), - "X", - mode, - kernel_regs[tp][1][0], kernel_regs[tp][1][1]]) + table.add_row( + [hex(tp), "X", mode, kernel_regs[tp][1][0], kernel_regs[tp][1][1]] + ) for fn in sorted(fns): - table.add_row([hex(fn), - "", - mode, - kernel_regs[fn][1][0], kernel_regs[fn][1][1]]) + table.add_row( + [hex(fn), "", mode, kernel_regs[fn][1][0], kernel_regs[fn][1][1]] + ) for fp in sorted(fps): - table.add_row([hex(fp), - "False positive", - mode, - "", ""]) + table.add_row([hex(fp), "False positive", mode, "", ""]) # User if SHORT.ttbcr_n != 0: for tp in sorted(tpsu): - table.add_row([hex(tp), - "X", - "U", - ttbr0s[tp][1][0], ttbr0s[tp][1][1]]) + table.add_row([hex(tp), "X", "U", ttbr0s[tp][1][0], ttbr0s[tp][1][1]]) for fn in sorted(fnsu): - table.add_row([hex(fn), - "", - "U", - ttbr0s[fn][1][0], ttbr0s[fn][1][1]]) + table.add_row([hex(fn), "", "U", ttbr0s[fn][1][0], ttbr0s[fn][1][1]]) for fp in sorted(fpsu): - table.add_row([hex(fp), - "False positive", - "U", - "", ""]) + table.add_row([hex(fp), "False positive", "U", "", ""]) print(table) print(f"TP:{len(tps)} FN:{len(fns)} FP:{len(fps)}") diff --git a/architectures/generic.py b/mmushell/architectures/generic.py similarity index 50% rename from architectures/generic.py rename to mmushell/architectures/generic.py index b755b47..d29bd21 100644 --- a/architectures/generic.py +++ b/mmushell/architectures/generic.py @@ -1,5 +1,12 @@ +import sys +import gc +import os +import random +import logging import portion -from mmap import mmap, MAP_SHARED, PROT_READ, MADV_HUGEPAGE +import importlib +import multiprocessing as mp + from miasm.jitter.VmMngr import Vm from miasm.jitter.csts import PAGE_READ, PAGE_WRITE, PAGE_EXEC from miasm.core.locationdb import LocationDB @@ -9,38 +16,33 @@ from miasm.expression.expression import ExprInt, ExprId from miasm.core.bin_stream import bin_stream_file from miasm.analysis.depgraph import DependencyGraph + +from mmap import mmap, MAP_SHARED, PROT_READ, MADV_HUGEPAGE from tqdm import tqdm from itertools import chain -import logging from compress_pickle import load, dump from cmd import Cmd -import sys -import importlib from IPython import embed -import multiprocessing as mp -import random from collections import defaultdict, deque -import sys +from dataclasses import dataclass, field from copy import deepcopy from pickle import load as Load -from dataclasses import dataclass, field from enum import IntEnum from typing import Any, Dict from time import sleep from random import uniform from struct import iter_unpack, unpack from copy import deepcopy, copy -import sys -import gc -import os + logger = logging.getLogger(__name__) + # Disable print() from MIASM class DisableLogs: def __enter__(self): self._original_stdout = sys.stdout - sys.stdout = open(os.devnull, 'w') + sys.stdout = open(os.devnull, "w") def __exit__(self, exc_type, exc_val, exc_tb): sys.stdout.close() @@ -48,13 +50,33 @@ def __exit__(self, exc_type, exc_val, exc_tb): class CPUReg: - def is_valid(self, value): + """Represents a CPU register""" + + def is_valid(self, value) -> bool: + """Check if the value is valid for the register + + Being a valid register is very specific to the architecture, so this function is aimed to be overloaded by child classes + + Args: + value: the value to check + + Returns: + True if the value is valid, False otherwise + """ return True def is_mmu_equivalent_to(self, other_reg): + """Check if the register is equivalent to another register + + Args: + other_reg: the other register to compare + + See child classes for more details + """ raise NotImplementedError def __hash__(self): + """Hash of the contained value from register""" return hash(self.value) def __eq__(self, other): @@ -62,27 +84,94 @@ def __eq__(self, other): class TableEntry: + """Represents a table Page Table Entry + + It holds the mapping between a virtual address of a page and the address of a physical frame. + There is also auxiliary information about the page such as a present bit, a dirty or modified bit, address space or process ID information, amongst others. + + Aimed to be inherited by child classes to represent different architectures + + Attributes: + address: the virtual address of the page + flags: the flags of the page + """ + def __init__(self, address, flags, *args): + """Initialize the Table Entry + + Args: + address: the virtual address of the page + flags: the flags of the page + *args: additional arguments + """ self.address = address self.flags = flags class PageTable: + """Represents a Page Table + + A page table is the data structure used by a virtual memory system in a computer operating system to store the mapping between virtual addresses and physical addresses. + + Aimed to be inherited by child classes to represent different architectures + + Attributes: + address: the address of the page table + size: the size of the page table + entries: the entries of the page table + levels: the levels of the page table + """ + entry_size = 0 + def __init__(self, address, size, entries, levels, *args): + """Initialize the Page Table + + Args: + address: the address of the page table + size: the size of the page table + entries: the entries of the page table + levels: the levels of the page table + *args: additional arguments + """ self.address = address self.size = size self.entries = entries self.levels = levels - def apply_on_entries(self, f, args): + def apply_on_entries(self, f: function, args): + """Run a function to all the entries of the page table. + + The provided function should take an entry and the arguments as input and return the result of the operation. + + Args: + f: the function to apply + args: the arguments to pass to the function + + Returns: + a list with the results of the function applied to all the entries + """ res = [] for entry in self.entries: res.append(f(entry, args)) return res + def perms_bool_to_string(kr, kw, kx, r, w, x): - perm_s = "R" if kr else "-" + """Convert a set of permissions from boolean to string + + Args: + kr: read permission for the kernel + kw: write permission for the kernel + kx: execute permission for the kernel + r: read permission for the user + w: write permission for the user + x: execute permission for the user + + Returns: + a string with the permissions + """ + perm_s = "R" if kr else "-" perm_s += "W" if kw else "-" perm_s += "X" if kx else "-" perm_s += "r" if r else "-" @@ -90,10 +179,55 @@ def perms_bool_to_string(kr, kw, kx, r, w, x): perm_s += "x" if x else "-" return perm_s + class RadixTree: - labels = ["Radix address", "First level", "Kernel size (Bytes)", "User size (Bytes)"] + """Represents a Radix Tree + + Radix trees maintain a hierarchical representation of the SAS. Each tree is composed by N-1 levels of + directory tables, each containing entries that either point to tables of the lower level or to a physical + memory page (huge pages), whose size depends on the level itself. + The final level is composed of page tables that point only to same-size physical memory pages. + The tree root is the physical address of the directory table of Level 0 and it identifies uniquely the SAS + and, consequently, the process to which it is assigned. This address is stored in a special system register + (here generically called RADIX_ROOT) by the operating system and it is used by the MMU to locate the radix + tree when it starts a new translation. + The translation performed by the MMU starts from the root table pointed by the address contained in RADIX_ROOT: + the MMU then divides the segmented address into two parts: a prefix and a page offset. + The prefix part is divided into a series of N chunks that represent the hierarchical sequence of indexes to be + used to locate the entry inside a table of the corresponding level. This process ends when an entry points to + a physical page. At this point, the MMU returns the concatenation of the page offset extracted by the segmented + address to the physical page address found in the last page table entry. + + Aimed to be inherited by child classes to represent different architectures + + Attributes: + top_table: the address of the top table + init_level: the initial level of the radix tree + pas: the Physical Address Space + vas: the Virtual Address Space + kernel: if the kernel space is enabled + user: if the user space is enabled + """ + + labels = [ + "Radix address", + "First level", + "Kernel size (Bytes)", + "User size (Bytes)", + ] addr_fmt = "0x{:016x}" + def __init__(self, top_table, init_level, pas, vas, kernel=True, user=True): + """Initialize the Radix Tree + + Args: + top_table: the address of the top table + init_level: the initial level of the radix tree + pas: the Physical Address Space + vas: the Virtual Address Space + kernel: if the kernel space is enabled + user: if the user space is enabled + """ self.top_table = top_table self.init_level = init_level self.pas = pas @@ -102,29 +236,60 @@ def __init__(self, top_table, init_level, pas, vas, kernel=True, user=True): self.user = user def __repr__(self): + """String representation of the Radix Tree""" e_resume = self.entry_resume_stringified() - return str([self.labels[i] + ": " + str(e_resume[i]) for i in range(len(self.labels))]) - - def entry_resume(self): - return [self.top_table, - self.init_level, - self.pas.get_kernel_size(), - self.pas.get_user_size() - ] - - def entry_resume_stringified(self): + return str( + [self.labels[i] + ": " + str(e_resume[i]) for i in range(len(self.labels))] + ) + + def entry_resume(self) -> list: + """Get the resume of the Radix Tree + + Returns: + a list with the resume of the Radix Tree + """ + return [ + self.top_table, + self.init_level, + self.pas.get_kernel_size(), + self.pas.get_user_size(), + ] + + def entry_resume_stringified(self) -> list: + """Get the resume of the Radix Tree as string + + Returns: + a list with the resume of the Radix Tree as string + """ res = self.entry_resume() res[0] = self.addr_fmt.format(res[0]) for idx, r in enumerate(res[1:], start=1): res[idx] = str(r) return res + class VAS(defaultdict): + """Represents a Virtual Address Space + + The Virtual Address Space (VAS) is a hierarchical data structure that represents the SAS of a process. + The VAS is composed of a set of entries, each one representing a set of contiguous virtual addresses + with the same permissions. The VAS is organized in a hierarchical way, where each entry is identified by + a set of permissions and contains a set of intervals of contiguous virtual addresses. + + Aimed to be inherited by child classes to represent different architectures + + Attributes: + default_factory: the default interval for the VAS + + Note: the default_factory is initialized as an empty interval to avoid the need of checking if a key is present before accessing it + """ + def __init__(self, *args, **kwargs): super(VAS, self).__init__() self.default_factory = portion.empty def __repr__(self): + """String representation of the VAS""" s = "" for k in self: k_str = perms_bool_to_string(*k) @@ -134,6 +299,12 @@ def __repr__(self): return s def hierarchical_extend(self, other, uperms): + """Extend this VAS with another VAS + + Args: + other: the other VAS to extend with + uperms: the permissions to use + """ for perm in other: new_perm = [] for i in range(6): @@ -144,10 +315,29 @@ def hierarchical_extend(self, other, uperms): @dataclass class PAS: + """Represents a Physical Address Space + + The Physical Address Space (PAS) is a hierarchical data structure that represents the SAS of the physical memory. + The PAS is composed of a set of entries, each one representing a set of contiguous physical addresses with the same permissions. + The PAS is organized in a hierarchical way, where each entry is identified by a set of permissions and contains a set of intervals of contiguous physical addresses. + + Attributes: + space: the space of the PAS + space_size: the size of the space + + Note: those attributes are initialized as defaultdicts to avoid the need of checking if a key is present before accessing it + """ + space: Dict = field(default_factory=lambda: defaultdict(dict)) space_size: Dict = field(default_factory=lambda: defaultdict(int)) def hierarchical_extend(self, other, uperms): + """Extend this PAS with another PAS + + Args: + other: the other PAS to extend with + uperms: the permissions to use + """ for perm in other.space: new_perm = [] for i in range(6): @@ -157,6 +347,14 @@ def hierarchical_extend(self, other, uperms): self.space_size[new_perm] += other.space_size[perm] def __contains__(self, key): + """Check if a key is present in the PAS + + Args: + key: the key to check + + Returns: + True if the key is present, False otherwise + """ for addresses in self.space.values(): for address in addresses: if address <= key < address + addresses[address]: @@ -164,6 +362,14 @@ def __contains__(self, key): return False def is_in_kernel_space(self, key): + """Check if a key is in the kernel space + + Args: + key: the key to check + + Returns: + True if the key is in the kernel space, False otherwise + """ for perms, addresses in self.space.items(): if perms[0] or perms[1] or perms[2]: for address in addresses: @@ -172,6 +378,14 @@ def is_in_kernel_space(self, key): return False def is_in_kernel_x_space(self, key): + """Check if a key is in the kernel executable space + + Args: + key: the key to check + + Returns: + True if the key is in the kernel executable space, False otherwise + """ for perms, addresses in self.space.items(): if perms[0] and perms[2]: for address in addresses: @@ -180,39 +394,78 @@ def is_in_kernel_x_space(self, key): return False def is_in_user_space(self, key): + """Check if a key is in the user space + + Args: + key: the key to check + + Returns: + True if the key is in the user space, False otherwise + """ for perms, addresses in self.space.items(): - if not(perms[0] or perms[1] or perms[2]): + if not (perms[0] or perms[1] or perms[2]): for address in addresses: if address <= key < address + addresses[address]: return True return False def __repr__(self): + """String representation of the PAS""" ret = "" for perm in self.space: - symb = lambda x,s: s if x else "-" - ret += "{}{}{} {}{}{}: {}\n".format(symb(perm[0], "R"),symb(perm[1], "W"),symb(perm[2], "X"),symb(perm[3], "R"),symb(perm[4], "W"),symb(perm[5], "X"), self.space_size[perm]) + symb = lambda x, s: s if x else "-" + ret += "{}{}{} {}{}{}: {}\n".format( + symb(perm[0], "R"), + symb(perm[1], "W"), + symb(perm[2], "X"), + symb(perm[3], "R"), + symb(perm[4], "W"), + symb(perm[5], "X"), + self.space_size[perm], + ) return ret def get_kernel_size(self): + """Get the size of the kernel space""" size = 0 for perm in self.space: - if not(perm[3] or perm[4] or perm[5]): + if not (perm[3] or perm[4] or perm[5]): size += self.space_size[perm] return size def get_user_size(self): + """Get the size of the user space""" size = 0 for perm in self.space: if perm[3] or perm[4] or perm[5]: size += self.space_size[perm] return size + class Machine: + """Represents a generic machine + + A machine is composed by a CPU, a MMU and a memory. It is able to parse the memory in parallel and to extract the dataflow of the registers. + + Attributes: + cpu: the CPU of the machine + mmu: the MMU of the machine + memory: the memory of the machine + gtruth: the ground truth of the machine + data: the data of the machine + """ + @classmethod def from_machine_config(cls, machine_config, **kwargs): - """Create a machine starting from a YAML file descriptor""" + """Create a machine starting from a YAML file descriptor + + Args: + machine_config: the YAML file descriptor + **kwargs: additional arguments + Returns: + a new Machine object + """ # Check no intersection between memory regions ram_portion = portion.empty() for region_dict in machine_config["memspace"]["ram"]: @@ -223,14 +476,18 @@ def from_machine_config(cls, machine_config, **kwargs): ram_portion = ram_portion.union(region_portion) # Module to use - architecture_module = importlib.import_module("architectures." + machine_config["cpu"]["architecture"]) + architecture_module = importlib.import_module( + "architectures." + machine_config["cpu"]["architecture"] + ) # Create CPU cpu = architecture_module.CPU.from_cpu_config(machine_config["cpu"]) # Create MMU try: - mmu_class = getattr(architecture_module, machine_config["mmu"]["mode"].upper()) + mmu_class = getattr( + architecture_module, machine_config["mmu"]["mode"].upper() + ) except AttributeError: logger.fatal("Unknown MMU mode!") exit(1) @@ -242,6 +499,14 @@ def from_machine_config(cls, machine_config, **kwargs): return architecture_module.Machine(cpu, mmu, memory, **kwargs) def __init__(self, cpu, mmu, memory, **kwargs): + """Initialize the Machine + + Args: + cpu: the CPU of the machine + mmu: the MMU of the machine + memory: the memory of the machine + **kwargs: additional arguments + """ self.cpu = cpu self.mmu = mmu self.memory = memory @@ -252,13 +517,35 @@ def __init__(self, cpu, mmu, memory, **kwargs): self.memory.machine = self def get_miasm_machine(self): + """Get the Miasm machine + + Aimed to be overloaded by child classes + + Returns: + the Miasm machine + """ return None def __del__(self): self.memory.close() - def apply_parallel(self, frame_size, parallel_func, iterators=None, max_address=-1, **kwargs): - """Apply parallel_func using multiple core to frame_size chunks of RAM or iterators arguments""" + def apply_parallel( + self, frame_size, parallel_func, iterators=None, max_address=-1, **kwargs + ) -> list: + """Run parallel_func using multiple core to frame_size chunks of RAM or iterators arguments + + Only used in child classes for parsing memory + + Args: + frame_size: the size of the frame to parse + parallel_func: the function to run in parallel + iterators: the iterators to use + max_address: the maximum address to parse + **kwargs: additional arguments + + Returns: + a list of the results of parallel_func + """ # Prepare the pool logger.info("Parsing memory...") @@ -267,31 +554,69 @@ def apply_parallel(self, frame_size, parallel_func, iterators=None, max_address= if iterators is None: # Create iterators for parallel execution - _, addresses_iterators = self.memory.get_addresses(frame_size, cpus=cpus, max_address=max_address) + _, addresses_iterators = self.memory.get_addresses( + frame_size, cpus=cpus, max_address=max_address + ) else: addresses_iterators = iterators # GO! - parsing_results_async = [pool.apply_async(parallel_func, - args=(addresses_iterator, frame_size, pidx), kwds=kwargs) - for pidx, addresses_iterator in enumerate(addresses_iterators)] + parsing_results_async = [ + pool.apply_async( + parallel_func, args=(addresses_iterator, frame_size, pidx), kwds=kwargs + ) + for pidx, addresses_iterator in enumerate(addresses_iterators) + ] pool.close() pool.join() - print("\n") # Workaround for tqdm + print("\n") # Workaround for tqdm return parsing_results_async + class CPU: + """Represents a generic CPU + + A CPU is able to parse opcodes and to find the dataflow of the registers. + + Aimed to be inherited by child classes to represent different architectures + + Attributes: + architecture: the architecture of the CPU + bits: the bits of the CPU + endianness: the endianness of the CPU + processor_features: the processor features of the CPU + registers_values: the registers values of the CPU + opcode_to_mmu_regs: the mapping between opcodes and MMU registers + opcode_to_gregs: the mapping between opcodes and general registers + machine: the machine of the CPU + """ + opcode_to_mmu_regs = None opcode_to_gregs = None @classmethod def from_cpu_config(cls, cpu_config, **kwargs): + """Create a CPU starting from a YAML file descriptor + + Args: + cpu_config: the YAML file descriptor + **kwargs: additional arguments + + Returns: + a new CPU object + """ return CPU(cpu_config) machine = None + def __init__(self, params): + """Initialize the CPU + + Args: + params: the parameters of the CPU + """ self.architecture = params["architecture"] self.bits = params["bits"] self.endianness = params["endianness"] @@ -300,21 +625,77 @@ def __init__(self, params): @staticmethod def extract_bits_little(entry, pos, n): + """Extract bits from an entry in little endian + + Args: + entry: the entry to extract bits from + pos: the position of the bits + n: the number of bits to extract + + Returns: + the extracted bits + """ return (entry >> pos) & ((1 << n) - 1) @staticmethod def extract_bits_big(entry, pos, n): + """Extract bits from an entry in big endian + + Args: + entry: the entry to extract bits from + pos: the position of the bits + n: the number of bits to extract + + Returns: + the extracted bits + """ return (entry >> (32 - pos - n)) & ((1 << n) - 1) @staticmethod def extract_bits_big64(entry, pos, n): + """Extract bits from an entry in big endian 64 bits + + Args: + entry: the entry to extract bits from + pos: the position of the bits + n: the number of bits to extract + + Returns: + the extracted bits + """ return (entry >> (64 - pos - n)) & ((1 << n) - 1) def parse_opcode(self, buff, page_addr, offset): + """Parse an opcode + + Aimed to be overloaded by child classes + + Args: + buff: the buffer to parse + page_addr: the address of the page + offset: the offset of the opcode + + Returns: + the parsed opcode + """ raise NotImplementedError - def parse_opcodes_parallel(self, addresses, frame_size, pidx, **kwargs): - sleep(uniform(pidx, pidx+1) // 1000) + def parse_opcodes_parallel(self, addresses, frame_size, pidx, **kwargs) -> dict: + """Parse opcodes in parallel + + Every process sleep a random delay in order to desynchronise access to disk and maximixe the throuput + + Args: + addresses: the addresses to parse + frame_size: the size of the frame + pidx: the process index + **kwargs: additional arguments + + Returns: + dictionnary of parsed opcodes + """ + # Every process sleep a random delay in order to desincronize access to disk and maximixe the throuput + sleep(uniform(pidx, pidx + 1) // 1000) opcodes = {} mm = copy(self.machine.memory) @@ -322,35 +703,57 @@ def parse_opcodes_parallel(self, addresses, frame_size, pidx, **kwargs): # Cicle over every frame total_elems, iterator = addresses - for frame_addr in tqdm(iterator, position=-pidx, total=total_elems, leave=False): - frame_buf = mm.get_data(frame_addr, frame_size) # We parse memory in PAGE_SIZE chunks - - for idx, opcode in enumerate(iter_unpack(self.processor_features["opcode_unpack_fmt"], frame_buf)): + for frame_addr in tqdm( + iterator, position=-pidx, total=total_elems, leave=False + ): + frame_buf = mm.get_data( + frame_addr, frame_size + ) # We parse memory in PAGE_SIZE chunks + + for idx, opcode in enumerate( + iter_unpack(self.processor_features["opcode_unpack_fmt"], frame_buf) + ): opcode = opcode[0] - opcodes.update(self.parse_opcode(opcode, frame_addr, idx * self.processor_features["instr_len"])) + opcodes.update( + self.parse_opcode( + opcode, frame_addr, idx * self.processor_features["instr_len"] + ) + ) return opcodes - def find_registers_values_dataflow(self, opcodes, zero_registers=[]): + def find_registers_values_dataflow(self, opcodes, zero_registers=[]) -> set: + """Find the dataflow of the registers + + Args: + opcodes: the opcodes to analyze + zero_registers: the registers to ignore + + Returns: + a dictionnary with the dataflow of the registers + """ # Miasm require to define a memory() function to access to the underlaying # memory layer during the Python translation # WORKAROUND: memory() does not permit more than 2 args... endianness = self.endianness + def memory(addr, size): return int.from_bytes(self.machine.memory.get_data(addr, size), endianness) machine = self.machine.get_miasm_machine() vm = self.machine.memory.get_miasm_vmmngr() - mdis = machine.dis_engine(bin_stream_vm(vm), dont_dis_nulstart_bloc=False, loc_db=LocationDB()) + mdis = machine.dis_engine( + bin_stream_vm(vm), dont_dis_nulstart_bloc=False, loc_db=LocationDB() + ) ir_arch = machine.ira(mdis.loc_db) py_transl = TranslatorPython() # Disable MIASM logging - logging.getLogger('asmblock').disabled = True + logging.getLogger("asmblock").disabled = True registers_values = defaultdict(set) - # We use a stack data structure (deque) in order to manage also parent functions (EXPERIMENTAL not implemented here) + # We use a stack data structure (deque) in order to also manage parent functions (EXPERIMENTAL not implemented here) instr_deque = deque([(addr, opcodes[addr]) for addr in opcodes]) while len(instr_deque): instr_addr, instr_data = instr_deque.pop() @@ -398,12 +801,16 @@ def memory(addr, size): # Recreate default CPU config registers state and general registers to look for bits = self.bits gp_registers = [ExprId(reg, bits) for reg in instr_data["gpr"]] - init_ctx = {ExprId(name.upper(), bits): ExprInt(value, bits) for name, value in self.registers_values.items()} + init_ctx = { + ExprId(name.upper(), bits): ExprInt(value, bits) + for name, value in self.registers_values.items() + } # Generate solutions loops = 0 - for sol_nb, sol in enumerate(dg.get(current_block.loc_key, gp_registers, assignblk_index, set())): - + for sol_nb, sol in enumerate( + dg.get(current_block.loc_key, gp_registers, assignblk_index, set()) + ): # The solution contains a loop, we permit only a maximum of 10 solutions with loops... if sol.has_loop: loops += 1 @@ -444,24 +851,120 @@ def memory(addr, size): self.machine.memory.free_miasm_memory() return registers_values + class MMU: + """Represents a generic MMU + + The MMU is the hardware device that converts the virtual addresses used by the processor to physical addresses. + To accomplish this task, the MMU needs to be configured by using special system registers, while in-memory + structures that maintain the virtual-to-physical mapping have to be defined and continuously updated by the + operating system. + When the MMU fails to resolve a requested virtual address, it raises an interrupt to signal the OS to update + the in-memory related structures. + It is important to note that the MMU demands strict conformity of the shape and topology of the in-memory structure + to the ISA and MMU configuration requirements. Otherwise, it raises an exception and aborts the address translation. + The translation process can involve up to two separate translations, both accomplished by the MMU: segmentation, + which converts virtual to segmented addresses, and paging, which converts segmented addresses to physical ones. + Some architectures use either one or the other, while others use both. + In general, when the system boots, the MMU is virtually disabled and all the virtual addresses are identically + transformed to physical ones. This allows the OS to start in a simplified memory environment and gives it time to + properly configure and enable the MMU before spawning other processes. + Since address translation is a performance bottleneck, the latest translated addresses are cached in a few but low- + latency hardware structures called Translation Lookaside Buffers (TLBs). Before starting a translation, the MMU checks + if TLBs contain an already resolved virtual address and, if so, it returns directly the corresponding physical addresses. + + Aimed to be inherited by child classes to represent different architectures + + Attributes: + mmu_config: the configuration of the MMU + PAGE_SIZE: the size of the page + extract_bits: the function to extract bits + paging_unpack_format: the format of the paging + page_table_class: the class of the page table + radix_levels: the levels of the radix tree + top_prefix: the top prefix of the radix tree + entries_size: the size of the entries + map_ptr_entries_to_levels: the mapping between pointer entries and levels + map_datapages_entries_to_levels: the mapping between data pages entries and levels + map_level_to_table_size: the mapping between levels and table sizes + map_entries_to_shifts: the mapping between entries and shifts + map_reserved_entries_to_levels: the mapping between reserved entries and levels + machine: the machine of the MMU + """ + machine = None + def __init__(self, mmu_config): + """Initialize the MMU + + Args: + mmu_config: the configuration of the MMU + """ self.mmu_config = mmu_config @staticmethod def extract_bits_little(entry, pos, n): + """Extract bits from an entry in little endian + + Args: + entry: the entry to extract bits from + pos: the position of the bits + n: the number of bits to extract + + Returns: + the extracted bits + """ return (entry >> pos) & ((1 << n) - 1) @staticmethod def extract_bits_big(entry, pos, n): + """Extract bits from an entry in big endian + + Args: + entry: the entry to extract bits from + pos: the position of the bits + n: the number of bits to extract + + Returns: + the extracted bits + """ return (entry >> (32 - pos - n)) & ((1 << n) - 1) @staticmethod def extract_bits_big64(entry, pos, n): + """Extract bits from an entry in big endian 64 bits + + Args: + entry: the entry to extract bits from + pos: the position of the bits + n: the number of bits to extract + + Returns: + the extracted bits + """ return (entry >> (64 - pos - n)) & ((1 << n) - 1) + class MMURadix(MMU): + """Represents a MMU that uses a Radix Tree + + Aimed to be inherited by child classes to represent different architectures + + Attributes: + PAGE_SIZE: the size of the page + extract_bits: the function to extract bits + paging_unpack_format: the format of the paging + page_table_class: the class of the page table + radix_levels: the levels of the radix tree + top_prefix: the top prefix of the radix tree + entries_size: the size of the entries + map_ptr_entries_to_levels: the mapping between pointer entries and levels + map_datapages_entries_to_levels: the mapping between data pages entries and levels + map_level_to_table_size: the mapping between levels and table sizes + map_entries_to_shifts: the mapping between entries and shifts + map_reserved_entries_to_levels: the mapping between reserved entries and levels + """ + PAGE_SIZE = 0 extract_bits = None paging_unpack_format = "" @@ -476,22 +979,56 @@ class MMURadix(MMU): map_reserved_entries_to_levels = {} def classify_entry(self, page_addr, entry): + """Classify an entry + + Aimed to be overloaded by child classes + + Args: + page_addr: the address of the page + entry: the entry to classify + + Returns: + the classified entry + """ raise NotImplementedError def derive_page_address(self, addr, mode="global"): - # Derive the addresses of pages containing the address + """Derive the addresses of pages containing the address + + Args: + addr: the address to derive + mode: the mode to use + + Returns: + the addresses of pages containing the address + """ addrs = [] - for lvl in range(self.radix_levels[mode] - 1 , -1, -1): + for lvl in range(self.radix_levels[mode] - 1, -1, -1): for entry_class in self.map_datapages_entries_to_levels[mode][lvl]: if entry_class is not None: shift = self.map_entries_to_shifts[mode][entry_class] addrs.append((lvl, (addr >> shift) << shift)) return addrs - def parse_parallel_frame(self, addresses, frame_size, pidx, mode="global", **kwargs): + def parse_parallel_frame( + self, addresses, frame_size, pidx, mode="global", **kwargs + ) -> tuple: + """Parse a frame in parallel - # Every process sleep a random delay in order to desincronize access to disk and maximixe the throuput - sleep(uniform(pidx, pidx+1) // 1000) + Every process sleep a random delay in order to desynchronise access to disk and maximixe the throuput + + Args: + addresses: the addresses to parse + frame_size: the size of the frame + pidx: the process index + mode: the mode to use + **kwargs: additional arguments + + Returns: + a tuple containing page tables, data pages and empty tables + """ + # Every process sleep a random delay in order to desynchronise access to disk and maximixe the throuput + sleep(uniform(pidx, pidx + 1) // 1000) data_pages = [] empty_tables = [] @@ -501,9 +1038,13 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, mode="global", **kwa # Cicle over every frame total_elems, iterator = addresses - for frame_addr in tqdm(iterator, position=-pidx, total=total_elems, leave=False): + for frame_addr in tqdm( + iterator, position=-pidx, total=total_elems, leave=False + ): frame_buf = mm.get_data(frame_addr, frame_size) - invalids, pt_classes, table_obj = self.parse_frame(frame_buf, frame_addr, frame_size) + invalids, pt_classes, table_obj = self.parse_frame( + frame_buf, frame_addr, frame_size + ) # It is a data page if invalids or -2 in pt_classes: @@ -516,20 +1057,42 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, mode="global", **kwa return page_tables, data_pages, empty_tables - def parse_frame(self, frame_buf, frame_addr, frame_size, frame_level=-1, mode="global"): + def parse_frame( + self, frame_buf, frame_addr, frame_size, frame_level=-1, mode="global" + ) -> tuple: + """Parse a frame + + Args: + frame_buf: the buffer of the frame + frame_addr: the address of the frame + frame_size: the size of the frame + frame_level: the level of the frame + mode: the mode to use + + Returns: + a tuple containing the number of invalids, the classes of the page tables and page table object + """ frame_d = defaultdict(dict) if frame_level >= 0: - reseved_classes = self.machine.mmu.map_reserved_entries_to_levels[mode][frame_level] - data_classes = self.machine.mmu.map_datapages_entries_to_levels[mode][frame_level] + reseved_classes = self.machine.mmu.map_reserved_entries_to_levels[mode][ + frame_level + ] + data_classes = self.machine.mmu.map_datapages_entries_to_levels[mode][ + frame_level + ] ptr_class = self.machine.mmu.map_ptr_entries_to_levels[mode][frame_level] # frame_size = self.machine.mmu.map_level_to_table_size[mode][frame_level] invalids = 0 empty_entries = 0 # Unpack records inside the frame - for entry_idx, entry in enumerate(iter_unpack(self.paging_unpack_format, frame_buf)): - - if frame_level >= 0 and entry_idx * self.machine.mmu.entries_size >= frame_size: + for entry_idx, entry in enumerate( + iter_unpack(self.paging_unpack_format, frame_buf) + ): + if ( + frame_level >= 0 + and entry_idx * self.machine.mmu.entries_size >= frame_size + ): break entry = entry[0] @@ -556,27 +1119,52 @@ def parse_frame(self, frame_buf, frame_addr, frame_size, frame_level=-1, mode="g else: for entry_obj in entry_classes: entry_type = type(entry_obj) - if type(entry_obj) in data_classes or \ - type(entry_obj) is ptr_class or \ - type(entry_obj) in reseved_classes: + if ( + type(entry_obj) in data_classes + or type(entry_obj) is ptr_class + or type(entry_obj) in reseved_classes + ): frame_d[entry_type][entry_idx] = entry_obj break else: invalids += 1 # Classify the frame - pt_classes = self.classify_frame(frame_d, empty_entries, int(frame_size // self.page_table_class.entry_size), mode=mode) - - if -1 in pt_classes or -2 in pt_classes: # EMPTY or DATA + pt_classes = self.classify_frame( + frame_d, + empty_entries, + int(frame_size // self.page_table_class.entry_size), + mode=mode, + ) + + if -1 in pt_classes or -2 in pt_classes: # EMPTY or DATA table_obj = None else: - table_obj = self.page_table_class(frame_addr, frame_size, frame_d, pt_classes) + table_obj = self.page_table_class( + frame_addr, frame_size, frame_d, pt_classes + ) return invalids, pt_classes, table_obj - def classify_frame(self, frame_d, empty_c, entries_per_frame, mode="global", ): - + def classify_frame( + self, + frame_d, + empty_c, + entries_per_frame, + mode="global", + ) -> list: + """Classify a frame + + Args: + frame_d: the buffer of the frame + empty_c: the number of empty entries + entries_per_frame: the number of entries per frame + mode: the mode to use + + Returns: + a list containing the classes of the page tables + """ if empty_c == entries_per_frame: - return [-1] # EMPTY + return [-1] # EMPTY # For each level check if a table is a valid candidate frame_classes = [] @@ -593,20 +1181,37 @@ def classify_frame(self, frame_d, empty_c, entries_per_frame, mode="global", ): frame_classes.append(level) if not frame_classes: - return [-2] # DATA + return [-2] # DATA else: return frame_classes + class PhysicalMemory: + """Represents the physical memory of a machine + + The physical memory is composed by a set of memory regions, each one representing a set of contiguous physical addresses with the same permissions. + + Attributes: + _is_opened: a flag to check if the memory is opened + _miasm_vm: the Miasm VM + _memregions: the memory regions + _memsize: the size of the memory + physpace: the physical space + raw_configuration: the raw configuration of the memory + """ + machine = None def __deepcopy__(self, memo): - return PhysicalMemory(self.raw_configuration) + """Deepcopy the PhysicalMemory object""" + return PhysicalMemory(self.raw_configuration) def __copy__(self): + """Copy the PhysicalMemory object""" return PhysicalMemory(self.raw_configuration) def __getstate__(self): + """Get the state of the PhysicalMemory object""" self.close() if self._miasm_vm: del self._miasm_vm @@ -616,24 +1221,26 @@ def __getstate__(self): return state def __setstate__(self, state): + """Set the state of the PhysicalMemory object""" self.__dict__.update(state) self.reopen() def __init__(self, regions_defs): + """Initialize the PhysicalMemory + + Args: + regions_defs: the regions definitions + """ self._is_opened = False self._miasm_vm = None self._memregions = [] self._memsize = 0 - self.physpace = { - "ram": portion.empty(), - "not_ram": portion.empty() - } + self.physpace = {"ram": portion.empty(), "not_ram": portion.empty()} self.raw_configuration = regions_defs # Load dump RAM files try: for region_def in regions_defs["ram"]: - # Load the dump file for a memory region fd = open(region_def["dumpfile"], "rb") mm = mmap(fd.fileno(), 0, MAP_SHARED, PROT_READ) @@ -642,16 +1249,27 @@ def __init__(self, regions_defs): region_size = len(mm) if region_size != region_def["end"] - region_def["start"] + 1: - raise IOError("Declared size {} is different from real size {} for: {}".format(region_def["end"] - region_def["start"] + 1, region_size, region_def["dumpfile"])) - - self._memregions.append({"filename": region_def["dumpfile"], - "fd": fd, - "mmap": mm, - "size": region_size, - "start": region_def["start"], - "end": region_def["end"] - }) - self.physpace["ram"] |= portion.closed(region_def["start"], region_def["end"]) + raise IOError( + "Declared size {} is different from real size {} for: {}".format( + region_def["end"] - region_def["start"] + 1, + region_size, + region_def["dumpfile"], + ) + ) + + self._memregions.append( + { + "filename": region_def["dumpfile"], + "fd": fd, + "mmap": mm, + "size": region_size, + "start": region_def["start"], + "end": region_def["end"], + } + ) + self.physpace["ram"] |= portion.closed( + region_def["start"], region_def["end"] + ) self._memregions.sort(key=lambda x: x["start"]) self._is_opened = True @@ -662,22 +1280,31 @@ def __init__(self, regions_defs): # Load not RAM regions for region_def in regions_defs.get("not_ram", []): - self.physpace["not_ram"] |= portion.closed(region_def["start"], region_def["end"]) - self.physpace["not_valid_regions"] = self.physpace["not_ram"].difference(self.physpace["ram"]) - self.physpace["defined_regions"] = self.physpace["not_ram"] | self.physpace["ram"] + self.physpace["not_ram"] |= portion.closed( + region_def["start"], region_def["end"] + ) + self.physpace["not_valid_regions"] = self.physpace["not_ram"].difference( + self.physpace["ram"] + ) + self.physpace["defined_regions"] = ( + self.physpace["not_ram"] | self.physpace["ram"] + ) def __del__(self): self.close() def __len__(self): + """Get the size of the memory""" return self._memsize def __contains__(self, key): + """Check if a key is in the memory""" if not isinstance(key, int): raise TypeError return key in self.physpace["ram"] def close(self): + """Close the memory""" for region in self._memregions: try: if region["mmap"] is not None: @@ -693,6 +1320,7 @@ def close(self): self._is_opened = False def reopen(self): + """Reopen the memory""" for region in self._memregions: region["fd"] = open(region["filename"], "rb") region["mmap"] = mmap(region["fd"].fileno(), 0, MAP_SHARED, PROT_READ) @@ -700,13 +1328,29 @@ def reopen(self): self._is_opened = True def get_data(self, start, size): + """Get data from the memory""" for region in self._memregions: if region["start"] <= start <= region["end"]: - return region["mmap"][start-region["start"]:start-region["start"]+size] + return region["mmap"][ + start - region["start"] : start - region["start"] + size + ] return bytearray() def get_addresses(self, size, align_offset=0, cpus=1, max_address=-1): - """Return a list contains tuples (for a total of cpus tuples). Each tuple contains the len of the iterator, and an iterator over part of all the addresses aligned to align_offset and distanced by size present in RAM""" + """Get the addresses of the memory + + Return a list contains tuples (for a total of cpus tuples). + Each tuple contains the len of the iterator, and an iterator over part of all the addresses aligned to align_offset and distanced by size present in RAM regions. + + Args: + size: the size of the addresses + align_offset: the alignment offset + cpus: the number of cpus + max_address: the maximum address + + Returns: + a tuple containing the total elements and the addresses + """ if size == 0: return 0, [] @@ -741,17 +1385,19 @@ def get_addresses(self, size, align_offset=0, cpus=1, max_address=-1): vcpus = cpus first_elem = region_start + align_offset - multi_ranges[0].append(range(first_elem, region_start + range_size, size)) + multi_ranges[0].append( + range(first_elem, region_start + range_size, size) + ) prev_last = region_start + range_size - for i in range(1, vcpus-1): + for i in range(1, vcpus - 1): r = (prev_last - align_offset - region_start) % size if r == 0: first_elem = prev_last else: first_elem = prev_last + (size - r) - last_elem = (i+1) * range_size + region_start + last_elem = (i + 1) * range_size + region_start multi_ranges[i].append(range(first_elem, last_elem, size)) prev_last = last_elem @@ -778,30 +1424,41 @@ def get_addresses(self, size, align_offset=0, cpus=1, max_address=-1): return total_elems, multi_ranges def get_miasm_vmmngr(self): - """Load each RAM file in a MIASM virtual memory region""" + """Load each RAM file in a MIASM virtual memory region + + Returns: + the Miasm VM + """ if self._miasm_vm is not None: return self._miasm_vm vm = Vm() for region_def in self._memregions: - vm.add_memory_page(region_def["start"], PAGE_READ | PAGE_WRITE | PAGE_EXEC, - region_def["fd"].read(), region_def["filename"]) + vm.add_memory_page( + region_def["start"], + PAGE_READ | PAGE_WRITE | PAGE_EXEC, + region_def["fd"].read(), + region_def["filename"], + ) region_def["fd"].seek(0) self._miasm_vm = vm return self._miasm_vm def get_memregions(self): + """Get the memory regions""" return self._memregions def free_miasm_memory(self): + """Free the Miasm memory""" if self._miasm_vm: self._miasm_vm = None gc.collect() + class MMUShell(Cmd): - intro = 'MMUShell. Type help or ? to list commands.\n' + intro = "MMUShell. Type help or ? to list commands.\n" - def __init__(self, completekey='tab', stdin=None, stdout=None, machine={}): + def __init__(self, completekey="tab", stdin=None, stdout=None, machine={}): super(MMUShell, self).__init__(completekey, stdin, stdout) self.machine = machine self.prompt = "[MMUShell " + self.machine.cpu.architecture + "]# " @@ -816,10 +1473,10 @@ def reload_data_from_file(self, data_filename): except Exception as e: logger.fatal("Fatal error loading session data! Error:{}".format(e)) import traceback + print(traceback.print_exc()) exit(1) - def load_gtruth(self, gtruth_fd): try: self.gtruth = Load(gtruth_fd) @@ -830,7 +1487,7 @@ def load_gtruth(self, gtruth_fd): def do_exit(self, arg): """Exit :)""" - logger.info('Bye! :)') + logger.info("Bye! :)") return True def do_save_data(self, arg): @@ -866,7 +1523,9 @@ def parse_int(self, value): else: return int(value, 10) - def radix_roots_from_data_page(self, pg_lvl, pg_addr, rev_map_pages, rev_map_tables): + def radix_roots_from_data_page( + self, pg_lvl, pg_addr, rev_map_pages, rev_map_tables + ): # For a page address pointed by tables of level 'level' find all the radix root of trees containing it level_tables = set() @@ -874,7 +1533,11 @@ def radix_roots_from_data_page(self, pg_lvl, pg_addr, rev_map_pages, rev_map_tab # Collect all table at level 'pg_lvl' which point to that page level_tables.update(rev_map_pages[pg_lvl][pg_addr]) - logger.debug("radix_roots_from_data_pages: level_tables found {} for pg_addr {}".format(len(level_tables), hex(pg_addr))) + logger.debug( + "radix_roots_from_data_pages: level_tables found {} for pg_addr {}".format( + len(level_tables), hex(pg_addr) + ) + ) # Raise the tree in order to find the top table for tree_lvl in range(pg_lvl - 1, -1, -1): @@ -883,11 +1546,25 @@ def radix_roots_from_data_page(self, pg_lvl, pg_addr, rev_map_pages, rev_map_tab level_tables = prev_level_tables prev_level_tables = set() - logger.debug("radix_roots_from_data_pages: level_tables found {} for pg_addr {}".format(len(level_tables), hex(pg_addr))) + logger.debug( + "radix_roots_from_data_pages: level_tables found {} for pg_addr {}".format( + len(level_tables), hex(pg_addr) + ) + ) return set(level_tables) - def physpace(self, addr, page_tables, empty_tables, lvl=0, uperms=(True,)*6, hierarchical=False, mode="global", cache=defaultdict(dict)): + def physpace( + self, + addr, + page_tables, + empty_tables, + lvl=0, + uperms=(True,) * 6, + hierarchical=False, + mode="global", + cache=defaultdict(dict), + ): """Recursively evaluate the consistency and return the kernel/user physical space addressed""" pas = PAS() data_classes = self.machine.mmu.map_datapages_entries_to_levels[mode][lvl] @@ -904,23 +1581,38 @@ def physpace(self, addr, page_tables, empty_tables, lvl=0, uperms=(True,)*6, hie cache[lvl][addr] = (True, pas) return True, pas - else: # Superior levels + else: # Superior levels ptr_class = self.machine.mmu.map_ptr_entries_to_levels[mode][lvl] if ptr_class in page_tables[lvl][addr].entries: for entry in page_tables[lvl][addr].entries[ptr_class].values(): if entry.address not in page_tables[lvl + 1]: - if entry.address not in empty_tables: # It is not an empty table! - logging.debug(f"physpace() radix: {hex(addr)} parent level: {lvl} table: {hex(entry.address)} invalid") + if ( + entry.address not in empty_tables + ): # It is not an empty table! + logging.debug( + f"physpace() radix: {hex(addr)} parent level: {lvl} table: {hex(entry.address)} invalid" + ) cache[lvl][addr] = (False, None) return False, None else: if entry.address not in cache[lvl + 1]: - low_cons, low_pas = self.physpace(entry.address, page_tables, empty_tables, lvl + 1, uperms=uperms, hierarchical=hierarchical, mode=mode, cache=cache) + low_cons, low_pas = self.physpace( + entry.address, + page_tables, + empty_tables, + lvl + 1, + uperms=uperms, + hierarchical=hierarchical, + mode=mode, + cache=cache, + ) else: low_cons, low_pas = cache[lvl + 1][entry.address] if not low_cons: - logging.debug(f"physpace() radix: {hex(addr)} parent level: {lvl} table: {hex(entry.address)} invalid") + logging.debug( + f"physpace() radix: {hex(addr)} parent level: {lvl} table: {hex(entry.address)} invalid" + ) cache[lvl][addr] = (False, None) return False, None @@ -930,7 +1622,10 @@ def physpace(self, addr, page_tables, empty_tables, lvl=0, uperms=(True,)*6, hie pas.hierarchical_extend(low_pas, (True,) * 6) for data_class in data_classes: - if data_class in page_tables[lvl][addr].entries and data_class is not None: + if ( + data_class in page_tables[lvl][addr].entries + and data_class is not None + ): for entry in page_tables[lvl][addr].entries[data_class].values(): perms = entry.get_permissions() pas.space[perms][entry.address] = entry.size @@ -944,7 +1639,6 @@ def resolve_vaddr(self, cr3, vaddr, mode="global"): # Split in possible table resolution paths for splitted_addr in self.machine.mmu.split_vaddr(vaddr): - current_table_addr = cr3 requested_steps = len(splitted_addr) - 1 resolution_steps = 0 @@ -953,14 +1647,27 @@ def resolve_vaddr(self, cr3, vaddr, mode="global"): level_class, entry_idx = idx_t # Missing valid table, no valid resolution path if current_table_addr not in self.data.page_tables["global"][level_idx]: - logging.debug(f"resolve_vaddr() Missing table {hex(current_table_addr)}") + logging.debug( + f"resolve_vaddr() Missing table {hex(current_table_addr)}" + ) logging.debug("resolve_vaddr() RESOLUTION PATH FAILED! ########") break - logging.debug(f"resolve_vaddr(): Resolution path Lvl: {level_class} Table: {hex(current_table_addr)} Entry addr: {hex( current_table_addr + self.machine.mmu.page_table_class.entry_size * entry_idx)}") + logging.debug( + f"resolve_vaddr(): Resolution path Lvl: {level_class} Table: {hex(current_table_addr)} Entry addr: {hex( current_table_addr + self.machine.mmu.page_table_class.entry_size * entry_idx)}" + ) # Find valid entry in table - if entry_idx in self.data.page_tables["global"][level_idx][current_table_addr].entries[level_class]: - current_table_addr = self.data.page_tables["global"][level_idx][current_table_addr].entries[level_class][entry_idx].address + if ( + entry_idx + in self.data.page_tables["global"][level_idx][ + current_table_addr + ].entries[level_class] + ): + current_table_addr = ( + self.data.page_tables["global"][level_idx][current_table_addr] + .entries[level_class][entry_idx] + .address + ) resolution_steps += 1 @@ -975,7 +1682,16 @@ def resolve_vaddr(self, cr3, vaddr, mode="global"): else: return -1 - def virtspace(self, addr, lvl=0, prefix=0, uperms=(True,)*6, hierarchical=False, mode="global", cache=defaultdict(dict)): + def virtspace( + self, + addr, + lvl=0, + prefix=0, + uperms=(True,) * 6, + hierarchical=False, + mode="global", + cache=defaultdict(dict), + ): """Recursively reconstruct virtual address space""" virtspace = VAS() @@ -986,10 +1702,14 @@ def virtspace(self, addr, lvl=0, prefix=0, uperms=(True,)*6, hierarchical=False, if lvl == self.machine.mmu.radix_levels[mode] - 1: for data_class in data_classes: shift = self.machine.mmu.map_entries_to_shifts[mode][data_class] - for entry_idx, entry in self.data.page_tables[mode][lvl][addr].entries[data_class].items(): + for entry_idx, entry in ( + self.data.page_tables[mode][lvl][addr].entries[data_class].items() + ): permissions = entry.get_permissions() virt_addr = prefix | (entry_idx << shift) - virtspace[permissions] |= portion.closedopen(virt_addr, virt_addr + entry.size) + virtspace[permissions] |= portion.closedopen( + virt_addr, virt_addr + entry.size + ) cache[lvl][addr] = virtspace return virtspace @@ -997,7 +1717,9 @@ def virtspace(self, addr, lvl=0, prefix=0, uperms=(True,)*6, hierarchical=False, else: if ptr_class in self.data.page_tables[mode][lvl][addr].entries: shift = self.machine.mmu.map_entries_to_shifts[mode][ptr_class] - for entry_idx, entry in self.data.page_tables[mode][lvl][addr].entries[ptr_class].items(): + for entry_idx, entry in ( + self.data.page_tables[mode][lvl][addr].entries[ptr_class].items() + ): if entry.address not in self.data.page_tables[mode][lvl + 1]: continue else: @@ -1005,7 +1727,15 @@ def virtspace(self, addr, lvl=0, prefix=0, uperms=(True,)*6, hierarchical=False, if entry.address not in cache[lvl + 1]: virt_addr = prefix | (entry_idx << shift) - low_virts = self.virtspace(entry.address, lvl + 1, virt_addr, permissions, hierarchical=hierarchical, mode=mode, cache=cache) + low_virts = self.virtspace( + entry.address, + lvl + 1, + virt_addr, + permissions, + hierarchical=hierarchical, + mode=mode, + cache=cache, + ) else: low_virts = cache[lvl + 1][entry.address] @@ -1015,12 +1745,21 @@ def virtspace(self, addr, lvl=0, prefix=0, uperms=(True,)*6, hierarchical=False, virtspace.hierarchical_extend(low_virts, (True,) * 6) for data_class in data_classes: - if data_class in self.data.page_tables[mode][lvl][addr].entries and data_class is not None: + if ( + data_class in self.data.page_tables[mode][lvl][addr].entries + and data_class is not None + ): shift = self.machine.mmu.map_entries_to_shifts[mode][data_class] - for entry_idx, entry in self.data.page_tables[mode][lvl][addr].entries[data_class].items(): + for entry_idx, entry in ( + self.data.page_tables[mode][lvl][addr] + .entries[data_class] + .items() + ): permissions = entry.get_permissions() virt_addr = prefix | (entry_idx << shift) - virtspace[permissions] |= portion.closedopen(virt_addr, virt_addr + entry.size) + virtspace[permissions] |= portion.closedopen( + virt_addr, virt_addr + entry.size + ) cache[lvl][addr] = virtspace return virtspace diff --git a/architectures/intel.py b/mmushell/architectures/intel.py similarity index 72% rename from architectures/intel.py rename to mmushell/architectures/intel.py index 7452595..8ba8e53 100644 --- a/architectures/intel.py +++ b/mmushell/architectures/intel.py @@ -1,25 +1,26 @@ -from json import dump +import logging +import portion +import multiprocessing as mp + from architectures.generic import Machine as MachineDefault from architectures.generic import CPU as CPUDefault from architectures.generic import PhysicalMemory as PhysicalMemoryDefault from architectures.generic import MMUShell as MMUShellDefault from architectures.generic import TableEntry, PageTable, MMURadix, PAS, RadixTree from architectures.generic import CPUReg -import logging + +from miasm.analysis.machine import Machine as MIASMMachine + +from more_itertools import divide +from dataclasses import dataclass from collections import defaultdict, deque from prettytable import PrettyTable +from random import uniform +from struct import iter_unpack, unpack +from json import dump from time import sleep from tqdm import tqdm from copy import deepcopy, copy -from random import uniform -from struct import iter_unpack, unpack -from dataclasses import dataclass -import multiprocessing as mp -# import cProfile -import portion -from more_itertools import divide -from miasm.analysis.machine import Machine as MIASMMachine -# from IPython import embed logger = logging.getLogger(__name__) @@ -45,7 +46,9 @@ def __init__(self, idtr): self.valid = True def __repr__(self): - return f"IDTR: {hex(self.value)} => Address: {hex(self.address)}, Size:{self.size}" + return ( + f"IDTR: {hex(self.value)} => Address: {hex(self.address)}, Size:{self.size}" + ) class CR3_32(CPUReg): @@ -111,9 +114,14 @@ def parse_idt(self, addr): entries = [] IDTable = self.processor_features["idt_table_class"] for idx in range(256): - entry_buff = self.machine.memory.get_data(addr + idx * self.processor_features["idt_entry_size"], self.processor_features["idt_entry_size"]) + entry_buff = self.machine.memory.get_data( + addr + idx * self.processor_features["idt_entry_size"], + self.processor_features["idt_entry_size"], + ) try: - raw_entry = unpack(self.processor_features["idt_unpack_format"], entry_buff) + raw_entry = unpack( + self.processor_features["idt_unpack_format"], entry_buff + ) except Exception: break entry_idt = self.classify_idt_entry(raw_entry) @@ -123,18 +131,29 @@ def parse_idt(self, addr): else: entries.append(entry_idt) - return(IDTable(addr, len(entries) * self.processor_features["idt_entry_size"], entries)) + return IDTable( + addr, len(entries) * self.processor_features["idt_entry_size"], entries + ) def find_idt_tables(self): # Look for IDT only in pointed pages idt_entry_size = self.processor_features["idt_entry_size"] # Workaround to reduce memory fingerprint - iterators = [self.machine.memory.get_addresses(idt_entry_size, align_offset=i) - for pidx, i in enumerate(range(0, idt_entry_size, 4))] - - pool = mp.Pool(idt_entry_size // 4, initializer=tqdm.set_lock, initargs=(mp.Lock(),)) - idt_candidates_async = [pool.apply_async(self.find_idt_tables_parallel, args=(iterators[pidx], pidx)) for pidx, i in enumerate(range(0, idt_entry_size, 4))] + iterators = [ + self.machine.memory.get_addresses(idt_entry_size, align_offset=i) + for pidx, i in enumerate(range(0, idt_entry_size, 4)) + ] + + pool = mp.Pool( + idt_entry_size // 4, initializer=tqdm.set_lock, initargs=(mp.Lock(),) + ) + idt_candidates_async = [ + pool.apply_async( + self.find_idt_tables_parallel, args=(iterators[pidx], pidx) + ) + for pidx, i in enumerate(range(0, idt_entry_size, 4)) + ] pool.close() pool.join() @@ -143,12 +162,12 @@ def find_idt_tables(self): for res in idt_candidates_async: idts.extend(res.get()) - print("\n") # Workaround TQDM + print("\n") # Workaround TQDM return idts def find_idt_tables_parallel(self, addresses_it, pidx): # Random sleep to desyncronize accesses to disk - sleep(uniform(pidx, pidx+1) // 1000) + sleep(uniform(pidx, pidx + 1) // 1000) idt_candidates = [] idt_under_analysis = deque(maxlen=256) @@ -159,7 +178,9 @@ def find_idt_tables_parallel(self, addresses_it, pidx): idt_unpack_format = self.processor_features["idt_unpack_format"] naddresses, addresses = addresses_it[1] - for addr in tqdm(addresses, total=naddresses, unit="tables", position=-pidx, leave=False): + for addr in tqdm( + addresses, total=naddresses, unit="tables", position=-pidx, leave=False + ): # parsing table machinery, we risk to lose some tables however... table_buff = mm.get_data(addr, idt_entry_size) @@ -177,21 +198,28 @@ def find_idt_tables_parallel(self, addresses_it, pidx): # If the entry is invalid finalize the current IDT under analysis if entry_idt is None: - if len(idt_under_analysis) > 0: if self.validate_idt(idt_under_analysis): - idt_candidates.append(IDTable(addr - len(idt_under_analysis) * idt_entry_size, - len(idt_under_analysis) * idt_entry_size, - tuple(deepcopy(idt_under_analysis)))) + idt_candidates.append( + IDTable( + addr - len(idt_under_analysis) * idt_entry_size, + len(idt_under_analysis) * idt_entry_size, + tuple(deepcopy(idt_under_analysis)), + ) + ) idt_under_analysis.clear() else: # Check if the candidate has reach the maximum size if len(idt_under_analysis) == 256: if self.validate_idt(idt_under_analysis): - idt_candidates.append(IDTable(addr - len(idt_under_analysis) * idt_entry_size, - len(idt_under_analysis) * idt_entry_size, - tuple(deepcopy(idt_under_analysis)))) + idt_candidates.append( + IDTable( + addr - len(idt_under_analysis) * idt_entry_size, + len(idt_under_analysis) * idt_entry_size, + tuple(deepcopy(idt_under_analysis)), + ) + ) idt_under_analysis.append(entry_idt) @@ -215,14 +243,18 @@ def __init__(self, features): # Check validity of m_phy m_phy = self.processor_features.get("m_phy", -1) if m_phy <= 0 or m_phy >= 52: - logging.fatal("m_phy must be positive and less then 40 in IA32 mode MMU modes") + logging.fatal( + "m_phy must be positive and less then 40 in IA32 mode MMU modes" + ) exit(1) CPU.m_phy = m_phy def classify_idt_entry(self, entry): # 32bit IDT entries p = CPU.extract_bits(entry[1], 15, 1) - offset = (CPU.extract_bits(entry[1], 16, 16) << 16) + CPU.extract_bits(entry[0], 0, 16) + offset = (CPU.extract_bits(entry[1], 16, 16) << 16) + CPU.extract_bits( + entry[0], 0, 16 + ) s_selector = CPU.extract_bits(entry[0], 16, 16) e_type = CPU.extract_bits(entry[1], 8, 5) dpl = CPU.extract_bits(entry[1], 13, 2) @@ -257,7 +289,6 @@ def classify_idt_entry(self, entry): def validate_idt(self, candidate): # Check if minimum interrupt handlers are defined for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14]: - try: if not candidate[i].p: return False @@ -286,14 +317,19 @@ def __init__(self, features): # Check validity of m_phy m_phy = self.processor_features.get("m_phy", -1) if m_phy <= 0 or m_phy >= 52: - logging.fatal("m_phy must be positive and less then 52 in PAE/IA64 MMU modes") + logging.fatal( + "m_phy must be positive and less then 52 in PAE/IA64 MMU modes" + ) exit(1) CPU.m_phy = m_phy def classify_idt_entry(self, entry): s_selector = CPU.extract_bits(entry[0], 16, 16) - offset = (entry[2] << 32) + \ - (CPU.extract_bits(entry[1], 16, 16) << 16) + CPU.extract_bits(entry[0], 0, 16) + offset = ( + (entry[2] << 32) + + (CPU.extract_bits(entry[1], 16, 16) << 16) + + CPU.extract_bits(entry[0], 0, 16) + ) ist = CPU.extract_bits(entry[1], 0, 3) dpl = CPU.extract_bits(entry[1], 13, 2) e_type = CPU.extract_bits(entry[1], 8, 4) @@ -326,7 +362,6 @@ def classify_idt_entry(self, entry): def validate_idt(self, candidate): # Check if minimum interrupt handlers are defined for i in [0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 16, 17, 18, 19]: - try: if not candidate[i].p: return False @@ -349,13 +384,22 @@ def validate_idt(self, candidate): return False return True + ##################################################################### # 32 bit entries and page table ##################################################################### + class IDTEntry32: entry_name = "Empty" - labels = ["Present:", "Type:", "Interrupt address:", "Segment:", "DPL:", "Gate size:"] + labels = [ + "Present:", + "Type:", + "Interrupt address:", + "Segment:", + "DPL:", + "Gate size:", + ] addr_fmt = "0x{:08x}" def __init__(self, offset, segment, typ, dpl, p): @@ -370,7 +414,9 @@ def __hash__(self): def __repr__(self): e_resume = self.entry_resume_stringified() - return str([self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))]) + return str( + [self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))] + ) def entry_resume_stringified(self): res = self.entry_resume() @@ -380,44 +426,57 @@ def entry_resume_stringified(self): return res def entry_resume(self): - return [self.p, self.entry_name, "", "", "", "" ] + return [self.p, self.entry_name, "", "", "", ""] class IDTInterruptEntry32(IDTEntry32): entry_name = "Interrupt" - labels = ["Present:","Type:", "Interrupt address:", "Segment:", "DPL:", "Gate size:"] + labels = [ + "Present:", + "Type:", + "Interrupt address:", + "Segment:", + "DPL:", + "Gate size:", + ] def entry_resume(self): - return [self.p, - self.entry_name, - self.offset, - self.segment, - self.dpl, - self.gate_size() - ] + return [ + self.p, + self.entry_name, + self.offset, + self.segment, + self.dpl, + self.gate_size(), + ] def gate_size(self): return 32 if CPU.extract_bits(self.type, 3, 1) else 16 + class IDTTrapEntry32(IDTInterruptEntry32): entry_name = "Trap" + class IDTTaskEntry32(IDTEntry32): entry_name = "Task" labels = ["Present:", "Type:", "TSS Segment:", "DPL:"] def entry_resume(self): - return [self.p, - self.entry_name, - "Res.", - self.segment, - self.dpl, - "Ign." - ] + return [self.p, self.entry_name, "Res.", self.segment, self.dpl, "Ign."] + class IDTable32: table_name = "Interrupt table" - table_fields = ["Entry ID", "Present", "Type", "Interrupt address", "Segment", "DPL", "Gate size"] + table_fields = [ + "Entry ID", + "Present", + "Type", + "Interrupt address", + "Segment", + "DPL", + "Gate size", + ] addr_fmt = "0x{:08x}" def __hash__(self): @@ -441,12 +500,22 @@ def __repr__(self): def __len__(self): return self.size + class TEntry32(TableEntry): entry_size = 4 entry_name = "TEntry32" size = 0 - labels = ["Address:", "Global:", "PAT:", "Dirty:", "Accessed:", "PCD:", - "PWT:", "Kernel:", "Writable:"] + labels = [ + "Address:", + "Global:", + "PAT:", + "Dirty:", + "Accessed:", + "PCD:", + "PWT:", + "Kernel:", + "Writable:", + ] addr_fmt = "0x{:08x}" def __hash__(self): @@ -454,19 +523,22 @@ def __hash__(self): def __repr__(self): e_resume = self.entry_resume_stringified() - return str([self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))]) + return str( + [self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))] + ) def entry_resume(self): - return [self.address, - self.is_global_set(), - self.is_pat_set(), - self.is_dirty_entry(), - self.is_accessed_entry(), - self.is_pcd_set(), - self.is_pwt_set(), - self.is_supervisor_entry(), - self.is_writeble_entry() - ] + return [ + self.address, + self.is_global_set(), + self.is_pat_set(), + self.is_dirty_entry(), + self.is_accessed_entry(), + self.is_pcd_set(), + self.is_pwt_set(), + self.is_supervisor_entry(), + self.is_writeble_entry(), + ] def entry_resume_stringified(self): res = self.entry_resume() @@ -513,6 +585,7 @@ def get_permissions(self): else: return (False, False, False) + perms + class PTE4KB32(TEntry32): entry_name = "PTE4KB32" size = 1024 * 4 @@ -551,9 +624,19 @@ def is_global_set(self): class PageTableIntel32(PageTable): entry_size = 4 - table_fields = ["Entry address", "Pointed address", "Global", "PAT", - "Dirty", "Accessed", "PCD", "PWT", - "Supervisor", "Writable", "Class"] + table_fields = [ + "Entry address", + "Pointed address", + "Global", + "PAT", + "Dirty", + "Accessed", + "PCD", + "PWT", + "Supervisor", + "Writable", + "Class", + ] addr_fmt = "0x{:08x}" def __repr__(self): @@ -563,11 +646,16 @@ def __repr__(self): for entry_class in self.entries: for entry_idx, entry_obj in self.entries[entry_class].items(): entry_addr = self.address + (entry_idx * self.entry_size) - table.add_row([self.addr_fmt.format(entry_addr)] + entry_obj.entry_resume_stringified() + [entry_class.entry_name]) + table.add_row( + [self.addr_fmt.format(entry_addr)] + + entry_obj.entry_resume_stringified() + + [entry_class.entry_name] + ) - table.sortby="Entry address" + table.sortby = "Entry address" return str(table) + ##################################################################### # 64 bit entries and page table ##################################################################### @@ -589,13 +677,7 @@ class IDTInterruptEntry64(IDTEntry64): entry_name = "Interupt" def entry_resume(self): - return [self.p, - self.entry_name, - self.offset, - self.segment, - self.dpl, - self.ist - ] + return [self.p, self.entry_name, self.offset, self.segment, self.dpl, self.ist] class IDTTrapEntry64(IDTInterruptEntry64): @@ -604,7 +686,15 @@ class IDTTrapEntry64(IDTInterruptEntry64): class IDTable64(IDTable32): table_name = "Interrupt table" - table_fields = ["Entry ID", "Present", "Type", "Interrupt address", "Segment", "DPL", "IST"] + table_fields = [ + "Entry ID", + "Present", + "Type", + "Interrupt address", + "Segment", + "DPL", + "IST", + ] addr_fmt = "0x{:016x}" def __repr__(self): @@ -622,8 +712,19 @@ class TEntry64(TEntry32): entry_size = 8 entry_name = "TEntry64" size = 0 - labels = ["Address:", "NX:", "Prot. key:", "Global:", "PAT:", "Dirty:", - "Accessed:", "PCD:", "PWT:", "Kernel:", "Writable:"] + labels = [ + "Address:", + "NX:", + "Prot. key:", + "Global:", + "PAT:", + "Dirty:", + "Accessed:", + "PCD:", + "PWT:", + "Kernel:", + "Writable:", + ] addr_fmt = "0x{:016x}" def __init__(self, address, flags, *args): @@ -632,21 +733,24 @@ def __init__(self, address, flags, *args): def __repr__(self): e_resume = self.entry_resume_stringified() - return str([self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))]) + return str( + [self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))] + ) def entry_resume(self): - return [self.address, - self.is_executable_entry(), - self.prot_key(), - self.is_global_set(), - self.is_pat_set(), - self.is_dirty_entry(), - self.is_accessed_entry(), - self.is_pcd_set(), - self.is_pwt_set(), - self.is_supervisor_entry(), - self.is_writeble_entry() - ] + return [ + self.address, + self.is_executable_entry(), + self.prot_key(), + self.is_global_set(), + self.is_pat_set(), + self.is_dirty_entry(), + self.is_accessed_entry(), + self.is_pcd_set(), + self.is_pwt_set(), + self.is_supervisor_entry(), + self.is_writeble_entry(), + ] def is_executable_entry(self): return not bool(MMU.extract_bits(self.upper_flags, 5, 1)) @@ -710,20 +814,25 @@ def extract_addr(entry): class PTE4KBPAE(PTE4KB64): entry_name = "PTE4KBPAE" + def prot_key(self): return "Res." + class PDE2MBPAE(PDE2MB): entry_name = "PDE2MBPAE" + def prot_key(self): return "Res." class PDEPAE(PDE64): entry_name = "PDEPAE" + def prot_key(self): return "Res." + class PDPTEPAE(TEntry64): entry_name = "PDPTEPAE" size = 0 @@ -774,14 +883,28 @@ class PML4E(TPE64): class PageTableIntel64(PageTableIntel32): entry_size = 8 - table_fields = ["Entry address", "Pointed address", "NX", "Prot. key", - "Global", "PAT", "Dirty", "Accessed", "PCD", "PWT", - "Supervisor", "Writable", "Classes"] + table_fields = [ + "Entry address", + "Pointed address", + "NX", + "Prot. key", + "Global", + "PAT", + "Dirty", + "Accessed", + "PCD", + "PWT", + "Supervisor", + "Writable", + "Classes", + ] + ################################################################# # MMU Modes ################################################################# + class MMU(MMURadix): PAGE_SIZE = 4096 extract_bits = MMURadix.extract_bits_little @@ -791,6 +914,7 @@ class MMU(MMURadix): top_prefix = 0 entries_size = 4 + ################################################################ # MMU Modes ################################################################ @@ -802,11 +926,7 @@ class IA32(MMU): map_ptr_entries_to_levels = {"global": [PDE32, None]} map_datapages_entries_to_levels = {"global": [[PDE4MB], [PTE4KB32]]} map_level_to_table_size = {"global": [4096, 4096]} - map_entries_to_shifts = {"global": { - PDE32: 22, - PDE4MB: 22, - PTE4KB32: 12 - }} + map_entries_to_shifts = {"global": {PDE32: 22, PDE4MB: 22, PTE4KB32: 12}} cr3_class = CR3_32 map_reserved_entries_to_levels = {"global": [[], []]} @@ -839,10 +959,11 @@ def classify_entry_pt_only(self, page_addr, entry): if not MMU.extract_bits(entry, 0, 1): return [False] else: - return [PTE4KB32(PTE4KB32.extract_addr(entry), MMU.extract_bits(entry, 0, 13))] + return [ + PTE4KB32(PTE4KB32.extract_addr(entry), MMU.extract_bits(entry, 0, 13)) + ] def classify_entry_full(self, page_addr, entry): - # If BIT P=0 is EMPTY if not MMU.extract_bits(entry, 0, 1): return [False] @@ -853,12 +974,12 @@ def classify_entry_full(self, page_addr, entry): # is the only filter to discard entries # but we prefer to not use it (more, more general!) # ----------------------------------- - #if is_dirty_entry and not is_accessed_entry: + # if is_dirty_entry and not is_accessed_entry: # return [None] # ----------------------------------- # Extract flags and address - addr = PTE4KB32.extract_addr(entry) # This is also the PDE32 + addr = PTE4KB32.extract_addr(entry) # This is also the PDE32 flags = MMU.extract_bits(entry, 0, 13) # Check BIT 7 @@ -884,11 +1005,17 @@ def extend_prefix(self, prefix, entry_idx, entry_class): return prefix | (entry_idx << 12) def split_vaddr(self, vaddr): - return (((PDE32, MMU.extract_bits(vaddr, 22, 10)), - (PTE4KB32, MMU.extract_bits(vaddr, 12, 10)), - ("OFFSET", MMU.extract_bits(vaddr, 0, 12))), \ - ((PDE4MB, MMU.extract_bits(vaddr, 22, 10)), - ("OFFSET", MMU.extract_bits(vaddr, 0, 22)))) + return ( + ( + (PDE32, MMU.extract_bits(vaddr, 22, 10)), + (PTE4KB32, MMU.extract_bits(vaddr, 12, 10)), + ("OFFSET", MMU.extract_bits(vaddr, 0, 12)), + ), + ( + (PDE4MB, MMU.extract_bits(vaddr, 22, 10)), + ("OFFSET", MMU.extract_bits(vaddr, 0, 22)), + ), + ) class PAE(MMU): @@ -896,20 +1023,16 @@ class PAE(MMU): page_table_class = PageTableIntel64 radix_levels = {"global": 3} top_prefix = 0x0 - map_ptr_entries_to_levels = {"global": [PDPTEPAE, PDEPAE, None] } + map_ptr_entries_to_levels = {"global": [PDPTEPAE, PDEPAE, None]} map_datapages_entries_to_levels = {"global": [[None], [PDE2MBPAE], [PTE4KBPAE]]} map_level_to_table_size = {"global": [32, 4096, 4096]} - map_entries_to_shifts = {"global": { - PDPTEPAE: 30, - PDEPAE: 21, - PDE2MBPAE: 21, - PTE4KBPAE: 12 - }} + map_entries_to_shifts = { + "global": {PDPTEPAE: 30, PDEPAE: 21, PDE2MBPAE: 21, PTE4KBPAE: 12} + } cr3_class = CR3_PAE map_reserved_entries_to_levels = {"global": [[], [], []]} def classify_entry(self, page_addr, entry): - # Check BIT 0: must be 1 for a valid entry if not MMU.extract_bits(entry, 0, 1): return [False] @@ -924,7 +1047,6 @@ def classify_entry(self, page_addr, entry): # Check BIT 7 if MMU.extract_bits(entry, 7, 1): - # BIT 7 = 1 # Check BITS 20:13 if not all 0 it's a PTE4KB if MMU.extract_bits(entry, 13, 8): @@ -932,8 +1054,10 @@ def classify_entry(self, page_addr, entry): # It can be a PDE2MB addr_2mb = PDE2MBPAE.extract_addr(entry) - return [PTE4KBPAE(addr_4k, flags, flags2), - PDE2MBPAE(addr_2mb, flags, flags2)] + return [ + PTE4KBPAE(addr_4k, flags, flags2), + PDE2MBPAE(addr_2mb, flags, flags2), + ] # BIT 7 = 0 ret = [PTE4KBPAE(addr_4k, flags, flags2)] @@ -941,10 +1065,12 @@ def classify_entry(self, page_addr, entry): ret.append(PDEPAE(PDEPAE.extract_addr(entry), flags, flags2)) # Check BITS 1,2,5,6,8,63: if they are not all 0 it cannot be a PDPTE - if not (MMU.extract_bits(entry, 1, 2) or \ - MMU.extract_bits(entry, 5, 2) or \ - MMU.extract_bits(entry, 8, 1) or \ - MMU.extract_bits(entry, 63, 1)): + if not ( + MMU.extract_bits(entry, 1, 2) + or MMU.extract_bits(entry, 5, 2) + or MMU.extract_bits(entry, 8, 1) + or MMU.extract_bits(entry, 63, 1) + ): ret.append(PDPTEPAE(PDPTEPAE.extract_addr(entry), flags, flags2)) return ret @@ -953,7 +1079,7 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): # The top table is composed by only 4 entries so we parse them directly here in a special way # Every process sleep a random delay in order to desincronize access to disk and maximixe the throuput - sleep(uniform(pidx, pidx+1) // 1000) + sleep(uniform(pidx, pidx + 1) // 1000) data_pages = [] empty_tables = [] @@ -965,7 +1091,9 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): # Cicle over every frame total_elems, iterator = addresses - for frame_addr in tqdm(iterator, position=-pidx, total=total_elems, leave=False): + for frame_addr in tqdm( + iterator, position=-pidx, total=total_elems, leave=False + ): frame_buf = mm.get_data(frame_addr, self.PAGE_SIZE) empty_entries = 0 @@ -974,7 +1102,9 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): frame_d.clear() # Unpack records inside the frame - for entry_idx, entry in enumerate(iter_unpack(self.paging_unpack_format, frame_buf)): + for entry_idx, entry in enumerate( + iter_unpack(self.paging_unpack_format, frame_buf) + ): entry = entry[0] # Every four entry we can have a new PDPT table @@ -1003,9 +1133,19 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): # Validate the PDPT if 4 aligned entries are parsed and add it to the page_tables if (entry_idx + 1) % 4 == 0: - if pdpt_empty_entries != 4 and pdpt_empty_entries + len(pdpt_table[PDPTEPAE]) == 4: - pdpt_table_obj = PageTableIntel64(frame_addr + (entry_idx - 3) * 8, 4, deepcopy(pdpt_table), [0]) - page_tables[0][pdpt_table_obj.address] = deepcopy(pdpt_table_obj) + if ( + pdpt_empty_entries != 4 + and pdpt_empty_entries + len(pdpt_table[PDPTEPAE]) == 4 + ): + pdpt_table_obj = PageTableIntel64( + frame_addr + (entry_idx - 3) * 8, + 4, + deepcopy(pdpt_table), + [0], + ) + page_tables[0][pdpt_table_obj.address] = deepcopy( + pdpt_table_obj + ) # Add the entries only if the table is not already marked as invalid if not is_invalid: @@ -1023,15 +1163,21 @@ def parse_parallel_frame(self, addresses, frame_size, pidx, **kwargs): continue # Classify the frame - pt_classes = self.classify_frame(frame_d, empty_entries, int(frame_size // self.page_table_class.entry_size)) + pt_classes = self.classify_frame( + frame_d, + empty_entries, + int(frame_size // self.page_table_class.entry_size), + ) - if -1 in pt_classes: # EMPTY or DATA + if -1 in pt_classes: # EMPTY or DATA empty_tables.append(frame_addr) elif -2 in pt_classes: data_pages.append(frame_addr) else: for pt_class in pt_classes: - table_obj = self.page_table_class(frame_addr, self.PAGE_SIZE, frame_d, pt_classes) + table_obj = self.page_table_class( + frame_addr, self.PAGE_SIZE, frame_d, pt_classes + ) page_tables[pt_class][frame_addr] = deepcopy(table_obj) return page_tables, data_pages, empty_tables @@ -1048,13 +1194,17 @@ def extend_prefix(self, prefix, entry_idx, entry_class): return prefix | (entry_idx << 12) def split_vaddr(self, vaddr): - return ((PDPTEPAE, MMU.extract_bits(vaddr, 30, 9)), - (PDEPAE, MMU.extract_bits(vaddr, 21, 9)), - (PTE4KBPAE, MMU.extract_bits(vaddr, 12, 9)), - ("OFFSET", MMU.extract_bits(vaddr, 0, 12))), \ - ((PDPTEPAE, MMU.extract_bits(vaddr, 30, 9)), - (PDE2MBPAE, MMU.extract_bits(vaddr, 21, 9)), - ("OFFSET", MMU.extract_bits(vaddr, 0, 21))) + return ( + (PDPTEPAE, MMU.extract_bits(vaddr, 30, 9)), + (PDEPAE, MMU.extract_bits(vaddr, 21, 9)), + (PTE4KBPAE, MMU.extract_bits(vaddr, 12, 9)), + ("OFFSET", MMU.extract_bits(vaddr, 0, 12)), + ), ( + (PDPTEPAE, MMU.extract_bits(vaddr, 30, 9)), + (PDE2MBPAE, MMU.extract_bits(vaddr, 21, 9)), + ("OFFSET", MMU.extract_bits(vaddr, 0, 21)), + ) + class IA64(MMU): paging_unpack_format = " (self.machine.mmu.radix_levels["global"] - 1): raise ValueError except ValueError: - logger.warning("Level must be an integer between 0 and {}".format(str(self.machine.mmu.radix_levels["global"] - 1))) + logger.warning( + "Level must be an integer between 0 and {}".format( + str(self.machine.mmu.radix_levels["global"] - 1) + ) + ) return if lvl == -1: @@ -1223,7 +1405,9 @@ def do_show_table(self, args): else: table_size = self.machine.mmu.map_level_to_table_size["global"][lvl] table_buff = self.machine.memory.get_data(addr, table_size) - invalids, pt_classes, table_obj = self.machine.mmu.parse_frame(table_buff, addr, table_size, lvl) + invalids, pt_classes, table_obj = self.machine.mmu.parse_frame( + table_buff, addr, table_size, lvl + ) print(table_obj) print(f"Invalid entries: {invalids} Table levels: {pt_classes}") @@ -1234,7 +1418,9 @@ def parse_memory_ia32(self): # Look for only PD, then use that to find only PT logger.info("Look for page directories..") self.machine.mmu.classify_entry = self.machine.mmu.classify_entry_pd_only - parallel_results = self.machine.apply_parallel(self.machine.mmu.PAGE_SIZE, self.machine.mmu.parse_parallel_frame) + parallel_results = self.machine.apply_parallel( + self.machine.mmu.PAGE_SIZE, self.machine.mmu.parse_parallel_frame + ) logger.info("Reaggregate threads data...") for result in parallel_results: @@ -1258,8 +1444,15 @@ def parse_memory_ia32(self): data = self.data self.data = None - iterators = [(len(y), y) for y in [list(x) for x in divide(mp.cpu_count(), pt_candidates)]] - parsing_results_async = self.machine.apply_parallel(self.machine.mmu.PAGE_SIZE, self.machine.mmu.parse_parallel_frame, iterators=iterators) + iterators = [ + (len(y), y) + for y in [list(x) for x in divide(mp.cpu_count(), pt_candidates)] + ] + parsing_results_async = self.machine.apply_parallel( + self.machine.mmu.PAGE_SIZE, + self.machine.mmu.parse_parallel_frame, + iterators=iterators, + ) # Restore previous data and set classify_entry to full version self.data = data @@ -1276,21 +1469,28 @@ def parse_memory_ia32(self): # Remove PT from data pages (in the first phase the alogrith has classified PT as data pages, now that # we know which PT is a true one, they must be removed from data pages) self.data.data_pages = set(self.data.data_pages) - self.data.data_pages.difference_update(self.data.page_tables["global"][1].keys()) + self.data.data_pages.difference_update( + self.data.page_tables["global"][1].keys() + ) self.data.empty_tables = set(self.data.empty_tables) logger.info("Reduce false positives...") # Remove all tables which point to inexistent table of lower level for lvl in range(self.machine.mmu.radix_levels["global"] - 1): - ptr_class = self.machine.mmu.map_ptr_entries_to_levels["global"][lvl] referenced_nxt = [] for table_addr in list(self.data.page_tables["global"][lvl].keys()): - for entry_obj in self.data.page_tables["global"][lvl][table_addr].entries[ptr_class].values(): - if entry_obj.address not in self.data.page_tables["global"][lvl + 1] and \ - entry_obj.address not in self.data.empty_tables: - + for entry_obj in ( + self.data.page_tables["global"][lvl][table_addr] + .entries[ptr_class] + .values() + ): + if ( + entry_obj.address + not in self.data.page_tables["global"][lvl + 1] + and entry_obj.address not in self.data.empty_tables + ): # Remove the table self.data.page_tables["global"][lvl].pop(table_addr) break @@ -1300,18 +1500,28 @@ def parse_memory_ia32(self): # Remove table not referenced by upper levels referenced_nxt = set(referenced_nxt) - for table_addr in set(self.data.page_tables["global"][lvl + 1].keys()).difference(referenced_nxt): + for table_addr in set( + self.data.page_tables["global"][lvl + 1].keys() + ).difference(referenced_nxt): self.data.page_tables["global"][lvl + 1].pop(table_addr) logger.info("Fill reverse maps...") for lvl in range(0, self.machine.mmu.radix_levels["global"]): ptr_class = self.machine.mmu.map_ptr_entries_to_levels["global"][lvl] - page_class = self.machine.mmu.map_datapages_entries_to_levels["global"][lvl][0] # Trick! Only one dataclass per level + page_class = self.machine.mmu.map_datapages_entries_to_levels["global"][ + lvl + ][ + 0 + ] # Trick! Only one dataclass per level for table_addr, table_obj in self.data.page_tables["global"][lvl].items(): for entry_obj in table_obj.entries[ptr_class].values(): - self.data.reverse_map_tables[lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_tables[lvl][entry_obj.address].add( + table_obj.address + ) for entry_obj in table_obj.entries[page_class].values(): - self.data.reverse_map_pages[lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_pages[lvl][entry_obj.address].add( + table_obj.address + ) logger.info("Look for interrupt tables...") self.data.idts = self.machine.cpu.find_idt_tables() @@ -1322,7 +1532,9 @@ def parse_memory_pae(self): def parse_memory_ia64(self): logger.info("Look for paging tables...") - parallel_results = self.machine.apply_parallel(self.machine.mmu.PAGE_SIZE, self.machine.mmu.parse_parallel_frame) + parallel_results = self.machine.apply_parallel( + self.machine.mmu.PAGE_SIZE, self.machine.mmu.parse_parallel_frame + ) logger.info("Reaggregate threads data...") for result in parallel_results: page_tables, data_pages, empty_tables = result.get() @@ -1339,15 +1551,20 @@ def parse_memory_ia64(self): logger.info("Reduce false positives...") # Remove all tables which point to inexistent table of lower level for lvl in range(self.machine.mmu.radix_levels["global"] - 1): - ptr_class = self.machine.mmu.map_ptr_entries_to_levels["global"][lvl] referenced_nxt = [] for table_addr in list(self.data.page_tables["global"][lvl].keys()): - for entry_obj in self.data.page_tables["global"][lvl][table_addr].entries[ptr_class].values(): - if entry_obj.address not in self.data.page_tables["global"][lvl + 1] and \ - entry_obj.address not in self.data.empty_tables: - + for entry_obj in ( + self.data.page_tables["global"][lvl][table_addr] + .entries[ptr_class] + .values() + ): + if ( + entry_obj.address + not in self.data.page_tables["global"][lvl + 1] + and entry_obj.address not in self.data.empty_tables + ): # Remove the table self.data.page_tables["global"][lvl].pop(table_addr) break @@ -1357,18 +1574,28 @@ def parse_memory_ia64(self): # Remove table not referenced by upper levels referenced_nxt = set(referenced_nxt) - for table_addr in set(self.data.page_tables["global"][lvl + 1].keys()).difference(referenced_nxt): + for table_addr in set( + self.data.page_tables["global"][lvl + 1].keys() + ).difference(referenced_nxt): self.data.page_tables["global"][lvl + 1].pop(table_addr) logger.info("Fill reverse maps...") for lvl in range(0, self.machine.mmu.radix_levels["global"]): ptr_class = self.machine.mmu.map_ptr_entries_to_levels["global"][lvl] - page_class = self.machine.mmu.map_datapages_entries_to_levels["global"][lvl][0] # Trick! Only one dataclass per level + page_class = self.machine.mmu.map_datapages_entries_to_levels["global"][ + lvl + ][ + 0 + ] # Trick! Only one dataclass per level for table_addr, table_obj in self.data.page_tables["global"][lvl].items(): for entry_obj in table_obj.entries[ptr_class].values(): - self.data.reverse_map_tables[lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_tables[lvl][entry_obj.address].add( + table_obj.address + ) for entry_obj in table_obj.entries[page_class].values(): - self.data.reverse_map_pages[lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_pages[lvl][entry_obj.address].add( + table_obj.address + ) logger.info("Look for interrupt tables...") self.data.idts = self.machine.cpu.find_idt_tables() @@ -1408,18 +1635,33 @@ def do_find_radix_trees(self, args): if derived_address in already_explored: continue lvl, addr = derived_address - cr3_candidates.extend(self.radix_roots_from_data_page(lvl, addr, self.data.reverse_map_pages, self.data.reverse_map_tables)) + cr3_candidates.extend( + self.radix_roots_from_data_page( + lvl, + addr, + self.data.reverse_map_pages, + self.data.reverse_map_tables, + ) + ) already_explored.add(derived_address) - cr3_candidates = list(set(cr3_candidates).intersection(self.data.page_tables["global"][0].keys())) + cr3_candidates = list( + set(cr3_candidates).intersection( + self.data.page_tables["global"][0].keys() + ) + ) # Refine dataset and use a fake IDT table cr3s = {-1: {}} logger.info("Filter candidates...") for cr3 in tqdm(cr3_candidates): - # Obtain radix tree infos - consistency, pas = self.physpace(cr3, self.data.page_tables["global"], self.data.empty_tables, hierarchical=True) + consistency, pas = self.physpace( + cr3, + self.data.page_tables["global"], + self.data.empty_tables, + hierarchical=True, + ) # Only consistent trees are valid if not consistency: @@ -1429,7 +1671,9 @@ def do_find_radix_trees(self, args): if pas.get_kernel_size() == pas.get_user_size() == 0: continue - vas = self.virtspace(cr3, 0, self.machine.mmu.top_prefix, hierarchical=True) + vas = self.virtspace( + cr3, 0, self.machine.mmu.top_prefix, hierarchical=True + ) cr3s[-1][cr3] = RadixTree(cr3, 0, pas, vas) self.data.cr3s = cr3s @@ -1445,18 +1689,31 @@ def do_find_radix_trees(self, args): # Collect all possible CR3: a valid CR3 must be able to address the IDT logger.info("Collect all valids CR3s...") - idt_pg_addresses = self.machine.mmu.derive_page_address(idt_obj.address >> 12 << 12) + idt_pg_addresses = self.machine.mmu.derive_page_address( + idt_obj.address >> 12 << 12 + ) for level, addr in idt_pg_addresses: - cr3_candidates.update(self.radix_roots_from_data_page(level, addr, self.data.reverse_map_pages, self.data.reverse_map_tables)) - cr3_candidates = list(cr3_candidates.intersection(self.data.page_tables["global"][0].keys())) - logger.info("Number of possible CR3s for IDT located at {}:{}". - format(hex(idt_obj.address), len(cr3_candidates))) + cr3_candidates.update( + self.radix_roots_from_data_page( + level, + addr, + self.data.reverse_map_pages, + self.data.reverse_map_tables, + ) + ) + cr3_candidates = list( + cr3_candidates.intersection(self.data.page_tables["global"][0].keys()) + ) + logger.info( + "Number of possible CR3s for IDT located at {}:{}".format( + hex(idt_obj.address), len(cr3_candidates) + ) + ) # Collect the page containig each virtual addresses defined inside interrupt handlers handlers_pages = set() for handler in idt_obj.entries: - # Task Entry does not point to interrupt hanlder if isinstance(handler, IDTTaskEntry32): continue @@ -1474,7 +1731,9 @@ def do_find_radix_trees(self, args): for vaddr in handlers_pages: paddr = self.resolve_vaddr(cr3_candidate, vaddr) if paddr == -1: - logging.debug(f"find_radix_trees(): {hex(cr3_candidate)} failed to solve {hex(vaddr)}") + logging.debug( + f"find_radix_trees(): {hex(cr3_candidate)} failed to solve {hex(vaddr)}" + ) errors += 1 cr3s_for_idt.append([cr3_candidate, errors]) @@ -1486,7 +1745,11 @@ def do_find_radix_trees(self, args): # Save only CR3s which resolv the max number of addresses cr3s_for_idt.sort(key=lambda x: (x[1], x[0])) max_value = cr3s_for_idt[0][1] - logger.debug("Interrupt pages: {}, Maximum pages resolved: {}".format(len(handlers_pages), len(handlers_pages) - max_value)) + logger.debug( + "Interrupt pages: {}, Maximum pages resolved: {}".format( + len(handlers_pages), len(handlers_pages) - max_value + ) + ) # Consider only CR3 which resolve the maximum number of interrupt pages for cr3 in cr3s_for_idt: @@ -1494,7 +1757,12 @@ def do_find_radix_trees(self, args): break # Extract an approximation of the kernel and user physical address space - consistency, pas = self.physpace(cr3[0], self.data.page_tables["global"], self.data.empty_tables, hierarchical=True) + consistency, pas = self.physpace( + cr3[0], + self.data.page_tables["global"], + self.data.empty_tables, + hierarchical=True, + ) # Only consistent trees are valid if not consistency: @@ -1504,7 +1772,9 @@ def do_find_radix_trees(self, args): if pas.get_kernel_size() == pas.get_user_size() == 0: continue - vas = self.virtspace(cr3[0], 0, self.machine.mmu.top_prefix, hierarchical=True) + vas = self.virtspace( + cr3[0], 0, self.machine.mmu.top_prefix, hierarchical=True + ) cr3s[idt_obj.address][cr3[0]] = RadixTree(cr3[0], 0, pas, vas) self.data.cr3s = cr3s @@ -1524,8 +1794,8 @@ def do_show_radix_trees(self, args): idt_addr = self.parse_int(args[0]) if not self.data.idts: - logging.info("No IDT found by MMUShell") - idt_addr = -1 + logging.info("No IDT found by MMUShell") + idt_addr = -1 else: for idt in self.data.idts: if idt_addr == idt.address: @@ -1535,14 +1805,20 @@ def do_show_radix_trees(self, args): return # Show results - labels = ["Radix address", "First level", "Kernel size (Bytes)", "User size (Bytes)"] + labels = [ + "Radix address", + "First level", + "Kernel size (Bytes)", + "User size (Bytes)", + ] table = PrettyTable() table.field_names = labels for cr3 in self.data.cr3s[idt_addr].values(): table.add_row(cr3.entry_resume_stringified()) - table.sortby="Radix address" + table.sortby = "Radix address" print(table) + class MMUShellGTruth(MMUShell): def do_show_idtrs_gtruth(self, args): """Compare IDTs found with the ground truth""" @@ -1563,8 +1839,15 @@ def do_show_idtrs_gtruth(self, args): # Validate CR3 if cr3_obj.address not in self.data.page_tables["global"][0]: continue - consistency, pas = self.physpace(cr3_obj.address, self.data.page_tables["global"], self.data.empty_tables, hierarchical=True) - if not consistency or (not pas.get_kernel_size() and not pas.get_user_size()): + consistency, pas = self.physpace( + cr3_obj.address, + self.data.page_tables["global"], + self.data.empty_tables, + hierarchical=True, + ) + if not consistency or ( + not pas.get_kernel_size() and not pas.get_user_size() + ): continue else: valid_cr3_obj = cr3_obj @@ -1578,7 +1861,13 @@ def do_show_idtrs_gtruth(self, args): # Resolve the IDT virtual address table = PrettyTable() - table.field_names = ["Virtual address", "Physical address", "Found", "First seen", "Last seen"] + table.field_names = [ + "Virtual address", + "Physical address", + "Found", + "First seen", + "Last seen", + ] tp = 0 unresolved = 0 @@ -1591,24 +1880,41 @@ def do_show_idtrs_gtruth(self, args): # Not solved by the CR3... if paddr == -1: unresolved += 1 - table.add_row([hex(idtr_obj.address), "?", "?", self.gtruth["IDTR"][idtr_obj.value][0], self.gtruth["IDTR"][idtr_obj.value][1]]) + table.add_row( + [ + hex(idtr_obj.address), + "?", + "?", + self.gtruth["IDTR"][idtr_obj.value][0], + self.gtruth["IDTR"][idtr_obj.value][1], + ] + ) else: if paddr in idts: tp += 1 found = "X" else: found = "" - table.add_row([hex(idtr_obj.address), hex(paddr), found, self.gtruth["IDTR"][idtr_obj.value][0], self.gtruth["IDTR"][idtr_obj.value][1]]) + table.add_row( + [ + hex(idtr_obj.address), + hex(paddr), + found, + self.gtruth["IDTR"][idtr_obj.value][0], + self.gtruth["IDTR"][idtr_obj.value][1], + ] + ) print(f"Use CR3 address: {hex(valid_cr3_obj.address)}") print(table) print(f"TP:{tp} FP:{len(idts) - tp} Unresolved: {unresolved}") - + # Export results for next analysis if len(args) == 2 and args[1] == "export": from pickle import dump as dump_p + with open("dump.mmu", "wb") as f: - results = [{"cr3":tp} for tp in sorted(tps)] + results = [{"cr3": tp} for tp in sorted(tps)] dump_p(results, f) def do_show_radix_trees_gtruth(self, args): @@ -1646,7 +1952,12 @@ def do_show_radix_trees_gtruth(self, args): return # Collect all valid CR3 - latest_idt_va_used = IDTR(sorted(list(self.gtruth["IDTR"].keys()), key=lambda x: self.gtruth["IDTR"][x][1])[-1]) + latest_idt_va_used = IDTR( + sorted( + list(self.gtruth["IDTR"].keys()), + key=lambda x: self.gtruth["IDTR"][x][1], + )[-1] + ) idts = {} cr3_errors = defaultdict(list) @@ -1654,14 +1965,23 @@ def do_show_radix_trees_gtruth(self, args): cr3_obj = self.machine.mmu.cr3_class(cr3) if cr3_obj.address not in self.data.page_tables["global"][0]: continue - consistency, pas = self.physpace(cr3_obj.address, self.data.page_tables["global"], self.data.empty_tables, hierarchical=True) - if not consistency or (not pas.get_kernel_size() and not pas.get_user_size()): + consistency, pas = self.physpace( + cr3_obj.address, + self.data.page_tables["global"], + self.data.empty_tables, + hierarchical=True, + ) + if not consistency or ( + not pas.get_kernel_size() and not pas.get_user_size() + ): continue # Check if they are able to address the IDT table - derived_addresses = self.machine.mmu.derive_page_address(idt_addr >> 12 << 12) + derived_addresses = self.machine.mmu.derive_page_address( + idt_addr >> 12 << 12 + ) if not any([x[1] in pas for x in derived_addresses]): - continue # Trick! Only one dataclass per level + continue # Trick! Only one dataclass per level # Check if the CR3 is able to resolve the latest IDTR value used # (we check this for simplicity instead of the VA associated with the selected IDT) @@ -1704,22 +2024,27 @@ def do_show_radix_trees_gtruth(self, args): table = PrettyTable() table.field_names = ["Address", "Found", "First seen", "Last seen"] for tp in sorted(tps): - table.add_row([hex(tp), - "X", - self.gtruth["CR3"][valid_cr3s[tp].value][0], - self.gtruth["CR3"][valid_cr3s[tp].value][1]]) + table.add_row( + [ + hex(tp), + "X", + self.gtruth["CR3"][valid_cr3s[tp].value][0], + self.gtruth["CR3"][valid_cr3s[tp].value][1], + ] + ) for fn in sorted(fns): - table.add_row([hex(fn), - "", - self.gtruth["CR3"][valid_cr3s[fn].value][0], - self.gtruth["CR3"][valid_cr3s[fn].value][1]]) + table.add_row( + [ + hex(fn), + "", + self.gtruth["CR3"][valid_cr3s[fn].value][0], + self.gtruth["CR3"][valid_cr3s[fn].value][1], + ] + ) for fp in sorted(fps): - table.add_row([hex(fp), - "False positive", - "", - ""]) + table.add_row([hex(fp), "False positive", "", ""]) print(table) print(f"TP:{len(tps)} FN:{len(fns)} FP:{len(fps)}") @@ -1727,6 +2052,7 @@ def do_show_radix_trees_gtruth(self, args): # Export results for next analysis if len(args) == 2 and args[1] == "export": from pickle import dump as dump_p + with open("dump.mmu", "wb") as f: - results = [{"cr3":tp} for tp in sorted(tps)] + results = [{"cr3": tp} for tp in sorted(tps)] dump_p(results, f) diff --git a/architectures/mips.py b/mmushell/architectures/mips.py similarity index 70% rename from architectures/mips.py rename to mmushell/architectures/mips.py index 7f73495..2a45365 100644 --- a/architectures/mips.py +++ b/mmushell/architectures/mips.py @@ -1,22 +1,25 @@ +import logging + from architectures.generic import Machine as MachineDefault from architectures.generic import CPU as CPUDefault from architectures.generic import PhysicalMemory as PhysicalMemoryDefault from architectures.generic import MMUShell as MMUShellDefault from architectures.generic import MMU as MMUDefault from architectures.generic import CPUReg -import logging -from prettytable import PrettyTable -from dataclasses import dataclass -from tqdm import tqdm -from struct import unpack -from collections import defaultdict + from miasm.analysis.machine import Machine as MIASMMachine from miasm.core.bin_stream import bin_stream_vm from miasm.core.locationdb import LocationDB from miasm.jitter.VmMngr import Vm from miasm.jitter.csts import PAGE_READ, PAGE_WRITE, PAGE_EXEC -from copy import deepcopy + +from prettytable import PrettyTable +from dataclasses import dataclass +from collections import defaultdict +from struct import unpack from pprint import pprint +from tqdm import tqdm +from copy import deepcopy logger = logging.getLogger(__name__) @@ -34,14 +37,16 @@ def __init__(self, cpu, mmu, memory, **kwargs): super(Machine, self).__init__(cpu, mmu, memory, **kwargs) def get_miasm_machine(self): - mn_s = "mips" + str(self.cpu.bits) + ("b" if self.cpu.endianness == "big" else "l") + mn_s = ( + "mips" + str(self.cpu.bits) + ("b" if self.cpu.endianness == "big" else "l") + ) return MIASMMachine(mn_s) class CPURegMIPS(CPUReg): @classmethod def get_register_obj(cls, reg_name, value): - return globals()[reg_name](value) + return globals()[reg_name](value) class ContextConfig(CPURegMIPS): @@ -53,7 +58,7 @@ def is_valid(self, value): digit_changes = 0 for i in range(1, len(digits)): - if digits[i-1] != digits[i]: + if digits[i - 1] != digits[i]: digit_changes += 1 if digits[0] == "1" or digits[-1] == "1": @@ -73,12 +78,17 @@ def is_mmu_equivalent_to(self, other): return self.valid == other.valid and self.VirtualIndex == other.VirtualIndex def __repr__(self): - return f"ContextConfig {hex(self.value)} => VirtualIndex:{hex(self.VirtualIndex)}" + return ( + f"ContextConfig {hex(self.value)} => VirtualIndex:{hex(self.VirtualIndex)}" + ) class PageMask(CPURegMIPS): def is_valid(self, value): - if CPU.extract_bits(value, 0, 11) != 0x0 or CPU.extract_bits(value, 29, 3) != 0x0: + if ( + CPU.extract_bits(value, 0, 11) != 0x0 + or CPU.extract_bits(value, 29, 3) != 0x0 + ): return False MaskX = CPU.extract_bits(value, 11, 2) @@ -93,7 +103,7 @@ def _is_a_valid_mask(self, value): digit_changes = 0 for i in range(1, len(digits)): - if digits[i-1] != digits[i]: + if digits[i - 1] != digits[i]: digit_changes += 1 return digit_changes <= 1 @@ -107,16 +117,22 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return self.valid == other.valid and \ - self.Mask == other.Mask and \ - self.MaskX == other.MaskX + return ( + self.valid == other.valid + and self.Mask == other.Mask + and self.MaskX == other.MaskX + ) def __repr__(self): return f"PageMask {hex(self.value)} => Mask:{hex(self.Mask)}, MaskX:{hex(self.MaskX)}" + class PageGrain(CPURegMIPS): def is_valid(self, value): - return not(CPU.extract_bits(value, 5, 3) != 0x0 or CPU.extract_bits(value, 13, 13) != 0x0) + return not ( + CPU.extract_bits(value, 5, 3) != 0x0 + or CPU.extract_bits(value, 13, 13) != 0x0 + ) def __init__(self, value): self.value = value @@ -134,18 +150,25 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return (self.valid == other.valid and \ - self.RIE == other.RIE and \ - self.XIE == other.XIE and \ - self.ELPA == other.ELPA and \ - self.ESP == other.ESP) + return ( + self.valid == other.valid + and self.RIE == other.RIE + and self.XIE == other.XIE + and self.ELPA == other.ELPA + and self.ESP == other.ESP + ) def __repr__(self): return f"PageGrain {hex(self.value)} => RIE:{hex(self.RIE)}, XIE:{hex(self.XIE)}, ELPA:{hex(self.ELPA)}, ESP:{hex(self.ESP)}" + class SegCtl(CPURegMIPS): def is_valid(self, value): - return not(CPU.extract_bits(value, 7, 2) != 0x0 or CPU.extract_bits(value, 23, 2) != 0x0) + return not ( + CPU.extract_bits(value, 7, 2) != 0x0 + or CPU.extract_bits(value, 23, 2) != 0x0 + ) + class SegCtl0(SegCtl): def __init__(self, value): @@ -158,12 +181,17 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return (self.valid == other.valid and \ - self.CFG0 == other.CFG0 and \ - self.CFG1 == other.CFG1) + return ( + self.valid == other.valid + and self.CFG0 == other.CFG0 + and self.CFG1 == other.CFG1 + ) def __repr__(self): - return f"SegCtl0 {hex(self.value)} => CFG0:{hex(self.CFG0)}, CFG1:{hex(self.CFG1)}" + return ( + f"SegCtl0 {hex(self.value)} => CFG0:{hex(self.CFG0)}, CFG1:{hex(self.CFG1)}" + ) + class SegCtl1(SegCtl): def __init__(self, value): @@ -176,12 +204,16 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return (self.valid == other.valid and \ - self.CFG2 == other.CFG2 and \ - self.CFG3 == other.CFG3) + return ( + self.valid == other.valid + and self.CFG2 == other.CFG2 + and self.CFG3 == other.CFG3 + ) def __repr__(self): - return f"SegCtl1 {hex(self.value)} => CFG2:{hex(self.CFG2)}, CFG3:{hex(self.CFG3)}" + return ( + f"SegCtl1 {hex(self.value)} => CFG2:{hex(self.CFG2)}, CFG3:{hex(self.CFG3)}" + ) class SegCtl2(SegCtl): @@ -195,12 +227,16 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return (self.valid == other.valid and \ - self.CFG4 == other.CFG4 and \ - self.CFG5 == other.CFG5) + return ( + self.valid == other.valid + and self.CFG4 == other.CFG4 + and self.CFG5 == other.CFG5 + ) def __repr__(self): - return f"SegCtl2 {hex(self.value)} => CFG4:{hex(self.CFG4)}, CFG5:{hex(self.CFG5)}" + return ( + f"SegCtl2 {hex(self.value)} => CFG4:{hex(self.CFG4)}, CFG5:{hex(self.CFG5)}" + ) class PWBase(CPURegMIPS): @@ -218,13 +254,15 @@ def is_mmu_equivalent_to(self, other): def __repr__(self): return f"PWBase {hex(self.value)} => PWBase:{hex(self.PWBase)}" + class PWField(CPURegMIPS): def is_valid(self, value): - if CPU.processor_features["R6_CPU"] and \ - (CPU.extract_bits(value, 24, 6) < 12 or \ - CPU.extract_bits(value, 18, 6) < 12 or \ - CPU.extract_bits(value, 12, 6) < 12 or \ - CPU.extract_bits(value, 6, 6) < 12): + if CPU.processor_features["R6_CPU"] and ( + CPU.extract_bits(value, 24, 6) < 12 + or CPU.extract_bits(value, 18, 6) < 12 + or CPU.extract_bits(value, 12, 6) < 12 + or CPU.extract_bits(value, 6, 6) < 12 + ): return False return CPU.extract_bits(value, 30, 2) == 0x0 @@ -242,12 +280,14 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return self.valid == other.valid and \ - self.PTEI == other.PTEI and \ - self.PTI == other.PTI and \ - self.MDI == other.MDI and \ - self.UDI == other.UDI and \ - self.GDI == other.GDI + return ( + self.valid == other.valid + and self.PTEI == other.PTEI + and self.PTI == other.PTI + and self.MDI == other.MDI + and self.UDI == other.UDI + and self.GDI == other.GDI + ) def __repr__(self): return f"PWField {hex(self.value)} => PTEI:{hex(self.PTEI)}, PTI:{hex(self.PTI)}, MDI:{hex(self.MDI)}, UDI:{hex(self.UDI)}, GDI:{hex(self.GDI)}" @@ -274,17 +314,20 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return self.valid == other.valid and \ - self.PTEW == other.PTEW and \ - self.PTW == other.PTW and \ - self.MDW == other.MDW and \ - self.UDW == other.UDW and \ - self.GDW == other.GDW and \ - self.PS == other.PS + return ( + self.valid == other.valid + and self.PTEW == other.PTEW + and self.PTW == other.PTW + and self.MDW == other.MDW + and self.UDW == other.UDW + and self.GDW == other.GDW + and self.PS == other.PS + ) def __repr__(self): return f"PWSize {hex(self.value)} => PTEW:{hex(self.PTEW)}, PTW:{hex(self.PTW)}, MDW:{hex(self.MDW)}, UDW:{hex(self.UDW)}, GDW:{hex(self.GDW)}, PS:{hex(self.PS)}" + class Wired(CPURegMIPS): def is_valid(self, value): return CPU.extract_bits(value, 0, 16) <= CPU.extract_bits(value, 16, 16) @@ -304,6 +347,7 @@ def is_mmu_equivalent_to(self, other): def __repr__(self): return f"Wired {hex(self.value)} => Wired:{self.Wired}, Limit:{self.Limit}" + class PWCtl(CPURegMIPS): def is_valid(self, value): return CPU.extract_bits(value, 8, 23) == 0x0 @@ -320,11 +364,13 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return self.valid == other.valid and \ - self.Psn == other.Psn and \ - self.HugePg == other.HugePg and \ - self.DPH == other.DPH and \ - self.PWEn == other.PWEn + return ( + self.valid == other.valid + and self.Psn == other.Psn + and self.HugePg == other.HugePg + and self.DPH == other.DPH + and self.PWEn == other.PWEn + ) def __repr__(self): return f"PWCtl {hex(self.value)} => Psn:{hex(self.Psn)}, HugePg:{hex(self.HugePg)}, DPH:{hex(self.DPH)}, PWEn:{hex(self.PWEn)}" @@ -332,8 +378,9 @@ def __repr__(self): class Config(CPURegMIPS): def is_valid(self, value): - return CPU.extract_bits(value, 4, 3) == 0x0 and \ - CPU.extract_bits(value, 31, 1) == 1 + return ( + CPU.extract_bits(value, 4, 3) == 0x0 and CPU.extract_bits(value, 31, 1) == 1 + ) def __init__(self, value): self.value = value @@ -352,20 +399,25 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return (self.valid == other.valid and \ - self.K0 == other.K0 and \ - self.MT == other.MT and \ - self.KU == other.KU and \ - self.K23 == other.K23) + return ( + self.valid == other.valid + and self.K0 == other.K0 + and self.MT == other.MT + and self.KU == other.KU + and self.K23 == other.K23 + ) def __repr__(self): return f"Config {hex(self.value)} => K0:{hex(self.K0)}, MT:{hex(self.MT)}, KU:{hex(self.KU)}, K23:{hex(self.K23)}" + class Config5(CPURegMIPS): def is_valid(self, value): - return not(CPU.extract_bits(value, 1, 1) != 0x0 or \ - CPU.extract_bits(value, 12, 1) != 0x0 or \ - CPU.extract_bits(value, 14, 13) != 0x0) + return not ( + CPU.extract_bits(value, 1, 1) != 0x0 + or CPU.extract_bits(value, 12, 1) != 0x0 + or CPU.extract_bits(value, 14, 13) != 0x0 + ) def __init__(self, value): self.value = value @@ -392,14 +444,17 @@ def __init__(self, value): self.valid = False def is_mmu_equivalent_to(self, other): - return (self.valid == other.valid and \ - self.MRP == other.MRP and \ - self.MVH == other.MVH and \ - self.EVA == other.EVA) + return ( + self.valid == other.valid + and self.MRP == other.MRP + and self.MVH == other.MVH + and self.EVA == other.EVA + ) def __repr__(self): return f"Config5 {hex(self.value)} => MRP:{hex(self.MRP)}, MVH:{hex(self.MVH)}, EVA:{hex(self.EVA)}" + class CPU(CPUDefault): @classmethod def from_cpu_config(cls, cpu_config, **kwargs): @@ -417,73 +472,113 @@ def __init__(self, features): self.processor_features["opcode_unpack_fmt"] = "= 2) + self.processor_features["R6_CPU"] = ( + CPU.extract_bits_little(self.registers_values["Config"], 10, 3) >= 2 + ) CPU.endianness = self.endianness CPU.processor_features = self.processor_features CPU.registers_values = self.registers_values CPU.extract_bits = CPU.extract_bits_little + class CPUMips32(CPU): def __init__(self, features): super(CPUMips32, self).__init__(features) self.processor_features["ksegs"] = { - "Kseg0": (0x80000000, 0x20000000), # Each Kseg segment start address and size - "Kseg1": (0xA0000000, 0x20000000) - } + "Kseg0": ( + 0x80000000, + 0x20000000, + ), # Each Kseg segment start address and size + "Kseg1": (0xA0000000, 0x20000000), + } self.processor_features["kern_code_phys_end"] = 0x20000000 self.processor_features["opcode_to_mmu_regs"] = { - (4, 1): "ContextConfig", - (5, 0): "PageMask", - (5, 1): "PageGrain", - (5, 2): "SegCtl0", - (5, 3): "SegCtl1", - (5, 4): "SegCtl2", - (5, 5): "PWBase", - (5, 6): "PWField", - (5, 7): "PWSize", - (6, 0): "Wired", - (6, 6): "PWCtl", - (16, 0): "Config", - (16, 5): "Config5", - # (4, 0): "Context", # Registers not used in our analisys - # (16, 4): "Config4", - # (15, 1): "EBase", - # (31, 2): "KScratch0", - # (31, 3): "KScratch1", - # (31, 4): "KScratch2", - # (31, 5): "KScratch3", - # (31, 6): "KScratch4", - # (31, 7): "KScratch5" + (4, 1): "ContextConfig", + (5, 0): "PageMask", + (5, 1): "PageGrain", + (5, 2): "SegCtl0", + (5, 3): "SegCtl1", + (5, 4): "SegCtl2", + (5, 5): "PWBase", + (5, 6): "PWField", + (5, 7): "PWSize", + (6, 0): "Wired", + (6, 6): "PWCtl", + (16, 0): "Config", + (16, 5): "Config5", + # (4, 0): "Context", # Registers not used in our analisys + # (16, 4): "Config4", + # (15, 1): "EBase", + # (31, 2): "KScratch0", + # (31, 3): "KScratch1", + # (31, 4): "KScratch2", + # (31, 5): "KScratch3", + # (31, 6): "KScratch4", + # (31, 7): "KScratch5" } - self.processor_features["opcode_to_gregs"] = ["ZERO", "AT", "V0", "V1", "A0", "A1", "A2", "A3", - "T0", "T1", "T2", "T3", "T4", "T5", "T6", "T7", - "S0", "S1", "S2", "S3", "S4", "S5", "S6", "S7", - "T8", "T9", "K0", "K1", "GP", "SP", "FP", "RA"] + self.processor_features["opcode_to_gregs"] = [ + "ZERO", + "AT", + "V0", + "V1", + "A0", + "A1", + "A2", + "A3", + "T0", + "T1", + "T2", + "T3", + "T4", + "T5", + "T6", + "T7", + "S0", + "S1", + "S2", + "S3", + "S4", + "S5", + "S6", + "S7", + "T8", + "T9", + "K0", + "K1", + "GP", + "SP", + "FP", + "RA", + ] CPU.processor_features = self.processor_features CPU.registers_values = self.registers_values def parse_opcode(self, instr, page_addr, offset): opcodes = {} # Collect MTC0 instructions for MMU registers - if CPUMips32.extract_bits(instr, 21, 11) == 0b01000000100 and \ - CPUMips32.extract_bits(instr, 3, 8) == 0x0: - + if ( + CPUMips32.extract_bits(instr, 21, 11) == 0b01000000100 + and CPUMips32.extract_bits(instr, 3, 8) == 0x0 + ): sel = CPUMips32.extract_bits(instr, 0, 3) rd = CPUMips32.extract_bits(instr, 11, 5) - gr = CPUMips32.processor_features["opcode_to_gregs"][CPUMips32.extract_bits(instr, 16, 5)] + gr = CPUMips32.processor_features["opcode_to_gregs"][ + CPUMips32.extract_bits(instr, 16, 5) + ] # For each address collect which coprocessor register in involved and which general register it is used to load a value if (rd, sel) in self.processor_features["opcode_to_mmu_regs"]: phy_addr = page_addr + offset mmu_reg = self.processor_features["opcode_to_mmu_regs"][rd, sel] for kseg_start, kseg_size in CPU.processor_features["ksegs"].values(): - opcodes[phy_addr + kseg_start] = {"register": mmu_reg, - "gpr": [gr], - "f_addr": -1, - "f_parents": set(), - "instruction": "MTC0" - } + opcodes[phy_addr + kseg_start] = { + "register": mmu_reg, + "gpr": [gr], + "f_addr": -1, + "f_parents": set(), + "instruction": "MTC0", + } return opcodes def identify_functions_start(self, addreses): @@ -495,7 +590,7 @@ def identify_functions_start(self, addreses): instr_len = self.processor_features["instr_len"] # Disable MIASM logging - logger = logging.getLogger('asmblock') + logger = logging.getLogger("asmblock") logger.disabled = True for addr in tqdm(addreses): @@ -506,7 +601,6 @@ def identify_functions_start(self, addreses): # Maximum 10000 instructions instructions = 0 while True and instructions <= 10000: - try: asmcode = mdis.dis_instr(cur_addr) @@ -516,7 +610,21 @@ def identify_functions_start(self, addreses): break # JR RA/JR.HB RA/J/JIC/B - elif asmcode.name in ["B", "J", "JIC", "JR", "JR.HB", "BAL", "BALC", "BC", "JALR", "JALR.HB","JALX","JIALC", "JIC"]: + elif asmcode.name in [ + "B", + "J", + "JIC", + "JR", + "JR.HB", + "BAL", + "BALC", + "BC", + "JALR", + "JALR.HB", + "JALX", + "JIALC", + "JIC", + ]: cur_addr += instr_len * 2 break @@ -534,6 +642,7 @@ def identify_functions_start(self, addreses): del vm + class PhysicalMemory(PhysicalMemoryDefault): def get_miasm_vmmngr(self): if self._miasm_vm is not None: @@ -544,10 +653,16 @@ def get_miasm_vmmngr(self): # MIASM to see both of them for region_def in tqdm(self._memregions): if region_def["start"] == 0: - for kseg_name, kseg_addr_size in CPU.processor_features["ksegs"].items(): + for kseg_name, kseg_addr_size in CPU.processor_features[ + "ksegs" + ].items(): kseg_addr, kseg_size = kseg_addr_size - vm.add_memory_page(region_def["start"] + kseg_addr, PAGE_READ | PAGE_WRITE | PAGE_EXEC, - region_def["fd"].read(kseg_size), kseg_name) + vm.add_memory_page( + region_def["start"] + kseg_addr, + PAGE_READ | PAGE_WRITE | PAGE_EXEC, + region_def["fd"].read(kseg_size), + kseg_name, + ) region_def["fd"].seek(0) break @@ -565,14 +680,16 @@ class MIPS32(MMU): class MMUShell(MMUShellDefault): - def __init__(self, completekey='tab', stdin=None, stdout=None, machine={}): + def __init__(self, completekey="tab", stdin=None, stdout=None, machine={}): super(MMUShell, self).__init__(completekey, stdin, stdout, machine) if not self.data: - self.data = Data(is_mem_parsed = False, - is_registers_found = False, - opcodes = {}, - regs_values = {}) + self.data = Data( + is_mem_parsed=False, + is_registers_found=False, + opcodes={}, + regs_values={}, + ) def do_parse_memory(self, args): """Find MMU related opcodes in dump""" @@ -585,7 +702,11 @@ def do_parse_memory(self, args): def parse_memory(self): logger.info("Look for opcodes related to MMU setup...") - parallel_results = self.machine.apply_parallel(self.machine.mmu.PAGE_SIZE, self.machine.cpu.parse_opcodes_parallel, max_address=self.machine.cpu.processor_features["kern_code_phys_end"]) + parallel_results = self.machine.apply_parallel( + self.machine.mmu.PAGE_SIZE, + self.machine.cpu.parse_opcodes_parallel, + max_address=self.machine.cpu.processor_features["kern_code_phys_end"], + ) opcodes = {} logger.info("Reaggregate threads data...") @@ -612,7 +733,9 @@ def do_find_registers_values(self, arg): logging.info("Identify register values using data flow analysis...") # We use data flow analysis and merge the results - dataflow_values = self.machine.cpu.find_registers_values_dataflow(self.data.opcodes, zero_registers=["ZERO"]) + dataflow_values = self.machine.cpu.find_registers_values_dataflow( + self.data.opcodes, zero_registers=["ZERO"] + ) filtered_values = defaultdict(set) for register, values in dataflow_values.items(): @@ -623,11 +746,18 @@ def do_find_registers_values(self, arg): # Add default values for register, value in self.machine.cpu.registers_values.items(): - if register not in self.machine.cpu.processor_features["opcode_to_mmu_regs"].values(): + if ( + register + not in self.machine.cpu.processor_features[ + "opcode_to_mmu_regs" + ].values() + ): continue reg_obj = CPURegMIPS.get_register_obj(register, value) - if reg_obj.valid and all([not reg_obj.is_mmu_equivalent_to(x) for x in filtered_values[register]]): + if reg_obj.valid and all( + [not reg_obj.is_mmu_equivalent_to(x) for x in filtered_values[register]] + ): filtered_values[register].add(reg_obj) self.data.regs_values = filtered_values @@ -656,11 +786,17 @@ def do_show_registers_gtruth(self, args): gvalues = {} for reg_name in mmu_regs: if reg_name in self.gtruth: - last_reg_value = sorted(self.gtruth[reg_name].keys(), key=lambda x: self.gtruth[reg_name][x][1])[-1] - gvalues[reg_name] = CPURegMIPS.get_register_obj(reg_name, last_reg_value) + last_reg_value = sorted( + self.gtruth[reg_name].keys(), + key=lambda x: self.gtruth[reg_name][x][1], + )[-1] + gvalues[reg_name] = CPURegMIPS.get_register_obj( + reg_name, last_reg_value + ) elif reg_name in self.machine.cpu.registers_values: - gvalues[reg_name] = CPURegMIPS.get_register_obj(reg_name, self.machine.cpu.registers_values[reg_name]) - + gvalues[reg_name] = CPURegMIPS.get_register_obj( + reg_name, self.machine.cpu.registers_values[reg_name] + ) tps = defaultdict(list) fps = defaultdict(list) @@ -674,7 +810,6 @@ def do_show_registers_gtruth(self, args): tmp_fps = [] tmp_fps_count = 0 for found_value in self.data.regs_values[register]: - if register_obj.is_mmu_equivalent_to(found_value): # Count only one TP per register if register not in tps: @@ -682,18 +817,19 @@ def do_show_registers_gtruth(self, args): tps[register].append(found_value) else: # Count only FP not equivalent among them - if all([not found_value.is_mmu_equivalent_to(x) for x in fps[register]]): + if all( + [not found_value.is_mmu_equivalent_to(x) for x in fps[register]] + ): tmp_fps_count += 1 tmp_fps.append(found_value) # Add false negatives if register not in tps: fns[register] = register_obj - else: # Add false positives only if it is not a false negative + else: # Add false positives only if it is not a false negative fps[register] = tmp_fps fps_count += tmp_fps_count - print("\nTrue positives") pprint(tps) diff --git a/architectures/ppc.py b/mmushell/architectures/ppc.py similarity index 74% rename from architectures/ppc.py rename to mmushell/architectures/ppc.py index be635bd..bff3c0f 100644 --- a/architectures/ppc.py +++ b/mmushell/architectures/ppc.py @@ -1,28 +1,32 @@ +import logging +import portion + from architectures.generic import Machine as MachineDefault from architectures.generic import CPU as CPUDefault from architectures.generic import PhysicalMemory as PhysicalMemoryDefault from architectures.generic import MMUShell as MMUShellDefault from architectures.generic import MMU as MMUDefault from architectures.generic import CPUReg -import logging -from prettytable import PrettyTable -from dataclasses import dataclass -from tqdm import tqdm -from struct import unpack, iter_unpack -from collections import defaultdict + from miasm.analysis.machine import Machine as MIASMMachine from miasm.core.bin_stream import bin_stream_vm from miasm.core.locationdb import LocationDB -from copy import deepcopy + +from prettytable import PrettyTable +from dataclasses import dataclass +from collections import defaultdict +from struct import unpack, iter_unpack +from random import uniform from pprint import pprint +from tqdm import tqdm +from copy import deepcopy from time import sleep -from random import uniform from copy import deepcopy, copy from math import log2 -import portion logger = logging.getLogger(__name__) + @dataclass class Data: is_mem_parsed: bool @@ -35,7 +39,6 @@ class Data: class CPURegPPC(CPUReg): @classmethod def get_register_obj(cls, reg_name, value): - # It exists multiple BAT registers if "BAT" in reg_name: if "U" in reg_name: @@ -91,7 +94,10 @@ def __hash__(self): class BATL(CPURegPPC): def is_valid(self, value): - return CPU.extract_bits(value, 15, 10) == 0x0 and CPU.extract_bits(value, 29, 1) == 0 + return ( + CPU.extract_bits(value, 15, 10) == 0x0 + and CPU.extract_bits(value, 29, 1) == 0 + ) def __init__(self, value, name): self.bat_name = name @@ -112,9 +118,19 @@ def __eq__(self, other): def __hash__(self): return hash((self.value, self.bat_name)) + class PTE32: entry_name = "PTE32" - labels = ["Address:", "VSID:", "RPN:", "API:" "Secondary hash:", "Referenced:", "Changed:", "WIMG:", "PP:"] + labels = [ + "Address:", + "VSID:", + "RPN:", + "API:" "Secondary hash:", + "Referenced:", + "Changed:", + "WIMG:", + "PP:", + ] size = 4 addr_fmt = "0x{:08x}" @@ -134,19 +150,22 @@ def __hash__(self): def __repr__(self): e_resume = self.entry_resume_stringified() - return str([self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))]) + return str( + [self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))] + ) def entry_resume(self): - return [self.address, - hex(self.vsid), - hex(self.rpn), - hex(self.api), - bool(self.h), - bool(self.r), - bool(self.c), - bin(self.wimg), - hex(self.pp) - ] + return [ + self.address, + hex(self.vsid), + hex(self.rpn), + hex(self.api), + bool(self.h), + bool(self.r), + bool(self.c), + bin(self.wimg), + hex(self.pp), + ] def entry_resume_stringified(self): res = self.entry_resume() @@ -162,7 +181,17 @@ def __init__(self, address, size, ptegs): self.size = size self.ptegs = ptegs - table_fields = ["Entry address", "VSID", "RPN", "API", "Secondary hash","Referenced", "Changed", "WIMG", "PP"] + table_fields = [ + "Entry address", + "VSID", + "RPN", + "API", + "Secondary hash", + "Referenced", + "Changed", + "WIMG", + "PP", + ] addr_fmt = "0x{:08x}" def __repr__(self): @@ -175,13 +204,14 @@ def __repr__(self): entry_resume[0] = self.addr_fmt.format(entry_resume[0]) table.add_row(entry_resume) - table.sortby="Entry address" + table.sortby = "Entry address" return str(table) class PhysicalMemory(PhysicalMemoryDefault): pass + class CPU(CPUDefault): @classmethod def from_cpu_config(cls, cpu_config, **kwargs): @@ -209,35 +239,72 @@ class CPUPPC32(CPU): def __init__(self, features): super(CPUPPC32, self).__init__(features) self.processor_features["opcode_to_mmu_regs"] = { - 0: "SR0", 1: "SR1", 2: "SR2", 3: "SR3", 4: "SR4", 5: "SR5", 6: "SR6", - 7: "SR7", 8: "SR8", 9: "SR9", 10: "SR10", 11: "SR11", 12: "SR12",13: "SR13", 14: "SR14", 15: "SR15", - 25: "SDR1", 528: "IBAT0U", 529: "IBAT0L", 530: "IBAT1U", 531: "IBAT1L", 532: "IBAT2U", 533: "IBAT2L", - 534: "IBAT3U", 535: "IBAT3L", 536: "DBAT0U", 537: "DBAT0L", 538: "DBAT1U", 539: "DBAT1L", 540: "DBAT2U", - 541: "DBAT2L", 542: "DBAT3U", 543: "DBAT3L", + 0: "SR0", + 1: "SR1", + 2: "SR2", + 3: "SR3", + 4: "SR4", + 5: "SR5", + 6: "SR6", + 7: "SR7", + 8: "SR8", + 9: "SR9", + 10: "SR10", + 11: "SR11", + 12: "SR12", + 13: "SR13", + 14: "SR14", + 15: "SR15", + 25: "SDR1", + 528: "IBAT0U", + 529: "IBAT0L", + 530: "IBAT1U", + 531: "IBAT1L", + 532: "IBAT2U", + 533: "IBAT2L", + 534: "IBAT3U", + 535: "IBAT3L", + 536: "DBAT0U", + 537: "DBAT0L", + 538: "DBAT1U", + 539: "DBAT1L", + 540: "DBAT2U", + 541: "DBAT2L", + 542: "DBAT3U", + 543: "DBAT3L", } - self.processor_features["opcode_to_gregs"] = ["R{}".format(str(i)) for i in range(32)] + self.processor_features["opcode_to_gregs"] = [ + "R{}".format(str(i)) for i in range(32) + ] CPU.processor_features = self.processor_features CPU.registers_values = self.registers_values def parse_opcode(self, instr, page_addr, offset): # Exclude all possible instructions which are not compatible with MTSPR, MTSR, MTSRIN - if CPUPPC32.extract_bits(instr, 31, 1) != 0 or CPUPPC32.extract_bits(instr, 0, 6) != 31: + if ( + CPUPPC32.extract_bits(instr, 31, 1) != 0 + or CPUPPC32.extract_bits(instr, 0, 6) != 31 + ): return {} # Look for MTSPR (SDR1 and BATs) if CPUPPC32.extract_bits(instr, 21, 10) == 467: - spr = (CPUPPC32.extract_bits(instr, 16, 5) << 5) + CPUPPC32.extract_bits(instr, 11, 5) + spr = (CPUPPC32.extract_bits(instr, 16, 5) << 5) + CPUPPC32.extract_bits( + instr, 11, 5 + ) if spr == 25 or 528 <= spr <= 543: gr = CPUPPC32.extract_bits(instr, 6, 5) addr = page_addr + offset - return {addr: {"register": self.processor_features["opcode_to_mmu_regs"][spr], - "gpr": [self.processor_features["opcode_to_gregs"][gr]], - "f_addr": -1, - "f_parents": set(), - "instruction": "MTSPR" - } - } + return { + addr: { + "register": self.processor_features["opcode_to_mmu_regs"][spr], + "gpr": [self.processor_features["opcode_to_gregs"][gr]], + "f_addr": -1, + "f_parents": set(), + "instruction": "MTSPR", + } + } return {} def identify_functions_start(self, addreses): @@ -248,7 +315,7 @@ def identify_functions_start(self, addreses): mdis.dontdis_retcall = False instr_len = self.processor_features["instr_len"] - logger = logging.getLogger('asmblock') + logger = logging.getLogger("asmblock") logger.disabled = True for addr in tqdm(addreses): @@ -259,16 +326,28 @@ def identify_functions_start(self, addreses): # Maximum 10000 instructions instructions = 0 while True and instructions <= 10000: - # Stop if found an invalid instruction try: asmcode = mdis.dis_instr(cur_addr) # RET: BLR, BLRL, BCTRL # JMP: B, BA, BCTR - if asmcode.name in ["BA", "BCTR", "BLR", "BLRL", "BCTRL", "BL", - "BLA", "BCA", "BCL", "BCLA", "BCLR", - "BCLRL", "BCCTR", "BCCTRL"]: + if asmcode.name in [ + "BA", + "BCTR", + "BLR", + "BLRL", + "BCTRL", + "BL", + "BLA", + "BCA", + "BCL", + "BCLA", + "BCLR", + "BCLRL", + "BCCTR", + "BCCTRL", + ]: cur_addr += instr_len break @@ -288,14 +367,19 @@ def identify_functions_start(self, addreses): addreses[addr]["f_addr"] = cur_addr del vm + class Machine(MachineDefault): def get_miasm_machine(self): - mn_s = "ppc" + str(self.cpu.bits) + ("b" if self.cpu.endianness == "big" else "l") + mn_s = ( + "ppc" + str(self.cpu.bits) + ("b" if self.cpu.endianness == "big" else "l") + ) return MIASMMachine(mn_s) + class MMU(MMUDefault): pass + class PPC32(MMU): PAGE_SIZE = 4096 HTABLE_MIN_BIT_SIZE = 16 @@ -313,7 +397,7 @@ def __init__(self, mmu_config): def parse_htable_opcodes_parallel(self, addresses, frame_size, pidx, **kwargs): # Parse hash tables fragments and opcodes at the same time - sleep(uniform(pidx, pidx+1) // 1000) + sleep(uniform(pidx, pidx + 1) // 1000) opcodes = {} mm = copy(self.machine.memory) @@ -334,16 +418,24 @@ def parse_htable_opcodes_parallel(self, addresses, frame_size, pidx, **kwargs): fragments.append(frame_obj) # Parse opcodes - for idx, opcode in enumerate(iter_unpack(self.machine.cpu.processor_features["opcode_unpack_fmt"], frame_buf)): + for idx, opcode in enumerate( + iter_unpack( + self.machine.cpu.processor_features["opcode_unpack_fmt"], frame_buf + ) + ): opcode = opcode[0] offset = idx * instr_len - opcodes.update(self.machine.cpu.parse_opcode(opcode, frame_addr, offset)) + opcodes.update( + self.machine.cpu.parse_opcode(opcode, frame_addr, offset) + ) return fragments, opcodes def collect_htable_framents_opcodes(self): logger.info("Look for hash tables fragments and opcodes...") - parallel_results = self.machine.apply_parallel(self.machine.mmu.HTABLE_MIN_SIZE, self.parse_htable_opcodes_parallel) + parallel_results = self.machine.apply_parallel( + self.machine.mmu.HTABLE_MIN_SIZE, self.parse_htable_opcodes_parallel + ) opcodes = {} htables = defaultdict(list) @@ -369,16 +461,17 @@ def parse_hash_table(self, frame_buf, frame_size, frame_addr): return None # No duplicates allowed in a PTEG - pteg_addr = (frame_addr + entry_idx * 8) - ((frame_addr + entry_idx * 8) % 64) + pteg_addr = (frame_addr + entry_idx * 8) - ( + (frame_addr + entry_idx * 8) % 64 + ) if pteg_addr in ptegs and entry_obj in ptegs[pteg_addr]: - return None + return None ptegs[pteg_addr].add(entry_obj) return HashTable(frame_addr, frame_size, ptegs) def classify_htable_entry(self, entry, entry_addr): - # If BIT 0 Word 0 = 0 is EMPTY if not PPC32.extract_bits(entry[0], 0, 1): return False @@ -420,8 +513,9 @@ def glue_htable_fragments(self, fragments): low_frames = htables[1 << self.HTABLE_MIN_BIT_SIZE] # Starting from fragments with lower size, check if it is possible to form bigger ones aggregating two halves - for i in tqdm(range(self.HTABLE_MIN_BIT_SIZE + 1, self.HTABLE_MAX_BIT_SIZE + 1)): - + for i in tqdm( + range(self.HTABLE_MIN_BIT_SIZE + 1, self.HTABLE_MAX_BIT_SIZE + 1) + ): htable_size = 1 << i low_frames_size = 1 << (i - 1) @@ -430,15 +524,20 @@ def glue_htable_fragments(self, fragments): continue # Check if the other half is present and aggregate them - if htables[low_frames_size][htable_idx + 1].address == htable.address + low_frames_size: + if ( + htables[low_frames_size][htable_idx + 1].address + == htable.address + low_frames_size + ): nxt_htable = htables[low_frames_size][htable_idx + 1] pteg_c = deepcopy(htable.ptegs) pteg_c.update(nxt_htable.ptegs) - htables[htable_size].append(HashTable(address=htable.address, - size=htable_size, - ptegs=pteg_c)) + htables[htable_size].append( + HashTable( + address=htable.address, size=htable_size, ptegs=pteg_c + ) + ) htables[htable_size].sort(key=lambda x: x.address) if htables[htable_size]: @@ -458,14 +557,13 @@ def glue_htable_fragments(self, fragments): def filter_htables(self, htables): logging.info("Filtering...") - final_candidates = defaultdict(list) #deepcopy(htables) + final_candidates = defaultdict(list) # deepcopy(htables) already_visited = portion.empty() entropies = [] # Start from table of big size for table_size in reversed(list(htables.keys())): for table_obj in tqdm(htables[table_size]): - # If a valid bigger table contains the little one remove the little one if table_obj.address in already_visited: continue @@ -486,7 +584,9 @@ def filter_htables(self, htables): total_rpn += 1 # If the hash validation fails for some PTE discard the table - if not self.validate_entry_by_hash(pte, table_obj.address, table_obj.size, pte.address): + if not self.validate_entry_by_hash( + pte, table_obj.address, table_obj.size, pte.address + ): raise UserWarning except UserWarning: continue @@ -504,10 +604,12 @@ def filter_htables(self, htables): # Calculate RPN entropy starting from RPN probabilities for the table table_entropy = 0 for rpn_count in rpn_probabilities.values(): - table_entropy -= rpn_count/total_rpn * log2(rpn_count/total_rpn) + table_entropy -= rpn_count / total_rpn * log2(rpn_count / total_rpn) entropies.append([table_size, table_obj, table_entropy]) - already_visited |= (portion.closedopen(table_obj.address, table_obj.address + table_obj.size)) + already_visited |= portion.closedopen( + table_obj.address, table_obj.address + table_obj.size + ) final_candidates[table_size].append(table_obj) # HEURISTIC: Filter for RPN entropy: cut-off at 80% of the maximum entropy, @@ -520,8 +622,9 @@ def filter_htables(self, htables): return final_candidates - def validate_entry_by_hash(self, entry_obj, htable_addr, htable_size, pteg_addr_entry): - + def validate_entry_by_hash( + self, entry_obj, htable_addr, htable_size, pteg_addr_entry + ): # We have only a part of the page index (9 bit) vsid_reduced = CPUPPC32.extract_bits_little(entry_obj.vsid, 10, 9) @@ -545,16 +648,17 @@ def validate_entry_by_hash(self, entry_obj, htable_addr, htable_size, pteg_addr_ class MMUShell(MMUShellDefault): - def __init__(self, completekey='tab', stdin=None, stdout=None, machine={}): + def __init__(self, completekey="tab", stdin=None, stdout=None, machine={}): super(MMUShell, self).__init__(completekey, stdin, stdout, machine) if not self.data: - self.data = Data(is_mem_parsed = False, - is_registers_found = False, - opcodes = {}, - regs_values = {}, - htables = {} - ) + self.data = Data( + is_mem_parsed=False, + is_registers_found=False, + opcodes={}, + regs_values={}, + htables={}, + ) def do_parse_memory(self, args): """Parse memory to find opcode MMU related and hash tables""" @@ -567,7 +671,10 @@ def do_parse_memory(self, args): def parse_memory(self): # Collect opcodes and hash table of the minium size - fragments, self.data.opcodes = self.machine.mmu.collect_htable_framents_opcodes() + ( + fragments, + self.data.opcodes, + ) = self.machine.mmu.collect_htable_framents_opcodes() # Glue hash table htables = self.machine.mmu.glue_htable_fragments(fragments) @@ -608,7 +715,9 @@ def do_find_registers_values(self, arg): logging.info("Identify register values using data flow analysis...") # We use data flow analysis and merge the results - dataflow_values = self.machine.cpu.find_registers_values_dataflow(self.data.opcodes, zero_registers=["ZERO"]) + dataflow_values = self.machine.cpu.find_registers_values_dataflow( + self.data.opcodes, zero_registers=["ZERO"] + ) filtered_values = defaultdict(set) for register, values in dataflow_values.items(): @@ -619,11 +728,18 @@ def do_find_registers_values(self, arg): # Add default values for register, value in self.machine.cpu.registers_values.items(): - if register not in self.machine.cpu.processor_features["opcode_to_mmu_regs"].values(): + if ( + register + not in self.machine.cpu.processor_features[ + "opcode_to_mmu_regs" + ].values() + ): continue reg_obj = CPURegPPC.get_register_obj(register, value) - if reg_obj.valid and all([not reg_obj.is_mmu_equivalent_to(x) for x in filtered_values[register]]): + if reg_obj.valid and all( + [not reg_obj.is_mmu_equivalent_to(x) for x in filtered_values[register]] + ): filtered_values[register].add(reg_obj) self.data.regs_values = filtered_values @@ -640,8 +756,8 @@ def do_show_registers(self, args): print(register) def do_show_hashtable(self, arg): - 'Parse a Hash Table of a given size' - 'Usage: show_hashtable ADDRESS size' + "Parse a Hash Table of a given size" + "Usage: show_hashtable ADDRESS size" arg = arg.split() if len(arg) < 2: @@ -663,7 +779,18 @@ def do_show_hashtable(self, arg): logging.error("Address not in memory address space") return - valid_sizes = [65536, 131072, 262144, 524288, 1048576, 2097152, 4194304, 8388608, 16777216, 33554432] + valid_sizes = [ + 65536, + 131072, + 262144, + 524288, + 1048576, + 2097152, + 4194304, + 8388608, + 16777216, + 33554432, + ] if size not in valid_sizes: logging.error(f"Invalid order table. Valid sizes are {valid_sizes}") return @@ -675,8 +802,8 @@ def do_show_hashtable(self, arg): else: print(table) -class MMUShellGTruth(MMUShell): +class MMUShellGTruth(MMUShell): def do_show_hashtables_gtruth(self, args): """Show hash tables found and compare them with the ground truth""" if not self.data.is_mem_parsed: @@ -684,7 +811,14 @@ def do_show_hashtables_gtruth(self, args): return table = PrettyTable() - table.field_names = ["Address", "Size", "Found", "Correct size", "First seen", "Last seen"] + table.field_names = [ + "Address", + "Size", + "Found", + "Correct size", + "First seen", + "Last seen", + ] # Collect valid true table valids = {} @@ -692,17 +826,42 @@ def do_show_hashtables_gtruth(self, args): sdr1_obj = SDR1(sdr1_value) if not sdr1_obj.valid: continue - valids[sdr1_obj.address] = [sdr1_obj.size, sdr1_data["first_seen"], sdr1_data["last_seen"]] + valids[sdr1_obj.address] = [ + sdr1_obj.size, + sdr1_data["first_seen"], + sdr1_data["last_seen"], + ] # MMUShell found values found = {} for size in self.data.htables: for table_obj in self.data.htables[size]: - found[table_obj.address] = [table_obj.size, "False positive", "False positive"] + found[table_obj.address] = [ + table_obj.size, + "False positive", + "False positive", + ] already_visited = set() for k, v in valids.items(): - table.add_row([hex(k), hex(v[0]), "X" if k in found else "", "X" if v[0] == found.get(k, [None,])[0] else "", v[1], v[2]]) + table.add_row( + [ + hex(k), + hex(v[0]), + "X" if k in found else "", + "X" + if v[0] + == found.get( + k, + [ + None, + ], + )[0] + else "", + v[1], + v[2], + ] + ) already_visited.add((k, v[0])) fps = 0 @@ -715,7 +874,6 @@ def do_show_hashtables_gtruth(self, args): print(table) print(f"FP: {fps}") - def do_show_registers_gtruth(self, args): """Show registers value retrieved and compare with the ground truth""" if not self.data.is_registers_found: @@ -723,19 +881,41 @@ def do_show_registers_gtruth(self, args): return # Check if the last value of SDR1 was found - last_sdr1 = SDR1(sorted(self.gtruth["SDR1"].keys(), key=lambda x: self.gtruth["SDR1"][x]["last_seen"], reverse=True)[0]) + last_sdr1 = SDR1( + sorted( + self.gtruth["SDR1"].keys(), + key=lambda x: self.gtruth["SDR1"][x]["last_seen"], + reverse=True, + )[0] + ) print(f"Correct SDR1 value: {last_sdr1}") - print("SDR1 correct value... {}FOUND".format("" if last_sdr1 in self.data.regs_values["SDR1"] else "NOT ")) + print( + "SDR1 correct value... {}FOUND".format( + "" if last_sdr1 in self.data.regs_values["SDR1"] else "NOT " + ) + ) # Found last BAT registers used by the system bats_found = {} for t in ["I", "D"]: for i in range(4): reg_name = t + "BAT" + str(i) - batu_v, batl_v = sorted(self.gtruth[reg_name].keys(), key=lambda x: self.gtruth[reg_name][x][1], reverse=True)[0] + batu_v, batl_v = sorted( + self.gtruth[reg_name].keys(), + key=lambda x: self.gtruth[reg_name][x][1], + reverse=True, + )[0] bats_found[reg_name + "U"] = BATU(batu_v, reg_name + "U") bats_found[reg_name + "L"] = BATL(batl_v, reg_name + "L") # Check if values are found for reg_name in bats_found: - print("{} correct value... {}FOUND\t\t{}".format(reg_name, "" if bats_found[reg_name] in self.data.regs_values[reg_name] else "NOT ", bats_found[reg_name])) + print( + "{} correct value... {}FOUND\t\t{}".format( + reg_name, + "" + if bats_found[reg_name] in self.data.regs_values[reg_name] + else "NOT ", + bats_found[reg_name], + ) + ) diff --git a/architectures/riscv.py b/mmushell/architectures/riscv.py similarity index 69% rename from architectures/riscv.py rename to mmushell/architectures/riscv.py index d58e64d..203aab4 100644 --- a/architectures/riscv.py +++ b/mmushell/architectures/riscv.py @@ -1,25 +1,27 @@ +import logging +import portion +import multiprocessing as mp + from architectures.generic import Machine as MachineDefault from architectures.generic import CPU as CPUDefault from architectures.generic import PhysicalMemory as PhysicalMemoryDefault from architectures.generic import MMUShell as MMUShellDefault from architectures.generic import TableEntry, PageTable, MMURadix, PAS, RadixTree -import logging + +from more_itertools import divide from collections import defaultdict, deque from prettytable import PrettyTable +from dataclasses import dataclass +from IPython import embed +from random import uniform +from struct import iter_unpack, unpack from time import sleep from tqdm import tqdm from copy import deepcopy, copy -from random import uniform -from struct import iter_unpack, unpack -from dataclasses import dataclass -import multiprocessing as mp -# import cProfile -import portion -from more_itertools import divide -from IPython import embed logger = logging.getLogger(__name__) + @dataclass class Data: is_mem_parsed: bool @@ -31,6 +33,7 @@ class Data: reverse_map_pages: list satps: dict + class SATP: def __init__(self, satp): self.satp = satp @@ -41,6 +44,7 @@ def __init__(self, satp): def __repr__(self): print(f"Mode:{self.mode}, ASID:{self.asid}, Address: {hex(self.address)}") + class Machine(MachineDefault): def __init__(self, cpu, mmu, memory, **kwargs): super(Machine, self).__init__(cpu, mmu, memory, **kwargs) @@ -69,18 +73,28 @@ class CPU32(CPU): class CPU64(CPU): - pass + pass + ##################################################################### # 32 bit entries and page table ##################################################################### + class TEntry32(TableEntry): entry_size = 4 entry_name = "TEntry32" size = 0 - labels = ["Address:", "Dirty:", "Accessed:", "Global:", - "User:", "Readable:", "Writable:", "Exectuable:"] + labels = [ + "Address:", + "Dirty:", + "Accessed:", + "Global:", + "User:", + "Readable:", + "Writable:", + "Exectuable:", + ] addr_fmt = "0x{:08x}" def __hash__(self): @@ -88,18 +102,21 @@ def __hash__(self): def __repr__(self): e_resume = self.entry_resume_stringified() - return str([self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))]) + return str( + [self.labels[i] + " " + str(e_resume[i]) for i in range(len(self.labels))] + ) def entry_resume(self): - return [self.address, - self.is_dirty_entry(), - self.is_accessed_entry(), - self.is_global_entry(), - not self.is_supervisor_entry(), - self.is_readable_entry(), - self.is_writeble_entry(), - self.is_executable_entry() - ] + return [ + self.address, + self.is_dirty_entry(), + self.is_accessed_entry(), + self.is_global_entry(), + not self.is_supervisor_entry(), + self.is_readable_entry(), + self.is_writeble_entry(), + self.is_executable_entry(), + ] def entry_resume_stringified(self): res = self.entry_resume() @@ -134,31 +151,50 @@ def extract_addr(entry): return MMU.extract_bits(entry, 10, 21) << 12 def get_permissions(self): - perms = (self.is_readable_entry(), self.is_writeble_entry(), self.is_executable_entry()) + perms = ( + self.is_readable_entry(), + self.is_writeble_entry(), + self.is_executable_entry(), + ) if self.is_supervisor_entry(): return perms + (False, False, False) else: return (False, False, False) + perms + class PTE4KB32(TEntry32): entry_name = "PTE4KB32" size = 1024 * 4 + class PTE4MB(TEntry32): entry_name = "PTE4MB" size = 1024 * 1024 * 4 + class PTP32(TEntry32): entry_name = "PTP32" size = 0 + class PTP32L0(PTP32): entry_name = "PTP32L0" + class PageTableSV32(PageTable): entry_size = 4 - table_fields = ["Entry address", "Pointed address", "Dirty", "Accessed", "Global", - "User", "Readable", "Writable", "Exectuable", "Class"] + table_fields = [ + "Entry address", + "Pointed address", + "Dirty", + "Accessed", + "Global", + "User", + "Readable", + "Writable", + "Exectuable", + "Class", + ] addr_fmt = "0x{:08x}" def __repr__(self): @@ -168,15 +204,21 @@ def __repr__(self): for entry_class in self.entries: for entry_idx, entry_obj in self.entries[entry_class].items(): entry_addr = self.address + (entry_idx * self.entry_size) - table.add_row([self.addr_fmt.format(entry_addr)] + entry_obj.entry_resume_stringified() + [entry_class.entry_name]) + table.add_row( + [self.addr_fmt.format(entry_addr)] + + entry_obj.entry_resume_stringified() + + [entry_class.entry_name] + ) - table.sortby="Entry address" + table.sortby = "Entry address" return str(table) + ##################################################################### # 64 bit entries and page table ##################################################################### + class TEntry64(TEntry32): entry_size = 8 entry_name = "TEntry64" @@ -192,43 +234,54 @@ class PTE4KB64(TEntry64): entry_name = "PTE4KB64" size = 1024 * 4 + class PTE2MB(TEntry64): entry_name = "PTE2MB" size = 1024 * 1024 * 2 + class PTE1GB(TEntry64): entry_name = "PTE1GB" size = 1024 * 1024 * 1024 + class PTE512GB(TEntry64): entry_name = "PTE512GB" size = 1024 * 1024 * 1024 * 512 + class PTP64(TEntry64): entry_name = "PTP64" size = 0 + class PTP64L0(PTP64): entry_name = "PTP64L0" + class PTP64L1(PTP64): entry_name = "PTP64L1" + class PTP64L2(PTP64): entry_name = "PTP64L2" + class PageTableSV39(PageTableSV32): entry_size = 8 addr_fmt = "0x{:016x}" + class PageTableSV48(PageTableSV32): entry_size = 8 addr_fmt = "0x{:016x}" + ################################################################# # MMU Modes ################################################################# + class MMU(MMURadix): PAGE_SIZE = 4096 extract_bits = MMURadix.extract_bits_little @@ -247,11 +300,7 @@ class SV32(MMU): map_ptr_entries_to_levels = {"global": [PTP32L0, None]} map_datapages_entries_to_levels = {"global": [[PTE4MB], [PTE4KB32]]} map_level_to_table_size = {"global": [4096, 4096]} - map_entries_to_shifts = {"global": { - PTP32L0: 22, - PTE4MB: 22, - PTE4KB32: 12 - }} + map_entries_to_shifts = {"global": {PTP32L0: 22, PTE4MB: 22, PTE4KB32: 12}} map_reserved_entries_to_levels = {"global": [[], []]} def return_not_leaf_entry(self, entry): @@ -278,11 +327,9 @@ def classify_entry(self, page_addr, entry): # If R,W,X are 0 it can be a non-leaf entry if not MMU.extract_bits(entry, 1, 3): - # Pointer entry # D,A,U, must be 0 - if MMU.extract_bits(entry, 4, 1) or \ - MMU.extract_bits(entry, 6, 2): + if MMU.extract_bits(entry, 4, 1) or MMU.extract_bits(entry, 6, 2): return [None] else: return self.return_not_leaf_entry(entry) @@ -303,27 +350,30 @@ def extend_prefix(self, prefix, entry_idx, entry_class): return prefix | (entry_idx << 12) def split_vaddr(self, vaddr): - return (((PTP32L0, MMU.extract_bits(vaddr, 22, 10)), - (PTE4KB32, MMU.extract_bits(vaddr, 12, 10)), - ("OFFSET", MMU.extract_bits(vaddr, 0, 12))), \ - ((PTE4MB, MMU.extract_bits(vaddr, 22, 10)), - ("OFFSET", MMU.extract_bits(vaddr, 0, 22)))) + return ( + ( + (PTP32L0, MMU.extract_bits(vaddr, 22, 10)), + (PTE4KB32, MMU.extract_bits(vaddr, 12, 10)), + ("OFFSET", MMU.extract_bits(vaddr, 0, 12)), + ), + ( + (PTE4MB, MMU.extract_bits(vaddr, 22, 10)), + ("OFFSET", MMU.extract_bits(vaddr, 0, 22)), + ), + ) + class SV39(SV32): paging_unpack_format = " self.machine.mmu.radix_levels["global"] - 1: raise ValueError except ValueError: - logger.warning("Level must be an integer between 0 and {}".format(str(self.machine.mmu.radix_levels["global"] - 1))) + logger.warning( + "Level must be an integer between 0 and {}".format( + str(self.machine.mmu.radix_levels["global"] - 1) + ) + ) return if lvl == -1: @@ -509,13 +600,17 @@ def do_show_table(self, args): else: table_size = self.machine.mmu.map_level_to_table_size["global"][lvl] table_buff = self.machine.memory.get_data(addr, self.machine.mmu.PAGE_SIZE) - invalids, pt_classes, table_obj = self.machine.mmu.parse_frame(table_buff, addr, table_size, lvl) + invalids, pt_classes, table_obj = self.machine.mmu.parse_frame( + table_buff, addr, table_size, lvl + ) print(table_obj) print(f"Invalid entries: {invalids} Table levels: {pt_classes}") def parse_memory(self): logger.info("Look for paging tables...") - parallel_results = self.machine.apply_parallel(self.machine.mmu.PAGE_SIZE, self.machine.mmu.parse_parallel_frame) + parallel_results = self.machine.apply_parallel( + self.machine.mmu.PAGE_SIZE, self.machine.mmu.parse_parallel_frame + ) logger.info("Reaggregate threads data...") for result in parallel_results: page_tables, data_pages, empty_tables = result.get() @@ -532,15 +627,20 @@ def parse_memory(self): logger.info("Reduce false positives...") # Remove all tables which point to inexistent table of lower level for lvl in range(self.machine.mmu.radix_levels["global"] - 1): - ptr_class = self.machine.mmu.map_ptr_entries_to_levels["global"][lvl] referenced_nxt = [] for table_addr in list(self.data.page_tables["global"][lvl].keys()): - for entry_obj in self.data.page_tables["global"][lvl][table_addr].entries[ptr_class].values(): - if entry_obj.address not in self.data.page_tables["global"][lvl + 1] and \ - entry_obj.address not in self.data.empty_tables: - + for entry_obj in ( + self.data.page_tables["global"][lvl][table_addr] + .entries[ptr_class] + .values() + ): + if ( + entry_obj.address + not in self.data.page_tables["global"][lvl + 1] + and entry_obj.address not in self.data.empty_tables + ): # Remove the table self.data.page_tables["global"][lvl].pop(table_addr) break @@ -550,18 +650,28 @@ def parse_memory(self): # Remove table not referenced by upper levels referenced_nxt = set(referenced_nxt) - for table_addr in set(self.data.page_tables["global"][lvl + 1].keys()).difference(referenced_nxt): + for table_addr in set( + self.data.page_tables["global"][lvl + 1].keys() + ).difference(referenced_nxt): self.data.page_tables["global"][lvl + 1].pop(table_addr) logger.info("Fill reverse maps...") for lvl in range(0, self.machine.mmu.radix_levels["global"]): ptr_class = self.machine.mmu.map_ptr_entries_to_levels["global"][lvl] - page_class = self.machine.mmu.map_datapages_entries_to_levels["global"][lvl][0] # Trick: it has only one dataclass per level + page_class = self.machine.mmu.map_datapages_entries_to_levels["global"][ + lvl + ][ + 0 + ] # Trick: it has only one dataclass per level for table_addr, table_obj in self.data.page_tables["global"][lvl].items(): for entry_obj in table_obj.entries[ptr_class].values(): - self.data.reverse_map_tables[lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_tables[lvl][entry_obj.address].add( + table_obj.address + ) for entry_obj in table_obj.entries[page_class].values(): - self.data.reverse_map_pages[lvl][entry_obj.address].add(table_obj.address) + self.data.reverse_map_pages[lvl][entry_obj.address].add( + table_obj.address + ) def do_find_radix_trees(self, args): """Reconstruct radix trees""" @@ -584,17 +694,27 @@ def do_find_radix_trees(self, args): if derived_address in already_explored: continue lvl, addr = derived_address - candidates.extend(self.radix_roots_from_data_page(lvl, addr, self.data.reverse_map_pages, self.data.reverse_map_tables)) + candidates.extend( + self.radix_roots_from_data_page( + lvl, + addr, + self.data.reverse_map_pages, + self.data.reverse_map_tables, + ) + ) already_explored.add(derived_address) - candidates = list(set(candidates).intersection(self.data.page_tables["global"][0].keys())) + candidates = list( + set(candidates).intersection(self.data.page_tables["global"][0].keys()) + ) candidates.sort() logger.info("Filter candidates...") satps = {} for candidate in tqdm(candidates): - # Obtain radix tree infos - consistency, pas = self.physpace(candidate, self.data.page_tables["global"], self.data.empty_tables) + consistency, pas = self.physpace( + candidate, self.data.page_tables["global"], self.data.empty_tables + ) # Only consistent trees are valid if not consistency: @@ -616,14 +736,20 @@ def do_show_radix_trees(self, args): logging.info("Please, find them first!") return - labels = ["Radix address", "First level", "Kernel size (Bytes)", "User size (Bytes)"] + labels = [ + "Radix address", + "First level", + "Kernel size (Bytes)", + "User size (Bytes)", + ] table = PrettyTable() table.field_names = labels for satp in self.data.satps.values(): table.add_row(satp.entry_resume_stringified()) - table.sortby="Radix address" + table.sortby = "Radix address" print(table) + class MMUShellGTruth(MMUShell): def do_show_radix_trees_gtruth(self, args): """Compare found radix trees with the gound truth""" @@ -642,8 +768,14 @@ def do_show_radix_trees_gtruth(self, args): if new_satp.address not in self.data.page_tables["global"][0]: continue - consistency, pas = self.physpace(new_satp.address, self.data.page_tables["global"], self.data.empty_tables) - if not consistency or (not pas.get_kernel_size() and not pas.get_user_size()): + consistency, pas = self.physpace( + new_satp.address, + self.data.page_tables["global"], + self.data.empty_tables, + ) + if not consistency or ( + not pas.get_kernel_size() and not pas.get_user_size() + ): continue satp_tp[new_satp.address] = new_satp @@ -656,22 +788,27 @@ def do_show_radix_trees_gtruth(self, args): table = PrettyTable() table.field_names = ["Address", "Found", "First seen", "Last seen"] for tp in sorted(tps): - table.add_row([hex(tp), - "X", - self.gtruth["SATP"][satp_tp[tp].satp][0], - self.gtruth["SATP"][satp_tp[tp].satp][1]]) + table.add_row( + [ + hex(tp), + "X", + self.gtruth["SATP"][satp_tp[tp].satp][0], + self.gtruth["SATP"][satp_tp[tp].satp][1], + ] + ) for fn in sorted(fns): - table.add_row([hex(fn), - "", - self.gtruth["SATP"][satp_tp[fn].satp][0], - self.gtruth["SATP"][satp_tp[fn].satp][1]]) + table.add_row( + [ + hex(fn), + "", + self.gtruth["SATP"][satp_tp[fn].satp][0], + self.gtruth["SATP"][satp_tp[fn].satp][1], + ] + ) for fp in sorted(fps): - table.add_row([hex(fp), - "False positive", - "", - ""]) + table.add_row([hex(fp), "False positive", "", ""]) print(table) print(f"TP:{len(tps)} FN:{len(fns)} FP:{len(fps)}") @@ -679,6 +816,7 @@ def do_show_radix_trees_gtruth(self, args): # Export results for next analysis if len(args): from pickle import dump as dump_p + with open("dump.mmu", "wb") as f: - results = [{"satp":tp} for tp in sorted(tps)] - dump_p(results, f) \ No newline at end of file + results = [{"satp": tp} for tp in sorted(tps)] + dump_p(results, f) diff --git a/mmushell/exporter.py b/mmushell/exporter.py new file mode 100755 index 0000000..c30e468 --- /dev/null +++ b/mmushell/exporter.py @@ -0,0 +1,1554 @@ +#!/usr/bin/env python3 + +import json +import argparse +import traceback +import numpy as np + +from elftools.elf.elffile import ELFFile +from elftools.elf.segments import NoteSegment +from compress_pickle import load as load_c +from collections import defaultdict +from bisect import bisect +from pickle import load +from struct import iter_unpack +from tqdm import tqdm + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("PHY_ELF", help="Dump file in ELF format", type=str) + parser.add_argument( + "MMU_DATA", + help="List of DTBs and MMU configuration registers", + type=argparse.FileType("rb"), + ) + args = parser.parse_args() + + # Load session file + try: + mmu_data = load(args.MMU_DATA) + except Exception as e: + print(f"Error: {e}") + exit(1) + + # Load ELF file + elf_dump = ELFDump(args.PHY_ELF) + + # Dump processes + for idx, process_mmu_data in enumerate(tqdm(mmu_data)): + try: + virtspace = get_virtspace(elf_dump, process_mmu_data) + virtspace.export_virtual_memory_elf(f"process.{idx}.elf") + except Exception as e: + print(f"Error during process exporting: {e}") + # print(traceback.format_exc()) + + +class IMSimple: + """Fast search in intervals (begin), (end). + + Description: + - Represents a class for efficient search in intervals. + + Attributes: + - keys: List of interval beginnings. + - values: List of interval ends. + + Methods: + - __init__(self, keys, values): Constructor method to initialize IMSimple instance. + - __getitem__(self, x): Method to get the difference between x and the nearest interval beginning. + - contains(self, x, size): Method to check if a given address and size fit within intervals. + - get_values(self): Method to get values of intervals. + - get_extremes(self): Method to get the first and last interval boundaries. + + Purpose: + - To provide fast search functionality in intervals. + """ + + def __init__(self, keys, values): + self.keys = keys + self.values = values + + def __getitem__(self, x): + idx = bisect(self.keys, x) - 1 + begin = self.keys[idx] + if begin <= x < self.values[idx]: + return x - begin + else: + return -1 + + def contains(self, x, size): + idx = bisect(self.keys, x) - 1 + begin = self.keys[idx] + end = self.values[idx] + if not (begin <= x < end) or x + size >= end: + return -1 + else: + return x - begin + + def get_values(self): + return zip(self.keys, self.values) + + def get_extremes(self): + return self.keys[0], self.values[-1] + + +class IMData: + """Fast search in intervals (begin), (end, associated data). + + Description: + - Represents a class for efficient search in intervals with associated data. + + Attributes: + - keys: List of interval beginnings. + - values: List of tuples containing interval ends and associated data. + + Methods: + - __init__(self, keys, values): Constructor method to initialize IMSimple instance. + - __getitem__(self, x): Method to get the associated data for a specific point. + - contains(self, x, size): Method to check if a given address and size fit within intervals. + - get_values(self): Method to get values of intervals. + - get_extremes(self): Method to get the first and last interval boundaries. + + Purpose: + - To provide fast search functionality in intervals with associated data + """ + + def __init__(self, keys, values): + self.keys = keys + self.values = values + + def __getitem__(self, x): + """Return the index of the interval containing the value x. + + Args: + - x: The value to check if it falls within any interval. + - size: The size of the interval to find. + + Returns: + - Index of the interval containing x if x falls within the interval and the size of the interval is not exceeded by adding the specified size; otherwise, returns -1. + + Description: + - The method checks if the value x falls within any interval in the mapping. + - If x is within an interval and adding the specified size does not exceed the interval size, it returns the index of the interval containing x. + - Otherwise, it returns -1. + + Purpose: + - To determine if a value falls within any interval in the mapping and if adding a specified size to the value exceeds the interval size. + + Technical Explanation: + - The method utilizes binary search to find the index of the interval containing the value x. + - It checks if x falls within the interval and if adding the specified size does not exceed the interval size. + - If the conditions are met, it returns the index of the interval; otherwise, it returns -1. + """ + idx = bisect(self.keys, x) - 1 + begin = self.keys[idx] + end, data = self.values[idx] + if begin <= x < end: + return data + else: + return -1 + + def contains(self, x, size): + idx = bisect(self.keys, x) - 1 + begin = self.keys[idx] + end, data = self.values[idx] + if not (begin <= x < end) or x + size >= end: + return -1 + else: + return data + + def get_values(self): + return zip(self.keys, self.values) + + def get_extremes(self): + return self.keys[0], self.values[-1][0] + + +class IMOffsets: + """Fast search in intervals (begin), (end, associated offset). + + Description: + - Represents a class for efficient search in intervals. + - Initializes with keys and values representing interval boundaries and associated offsets. + + Attributes: + - keys: List of interval beginnings. + - values: List of tuples containing interval ends and associated offsets. + + Methods: + - __init__(self, keys, values): Constructor method to initialize IMOffsets instance. + - __getitem__(self, x): Method to get the offset associated with a specific point. + - contains(self, x, size): Method to check if a given address and size fit within intervals. + - get_values(self): Method to get values of intervals. + - get_extremes(self): Method to get the first and last interval boundaries. + + Purpose: + - To provide fast search functionality in intervals with associated offsets. + """ + + def __init__(self, keys, values): + self.keys = keys + self.values = values + + def __getitem__(self, x): + idx = bisect(self.keys, x) - 1 + begin = self.keys[idx] + end, data = self.values[idx] + if begin <= x < end: + return x - begin + data + else: + return -1 + + def contains(self, x, size): + """Return the maximum size and the list of intervals. + + Args: + - x: The value to check if it falls within any interval. + - size: The size of the interval to find. + + Returns: + - Maximum size available within intervals and the list of intervals. + + Description: + - The method checks if the value x falls within any interval in the mapping. + - If x is within an interval, it calculates the maximum size available starting from x and returns the list of intervals that cover the requested size. + + Purpose: + - To find intervals containing a specific value and to determine the maximum size available within those intervals. + + Technical Explanation: + - The method utilizes binary search to find the index of the interval containing the value x. + - It then iterates through the intervals starting from the identified index and calculates the maximum size available. + - The method returns the maximum size and the list of intervals that cover the requested size. + """ + idx = bisect(self.keys, x) - 1 + begin = self.keys[idx] + end, data = self.values[idx] + if not (begin <= x < end): + return 0, [] + + intervals = [(x, min(end - x, size), x - begin + data)] + if end - x >= size: + return size, intervals + + # The address space requested is bigger than a single interval + start = end + remaining = size - (end - x) + idx += 1 + print(start, remaining, idx) + while idx < len(self.values): + begin = self.keys[idx] + end, data = self.values[idx] + + # Virtual addresses must be contigous + if begin != start: + return size - remaining, intervals + + interval_size = min(end - begin, remaining) + intervals.append((start, interval_size, data)) + remaining -= interval_size + if not remaining: + return size, intervals + start += interval_size + idx += 1 + + def get_values(self): + return zip(self.keys, self.values) + + def get_extremes(self): + return self.keys[0], self.values[-1][0] + + +class IMOverlapping: + """Fast search in overlapping intervals (begin), (end, [associated offsets]). + + Description: + - Represents a class for efficiently searching overlapping intervals. + - Initializes with a list of intervals. + + Attributes: + - intervals: List of intervals (begin), (end, [associated offsets]). + - limits: Sorted list of interval limits. + - results: List of results corresponding to intervals. + + Methods: + - __init__(self, intervals): Constructor method to initialize IMOverlapping instance. + - __getitem__(self, x): Method to get values associated with a specific point. + - get_values(self): Method to get values of intervals. + + Purpose: + - To provide fast search functionality in overlapping intervals.) + """ + + def __init__(self, intervals): + """Initialize the class instance with a list of intervals. + Args: + - intervals: A list of intervals represented as tuples (l, r, v), where: + - l: The left endpoint of the interval. + - r: The right endpoint of the interval. + - v: The value associated with the interval. + + Description: + - The constructor initializes the class instance with a list of intervals. + - It organizes the intervals based on their left endpoints and computes the changes in offsets. + - The results are stored for efficient access during interval queries. + + Purpose: + - To initialize the class instance with a list of intervals and precompute results for interval queries. + + Technical Explanation: + - The constructor takes a list of intervals and sorts them based on their left endpoints. + - It computes the changes in offsets for each interval and stores the results for efficient access during interval queries. + + Example: + ```python + intervals = [(0, 3, [10]), (1, 4, [20]), (2, 5, [30])] + instance = ClassName(intervals) ``` + """ + limit2changes = defaultdict(lambda: ([], [])) + for idx, (l, r, v) in enumerate(intervals): + assert l < r + limit2changes[l][0].append(v) + limit2changes[r][1].append(v) + self.limits, changes = zip(*sorted(limit2changes.items())) + + self.results = [[]] + s = set() + offsets = {} + res = [] + for idx, (arrivals, departures) in enumerate(changes): + s.difference_update(departures) + for i in departures: + offsets.pop(i) + + for i in s: + offsets[i] += self.limits[idx] - self.limits[idx - 1] + + s.update(arrivals) + for i in arrivals: + offsets[i] = 0 + + res.clear() + for k, v in offsets.items(): + res.extend([i + v for i in k]) + self.results.append(res.copy()) + + def __getitem__(self, x): + idx = bisect(self.limits, x) + k = x - self.limits[idx - 1] + return [k + p for p in self.results[idx]] + + def get_values(self): + return zip(self.limits, self.results) + + +class ELFDump: + """Represents a class for parsing ELF files and extracting machine data. + + Description: + - ELFDump is a class designed to read and parse ELF files, extracting necessary information such as machine data, endianness, architecture, and memory mapped devices. + + Attributes: + - filename: Name of the ELF file. + - machine_data: Dictionary containing machine configuration extracted from ELF segments. + - p2o: Mapping of physical addresses to file offsets. + - o2p: Mapping of file offsets to physical addresses. + - p2mmd: Mapping of physical addresses to Memory Mapped Devices (MMD) intervals. + - elf_buf: Buffer containing the contents of the ELF file. + - elf_filename: Name of the ELF file. + + Methods: + - __init__(self, elf_filename): Constructor method to initialize ELFDump instance. + - __read_elf_file(self, elf_fd): Reads and parses the ELF file from the provided file descriptor. + - _compact_intervals_simple(self, intervals): Compacts intervals for contiguous pointer values. + - _compact_intervals(self, intervals): Compacts intervals for contiguous pointer and pointed values. + - in_ram(self, paddr, size=1): Returns True if the interval is completely in RAM. + - in_mmd(self, paddr, size=1): Returns True if the interval is completely in Memory mapped devices space. + - get_data(self, paddr, size): Returns the data at physical address (interval). + - get_data_raw(self, offset, size=1): Returns the data at the offset in the ELF (interval). + - get_machine_data(self): Returns a dict containing machine configuration. + - get_ram_regions(self): Returns all the RAM regions of the machine and the associated offset. + - get_mmd_regions(self): Returns all the Memory mapped devices intervals of the machine and the associated offset. + + Purpose: + - To parse ELF files, extract machine data, and provide methods for accessing data and information from the ELF file. + + + """ + def __init__(self, elf_filename): + self.filename = elf_filename + self.machine_data = {} + self.p2o = None # Physical to RAM (ELF offset) + self.o2p = None # RAM (ELF offset) to Physical + self.p2mmd = None # Physical to Memory Mapped Devices (ELF offset) + self.elf_buf = np.zeros(0, dtype=np.byte) + self.elf_filename = elf_filename + + with open(self.elf_filename, "rb") as elf_fd: + # Load the ELF in memory + self.elf_buf = np.fromfile(elf_fd, dtype=np.byte) + elf_fd.seek(0) + + # Parse the ELF file + self.__read_elf_file(elf_fd) + + def __read_elf_file(self, elf_fd): + """Functioning: + - Reads and parses the dump in ELF format from the provided file descriptor. + - Extracts machine data, endianness, and architecture information from ELF segments. + - Fills arrays for translating physical addresses to file offsets. + - Compacts intervals for efficient processing. + + Args: + self: The instance of the calling class. + elf_fd: File descriptor for the ELF file to be read and parsed. + + Expected Results: + Extract machine data, endianness, and architecture information from ELF segments. + Fill arrays for translating physical addresses to file offsets. + Compact intervals for efficient processing. + + Purpose: + To read and parse an ELF file, extracting necessary information and preparing data structures for further processing. + + Technical Explanation: + - The function begins by initializing lists to store information related to physical addresses, offsets, and memory-mapped devices. + - It then utilizes the `ELFFile` class from the `elftools` module to parse the ELF file. + - For each segment in the ELF file, it checks if it is a NoteSegment. If it is, it iterates through notes and extracts machine data, endianness, and architecture information. + - If the segment is not a NoteSegment, it calculates the start and end addresses of the segment and fills arrays needed to translate physical addresses to file offsets. + - These arrays are then compacted for efficiency using the `_compact_intervals` method. + - Finally, the compacted intervals are assigned to attributes of the class instance for future reference and usage. + """ + o2p_list = [] + p2o_list = [] + p2mmd_list = [] + elf_file = ELFFile(elf_fd) + + for segm in elf_file.iter_segments(): + # NOTES + if isinstance(segm, NoteSegment): + for note in segm.iter_notes(): + # Ignore NOTE genrated by other softwares + if note["n_name"] != "FOSSIL": + continue + + # At moment only one type of note + if note["n_type"] != 0xDEADC0DE: + continue + + # Suppose only one deadcode note + self.machine_data = json.loads(note["n_desc"].rstrip("\x00")) + self.machine_data["Endianness"] = ( + "little" + if elf_file.header["e_ident"].EI_DATA == "ELFDATA2LSB" + else "big" + ) + self.machine_data["Architecture"] = "_".join( + elf_file.header["e_machine"].split("_")[1:] + ) + else: + # Fill arrays needed to translate physical addresses to file offsets + r_start = segm["p_vaddr"] + r_end = r_start + segm["p_memsz"] + + if segm["p_filesz"]: + p_offset = segm["p_offset"] + p2o_list.append((r_start, (r_end, p_offset))) + o2p_list.append((p_offset, (p_offset + (r_end - r_start), r_start))) + else: + # device_name = "" # UNUSED + for device in self.machine_data[ + "MemoryMappedDevices" + ]: # Possible because NOTES always the first segment + if device[0] == r_start: + # device_name = device[1] # UNUSED + break + p2mmd_list.append((r_start, r_end)) + + # Debug + # self.p2o_list = p2o_list + # self.o2p_list = o2p_list + # self.p2mmd_list = p2mmd_list + + # Compact intervals + p2o_list = self._compact_intervals(p2o_list) + o2p_list = self._compact_intervals(o2p_list) + p2mmd_list = self._compact_intervals_simple(p2mmd_list) + + self.p2o = IMOffsets(*list(zip(*sorted(p2o_list)))) + self.o2p = IMOffsets(*list(zip(*sorted(o2p_list)))) + self.p2mmd = IMSimple(*list(zip(*sorted(p2mmd_list)))) + + def _compact_intervals_simple(self, intervals): + """Functioning: + - Compacts intervals if pointer values are contiguous. + - Merges adjacent intervals into a single interval. + + Args: + self: The instance of the calling class. + intervals: List of intervals to be compacted. + + Expected Results: + Compact intervals if pointer values are contiguous. + + Purpose: + To merge adjacent intervals into a single interval for efficient processing. + + Technical Explanation: + - The function takes a list of intervals as input, where each interval consists of a begin and end value. + - It initializes an empty list to store the compacted intervals. + - It then iterates through each interval in the input list. + - If the end value of the current interval is equal to the begin value of the previous interval, it extends the previous interval to include the end value of the current interval. + - If the end value of the current interval is not equal to the begin value of the previous interval, it adds the previous interval to the list of compacted intervals and starts a new interval. + - Finally, it appends the last interval to the list of compacted intervals. + - The function returns the list of compacted intervals. + + Example: + Input: [(0, 1), (1, 3), (4, 6), (6, 8)] + Output: [(0, 3), (4, 8)] + """ + fused_intervals = [] + prev_begin = prev_end = -1 + for interval in intervals: + begin, end = interval + if prev_end == begin: + prev_end = end + else: + fused_intervals.append((prev_begin, prev_end)) + prev_begin = begin + prev_end = end + + if prev_begin != begin: + fused_intervals.append((prev_begin, prev_end)) + else: + fused_intervals.append((begin, end)) + + return fused_intervals[1:] + + def _compact_intervals(self, intervals): + """Functioning: + - Compacts intervals if pointer and pointed values are contiguous. + - Merges adjacent intervals into a single interval. + + Args: + self: The instance of the calling class. + intervals: List of intervals to be compacted, where each interval consists of a start address and a tuple containing the end address and physical address. + + Expected Results: + Compact intervals if pointer and pointed values are contiguous. + + Purpose: + To merge adjacent intervals into a single interval for efficient processing. + + Technical Explanation: + - The function takes a list of intervals as input, where each interval consists of a start address and a tuple containing the end address and physical address. + - It initializes an empty list to store the compacted intervals. + - It then iterates through each interval in the input list. + - If the end address of the current interval is equal to the start address of the previous interval, and the physical address of the current interval is contiguous with the previous physical address, it extends the previous interval to include the end address and physical address of the current interval. + - If the end address of the current interval is not equal to the start address of the previous interval or the physical address is not contiguous, it adds the previous interval to the list of compacted intervals and starts a new interval. + - Finally, it appends the last interval to the list of compacted intervals. + - The function returns the list of compacted intervals. + + Example: + Input: [(0, (1, 100)), (1, (3, 101)), (4, (6, 102)), (6, (8, 103))] + Output: [(0, (3, 100)), (4, (8, 102))] + """ + fused_intervals = [] + prev_begin = prev_end = prev_phy = -1 + for interval in intervals: + begin, (end, phy) = interval + if prev_end == begin and prev_phy + (prev_end - prev_begin) == phy: + prev_end = end + else: + fused_intervals.append((prev_begin, (prev_end, prev_phy))) + prev_begin = begin + prev_end = end + prev_phy = phy + + if prev_begin != begin: + fused_intervals.append((prev_begin, (prev_end, prev_phy))) + else: + fused_intervals.append((begin, (end, phy))) + + return fused_intervals[1:] + + def in_ram(self, paddr, size=1): + """Return True if the interval is completely in RAM""" + return self.p2o.contains(paddr, size)[0] == size + + def in_mmd(self, paddr, size=1): + """Return True if the interval is completely in Memory mapped devices space""" + return True if self.p2mmd.contains(paddr, size) != -1 else False + + def get_data(self, paddr, size): + """Return the data at physical address (interval)""" + size_available, intervals = self.p2o.contains(paddr, size) + if size_available != size: + return bytes() + + ret = bytearray() + for interval in intervals: + _, interval_size, offset = interval + ret.extend(self.elf_buf[offset : offset + interval_size].tobytes()) + + return ret + + def get_data_raw(self, offset, size=1): + """Return the data at the offset in the ELF (interval)""" + return self.elf_buf[offset : offset + size].tobytes() + + def get_machine_data(self): + """Return a dict containing machine configuration""" + return self.machine_data + + def get_ram_regions(self): + """Return all the RAM regions of the machine and the associated offset""" + return self.p2o.get_values() + + def get_mmd_regions(self): + """Return all the Memory mapped devices intervals of the machine and the associated offset""" + return self.p2mmd.get_values() + + +def get_virtspace(phy, mmu_values): + """Get Virtual Address Space from a Physical Address Space. + Functioning: + - Returns a virtual address space based on the physical address space and MMU values. + - Determines the architecture of the physical address space and selects the appropriate translator to create the virtual address space. + + Args: + phy: Physical address space object. + mmu_values: MMU (Memory Management Unit) values associated with the physical address space. + + Returns: + A virtual address space object. + + Raises: + Exception: If the architecture of the physical address space is unknown. + + Purpose: + To create a virtual address space from a physical address space based on the architecture of the system. +""" + architecture = phy.get_machine_data()["Architecture"].lower() + if "riscv" in architecture: + return RISCVTranslator.factory(phy, mmu_values) + elif "x86" in architecture or "386" in architecture: + return IntelTranslator.factory(phy, mmu_values) + else: + raise Exception("Unknown architecture") + + +class AddressTranslator: + """ Represents a base class for address translation. + Description: + - AddressTranslator is a base class providing functionalities for translating addresses. + + Attributes: + - dtb: Device tree blob. + - phy: Physical memory instance. + - wordsize: Size of the machine word (4 or 8 bytes). + - word_type: Data type of the machine word (np.uint32 or np.uint64). + - word_fmt: Format of the machine word for packing and unpacking. + - v2o: Mapping of virtual to offset. + - o2v: Mapping of offset to virtual. + - pmasks: Permission masks. + - minimum_page: Minimum page size. + + Purpose: + - To provide a base class for address translation and related functionalities. + """ + def __init__(self, dtb, phy): + """Initialize the AddressTranslator instance with a device tree blob and physical memory instance. + + Args: + - dtb: Device tree blob. + - phy: Physical memory instance. + + Description: + - This constructor initializes the AddressTranslator instance with a device tree blob and physical memory instance. + - It sets machine-specific attributes such as word_type, word_fmt based on the machine's word size and endianness. + + Purpose: + - To initialize the AddressTranslator instance with necessary attributes and machine specifics. + """ + self.dtb = dtb + self.phy = phy + + # Set machine specifics + if self.wordsize == 4: + self.word_type = np.uint32 + if self.phy.machine_data["Endianness"] == "big": + self.word_fmt = ">u4" + else: + self.word_fmt = " physical mappings. + + Args: + self: The instance of the calling class. + table_addr: The address of the radix tree table. + mapping: A dictionary to store virtual <-> physical mappings. + reverse_mapping: A dictionary to store reverse mappings from physical to virtual addresses. + lvl (optional): The level of the radix tree being explored (default is 0). + prefix (optional): The prefix used for constructing virtual addresses (default is 0). + upmask (optional): A list containing the permission mask (default is an empty list). + + Expected Results: + Exploration of a radix tree to obtain mappings. + + Purpose: + To traverse a radix tree and obtain virtual <-> physical mappings. + + Functioning: + - The function starts by retrieving the data of the radix tree table. + - It iterates through each entry of the table, checking for validity. + - If the entry is valid, it computes the virtual address and permission mask. + - If it's the last level of the radix tree or a leaf node, it checks if the page is in RAM or memory-mapped devices. + - If the page is in RAM, it constructs the virtual address and adds it to the mappings dictionary. + - If it's not the last level, it recursively calls itself to explore lower-level entries. + """ + + + table = self.phy.get_data(table_addr, self.table_sizes[lvl]) + if not table: + print( + f"Table {hex(table_addr)} size:{self.table_sizes[lvl]} at level {lvl} not in RAM" + ) + return + + for index, entry in enumerate(iter_unpack(self.unpack_fmt, table)): + is_valid, pmask, phy_addr, page_size = self._read_entry( + index, entry[0], lvl + ) + + if not is_valid: + continue + + virt_addr = prefix | (index << self.shifts[lvl]) + pmask = upmask + pmask + + if (lvl == self.total_levels - 1) or page_size: # Last radix level or Leaf + # Ignore pages not in RAM (some OSs map more RAM than available) and not memory mapped devices + in_ram = self.phy.in_ram(phy_addr, page_size) + in_mmd = self.phy.in_mmd(phy_addr, page_size) + if not in_ram and not in_mmd: + continue + + permissions = self._reconstruct_permissions(pmask) + virt_addr = self._finalize_virt_addr(virt_addr, permissions) + mapping[permissions].append((virt_addr, page_size, phy_addr, in_mmd)) + + # Add only RAM address to the reverse translation P2V + if in_ram and not in_mmd: + if permissions not in reverse_mapping: + reverse_mapping[permissions] = defaultdict(list) + reverse_mapping[permissions][(phy_addr, page_size)].append( + virt_addr + ) + else: + # Lower level entry + self._explore_radixtree( + phy_addr, + mapping, + reverse_mapping, + lvl=lvl + 1, + prefix=virt_addr, + upmask=pmask, + ) + + def _compact_intervals_virt_offset(self, intervals): + """Functioning: + - Compacts intervals if virtual addresses and offset values are contiguous (virt -> offset). + - Iterates through intervals and merges contiguous ones. + - Appends the merged intervals to the `fused_intervals` list. + + Args: + self: The instance of the calling class. + intervals: List of intervals containing virtual addresses, end addresses, physical pages, and permission masks. + + Expected Results: + Compact intervals for efficient processing of virtual addresses and offset values. + + Purpose: + To compact intervals if virtual addresses and offset values are contiguous, enhancing processing efficiency. + """ + fused_intervals = [] + prev_begin = prev_end = prev_offset = -1 + for interval in intervals: + begin, end, phy, _ = interval + + offset = self.phy.p2o[phy] + if offset == -1: + continue + + if prev_end == begin and prev_offset + (prev_end - prev_begin) == offset: + prev_end = end + else: + fused_intervals.append((prev_begin, (prev_end, prev_offset))) + prev_begin = begin + prev_end = end + prev_offset = offset + + if prev_begin != begin: + fused_intervals.append((prev_begin, (prev_end, prev_offset))) + else: + offset = self.phy.p2o[phy] + if offset == -1: + print(f"ERROR!! {phy}") + else: + fused_intervals.append((begin, (end, offset))) + return fused_intervals[1:] + + def _compact_intervals_permissions(self, intervals): + """Functioning: + - Compacts intervals if virtual addresses are contiguous and permissions are equal. + - Iterates through intervals and merges contiguous ones with equal permissions. + - Appends the merged intervals to the `fused_intervals` list. + + Args: + self: The instance of the calling class. + intervals: List of intervals containing virtual addresses, end addresses, physical pages, and permission masks. + + Expected Results: + Compact intervals for efficient processing of contiguous virtual addresses with equal permissions. + + Purpose: + To compact intervals if virtual addresses are contiguous and permissions are equal, enhancing processing efficiency.s + """ + fused_intervals = [] + prev_begin = prev_end = -1 + prev_pmask = (0, 0) + for interval in intervals: + begin, end, _, pmask = interval + if prev_end == begin and prev_pmask == pmask: + prev_end = end + else: + fused_intervals.append((prev_begin, (prev_end, prev_pmask))) + prev_begin = begin + prev_end = end + prev_pmask = pmask + + if prev_begin != begin: + fused_intervals.append((prev_begin, (prev_end, prev_pmask))) + else: + fused_intervals.append((begin, (end, pmask))) + + return fused_intervals[1:] + + def _reconstruct_mappings(self, table_addr, upmask): + """ + Args: + self: The instance of the calling class. + table_addr: The address of the radix tree table. + upmask: The permission mask. + + Expected Results: + Reconstruction of memory mappings for efficient processing. + + Purpose: + To reconstruct memory mappings using a radix tree traversal approach and organize them for further processing. + + Exploring Radix Tree: + - Initiates mappings and calls `_explore_radixtree()` with necessary parameters. + - Populates `mapping` and `reverse_mapping` with mappings obtained from the radix tree. + + ELF Virtual Mapping Reconstruction: + - Sets `reverse_mapping` and `mapping` attributes for potential future reference. + - Likely essential for reconstructing ELF (Executable and Linkable Format) virtual mappings. + + Collecting Intervals: + - Collects intervals based on mappings, excluding user-inaccessible pages. + - Compiles intervals with start address, end address, physical page, and permission mask. + + Interval Fusion: + - Performs operations to compact intervals for efficiency. + - Reduces the number of elements for faster processing. + + Offset to Virtual Mapping: + - Translates physical offsets to virtual addresses. + - Creates intervals for physical offsets to virtual addresses based on `reverse_mapping`. + + Sorting Intervals: + - Sorts the intervals obtained from the previous steps. + - Essential for organized and efficient processing. + + Creating Resolution Objects: + - Constructs resolution objects using collected and sorted interval data. + - Includes objects such as `IMOffsets`, `IMOverlapping`, `IMData`. + + Conclusion: + - The function aims to reconstruct memory mappings using a radix tree traversal approach. + """ + # Explore the radix tree + mapping = defaultdict(list) + reverse_mapping = {} + self._explore_radixtree(table_addr, mapping, reverse_mapping, upmask=upmask) + + # Needed for ELF virtual mapping reconstruction + self.reverse_mapping = reverse_mapping + self.mapping = mapping + + # Collect all intervals (start, end+1, phy_page, pmask) + intervals = [] + for pmask, mapping_p in mapping.items(): + if pmask[1] == 0: # Ignore user not accessible pages + print(pmask) + continue + intervals.extend( + [(x[0], x[0] + x[1], x[2], pmask) for x in mapping_p if not x[3]] + ) # Ignore MMD + intervals.sort() + + if not intervals: + raise Exception + # Fuse intervals in order to reduce the number of elements to speed up + fused_intervals_v2o = self._compact_intervals_virt_offset(intervals) + fused_intervals_permissions = self._compact_intervals_permissions(intervals) + + # Offset to virtual is impossible to compact in a easy way due to the + # multiple-to-one mapping. We order the array and use bisection to find + # the possible results and a partial + intervals_o2v = [] + for pmasks, d in reverse_mapping.items(): + if pmasks[1] != 0: # Ignore user accessible pages + continue + for k, v in d.items(): + # We have to translate phy -> offset + offset = self.phy.p2o[k[0]] + if offset == -1: # Ignore unresolvable pages + continue + intervals_o2v.append((offset, k[1] + offset, tuple(v))) + intervals_o2v.sort() + + # Fill resolution objects + self.v2o = IMOffsets(*list(zip(*fused_intervals_v2o))) + self.o2v = IMOverlapping(intervals_o2v) + self.pmasks = IMData(*list(zip(*fused_intervals_permissions))) + + def export_virtual_memory_elf(self, elf_filename): + + """Create an ELF file containing the virtual address space of the process. + Args: + self: The instance of the class invoking the function. + elf_filename: Name of the file to create. + + Raises: + IOError: An error occurred while accessing the ELF file. + + Objectives: + - Create an ELF file containing the virtual address space of the process. + + Functioning: + - The function begins by creating the ELF header and writing it to the file. + - It then iterates through intervals of memory mappings to compact them for efficiency. + - Next, it writes segments in the new file and fills the program header. + - Finally, it modifies the ELF header to point to the program header for consistency and validity of the ELF file. + """ + + with open(elf_filename, "wb") as elf_fd: + # Create the ELF header and write it on the file + machine_data = self.phy.get_machine_data() + endianness = machine_data["Endianness"] + machine = machine_data["Architecture"].lower() + + # Create ELF main header + if "aarch64" in machine: + e_machine = 0xB7 + elif "arm" in machine: + e_machine = 0x28 + elif "riscv" in machine: + e_machine = 0xF3 + elif "x86_64" in machine: + e_machine = 0x3E + elif "386" in machine: + e_machine = 0x03 + else: + raise Exception("Unknown architecture") + + e_ehsize = 0x40 + e_phentsize = 0x38 + elf_h = bytearray(e_ehsize) + elf_h[0x00:0x04] = b"\x7fELF" # Magic + elf_h[0x04] = 2 # Elf type + elf_h[0x05] = 1 if endianness == "little" else 2 # Endianness + elf_h[0x06] = 1 # Version + elf_h[0x10:0x12] = 0x4.to_bytes(2, endianness) # e_type + elf_h[0x12:0x14] = e_machine.to_bytes(2, endianness) # e_machine + elf_h[0x14:0x18] = 0x1.to_bytes(4, endianness) # e_version + elf_h[0x34:0x36] = e_ehsize.to_bytes(2, endianness) # e_ehsize + elf_h[0x36:0x38] = e_phentsize.to_bytes(2, endianness) # e_phentsize + elf_fd.write(elf_h) + + # For each pmask try to compact intervals in order to reduce the number of segments + intervals = defaultdict(list) + for (kpmask, pmask), intervals_list in self.mapping.items(): + print(kpmask, pmask) + + if pmask == 0: # Ignore pages not accessible by the process + continue + + intervals[pmask].extend( + [(x[0], x[0] + x[1], x[2]) for x in intervals_list if not x[3]] + ) # Ignore MMD + intervals[pmask].sort() + + if len(intervals[pmask]) == 0: + intervals.pop(pmask) + continue + + # Compact them + fused_intervals = [] + prev_begin = prev_end = prev_offset = -1 + for interval in intervals[pmask]: + begin, end, phy = interval + + offset = self.phy.p2o[phy] + if offset == -1: + continue + + if ( + prev_end == begin + and prev_offset + (prev_end - prev_begin) == offset + ): + prev_end = end + else: + fused_intervals.append([prev_begin, prev_end, prev_offset]) + prev_begin = begin + prev_end = end + prev_offset = offset + + if prev_begin != begin: + fused_intervals.append([prev_begin, prev_end, prev_offset]) + else: + offset = self.phy.p2o[phy] + if offset == -1: + print(f"ERROR!! {phy}") + else: + fused_intervals.append([begin, end, offset]) + intervals[pmask] = sorted( + fused_intervals[1:], key=lambda x: x[1] - x[0], reverse=True + ) + + # Write segments in the new file and fill the program header + p_offset = len(elf_h) + offset2p_offset = ( + {} + ) # Slow but more easy to implement (best way: a tree sort structure able to be updated) + e_phnum = 0 + + for pmask, interval_list in intervals.items(): + e_phnum += len(interval_list) + for idx, interval in enumerate(interval_list): + begin, end, offset = interval + size = end - begin + if offset not in offset2p_offset: + elf_fd.write(self.phy.get_data_raw(offset, size)) + if not self.phy.get_data_raw(offset, size): + print(hex(offset), hex(size)) + new_offset = p_offset + p_offset += size + for page_idx in range(0, size, self.minimum_page): + offset2p_offset[offset + page_idx] = new_offset + page_idx + else: + new_offset = offset2p_offset[offset] + interval_list[idx].append( + new_offset + ) # Assign the new offset in the dest file + + # Create the program header containing all the segments (ignoring not in RAM pages) + e_phoff = elf_fd.tell() + p_header = bytes() + for pmask, interval_list in intervals.items(): + for begin, end, offset, p_offset in interval_list: + p_filesz = end - begin + + # Back convert offset to physical page + p_addr = self.phy.o2p[offset] + assert p_addr != -1 + + segment_entry = bytearray(e_phentsize) + segment_entry[0x00:0x04] = 0x1.to_bytes(4, endianness) # p_type + segment_entry[0x04:0x08] = pmask.to_bytes(4, endianness) # p_flags + segment_entry[0x10:0x18] = begin.to_bytes(8, endianness) # p_vaddr + segment_entry[0x18:0x20] = p_addr.to_bytes( + 8, endianness + ) # p_paddr Original physical address + segment_entry[0x28:0x30] = p_filesz.to_bytes( + 8, endianness + ) # p_memsz + segment_entry[0x08:0x10] = p_offset.to_bytes( + 8, endianness + ) # p_offset + segment_entry[0x20:0x28] = p_filesz.to_bytes( + 8, endianness + ) # p_filesz + + p_header += segment_entry + + # Write the segment header + elf_fd.write(p_header) + s_header_pos = ( + elf_fd.tell() + ) # Last position written (used if we need to write segment header) + + # Modify the ELF header to point to program header + elf_fd.seek(0x20) + elf_fd.write(e_phoff.to_bytes(8, endianness)) # e_phoff + + # If we have more than 65535 segments we have create a special Section entry contains the + # number of program entry (as specified in ELF64 specifications) + if e_phnum < 65536: + elf_fd.seek(0x38) + elf_fd.write(e_phnum.to_bytes(2, endianness)) # e_phnum + else: + elf_fd.seek(0x28) + elf_fd.write(s_header_pos.to_bytes(8, endianness)) # e_shoff + elf_fd.seek(0x38) + elf_fd.write(0xFFFF.to_bytes(2, endianness)) # e_phnum + elf_fd.write(0x40.to_bytes(2, endianness)) # e_shentsize + elf_fd.write(0x1.to_bytes(2, endianness)) # e_shnum + + section_entry = bytearray(0x40) + section_entry[0x2C:0x30] = e_phnum.to_bytes(4, endianness) # sh_info + elf_fd.seek(s_header_pos) + elf_fd.write(section_entry) + + +class IntelTranslator(AddressTranslator): + @staticmethod + def derive_mmu_settings(mmu_class, regs_dict, mphy): + if mmu_class is IntelAMD64: + dtb = ((regs_dict["cr3"] >> 12) & ((1 << (mphy - 12)) - 1)) << 12 + + elif mmu_class is IntelIA32: + dtb = ((regs_dict["cr3"] >> 12) & (1 << 20) - 1) << 12 + mphy = min(mphy, 40) + + else: + raise NotImplementedError + + return { + "dtb": dtb, + "wp": True, + "ac": False, + "nxe": True, + "smep": False, + "smap": False, + "mphy": mphy, + } + + @staticmethod + def derive_translator_class(mmu_mode): + if mmu_mode == "ia64": + return IntelAMD64 + elif mmu_mode == "pae": + return NotImplementedError + elif mmu_mode == "ia32": + return IntelIA32 + else: + raise NotImplementedError + + @staticmethod + def factory(phy, mmu_values): + machine_data = phy.get_machine_data() + mmu_mode = machine_data["MMUMode"] + mphy = machine_data["CPUSpecifics"]["MAXPHYADDR"] + + translator_c = IntelTranslator.derive_translator_class(mmu_mode) + mmu_settings = IntelTranslator.derive_mmu_settings( + translator_c, mmu_values, mphy + ) + return translator_c(phy=phy, **mmu_settings) + + def __init__( + self, dtb, phy, mphy, wp=False, ac=False, nxe=False, smap=False, smep=False + ): + super(IntelTranslator, self).__init__(dtb, phy) + self.mphy = mphy + self.wp = wp + self.ac = ac # UNUSED by Fossil + self.smap = smap + self.nxe = nxe + self.smep = smep + self.minimum_page = 0x1000 + + print("Creating resolution trees...") + self._reconstruct_mappings(self.dtb, upmask=[[False, True, True]]) + + def _finalize_virt_addr(self, virt_addr, permissions): + return virt_addr + + +class IntelIA32(IntelTranslator): + def __init__( + self, dtb, phy, mphy, wp=True, ac=False, nxe=False, smap=False, smep=False + ): + self.unpack_fmt = "> 12) & ((1 << 20) - 1)) << 12 + return True, perms_flags, addr, 0 + + # Leaf + else: + if lvl == 0: + addr = (((entry >> 13) & ((1 << (self.mphy - 32)) - 1)) << 32) | ( + ((entry >> 22) & ((1 << 10) - 1)) << 22 + ) + else: + addr = ((entry >> 12) & ((1 << 20) - 1)) << 12 + return True, perms_flags, addr, 1 << self.shifts[lvl] + + def _reconstruct_permissions(self, pmask): + k_flags, w_flags, _ = zip(*pmask) + + # Kernel page in user mode + if any(k_flags): + r = True + w = all(w_flags) if self.wp else True + return r << 2 | w << 1 | 1, 0 + + # User page in user mode + else: + r = True + w = all(w_flags) + return 0, r << 2 | w << 1 | 1 + + +class IntelAMD64(IntelTranslator): + def __init__( + self, dtb, phy, mphy, wp=True, ac=False, nxe=True, smap=False, smep=False + ): + self.unpack_fmt = "> 12) & ((1 << (self.mphy - 12)) - 1)) << 12 + return True, perms_flags, addr, 0 + + # Leaf + else: + addr = ( + (entry >> self.shifts[lvl]) + & ((1 << (self.mphy - self.shifts[lvl])) - 1) + ) << self.shifts[lvl] + return True, perms_flags, addr, 1 << self.shifts[lvl] + + def _reconstruct_permissions(self, pmask): + k_flags, w_flags, x_flags = zip(*pmask) + + # Kernel page in user mode + if any(k_flags): + r = True + w = all(w_flags) if self.wp else True + x = all(x_flags) if self.nxe else True + + return r << 2 | w << 1 | int(x), 0 + + # User page in user mode + else: + r = True + w = all(w_flags) + x = all(x_flags) if self.nxe else True + + return 0, r << 2 | w << 1 | int(x) + + def _finalize_virt_addr(self, virt_addr, permissions): + # Canonical address form + if virt_addr & 0x800000000000: + return self.prefix | virt_addr + else: + return virt_addr + + +class RISCVTranslator(AddressTranslator): + """Description: + - Represents a translator for converting physical addresses to virtual addresses in RISC-V architecture. + - Provides methods for deriving MMU settings, selecting translator class, creating translator instances, and handling virtual address mapping. + Purpose: + - To provide functionality for translating physical addresses to virtual addresses in RISC-V architecture. +""" + @staticmethod + def derive_mmu_settings(mmu_class, regs_dict): + """derive_mmu_settings(mmu_class, regs_dict): + - Static method to derive MMU (Memory Management Unit) settings based on the MMU class and register dictionary. + - Returns MMU settings including the DTB (Translation Base Register) and control flags. + """ + dtb = regs_dict["satp"] + return {"dtb": dtb, "Sum": False, "mxr": False} + + @staticmethod + def derive_translator_class(mmu_mode): + if mmu_mode == "sv39": + return RISCVSV39 + else: + return RISCVSV32 + + @staticmethod + def factory(phy, mmu_values): + machine_data = phy.get_machine_data() + mmu_mode = machine_data["MMUMode"] + translator_c = RISCVTranslator.derive_translator_class(mmu_mode) + mmu_settings = RISCVTranslator.derive_mmu_settings(translator_c, mmu_values) + return translator_c(phy=phy, **mmu_settings) + + def __init__(self, dtb, phy, Sum=True, mxr=True): + super(RISCVTranslator, self).__init__(dtb, phy) + self.Sum = Sum + self.mxr = mxr + self.minimum_page = 0x1000 + + print("Creating resolution trees...") + self._reconstruct_mappings(self.dtb, upmask=[[False, True, True, True]]) + + def _finalize_virt_addr(self, virt_addr, permissions): + return virt_addr + + def _reconstruct_permissions(self, pmask): + """_reconstruct_permissions(self, pmask): + - Method to reconstruct permissions based on the permission mask. + - Extracts permission flags (read, write, execute) from the permission mask. + - Sets permission bits for kernel and user modes. + - Returns permission settings for kernel and user modes. + """ + k_flag, r_flag, w_flag, x_flag = pmask[-1] # No hierarchy + + r = r_flag + if self.mxr: + r |= x_flag + + w = w_flag + x = x_flag + + # Kernel page in user mode + if k_flag: + return r << 2 | w << 1 | int(x), 0 + + # User page in user mode + else: + return 0, r << 2 | w << 1 | int(x) + + +class RISCVSV32(RISCVTranslator): + """Description: + - Represents a translator for RISC-V SV32 addressing mode. + - Inherits from the RISCVTranslator class. + + Attributes: + - unpack_fmt: Format string for unpacking binary data. + - total_levels: Total levels in the radix tree. + - prefix: Prefix for virtual addresses. + - table_sizes: Sizes of radix tree tables at each level. + - shifts: Shift values for computing virtual addresses. + - wordsize: Size of a word in bytes. + + Methods: + - __init__(self, dtb, phy, Sum, mxr): Constructor method to initialize RISCVSV32 instance. + - _read_entry(self, idx, entry, lvl): Method to read a radix tree entry. + + Purpose: + - To provide functionality for translating physical addresses to virtual addresses in RISC-V SV32 addressing mode. + """ + def __init__(self, dtb, phy, Sum, mxr): + self.unpack_fmt = "> 10) & ((1 << 22) - 1)) << 12 + # Leaf + if r or w or x or lvl == 1: + return True, perms_flags, addr, 1 << self.shifts[lvl] + else: + # Upper tables pointers + return True, perms_flags, addr, 0 + + +class RISCVSV39(RISCVTranslator): + def __init__(self, dtb, phy, Sum, mxr): + """Initialize the RISCVSV39 instance with specific attributes for RV39 address translation. + + Args: + - dtb: Device tree blob. + - phy: Physical memory instance. + - Sum: Sum attribute (unspecified type). + - mxr: Mxr attribute (unspecified type). + + Description: + - This constructor initializes the RISCVSV39 instance with specific attributes required for RV39 address translation. + + Attributes: + - unpack_fmt: Format string for unpacking binary data. + - total_levels: Total levels of the page table hierarchy. + - prefix: Prefix value for address translation. + - table_sizes: Sizes of the page tables at each level. + - shifts: Shift values for calculating addresses. + - wordsize: Size of the machine word. + """ + self.unpack_fmt = "> 10) & ((1 << 44) - 1)) << 12 + # Leaf + if r or w or x or lvl == 2: + return True, perms_flags, addr, 1 << self.shifts[lvl] + else: + # Upper tables pointers + return True, perms_flags, addr, 0 + + +if __name__ == "__main__": + main() diff --git a/mmushell/mmushell.py b/mmushell/mmushell.py new file mode 100755 index 0000000..9f414c9 --- /dev/null +++ b/mmushell/mmushell.py @@ -0,0 +1,165 @@ +#!/usr/bin/env python3 +import argparse +import logging +import importlib +import yaml + +from cerberus import Validator + +# Set logging configuration +logger = logging.getLogger(__name__) + +# Schema for YAML configuration file +machine_yaml_schema = { + "cpu": { + "required": True, + "type": "dict", + "schema": { + "architecture": {"required": True, "type": "string", "min": 1}, + "endianness": {"required": True, "type": "string", "min": 1}, + "bits": {"required": True, "type": "integer", "allowed": [32, 64]}, + "processor_features": {"required": False, "type": "dict"}, + "registers_values": { + "required": False, + "type": "dict", + "keysrules": {"type": "string", "min": 1}, + "valuesrules": {"type": "integer"}, + }, + }, + }, + "mmu": { + "required": True, + "type": "dict", + "schema": {"mode": {"required": True, "type": "string", "min": 1}}, + }, + "memspace": { + "required": True, + "type": "dict", + "schema": { + "ram": { + "required": True, + "type": "list", + "minlength": 1, + "schema": { + "type": "dict", + "schema": { + "start": { + "required": True, + "type": "integer", + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + }, + "end": { + "required": True, + "type": "integer", + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + }, + "dumpfile": {"required": True, "type": "string", "min": 0}, + }, + }, + }, + "not_ram": { + "required": True, + "type": "list", + "minlength": 1, + "schema": { + "type": "dict", + "schema": { + "start": { + "required": True, + "type": "integer", + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + }, + "end": { + "required": True, + "type": "integer", + "min": 0, + "max": 0xFFFFFFFFFFFFFFFF, + }, + }, + }, + }, + }, + }, +} + + +def main(): + # Parse arguments + parser = argparse.ArgumentParser() + parser.add_argument( + "MACHINE_CONFIG", + help="YAML file describing the machine", + type=argparse.FileType("r"), + ) + parser.add_argument( + "--gtruth", + help="Ground truth from QEMU registers", + type=argparse.FileType("rb", 0), + default=None, + ) + parser.add_argument( + "--session", + help="Data file of a previous MMUShell session", + type=str, + default=None, + ) + parser.add_argument( + "--debug", help="Enable debug output", action="store_true", default=False + ) + args = parser.parse_args() + + # Set logging system + fmt = "%(msg)s" + if args.debug: + logging.basicConfig(level=logging.DEBUG, format=fmt) + else: + logging.basicConfig(level=logging.INFO, format=fmt) + + # Load the machine configuration YAML file + try: + machine_config = yaml.load(args.MACHINE_CONFIG, Loader=yaml.FullLoader) + args.MACHINE_CONFIG.close() + except Exception as e: + logger.fatal("Malformed YAML file: {}".format(e)) + exit(1) + + # Validate YAML schema + yaml_validator = Validator(allow_unknown=True) + if not yaml_validator.validate(machine_config, machine_yaml_schema): + logger.fatal("Invalid YAML file. Error:" + str(yaml_validator.errors)) + exit(1) + + # Create the Machine class + try: + architecture_module = importlib.import_module( + "architectures." + machine_config["cpu"]["architecture"] + ) + except ModuleNotFoundError: + logger.fatal("Unkown architecture!") + exit(1) + + # Create a Machine starting from the parsed configuration + machine = architecture_module.Machine.from_machine_config(machine_config) + + # Launch the interactive shell + if args.gtruth: + shell = architecture_module.MMUShellGTruth(machine=machine) + else: + shell = architecture_module.MMUShell(machine=machine) + + # Load ground truth (if passed) + if args.gtruth: + shell.load_gtruth(args.gtruth) + + # Load previous data (if passed) + if args.session: + shell.reload_data_from_file(args.session) + + shell.cmdloop() + + +if __name__ == "__main__": + main() diff --git a/qemu/README b/qemu/README deleted file mode 100644 index 32fc922..0000000 --- a/qemu/README +++ /dev/null @@ -1,5 +0,0 @@ -INSTALL: -If you use Debian/Ubuntu please install - build-essential git pkg-config libgtk-3-dev python3 python3-dev python3-pip python3-venv - -To run qemu_logger.py or the run_qemu script use ALWAYS the Python 3 virtual env "mmushell_venv" created by the installation script diff --git a/qemu/README.md b/qemu/README.md new file mode 100644 index 0000000..6d2d736 --- /dev/null +++ b/qemu/README.md @@ -0,0 +1,13 @@ +## Description + +The patch `qemu_v5.0.0.patch` allows to get the ground-truth MMU values of running virtual machines through memory dumps. + +## Installation + +If you use Debian/Ubuntu please install the following packages : + +``` +build-essential git pkg-config libgtk-3-dev python3 python3-dev python3-pip python3-venv +``` + +Remember to always run `qemu_logger.py` or `run_qemu` script in the python3 venv created by the installer script (`build_qemu`). diff --git a/qemu/build_qemu b/qemu/build_qemu index a685930..69b16e1 100755 --- a/qemu/build_qemu +++ b/qemu/build_qemu @@ -16,9 +16,9 @@ make CFLAGS="-Warray-bounds=0" -j 8 cd ../../../ # Create the virtualenv and install dependencies -if [ ! -d "mmushell_venv" ] ; then - python3 -m venv mmushell_venv +if [ ! -d "venv" ] ; then + python3 -m venv venv fi -source ./mmushell_venv/bin/activate +source ./venv/bin/activate pip3 install -r qemu/requirements.txt diff --git a/qemu/qemu_logger.py b/qemu/qemu_logger.py index 52545dc..4355123 100755 --- a/qemu/qemu_logger.py +++ b/qemu/qemu_logger.py @@ -1,17 +1,19 @@ #!/usr/bin/env python3 import os import errno -from collections import defaultdict -import argparse import pickle -from qmp import QEMUMonitorProtocol -from signal import signal, SIGINT -from threading import Timer +import argparse + +from collections import defaultdict from datetime import datetime +from threading import Timer +from signal import signal, SIGINT from copy import deepcopy +from qmp import QEMUMonitorProtocol start_time_g = 0 + class CPU: def __init__(self, mem_base_addr, dump_file_name, debug): self.regs = defaultdict(dict) @@ -37,9 +39,14 @@ def dump_memory(self, qmonitor): # Grab the memory size and dump the memory res = qmonitor.cmd("query-memory-size-summary") memory_size = res["return"]["base-memory"] - qmonitor.cmd("pmemsave", {"val": self.mem_base_addr, - "size": memory_size, - "filename": self.dump_file_name + ".dump"}) + qmonitor.cmd( + "pmemsave", + { + "val": self.mem_base_addr, + "size": memory_size, + "filename": self.dump_file_name + ".dump", + }, + ) class Intel(CPU): @@ -67,15 +74,23 @@ def dump_memory(self, qmonitor): memory_size = res["return"]["base-memory"] # Dump different chunks of memory - qmonitor.cmd("pmemsave", {"val": 0x0, - "size": min(memory_size, 0xC0000000), - "filename": self.dump_file_name + ".dump.0" - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0x0, + "size": min(memory_size, 0xC0000000), + "filename": self.dump_file_name + ".dump.0", + }, + ) if memory_size >= 0xC0000000: - qmonitor.cmd("pmemsave", {"val": 0x100000000, - "size": memory_size - 0xC0000000, - "filename": self.dump_file_name + ".dump.1" - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0x100000000, + "size": memory_size - 0xC0000000, + "filename": self.dump_file_name + ".dump.1", + }, + ) class IntelQ35(Intel): @@ -84,23 +99,45 @@ def dump_memory(self, qmonitor): memory_size = res["return"]["base-memory"] # Dump different chunks of memory - qmonitor.cmd("pmemsave", {"val": 0x0, - "size": min(memory_size, 0x80000000), - "filename": self.dump_file_name + ".dump.0" - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0x0, + "size": min(memory_size, 0x80000000), + "filename": self.dump_file_name + ".dump.0", + }, + ) if memory_size >= 0x80000000: - qmonitor.cmd("pmemsave", {"val": 0x100000000, - "size": memory_size - 0x80000000, - "filename": self.dump_file_name + ".dump.1" - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0x100000000, + "size": memory_size - 0x80000000, + "filename": self.dump_file_name + ".dump.1", + }, + ) class PPC(CPU): def __init__(self, mem_base_addr, dump_file_name, debug): super(PPC, self).__init__(mem_base_addr, dump_file_name, debug) - dict_proto = {"U": {"value": 0, "modified": False}, "L": {"value": 0, "modified": False}} - self._BATS = {x: deepcopy(dict_proto) for x in - ["DBAT0", "DBAT1", "DBAT2", "DBAT3", "IBAT0", "IBAT1", "IBAT2", "IBAT3"]} + dict_proto = { + "U": {"value": 0, "modified": False}, + "L": {"value": 0, "modified": False}, + } + self._BATS = { + x: deepcopy(dict_proto) + for x in [ + "DBAT0", + "DBAT1", + "DBAT2", + "DBAT3", + "IBAT0", + "IBAT1", + "IBAT2", + "IBAT3", + ] + } def parse_log_row(self, data_log, start_time): if self.debug: @@ -115,20 +152,33 @@ def parse_log_row(self, data_log, start_time): _, vsid = keys_values[1].split("=") if reg_value not in self.regs[reg_name]: if vsid.strip() == "-1": - self.regs[reg_name][reg_value] = {"first_seen": time_now, "last_seen": time_now, "vsids": {}} + self.regs[reg_name][reg_value] = { + "first_seen": time_now, + "last_seen": time_now, + "vsids": {}, + } else: vsid = int(vsid.strip(), 16) - self.regs[reg_name][reg_value] = {"first_seen": time_now, "last_seen": time_now, - "vsids": {vsid: (time_now, time_now)}} + self.regs[reg_name][reg_value] = { + "first_seen": time_now, + "last_seen": time_now, + "vsids": {vsid: (time_now, time_now)}, + } else: self.regs[reg_name][reg_value]["last_seen"] = time_now if vsid.strip() != "-1": vsid = int(vsid.strip(), 16) if vsid not in self.regs[reg_name][reg_value]["vsids"]: - self.regs[reg_name][reg_value]["vsids"][vsid] = (time_now, time_now) + self.regs[reg_name][reg_value]["vsids"][vsid] = ( + time_now, + time_now, + ) else: first_seen = self.regs[reg_name][reg_value]["vsids"][vsid][0] - self.regs[reg_name][reg_value]["vsids"][vsid] = (first_seen, time_now) + self.regs[reg_name][reg_value]["vsids"][vsid] = ( + first_seen, + time_now, + ) pass elif "BAT" in reg_name: reg_group = reg_name[0:5] @@ -138,8 +188,14 @@ def parse_log_row(self, data_log, start_time): self._BATS[reg_group][reg_part]["value"] = reg_value self._BATS[reg_group][reg_part]["modified"] = True - if self._BATS[reg_group]["U"]["modified"] and self._BATS[reg_group]["L"]["modified"]: - regs_values = (self._BATS[reg_group]["U"]["value"], self._BATS[reg_group]["L"]["value"]) + if ( + self._BATS[reg_group]["U"]["modified"] + and self._BATS[reg_group]["L"]["modified"] + ): + regs_values = ( + self._BATS[reg_group]["U"]["value"], + self._BATS[reg_group]["L"]["value"], + ) if regs_values not in self.regs[reg_group]: self.regs[reg_group][regs_values] = (time_now, time_now) else: @@ -164,17 +220,21 @@ class Arm(CPU): class ArmVirtSecure(CPU): def __init__(self, mem_base_addr, dump_file_name, debug): super(ArmVirtSecure, self).__init__(mem_base_addr, dump_file_name, debug) - self.secure_mem_base_addr = 0xe000000 + self.secure_mem_base_addr = 0xE000000 self.secure_memory_size = 0x01000000 def dump_memory(self, qmonitor): super(ArmVirtSecure, self).dump_memory(qmonitor) # Dump also secure memory - qmonitor.cmd("pmemsave", {"val": self.secure_mem_base_addr, - "size": self.secure_memory_size, - "filename": self.dump_file_name + "_secure.dump" - }) + qmonitor.cmd( + "pmemsave", + { + "val": self.secure_mem_base_addr, + "size": self.secure_memory_size, + "filename": self.dump_file_name + "_secure.dump", + }, + ) class ARM_integratorcp(Arm): @@ -182,37 +242,47 @@ def dump_memory(self, qmonitor): res = qmonitor.cmd("query-memory-size-summary") memory_size = res["return"]["base-memory"] - memory_chunks = [(0x0000000000000000, 0x000000000fffffff), - (0x0000000010800000, 0x0000000012ffffff), - (0x0000000013001000, 0x0000000013ffffff), - (0x0000000014800000, 0x0000000014ffffff), - (0x0000000015001000, 0x0000000015ffffff), - (0x0000000016001000, 0x0000000016ffffff), - (0x0000000017001000, 0x0000000017ffffff), - (0x0000000018001000, 0x0000000018ffffff), - (0x0000000019001000, 0x0000000019ffffff), - (0x000000001b000000, 0x000000001bffffff), - (0x000000001c001000, 0x000000001cffffff), - (0x000000001d001000, 0x00000000bfffffff), - (0x00000000c0001000, 0x00000000c7ffffff), - (0x00000000c8000010, 0x00000000c9ffffff), - (0x00000000ca800000, 0x00000000caffffff)] + memory_chunks = [ + (0x0000000000000000, 0x000000000FFFFFFF), + (0x0000000010800000, 0x0000000012FFFFFF), + (0x0000000013001000, 0x0000000013FFFFFF), + (0x0000000014800000, 0x0000000014FFFFFF), + (0x0000000015001000, 0x0000000015FFFFFF), + (0x0000000016001000, 0x0000000016FFFFFF), + (0x0000000017001000, 0x0000000017FFFFFF), + (0x0000000018001000, 0x0000000018FFFFFF), + (0x0000000019001000, 0x0000000019FFFFFF), + (0x000000001B000000, 0x000000001BFFFFFF), + (0x000000001C001000, 0x000000001CFFFFFF), + (0x000000001D001000, 0x00000000BFFFFFFF), + (0x00000000C0001000, 0x00000000C7FFFFFF), + (0x00000000C8000010, 0x00000000C9FFFFFF), + (0x00000000CA800000, 0x00000000CAFFFFFF), + ] dumped_size = 0 i = 0 for i, chunk in enumerate(memory_chunks): dumped_chunk_size = min(memory_size - dumped_size, chunk[1] - chunk[0] + 1) - qmonitor.cmd("pmemsave", {"val": chunk[0], - "size": dumped_chunk_size, - "filename": self.dump_file_name + ".dump." + str(i) - }) + qmonitor.cmd( + "pmemsave", + { + "val": chunk[0], + "size": dumped_chunk_size, + "filename": self.dump_file_name + ".dump." + str(i), + }, + ) dumped_size += dumped_chunk_size if dumped_size < memory_size: - qmonitor.cmd("pmemsave", {"val": 0xcb800000, - "size": memory_size - dumped_size, - "filename": self.dump_file_name + ".dump." + str(i+1) - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0xCB800000, + "size": memory_size - dumped_size, + "filename": self.dump_file_name + ".dump." + str(i + 1), + }, + ) class ARM_raspi3(Arm): @@ -220,46 +290,56 @@ def dump_memory(self, qmonitor): res = qmonitor.cmd("query-memory-size-summary") memory_size = res["return"]["base-memory"] - memory_chunks = [(0x0, 0x3f002fff), - (0x3f003020, 0x3f006fff), - (0x3f008000, 0x3f00b1ff), - (0x3f00b440, 0x3f00b7ff), - (0x3f00bc00, 0x3f0fffff), - (0x3f101000, 0x3f101fff), - (0x3f103000, 0x3f103fff), - (0x3f104010, 0x3f1fffff), - (0x3f203100, 0x3f203fff), - (0x3f204020, 0x3f204fff), - (0x3f205020, 0x3f20efff), - (0x3f20f080, 0x3f211fff), - (0x3f212008, 0x3f213fff), - (0x3f214100, 0x3f214fff), - (0x3f215100, 0x3f2fffff), - (0x3f300100, 0x3f5fffff), - (0x3f600100, 0x3f803fff), - (0x3f804020, 0x3f804fff), - (0x3f805020, 0x3f8fffff), - (0x3f908000, 0x3f90ffff), - (0x3f918000, 0x3f97ffff), - (0x3f981000, 0x3fdfffff), - (0x3fe00100, 0x3fe04fff), - (0x3fe05100, 0x3fffffff)] + memory_chunks = [ + (0x0, 0x3F002FFF), + (0x3F003020, 0x3F006FFF), + (0x3F008000, 0x3F00B1FF), + (0x3F00B440, 0x3F00B7FF), + (0x3F00BC00, 0x3F0FFFFF), + (0x3F101000, 0x3F101FFF), + (0x3F103000, 0x3F103FFF), + (0x3F104010, 0x3F1FFFFF), + (0x3F203100, 0x3F203FFF), + (0x3F204020, 0x3F204FFF), + (0x3F205020, 0x3F20EFFF), + (0x3F20F080, 0x3F211FFF), + (0x3F212008, 0x3F213FFF), + (0x3F214100, 0x3F214FFF), + (0x3F215100, 0x3F2FFFFF), + (0x3F300100, 0x3F5FFFFF), + (0x3F600100, 0x3F803FFF), + (0x3F804020, 0x3F804FFF), + (0x3F805020, 0x3F8FFFFF), + (0x3F908000, 0x3F90FFFF), + (0x3F918000, 0x3F97FFFF), + (0x3F981000, 0x3FDFFFFF), + (0x3FE00100, 0x3FE04FFF), + (0x3FE05100, 0x3FFFFFFF), + ] dumped_size = 0 i = 0 for i, chunk in enumerate(memory_chunks): dumped_chunk_size = min(memory_size - dumped_size, chunk[1] - chunk[0] + 1) - qmonitor.cmd("pmemsave", {"val": chunk[0], - "size": dumped_chunk_size, - "filename": self.dump_file_name + ".dump." + str(i) - }) + qmonitor.cmd( + "pmemsave", + { + "val": chunk[0], + "size": dumped_chunk_size, + "filename": self.dump_file_name + ".dump." + str(i), + }, + ) dumped_size += dumped_chunk_size if dumped_size < memory_size: - qmonitor.cmd("pmemsave", {"val": 0xcb800000, - "size": memory_size - dumped_size, - "filename": self.dump_file_name + ".dump." + str(i+1) - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0xCB800000, + "size": memory_size - dumped_size, + "filename": self.dump_file_name + ".dump." + str(i + 1), + }, + ) class RISCV(CPU): @@ -276,15 +356,23 @@ def dump_memory(self, qmonitor): memory_size = res["return"]["base-memory"] # Dump different chunks of memory - qmonitor.cmd("pmemsave", {"val": 0x0, - "size": min(memory_size, 0x10000000), - "filename": self.dump_file_name + ".dump.0" - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0x0, + "size": min(memory_size, 0x10000000), + "filename": self.dump_file_name + ".dump.0", + }, + ) if memory_size >= 0x10000000: - qmonitor.cmd("pmemsave", {"val": 0x20000000, - "size": memory_size - 0x10000000, - "filename": self.dump_file_name + ".dump.1" - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0x20000000, + "size": memory_size - 0x10000000, + "filename": self.dump_file_name + ".dump.1", + }, + ) class MIPS_mipssim(MIPS): @@ -293,15 +381,23 @@ def dump_memory(self, qmonitor): memory_size = res["return"]["base-memory"] # Dump different chunks of memory - qmonitor.cmd("pmemsave", {"val": 0x0, - "size": min(memory_size, 0x1fc00000), - "filename": self.dump_file_name + ".dump.0" - }) - if memory_size >= 0x1fc00000: - qmonitor.cmd("pmemsave", {"val": 0x20000000, - "size": min(memory_size - 0x1fc00000, 0xe0000000), - "filename": self.dump_file_name + ".dump.1" - }) + qmonitor.cmd( + "pmemsave", + { + "val": 0x0, + "size": min(memory_size, 0x1FC00000), + "filename": self.dump_file_name + ".dump.0", + }, + ) + if memory_size >= 0x1FC00000: + qmonitor.cmd( + "pmemsave", + { + "val": 0x20000000, + "size": min(memory_size - 0x1FC00000, 0xE0000000), + "filename": self.dump_file_name + ".dump.1", + }, + ) class POWER(CPU): @@ -318,20 +414,33 @@ def parse_log_row(self, data_log, start_time): _, vsid = keys_values[1].split("=") if reg_value not in self.regs[reg_name]: if vsid.strip() == "-1": - self.regs[reg_name][reg_value] = {"first_seen": time_now, "last_seen": time_now, "vsids": {}} + self.regs[reg_name][reg_value] = { + "first_seen": time_now, + "last_seen": time_now, + "vsids": {}, + } else: vsid = int(vsid.strip(), 16) - self.regs[reg_name][reg_value] = {"first_seen": time_now, "last_seen": time_now, - "vsids": {vsid: (time_now, time_now)}} + self.regs[reg_name][reg_value] = { + "first_seen": time_now, + "last_seen": time_now, + "vsids": {vsid: (time_now, time_now)}, + } else: self.regs[reg_name][reg_value]["last_seen"] = time_now if vsid.strip() != "-1": vsid = int(vsid.strip(), 16) if vsid not in self.regs[reg_name][reg_value]["vsids"]: - self.regs[reg_name][reg_value]["vsids"][vsid] = (time_now, time_now) + self.regs[reg_name][reg_value]["vsids"][vsid] = ( + time_now, + time_now, + ) else: first_seen = self.regs[reg_name][reg_value]["vsids"][vsid][0] - self.regs[reg_name][reg_value]["vsids"][vsid] = (first_seen, time_now) + self.regs[reg_name][reg_value]["vsids"][vsid] = ( + first_seen, + time_now, + ) pass else: @@ -342,13 +451,19 @@ def parse_log_row(self, data_log, start_time): self.regs[reg_name][reg_value] = (first_seen, time_now) -parser = argparse.ArgumentParser(description='You have to call QEMU with "-qmp tcp:HOST:PORT,server,' - 'nowait -d fossil -D pipe_file" options and WITHOUT "-enable-kvm" option') +parser = argparse.ArgumentParser( + description='You have to call QEMU with "-qmp tcp:HOST:PORT,server,' + 'nowait -d fossil -D pipe_file" options and WITHOUT "-enable-kvm" option' +) parser.add_argument("pipe_file", help="PIPE for QEMU log file", type=str) parser.add_argument("qmp", help="QEMU QMP channel (host:port)", type=str) parser.add_argument("prefix_filename", help="Prefix for dump and .regs file.", type=str) -parser.add_argument("--debug", help="Print debug info", action="store_true", default=False) -parser.add_argument("--timer", help="Shutdown machine after N seconds", type=int, default=0) +parser.add_argument( + "--debug", help="Print debug info", action="store_true", default=False +) +parser.add_argument( + "--timer", help="Shutdown machine after N seconds", type=int, default=0 +) subparser = parser.add_subparsers(required=True, help="Architectures", dest="arch") parser_intel = subparser.add_parser("intel") parser_intelq35 = subparser.add_parser("intel_q35") @@ -481,9 +596,13 @@ def timer_handler(): if args.timer > 0: t = Timer(args.timer, timer_handler) t.start() - print("After {} seconds it dump the memory, save the registers, " - "and shutdown the machine. So wait...".format(str(args.timer))) + print( + "After {} seconds it dump the memory, save the registers, " + "and shutdown the machine. So wait...".format(str(args.timer)) + ) else: - print("Press CTRL-C to dump the memory, save the registers, and shutdown the machine") + print( + "Press CTRL-C to dump the memory, save the registers, and shutdown the machine" + ) log_class.parse_log_row(data, start_time_g) diff --git a/requirements.txt b/requirements.txt index 8c72ce6..d1b1c18 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,10 @@ +# mmushell wheel qmp prettytable tqdm colorama -git+git://github.com/cea-sec/miasm.git@218492cd10b339a8d47d2fdbd61953fcf954fb8b#egg=miasm +miasm pyparsing future portion @@ -15,3 +16,10 @@ cerberus sortedcontainers numpy pyelftools + +# documentation +mkdocs +mkdocs-material +mkdocs-git-revision-date-localized-plugin +mkdocs-git-committers-plugin-2 +mkdocstrings[python]