Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 153 additions & 0 deletions set/set.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright The Perses Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package set

import (
"cmp"
"encoding/json"
"maps"
"reflect"
"slices"
"strings"
)

type Set[T comparable] map[T]struct{}

func NewSet[T comparable](vals ...T) Set[T] {
s := Set[T]{}
for _, v := range vals {
s[v] = struct{}{}
}
return s
}

func MergeSet[T comparable](old, new Set[T]) Set[T] {
if new == nil {
return old
}
if old == nil {
return new
}
s := Set[T]{}
maps.Copy(s, new)
maps.Copy(s, old)
return s
}

func (s Set[T]) Add(vals ...T) {
for _, v := range vals {
s[v] = struct{}{}
}
}

func (s Set[T]) Remove(value T) {
delete(s, value)
}

func (s Set[T]) Contains(value T) bool {
_, ok := s[value]
return ok
}

func (s Set[T]) Merge(other Set[T]) {
for v := range other {
s.Add(v)
}
}

func (s Set[T]) TransformAsSlice() []T {
if s == nil {
return nil
}

var slice []T
for v := range s {
slice = append(slice, v)
}
slices.SortFunc(slice, compare[T])

return slice
}

func (s Set[T]) MarshalJSON() ([]byte, error) {
if len(s) == 0 {
return []byte("[]"), nil
}

return json.Marshal(s.TransformAsSlice())
}

func (s *Set[T]) UnmarshalJSON(b []byte) error {
var slice []T
if err := json.Unmarshal(b, &slice); err != nil {
return err
}
if len(slice) == 0 {
return nil
}
*s = make(map[T]struct{}, len(slice))
for _, v := range slice {
s.Add(v)
}
return nil
}

// compare has similar semantics to cmp.Compare except that it works for
// strings and structs. When comparing Go structs, it only checks the struct
// fields of string type.
// If the compared values aren't strings or structs, they are considered equal.
func compare[T comparable](a, b T) int {
switch reflect.TypeOf(a).Kind() {
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
return cmp.Compare(
reflect.ValueOf(a).Int(),
reflect.ValueOf(b).Int(),
)
case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return cmp.Compare(
reflect.ValueOf(a).Uint(),
reflect.ValueOf(b).Uint(),
)
case reflect.Float32, reflect.Float64:
return cmp.Compare(
reflect.ValueOf(a).Float(),
reflect.ValueOf(b).Float(),
)
case reflect.String:
return cmp.Compare(
reflect.ValueOf(a).String(),
reflect.ValueOf(b).String(),
)
case reflect.Struct:
return cmp.Compare(
buildKey(reflect.ValueOf(a)),
buildKey(reflect.ValueOf(b)),
)
}

return 0
}

func buildKey(v reflect.Value) string {
var key strings.Builder
for i := 0; i < v.Type().NumField(); i++ {
v := v.Field(i)
if v.Type().Kind() != reflect.String {
continue
}
key.WriteString(v.String())
}

return key.String()
}
109 changes: 109 additions & 0 deletions set/set_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
// Copyright The Perses Authors
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package set

import (
"encoding/json"
"reflect"
"testing"
)

func TestNewSet(t *testing.T) {
set := NewSet(1, 2, 3)
if len(set) != 3 {
t.Errorf("Expected set length 3, got %d", len(set))
}
if !set.Contains(1) || !set.Contains(2) || !set.Contains(3) {
t.Errorf("Set does not contain expected elements")
}
}

func TestMergeSet(t *testing.T) {
set1 := NewSet(1, 2)
set2 := NewSet(3, 4)
merged := MergeSet(set1, set2)
if len(merged) != 4 {
t.Errorf("Expected merged set length 4, got %d", len(merged))
}
}

func TestSetAdd(t *testing.T) {
set := NewSet(1, 2)
set.Add(3)
if !set.Contains(3) {
t.Errorf("Set does not contain added element")
}
}

func TestSetRemove(t *testing.T) {
set := NewSet(1, 2, 3)
set.Remove(2)
if set.Contains(2) {
t.Errorf("Set still contains removed element")
}
}

func TestSetContains(t *testing.T) {
set := NewSet(1, 2, 3)
if !set.Contains(2) {
t.Errorf("Set does not contain expected element")
}
if set.Contains(4) {
t.Errorf("Set contains unexpected element")
}
}

func TestSetMerge(t *testing.T) {
set1 := NewSet(1, 2)
set2 := NewSet(3, 4)
set1.Merge(set2)
if len(set1) != 4 {
t.Errorf("Expected merged set length 4, got %d", len(set1))
}
}

func TestSetTransformAsSlice(t *testing.T) {
set := NewSet(3, 1, 2)
slice := set.TransformAsSlice()
if len(slice) != 3 {
t.Errorf("Expected slice length 3, got %d", len(slice))
}
if !reflect.DeepEqual(slice, []int{1, 2, 3}) {
t.Errorf("Expected sorted slice [1, 2, 3], got %v", slice)
}
}

func TestSetMarshalJSON(t *testing.T) {
set := NewSet(1, 2, 3)
data, err := json.Marshal(set)
if err != nil {
t.Errorf("Failed to marshal JSON: %v", err)
}
expected := "[1,2,3]"
if string(data) != expected {
t.Errorf("Expected JSON %s, got %s", expected, string(data))
}
}

func TestSetUnmarshalJSON(t *testing.T) {
data := "[1,2,3]"
var set Set[int]
err := json.Unmarshal([]byte(data), &set)
if err != nil {
t.Errorf("Failed to unmarshal JSON: %v", err)
}
if len(set) != 3 || !set.Contains(1) || !set.Contains(2) || !set.Contains(3) {
t.Errorf("Set does not contain expected elements after unmarshalling")
}
}