diff --git a/probert/storage.py b/probert/storage.py index 453ddac..27ad6b8 100644 --- a/probert/storage.py +++ b/probert/storage.py @@ -21,7 +21,7 @@ import subprocess from probert.utils import ( - read_sys_block_size_bytes, + read_sys_devpath_size_bytes, sane_block_devices, udev_get_attributes, ) @@ -135,11 +135,12 @@ def _extract_partition_table(devname): blockdev = {} for device in interesting_storage_devs(context): devname = device.properties['DEVNAME'] + devpath = device.properties['DEVPATH'] attrs = udev_get_attributes(device) # update the size attr as it may only be the number # of blocks rather than size in bytes. - attrs['size'] = \ - str(read_sys_block_size_bytes(devname)) + attrs['size'] = str(read_sys_devpath_size_bytes( + devpath, log_inexistent=True)) # When dereferencing device[prop], pyudev calls bytes.decode(), which # can fail if the value is invalid utf-8. We don't want a single # invalid value to completely prevent probing. So we iterate diff --git a/probert/tests/test_utils.py b/probert/tests/test_utils.py index 1cd10d3..7cc1b9e 100644 --- a/probert/tests/test_utils.py +++ b/probert/tests/test_utils.py @@ -1,13 +1,14 @@ import contextlib import logging import os +import pathlib import tempfile import textwrap import unittest from unittest.mock import call from probert import utils -from probert.tests.helpers import random_string, simple_mocked_open +from probert.tests.helpers import random_string class ProbertTestUtils(unittest.TestCase): @@ -45,23 +46,67 @@ def test_utils_dict_merge_dicts(self): def test_utils_read_sys_block_size_bytes(self): devname = random_string() - expected_fname = '/sys/class/block/%s/size' % devname + expected_path = pathlib.Path(f'/sys/class/block/{devname}/size') expected_bytes = 10737418240 content = '20971520' - with simple_mocked_open(content=content) as m_open: + + with unittest.mock.patch("probert.utils.Path.read_text", + autospec=True, + return_value=content) as m_read_text: result = utils.read_sys_block_size_bytes(devname) self.assertEqual(expected_bytes, result) - self.assertEqual([call(expected_fname)], m_open.call_args_list) + m_read_text.assert_called_once() + self.assertEqual([call(expected_path)], m_read_text.call_args_list) def test_utils_read_sys_block_size_bytes_strips_value(self): devname = random_string() - expected_fname = '/sys/class/block/%s/size' % devname + expected_path = pathlib.Path(f'/sys/class/block/{devname}/size') expected_bytes = 10737418240 content = ' 20971520 \n ' - with simple_mocked_open(content=content) as m_open: + + with unittest.mock.patch("probert.utils.Path.read_text", + autospec=True, + return_value=content) as m_read_text: result = utils.read_sys_block_size_bytes(devname) self.assertEqual(expected_bytes, result) - self.assertEqual([call(expected_fname)], m_open.call_args_list) + m_read_text.assert_called_once() + self.assertEqual([call(expected_path)], m_read_text.call_args_list) + + def test_utils_read_sys_devpath_size_bytes_strips_value(self): + devpath = """\ +/devices/pci0000:00/0000:00:1d.0/0000:03:00.0/nvme/nvme0/nvme0n1/nvme0n1p3""" + expected_path = pathlib.Path(f'/sys{devpath}/size') + expected_bytes = 10737418240 + content = ' 20971520 \n ' + + with unittest.mock.patch("probert.utils.Path.read_text", + autospec=True, + return_value=content) as m_read_text: + result = utils.read_sys_devpath_size_bytes(devpath) + self.assertEqual(expected_bytes, result) + self.assertEqual([call(expected_path)], m_read_text.call_args_list) + + def test_utils_read_sys_devpath_size_bytes__inexistent_nologging(self): + with self.assertRaises(FileNotFoundError): + utils.read_sys_devpath_size_bytes("/devices/that/does/not/exist") + + def test_utils_read_sys_devpath_size_bytes__existent_directory(self): + with self.assertRaises(FileNotFoundError) as cm_exc: + # /sys/devices/ should not exist but /sys/devices should + with self.assertLogs("probert.utils", level="WARNING") as cm_log: + utils.read_sys_devpath_size_bytes("/devices", log_inexistent=True) + self.assertEqual("%s contains %s", cm_log.records[0].msg) + path, child_paths = cm_log.records[0].args + self.assertEqual(pathlib.Path("/sys/devices"), path) + for child in child_paths: + self.assertIsInstance(child, pathlib.Path) + self.assertIsNone(cm_exc.exception.__context__) + + def test_utils_read_sys_devpath_size_bytes__inexistent_directory(self): + with self.assertRaises(FileNotFoundError) as cm_exc: + utils.read_sys_devpath_size_bytes("/devices/that/does/not/exist/nvme0n1p3", + log_inexistent=True) + self.assertIsInstance(cm_exc.exception.__context__, FileNotFoundError) @contextlib.contextmanager diff --git a/probert/utils.py b/probert/utils.py index 0e08cc1..69297d7 100644 --- a/probert/utils.py +++ b/probert/utils.py @@ -7,6 +7,7 @@ import re import shlex import subprocess +from pathlib import Path from subprocess import PIPE import pyudev @@ -293,14 +294,28 @@ def parse_etc_network_interfaces(ifaces, contents, path): ifaces[iface]['auto'] = False -def read_sys_block_size_bytes(device): - """ /sys/class/block//size and return integer value in bytes""" - device_dir = os.path.join('/sys/class/block', os.path.basename(device)) - blockdev_size = os.path.join(device_dir, 'size') - with open(blockdev_size) as d: - size = int(d.read().strip()) * SECTOR_SIZE_BYTES - - return size +def read_sys_block_size_bytes(device: str) -> int: + """ /sys/class/block//size and return integer value in bytes. + NOTE: if you are not sure whether the /sys/class/block/ directory + exists, consider using read_sys_devpath_size_bytes instead. """ + path = Path("/sys/class/block") / os.path.basename(device) / "size" + return int(path.read_text().strip()) * SECTOR_SIZE_BYTES + + +def read_sys_devpath_size_bytes(devpath: Path | str, + log_inexistent=False) -> int: + """ Based on the value of a DEVPATH udev property, return the associated + size (converted to bytes) by reading from the sysfs. """ + path = Path("/sys") / Path(devpath).relative_to("/") / "size" + + try: + return int(path.read_text().strip()) * SECTOR_SIZE_BYTES + except FileNotFoundError: + if log_inexistent: + # path.parent.iterdir can raise another FileNotFoundError exception + # This is fine, we will know it means the directory does not exist. + log.warning("%s contains %s", path.parent, list(path.parent.iterdir())) + raise def read_sys_block_slaves(device):