diff --git a/goose_std.go b/goose_std.go index 987231e..694e8ef 100644 --- a/goose_std.go +++ b/goose_std.go @@ -69,6 +69,22 @@ func SumAssumeNoOverflow(x uint64, y uint64) uint64 { return x + y } +// MulNoOverflow returns true if x * y does not overflow +func MulNoOverflow(x uint64, y uint64) bool { + if x == 0 || y == 0 { + return true + } + return x <= (1<<64-1)/y +} + +// MulAssumeNoOverflow returns x * y, `Assume`ing that this does not overflow. +// +// *Use with care* - if the assumption is violated this function will panic. +func MulAssumeNoOverflow(x uint64, y uint64) uint64 { + primitive.Assume(MulNoOverflow(x, y)) + return x * y +} + // JoinHandle is a mechanism to wait for a goroutine to finish. Calling `Join()` // on the handle returned by `Spawn(f)` will wait for f to finish. type JoinHandle struct { diff --git a/goose_std_test.go b/goose_std_test.go index 265089d..802e56b 100644 --- a/goose_std_test.go +++ b/goose_std_test.go @@ -84,6 +84,21 @@ func TestSumAssumeNoOverflow(t *testing.T) { }) } +func TestMulAssumeNoOverflow(t *testing.T) { + assert := assert.New(t) + + assert.Equal(uint64(6), MulAssumeNoOverflow(2, 3)) + assert.Equal(uint64(0), MulAssumeNoOverflow(0, 3)) + assert.Equal(uint64(1<<64-1), MulAssumeNoOverflow(1<<32-1, 1<<32+1)) + assert.Equal(uint64(1<<64-1), MulAssumeNoOverflow(1<<32+1, 1<<32-1)) + assert.Panics(func() { + MulAssumeNoOverflow(1<<63, 2) + }) + assert.Panics(func() { + MulAssumeNoOverflow(2, 1<<63) + }) +} + func TestMultipar(t *testing.T) { ch := make(chan uint64) go Multipar(5, func(i uint64) {