From 42b34c5884d7cfa6cc322b5a428764b7d3483d67 Mon Sep 17 00:00:00 2001 From: Avi Deitcher Date: Mon, 5 Jan 2026 12:06:41 +0200 Subject: [PATCH] properly handle partition index with get and create Signed-off-by: Avi Deitcher --- disk/disk.go | 83 ++++++++++++++++++++++++----------------------- disk/disk_test.go | 36 ++++++++++---------- disk/error.go | 14 ++++++++ 3 files changed, 75 insertions(+), 58 deletions(-) diff --git a/disk/disk.go b/disk/disk.go index 139dedd..30a6d51 100644 --- a/disk/disk.go +++ b/disk/disk.go @@ -16,6 +16,7 @@ import ( "github.com/diskfs/go-diskfs/filesystem/iso9660" "github.com/diskfs/go-diskfs/filesystem/squashfs" "github.com/diskfs/go-diskfs/partition" + "github.com/diskfs/go-diskfs/partition/part" log "github.com/sirupsen/logrus" ) @@ -81,23 +82,16 @@ func (d *Disk) Partition(table partition.Table) error { // // returns an error if there was an error writing to the disk, reading from the reader, the table // is invalid, or the partition is invalid -func (d *Disk) WritePartitionContents(part int, reader io.Reader) (int64, error) { +func (d *Disk) WritePartitionContents(partIndex int, reader io.Reader) (int64, error) { backingRwFile, err := d.Backend.Writable() if err != nil { return -1, err } - if d.Table == nil { - return -1, fmt.Errorf("cannot write contents of a partition on a disk without a partition table") - } - if part < 0 { - return -1, fmt.Errorf("cannot write contents of a partition without specifying a partition") - } - partitions := d.Table.GetPartitions() - // API indexes from 1, but slice from 0 - if part > len(partitions) { - return -1, fmt.Errorf("cannot write contents of partition %d which is greater than max partition %d", part, len(partitions)) + foundPart, err := d.GetPartition(partIndex) + if err != nil { + return -1, err } - written, err := partitions[part-1].WriteContents(backingRwFile, reader) + written, err := foundPart.WriteContents(backingRwFile, reader) return int64(written), err } @@ -107,19 +101,12 @@ func (d *Disk) WritePartitionContents(part int, reader io.Reader) (int64, error) // // returns an error if there was an error reading from the disk, writing to the writer, the table // is invalid, or the partition is invalid -func (d *Disk) ReadPartitionContents(part int, writer io.Writer) (int64, error) { - if d.Table == nil { - return -1, fmt.Errorf("cannot read contents of a partition on a disk without a partition table") - } - if part < 0 { - return -1, fmt.Errorf("cannot read contents of a partition without specifying a partition") - } - partitions := d.Table.GetPartitions() - // API indexes from 1, but slice from 0 - if part > len(partitions) { - return -1, fmt.Errorf("cannot read contents of partition %d which is greater than max partition %d", part, len(partitions)) +func (d *Disk) ReadPartitionContents(partIndex int, writer io.Writer) (int64, error) { + foundPart, err := d.GetPartition(partIndex) + if err != nil { + return -1, err } - return partitions[part-1].ReadContents(d.Backend, writer) + return foundPart.ReadContents(d.Backend, writer) } // FilesystemSpec represents the specification of a filesystem to be created @@ -160,14 +147,12 @@ func (d *Disk) CreateFilesystem(spec FilesystemSpec) (filesystem.FileSystem, err case d.Table == nil: return nil, fmt.Errorf("cannot create filesystem on a partition without a partition table") default: - partitions := d.Table.GetPartitions() - // API indexes from 1, but slice from 0 - part := spec.Partition - 1 - if spec.Partition > len(partitions) { - return nil, fmt.Errorf("cannot create filesystem on partition %d greater than maximum partition %d", spec.Partition, len(partitions)) + foundPart, err := d.GetPartition(spec.Partition) + if err != nil { + return nil, err } - size = partitions[part].GetSize() - start = partitions[part].GetStart() + size = foundPart.GetSize() + start = foundPart.GetStart() } switch spec.FSType { @@ -192,7 +177,7 @@ func (d *Disk) CreateFilesystem(spec FilesystemSpec) (filesystem.FileSystem, err // // returns error if there was an error reading the filesystem, or the partition table is invalid and did not // request the entire disk. -func (d *Disk) GetFilesystem(part int) (filesystem.FileSystem, error) { +func (d *Disk) GetFilesystem(partIndex int) (filesystem.FileSystem, error) { // find out where the partition starts and ends, or if it is the entire disk var ( size, start int64 @@ -200,19 +185,18 @@ func (d *Disk) GetFilesystem(part int) (filesystem.FileSystem, error) { ) switch { - case part == 0: + case partIndex == 0: size = d.Size start = 0 case d.Table == nil: return nil, &NoPartitionTableError{} default: - partitions := d.Table.GetPartitions() - // API indexes from 1, but slice from 0 - if part > len(partitions) { - return nil, NewMaxPartitionsExceededError(part, len(partitions)) + foundPart, err := d.GetPartition(partIndex) + if err != nil { + return nil, err } - size = partitions[part-1].GetSize() - start = partitions[part-1].GetStart() + size = foundPart.GetSize() + start = foundPart.GetStart() } // just try each type @@ -242,7 +226,7 @@ func (d *Disk) GetFilesystem(part int) (filesystem.FileSystem, error) { return ext4FS, nil } log.Debugf("ext4 failed: %v", err) - return nil, NewUnknownFilesystemError(part) + return nil, NewUnknownFilesystemError(partIndex) } // Close the disk. Once successfully closed, it can no longer be used. @@ -253,3 +237,22 @@ func (d *Disk) Close() error { *d = Disk{} return nil } + +func (d *Disk) GetPartition(partIndex int) (part.Partition, error) { + if d.Table == nil { + return nil, &NoPartitionTableError{} + } + partitions := d.Table.GetPartitions() + // find the specific partition + var foundPart part.Partition + for _, p := range partitions { + if p.GetIndex() == partIndex { + foundPart = p + break + } + } + if foundPart == nil { + return nil, NewInvalidPartitionError(partIndex) + } + return foundPart, nil +} diff --git a/disk/disk_test.go b/disk/disk_test.go index 6fdcf8a..37353f4 100644 --- a/disk/disk_test.go +++ b/disk/disk_test.go @@ -187,7 +187,7 @@ func TestWritePartitionContents(t *testing.T) { partitionEnd := partitionStart + partitionSize/512 - 1 table := &gpt.Table{ Partitions: []*gpt.Partition{ - {Start: 2048, End: partitionEnd, Type: gpt.EFISystemPartition, Name: "EFI System"}, + {Index: 1, Start: 2048, End: partitionEnd, Type: gpt.EFISystemPartition, Name: "EFI System"}, }, LogicalSectorSize: 512, } @@ -198,9 +198,9 @@ func TestWritePartitionContents(t *testing.T) { err error }{ // various invalid table scenarios - {"no table, write to partition 1", nil, 1, fmt.Errorf("cannot write contents of a partition on a disk without a partition table")}, - {"no table, write to partition 0", nil, 0, fmt.Errorf("cannot write contents of a partition on a disk without a partition table")}, - {"no table, write to partition -1", nil, -1, fmt.Errorf("cannot write contents of a partition on a disk without a partition table")}, + {"no table, write to partition 1", nil, 1, fmt.Errorf("no partition table found on disk")}, + {"no table, write to partition 0", nil, 0, fmt.Errorf("no partition table found on disk")}, + {"no table, write to partition -1", nil, -1, fmt.Errorf("no partition table found on disk")}, {"good table, write to partition 1", table, 1, nil}, } for _, t2 := range tests { @@ -258,7 +258,7 @@ func TestReadPartitionContents(t *testing.T) { partitionSize := uint64(1000) table := &gpt.Table{ Partitions: []*gpt.Partition{ - {Start: partitionStart, Size: partitionSize * 512, Type: gpt.LinuxFilesystem}, + {Index: 1, Start: partitionStart, Size: partitionSize * 512, Type: gpt.LinuxFilesystem}, }, LogicalSectorSize: 512, } @@ -269,12 +269,12 @@ func TestReadPartitionContents(t *testing.T) { err error }{ // various invalid table scenarios - {"no table, partition 1", nil, 1, fmt.Errorf("cannot read contents of a partition on a disk without a partition table")}, - {"no table, partition 0", nil, 0, fmt.Errorf("cannot read contents of a partition on a disk without a partition table")}, - {"no table, partition -1", nil, -1, fmt.Errorf("cannot read contents of a partition on a disk without a partition table")}, + {"no table, partition 1", nil, 1, fmt.Errorf("no partition table found on disk")}, + {"no table, partition 0", nil, 0, fmt.Errorf("no partition table found on disk")}, + {"no table, partition -1", nil, -1, fmt.Errorf("no partition table found on disk")}, // invalid partition number scenarios - {"good table, partition -1", table, -1, fmt.Errorf("cannot read contents of a partition without specifying a partition")}, - {"good table, partition greater than max", table, 5, fmt.Errorf("cannot read contents of partition %d which is greater than max partition %d", 5, 1)}, + {"good table, partition -1", table, -1, fmt.Errorf("requested partition -1 not found")}, + {"good table, partition greater than max", table, 5, fmt.Errorf("requested partition %d not found", 5)}, {"good table, good partition 1", table, 1, nil}, } for _, t2 := range tests { @@ -325,7 +325,7 @@ func TestReadPartitionContents(t *testing.T) { partitionSize := uint32(1000) table := &mbr.Table{ Partitions: []*mbr.Partition{ - {Start: partitionStart, Size: partitionSize}, + {Index: 1, Start: partitionStart, Size: partitionSize}, }, LogicalSectorSize: 512, } @@ -336,12 +336,12 @@ func TestReadPartitionContents(t *testing.T) { err error }{ // various invalid table scenarios - {"no table partition 1", nil, 1, fmt.Errorf("cannot read contents of a partition on a disk without a partition table")}, - {"no table partition 0", nil, 0, fmt.Errorf("cannot read contents of a partition on a disk without a partition table")}, - {"no table partition -1", nil, -1, fmt.Errorf("cannot read contents of a partition on a disk without a partition table")}, + {"no table partition 1", nil, 1, fmt.Errorf("no partition table found on disk")}, + {"no table partition 0", nil, 0, fmt.Errorf("no partition table found on disk")}, + {"no table partition -1", nil, -1, fmt.Errorf("no partition table found on disk")}, // invalid partition number scenarios - {"valid table partition -1", table, -1, fmt.Errorf("cannot read contents of a partition without specifying a partition")}, - {"valid table partition 5", table, 5, fmt.Errorf("cannot read contents of partition %d which is greater than max partition %d", 5, 1)}, + {"valid table partition -1", table, -1, fmt.Errorf("requested partition -1 not found")}, + {"valid table partition 5", table, 5, fmt.Errorf("requested partition %d not found", 5)}, {"valid table partition 1", table, 1, nil}, } for _, t2 := range tests { @@ -472,7 +472,7 @@ func TestCreateFilesystem(t *testing.T) { partitionSize := uint32(20480) table := &mbr.Table{ Partitions: []*mbr.Partition{ - {Start: partitionStart, Size: partitionSize}, + {Index: 1, Start: partitionStart, Size: partitionSize}, }, LogicalSectorSize: 512, } @@ -597,7 +597,7 @@ func TestGetFilesystem(t *testing.T) { partitionSize := uint32(20480) table := &mbr.Table{ Partitions: []*mbr.Partition{ - {Start: partitionStart, Size: partitionSize}, + {Index: 1, Start: partitionStart, Size: partitionSize}, }, LogicalSectorSize: 512, } diff --git a/disk/error.go b/disk/error.go index ea4c0e9..90f2687 100644 --- a/disk/error.go +++ b/disk/error.go @@ -37,3 +37,17 @@ func NewMaxPartitionsExceededError(requested, maxPart int) *MaxPartitionsExceede max: maxPart, } } + +type InvalidPartitionError struct { + requested int +} + +func (e *InvalidPartitionError) Error() string { + return fmt.Sprintf("requested partition %d not found", e.requested) +} + +func NewInvalidPartitionError(requested int) *InvalidPartitionError { + return &InvalidPartitionError{ + requested: requested, + } +}