diff --git a/.github/workflows/cd.yaml b/.github/workflows/cd.yaml new file mode 100644 index 000000000..873b9f269 --- /dev/null +++ b/.github/workflows/cd.yaml @@ -0,0 +1,37 @@ +name: CD + +on: + release: + types: [published] + workflow_dispatch: + inputs: + TAG: + required: true + description: Docker image tag + default: latest + +jobs: + publish: + name: Publish the Docker image + # runs-on: ubuntu-latest + runs-on: zerion-arm-runners # self-hosted arm64 runner + permissions: + id-token: write + contents: read + steps: + - name: Generate token + id: app-token + uses: actions/create-github-app-token@v1 + with: + app-id: ${{ secrets.ZERION_CI_APP_ID }} + private-key: ${{ secrets.ZERION_CI_APP_PRIVATE_KEY }} + owner: zeriontech + - name: Zerion AWS + uses: zeriontech/zerion-github-actions/aws@v5 + with: + GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + AWS_REGION: us-east-1 + ECR_REPOSITORY: ${{ github.event.repository.name }} + ZERION_PAT: ${{ steps.app-token.outputs.token }} + PLATFORM: linux/arm64 + IMAGE_TAG: ${{ github.event.inputs.TAG || github.ref_name }} diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml deleted file mode 100644 index fc8bf19f9..000000000 --- a/.github/workflows/codeql.yml +++ /dev/null @@ -1,48 +0,0 @@ -name: "CodeQL" - -permissions: - actions: read - contents: read - security-events: write - -on: - push: - branches: ["main"] - pull_request: - branches: ["main"] - schedule: - - cron: "30 11 * * 6" - -jobs: - analyze: - name: Analyze - runs-on: ubuntu-latest - timeout-minutes: 360 - - strategy: - fail-fast: false - matrix: - language: ["go"] - - steps: - - name: Checkout repository - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - - name: Install Go - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 - with: - go-version-file: go.mod - - # Initializes the CodeQL tools for scanning. - - name: Initialize CodeQL - uses: github/codeql-action/init@181d5eefc20863364f96762470ba6f862bdef56b # v3.29.2 - with: - languages: ${{ matrix.language }} - - - name: Autobuild - uses: github/codeql-action/autobuild@181d5eefc20863364f96762470ba6f862bdef56b # v3.29.2 - - - name: Perform CodeQL Analysis - uses: github/codeql-action/analyze@181d5eefc20863364f96762470ba6f862bdef56b # v3.29.2 - with: - category: "/language:${{matrix.language}}" diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml deleted file mode 100644 index 69cce14ba..000000000 --- a/.github/workflows/main.yaml +++ /dev/null @@ -1,57 +0,0 @@ -name: Build and push :master image - -permissions: - contents: read - -on: - push: - branches: - - main - -jobs: - check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - name: check format - run: make check_format - - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - - name: Set up QEMU - uses: docker/setup-qemu-action@49b3bc8e6bdd4a60e6116a5414239cba5943d3cf # v3.2.0 - - - name: Set up Docker buildx - id: buildx - uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 - - - name: build and push docker image - run: | - echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin - VERSION=master make docker_multiarch_push # Push image tagged with "master" - make docker_multiarch_push # Push image tagged with git sha - env: - DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} - DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} - - precommits: - runs-on: ubuntu-latest - - steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 - with: - python-version: "3.9" - - - uses: actions/setup-go@0a12ed9d6a96ab950c8f026ed9f722fe0da7ef32 # v5.0.2 - with: - go-version: "1.26.2" - - - name: run pre-commits - run: | - make precommit_install - pre-commit run -a diff --git a/.github/workflows/release.yaml b/.github/workflows/release.yaml deleted file mode 100644 index 8193b01d9..000000000 --- a/.github/workflows/release.yaml +++ /dev/null @@ -1,36 +0,0 @@ -name: Build and push :release image - -permissions: - contents: read - -on: - push: - tags: - - "v*" - -jobs: - check: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - name: check format - run: make check_format - build: - runs-on: ubuntu-latest - steps: - - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - - - name: Set up QEMU - uses: docker/setup-qemu-action@49b3bc8e6bdd4a60e6116a5414239cba5943d3cf # v3.2.0 - - - name: Set up Docker buildx - id: buildx - uses: docker/setup-buildx-action@988b5a0280414f521da01fcc63a27aeeb4b104db # v3.6.1 - - - name: build and push docker image - run: | - echo "$DOCKER_PASSWORD" | docker login -u "$DOCKER_USERNAME" --password-stdin - make docker_multiarch_push - env: - DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }} - DOCKER_PASSWORD: ${{ secrets.DOCKER_PASSWORD }} diff --git a/.github/workflows/scorecard.yml b/.github/workflows/scorecard.yml deleted file mode 100644 index f053a9c89..000000000 --- a/.github/workflows/scorecard.yml +++ /dev/null @@ -1,69 +0,0 @@ -name: Scorecard supply-chain security - -permissions: - contents: read - -on: - # For Branch-Protection check. Only the default branch is supported. See - # https://github.com/ossf/scorecard/blob/main/docs/checks.md#branch-protection - branch_protection_rule: - # To guarantee Maintained check is occasionally updated. See - # https://github.com/ossf/scorecard/blob/main/docs/checks.md#maintained - schedule: - - cron: "31 17 * * 3" - push: - branches: ["main"] - -jobs: - analysis: - name: Scorecard analysis - runs-on: ubuntu-latest - permissions: - # Needed to upload the results to code-scanning dashboard. - security-events: write - # Needed to publish results and get a badge (see publish_results below). - id-token: write - # Uncomment the permissions below if installing in a private repository. - # contents: read - # actions: read - - steps: - - name: "Checkout code" - uses: actions/checkout@692973e3d937129bcbf40652eb9f2f61becf3332 # v4.1.7 - with: - persist-credentials: false - - - name: "Run analysis" - uses: ossf/scorecard-action@62b2cac7ed8198b15735ed49ab1e5cf35480ba46 # v2.4.0 - with: - results_file: results.sarif - results_format: sarif - # (Optional) "write" PAT token. Uncomment the `repo_token` line below if: - # - you want to enable the Branch-Protection check on a *public* repository, or - # - you are installing Scorecard on a *private* repository - # To create the PAT, follow the steps in https://github.com/ossf/scorecard-action#authentication-with-pat. - # repo_token: ${{ secrets.SCORECARD_TOKEN }} - - # Public repositories: - # - Publish results to OpenSSF REST API for easy access by consumers - # - Allows the repository to include the Scorecard badge. - # - See https://github.com/ossf/scorecard-action#publishing-results. - # For private repositories: - # - `publish_results` will always be set to `false`, regardless - # of the value entered here. - publish_results: true - - # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF - # format to the repository Actions tab. - - name: "Upload artifact" - uses: actions/upload-artifact@834a144ee995460fba8ed112a2fc961b36a5ec5a # v4.3.6 - with: - name: SARIF file - path: results.sarif - retention-days: 5 - - # Upload the results to GitHub's code scanning dashboard. - - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@181d5eefc20863364f96762470ba6f862bdef56b # v3.29.2 - with: - sarif_file: results.sarif diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml deleted file mode 100644 index b96b0ffa5..000000000 --- a/.github/workflows/stale.yml +++ /dev/null @@ -1,48 +0,0 @@ -permissions: - contents: read - -on: - workflow_dispatch: - schedule: - - cron: "0 */4 * * *" - -jobs: - prune_stale: - permissions: - issues: write # for actions/stale to close stale issues - pull-requests: write # for actions/stale to close stale PRs - name: Prune Stale - runs-on: ubuntu-latest - - steps: - - name: Prune Stale - uses: actions/stale@87c2b794b9b47a9bec68ae03c01aeb572ffebdb1 # v3.0.14 - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - # Different amounts of days for issues/PRs are not currently supported but there is a PR - # open for it: https://github.com/actions/stale/issues/214 - days-before-stale: 30 - days-before-close: 7 - stale-issue-message: > - This issue has been automatically marked as stale because it has not had activity in the - last 30 days. It will be closed in the next 7 days unless it is tagged "help wanted" or "no stalebot" or other activity - occurs. Thank you for your contributions. - close-issue-message: > - This issue has been automatically closed because it has not had activity in the - last 37 days. If this issue is still valid, please ping a maintainer and ask them to label it as "help wanted" or "no stalebot". - Thank you for your contributions. - stale-pr-message: > - This pull request has been automatically marked as stale because it has not had - activity in the last 30 days. It will be closed in 7 days if no further activity occurs. Please - feel free to give a status update now, ping for review, or re-open when it's ready. - Thank you for your contributions! - close-pr-message: > - This pull request has been automatically closed because it has not had - activity in the last 37 days. Please feel free to give a status update now, ping for review, or re-open when it's ready. - Thank you for your contributions! - stale-issue-label: "stale" - exempt-issue-labels: "no stalebot,help wanted" - stale-pr-label: "stale" - exempt-pr-labels: "no stalebot" - operations-per-run: 500 - ascending: true diff --git a/Makefile b/Makefile index 342e2beba..c0063af3f 100644 --- a/Makefile +++ b/Makefile @@ -20,6 +20,14 @@ endif .PHONY: bootstrap bootstrap: ; +.PHONY: proto +proto: + protoc \ + --proto_path=api \ + --go_out=. --go_opt=module=$(MODULE) \ + --go-grpc_out=. --go-grpc_opt=module=$(MODULE) \ + api/ratelimit/corrector/v1/corrector.proto + define REDIS_STUNNEL cert = private.pem pid = /var/run/stunnel.pid diff --git a/README.md b/README.md index 6cf2e51c2..2d42d0156 100644 --- a/README.md +++ b/README.md @@ -57,6 +57,7 @@ - [Prometheus](#prometheus) - [HTTP Port](#http-port) - [/json endpoint](#json-endpoint) +- [Corrector Service](#corrector-service) - [Debug Port](#debug-port) - [Local Cache](#local-cache) - [Redis](#redis) @@ -1251,6 +1252,43 @@ The response is a RateLimitResponse encoded with } ``` +# Corrector Service + +`CorrectorService.DecrementCounter` is an internal admin RPC for refunding +quota that was charged but never actually consumed (for example, when an +upstream returns a non-2xx and the caller decides to roll the slot back). +It is exposed on the existing gRPC port and on the HTTP/1 listener at +`/json/corrector/decrement`. + +Trust model: same as the data plane. No per-RPC authentication; access is +gated by network isolation and the existing TLS configuration. + +Semantics: + +- Only limits with units of HOUR or larger are eligible. Sub-hour units are + refused with `UNIT_NOT_ALLOWED` — an out-of-band correction loop can't + keep up with shorter windows. +- The caller must echo back the current bucket-start timestamp in + `original_bucket_timestamp`; if the bucket has rolled, the call returns + `BUCKET_EXPIRED` and the caller knows to drop the correction. +- Underflow is clamped at 0 atomically inside Redis (Lua). Over-decrement + returns `OK` with `new_value == 0` rather than letting the counter go + negative. +- Backend support is Redis only. Memcached returns `Unimplemented`. + +Companion setting: `LOCAL_CACHE_MAX_UNIT_SECONDS` (default `0`, no cap) +restricts the per-replica freecache "over limit" markers to units at or +below the configured cap. Set it to `3600` so HOUR+ limits — the ones the +corrector touches — never go through the local cache; otherwise a stale +marker can outlive a decrement until the bucket boundary. + +Observability: + +- Prometheus counter `ratelimit_corrector_decrements_total{domain, limit_unit, result}`. + `result` values: `ok`, `bucket_expired`, `key_not_found`, `unit_not_allowed`, `limit_not_found`, `invalid_argument`, `unimplemented`, `no_config`, `transport_error`, `service_error`, `error` (fallback for non-status backend errors). Caller-controlled `domain` collapses to `` whenever no configured limit was matched, so cardinality stays bounded. +- Structured audit log line (`component=corrector`) on every accepted call. + Validation failures log at Debug to avoid flooding on hostile traffic. + # Debug Port The debug port can be used to interact with the running process. diff --git a/api/ratelimit/corrector/v1/corrector.proto b/api/ratelimit/corrector/v1/corrector.proto new file mode 100644 index 000000000..e15b0a98d --- /dev/null +++ b/api/ratelimit/corrector/v1/corrector.proto @@ -0,0 +1,55 @@ +syntax = "proto3"; + +package ratelimit.corrector.v1; + +option go_package = "github.com/envoyproxy/ratelimit/src/proto/corrector/v1;corrector_v1"; + +// CorrectorService is an internal admin RPC that decrements long-window rate +// limit counters (HOUR or larger). It is intended to refund quota that was +// charged for requests that did not actually consume it (e.g. a non-2xx +// response from a downstream service that the caller is choosing to refund). +// +// Not exposed to public clients; trust is granted by network isolation and the +// existing TLS configuration of the data plane. +service CorrectorService { + rpc DecrementCounter(DecrementCounterRequest) returns (DecrementCounterResponse); +} + +// Entry mirrors envoy.extensions.common.ratelimit.v3.RateLimitDescriptor.Entry +// without taking a direct dependency on the upstream Envoy proto. +message Entry { + string key = 1; + string value = 2; +} + +// Descriptor mirrors envoy.extensions.common.ratelimit.v3.RateLimitDescriptor. +// Only the fields the corrector uses are included. Nested descriptors and +// per-descriptor limit overrides are intentionally omitted; this service +// looks up an existing configured limit by its key path. +message Descriptor { + repeated Entry entries = 1; +} + +message DecrementCounterRequest { + string domain = 1; + Descriptor descriptor = 2; + uint64 delta = 3; + // Unix seconds. Must equal the current bucket start for the resolved limit; + // protects against decrementing a bucket the caller no longer observes. + int64 original_bucket_timestamp = 4; +} + +message DecrementCounterResponse { + enum Code { + UNKNOWN = 0; + OK = 1; + BUCKET_EXPIRED = 2; // original_bucket_timestamp != current bucket + KEY_NOT_FOUND = 3; // Redis key absent (never written or already expired) + UNIT_NOT_ALLOWED = 4; // limit unit < HOUR + LIMIT_NOT_FOUND = 5; // descriptor matches no configured limit + } + Code code = 1; + // Resulting counter value after the decrement. Floored at 0 when the + // requested delta would have underflowed. + int64 new_value = 2; +} diff --git a/src/limiter/base_limiter.go b/src/limiter/base_limiter.go index 3ee5ce8e5..55f4385c9 100644 --- a/src/limiter/base_limiter.go +++ b/src/limiter/base_limiter.go @@ -20,10 +20,31 @@ type BaseRateLimiter struct { ExpirationJitterMaxSeconds int64 cacheKeyGenerator CacheKeyGenerator localCache *freecache.Cache + localCacheMaxUnitSeconds int64 nearLimitRatio float32 StatsManager stats.Manager } +// shouldUseLocalCache reports whether the per-replica over-limit local cache +// should be populated for a limit with the given unit. The local cache is +// correct only under monotonic increments, so deployments that decrement +// counters (e.g. via CorrectorService) should exclude long-window units to +// avoid stale "over limit" markers surviving until the next bucket boundary. +// +// Returns true when the cap is unset (zero) or when the unit divider is +// strictly below the configured cap. Strict inequality matches the documented +// recommendation of setting the cap to a unit's divider to exclude that unit +// and longer (e.g. 3600 to exclude HOUR+). +func (this *BaseRateLimiter) shouldUseLocalCache(unit pb.RateLimitResponse_RateLimit_Unit) bool { + if this.localCache == nil { + return false + } + if this.localCacheMaxUnitSeconds <= 0 { + return true + } + return utils.UnitToDivider(unit) < this.localCacheMaxUnitSeconds +} + type LimitInfo struct { limit *config.RateLimit limitBeforeIncrease uint64 @@ -108,7 +129,7 @@ func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo * this.checkOverLimitThreshold(limitInfo, hitsAddend) - if this.localCache != nil { + if this.shouldUseLocalCache(limitInfo.limit.Limit.Unit) { // Set the TTL of the local_cache to be the entire duration. // Since the cache_key gets changed once the time crosses over current time slot, the over-the-limit // cache keys in local_cache lose effectiveness. @@ -144,6 +165,7 @@ func (this *BaseRateLimiter) GetResponseDescriptorStatus(key string, limitInfo * func NewBaseRateLimit(timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager, + localCacheMaxUnitSeconds int64, ) *BaseRateLimiter { return &BaseRateLimiter{ timeSource: timeSource, @@ -151,6 +173,7 @@ func NewBaseRateLimit(timeSource utils.TimeSource, jitterRand *rand.Rand, expira ExpirationJitterMaxSeconds: expirationJitterMaxSeconds, cacheKeyGenerator: NewCacheKeyGenerator(cacheKeyPrefix), localCache: localCache, + localCacheMaxUnitSeconds: localCacheMaxUnitSeconds, nearLimitRatio: nearLimitRatio, StatsManager: statsManager, } diff --git a/src/limiter/bucket_corrector.go b/src/limiter/bucket_corrector.go new file mode 100644 index 000000000..cb7f8ba94 --- /dev/null +++ b/src/limiter/bucket_corrector.go @@ -0,0 +1,24 @@ +package limiter + +import ( + "context" + "errors" +) + +// ErrBucketNotFound signals that the bucket counter is absent in the backend +// — either never written or already expired with the bucket boundary. +// Distinct from a successful decrement that lands on 0. +var ErrBucketNotFound = errors.New("bucket counter not found") + +// BucketCorrector is the storage-facing abstraction the CorrectorService +// uses to manipulate long-window rate limit counters out-of-band. Keeping +// this separate from RateLimitCache means each backend implements only the +// operations it actually supports (memcached, for example, has no atomic +// floor-at-0 primitive and need not satisfy this interface). +type BucketCorrector interface { + // DecrementBucket atomically decrements the counter at `key` by `delta`, + // clamping at 0 so the value never goes negative. Returns the resulting + // value, or ErrBucketNotFound if the key does not exist. Implementations + // must not create the key when it is missing. + DecrementBucket(ctx context.Context, key string, delta uint64) (newValue int64, err error) +} diff --git a/src/memcached/cache_impl.go b/src/memcached/cache_impl.go index c8bcd0774..8bcae2a93 100644 --- a/src/memcached/cache_impl.go +++ b/src/memcached/cache_impl.go @@ -303,6 +303,7 @@ func runAsync(task func()) { func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, statsManager stats.Manager, nearLimitRatio float32, cacheKeyPrefix string, + localCacheMaxUnitSeconds int64, ) limiter.RateLimitCache { return &rateLimitMemcacheImpl{ client: client, @@ -311,7 +312,7 @@ func NewRateLimitCacheImpl(client Client, timeSource utils.TimeSource, jitterRan expirationJitterMaxSeconds: expirationJitterMaxSeconds, localCache: localCache, nearLimitRatio: nearLimitRatio, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager), + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager, localCacheMaxUnitSeconds), } } @@ -327,5 +328,6 @@ func NewRateLimitCacheImplFromSettings(s settings.Settings, timeSource utils.Tim statsManager, s.NearLimitRatio, s.CacheKeyPrefix, + s.LocalCacheMaxUnitSeconds, ) } diff --git a/src/proto/corrector/v1/corrector.pb.go b/src/proto/corrector/v1/corrector.pb.go new file mode 100644 index 000000000..10c61dd33 --- /dev/null +++ b/src/proto/corrector/v1/corrector.pb.go @@ -0,0 +1,400 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.11 +// protoc v7.34.1 +// source: ratelimit/corrector/v1/corrector.proto + +package corrector_v1 + +import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type DecrementCounterResponse_Code int32 + +const ( + DecrementCounterResponse_UNKNOWN DecrementCounterResponse_Code = 0 + DecrementCounterResponse_OK DecrementCounterResponse_Code = 1 + DecrementCounterResponse_BUCKET_EXPIRED DecrementCounterResponse_Code = 2 // original_bucket_timestamp != current bucket + DecrementCounterResponse_KEY_NOT_FOUND DecrementCounterResponse_Code = 3 // Redis key absent (never written or already expired) + DecrementCounterResponse_UNIT_NOT_ALLOWED DecrementCounterResponse_Code = 4 // limit unit < HOUR + DecrementCounterResponse_LIMIT_NOT_FOUND DecrementCounterResponse_Code = 5 // descriptor matches no configured limit +) + +// Enum value maps for DecrementCounterResponse_Code. +var ( + DecrementCounterResponse_Code_name = map[int32]string{ + 0: "UNKNOWN", + 1: "OK", + 2: "BUCKET_EXPIRED", + 3: "KEY_NOT_FOUND", + 4: "UNIT_NOT_ALLOWED", + 5: "LIMIT_NOT_FOUND", + } + DecrementCounterResponse_Code_value = map[string]int32{ + "UNKNOWN": 0, + "OK": 1, + "BUCKET_EXPIRED": 2, + "KEY_NOT_FOUND": 3, + "UNIT_NOT_ALLOWED": 4, + "LIMIT_NOT_FOUND": 5, + } +) + +func (x DecrementCounterResponse_Code) Enum() *DecrementCounterResponse_Code { + p := new(DecrementCounterResponse_Code) + *p = x + return p +} + +func (x DecrementCounterResponse_Code) String() string { + return protoimpl.X.EnumStringOf(x.Descriptor(), protoreflect.EnumNumber(x)) +} + +func (DecrementCounterResponse_Code) Descriptor() protoreflect.EnumDescriptor { + return file_ratelimit_corrector_v1_corrector_proto_enumTypes[0].Descriptor() +} + +func (DecrementCounterResponse_Code) Type() protoreflect.EnumType { + return &file_ratelimit_corrector_v1_corrector_proto_enumTypes[0] +} + +func (x DecrementCounterResponse_Code) Number() protoreflect.EnumNumber { + return protoreflect.EnumNumber(x) +} + +// Deprecated: Use DecrementCounterResponse_Code.Descriptor instead. +func (DecrementCounterResponse_Code) EnumDescriptor() ([]byte, []int) { + return file_ratelimit_corrector_v1_corrector_proto_rawDescGZIP(), []int{3, 0} +} + +// Entry mirrors envoy.extensions.common.ratelimit.v3.RateLimitDescriptor.Entry +// without taking a direct dependency on the upstream Envoy proto. +type Entry struct { + state protoimpl.MessageState `protogen:"open.v1"` + Key string `protobuf:"bytes,1,opt,name=key,proto3" json:"key,omitempty"` + Value string `protobuf:"bytes,2,opt,name=value,proto3" json:"value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Entry) Reset() { + *x = Entry{} + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Entry) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Entry) ProtoMessage() {} + +func (x *Entry) ProtoReflect() protoreflect.Message { + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Entry.ProtoReflect.Descriptor instead. +func (*Entry) Descriptor() ([]byte, []int) { + return file_ratelimit_corrector_v1_corrector_proto_rawDescGZIP(), []int{0} +} + +func (x *Entry) GetKey() string { + if x != nil { + return x.Key + } + return "" +} + +func (x *Entry) GetValue() string { + if x != nil { + return x.Value + } + return "" +} + +// Descriptor mirrors envoy.extensions.common.ratelimit.v3.RateLimitDescriptor. +// Only the fields the corrector uses are included. Nested descriptors and +// per-descriptor limit overrides are intentionally omitted; this service +// looks up an existing configured limit by its key path. +type Descriptor struct { + state protoimpl.MessageState `protogen:"open.v1"` + Entries []*Entry `protobuf:"bytes,1,rep,name=entries,proto3" json:"entries,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Descriptor) Reset() { + *x = Descriptor{} + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Descriptor) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Descriptor) ProtoMessage() {} + +func (x *Descriptor) ProtoReflect() protoreflect.Message { + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Descriptor.ProtoReflect.Descriptor instead. +func (*Descriptor) Descriptor() ([]byte, []int) { + return file_ratelimit_corrector_v1_corrector_proto_rawDescGZIP(), []int{1} +} + +func (x *Descriptor) GetEntries() []*Entry { + if x != nil { + return x.Entries + } + return nil +} + +type DecrementCounterRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Domain string `protobuf:"bytes,1,opt,name=domain,proto3" json:"domain,omitempty"` + Descriptor_ *Descriptor `protobuf:"bytes,2,opt,name=descriptor,proto3" json:"descriptor,omitempty"` + Delta uint64 `protobuf:"varint,3,opt,name=delta,proto3" json:"delta,omitempty"` + // Unix seconds. Must equal the current bucket start for the resolved limit; + // protects against decrementing a bucket the caller no longer observes. + OriginalBucketTimestamp int64 `protobuf:"varint,4,opt,name=original_bucket_timestamp,json=originalBucketTimestamp,proto3" json:"original_bucket_timestamp,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DecrementCounterRequest) Reset() { + *x = DecrementCounterRequest{} + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DecrementCounterRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DecrementCounterRequest) ProtoMessage() {} + +func (x *DecrementCounterRequest) ProtoReflect() protoreflect.Message { + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DecrementCounterRequest.ProtoReflect.Descriptor instead. +func (*DecrementCounterRequest) Descriptor() ([]byte, []int) { + return file_ratelimit_corrector_v1_corrector_proto_rawDescGZIP(), []int{2} +} + +func (x *DecrementCounterRequest) GetDomain() string { + if x != nil { + return x.Domain + } + return "" +} + +func (x *DecrementCounterRequest) GetDescriptor_() *Descriptor { + if x != nil { + return x.Descriptor_ + } + return nil +} + +func (x *DecrementCounterRequest) GetDelta() uint64 { + if x != nil { + return x.Delta + } + return 0 +} + +func (x *DecrementCounterRequest) GetOriginalBucketTimestamp() int64 { + if x != nil { + return x.OriginalBucketTimestamp + } + return 0 +} + +type DecrementCounterResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + Code DecrementCounterResponse_Code `protobuf:"varint,1,opt,name=code,proto3,enum=ratelimit.corrector.v1.DecrementCounterResponse_Code" json:"code,omitempty"` + // Resulting counter value after the decrement. Floored at 0 when the + // requested delta would have underflowed. + NewValue int64 `protobuf:"varint,2,opt,name=new_value,json=newValue,proto3" json:"new_value,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *DecrementCounterResponse) Reset() { + *x = DecrementCounterResponse{} + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *DecrementCounterResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*DecrementCounterResponse) ProtoMessage() {} + +func (x *DecrementCounterResponse) ProtoReflect() protoreflect.Message { + mi := &file_ratelimit_corrector_v1_corrector_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use DecrementCounterResponse.ProtoReflect.Descriptor instead. +func (*DecrementCounterResponse) Descriptor() ([]byte, []int) { + return file_ratelimit_corrector_v1_corrector_proto_rawDescGZIP(), []int{3} +} + +func (x *DecrementCounterResponse) GetCode() DecrementCounterResponse_Code { + if x != nil { + return x.Code + } + return DecrementCounterResponse_UNKNOWN +} + +func (x *DecrementCounterResponse) GetNewValue() int64 { + if x != nil { + return x.NewValue + } + return 0 +} + +var File_ratelimit_corrector_v1_corrector_proto protoreflect.FileDescriptor + +const file_ratelimit_corrector_v1_corrector_proto_rawDesc = "" + + "\n" + + "&ratelimit/corrector/v1/corrector.proto\x12\x16ratelimit.corrector.v1\"/\n" + + "\x05Entry\x12\x10\n" + + "\x03key\x18\x01 \x01(\tR\x03key\x12\x14\n" + + "\x05value\x18\x02 \x01(\tR\x05value\"E\n" + + "\n" + + "Descriptor\x127\n" + + "\aentries\x18\x01 \x03(\v2\x1d.ratelimit.corrector.v1.EntryR\aentries\"\xc7\x01\n" + + "\x17DecrementCounterRequest\x12\x16\n" + + "\x06domain\x18\x01 \x01(\tR\x06domain\x12B\n" + + "\n" + + "descriptor\x18\x02 \x01(\v2\".ratelimit.corrector.v1.DescriptorR\n" + + "descriptor\x12\x14\n" + + "\x05delta\x18\x03 \x01(\x04R\x05delta\x12:\n" + + "\x19original_bucket_timestamp\x18\x04 \x01(\x03R\x17originalBucketTimestamp\"\xf1\x01\n" + + "\x18DecrementCounterResponse\x12I\n" + + "\x04code\x18\x01 \x01(\x0e25.ratelimit.corrector.v1.DecrementCounterResponse.CodeR\x04code\x12\x1b\n" + + "\tnew_value\x18\x02 \x01(\x03R\bnewValue\"m\n" + + "\x04Code\x12\v\n" + + "\aUNKNOWN\x10\x00\x12\x06\n" + + "\x02OK\x10\x01\x12\x12\n" + + "\x0eBUCKET_EXPIRED\x10\x02\x12\x11\n" + + "\rKEY_NOT_FOUND\x10\x03\x12\x14\n" + + "\x10UNIT_NOT_ALLOWED\x10\x04\x12\x13\n" + + "\x0fLIMIT_NOT_FOUND\x10\x052\x89\x01\n" + + "\x10CorrectorService\x12u\n" + + "\x10DecrementCounter\x12/.ratelimit.corrector.v1.DecrementCounterRequest\x1a0.ratelimit.corrector.v1.DecrementCounterResponseBEZCgithub.com/envoyproxy/ratelimit/src/proto/corrector/v1;corrector_v1b\x06proto3" + +var ( + file_ratelimit_corrector_v1_corrector_proto_rawDescOnce sync.Once + file_ratelimit_corrector_v1_corrector_proto_rawDescData []byte +) + +func file_ratelimit_corrector_v1_corrector_proto_rawDescGZIP() []byte { + file_ratelimit_corrector_v1_corrector_proto_rawDescOnce.Do(func() { + file_ratelimit_corrector_v1_corrector_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_ratelimit_corrector_v1_corrector_proto_rawDesc), len(file_ratelimit_corrector_v1_corrector_proto_rawDesc))) + }) + return file_ratelimit_corrector_v1_corrector_proto_rawDescData +} + +var ( + file_ratelimit_corrector_v1_corrector_proto_enumTypes = make([]protoimpl.EnumInfo, 1) + file_ratelimit_corrector_v1_corrector_proto_msgTypes = make([]protoimpl.MessageInfo, 4) + file_ratelimit_corrector_v1_corrector_proto_goTypes = []any{ + (DecrementCounterResponse_Code)(0), // 0: ratelimit.corrector.v1.DecrementCounterResponse.Code + (*Entry)(nil), // 1: ratelimit.corrector.v1.Entry + (*Descriptor)(nil), // 2: ratelimit.corrector.v1.Descriptor + (*DecrementCounterRequest)(nil), // 3: ratelimit.corrector.v1.DecrementCounterRequest + (*DecrementCounterResponse)(nil), // 4: ratelimit.corrector.v1.DecrementCounterResponse + } +) + +var file_ratelimit_corrector_v1_corrector_proto_depIdxs = []int32{ + 1, // 0: ratelimit.corrector.v1.Descriptor.entries:type_name -> ratelimit.corrector.v1.Entry + 2, // 1: ratelimit.corrector.v1.DecrementCounterRequest.descriptor:type_name -> ratelimit.corrector.v1.Descriptor + 0, // 2: ratelimit.corrector.v1.DecrementCounterResponse.code:type_name -> ratelimit.corrector.v1.DecrementCounterResponse.Code + 3, // 3: ratelimit.corrector.v1.CorrectorService.DecrementCounter:input_type -> ratelimit.corrector.v1.DecrementCounterRequest + 4, // 4: ratelimit.corrector.v1.CorrectorService.DecrementCounter:output_type -> ratelimit.corrector.v1.DecrementCounterResponse + 4, // [4:5] is the sub-list for method output_type + 3, // [3:4] is the sub-list for method input_type + 3, // [3:3] is the sub-list for extension type_name + 3, // [3:3] is the sub-list for extension extendee + 0, // [0:3] is the sub-list for field type_name +} + +func init() { file_ratelimit_corrector_v1_corrector_proto_init() } +func file_ratelimit_corrector_v1_corrector_proto_init() { + if File_ratelimit_corrector_v1_corrector_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_ratelimit_corrector_v1_corrector_proto_rawDesc), len(file_ratelimit_corrector_v1_corrector_proto_rawDesc)), + NumEnums: 1, + NumMessages: 4, + NumExtensions: 0, + NumServices: 1, + }, + GoTypes: file_ratelimit_corrector_v1_corrector_proto_goTypes, + DependencyIndexes: file_ratelimit_corrector_v1_corrector_proto_depIdxs, + EnumInfos: file_ratelimit_corrector_v1_corrector_proto_enumTypes, + MessageInfos: file_ratelimit_corrector_v1_corrector_proto_msgTypes, + }.Build() + File_ratelimit_corrector_v1_corrector_proto = out.File + file_ratelimit_corrector_v1_corrector_proto_goTypes = nil + file_ratelimit_corrector_v1_corrector_proto_depIdxs = nil +} diff --git a/src/proto/corrector/v1/corrector_grpc.pb.go b/src/proto/corrector/v1/corrector_grpc.pb.go new file mode 100644 index 000000000..3a7ca3543 --- /dev/null +++ b/src/proto/corrector/v1/corrector_grpc.pb.go @@ -0,0 +1,138 @@ +// Code generated by protoc-gen-go-grpc. DO NOT EDIT. +// versions: +// - protoc-gen-go-grpc v1.6.1 +// - protoc v7.34.1 +// source: ratelimit/corrector/v1/corrector.proto + +package corrector_v1 + +import ( + context "context" + + grpc "google.golang.org/grpc" + codes "google.golang.org/grpc/codes" + status "google.golang.org/grpc/status" +) + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the grpc package it is being compiled against. +// Requires gRPC-Go v1.64.0 or later. +const _ = grpc.SupportPackageIsVersion9 + +const ( + CorrectorService_DecrementCounter_FullMethodName = "/ratelimit.corrector.v1.CorrectorService/DecrementCounter" +) + +// CorrectorServiceClient is the client API for CorrectorService service. +// +// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. +// +// CorrectorService is an internal admin RPC that decrements long-window rate +// limit counters (HOUR or larger). It is intended to refund quota that was +// charged for requests that did not actually consume it (e.g. a non-2xx +// response from a downstream service that the caller is choosing to refund). +// +// Not exposed to public clients; trust is granted by network isolation and the +// existing TLS configuration of the data plane. +type CorrectorServiceClient interface { + DecrementCounter(ctx context.Context, in *DecrementCounterRequest, opts ...grpc.CallOption) (*DecrementCounterResponse, error) +} + +type correctorServiceClient struct { + cc grpc.ClientConnInterface +} + +func NewCorrectorServiceClient(cc grpc.ClientConnInterface) CorrectorServiceClient { + return &correctorServiceClient{cc} +} + +func (c *correctorServiceClient) DecrementCounter(ctx context.Context, in *DecrementCounterRequest, opts ...grpc.CallOption) (*DecrementCounterResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(DecrementCounterResponse) + err := c.cc.Invoke(ctx, CorrectorService_DecrementCounter_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + +// CorrectorServiceServer is the server API for CorrectorService service. +// All implementations must embed UnimplementedCorrectorServiceServer +// for forward compatibility. +// +// CorrectorService is an internal admin RPC that decrements long-window rate +// limit counters (HOUR or larger). It is intended to refund quota that was +// charged for requests that did not actually consume it (e.g. a non-2xx +// response from a downstream service that the caller is choosing to refund). +// +// Not exposed to public clients; trust is granted by network isolation and the +// existing TLS configuration of the data plane. +type CorrectorServiceServer interface { + DecrementCounter(context.Context, *DecrementCounterRequest) (*DecrementCounterResponse, error) + mustEmbedUnimplementedCorrectorServiceServer() +} + +// UnimplementedCorrectorServiceServer must be embedded to have +// forward compatible implementations. +// +// NOTE: this should be embedded by value instead of pointer to avoid a nil +// pointer dereference when methods are called. +type UnimplementedCorrectorServiceServer struct{} + +func (UnimplementedCorrectorServiceServer) DecrementCounter(context.Context, *DecrementCounterRequest) (*DecrementCounterResponse, error) { + return nil, status.Error(codes.Unimplemented, "method DecrementCounter not implemented") +} +func (UnimplementedCorrectorServiceServer) mustEmbedUnimplementedCorrectorServiceServer() {} +func (UnimplementedCorrectorServiceServer) testEmbeddedByValue() {} + +// UnsafeCorrectorServiceServer may be embedded to opt out of forward compatibility for this service. +// Use of this interface is not recommended, as added methods to CorrectorServiceServer will +// result in compilation errors. +type UnsafeCorrectorServiceServer interface { + mustEmbedUnimplementedCorrectorServiceServer() +} + +func RegisterCorrectorServiceServer(s grpc.ServiceRegistrar, srv CorrectorServiceServer) { + // If the following call panics, it indicates UnimplementedCorrectorServiceServer was + // embedded by pointer and is nil. This will cause panics if an + // unimplemented method is ever invoked, so we test this at initialization + // time to prevent it from happening at runtime later due to I/O. + if t, ok := srv.(interface{ testEmbeddedByValue() }); ok { + t.testEmbeddedByValue() + } + s.RegisterService(&CorrectorService_ServiceDesc, srv) +} + +func _CorrectorService_DecrementCounter_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(DecrementCounterRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(CorrectorServiceServer).DecrementCounter(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: CorrectorService_DecrementCounter_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(CorrectorServiceServer).DecrementCounter(ctx, req.(*DecrementCounterRequest)) + } + return interceptor(ctx, in, info, handler) +} + +// CorrectorService_ServiceDesc is the grpc.ServiceDesc for CorrectorService service. +// It's only intended for direct use with grpc.RegisterService, +// and not to be introspected or modified (even as a copy) +var CorrectorService_ServiceDesc = grpc.ServiceDesc{ + ServiceName: "ratelimit.corrector.v1.CorrectorService", + HandlerType: (*CorrectorServiceServer)(nil), + Methods: []grpc.MethodDesc{ + { + MethodName: "DecrementCounter", + Handler: _CorrectorService_DecrementCounter_Handler, + }, + }, + Streams: []grpc.StreamDesc{}, + Metadata: "ratelimit/corrector/v1/corrector.proto", +} diff --git a/src/redis/bucket_corrector.go b/src/redis/bucket_corrector.go new file mode 100644 index 000000000..3e8b6ff2d --- /dev/null +++ b/src/redis/bucket_corrector.go @@ -0,0 +1,50 @@ +package redis + +import ( + "context" + "strconv" + + "github.com/mediocregopher/radix/v4" + + "github.com/envoyproxy/ratelimit/src/limiter" +) + +// DecrementLua atomically clamps the requested delta at the current counter +// value and then DECRBYs by the clamped amount. Returns false (nil at the +// RESP layer) when the key does not exist, so callers can distinguish +// "never written / already expired" from a successful drain-to-zero. +// +// The clamp gives a "floor at 0" guarantee without a follow-up SET, which +// would race against concurrent INCRBYs from the data plane. Exported so +// integration tests can exercise it against a real Redis without going +// through the full handler. +const DecrementLua = `local v = redis.call('GET', KEYS[1]) +if not v then return false end +local d = tonumber(ARGV[1]) +local cur = tonumber(v) +if d > cur then d = cur end +return redis.call('DECRBY', KEYS[1], d)` + +type bucketCorrector struct { + client Client +} + +// NewBucketCorrector returns a BucketCorrector backed by the given Redis +// client. The client is expected to be the same pool the data plane uses, +// so cluster routing, auth, and TLS stay consistent between increments and +// decrements. +func NewBucketCorrector(client Client) limiter.BucketCorrector { + return &bucketCorrector{client: client} +} + +func (b *bucketCorrector) DecrementBucket(_ context.Context, key string, delta uint64) (int64, error) { + var newValue int64 + maybe := radix.Maybe{Rcv: &newValue} + if err := b.client.DoCmd(&maybe, "EVAL", DecrementLua, 1, key, strconv.FormatUint(delta, 10)); err != nil { + return 0, RedisError(err.Error()) + } + if maybe.Null { + return 0, limiter.ErrBucketNotFound + } + return newValue, nil +} diff --git a/src/redis/cache_impl.go b/src/redis/cache_impl.go index 21db8b6d2..e7051e782 100644 --- a/src/redis/cache_impl.go +++ b/src/redis/cache_impl.go @@ -13,24 +13,29 @@ import ( "github.com/envoyproxy/ratelimit/src/utils" ) -func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freecache.Cache, srv server.Server, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { - closer := &utils.MultiCloser{} - var perSecondPool Client +// NewClientsFromSettings builds the Redis connection pool(s) used by both +// the data plane (rate limit cache) and the admin CorrectorService. +// Returning the clients at the composition root lets the runner share them +// across consumers instead of stashing a copy inside the cache impl. +func NewClientsFromSettings(s settings.Settings, srv server.Server) (mainClient, perSecondClient Client, closer io.Closer) { + mc := &utils.MultiCloser{} if s.RedisPerSecond { - perSecondPool = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondSocketType, + perSecondClient = NewClientImpl(srv.Scope().Scope("redis_per_second_pool"), s.RedisPerSecondTls, s.RedisPerSecondAuth, s.RedisPerSecondSocketType, s.RedisPerSecondType, s.RedisPerSecondUrl, s.RedisPerSecondPoolSize, s.RedisPerSecondPipelineWindow, s.RedisPerSecondPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv, s.RedisPerSecondTimeout, s.RedisPerSecondPoolOnEmptyBehavior, s.RedisPerSecondSentinelAuth) - closer.Closers = append(closer.Closers, perSecondPool) + mc.Closers = append(mc.Closers, perSecondClient) } - - otherPool := NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisSocketType, s.RedisType, s.RedisUrl, s.RedisPoolSize, + mainClient = NewClientImpl(srv.Scope().Scope("redis_pool"), s.RedisTls, s.RedisAuth, s.RedisSocketType, s.RedisType, s.RedisUrl, s.RedisPoolSize, s.RedisPipelineWindow, s.RedisPipelineLimit, s.RedisTlsConfig, s.RedisHealthCheckActiveConnection, srv, s.RedisTimeout, s.RedisPoolOnEmptyBehavior, s.RedisSentinelAuth) - closer.Closers = append(closer.Closers, otherPool) + mc.Closers = append(mc.Closers, mainClient) + return mainClient, perSecondClient, mc +} +func NewRateLimiterCacheImplFromSettings(s settings.Settings, mainClient, perSecondClient Client, localCache *freecache.Cache, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, statsManager stats.Manager) limiter.RateLimitCache { return NewFixedRateLimitCacheImpl( - otherPool, - perSecondPool, + mainClient, + perSecondClient, timeSource, jitterRand, expirationJitterMaxSeconds, @@ -39,5 +44,6 @@ func NewRateLimiterCacheImplFromSettings(s settings.Settings, localCache *freeca s.CacheKeyPrefix, statsManager, s.StopCacheKeyIncrementWhenOverlimit, - ), closer + s.LocalCacheMaxUnitSeconds, + ) } diff --git a/src/redis/fixed_cache_impl.go b/src/redis/fixed_cache_impl.go index 1034d3beb..7611b0c53 100644 --- a/src/redis/fixed_cache_impl.go +++ b/src/redis/fixed_cache_impl.go @@ -224,12 +224,12 @@ func (this *fixedRateLimitCacheImpl) Flush() {} func NewFixedRateLimitCacheImpl(client Client, perSecondClient Client, timeSource utils.TimeSource, jitterRand *rand.Rand, expirationJitterMaxSeconds int64, localCache *freecache.Cache, nearLimitRatio float32, cacheKeyPrefix string, statsManager stats.Manager, - stopCacheKeyIncrementWhenOverlimit bool, + stopCacheKeyIncrementWhenOverlimit bool, localCacheMaxUnitSeconds int64, ) limiter.RateLimitCache { return &fixedRateLimitCacheImpl{ client: client, perSecondClient: perSecondClient, stopCacheKeyIncrementWhenOverlimit: stopCacheKeyIncrementWhenOverlimit, - baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager), + baseRateLimiter: limiter.NewBaseRateLimit(timeSource, jitterRand, expirationJitterMaxSeconds, localCache, nearLimitRatio, cacheKeyPrefix, statsManager, localCacheMaxUnitSeconds), } } diff --git a/src/server/server.go b/src/server/server.go index 7202fa2f5..3c6f4d1b3 100644 --- a/src/server/server.go +++ b/src/server/server.go @@ -5,6 +5,7 @@ import ( pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + corrector_v1 "github.com/envoyproxy/ratelimit/src/proto/corrector/v1" "github.com/envoyproxy/ratelimit/src/provider" stats "github.com/lyft/gostats" @@ -30,6 +31,11 @@ type Server interface { AddDebugHttpEndpoint(path string, help string, handler http.HandlerFunc) AddJsonHandler(pb.RateLimitServiceServer) + /** + * Add an HTTP/1 endpoint for the CorrectorService. + */ + AddCorrectorJsonHandler(corrector_v1.CorrectorServiceServer) + /** * Returns the embedded gRPC server to be used for registering gRPC endpoints. */ diff --git a/src/server/server_impl.go b/src/server/server_impl.go index e9402da0a..3f26fee9c 100644 --- a/src/server/server_impl.go +++ b/src/server/server_impl.go @@ -19,6 +19,11 @@ import ( "google.golang.org/grpc/keepalive" "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + + corrector_v1 "github.com/envoyproxy/ratelimit/src/proto/corrector/v1" "github.com/envoyproxy/ratelimit/src/provider" "github.com/envoyproxy/ratelimit/src/stats" @@ -162,6 +167,86 @@ func (server *server) AddJsonHandler(svc pb.RateLimitServiceServer) { server.router.HandleFunc("/json", NewJsonHandler(svc)) } +// httpPeerAddr adapts an HTTP RemoteAddr into net.Addr so the corrector's +// audit log can use the same peer.FromContext path for HTTP and gRPC callers. +type httpPeerAddr string + +func (h httpPeerAddr) Network() string { return "tcp" } +func (h httpPeerAddr) String() string { return string(h) } + +// NewCorrectorJsonHandler exposes CorrectorService.DecrementCounter over the +// same HTTP/1 listener as the existing /json endpoint for tools that cannot +// speak gRPC. Like /json, it relies on network isolation for authorization; +// it does not authenticate the caller. +func NewCorrectorJsonHandler(svc corrector_v1.CorrectorServiceServer) func(http.ResponseWriter, *http.Request) { + return func(writer http.ResponseWriter, request *http.Request) { + if request.Method != http.MethodPost { + writer.Header().Set("Allow", http.MethodPost) + writeHttpStatus(writer, http.StatusMethodNotAllowed) + return + } + var req corrector_v1.DecrementCounterRequest + + body, err := io.ReadAll(request.Body) + if err != nil { + logger.Warnf("corrector: read body: %s", err.Error()) + writeHttpStatus(writer, http.StatusBadRequest) + return + } + if err := protojson.Unmarshal(body, &req); err != nil { + logger.Warnf("corrector: unmarshal: %s", err.Error()) + writeHttpStatus(writer, http.StatusBadRequest) + return + } + + // Attach the HTTP caller address as a gRPC-style peer so the handler's + // audit log records it uniformly across both transports. + ctx := peer.NewContext(request.Context(), &peer.Peer{Addr: httpPeerAddr(request.RemoteAddr)}) + + resp, err := svc.DecrementCounter(ctx, &req) + if err != nil { + // Don't log here — the handler's audit log already covers this + // call (and chose its level intentionally). Just map the gRPC + // code to an appropriate HTTP status. + writeHttpStatus(writer, httpStatusForGrpcCode(status.Code(err))) + return + } + jsonResp, err := protojson.Marshal(resp) + if err != nil { + logger.Errorf("corrector: marshal: %s", err.Error()) + writeHttpStatus(writer, http.StatusInternalServerError) + return + } + writer.Header().Set("Content-Type", "application/json") + writer.Write(jsonResp) + } +} + +// httpStatusForGrpcCode maps the small set of gRPC codes the corrector +// actually returns to sensible HTTP statuses. Anything we don't recognize +// falls back to 500 — callers should not rely on exact codes outside the +// listed set. +func httpStatusForGrpcCode(code codes.Code) int { + switch code { + case codes.OK: + return http.StatusOK + case codes.InvalidArgument: + return http.StatusBadRequest + case codes.Unimplemented: + return http.StatusNotImplemented + case codes.FailedPrecondition: + return http.StatusServiceUnavailable + case codes.Unavailable: + return http.StatusServiceUnavailable + default: + return http.StatusInternalServerError + } +} + +func (server *server) AddCorrectorJsonHandler(svc corrector_v1.CorrectorServiceServer) { + server.router.HandleFunc("/json/corrector/decrement", NewCorrectorJsonHandler(svc)) +} + func (server *server) GrpcServer() *grpc.Server { return server.grpcServer } diff --git a/src/service/corrector.go b/src/service/corrector.go new file mode 100644 index 000000000..47e192af4 --- /dev/null +++ b/src/service/corrector.go @@ -0,0 +1,286 @@ +package ratelimit + +import ( + "context" + "errors" + "strings" + + pb_struct "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" + gostats "github.com/lyft/gostats" + logger "github.com/sirupsen/logrus" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/peer" + "google.golang.org/grpc/status" + + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + corrector_v1 "github.com/envoyproxy/ratelimit/src/proto/corrector/v1" + "github.com/envoyproxy/ratelimit/src/redis" + "github.com/envoyproxy/ratelimit/src/utils" +) + +// minDecrementableUnitSeconds is the smallest limit-unit divider (in seconds) +// the corrector is willing to operate on. Shorter windows are intentionally +// excluded — an external correction loop can't keep up with per-second or +// per-minute buckets without producing more confusion than it removes. +const minDecrementableUnitSeconds = 3600 + +// ConfigSnapshotter is the narrow slice of the rate limit service the +// corrector depends on: a snapshot of the currently loaded config. The +// shadow/quota mode flags are part of the upstream signature but the +// corrector doesn't care about them — it operates on raw counters. +type ConfigSnapshotter interface { + GetCurrentConfig() (config.RateLimitConfig, bool, bool) +} + +type CorrectorService struct { + corrector_v1.UnimplementedCorrectorServiceServer + config ConfigSnapshotter + bucketCorrector limiter.BucketCorrector + cacheKeyGenerator limiter.CacheKeyGenerator + timeSource utils.TimeSource + statsStore gostats.Store +} + +// NewCorrectorService wires the handler. A nil bucketCorrector signals an +// unsupported backend (memcached today); the handler will short-circuit +// with codes.Unimplemented in that case. +func NewCorrectorService( + config ConfigSnapshotter, + bucketCorrector limiter.BucketCorrector, + cacheKeyPrefix string, + timeSource utils.TimeSource, + statsStore gostats.Store, +) *CorrectorService { + return &CorrectorService{ + config: config, + bucketCorrector: bucketCorrector, + cacheKeyGenerator: limiter.NewCacheKeyGenerator(cacheKeyPrefix), + timeSource: timeSource, + statsStore: statsStore, + } +} + +func (s *CorrectorService) DecrementCounter( + ctx context.Context, req *corrector_v1.DecrementCounterRequest, +) (resp *corrector_v1.DecrementCounterResponse, err error) { + var limitUnitLabel string + + // LIFO defers: telemetry runs *after* recover, and crucially also fires + // while a re-thrown unknown panic is unwinding — so a real bug still + // leaves a corrector audit/metric trail before the gRPC server's own + // recover takes over. + defer func() { s.emitTelemetry(ctx, req, resp, err, limitUnitLabel) }() + defer func() { + r := recover() + if r == nil { + return + } + switch t := r.(type) { + case redis.RedisError: + resp, err = nil, status.Error(codes.Unavailable, t.Error()) + case serviceError: + resp, err = nil, status.Error(codes.Internal, t.Error()) + default: + panic(r) + } + }() + + if s.bucketCorrector == nil { + return nil, status.Error(codes.Unimplemented, "corrector is not supported for the configured backend") + } + if req.GetDomain() == "" { + return nil, status.Error(codes.InvalidArgument, "domain must not be empty") + } + if req.GetDescriptor_() == nil || len(req.GetDescriptor_().GetEntries()) == 0 { + return nil, status.Error(codes.InvalidArgument, "descriptor must not be empty") + } + if req.GetDelta() == 0 { + return nil, status.Error(codes.InvalidArgument, "delta must be positive") + } + if req.GetOriginalBucketTimestamp() <= 0 { + return nil, status.Error(codes.InvalidArgument, "original_bucket_timestamp must be positive") + } + + descriptor := toEnvoyDescriptor(req.GetDescriptor_()) + + snappedConfig, _, _ := s.config.GetCurrentConfig() + if snappedConfig == nil { + // FailedPrecondition rather than Unavailable: the service is up, + // it just hasn't loaded a config yet. Keeps Unavailable reserved + // for backend (Redis) transport failures, which lets the metric + // label distinguish the two cleanly. + return nil, status.Error(codes.FailedPrecondition, "no rate limit configuration loaded") + } + limit := snappedConfig.GetLimit(ctx, req.GetDomain(), descriptor) + if limit == nil { + return &corrector_v1.DecrementCounterResponse{ + Code: corrector_v1.DecrementCounterResponse_LIMIT_NOT_FOUND, + }, nil + } + limitUnitLabel = strings.ToLower(limit.Limit.Unit.String()) + + divider := utils.UnitToDivider(limit.Limit.Unit) + if divider < minDecrementableUnitSeconds { + return &corrector_v1.DecrementCounterResponse{ + Code: corrector_v1.DecrementCounterResponse_UNIT_NOT_ALLOWED, + }, nil + } + + now := s.timeSource.UnixNow() + currentBucket := (now / divider) * divider + if currentBucket != req.GetOriginalBucketTimestamp() { + return &corrector_v1.DecrementCounterResponse{ + Code: corrector_v1.DecrementCounterResponse_BUCKET_EXPIRED, + }, nil + } + + cacheKey := s.cacheKeyGenerator.GenerateCacheKey(req.GetDomain(), descriptor, limit, now) + if cacheKey.Key == "" { + // Defensive: GenerateCacheKey only returns an empty key when limit + // is nil, which we already filtered above. Keep the guard so a + // future signature change can't silently produce a malformed EVAL. + return nil, status.Error(codes.Internal, "failed to generate cache key") + } + + newValue, decErr := s.bucketCorrector.DecrementBucket(ctx, cacheKey.Key, req.GetDelta()) + if errors.Is(decErr, limiter.ErrBucketNotFound) { + return &corrector_v1.DecrementCounterResponse{ + Code: corrector_v1.DecrementCounterResponse_KEY_NOT_FOUND, + }, nil + } + if decErr != nil { + // Surface backend transport failures through the existing recover() + // path so they map to the same gRPC code as data-plane Redis errors. + // BucketCorrector implementations are expected to return either + // ErrBucketNotFound (handled above) or redis.RedisError; anything + // else gets re-thrown as a generic service error. + if redisErr, ok := decErr.(redis.RedisError); ok { + panic(redisErr) + } + panic(serviceError(decErr.Error())) + } + + return &corrector_v1.DecrementCounterResponse{ + Code: corrector_v1.DecrementCounterResponse_OK, + NewValue: newValue, + }, nil +} + +// unresolvedDomainLabel substitutes for the caller-supplied domain on paths +// where we never confirmed it matches a configured limit. Without this the +// `domain` Prometheus label would have unbounded cardinality — a hostile or +// misbehaving caller could create one series per unique string they send. +const unresolvedDomainLabel = "" + +// emitTelemetry writes a structured audit log line and bumps a Prometheus +// counter tagged with {domain, limit_unit, result}. +// +// On paths where the limit wasn't resolved (validation errors, no config, +// LIMIT_NOT_FOUND) both domain and limit_unit collapse to a placeholder so +// the metric retains bounded cardinality. The raw request domain is still +// available in the audit log for debugging. +// +// Validation failures (InvalidArgument) drop to Debug to keep hostile spray +// off the Info-level pipeline; everything else is Info. +func (s *CorrectorService) emitTelemetry( + ctx context.Context, + req *corrector_v1.DecrementCounterRequest, + resp *corrector_v1.DecrementCounterResponse, + err error, + limitUnit string, +) { + result := resultLabel(resp, err) + var newValue int64 + if err == nil && resp != nil { + newValue = resp.GetNewValue() + } + + entry := logger.WithFields(logger.Fields{ + "component": "corrector", + "domain": req.GetDomain(), + "descriptor": formatDescriptor(req.GetDescriptor_()), + "delta": req.GetDelta(), + "bucket_ts": req.GetOriginalBucketTimestamp(), + "result": result, + "new_value": newValue, + "caller": callerAddr(ctx), + "grpc_error": errString(err), + }) + if status.Code(err) == codes.InvalidArgument { + entry.Debug("corrector decrement") + } else { + entry.Info("corrector decrement") + } + + if s.statsStore != nil { + domainLabel := req.GetDomain() + if limitUnit == "" { + // We never matched a configured limit, so domain is caller- + // controlled and unbounded. Collapse to a fixed label. + domainLabel = unresolvedDomainLabel + } + s.statsStore.ScopeWithTags("ratelimit_corrector", map[string]string{ + "domain": domainLabel, + "limit_unit": limitUnit, + "result": result, + }).NewCounter("decrements_total").Inc() + } +} + +// resultLabel produces the `result` Prometheus label / audit log value. +// On success it surfaces the response Code (ok, bucket_expired, ...). On +// error it surfaces the gRPC failure mode so dashboards can distinguish a +// Redis outage from a validation spike or a config-not-ready window. +func resultLabel(resp *corrector_v1.DecrementCounterResponse, err error) string { + if err == nil && resp != nil { + return strings.ToLower(resp.GetCode().String()) + } + switch status.Code(err) { + case codes.InvalidArgument: + return "invalid_argument" + case codes.Unimplemented: + return "unimplemented" + case codes.FailedPrecondition: + return "no_config" + case codes.Unavailable: + return "transport_error" + case codes.Internal: + return "service_error" + default: + return "error" + } +} + +func toEnvoyDescriptor(d *corrector_v1.Descriptor) *pb_struct.RateLimitDescriptor { + entries := make([]*pb_struct.RateLimitDescriptor_Entry, len(d.GetEntries())) + for i, e := range d.GetEntries() { + entries[i] = &pb_struct.RateLimitDescriptor_Entry{Key: e.GetKey(), Value: e.GetValue()} + } + return &pb_struct.RateLimitDescriptor{Entries: entries} +} + +func formatDescriptor(d *corrector_v1.Descriptor) string { + if d == nil { + return "" + } + parts := make([]string, 0, len(d.GetEntries())) + for _, e := range d.GetEntries() { + parts = append(parts, e.GetKey()+"="+e.GetValue()) + } + return strings.Join(parts, ",") +} + +func callerAddr(ctx context.Context) string { + if p, ok := peer.FromContext(ctx); ok && p.Addr != nil { + return p.Addr.String() + } + return "" +} + +func errString(err error) string { + if err == nil { + return "" + } + return err.Error() +} diff --git a/src/service/corrector_label_test.go b/src/service/corrector_label_test.go new file mode 100644 index 000000000..54e63dd0f --- /dev/null +++ b/src/service/corrector_label_test.go @@ -0,0 +1,38 @@ +package ratelimit + +import ( + "errors" + "testing" + + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + corrector_v1 "github.com/envoyproxy/ratelimit/src/proto/corrector/v1" +) + +// resultLabel is the source of truth for the corrector's `result` Prom label. +// A regression here would silently rename dashboard series, so pin every +// mapping a caller could hit. +func TestResultLabel(t *testing.T) { + okResp := &corrector_v1.DecrementCounterResponse{Code: corrector_v1.DecrementCounterResponse_OK} + cases := []struct { + name string + resp *corrector_v1.DecrementCounterResponse + err error + want string + }{ + {"OK from response code", okResp, nil, "ok"}, + {"BUCKET_EXPIRED", &corrector_v1.DecrementCounterResponse{Code: corrector_v1.DecrementCounterResponse_BUCKET_EXPIRED}, nil, "bucket_expired"}, + {"InvalidArgument", nil, status.Error(codes.InvalidArgument, "x"), "invalid_argument"}, + {"Unimplemented", nil, status.Error(codes.Unimplemented, "x"), "unimplemented"}, + {"FailedPrecondition", nil, status.Error(codes.FailedPrecondition, "x"), "no_config"}, + {"Unavailable", nil, status.Error(codes.Unavailable, "x"), "transport_error"}, + {"Internal", nil, status.Error(codes.Internal, "x"), "service_error"}, + {"plain error", nil, errors.New("x"), "error"}, + } + for _, c := range cases { + if got := resultLabel(c.resp, c.err); got != c.want { + t.Errorf("%s: got %q, want %q", c.name, got, c.want) + } + } +} diff --git a/src/service_cmd/runner/runner.go b/src/service_cmd/runner/runner.go index a3889a58d..044b3ae47 100644 --- a/src/service_cmd/runner/runner.go +++ b/src/service_cmd/runner/runner.go @@ -18,6 +18,7 @@ import ( "github.com/envoyproxy/ratelimit/src/limiter" "github.com/envoyproxy/ratelimit/src/memcached" "github.com/envoyproxy/ratelimit/src/metrics" + corrector_v1 "github.com/envoyproxy/ratelimit/src/proto/corrector/v1" "github.com/envoyproxy/ratelimit/src/redis" "github.com/envoyproxy/ratelimit/src/server" ratelimit "github.com/envoyproxy/ratelimit/src/service" @@ -89,13 +90,17 @@ func (runner *Runner) GetStatsStore() gostats.Store { return runner.statsManager.GetStatsStore() } -func createLimiter(srv server.Server, s settings.Settings, localCache *freecache.Cache, statsManager stats.Manager) (limiter.RateLimitCache, io.Closer) { +// createLimiter constructs the rate limit cache. The Redis backend takes +// pre-built clients so the runner can share them with the CorrectorService; +// memcached owns its own client lifecycle internally. +func createLimiter(srv server.Server, s settings.Settings, redisClient, redisPerSecondClient redis.Client, localCache *freecache.Cache, statsManager stats.Manager) limiter.RateLimitCache { switch s.BackendType { case "redis", "": return redis.NewRateLimiterCacheImplFromSettings( s, + redisClient, + redisPerSecondClient, localCache, - srv, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), s.ExpirationJitterMaxSeconds, @@ -108,7 +113,7 @@ func createLimiter(srv server.Server, s settings.Settings, localCache *freecache rand.New(utils.NewLockedSource(time.Now().Unix())), localCache, srv.Scope(), - statsManager), &utils.MultiCloser{} // memcache client can't be closed + statsManager) default: logger.Fatalf("Invalid setting for BackendType: %s", s.BackendType) panic("This line should not be reachable") @@ -156,11 +161,26 @@ func (runner *Runner) Run() { runner.srv = srv runner.mu.Unlock() - limiter, limiterCloser := createLimiter(srv, s, localCache, runner.statsManager) - runner.ratelimitCloser = limiterCloser + // Build Redis clients once at the composition root so the data plane + // (cache) and the admin CorrectorService share the same pool. For the + // memcached backend redisClient stays nil and the corrector returns + // Unimplemented. + var ( + redisClient redis.Client + redisPerSecondClient redis.Client + ) + if s.BackendType == "redis" || s.BackendType == "" { + var redisCloser io.Closer + redisClient, redisPerSecondClient, redisCloser = redis.NewClientsFromSettings(s, srv) + runner.ratelimitCloser = redisCloser + } else { + runner.ratelimitCloser = &utils.MultiCloser{} + } + + cache := createLimiter(srv, s, redisClient, redisPerSecondClient, localCache, runner.statsManager) service := ratelimit.NewService( - limiter, + cache, srv.Provider(), runner.statsManager, srv.HealthChecker(), @@ -186,6 +206,18 @@ func (runner *Runner) Run() { // v2 proto is no longer supported pb.RegisterRateLimitServiceServer(srv.GrpcServer(), service) + // CorrectorService is registered unconditionally. The handler talks to + // the backend through limiter.BucketCorrector; for backends that have no + // implementation (memcached) we pass nil and the handler short-circuits + // with codes.Unimplemented. + var bucketCorrector limiter.BucketCorrector + if redisClient != nil { + bucketCorrector = redis.NewBucketCorrector(redisClient) + } + corrector := ratelimit.NewCorrectorService(service, bucketCorrector, s.CacheKeyPrefix, utils.NewTimeSourceImpl(), runner.statsManager.GetStatsStore()) + corrector_v1.RegisterCorrectorServiceServer(srv.GrpcServer(), corrector) + srv.AddCorrectorJsonHandler(corrector) + srv.Start() } diff --git a/src/settings/settings.go b/src/settings/settings.go index 6a1be618f..440b86e9d 100644 --- a/src/settings/settings.go +++ b/src/settings/settings.go @@ -106,6 +106,7 @@ type Settings struct { // Settings for all cache types ExpirationJitterMaxSeconds int64 `envconfig:"EXPIRATION_JITTER_MAX_SECONDS" default:"300"` LocalCacheSizeInBytes int `envconfig:"LOCAL_CACHE_SIZE_IN_BYTES" default:"0"` + LocalCacheMaxUnitSeconds int64 `envconfig:"LOCAL_CACHE_MAX_UNIT_SECONDS" default:"0"` NearLimitRatio float32 `envconfig:"NEAR_LIMIT_RATIO" default:"0.8"` CacheKeyPrefix string `envconfig:"CACHE_KEY_PREFIX" default:""` BackendType string `envconfig:"BACKEND_TYPE" default:"redis"` diff --git a/test/integration/corrector_test.go b/test/integration/corrector_test.go new file mode 100644 index 000000000..15cfd9c22 --- /dev/null +++ b/test/integration/corrector_test.go @@ -0,0 +1,117 @@ +//go:build integration + +package integration_test + +import ( + "context" + "strconv" + "testing" + "time" + + "github.com/mediocregopher/radix/v4" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/envoyproxy/ratelimit/src/redis" +) + +// dialCorrectorRedis connects to the same Redis the rest of the integration +// suite uses (port 6381 with password auth from the Makefile). We talk to it +// directly via radix instead of through the project's wrapped Client so the +// test pins the Lua script behavior, not the wrapping. +func dialCorrectorRedis(t *testing.T) radix.Client { + t.Helper() + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + pool, err := (radix.PoolConfig{ + Dialer: radix.Dialer{AuthPass: "password123"}, + Size: 1, + }).New(ctx, "tcp", "127.0.0.1:6381") + require.NoError(t, err) + t.Cleanup(func() { pool.Close() }) + return pool +} + +func corrKey(t *testing.T) string { + t.Helper() + return "ratelimit_corrector_it_" + strconv.FormatInt(time.Now().UnixNano(), 10) +} + +// Verifies normal decrement and that the TTL set on the counter survives — +// the Lua script must not strip the expiration. +func TestCorrector_Lua_NormalDecrement(t *testing.T) { + ctx := context.Background() + cli := dialCorrectorRedis(t) + key := corrKey(t) + require.NoError(t, cli.Do(ctx, radix.FlatCmd(nil, "SET", key, 10, "EX", 3600))) + + var v int64 + require.NoError(t, cli.Do(ctx, radix.FlatCmd(&v, "EVAL", redis.DecrementLua, 1, key, 3))) + assert.Equal(t, int64(7), v) + + var ttl int64 + require.NoError(t, cli.Do(ctx, radix.Cmd(&ttl, "TTL", key))) + assert.Greater(t, ttl, int64(0), "TTL should be preserved by the script") +} + +// Over-decrement: delta > current value floors at 0 (does not go negative). +func TestCorrector_Lua_OverDecrementFloorsAtZero(t *testing.T) { + ctx := context.Background() + cli := dialCorrectorRedis(t) + key := corrKey(t) + require.NoError(t, cli.Do(ctx, radix.FlatCmd(nil, "SET", key, 5, "EX", 3600))) + + var v int64 + require.NoError(t, cli.Do(ctx, radix.FlatCmd(&v, "EVAL", redis.DecrementLua, 1, key, 100))) + assert.Equal(t, int64(0), v, "should floor at 0, not go negative") + + var stored string + require.NoError(t, cli.Do(ctx, radix.Cmd(&stored, "GET", key))) + assert.Equal(t, "0", stored) +} + +// Missing key returns nil at the RESP layer and creates no phantom key. +func TestCorrector_Lua_MissingKey(t *testing.T) { + ctx := context.Background() + cli := dialCorrectorRedis(t) + key := corrKey(t) + + var n int64 + maybe := radix.Maybe{Rcv: &n} + require.NoError(t, cli.Do(ctx, radix.FlatCmd(&maybe, "EVAL", redis.DecrementLua, 1, key, 1))) + assert.True(t, maybe.Null, "script must return nil for a missing key") + + var exists int + require.NoError(t, cli.Do(ctx, radix.Cmd(&exists, "EXISTS", key))) + assert.Equal(t, 0, exists, "script must not create a phantom key") +} + +// Concurrent INCRBYs from the data plane racing with DECRBYs from the +// corrector must leave the counter at the algebraically correct value — the +// Lua's atomicity is what makes the floor-at-0 logic safe under contention. +func TestCorrector_Lua_ConcurrentIncrDecrAlgebra(t *testing.T) { + ctx := context.Background() + cli := dialCorrectorRedis(t) + key := corrKey(t) + require.NoError(t, cli.Do(ctx, radix.FlatCmd(nil, "SET", key, 1000, "EX", 3600))) + + const N = 50 + errs := make(chan error, 2*N) + for i := 0; i < N; i++ { + go func() { + errs <- cli.Do(ctx, radix.FlatCmd(nil, "INCRBY", key, 1)) + }() + go func() { + var v int64 + errs <- cli.Do(ctx, radix.FlatCmd(&v, "EVAL", redis.DecrementLua, 1, key, 1)) + }() + } + for i := 0; i < 2*N; i++ { + require.NoError(t, <-errs) + } + + var final int64 + require.NoError(t, cli.Do(ctx, radix.Cmd(&final, "GET", key))) + // Started at 1000; +N from incrs, -N from decrs (no underflow since start > N). + assert.Equal(t, int64(1000), final) +} diff --git a/test/limiter/base_limiter_test.go b/test/limiter/base_limiter_test.go index 7bc404079..219dfddea 100644 --- a/test/limiter/base_limiter_test.go +++ b/test/limiter/base_limiter_test.go @@ -27,7 +27,7 @@ func TestGenerateCacheKeys(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) timeSource.EXPECT().UnixNow().Return(int64(1234)) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm, 0) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, false, "", nil, false)} assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) @@ -46,7 +46,7 @@ func TestGenerateCacheKeysPrefix(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) timeSource.EXPECT().UnixNow().Return(int64(1234)) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "prefix:", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "prefix:", sm, 0) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, false, "", nil, false)} assert.Equal(uint64(0), limits[0].Stats.TotalHits.Value()) @@ -65,7 +65,7 @@ func TestGenerateCacheKeysWithShareThreshold(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) timeSource.EXPECT().UnixNow().Return(int64(1234)).AnyTimes() - baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm, 0) // Test 1: Simple case - different values with same wildcard prefix generate same cache key limit := config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("files_files/*"), false, false, false, "", nil, false) @@ -144,7 +144,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { localCache := freecache.NewCache(100) localCache.Set([]byte("key"), []byte("value"), 100) sm := mockstats.NewMockStatManager(stats.NewStore(stats.NewNullSink(), false)) - baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8, "", sm, 0) // Returns true, as local cache contains over limit value for the key. assert.Equal(true, baseRateLimit.IsOverLimitWithLocalCache("key")) } @@ -154,11 +154,11 @@ func TestNoOverLimitWithLocalCache(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() sm := mockstats.NewMockStatManager(stats.NewStore(stats.NewNullSink(), false)) - baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8, "", sm, 0) // Returns false, as local cache is nil. assert.Equal(false, baseRateLimit.IsOverLimitWithLocalCache("domain_key_value_1234")) localCache := freecache.NewCache(100) - baseRateLimitWithLocalCache := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8, "", sm) + baseRateLimitWithLocalCache := limiter.NewBaseRateLimit(nil, nil, 3600, localCache, 0.8, "", sm, 0) // Returns false, as local cache does not contain value for cache key. assert.Equal(false, baseRateLimitWithLocalCache.IsOverLimitWithLocalCache("domain_key_value_1234")) } @@ -168,7 +168,7 @@ func TestGetResponseStatusEmptyKey(t *testing.T) { controller := gomock.NewController(t) defer controller.Finish() sm := mockstats.NewMockStatManager(stats.NewStore(stats.NewNullSink(), false)) - baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(nil, nil, 3600, nil, 0.8, "", sm, 0) responseStatus := baseRateLimit.GetResponseDescriptorStatus("", nil, false, 1) assert.Equal(pb.RateLimitResponse_OK, responseStatus.GetCode()) assert.Equal(uint32(0), responseStatus.GetLimitRemaining()) @@ -182,7 +182,7 @@ func TestGetResponseStatusOverLimitWithLocalCache(t *testing.T) { timeSource.EXPECT().UnixNow().Return(int64(1234)) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm, 0) limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, false, "", nil, false)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 4, 5) // As `isOverLimitWithLocalCache` is passed as `true`, immediate response is returned with no checks of the limits. @@ -204,7 +204,7 @@ func TestGetResponseStatusOverLimitWithLocalCacheShadowMode(t *testing.T) { timeSource.EXPECT().UnixNow().Return(int64(1234)) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm, 0) // This limit is in ShadowMode limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, false, "", nil, false)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 4, 5) @@ -229,7 +229,7 @@ func TestGetResponseStatusOverLimit(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) localCache := freecache.NewCache(100) sm := mockstats.NewMockStatManager(statsStore) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "", sm, 0) limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, false, "", nil, false)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 7, 4, 5) responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) @@ -245,6 +245,43 @@ func TestGetResponseStatusOverLimit(t *testing.T) { assert.Equal(uint64(0), limits[0].Stats.ShadowMode.Value()) } +func TestGetResponseStatusOverLimitLocalCacheMaxUnit(t *testing.T) { + // With localCacheMaxUnitSeconds=0 (unset), HOUR limits populate the local cache. + // With localCacheMaxUnitSeconds=3600, HOUR is excluded (divider >= cap) but + // MINUTE (divider=60 < 3600) still populates. + cases := []struct { + name string + cap int64 + unit pb.RateLimitResponse_RateLimit_Unit + expectInCache bool + }{ + {"unset_cap_hour_populates", 0, pb.RateLimitResponse_RateLimit_HOUR, true}, + {"capped_hour_excluded", 3600, pb.RateLimitResponse_RateLimit_HOUR, false}, + {"capped_minute_populates", 3600, pb.RateLimitResponse_RateLimit_MINUTE, true}, + } + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + assert := assert.New(t) + controller := gomock.NewController(t) + defer controller.Finish() + timeSource := mock_utils.NewMockTimeSource(controller) + timeSource.EXPECT().UnixNow().Return(int64(1234)).AnyTimes() + localCache := freecache.NewCache(100) + sm := mockstats.NewMockStatManager(stats.NewStore(stats.NewNullSink(), false)) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "", sm, tc.cap) + limits := []*config.RateLimit{config.NewRateLimit(5, tc.unit, sm.NewStats("key_value"), false, false, false, "", nil, false)} + limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 7, 4, 5) + baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) + _, err := localCache.Get([]byte("key")) + if tc.expectInCache { + assert.NoError(err, "expected over-limit key to be cached") + } else { + assert.Error(err, "expected over-limit key NOT to be cached") + } + }) + } +} + func TestGetResponseStatusOverLimitShadowMode(t *testing.T) { assert := assert.New(t) controller := gomock.NewController(t) @@ -254,7 +291,7 @@ func TestGetResponseStatusOverLimitShadowMode(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) localCache := freecache.NewCache(100) sm := mockstats.NewMockStatManager(statsStore) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, localCache, 0.8, "", sm, 0) // Key is in shadow_mode: true limits := []*config.RateLimit{config.NewRateLimit(5, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, false, "", nil, false)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 7, 4, 5) @@ -277,7 +314,7 @@ func TestGetResponseStatusBelowLimit(t *testing.T) { timeSource.EXPECT().UnixNow().Return(int64(1234)) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm, 0) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, false, "", nil, false)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 9, 10) responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) @@ -298,7 +335,7 @@ func TestGetResponseStatusBelowLimitShadowMode(t *testing.T) { timeSource.EXPECT().UnixNow().Return(int64(1234)) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm) + baseRateLimit := limiter.NewBaseRateLimit(timeSource, nil, 3600, nil, 0.8, "", sm, 0) limits := []*config.RateLimit{config.NewRateLimit(10, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, true, false, "", nil, false)} limitInfo := limiter.NewRateLimitInfo(limits[0], 2, 6, 9, 10) responseStatus := baseRateLimit.GetResponseDescriptorStatus("key", limitInfo, false, 1) diff --git a/test/memcached/cache_impl_test.go b/test/memcached/cache_impl_test.go index 606a9d844..ab943120f 100644 --- a/test/memcached/cache_impl_test.go +++ b/test/memcached/cache_impl_test.go @@ -44,7 +44,7 @@ func TestMemcached(t *testing.T) { client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "", 0) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( @@ -141,7 +141,7 @@ func TestMemcachedGetError(t *testing.T) { client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "", 0) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( @@ -229,7 +229,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { sink := &common.TestStatSink{} statsStore := stats.NewStore(sink, true) sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, localCache, sm, 0.8, "") + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, localCache, sm, 0.8, "", 0) localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope("localcache")) // Test Near Limit Stats. Under Near Limit Ratio @@ -331,7 +331,7 @@ func TestNearLimit(t *testing.T) { client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "", 0) // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) @@ -513,7 +513,7 @@ func TestMemcacheWithJitter(t *testing.T) { jitterSource := mock_utils.NewMockJitterRandSource(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, sm, 0.8, "") + cache := memcached.NewRateLimitCacheImpl(client, timeSource, rand.New(jitterSource), 3600, nil, sm, 0.8, "", 0) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) jitterSource.EXPECT().Int63().Return(int64(100)) @@ -556,7 +556,7 @@ func TestMemcacheAdd(t *testing.T) { client := mock_memcached.NewMockClient(controller) statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "", 0) // Test a race condition with the initial add timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) @@ -665,7 +665,7 @@ func TestMemcachedTracer(t *testing.T) { statsStore := stats.NewStore(stats.NewNullSink(), false) sm := mockstats.NewMockStatManager(statsStore) - cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "") + cache := memcached.NewRateLimitCacheImpl(client, timeSource, nil, 0, nil, sm, 0.8, "", 0) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) client.EXPECT().GetMulti([]string{"domain_key_value_1234"}).Return( diff --git a/test/redis/bench_test.go b/test/redis/bench_test.go index 2c153ba97..f4505b718 100644 --- a/test/redis/bench_test.go +++ b/test/redis/bench_test.go @@ -47,7 +47,7 @@ func BenchmarkParallelDoLimit(b *testing.B) { client := redis.NewClientImpl(statsStore, false, "", "tcp", "single", "127.0.0.1:6379", poolSize, pipelineWindow, pipelineLimit, nil, false, nil, 10*time.Second, "", "") defer client.Close() - cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "", sm, true) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, utils.NewTimeSourceImpl(), rand.New(utils.NewLockedSource(time.Now().Unix())), 10, nil, 0.8, "", sm, true, 0) request := common.NewRateLimitRequest("domain", [][][2]string{{{"key", "value"}}}, 1) limits := []*config.RateLimit{config.NewRateLimit(1000000000, pb.RateLimitResponse_RateLimit_SECOND, sm.NewStats("key_value"), false, false, false, "", nil, false)} diff --git a/test/redis/fixed_cache_impl_test.go b/test/redis/fixed_cache_impl_test.go index e52abb68d..ba49c81ab 100644 --- a/test/redis/fixed_cache_impl_test.go +++ b/test/redis/fixed_cache_impl_test.go @@ -54,9 +54,9 @@ func testRedis(usePerSecondRedis bool) func(*testing.T) { timeSource := mock_utils.NewMockTimeSource(controller) var cache limiter.RateLimitCache if usePerSecondRedis { - cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false) + cache = redis.NewFixedRateLimitCacheImpl(client, perSecondClient, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false, 0) } else { - cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false) + cache = redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false, 0) } timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) @@ -196,7 +196,7 @@ func TestOverLimitWithLocalCache(t *testing.T) { sink := common.NewTestStatSink() statsStore := gostats.NewStore(sink, false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "", sm, false) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "", sm, false, 0) localCacheScopeName := "localcache" localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope(localCacheScopeName)) @@ -299,7 +299,7 @@ func TestNearLimit(t *testing.T) { timeSource := mock_utils.NewMockTimeSource(controller) statsStore := gostats.NewStore(gostats.NewNullSink(), false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false, 0) // Test Near Limit Stats. Under Near Limit Ratio timeSource.EXPECT().UnixNow().Return(int64(1000000)).MaxTimes(3) @@ -473,7 +473,7 @@ func TestRedisWithJitter(t *testing.T) { jitterSource := mock_utils.NewMockJitterRandSource(controller) statsStore := gostats.NewStore(gostats.NewNullSink(), false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm, false) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(jitterSource), 3600, nil, 0.8, "", sm, false, 0) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) jitterSource.EXPECT().Int63().Return(int64(100)) @@ -504,7 +504,7 @@ func TestOverLimitWithLocalCacheShadowRule(t *testing.T) { sink := common.NewTestStatSink() statsStore := gostats.NewStore(sink, false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "", sm, false) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "", sm, false, 0) localCacheScopeName := "localcache" localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope(localCacheScopeName)) @@ -618,7 +618,7 @@ func TestRedisTracer(t *testing.T) { client := mock_redis.NewMockClient(controller) timeSource := mock_utils.NewMockTimeSource(controller) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, nil, 0.8, "", sm, false, 0) timeSource.EXPECT().UnixNow().Return(int64(1234)).MaxTimes(3) @@ -647,7 +647,7 @@ func TestOverLimitWithStopCacheKeyIncrementWhenOverlimitConfig(t *testing.T) { sink := common.NewTestStatSink() statsStore := gostats.NewStore(sink, false) sm := stats.NewMockStatManager(statsStore) - cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "", sm, true) + cache := redis.NewFixedRateLimitCacheImpl(client, nil, timeSource, rand.New(rand.NewSource(1)), 0, localCache, 0.8, "", sm, true, 0) localCacheScopeName := "localcache" localCacheStats := limiter.NewLocalCacheStats(localCache, statsStore.Scope(localCacheScopeName)) diff --git a/test/service/corrector_test.go b/test/service/corrector_test.go new file mode 100644 index 000000000..c98aec3a6 --- /dev/null +++ b/test/service/corrector_test.go @@ -0,0 +1,234 @@ +package ratelimit_test + +import ( + "context" + "errors" + "testing" + + pb_struct "github.com/envoyproxy/go-control-plane/envoy/extensions/common/ratelimit/v3" + pb "github.com/envoyproxy/go-control-plane/envoy/service/ratelimit/v3" + "github.com/golang/mock/gomock" + gostats "github.com/lyft/gostats" + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/envoyproxy/ratelimit/src/config" + "github.com/envoyproxy/ratelimit/src/limiter" + corrector_v1 "github.com/envoyproxy/ratelimit/src/proto/corrector/v1" + ratelimit "github.com/envoyproxy/ratelimit/src/service" + + mock_config "github.com/envoyproxy/ratelimit/test/mocks/config" + mockstats "github.com/envoyproxy/ratelimit/test/mocks/stats" + mock_utils "github.com/envoyproxy/ratelimit/test/mocks/utils" +) + +// snapshotter returns a fixed config snapshot. Trivially small — easier than +// generating a mock for the in-package ConfigSnapshotter interface. +type snapshotter struct{ cfg config.RateLimitConfig } + +func (s snapshotter) GetCurrentConfig() (config.RateLimitConfig, bool, bool) { + return s.cfg, false, false +} + +// stubBucketCorrector is a hand-rolled BucketCorrector used in tests; gomock +// adds little for a single-method interface and forces ceremony. +type stubBucketCorrector struct { + gotKey string + gotDelta uint64 + newValue int64 + err error +} + +func (s *stubBucketCorrector) DecrementBucket(_ context.Context, key string, delta uint64) (int64, error) { + s.gotKey = key + s.gotDelta = delta + return s.newValue, s.err +} + +// fixedNow lets us pin the bucket-start timestamp from a test without coupling +// the time source to a particular real clock. +const fixedNow = int64(1_700_000_000) + +func newRateLimit(unit pb.RateLimitResponse_RateLimit_Unit, store gostats.Store) *config.RateLimit { + sm := mockstats.NewMockStatManager(store) + return config.NewRateLimit(100, unit, sm.NewStats("k_v"), false, false, false, "", nil, false) +} + +func newRequest(delta uint64, bucketTs int64) *corrector_v1.DecrementCounterRequest { + return &corrector_v1.DecrementCounterRequest{ + Domain: "d", + Descriptor_: &corrector_v1.Descriptor{ + Entries: []*corrector_v1.Entry{{Key: "k", Value: "v"}}, + }, + Delta: delta, + OriginalBucketTimestamp: bucketTs, + } +} + +// matchEnvoyDescriptor checks that the *RateLimitDescriptor passed to GetLimit +// has the entries we sent — protects against accidental loss in conversion. +type matchEnvoyDescriptor struct{} + +func (matchEnvoyDescriptor) Matches(x interface{}) bool { + d, ok := x.(*pb_struct.RateLimitDescriptor) + if !ok || len(d.Entries) != 1 { + return false + } + return d.Entries[0].Key == "k" && d.Entries[0].Value == "v" +} +func (matchEnvoyDescriptor) String() string { return "descriptor{k=v}" } + +func buildHourBucketTs() int64 { return (fixedNow / 3600) * 3600 } + +func TestCorrector_InvalidArgument(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + store := gostats.NewStore(gostats.NewNullSink(), false) + bc := &stubBucketCorrector{} + + svc := ratelimit.NewCorrectorService(snapshotter{cfg: nil}, bc, "", ts, store) + + cases := map[string]*corrector_v1.DecrementCounterRequest{ + "empty domain": {Descriptor_: newRequest(1, 1).Descriptor_, Delta: 1, OriginalBucketTimestamp: 1}, + "empty descriptor": {Domain: "d", Delta: 1, OriginalBucketTimestamp: 1}, + "zero delta": {Domain: "d", Descriptor_: newRequest(1, 1).Descriptor_, OriginalBucketTimestamp: 1}, + "negative bucket": {Domain: "d", Descriptor_: newRequest(1, 1).Descriptor_, Delta: 1, OriginalBucketTimestamp: 0}, + } + for name, req := range cases { + resp, err := svc.DecrementCounter(context.Background(), req) + assert.Nil(t, resp, name) + st, ok := status.FromError(err) + assert.True(t, ok, name) + assert.Equal(t, codes.InvalidArgument, st.Code(), name) + } +} + +func TestCorrector_UnimplementedForMissingBackend(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + store := gostats.NewStore(gostats.NewNullSink(), false) + + // nil BucketCorrector ⇒ backend has no implementation (memcached today). + svc := ratelimit.NewCorrectorService(snapshotter{}, nil, "", ts, store) + + resp, err := svc.DecrementCounter(context.Background(), newRequest(1, 1)) + assert.Nil(t, resp) + st, _ := status.FromError(err) + assert.Equal(t, codes.Unimplemented, st.Code()) +} + +func TestCorrector_LimitNotFound(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + store := gostats.NewStore(gostats.NewNullSink(), false) + bc := &stubBucketCorrector{} + + cfg := mock_config.NewMockRateLimitConfig(ctrl) + cfg.EXPECT().GetLimit(gomock.Any(), "d", matchEnvoyDescriptor{}).Return(nil) + + svc := ratelimit.NewCorrectorService(snapshotter{cfg: cfg}, bc, "", ts, store) + resp, err := svc.DecrementCounter(context.Background(), newRequest(1, buildHourBucketTs())) + assert.NoError(t, err) + assert.Equal(t, corrector_v1.DecrementCounterResponse_LIMIT_NOT_FOUND, resp.GetCode()) + assert.Equal(t, "", bc.gotKey, "should not call backend on LIMIT_NOT_FOUND") +} + +func TestCorrector_UnitNotAllowed(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + store := gostats.NewStore(gostats.NewNullSink(), false) + bc := &stubBucketCorrector{} + + cfg := mock_config.NewMockRateLimitConfig(ctrl) + cfg.EXPECT().GetLimit(gomock.Any(), "d", matchEnvoyDescriptor{}). + Return(newRateLimit(pb.RateLimitResponse_RateLimit_MINUTE, store)) + + svc := ratelimit.NewCorrectorService(snapshotter{cfg: cfg}, bc, "", ts, store) + resp, err := svc.DecrementCounter(context.Background(), newRequest(1, buildHourBucketTs())) + assert.NoError(t, err) + assert.Equal(t, corrector_v1.DecrementCounterResponse_UNIT_NOT_ALLOWED, resp.GetCode()) +} + +func TestCorrector_BucketExpired(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + ts.EXPECT().UnixNow().Return(fixedNow) + store := gostats.NewStore(gostats.NewNullSink(), false) + bc := &stubBucketCorrector{} + + cfg := mock_config.NewMockRateLimitConfig(ctrl) + cfg.EXPECT().GetLimit(gomock.Any(), "d", matchEnvoyDescriptor{}). + Return(newRateLimit(pb.RateLimitResponse_RateLimit_HOUR, store)) + + svc := ratelimit.NewCorrectorService(snapshotter{cfg: cfg}, bc, "", ts, store) + // supply a bucket ts that doesn't match the current hour + resp, err := svc.DecrementCounter(context.Background(), newRequest(1, buildHourBucketTs()-3600)) + assert.NoError(t, err) + assert.Equal(t, corrector_v1.DecrementCounterResponse_BUCKET_EXPIRED, resp.GetCode()) + assert.Equal(t, "", bc.gotKey, "should not call backend on BUCKET_EXPIRED") +} + +func TestCorrector_KeyNotFound(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + ts.EXPECT().UnixNow().Return(fixedNow) + store := gostats.NewStore(gostats.NewNullSink(), false) + bc := &stubBucketCorrector{err: limiter.ErrBucketNotFound} + + cfg := mock_config.NewMockRateLimitConfig(ctrl) + cfg.EXPECT().GetLimit(gomock.Any(), "d", matchEnvoyDescriptor{}). + Return(newRateLimit(pb.RateLimitResponse_RateLimit_HOUR, store)) + + svc := ratelimit.NewCorrectorService(snapshotter{cfg: cfg}, bc, "", ts, store) + resp, err := svc.DecrementCounter(context.Background(), newRequest(1, buildHourBucketTs())) + assert.NoError(t, err) + assert.Equal(t, corrector_v1.DecrementCounterResponse_KEY_NOT_FOUND, resp.GetCode()) +} + +func TestCorrector_OK(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + ts.EXPECT().UnixNow().Return(fixedNow) + store := gostats.NewStore(gostats.NewNullSink(), false) + bc := &stubBucketCorrector{newValue: 7} + + cfg := mock_config.NewMockRateLimitConfig(ctrl) + cfg.EXPECT().GetLimit(gomock.Any(), "d", matchEnvoyDescriptor{}). + Return(newRateLimit(pb.RateLimitResponse_RateLimit_HOUR, store)) + + svc := ratelimit.NewCorrectorService(snapshotter{cfg: cfg}, bc, "", ts, store) + resp, err := svc.DecrementCounter(context.Background(), newRequest(3, buildHourBucketTs())) + assert.NoError(t, err) + assert.Equal(t, corrector_v1.DecrementCounterResponse_OK, resp.GetCode()) + assert.Equal(t, int64(7), resp.GetNewValue()) + assert.Equal(t, uint64(3), bc.gotDelta) + assert.NotEmpty(t, bc.gotKey, "handler should forward a non-empty cache key") +} + +func TestCorrector_BackendTransportError(t *testing.T) { + ctrl := gomock.NewController(t) + defer ctrl.Finish() + ts := mock_utils.NewMockTimeSource(ctrl) + ts.EXPECT().UnixNow().Return(fixedNow) + store := gostats.NewStore(gostats.NewNullSink(), false) + bc := &stubBucketCorrector{err: errors.New("boom")} + + cfg := mock_config.NewMockRateLimitConfig(ctrl) + cfg.EXPECT().GetLimit(gomock.Any(), "d", matchEnvoyDescriptor{}). + Return(newRateLimit(pb.RateLimitResponse_RateLimit_HOUR, store)) + + svc := ratelimit.NewCorrectorService(snapshotter{cfg: cfg}, bc, "", ts, store) + resp, err := svc.DecrementCounter(context.Background(), newRequest(1, buildHourBucketTs())) + assert.Nil(t, resp) + st, ok := status.FromError(err) + assert.True(t, ok) + assert.Equal(t, codes.Internal, st.Code()) +}