diff --git a/orderedmap.go b/orderedmap.go index 9c013ec..816fcc6 100644 --- a/orderedmap.go +++ b/orderedmap.go @@ -1,6 +1,9 @@ package orderedmap +import "sync" + type OrderedMap struct { + sync.RWMutex kv map[interface{}]*Element ll list } @@ -22,6 +25,9 @@ func NewOrderedMapWithCapacity(capacity int) *OrderedMap { // Get returns the value for a key. If the key does not exist, the second return // parameter will be false and the value will be nil. func (m *OrderedMap) Get(key interface{}) (interface{}, bool) { + m.RLock() + defer m.RUnlock() + element, ok := m.kv[key] if ok { return element.Value, true @@ -34,6 +40,9 @@ func (m *OrderedMap) Get(key interface{}) (interface{}, bool) { // will be returned. The returned value will be false if the value was replaced // (even if the value was the same). func (m *OrderedMap) Set(key, value interface{}) bool { + m.Lock() + defer m.Unlock() + _, alreadyExist := m.kv[key] if alreadyExist { m.kv[key].Value = value @@ -48,6 +57,9 @@ func (m *OrderedMap) Set(key, value interface{}) bool { // GetOrDefault returns the value for a key. If the key does not exist, returns // the default value instead. func (m *OrderedMap) GetOrDefault(key, defaultValue interface{}) interface{} { + m.RLock() + defer m.RUnlock() + if element, ok := m.kv[key]; ok { return element.Value } @@ -58,6 +70,9 @@ func (m *OrderedMap) GetOrDefault(key, defaultValue interface{}) interface{} { // GetElement returns the element for a key. If the key does not exist, the // pointer will be nil. func (m *OrderedMap) GetElement(key interface{}) *Element { + m.RLock() + defer m.RUnlock() + element, ok := m.kv[key] if ok { return element @@ -68,6 +83,9 @@ func (m *OrderedMap) GetElement(key interface{}) *Element { // Len returns the number of elements in the map. func (m *OrderedMap) Len() int { + m.RLock() + defer m.RUnlock() + return len(m.kv) } @@ -75,6 +93,9 @@ func (m *OrderedMap) Len() int { // replaced it will retain the same position. To ensure most recently set keys // are always at the end you must always Delete before Set. func (m *OrderedMap) Keys() (keys []interface{}) { + m.RLock() + defer m.RUnlock() + keys = make([]interface{}, 0, m.Len()) for el := m.Front(); el != nil; el = el.Next() { keys = append(keys, el.Key) @@ -85,6 +106,9 @@ func (m *OrderedMap) Keys() (keys []interface{}) { // Delete will remove a key from the map. It will return true if the key was // removed (the key did exist). func (m *OrderedMap) Delete(key interface{}) (didDelete bool) { + m.Lock() + defer m.Unlock() + element, ok := m.kv[key] if ok { m.ll.Remove(element) @@ -97,18 +121,27 @@ func (m *OrderedMap) Delete(key interface{}) (didDelete bool) { // Front will return the element that is the first (oldest Set element). If // there are no elements this will return nil. func (m *OrderedMap) Front() *Element { + m.RLock() + defer m.RUnlock() + return m.ll.Front() } // Back will return the element that is the last (most recent Set element). If // there are no elements this will return nil. func (m *OrderedMap) Back() *Element { + m.RLock() + defer m.RUnlock() + return m.ll.Back() } // Copy returns a new OrderedMap with the same elements. // Using Copy while there are concurrent writes may mangle the result. func (m *OrderedMap) Copy() *OrderedMap { + m.RLock() + defer m.RUnlock() + m2 := NewOrderedMapWithCapacity(m.Len()) for el := m.Front(); el != nil; el = el.Next() { @@ -120,6 +153,9 @@ func (m *OrderedMap) Copy() *OrderedMap { // Has checks if a key exists in the map. func (m *OrderedMap) Has(key interface{}) bool { + m.RLock() + defer m.RUnlock() + _, exists := m.kv[key] return exists } diff --git a/orderedmap_test.go b/orderedmap_test.go index 153f60c..e7f17c0 100644 --- a/orderedmap_test.go +++ b/orderedmap_test.go @@ -3,6 +3,7 @@ package orderedmap_test import ( "fmt" "strconv" + "sync" "testing" "github.com/elliotchance/orderedmap" @@ -1168,6 +1169,31 @@ func BenchmarkBigOrderedMapString_Has(b *testing.B) { benchmarkBigOrderedMapString_Has()(b) } +func TestThreadSafe(t *testing.T) { + m := orderedmap.NewOrderedMap() + wg := &sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 1000000; i++ { + _ = m.Set(i, i) + } + }() + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < 1000000; i++ { + _ = m.Set(i, i) + } + }() + wg.Wait() + for i := 0; i < 1000000; i++ { + if !m.Has(i) { + t.Fail() + } + } +} + func BenchmarkAll(b *testing.B) { b.Run("BenchmarkOrderedMap_Keys", BenchmarkOrderedMap_Keys)