Kaynağa Gözat

Merge pull request #160 from lengyhua/main

fix: thread safety issue and len inaccuracy in Map
Yeuoly 6 ay önce
ebeveyn
işleme
fdf260cc41

+ 30 - 2
internal/utils/mapping/sync.go

@@ -8,9 +8,12 @@ import (
 type Map[K comparable, V any] struct {
 	len   int32
 	store sync.Map
+	mu    sync.RWMutex
 }
 
 func (m *Map[K, V]) Load(key K) (value V, ok bool) {
+	m.mu.RLock()
+	defer m.mu.RUnlock()
 	v, ok := m.store.Load(key)
 	if !ok {
 		return
@@ -21,12 +24,25 @@ func (m *Map[K, V]) Load(key K) (value V, ok bool) {
 }
 
 func (m *Map[K, V]) Store(key K, value V) {
-	atomic.AddInt32(&m.len, 1)
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	// If the key already exists, we don't want to increment the length
+	_, loaded := m.store.Load(key)
+	if !loaded {
+		atomic.AddInt32(&m.len, 1)
+	}
 	m.store.Store(key, value)
 }
 
 func (m *Map[K, V]) Delete(key K) {
-	atomic.AddInt32(&m.len, -1)
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	_, loaded := m.store.Load(key)
+	// If the key exists, we want to decrement the length
+	// If the key does not exist, we don't want to decrement the length
+	if loaded {
+		atomic.AddInt32(&m.len, -1)
+	}
 	m.store.Delete(key)
 }
 
@@ -37,6 +53,9 @@ func (m *Map[K, V]) Range(f func(key K, value V) bool) {
 }
 
 func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	
 	v, loaded := m.store.LoadOrStore(key, value)
 	actual = v.(V)
 	if !loaded {
@@ -46,6 +65,9 @@ func (m *Map[K, V]) LoadOrStore(key K, value V) (actual V, loaded bool) {
 }
 
 func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+
 	v, loaded := m.store.LoadAndDelete(key)
 	value = v.(V)
 	if loaded {
@@ -55,12 +77,18 @@ func (m *Map[K, V]) LoadAndDelete(key K) (value V, loaded bool) {
 }
 
 func (m *Map[K, V]) Swap(key K, value V) (actual V, swapped bool) {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+
 	v, swapped := m.store.Swap(key, value)
 	actual = v.(V)
 	return
 }
 
 func (m *Map[K, V]) Clear() {
+	m.mu.Lock()
+	defer m.mu.Unlock()
+	// Clear the map
 	m.store.Range(func(key, value interface{}) bool {
 		m.store.Delete(key)
 		return true

+ 111 - 0
internal/utils/mapping/sync_test.go

@@ -0,0 +1,111 @@
+package mapping
+
+import (
+	"sync"
+	"testing"
+)
+
+// TestLoadStore validates basic read/write operations
+func TestLoadStore(t *testing.T) {
+	t.Parallel()
+	m := Map[string, int]{}
+
+	// Test initial state
+	if val, ok := m.Load("missing"); ok {
+		t.Fatalf("Unexpected value for missing key: %v", val)
+	}
+
+	// Test basic store
+	m.Store("answer", 42)
+	if val, ok := m.Load("answer"); !ok || val != 42 {
+		t.Errorf("Load after Store failed, got (%v, %v)", val, ok)
+	}
+
+	// Test overwrite
+	prevLen := m.Len()
+	m.Store("answer", 100)
+	if m.Len() != prevLen {
+		t.Error("Overwriting existing key should not change length")
+	}
+}
+
+// TestDelete validates deletion behavior
+func TestDelete(t *testing.T) {
+	t.Parallel()
+	m := Map[string, string]{}
+
+	// Delete non-existent key
+	m.Delete("ghost")
+	if m.Len() != 0 {
+		t.Error("Deleting non-existent key should not affect length")
+	}
+
+	// Delete existing key
+	m.Store("name", "gopher")
+	m.Delete("name")
+	if _, ok := m.Load("name"); ok || m.Len() != 0 {
+		t.Error("Delete failed to remove item")
+	}
+}
+
+// TestConcurrentAccess verifies thread safety
+func TestConcurrentAccess(t *testing.T) {
+	t.Parallel()
+	m := Map[int, float64]{}
+	const workers = 100
+
+	var wg sync.WaitGroup
+	wg.Add(workers)
+	
+	for i := 0; i < workers; i++ {
+		go func(i int) {
+			defer wg.Done()
+			m.Store(i, float64(i)*1.5)
+			m.Load(i)
+			m.Delete(i)
+		}(i)
+	}
+	wg.Wait()
+
+	if m.Len() != 0 {
+		t.Errorf("Expected empty map after concurrent ops, got len %d", m.Len())
+	}
+}
+
+// TestLoadOrStore verifies conditional storage
+func TestLoadOrStore(t *testing.T) {
+	t.Parallel()
+	m := Map[string, interface{}]{}
+
+	// First store
+	val, loaded := m.LoadOrStore("data", []byte{1,2,3})
+	if loaded || val.([]byte)[0] != 1 {
+		t.Error("Initial LoadOrStore failed")
+	}
+
+	// Existing key
+	val, loaded = m.LoadOrStore("data", "new value")
+	if !loaded || len(val.([]byte)) != 3 {
+		t.Error("Existing key LoadOrStore failed")
+	}
+}
+
+
+
+// TestEdgeCases covers special scenarios
+func TestEdgeCases(t *testing.T) {
+	t.Parallel()
+	m := Map[bool, bool]{}
+
+	// Zero value storage
+	m.Store(true, false)
+	if val, _ := m.Load(true); val != false {
+		t.Error("Zero value storage failed")
+	}
+
+	// Clear operation
+	m.Clear()
+	if m.Len() != 0 {
+		t.Error("Clear failed to reset map")
+	}
+}