diff --git a/.github/workflows/ci-tests.yml b/.github/workflows/ci-tests.yml index b4b199de..7c514b55 100644 --- a/.github/workflows/ci-tests.yml +++ b/.github/workflows/ci-tests.yml @@ -114,6 +114,10 @@ jobs: database: redis version: 7.0.0 instancetype: TLS + - storagetype: kv + database: none + version: none + instancetype: none steps: - name: Generate GitHub App token diff --git a/Taskfile.yml b/Taskfile.yml index 731f9437..e05eda32 100644 --- a/Taskfile.yml +++ b/Taskfile.yml @@ -62,3 +62,11 @@ tasks: DB_VERSION: "{{.DB_VERSION}}" - mkdir -p coverage/temporal - cp temporal/*.cov coverage/temporal/ + + test-kv: + desc: "Run tests for kv storage" + cmds: + - task: run-tests + vars: + STORAGE_TYPE: kv + DB: "none" diff --git a/kv/config.go b/kv/config.go new file mode 100644 index 00000000..f8e6ed56 --- /dev/null +++ b/kv/config.go @@ -0,0 +1,60 @@ +package kv + +import ( + "encoding/json" +) + +// Config represents the top-level "kv" configuration block in component configs. +// It contains global settings and named store definitions. +// +// Example JSON structure: +// +// { +// "kv": { +// "cache": {"enabled": true, "ttl": "60s"}, +// "stores": { +// "vault-prod": {"type": "vault", "required": true, "config": {...}} +// } +// } +// } +type Config struct { + Stores map[string]StoreConfig `json:"stores"` + Cache CacheConfig `json:"cache"` +} + +// StoreConfig defines the configuration for a single named KV store instance. +type StoreConfig struct { + // Type specifies which provider factory to use. + Type ProviderType `json:"type"` + + // Required determines startup behavior if the store fails to initialize. + Required bool `json:"required"` + + // Config contains provider-specific configuration as raw JSON. + // Each provider's factory knows how to parse its own config format. + Config json.RawMessage `json:"config"` +} + +// CacheConfig controls the caching behavior for resolved secrets. +type CacheConfig struct { + // Enabled controls whether resolved secrets are cached in memory + Enabled bool `json:"enabled"` + + // TTL specifies how long cached values remain valid before refresh. + // Format: Go duration string (e.g., "60s", "5m", "1h") + TTL string `json:"ttl"` + + // RefreshBeforeExpiry specifies the threshold before TTL expiration when a background + // refresh is proactively triggered. It must be less than TTL. + // Format: Go duration string (e.g., "10s"). 0s or empty disables background refresh. + RefreshBeforeExpiry string `json:"refresh_before_expiry"` + + // NegativeTTLNotFound specifies how long to cache "key not found" errors. + // This is typically longer than transient errors as missing keys rarely resolve quickly. + NegativeTTLNotFound string `json:"negative_ttl_not_found"` + + // NegativeTTLTransient specifies how long to cache transient provider errors + // (e.g., network timeouts, service unavailable) to prevent hammering a failing provider. + // This should typically be short to allow quick recovery. + NegativeTTLTransient string `json:"negative_ttl_transient"` +} diff --git a/kv/errors.go b/kv/errors.go new file mode 100644 index 00000000..8baf8e2d --- /dev/null +++ b/kv/errors.go @@ -0,0 +1,48 @@ +package kv + +import ( + "errors" + "fmt" +) + +var ( + // ErrStoreNotFound is returned when referencing an unregistered store name. + ErrStoreNotFound = errors.New("store not found") + + // ErrContractViolation indicates that an underlying KV provider returned data + // violates the expected API contract (e.g., type assertion failures) + ErrContractViolation = errors.New("provider contract violation") + + // ErrStoreClosed is returned when an operation is attempted on closed store or + // provider that has already been shut down via its Close method. + ErrStoreClosed = errors.New("secret store is closed") +) + +func NewStoreNotFoundError(storeName string) error { + return fmt.Errorf("store %q: %w", storeName, ErrStoreNotFound) +} + +// KeyNotFoundError indicates the store is reachable but the key does not exist. +type KeyNotFoundError struct { + StoreName string + KeyPath string +} + +func (e *KeyNotFoundError) Error() string { + return fmt.Sprintf("key %q not found in store %q", e.KeyPath, e.StoreName) +} + +// StoreUnavailableError indicates a transient failure reaching the store. +type StoreUnavailableError struct { + StoreName string + KeyPath string + Err error +} + +func (e *StoreUnavailableError) Error() string { + return fmt.Sprintf("store %q unavailable when fetching key %q: %v", e.StoreName, e.KeyPath, e.Err) +} + +func (e *StoreUnavailableError) Unwrap() error { + return e.Err +} diff --git a/kv/internal/cache/cache.go b/kv/internal/cache/cache.go new file mode 100644 index 00000000..c6a313bc --- /dev/null +++ b/kv/internal/cache/cache.go @@ -0,0 +1,288 @@ +package cache + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/TykTechnologies/storage/kv" +) + +const ( + defaultTTL = 60 * time.Second + defaultNegativeTTLNotFound = 60 * time.Second + defaultNegativeTTLTransient = 5 * time.Second +) + +// Cache provides TTL-based in-memory caching for secret values. +// It's thread-safe and automatically expires entries based on configured TTL. +type Cache struct { + entries map[string]*cacheEntry + enabled bool + ttl time.Duration + refreshBeforeExpiry time.Duration + negativeTTLNotFound time.Duration + negativeTTLTransient time.Duration + mu sync.RWMutex + done chan struct{} + closeOnce sync.Once + isClosed atomic.Bool +} + +// cacheEntry holds a Cached value with its expiration time +type cacheEntry struct { + value string + err error + expiresAt time.Time +} + +// Get retrieves a Cached value by key and returns metadata about Cache state. +// +// Returns: +// - value: the Cached string value (empty if Cache miss or expired) +// - found: true if a valid (non-expired) Cache entry exists +// - needsRefresh: true if entry exists but is within refreshBeforeExpiry window +// - err: the Cached error from the original fetch operation (nil for successful Cached values) +func (c *Cache) Get(key string) (string, bool, bool, error) { + if !c.enabled || c.isClosed.Load() { + return "", false, false, nil + } + + entry, exists, expired := c.get(key) + if !exists || expired { + return "", false, false, nil + } + + var needsRefresh bool + if c.refreshBeforeExpiry > 0 && time.Until(entry.expiresAt) <= c.refreshBeforeExpiry { + needsRefresh = true + } + + return entry.value, true, needsRefresh, entry.err +} + +func (c *Cache) Set(key, value string, err error) { + if !c.enabled || c.isClosed.Load() { + return + } + + // Context errors should NOT be Cached - they indicate caller abandonment, not provider failure + if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) { + return + } + + ttl, shouldCache := c.selectTTL(err) + if !shouldCache { + return + } + + c.mu.Lock() + defer c.mu.Unlock() + + if c.entries == nil { + return + } + + c.entries[key] = &cacheEntry{ + value: value, + expiresAt: time.Now().Add(ttl), + err: err, + } +} + +// Close stops the cleanup goroutine and releases resources. +// It is fully thread-safe and safe to call multiple times. +func (c *Cache) Close() { + if c.done != nil { + c.closeOnce.Do(func() { + c.isClosed.Store(true) + close(c.done) + + c.mu.Lock() + c.entries = nil + c.mu.Unlock() + }) + } +} + +func (c *Cache) selectTTL(err error) (time.Duration, bool) { + if err == nil { + return c.ttl, true + } + + var notFoundErr *kv.KeyNotFoundError + var transientErr *kv.StoreUnavailableError + + if errors.As(err, ¬FoundErr) { + return c.negativeTTLNotFound, true + } + + if errors.As(err, &transientErr) { + return c.negativeTTLTransient, true + } + + return 0, false +} + +func (c *Cache) get(key string) (*cacheEntry, bool, bool) { + now := time.Now() + + c.mu.RLock() + defer c.mu.RUnlock() + + entry, exists := c.entries[key] + + var expired bool + if entry != nil && !now.Before(entry.expiresAt) { + expired = true + } + + return entry, exists, expired +} + +func (c *Cache) cleanupLoop() { + interval := min(c.ttl, c.negativeTTLNotFound, c.negativeTTLTransient) + if interval < time.Second { + interval = time.Second + } + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-c.done: + return + case <-ticker.C: + c.cleanup() + } + } +} + +func (c *Cache) cleanup() { + now := time.Now() + var expired []string + + c.mu.RLock() + for k, v := range c.entries { + if v == nil || !now.Before(v.expiresAt) { + expired = append(expired, k) + } + } + c.mu.RUnlock() + + if len(expired) == 0 { + return + } + + // Recompute the current time under the write lock to avoid deleting entries + // that were written (or renewed) during the window between RUnlock and Lock. + deleteNow := time.Now() + + c.mu.Lock() + for _, k := range expired { + entry := c.entries[k] + if entry == nil || !deleteNow.Before(entry.expiresAt) { + delete(c.entries, k) + } + } + c.mu.Unlock() +} + +// NewCache creates a TTL-based in-memory cache. When Enabled is true, callers +// must call Close() to stop the background cleanup goroutine and release resources. +func NewCache(config kv.CacheConfig) (*Cache, error) { + if !config.Enabled { + return &Cache{enabled: false, entries: make(map[string]*cacheEntry)}, nil + } + + ttl, err := parseOptionalDuration(config.TTL, defaultTTL, "ttl") + if err != nil { + return nil, err + } + + refreshBeforeExpiry, err := parseRefreshBeforeExpiry(config.RefreshBeforeExpiry, ttl) + if err != nil { + return nil, err + } + + negativeTTLNotFound, err := parseOptionalDuration( + config.NegativeTTLNotFound, + defaultNegativeTTLNotFound, + "negative_ttl_not_found", + ) + if err != nil { + return nil, err + } + + negativeTTLTransient, err := parseOptionalDuration( + config.NegativeTTLTransient, + defaultNegativeTTLTransient, + "negative_ttl_transient", + ) + if err != nil { + return nil, err + } + + c := &Cache{ + entries: make(map[string]*cacheEntry), + enabled: config.Enabled, + ttl: ttl, + refreshBeforeExpiry: refreshBeforeExpiry, + negativeTTLNotFound: negativeTTLNotFound, + negativeTTLTransient: negativeTTLTransient, + done: make(chan struct{}), + } + + go c.cleanupLoop() + + return c, nil +} + +// parseOptionalDuration parses a duration string, returning a default value if empty. +// It also validates that the parsed duration is strictly positive. +func parseOptionalDuration(val string, defaultVal time.Duration, name string) (time.Duration, error) { + if val == "" { + return defaultVal, nil + } + + d, err := time.ParseDuration(val) + if err != nil { + return 0, fmt.Errorf("invalid cache %s value: %w", name, err) + } + + if d <= 0 { + return 0, fmt.Errorf("cache %s must be positive, got %v", name, val) + } + + return d, nil +} + +// parseRefreshBeforeExpiry parses and validates the refresh_before_expiry configuration. +func parseRefreshBeforeExpiry(val string, ttl time.Duration) (time.Duration, error) { + if val == "" { + return 0, nil + } + + d, err := time.ParseDuration(val) + if err != nil { + return 0, fmt.Errorf("invalid cache refresh_before_expiry value: %w", err) + } + + if d < 0 { + return 0, fmt.Errorf("cache refresh_before_expiry must be positive, got %v", val) + } + + if d >= ttl { + return 0, fmt.Errorf( + "refresh_before_expiry(%v) must be less than ttl(%v)", + d, + ttl, + ) + } + + return d, nil +} diff --git a/kv/internal/cache/cache_test.go b/kv/internal/cache/cache_test.go new file mode 100644 index 00000000..988afda7 --- /dev/null +++ b/kv/internal/cache/cache_test.go @@ -0,0 +1,583 @@ +package cache + +import ( + "context" + "fmt" + "sync" + "testing" + "testing/synctest" + "time" + + "github.com/TykTechnologies/storage/kv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCache(t *testing.T) { + t.Parallel() + + t.Run("disable cache", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: false} + c, err := NewCache(cfg) + require.NoError(t, err) + require.NotNil(t, c) + require.False(t, c.enabled) + }) + + t.Run("invalid TTL", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "invalid"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "invalid cache ttl") + }) + + t.Run("negative TTL", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "-5s"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "cache ttl must be positive") + }) + + t.Run("invalid refresh before expiry", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "1s", RefreshBeforeExpiry: "some"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "invalid cache refresh_before_expiry") + }) + + t.Run("negative refresh before expiry", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "1s", RefreshBeforeExpiry: "-1s"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "refresh_before_expiry must be positive") + }) + + t.Run("invalid negative ttl not found", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "1s", NegativeTTLNotFound: "some"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "invalid cache negative_ttl_not_found") + }) + + t.Run("negative value for negative ttl not found", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "1s", NegativeTTLNotFound: "-1s"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "negative_ttl_not_found must be positive") + }) + + t.Run("invalid negative ttl transient", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "1s", NegativeTTLTransient: "some"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "invalid cache negative_ttl_transient") + }) + + t.Run("negative value for negative ttl transient", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "1s", NegativeTTLTransient: "-1s"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Nil(t, c) + require.Contains(t, err.Error(), "negative_ttl_transient must be positive") + }) + + t.Run("returns error when refresh before expiry >= ttl", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true, TTL: "1s", RefreshBeforeExpiry: "1s"} + c, err := NewCache(cfg) + require.Error(t, err) + require.Contains(t, err.Error(), "must be less than ttl") + require.Nil(t, c) + }) + + t.Run("sets correct defaults", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: true} + c, err := NewCache(cfg) + require.NoError(t, err) + require.NotNil(t, c) + require.Equal(t, defaultTTL, c.ttl) + require.Empty(t, c.refreshBeforeExpiry) + require.Equal(t, defaultNegativeTTLNotFound, c.negativeTTLNotFound) + require.Equal(t, defaultNegativeTTLTransient, c.negativeTTLTransient) + }) + + t.Run("valid config", func(t *testing.T) { + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "100ms", + RefreshBeforeExpiry: "50ms", + NegativeTTLNotFound: "20s", + NegativeTTLTransient: "2s", + } + c := newTestCache(t, cfg) + require.NotNil(t, c) + require.Equal(t, 100*time.Millisecond, c.ttl) + require.Equal(t, 50*time.Millisecond, c.refreshBeforeExpiry) + require.Equal(t, 20*time.Second, c.negativeTTLNotFound) + require.Equal(t, 2*time.Second, c.negativeTTLTransient) + require.NotNil(t, c.entries) + }) +} + +func TestCache_GetSet(t *testing.T) { + t.Parallel() + + defaultConfig := kv.CacheConfig{ + Enabled: true, + TTL: "500ms", + } + c, err := NewCache(defaultConfig) + require.NoError(t, err) + + t.Run("cache disabled", func(t *testing.T) { + cfg := kv.CacheConfig{Enabled: false, TTL: "500ms", RefreshBeforeExpiry: "200ms"} + c := newTestCache(t, cfg) + require.NoError(t, err) + + c.Set("cache-disabled", "some", nil) + + val, exists, needsRefresh, err := c.Get("cache-disabled") + assert.False(t, exists) + assert.False(t, needsRefresh) + assert.Empty(t, val) + assert.NoError(t, err) + }) + + t.Run("cache miss", func(t *testing.T) { + val, exists, needsRefresh, err := c.Get("non-existent") + assert.False(t, exists) + assert.False(t, needsRefresh) + assert.Empty(t, val) + assert.NoError(t, err) + }) + + t.Run("cache hit", func(t *testing.T) { + c.Set("key1", "value1", nil) + + val, exists, needsRefresh, err := c.Get("key1") + assert.True(t, exists) + assert.False(t, needsRefresh) + assert.Equal(t, "value1", val) + assert.NoError(t, err) + }) + + t.Run("negative caching with KeyNotFoundError", func(t *testing.T) { + expectedErr := &kv.KeyNotFoundError{} + c.Set("key2", "", expectedErr) + + val, exists, needsRefresh, err := c.Get("key2") + assert.True(t, exists) + assert.False(t, needsRefresh) + assert.Empty(t, val) + assert.ErrorAs(t, err, &expectedErr) + + entry, _, _ := c.get("key2") + expectedTTL := time.Now().Add(defaultNegativeTTLNotFound) + require.WithinDuration(t, expectedTTL, entry.expiresAt, time.Second) + }) + + t.Run("negative caching with StoreUnavailableError", func(t *testing.T) { + expectedErr := &kv.StoreUnavailableError{} + c.Set("key3", "", expectedErr) + + val, exists, needsRefresh, err := c.Get("key3") + assert.True(t, exists) + assert.False(t, needsRefresh) + assert.Empty(t, val) + assert.ErrorAs(t, err, &expectedErr) + + entry, _, _ := c.get("key3") + expectedTTL := time.Now().Add(defaultNegativeTTLTransient) + require.WithinDuration(t, expectedTTL, entry.expiresAt, time.Second) + }) + + t.Run("negative caching is disabled for context errors", func(t *testing.T) { + c.Set("key4", "value4", context.Canceled) + c.Set("key5", "value5", context.DeadlineExceeded) + + val, exists, needsRefresh, err := c.Get("key4") + assert.False(t, exists) + assert.False(t, needsRefresh) + assert.Empty(t, val) + assert.NoError(t, err) + + val, exists, needsRefresh, err = c.Get("key5") + assert.False(t, exists) + assert.False(t, needsRefresh) + assert.Empty(t, val) + assert.NoError(t, err) + + entry, _, _ := c.get("key4") + require.Empty(t, entry) + + entry, _, _ = c.get("key5") + require.Empty(t, entry) + }) +} + +func TestCache_RefreshBeforeExpiry(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "2s", + RefreshBeforeExpiry: "1s", + } + c := newTestCache(t, cfg) + c.Set("key1", "value1", nil) + time.Sleep(time.Second) + + val, exists, needsRefresh, err := c.Get("key1") + assert.True(t, exists) + assert.True(t, needsRefresh) + assert.Equal(t, "value1", val) + assert.NoError(t, err) + }) +} + +func TestCache_RefreshBeforeExpiryBoundary(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "1s", + RefreshBeforeExpiry: "500ms", + } + c := newTestCache(t, cfg) + c.Set("key1", "value1", nil) + + // Just before refresh window + time.Sleep(490 * time.Millisecond) + synctest.Wait() + + _, _, needsRefresh, err := c.Get("key1") + assert.NoError(t, err) + assert.False(t, needsRefresh, "Should not need refresh yet") + + time.Sleep(20 * time.Millisecond) + synctest.Wait() + + _, _, needsRefresh, err = c.Get("key1") + assert.NoError(t, err) + assert.True(t, needsRefresh, "Should need refresh now") + }) +} + +func TestCache_ZeroRefreshBeforeExpiry(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "1s", + RefreshBeforeExpiry: "0s", + } + c := newTestCache(t, cfg) + + c.Set("key1", "value1", nil) + + time.Sleep(900 * time.Millisecond) + synctest.Wait() + + _, _, needsRefresh, err := c.Get("key1") + require.NoError(t, err) + assert.False(t, needsRefresh, "Zero refresh window should never trigger refresh") + }) +} + +func TestCache_NegativeCachingBoundaries(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "1s", + NegativeTTLNotFound: "100ms", + } + c := newTestCache(t, cfg) + + c.Set("short-not-found", "", &kv.KeyNotFoundError{}) + + time.Sleep(60 * time.Millisecond) + synctest.Wait() + + _, exists, _, err := c.Get("short-not-found") + require.Error(t, err) + assert.True(t, exists, "Should still exist before negative TTL expires") + + time.Sleep(50 * time.Millisecond) + synctest.Wait() + + _, exists, _, err = c.Get("short-not-found") + require.NoError(t, err) + assert.False(t, exists, "Should expire after negative TTL") + }) +} + +func TestCache_OverwriteExistingEntry(t *testing.T) { + t.Parallel() + + cfg := kv.CacheConfig{Enabled: true, TTL: "1s"} + c := newTestCache(t, cfg) + + // Set initial value + c.Set("key1", "value1", nil) + val, exists, _, err := c.Get("key1") + assert.True(t, exists) + assert.Equal(t, "value1", val) + assert.NoError(t, err) + + // Overwrite with error + c.Set("key1", "", &kv.KeyNotFoundError{}) + val, exists, _, err = c.Get("key1") + assert.True(t, exists) + assert.Empty(t, val) + assert.Error(t, err) + + // Overwrite error with success + c.Set("key1", "value2", nil) + val, exists, _, err = c.Get("key1") + assert.True(t, exists) + assert.Equal(t, "value2", val) + assert.NoError(t, err) +} + +func TestCache_UnknownErrorTypes(t *testing.T) { + t.Parallel() + + cfg := kv.CacheConfig{Enabled: true, TTL: "1s"} + c := newTestCache(t, cfg) + + unknownErr := fmt.Errorf("some random error") + c.Set("key1", "value1", unknownErr) + + _, exists, _, err := c.Get("key1") + require.NoError(t, err) + assert.False(t, exists, "Unknown errors should not be cached") +} + +func TestCache_CleanupIntervalScaling(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "100ms", + } + c := newTestCache(t, cfg) + + c.Set("key1", "value1", nil) + + time.Sleep(150 * time.Millisecond) + synctest.Wait() + + // Value is expired but still physically exists + _, exists, _, err := c.Get("key1") + require.NoError(t, err) + assert.False(t, exists) + + _, physicallyExists, _ := c.get("key1") + assert.True(t, physicallyExists, "Should be physically present as cleanup hasn't run yet") + + // Wait for cleanup interval (should be 1s minimum) + time.Sleep(2 * time.Second) + synctest.Wait() + + _, physicallyExists, _ = c.get("key1") + assert.False(t, physicallyExists, "Should be physically removed after cleanup") + }) +} + +func TestCache_CleanupExpiredEntries(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "2s", + NegativeTTLNotFound: "10s", + NegativeTTLTransient: "4s", + } + c := newTestCache(t, cfg) + + testEntries := []struct { + key string + value string + err error + description string + }{ + {"key1", "value1", nil, "normal entry (2s TTL)"}, + {"key2", "value2", &kv.KeyNotFoundError{}, "not found entry (10s TTL)"}, + {"key3", "value3", &kv.StoreUnavailableError{}, "transient error entry (5s TTL)"}, + } + + for _, entry := range testEntries { + c.Set(entry.key, entry.value, entry.err) + } + + // Phase 1: Verify all entries are initially present and accessible + assertCacheEntry(t, c, "key1", "value1", true, false, "normal entry should be accessible") + assertCacheEntry(t, c, "key2", "value2", true, true, "not found entry should be accessible with error") + assertCacheEntry(t, c, "key3", "value3", true, true, "transient error entry should be accessible with error") + + // Phase 2: After 2s - normal entry (key1) should expire, negative entries should remain + time.Sleep(2 * time.Second) + synctest.Wait() + + assertCacheEntry(t, c, "key1", "", false, false, "normal entry should be expired") + assertCacheEntry(t, c, "key2", "value2", true, true, "not found entry should still be present") + assertCacheEntry(t, c, "key3", "value3", true, true, "transient error entry should still be present") + + _, exists, _ := c.get("key1") + assert.False(t, exists, "key1 should be removed from internal storage") + + // Phase 3: After 4s total - transient error entry (key3) should expire + time.Sleep(2 * time.Second) + synctest.Wait() + + assertCacheEntry(t, c, "key1", "", false, false, "normal entry should still be expired") + assertCacheEntry(t, c, "key2", "value2", true, true, "not found entry should still be present") + assertCacheEntry(t, c, "key3", "", false, false, "transient error entry should now be expired") + + _, exists, _ = c.get("key3") + assert.False(t, exists, "key3 should be removed from internal storage") + + // Phase 4: After 10s total - not found error entry (key2) should expire + time.Sleep(6 * time.Second) + synctest.Wait() + + assertCacheEntry(t, c, "key1", "", false, false, "normal entry should still be expired") + assertCacheEntry(t, c, "key2", "", false, false, "not found entry should now be expired") + assertCacheEntry(t, c, "key3", "", false, false, "transient error entry should still be expired") + + _, exists, _ = c.get("key2") + assert.False(t, exists, "key2 should be removed from internal storage") + }) +} + +func TestCache_Concurrency(t *testing.T) { + t.Parallel() + + cfg := kv.CacheConfig{Enabled: true, TTL: "10m"} + c := newTestCache(t, cfg) + + var wg sync.WaitGroup + + for range 50 { + wg.Go(func() { + for j := range 100 { + key := fmt.Sprintf("key-%d", j) + c.Set(key, "value", nil) + } + }) + } + + for range 50 { + wg.Go(func() { + for j := range 100 { + key := fmt.Sprintf("key-%d", j) + _, _, _, err := c.Get(key) + assert.NoError(t, err) + } + }) + } + + wg.Wait() + + c.mu.RLock() + assert.Greater(t, len(c.entries), 0) + c.mu.RUnlock() +} + +func TestCache_Close(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + cache := newTestCache(t, kv.CacheConfig{Enabled: true, TTL: "1s"}) + + cache.Set("key", "value", nil) + + time.Sleep(time.Second) + synctest.Wait() + + // Value is cleared by normal cleanup work + _, exists, _, err := cache.Get("key") + require.NoError(t, err) + assert.False(t, exists) + + cache.Set("key2", "value2", nil) + + _, exists, _, err = cache.Get("key2") + require.NoError(t, err) + assert.True(t, exists) + + // The cleanup goroutine stopped + cache.Close() + + // Advance time past TTL + time.Sleep(time.Second) + synctest.Wait() + + _, exists, _, err = cache.Get("key2") + require.NoError(t, err) + assert.False(t, exists) + + entry, _, _ := cache.get("key2") + require.Empty(t, entry) + + // Stress test concurrent idempotency of Close() + var wg sync.WaitGroup + for range 10 { + wg.Go(func() { + cache.Close() + }) + } + }) +} + +func newTestCache(t *testing.T, cfg kv.CacheConfig) *Cache { + t.Helper() + + cache, err := NewCache(cfg) + require.NoError(t, err) + t.Cleanup(func() { + cache.Close() + }) + + return cache +} + +func assertCacheEntry( + t *testing.T, + c *Cache, + key, + expectedValue string, + shouldExist, + shouldHaveError bool, + description string, +) { + t.Helper() + + val, exists, _, err := c.Get(key) + + if shouldExist { + assert.True(t, exists, "%s: key %s should exist in cache", description, key) + assert.Equal(t, expectedValue, val, "%s: key %s should have expected value", description, key) + } else { + assert.False(t, exists, "%s: key %s should not exist in cache", description, key) + } + + if shouldHaveError { + assert.Error(t, err, "%s: key %s should return an error", description, key) + } else { + assert.NoError(t, err, "%s: key %s should not return an error", description, key) + } +} diff --git a/kv/internal/store/store.go b/kv/internal/store/store.go new file mode 100644 index 00000000..59e6ead0 --- /dev/null +++ b/kv/internal/store/store.go @@ -0,0 +1,180 @@ +package store + +import ( + "context" + "errors" + "fmt" + "sync/atomic" + "time" + + "github.com/TykTechnologies/storage/kv" + "github.com/TykTechnologies/storage/kv/internal/cache" + "golang.org/x/sync/singleflight" +) + +const defaultProviderTimeout = 5 * time.Second + +// SecretStore is an internal decorator that adds caching and singleflight to a Provider. +type SecretStore struct { + name string + provider kv.Provider + cache *cache.Cache + sf *singleflight.Group + sfRefresh *singleflight.Group + isClosed atomic.Bool + timeout time.Duration +} + +// Option defines a functional option for configuring the SecretStore. +type Option func(*SecretStore) + +// Get retrieves a secret value with caching and deduplication. +func (s *SecretStore) Get(ctx context.Context, path string) (string, error) { + if s.isClosed.Load() { + return "", kv.ErrStoreClosed + } + + val, exists, needsRefresh, err := s.cache.Get(path) + if exists { + // Fail fast on cached errors + if err != nil { + return "", err + } + + // If value is almost expired on cache, the process should refresh it + // on background which is called "stale-while-revalidate" strategy + if needsRefresh && !s.isClosed.Load() { + s.triggerBackgroundRefreshOnce(path) + } + + return val, err + } + + if s.isClosed.Load() { + return "", kv.ErrStoreClosed + } + + ch := s.sf.DoChan(path, func() (any, error) { + fetchCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), s.timeout) + defer cancel() + + newVal, err := s.provider.Get(fetchCtx, path) + + // Return earlier to prevent cache poisoning with context errors + if errors.Is(err, context.DeadlineExceeded) { + return "", fmt.Errorf("timeout fetching %q: %w", path, err) + } + + if !s.isClosed.Load() { + s.cache.Set(path, newVal, err) + } + + return newVal, err + }) + + select { + case <-ctx.Done(): + return "", ctx.Err() + case res := <-ch: + if res.Err != nil { + return "", res.Err + } + + v, ok := res.Val.(string) + if !ok { + return "", fmt.Errorf( + "%w: path %q returned non-string type", + kv.ErrContractViolation, + path, + ) + } + + return v, nil + } +} + +// Unwrap allows callers to access the underlying provider for optional interfaces (like Lister) +func (s *SecretStore) Unwrap() kv.Provider { + return s.provider +} + +func (s *SecretStore) Close(ctx context.Context) error { + if s.isClosed.Swap(true) { + return nil + } + + s.cache.Close() + + if closer, ok := kv.AsCloser(s.provider); ok { + return closer.Close(ctx) + } + + return nil +} + +func (s *SecretStore) triggerBackgroundRefreshOnce(path string) { + ch := s.sfRefresh.DoChan(path, func() (any, error) { + return s.doBackgroundRefresh(path) + }) + _ = ch +} + +func (s *SecretStore) doBackgroundRefresh(path string) (any, error) { + if s.isClosed.Load() { + return "", kv.ErrStoreClosed + } + + // We're creating a new context for background refresh because we don't want + // a cancelled HTTP request to abort a cache refresh that benefits all future callers. + ctx, cancel := context.WithTimeout(context.Background(), s.timeout) + defer cancel() + + newVal, err := s.provider.Get(ctx, path) + // Update the cache on success to ensure errors don't overwrite valid entries. + if err == nil && !s.isClosed.Load() { + s.cache.Set(path, newVal, nil) + } + + return newVal, err +} + +// WithTimeout overrides the global default provider timeout. +func WithTimeout(timeout time.Duration) Option { + return func(store *SecretStore) { + if timeout > 0 { + store.timeout = timeout + } + } +} + +// NewSecretStore instantiates the store wrapper with optional configurations. +func NewSecretStore( + name string, + provider kv.Provider, + cacheConfig kv.CacheConfig, + opts ...Option, +) (*SecretStore, error) { + if provider == nil { + return nil, fmt.Errorf("failed to create a secret store with name %q: provider cannot be nil", name) + } + + cache, err := cache.NewCache(cacheConfig) + if err != nil { + return nil, fmt.Errorf("failed to create secret store: %w", err) + } + + s := &SecretStore{ + name: name, + provider: provider, + cache: cache, + sf: &singleflight.Group{}, + sfRefresh: &singleflight.Group{}, + timeout: defaultProviderTimeout, + } + + for _, opt := range opts { + opt(s) + } + + return s, nil +} diff --git a/kv/internal/store/store_test.go b/kv/internal/store/store_test.go new file mode 100644 index 00000000..f33616b0 --- /dev/null +++ b/kv/internal/store/store_test.go @@ -0,0 +1,669 @@ +package store + +import ( + "context" + "fmt" + "sync" + "sync/atomic" + "testing" + "testing/synctest" + "time" + + "github.com/TykTechnologies/storage/kv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockProvider struct { + calls atomic.Int32 + delay time.Duration + mockGetFunc func(ctx context.Context, path string) (string, error) + closed atomic.Bool +} + +func (m *mockProvider) Get(ctx context.Context, path string) (string, error) { + m.calls.Add(1) + + if m.delay > 0 { + select { + case <-time.After(m.delay): + case <-ctx.Done(): + return "", ctx.Err() + } + } + + if m.mockGetFunc != nil { + return m.mockGetFunc(ctx, path) + } + + return "mock-secret", nil +} + +func (m *mockProvider) Close(_ context.Context) error { + m.closed.Store(true) + return nil +} + +func TestNewSecretStore(t *testing.T) { + t.Parallel() + + t.Run("nil provider", func(t *testing.T) { + store, err := NewSecretStore("test", nil, kv.CacheConfig{Enabled: true, TTL: "1m"}) + require.Error(t, err) + require.Nil(t, store) + require.Contains(t, err.Error(), "provider cannot be nil") + }) + + t.Run("invalid cache config", func(t *testing.T) { + provider := &mockProvider{} + store, err := NewSecretStore("test", provider, kv.CacheConfig{ + Enabled: true, + TTL: "invalid-duration", + }) + require.Error(t, err) + require.Nil(t, store) + require.Contains(t, err.Error(), "failed to create secret store") + }) + + t.Run("negative TTL", func(t *testing.T) { + provider := &mockProvider{} + store, err := NewSecretStore("test", provider, kv.CacheConfig{ + Enabled: true, + TTL: "-10s", + }) + require.Error(t, err) + require.Nil(t, store) + }) + + t.Run("assigns default values", func(t *testing.T) { + provider := &mockProvider{} + store, err := NewSecretStore("test", provider, kv.CacheConfig{ + Enabled: false, + }) + require.NoError(t, err) + require.NotNil(t, store) + require.Equal(t, defaultProviderTimeout, store.timeout) + }) + + t.Run("cache disabled", func(t *testing.T) { + provider := &mockProvider{} + store, err := NewSecretStore("test", provider, kv.CacheConfig{ + Enabled: false, + }) + require.NoError(t, err) + require.NotNil(t, store) + + // Every call should hit provider + _, err = store.Get(t.Context(), "key1") + require.NoError(t, err) + _, err = store.Get(t.Context(), "key1") + require.NoError(t, err) + require.Equal(t, int32(2), provider.calls.Load()) + }) +} + +func TestGet_CacheMissAndHit(t *testing.T) { + t.Parallel() + + provider := &mockProvider{} + cfg := kv.CacheConfig{Enabled: true, TTL: "1m"} + store, err := NewSecretStore("test-store", provider, cfg) + require.NoError(t, err) + + // First call: cache miss + val, err := store.Get(t.Context(), "secret-1") + require.NoError(t, err) + assert.Equal(t, "mock-secret", val) + assert.Equal(t, int32(1), provider.calls.Load(), "cache miss should call provider") + + // Second call: cache hit + val, err = store.Get(t.Context(), "secret-1") + require.NoError(t, err) + assert.Equal(t, "mock-secret", val) + assert.Equal(t, int32(1), provider.calls.Load(), "cache hit should not call provider") +} + +func TestGet_ProviderErrorReturned(t *testing.T) { + t.Parallel() + + expectedErr := &kv.KeyNotFoundError{} + provider := &mockProvider{ + mockGetFunc: func(ctx context.Context, path string) (string, error) { + return "", expectedErr + }, + } + cfg := kv.CacheConfig{Enabled: true, TTL: "1m"} + store, err := NewSecretStore("test-store", provider, cfg) + require.NoError(t, err) + + val, err := store.Get(t.Context(), "secret-err") + require.Error(t, err) + require.ErrorAs(t, err, &expectedErr) + assert.Empty(t, val) + assert.Equal(t, int32(1), provider.calls.Load()) +} + +func TestGet_NegativeCachingForKeyNotFoundError(t *testing.T) { + t.Parallel() + + expectedErr := &kv.KeyNotFoundError{} + provider := &mockProvider{ + mockGetFunc: func(ctx context.Context, path string) (string, error) { + return "secret", expectedErr + }, + } + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "1m", + NegativeTTLNotFound: "30s", + } + store, err := NewSecretStore("test-store", provider, cfg) + require.NoError(t, err) + + val, err := store.Get(t.Context(), "secret-err") + require.Error(t, err) + require.ErrorAs(t, err, &expectedErr) + assert.Empty(t, val, "value should be empty even if provider returned non-empty string") + assert.Equal(t, int32(1), provider.calls.Load()) + + val, err = store.Get(t.Context(), "secret-err") + require.Error(t, err) + require.ErrorAs(t, err, &expectedErr) + assert.Empty(t, val) + assert.Equal(t, int32(1), provider.calls.Load(), "cached error should prevent provider call") +} + +func TestGet_SingleFlightDeduplication(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + provider := &mockProvider{ + delay: time.Second, + } + cfg := kv.CacheConfig{Enabled: true, TTL: "10s"} + store, err := NewSecretStore("test-store", provider, cfg) + require.NoError(t, err) + + t.Cleanup(func() { + store.Close(t.Context()) + }) + + var wg sync.WaitGroup + + start := time.Now() + + for range 100 { + wg.Go(func() { + val, err := store.Get(t.Context(), "concurrent-secret") + require.NoError(t, err) + assert.Equal(t, "mock-secret", val) + }) + } + + wg.Wait() + + require.Less( + t, + time.Since(start), + 1001*time.Millisecond, + "all 100 requests will return after first success singleflight call", + ) + assert.Equal( + t, + int32(1), + provider.calls.Load(), + "100 concurrent requests should deduplicate to 1 provider call", + ) + }) +} + +func TestGet_CacheDisabled_AlwaysCallsProvider(t *testing.T) { + t.Parallel() + + provider := &mockProvider{} + cfg := kv.CacheConfig{Enabled: false} + store, err := NewSecretStore("test-store", provider, cfg) + require.NoError(t, err) + + val, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + assert.Equal(t, "mock-secret", val) + assert.Equal(t, int32(1), provider.calls.Load()) + + val, err = store.Get(t.Context(), "key1") + require.NoError(t, err) + assert.Equal(t, "mock-secret", val) + assert.Equal( + t, + int32(2), + provider.calls.Load(), + "cache disabled should call provider every time", + ) +} + +func TestGet_DifferentKeysIndependent(t *testing.T) { + t.Parallel() + + var callCount atomic.Int32 + provider := &mockProvider{ + mockGetFunc: func(ctx context.Context, path string) (string, error) { + callCount.Add(1) + return fmt.Sprintf("secret-%s", path), nil + }, + } + cfg := kv.CacheConfig{Enabled: true, TTL: "1m"} + store, err := NewSecretStore("test-store", provider, cfg) + require.NoError(t, err) + + val1, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + assert.Equal(t, "secret-key1", val1) + + val2, err := store.Get(t.Context(), "key2") + require.NoError(t, err) + assert.Equal(t, "secret-key2", val2) + + assert.Equal(t, int32(2), callCount.Load(), "different keys should trigger separate provider calls") + + val1, err = store.Get(t.Context(), "key1") + require.NoError(t, err) + assert.Equal(t, "secret-key1", val1) + assert.Equal(t, int32(2), callCount.Load(), "refetch should use cache") +} + +func TestGet_TimeoutEnforcement(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + opts []Option + expectedTimeout time.Duration + }{ + { + name: "Use default timeout if not explicitly provided", + opts: nil, + expectedTimeout: 5 * time.Second, + }, + { + name: "Override default timeout with custom duration", + opts: []Option{WithTimeout(10 * time.Second)}, + expectedTimeout: 10 * time.Second, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + provider := &mockProvider{ + delay: 30 * time.Second, + } + cfg := kv.CacheConfig{Enabled: true, TTL: "1m"} + + store, err := NewSecretStore("test-store", provider, cfg, tt.opts...) + require.NotNil(t, store) + require.NoError(t, err) + t.Cleanup(func() { + store.Close(t.Context()) + }) + + start := time.Now() + + var wg sync.WaitGroup + wg.Go(func() { + val, err := store.Get(t.Context(), "slow-key") + require.Error(t, err) + require.Contains(t, err.Error(), "timeout fetching") + assert.Empty(t, val) + }) + + synctest.Wait() + wg.Wait() + + elapsed := time.Since(start) + assert.Equal(t, tt.expectedTimeout, elapsed) + }) + }) + } +} + +func TestStaleWhileRevalidate(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + var callCount int32 + + provider := &mockProvider{ + mockGetFunc: func(ctx context.Context, path string) (string, error) { + count := atomic.AddInt32(&callCount, 1) + return fmt.Sprintf("secret-v%d", count), nil + }, + } + cfg := kv.CacheConfig{Enabled: true, TTL: "5s", RefreshBeforeExpiry: "1s"} + store := newTestStore(t, provider, cfg) + + // Cache miss + val, err := store.Get(t.Context(), "stale-secret") + require.NoError(t, err) + assert.Equal(t, "secret-v1", val) + + time.Sleep(4 * time.Second) + + // Cache hit and triggers background refresh + val, err = store.Get(t.Context(), "stale-secret") + require.NoError(t, err) + assert.Equal(t, "secret-v1", val) + + // Wait for background refresh to finish + synctest.Wait() + + // Refreshed value + start := time.Now() + val, err = store.Get(t.Context(), "stale-secret") + require.NoError(t, err) + assert.Equal(t, "secret-v2", val) + assert.Equal(t, int32(2), callCount) + + latency := time.Since(start) + require.Less(t, latency, 10*time.Millisecond, "should return stale value immediately") + }) +} + +func TestBackgroundRefreshDeduplication(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + provider := &mockProvider{ + delay: 100 * time.Millisecond, + } + + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "1s", + RefreshBeforeExpiry: "500ms", + } + store := newTestStore(t, provider, cfg) + + _, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + + // Advance time to enter RefreshBeforeExpiry window + time.Sleep(600 * time.Millisecond) + + var wg sync.WaitGroup + + for range 100 { + wg.Go(func() { + _, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + }) + } + + wg.Wait() + + // We wanna be sure that second request to provider is finished + time.Sleep(100 * time.Millisecond) + synctest.Wait() + + require.Equal(t, int32(2), provider.calls.Load()) + }) +} + +func TestBackgroundRefreshSurvivesRequestCancellation(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + var callCount atomic.Int32 + provider := &mockProvider{ + delay: 100 * time.Millisecond, + mockGetFunc: func(ctx context.Context, path string) (string, error) { + count := callCount.Add(1) + return fmt.Sprintf("secret-v%d", count), nil + }, + } + + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "2s", + RefreshBeforeExpiry: "1s", + } + store := newTestStore(t, provider, cfg) + + // Initial fetch + val, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + require.Equal(t, "secret-v1", val) + + // Advance time to enter the RefreshBeforeExpiry window + time.Sleep(time.Second) + + cancelCtx, cancel := context.WithCancel(context.Background()) + val, err = store.Get(cancelCtx, "key1") + require.NoError(t, err) + require.Equal(t, "secret-v1", val) + + cancel() + + // Wait for background refresh to complete + time.Sleep(100 * time.Millisecond) + synctest.Wait() + + // Verify fresh value was cached despite cancellation + val, err = store.Get(t.Context(), "key1") + require.NoError(t, err) + require.Equal(t, "secret-v2", val) + require.Equal(t, int32(2), callCount.Load()) + }) +} + +func TestConcurrentBackgroundRefreshDifferentKeys(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + var key1Calls, key2Calls atomic.Int32 + provider := &mockProvider{ + delay: 100 * time.Millisecond, + mockGetFunc: func(ctx context.Context, path string) (string, error) { + if path == "key1" { + key1Calls.Add(1) + return "secret-key1", nil + } + + key2Calls.Add(1) + return "secret-key2", nil + }, + } + + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "2s", + RefreshBeforeExpiry: "1s", + } + store := newTestStore(t, provider, cfg) + + _, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + _, err = store.Get(t.Context(), "key2") + require.NoError(t, err) + + time.Sleep(time.Second) + + var wg sync.WaitGroup + + wg.Go(func() { + _, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + }) + wg.Go(func() { + _, err := store.Get(t.Context(), "key2") + require.NoError(t, err) + }) + + wg.Wait() + + time.Sleep(100 * time.Millisecond) // Wait for refreshes to complete + synctest.Wait() + + // Each key should have exactly 2 calls (initial + 1 refresh) + require.Equal(t, int32(2), key1Calls.Load()) + require.Equal(t, int32(2), key2Calls.Load()) + }) +} + +func TestContextCancellationDoesNotPoisonCache(t *testing.T) { + t.Parallel() + + synctest.Test(t, func(t *testing.T) { + provider := &mockProvider{ + // Each call to provider will end-up deadline exceeded + // if context is not canceled before. + delay: 10 * time.Second, + } + + cfg := kv.CacheConfig{Enabled: true, TTL: "5s"} + store := newTestStore(t, provider, cfg) + + go func() { + // Foreground fetch with canceled request. + // The select immediately returns an error and provider + // hasn't been called. + ctx, cancel := context.WithCancel(t.Context()) + + cancel() + + val, err := store.Get(ctx, "cancel-secret") + require.Error(t, err) + require.Contains(t, err.Error(), "context canceled") + require.Empty(t, val) + }() + + val, err := store.Get(t.Context(), "cancel-secret") + require.Error(t, err) + require.Contains(t, err.Error(), "timeout fetching ") + require.Empty(t, val) + + val, err = store.Get(t.Context(), "cancel-secret") + require.Error(t, err) + require.Contains(t, err.Error(), "timeout fetching ") + require.Empty(t, val) + + time.Sleep(10 * time.Second) + synctest.Wait() + + require.Equal(t, int32(2), provider.calls.Load()) + }) +} + +func TestClose_LifecycleBoundaries(t *testing.T) { + t.Parallel() + + t.Run("Get rejects calls immediately after close", func(t *testing.T) { + provider := &mockProvider{} + cfg := kv.CacheConfig{Enabled: true, TTL: "1m"} + store := newTestStore(t, provider, cfg) + + err := store.Close(t.Context()) + require.NoError(t, err) + + val, err := store.Get(t.Context(), "any-key") + assert.ErrorIs(t, err, kv.ErrStoreClosed) + assert.Empty(t, val) + assert.Equal(t, int32(0), provider.calls.Load(), "Should never hit provider once closed") + }) + + t.Run("In-flight foreground fetches do not write to cache on mid-flight close", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + provider := &mockProvider{ + delay: 2 * time.Second, + } + cfg := kv.CacheConfig{Enabled: true, TTL: "10s"} + store := newTestStore(t, provider, cfg) + + var wg sync.WaitGroup + wg.Go(func() { + _, err := store.Get(t.Context(), "mid-flight-key") + require.NoError(t, err) + }) + + // Give the goroutine a small virtual tick to enter the provider block + time.Sleep(100 * time.Millisecond) + + // Suddenly close the store while the provider call is working + err := store.Close(t.Context()) + require.NoError(t, err) + + // Let the provider finish its work + time.Sleep(2 * time.Second) + synctest.Wait() + wg.Wait() + + // Because store was closed mid-flight, the singleflight shouldn't poison/write to cache. + _, err = store.Get(t.Context(), "mid-flight-key") + assert.ErrorIs(t, err, kv.ErrStoreClosed) + }) + }) + + t.Run("Background refresh drops writes if closed mid-execution", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + provider := &mockProvider{ + delay: 500 * time.Millisecond, + } + cfg := kv.CacheConfig{ + Enabled: true, + TTL: "2s", + RefreshBeforeExpiry: "1s", + } + store := newTestStore(t, provider, cfg) + + _, err := store.Get(t.Context(), "refresh-key") + require.NoError(t, err) + + // Move virtual clock into the refresh window + time.Sleep(1200 * time.Millisecond) + + // Trigger the background task by asking for it + _, err = store.Get(t.Context(), "refresh-key") + require.NoError(t, err) + + err = store.Close(t.Context()) + require.NoError(t, err) + + // Advance past the provider delay so background worker wraps up + time.Sleep(600 * time.Millisecond) + synctest.Wait() + + assert.True(t, provider.closed.Load()) + }) + }) + + t.Run("Close is idempotent", func(t *testing.T) { + provider := &mockProvider{} + cfg := kv.CacheConfig{Enabled: true, TTL: "1s"} + store := newTestStore(t, provider, cfg) + + var wg sync.WaitGroup + for range 10 { + wg.Go(func() { + err := store.Close(t.Context()) + require.NoError(t, err) + }) + } + + wg.Wait() + + assert.True(t, provider.closed.Load()) + }) +} + +func newTestStore(t *testing.T, provider kv.Provider, cfg kv.CacheConfig) *SecretStore { + t.Helper() + + store, err := NewSecretStore("test", provider, cfg) + require.NoError(t, err) + t.Cleanup(func() { + store.Close(t.Context()) + }) + + return store +} diff --git a/kv/logger.go b/kv/logger.go new file mode 100644 index 00000000..f7f20cad --- /dev/null +++ b/kv/logger.go @@ -0,0 +1,11 @@ +package kv + +type Logger interface { + Warn(msg string, fields map[string]any) + Warnf(format string, args ...any) +} + +type NoopLogger struct{} + +func (NoopLogger) Warn(_ string, _ map[string]any) {} +func (NoopLogger) Warnf(_ string, _ ...any) {} diff --git a/kv/provider.go b/kv/provider.go new file mode 100644 index 00000000..81b3ce8b --- /dev/null +++ b/kv/provider.go @@ -0,0 +1,152 @@ +package kv + +import ( + "context" + "encoding/json" + "time" +) + +// ProviderType represents the unique string identifier for a KV provider. +type ProviderType string + +const ( + // --- Open Source (OSS) Providers --- + + // Env resolves secrets from environment variables. + Env ProviderType = "env" + + // Inline resolves secrets from plain text in the configuration. + Inline ProviderType = "inline" + + // Vault resolves secrets from HashiCorp Vault. + Vault ProviderType = "hashicorp_vault" + + // Consul resolves secrets from HashiCorp Consul. + Consul ProviderType = "hashicorp_consul" + + // K8s resolves secrets from Kubernetes Secrets mounted as files. + K8s ProviderType = "k8s_files" + + // --- Enterprise Edition (EE) Providers --- + + // AWS resolves secrets from AWS Secrets Manager. + AWS ProviderType = "aws_secrets_manager" + + // GCP resolves secrets from Google Cloud Secret Manager. + GCP ProviderType = "gcp_secret_manager" + + // Azure resolves secrets from Azure Key Vault. + Azure ProviderType = "azure_key_vault" + + // Conjur resolves secrets from CyberArk Conjur. + Conjur ProviderType = "cyberark_conjur" +) + +// KeyValueRetriever defines the core read capability for retrieving values by key. +type KeyValueRetriever interface { + Get(ctx context.Context, key string) (string, error) +} + +// Provider is the composite interface that all KV providers must implement. +// Currently only requires read access via KeyValueRetriever, but designed +// for future expansion. +// +// Providers may optionally implement Initializer, Closer, HealthChecker, +// or Lister interfaces for additional capabilities that will be detected +// via type assertion during registry operations. +type Provider interface { + KeyValueRetriever +} + +// ProviderFactory creates a specific provider instance from raw JSON configuration. +// Each provider type registers its own factory function that knows how to parse +// its specific configuration format and return a configured Provider. +// +// The factory pattern allows the registry to create providers dynamically +// without compile-time dependencies on specific provider implementations. +type ProviderFactory func(config json.RawMessage) (Provider, error) + +// Initializer is an optional interface for providers that require network +// initialization or connection establishment before use. +type Initializer interface { + Init(ctx context.Context) error +} + +// Lister is an optional interface for providers that support enumerating +// keys by prefix. This enables dynamic discovery of available secrets +// and operational tooling. +type Lister interface { + List(ctx context.Context, prefix string) ([]string, error) +} + +// Closer is an optional interface for providers that need graceful shutdown +// or resource cleanup when the registry is closed. +type Closer interface { + Close(ctx context.Context) error +} + +// Standalone is an optional interface for providers that do not need +// to be combined with caching or singleflight mechanisms. +type Standalone interface { + IsStandalone() bool +} + +// Timeouter is an optional interface for providers that expose a custom +// duration configuration for operations. +type Timeouter interface { + Timeout() time.Duration +} + +// AsLister attempts to extract a Lister from a Provider, +// automatically unwrapping decorators. +func AsLister(p Provider) (Lister, bool) { + return As[Lister](p) +} + +// AsInitializer attempts to extract an Initializer from a Provider, +// automatically unwrapping decorators. +func AsInitializer(p Provider) (Initializer, bool) { + return As[Initializer](p) +} + +// AsCloser attempts to extract an Closer from a Provider, +// automatically unwrapping decorators. +func AsCloser(p Provider) (Closer, bool) { + return As[Closer](p) +} + +// AsStandalone attempts to extract a Standalone from a Provider. +func AsStandalone(p Provider) (Standalone, bool) { + return As[Standalone](p) +} + +// AsTimeouter attempts to extract a Timeouter from a Provider. +func AsTimeouter(p Provider) (Timeouter, bool) { + return As[Timeouter](p) +} + +// As attempts to extract an interface of type T from a Provider, +// automatically unwrapping decorators up to a maximum depth. +func As[T any](p Provider) (T, bool) { + const maxDepth = 100 + var zero T + + for range maxDepth { + if p == nil { + return zero, false + } + + if v, ok := p.(T); ok { + return v, true + } + + wrapper, ok := p.(interface{ Unwrap() Provider }) + if !ok { + return zero, false + } + + p = wrapper.Unwrap() + } + + return zero, false +} diff --git a/kv/provider_test.go b/kv/provider_test.go new file mode 100644 index 00000000..f72b3320 --- /dev/null +++ b/kv/provider_test.go @@ -0,0 +1,27 @@ +package kv + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" +) + +type mockProvider struct{} + +func (m *mockProvider) Get(ctx context.Context, path string) (string, error) { + return "", nil +} + +func (m *mockProvider) Unwrap() Provider { + return m +} + +func TestAs_CircularDependencyWontFail(t *testing.T) { + t.Parallel() + + m := &mockProvider{} + _, ok := As[Closer](m) + + require.False(t, ok) +} diff --git a/kv/registry/registry.go b/kv/registry/registry.go new file mode 100644 index 00000000..1490aaf0 --- /dev/null +++ b/kv/registry/registry.go @@ -0,0 +1,306 @@ +package registry + +import ( + "context" + "errors" + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/TykTechnologies/storage/kv" + "github.com/TykTechnologies/storage/kv/internal/store" + + "golang.org/x/sync/errgroup" +) + +// Registry manages provider factories and initialized stores without global state. +// It provides a clean separation between provider registration (factories) and +// runtime instances (stores), enabling components to control their own KV lifecycle. +// +// All operations are safe for concurrent use. +type Registry struct { + factories map[kv.ProviderType]kv.ProviderFactory + stores map[string]kv.Provider + mu sync.RWMutex + isInitialized atomic.Bool + logger kv.Logger +} + +type Option func(r *Registry) + +func WithLogger(l kv.Logger) Option { + return func(r *Registry) { + if l != nil { + r.logger = l + } + } +} + +// NewRegistry creates a new empty registry with no registered factories or stores. +func NewRegistry(opts ...Option) *Registry { + r := &Registry{ + factories: make(map[kv.ProviderType]kv.ProviderFactory), + stores: make(map[string]kv.Provider), + logger: kv.NoopLogger{}, + } + + for _, opt := range opts { + opt(r) + } + + return r +} + +// NewDefaultRegistry creates a registry with added OSS providers. +func NewDefaultRegistry(opts ...Option) *Registry { + r := NewRegistry(opts...) + + // TODO: Uncomment provider registration when implementation is added + // r.Add(kv.Env, env.NewFactory()) + // r.Add(kv.Inline, inline.NewFactory()) + // r.Add(kv.Vault, vault.NewFactory()) + // r.Add(kv.Consul, consul.NewFactory()) + // r.Add(kv.K8s, k8s.NewFactory()) + + return r +} + +// Add registers a provider factory for the given provider type. +func (r *Registry) Add(pt kv.ProviderType, factory kv.ProviderFactory) error { + if pt == "" { + return errors.New("provider type cannot be empty") + } + + if factory == nil { + return errors.New("factory cannot be nil") + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Safe check within the write lock prevents the TOCTOU race condition + if _, ok := r.factories[pt]; ok { + return fmt.Errorf("factory for type %q is already provided; override is not allowed", pt) + } + + r.factories[pt] = factory + + return nil +} + +// InitStores initializes named store instances using registered provider factories. +// The configs map keys become the store names used in KV references. +// +// If a store is marked as required:true and fails to initialize, InitStores +// returns an error. Optional stores (required:false) log warnings but don't +// fail the initialization process. +// +// Example config: +// +// { +// "kv": { +// "cache": { +// "enabled": true, +// "ttl": "60s" +// }, +// "stores": { +// "vault-prod": { +// "type": "vault", +// "required": true, +// "config": { +// "address": "https://vault.internal:8200", +// "token": "kv://env/VAULT_TOKEN" +// } +// } +// } +// } +// } +func (r *Registry) InitStores(ctx context.Context, config *kv.Config) (err error) { + r.mu.RLock() + factoriesCount := len(r.factories) + r.mu.RUnlock() + + if factoriesCount == 0 { + return errors.New("factories must be added before initialize stores") + } + + if config == nil || config.Stores == nil { + return nil + } + + if r.isInitialized.Swap(true) { + return errors.New("stores have been initialized") + } + + var tempMu sync.Mutex + tempStores := make(map[string]kv.Provider, len(config.Stores)) + + // This defer block guarantees cleanup of partially initialized stores if the + // overall initialization process fails, preventing resource leaks. + defer func() { + if err != nil { + r.isInitialized.Store(false) + + cleanupCtx := context.WithoutCancel(ctx) + + tempMu.Lock() + defer tempMu.Unlock() + + for _, store := range tempStores { + if closer, ok := kv.AsCloser(store); ok { + _ = closer.Close(cleanupCtx) + } + } + } + }() + + eg, egCtx := errgroup.WithContext(ctx) + + for name, storeCfg := range config.Stores { + r.mu.RLock() + factory, ok := r.factories[storeCfg.Type] + r.mu.RUnlock() + + if !ok { + initErr := fmt.Errorf("unknown provider type %q for store %q", storeCfg.Type, name) + if storeCfg.Required { + return initErr + } + + r.logger.Warn("Skipping optional store initialization", map[string]any{ + "store": name, + "error": initErr, + }) + + continue + } + + eg.Go(func() error { + store, initErr := buildSingleStore(egCtx, name, storeCfg, config.Cache, factory) + if initErr != nil { + if storeCfg.Required { + return initErr + } + + r.logger.Warn("Skipping optional store initialization", map[string]any{ + "store": name, + "error": initErr, + }) + + return nil + } + + tempMu.Lock() + tempStores[name] = store + tempMu.Unlock() + + return nil + }) + } + + err = eg.Wait() + if err != nil { + return err + } + + r.mu.Lock() + defer r.mu.Unlock() + + // Double-check registry wasn't closed during the initialization + if !r.isInitialized.Load() { + return errors.New("registry was closed during initialization") + } + + for name, store := range tempStores { + r.stores[name] = store + } + + return nil +} + +// GetStore retrieves an initialized store by name. +// Returns ErrStoreNotFound if no store with the given name was initialized. +func (r *Registry) GetStore(name string) (kv.Provider, error) { + r.mu.RLock() + defer r.mu.RUnlock() + + store, ok := r.stores[name] + if !ok { + return nil, kv.NewStoreNotFoundError(name) + } + + return store, nil +} + +// Close gracefully shuts down all initialized stores. +func (r *Registry) Close(ctx context.Context) error { + r.mu.Lock() + stores := r.stores + r.stores = make(map[string]kv.Provider) + r.isInitialized.Store(false) + r.mu.Unlock() + + var ( + mu sync.Mutex + errs []error + wg sync.WaitGroup + ) + + for name, store := range stores { + wg.Go(func() { + if closer, ok := kv.AsCloser(store); ok { + if err := closer.Close(ctx); err != nil { + mu.Lock() + errs = append(errs, fmt.Errorf("failed to close store %q: %w", name, err)) + mu.Unlock() + } + } + }) + } + + wg.Wait() + + return errors.Join(errs...) +} + +func buildSingleStore( + ctx context.Context, + name string, + storeCfg kv.StoreConfig, + cacheCfg kv.CacheConfig, + factory kv.ProviderFactory, +) (kv.Provider, error) { + provider, err := factory(storeCfg.Config) + if err != nil { + return nil, fmt.Errorf("failed to create provider %q (type: %s): %w", name, storeCfg.Type, err) + } + + if initializer, ok := kv.AsInitializer(provider); ok { + err := initializer.Init(ctx) + if err != nil { + return nil, fmt.Errorf("failed to initialize store %q (type: %s): %w", name, storeCfg.Type, err) + } + } + + if s, ok := kv.AsStandalone(provider); ok && s.IsStandalone() { + return provider, nil + } + + var timeout time.Duration + if t, ok := kv.AsTimeouter(provider); ok { + timeout = t.Timeout() + } + + ss, err := store.NewSecretStore( + name, + provider, + cacheCfg, + store.WithTimeout(timeout), + ) + if err != nil { + return nil, fmt.Errorf("failed to wrap store %q: %w", name, err) + } + + return ss, nil +} diff --git a/kv/registry/registry_test.go b/kv/registry/registry_test.go new file mode 100644 index 00000000..1d5928bf --- /dev/null +++ b/kv/registry/registry_test.go @@ -0,0 +1,592 @@ +package registry + +import ( + "context" + "encoding/json" + "errors" + "sync" + "sync/atomic" + "testing" + "testing/synctest" + "time" + + "github.com/TykTechnologies/storage/kv" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type mockProvider struct { + initFunc func(ctx context.Context) error + closeFunc func(ctx context.Context) error + isStandalone bool + calls atomic.Int32 +} + +func (m *mockProvider) Get(ctx context.Context, key string) (string, error) { + m.calls.Add(1) + return "value", nil +} + +func (m *mockProvider) Init(ctx context.Context) error { + if m.initFunc != nil { + return m.initFunc(ctx) + } + + return nil +} + +func (m *mockProvider) Close(ctx context.Context) error { + if m.closeFunc != nil { + return m.closeFunc(ctx) + } + + return nil +} + +func (m *mockProvider) IsStandalone() bool { + return m.isStandalone +} + +type mockLogger struct { + warnCalls int +} + +func (l *mockLogger) Warn(_ string, _ map[string]any) { + l.warnCalls++ +} +func (*mockLogger) Warnf(_ string, _ ...any) {} + +func newFactory(initFunc, closeFunc func(ctx context.Context) error) kv.ProviderFactory { + return func(config json.RawMessage) (kv.Provider, error) { + return &mockProvider{ + initFunc: initFunc, + closeFunc: closeFunc, + }, nil + } +} + +func TestNewRegistry(t *testing.T) { + t.Parallel() + + registry := NewRegistry() + require.NotNil(t, registry) + require.NotNil(t, registry.stores) + require.NotNil(t, registry.factories) +} + +// TODO: Update the test case when providers are set, to assert +// that all OSS providers are registered. +func TestNewDefaultRegistry(t *testing.T) {} + +func TestAddFactory(t *testing.T) { + t.Parallel() + + t.Run("successful registration", func(t *testing.T) { + r := NewRegistry() + + err := r.Add(kv.Env, newFactory(nil, nil)) + require.NoError(t, err) + + err = r.Add(kv.Inline, newFactory(nil, nil)) + require.NoError(t, err) + + require.Len(t, r.factories, 2) + }) + + t.Run("reject empty provider type", func(t *testing.T) { + r := NewRegistry() + err := r.Add("", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot be empty") + + require.Len(t, r.factories, 0) + }) + + t.Run("prevent factory duplication override", func(t *testing.T) { + r := NewRegistry() + err := r.Add("env", newFactory(nil, nil)) + require.NoError(t, err) + + err = r.Add("env", newFactory(nil, nil)) + require.Error(t, err) + require.Contains(t, err.Error(), "is already provided") + + require.Len(t, r.factories, 1) + }) + + t.Run("reject nil factory", func(t *testing.T) { + r := NewRegistry() + err := r.Add("env", nil) + require.Error(t, err) + require.Contains(t, err.Error(), "factory cannot be nil") + require.Len(t, r.factories, 0) + }) +} + +func TestGetStore(t *testing.T) { + t.Parallel() + + r := NewRegistry() + + err := r.Add("valid", newFactory(nil, nil)) + require.NoError(t, err) + + err = r.InitStores(t.Context(), &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "valid-1": {Type: "valid", Required: true}, + }, + }) + require.NoError(t, err) + + p, err := r.GetStore("valid-1") + require.NoError(t, err) + require.NotNil(t, p) +} + +func TestInitStores_BlastRadius(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + storeType kv.ProviderType + required bool + factoryErr error + initErr error + expectError bool + } + + table := []testCase{ + { + name: "required store initializes perfectly", + storeType: kv.Env, + required: true, + expectError: false, + }, + { + name: "unregistered provider type fails if required", + storeType: kv.Inline, + required: true, + expectError: true, + }, + { + name: "unregistered provider type skipped if optional", + storeType: kv.Inline, + required: false, + expectError: false, + }, + { + name: "factory generation failure blocks startup if required", + storeType: kv.Env, + required: true, + factoryErr: errors.New("bad config content"), + expectError: true, + }, + { + name: "factory generation failure skipped if optional", + storeType: kv.Env, + required: false, + factoryErr: errors.New("bad config content"), + expectError: false, + }, + { + name: "network init phase failure blocks startup if required", + storeType: kv.Env, + required: true, + initErr: errors.New("vault unreachable"), + expectError: true, + }, + { + name: "network init phase failure skipped if not required", + storeType: kv.Env, + required: false, + initErr: errors.New("consul dead"), + expectError: false, + }, + } + + for _, tc := range table { + t.Run(tc.name, func(t *testing.T) { + reg := NewRegistry() + + // Adding factory for env provider + err := reg.Add(kv.Env, func(cfg json.RawMessage) (kv.Provider, error) { + if tc.factoryErr != nil { + return nil, tc.factoryErr + } + + return &mockProvider{ + initFunc: func(ctx context.Context) error { + return tc.initErr + }, + }, nil + }) + require.NoError(t, err) + + config := &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "target-store": { + Type: tc.storeType, + Required: tc.required, + }, + }, + } + + err = reg.InitStores(t.Context(), config) + if tc.expectError { + require.Error(t, err) + } + }) + } +} + +func TestInitStores_EdgeCases(t *testing.T) { + t.Parallel() + + t.Run("returns nil early without initializing stores when config is empty", func(t *testing.T) { + r := NewRegistry() + err := r.Add("mock", newFactory(nil, nil)) + require.NoError(t, err) + + err = r.InitStores(t.Context(), nil) + assert.NoError(t, err) + + err = r.InitStores(t.Context(), &kv.Config{}) + assert.NoError(t, err) + + assert.False(t, r.isInitialized.Load()) + }) + + t.Run("returns error if no factory provided", func(t *testing.T) { + r := NewRegistry() + err := r.InitStores(t.Context(), &kv.Config{}) + require.Error(t, err) + require.Contains(t, err.Error(), "factories must be added before initialize stores") + require.False(t, r.isInitialized.Load()) + }) + + t.Run("should be called once unless Close() was called", func(t *testing.T) { + reg := NewRegistry() + + err := reg.Add("mock", newFactory(nil, nil)) + require.NoError(t, err) + + config := &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "target-store": { + Type: "mock", + Required: true, + }, + }, + } + + err = reg.InitStores(t.Context(), config) + require.NoError(t, err) + + err = reg.InitStores(t.Context(), config) + require.Error(t, err) + require.Contains(t, err.Error(), "stores have been initialized") + + err = reg.Close(t.Context()) + require.NoError(t, err) + + err = reg.InitStores(t.Context(), config) + require.NoError(t, err) + }) + + t.Run("should close temporarily added stores when required store failed", func(t *testing.T) { + // We have to iterate over until valid store is initialized because + // we have map non-deterministic iteration order. + for { + r := NewRegistry() + + var validInitialized bool + var closeFuncCalled bool + validFactory := newFactory(func(ctx context.Context) error { + validInitialized = true + return nil + }, func(ctx context.Context) error { + closeFuncCalled = true + return nil + }) + + err := r.Add("valid", validFactory) + require.NoError(t, err) + + invalidFactory := newFactory(func(ctx context.Context) error { + return errors.New("init error") + }, nil) + err = r.Add("invalid", invalidFactory) + require.NoError(t, err) + + err = r.InitStores(t.Context(), &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "valid-1": {Type: "valid", Required: true}, + "invalid-1": {Type: "invalid", Required: true}, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to initialize store") + + if validInitialized { + require.True( + t, + closeFuncCalled, + "expected temporarily added store to be closed, but Close was not called", + ) + require.False(t, r.isInitialized.Load()) + + break + } + } + }) + + t.Run("should handle error returned by secret store wrapper", func(t *testing.T) { + r := NewRegistry() + + err := r.Add("valid", newFactory(nil, nil)) + require.NoError(t, err) + + err = r.InitStores(t.Context(), &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "valid-1": {Type: "valid", Required: true}, + }, + Cache: kv.CacheConfig{ + Enabled: true, + TTL: "-10s", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "failed to wrap store") + require.False(t, r.isInitialized.Load()) + }) + + t.Run("should log warning when optional store failed one of the steps", func(t *testing.T) { + l := &mockLogger{} + r := NewRegistry(WithLogger(l)) + + err := r.Add("valid", newFactory(nil, nil)) + require.NoError(t, err) + + err = r.InitStores(t.Context(), &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "valid-1": {Type: "valid", Required: false}, + }, + Cache: kv.CacheConfig{ + Enabled: true, + TTL: "-10s", + }, + }) + require.NoError(t, err) + require.Equal(t, 1, l.warnCalls) + }) + + t.Run("should skip secret store wrapping if provider is standalone", func(t *testing.T) { + r := NewRegistry() + + err := r.Add("valid", func(_ json.RawMessage) (kv.Provider, error) { + return &mockProvider{isStandalone: true}, nil + }) + require.NoError(t, err) + + err = r.InitStores(t.Context(), &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "valid-1": {Type: "valid", Required: true}, + }, + }) + require.NoError(t, err) + require.Len(t, r.stores, 1) + }) + + t.Run("should initialize stores concurrently", func(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + r := NewRegistry() + + err := r.Add("valid", newFactory(func(ctx context.Context) error { + time.Sleep(10 * time.Second) + return nil + }, nil)) + require.NoError(t, err) + + start := time.Now() + err = r.InitStores(t.Context(), &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "valid-1": {Type: "valid", Required: true}, + "valid-2": {Type: "valid", Required: true}, + "valid-3": {Type: "valid", Required: true}, + "valid-4": {Type: "valid", Required: true}, + "valid-5": {Type: "valid", Required: true}, + }, + }) + require.NoError(t, err) + synctest.Wait() + + elapsed := time.Since(start) + + require.Equal(t, 10*time.Second, elapsed) + require.Len(t, r.stores, 5) + }) + }) +} + +func TestClose(t *testing.T) { + t.Parallel() + + t.Run("can be called multiple times without error", func(t *testing.T) { + reg := NewRegistry() + err := reg.Close(t.Context()) + require.NoError(t, err) + err = reg.Close(t.Context()) + require.NoError(t, err) + err = reg.Close(t.Context()) + require.NoError(t, err) + }) + + t.Run("aggregates multiple errors", func(t *testing.T) { + reg := NewRegistry() + + err := reg.Add("mock", func(cfg json.RawMessage) (kv.Provider, error) { + return &mockProvider{ + closeFunc: func(ctx context.Context) error { + return errors.New("cleanup failed") + }, + }, nil + }) + require.NoError(t, err) + + config := &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "store-1": {Type: "mock", Required: true}, + "store-2": {Type: "mock", Required: true}, + }, + } + + err = reg.InitStores(t.Context(), config) + require.NoError(t, err) + + closeErr := reg.Close(t.Context()) + require.Error(t, closeErr) + + if err, ok := closeErr.(interface{ Unwrap() []error }); ok { + er := err.Unwrap() + require.Len(t, er, 2) + } else { + t.Error("close error must be a result of errors.Join which implements Unwrap") + } + }) +} + +func TestRegistry_Concurrency(t *testing.T) { + t.Parallel() + + reg := NewRegistry() + + err := reg.Add("static-type", newFactory(nil, nil)) + require.NoError(t, err) + + var wg sync.WaitGroup + + wg.Go(func() { + err := reg.Add(kv.Env, newFactory(nil, nil)) + require.NoError(t, err) + }) + + wg.Go(func() { + _, err := reg.GetStore("any-store") + require.ErrorIs(t, err, kv.ErrStoreNotFound) + }) + + wg.Wait() +} + +func TestConcurrentInitStoresAndCloseAreHandledCorrectly(t *testing.T) { + t.Parallel() + + reg := NewRegistry() + + inInit := make(chan struct{}) + closeDone := make(chan struct{}) + + err := reg.Add("mock", newFactory(func(ctx context.Context) error { + close(inInit) + <-closeDone + + return nil + }, nil)) + require.NoError(t, err) + + var wg sync.WaitGroup + + // While InitStores is running another goroutine calls the Close method + wg.Go(func() { + config := &kv.Config{ + Stores: map[string]kv.StoreConfig{ + "store-1": {Type: "mock", Required: true}, + }, + } + + err := reg.InitStores(t.Context(), config) + require.Error(t, err) + require.Contains(t, err.Error(), "registry was closed during initialization") + }) + + wg.Go(func() { + <-inInit + + err := reg.Close(t.Context()) + require.NoError(t, err) + + close(closeDone) + }) + + wg.Wait() +} + +func TestInitStores_CacheCleanupSurvivesInitialization(t *testing.T) { + synctest.Test(t, func(t *testing.T) { + r := NewRegistry() + + t.Cleanup(func() { + r.Close(t.Context()) + }) + + p := &mockProvider{} + err := r.Add("test", func(config json.RawMessage) (kv.Provider, error) { + return p, nil + }) + require.NoError(t, err) + + cfg := &kv.Config{ + Cache: kv.CacheConfig{ + Enabled: true, + TTL: "1s", + }, + Stores: map[string]kv.StoreConfig{ + "test-store": {Type: "test", Required: true}, + }, + } + + err = r.InitStores(t.Context(), cfg) + require.NoError(t, err) + + store, err := r.GetStore("test-store") + require.NoError(t, err) + + // Populate cache + val, err := store.Get(t.Context(), "key1") + require.NoError(t, err) + require.Equal(t, "value", val) + require.Equal(t, int32(1), p.calls.Load()) + + // Second call should hit cache (no provider call) + _, err = store.Get(t.Context(), "key1") + require.NoError(t, err) + require.Equal(t, int32(1), p.calls.Load(), "should hit cache") + + time.Sleep(time.Second) + synctest.Wait() + + _, err = store.Get(t.Context(), "key1") + require.NoError(t, err) + require.Equal(t, int32(2), p.calls.Load(), "cache should have cleaned up expired entry") + }) +} diff --git a/kv/resolver/resolver.go b/kv/resolver/resolver.go new file mode 100644 index 00000000..d354d3ee --- /dev/null +++ b/kv/resolver/resolver.go @@ -0,0 +1,54 @@ +package resolver + +import ( + "context" + + "github.com/TykTechnologies/storage/kv/registry" +) + +// Resolver handles string replacement for KV references in configuration strings. +// It supports two syntax patterns: +// - Whole-value references: "kv://store-name/path/to/secret#field" +// - Inline references: "https://$kv{store-name:path/to/secret#field}/api/v1" +// +// The resolver works against a registry of named stores, allowing the same +// syntax to work across different provider types (Vault, Consul, AWS, etc.). +// +// JSON field extraction is supported via the #field syntax using JSON Pointer +// notation for nested field access. +type Resolver interface { + // Resolve processes the input string and replaces any KV references with + // their resolved values from the configured stores. + // + // Returns the resolved string with all KV references replaced, or an error + // if any reference cannot be resolved. + // + // If the input contains no KV references, it is returned unchanged. + Resolve(ctx context.Context, input string) (string, error) +} + +// ResolveConfig processes an entire configuration and resolves any +// KV references found within string fields. +// +// This function enables config-level resolution during component startup, +// allowing any string field in any configuration structure to contain +// KV references that will be resolved before the config is used. +// +// The resolver traverses the JSON structure recursively, applying Resolve() +// to all string values while preserving the overall structure and non-string +// fields unchanged. +func ResolveConfig(ctx context.Context, resolver Resolver, rawConfig []byte) ([]byte, error) { + return nil, nil +} + +type resolver struct { + registry *registry.Registry +} + +func NewResolver(registry *registry.Registry) Resolver { + return &resolver{registry: registry} +} + +func (r *resolver) Resolve(ctx context.Context, input string) (string, error) { + return "", nil +}