From ee80534b1c47cc36a0f9b00053cee666822b7306 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Fri, 24 Apr 2026 14:22:34 -0700 Subject: [PATCH 01/10] feat(core): accept shorthand enum names in policy API requests Developers can now use short, readable enum values (e.g. "IN", "AND", "ALL_OF", "ACTIVE") instead of verbose proto names (e.g. "SUBJECT_MAPPING_OPERATOR_ENUM_IN") in JSON API requests. Full canonical names continue to work for backward compatibility. Adds an HTTP middleware that normalizes shorthand enum strings to their canonical proto form before ConnectRPC deserializes the request. The middleware is scoped to specific RPC paths and only processes JSON content types. Enums covered: - SubjectMappingOperatorEnum: IN, NOT_IN, IN_CONTAINS - ConditionBooleanTypeEnum: AND, OR - AttributeRuleTypeEnum: ALL_OF, ANY_OF, HIERARCHY - ActiveStateEnum: ACTIVE, INACTIVE, ANY Closes #3338 Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Mary Dickson --- .../internal/enumnormalize/enumnormalize.go | 74 ++++ .../enumnormalize/enumnormalize_test.go | 323 ++++++++++++++++++ service/internal/enumnormalize/middleware.go | 62 ++++ .../internal/enumnormalize/middleware_test.go | 110 ++++++ service/internal/server/server.go | 36 ++ 5 files changed, 605 insertions(+) create mode 100644 service/internal/enumnormalize/enumnormalize.go create mode 100644 service/internal/enumnormalize/enumnormalize_test.go create mode 100644 service/internal/enumnormalize/middleware.go create mode 100644 service/internal/enumnormalize/middleware_test.go diff --git a/service/internal/enumnormalize/enumnormalize.go b/service/internal/enumnormalize/enumnormalize.go new file mode 100644 index 0000000000..b2026a96d1 --- /dev/null +++ b/service/internal/enumnormalize/enumnormalize.go @@ -0,0 +1,74 @@ +package enumnormalize + +import ( + "encoding/json" + "strings" +) + +// EnumFieldRule maps a JSON field name to the prefix that protobuf requires. +// When the middleware encounters a string value in a matching field that does +// not already carry the prefix, it prepends the prefix so that protojson +// recognises the canonical enum name. +type EnumFieldRule struct { + // JSONField is the protojson camelCase field name (e.g. "operator", "booleanOperator"). + JSONField string + // Prefix is the proto enum type prefix including trailing underscore + // (e.g. "SUBJECT_MAPPING_OPERATOR_ENUM_"). + Prefix string +} + +// NormalizeJSON rewrites shorthand enum string values in body according to +// rules. Values that already carry the full prefix, numeric values, and fields +// not covered by any rule pass through unchanged. +func NormalizeJSON(body []byte, rules []EnumFieldRule) ([]byte, error) { + if len(body) == 0 || len(rules) == 0 { + return body, nil + } + + // Build a lookup: lowercase field name → prefix + lookup := make(map[string]string, len(rules)) + for _, r := range rules { + lookup[strings.ToLower(r.JSONField)] = r.Prefix + } + + var parsed any + if err := json.Unmarshal(body, &parsed); err != nil { + // Not valid JSON — pass through and let ConnectRPC surface the error. + return body, nil //nolint:nilerr // intentional: invalid JSON is not our error to report + } + + normalizeValue(parsed, lookup) + + return json.Marshal(parsed) +} + +// normalizeValue recursively walks a decoded JSON value, normalizing string +// enum fields according to the lookup map. +func normalizeValue(v any, lookup map[string]string) { + switch val := v.(type) { + case map[string]any: + for key, child := range val { + if prefix, ok := lookup[strings.ToLower(key)]; ok { + if s, isStr := child.(string); isStr { + val[key] = applyPrefix(s, prefix) + } + } + normalizeValue(child, lookup) + } + case []any: + for _, item := range val { + normalizeValue(item, lookup) + } + } +} + +// applyPrefix prepends prefix to value if it is not already present +// (case-insensitive check). The value is upper-cased before comparison and +// before prepending so that "in" and "IN" both resolve correctly. +func applyPrefix(value, prefix string) string { + upper := strings.ToUpper(value) + if strings.HasPrefix(upper, strings.ToUpper(prefix)) { + return upper + } + return prefix + upper +} diff --git a/service/internal/enumnormalize/enumnormalize_test.go b/service/internal/enumnormalize/enumnormalize_test.go new file mode 100644 index 0000000000..a8bd38f62c --- /dev/null +++ b/service/internal/enumnormalize/enumnormalize_test.go @@ -0,0 +1,323 @@ +package enumnormalize + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var allRules = []EnumFieldRule{ + {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, + {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, + {JSONField: "rule", Prefix: "ATTRIBUTE_RULE_TYPE_ENUM_"}, + {JSONField: "state", Prefix: "ACTIVE_STATE_ENUM_"}, +} + +func TestNormalizeJSON_ShorthandOperators(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "IN shorthand", + input: `{"operator":"IN"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}`, + }, + { + name: "NOT_IN shorthand", + input: `{"operator":"NOT_IN"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"}`, + }, + { + name: "IN_CONTAINS shorthand", + input: `{"operator":"IN_CONTAINS"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN_CONTAINS"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := NormalizeJSON([]byte(tt.input), allRules) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_ShorthandBooleanOperators(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "AND shorthand", + input: `{"booleanOperator":"AND"}`, + expected: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_AND"}`, + }, + { + name: "OR shorthand", + input: `{"booleanOperator":"OR"}`, + expected: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_OR"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := NormalizeJSON([]byte(tt.input), allRules) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_ShorthandAttributeRuleType(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "ALL_OF shorthand", + input: `{"rule":"ALL_OF"}`, + expected: `{"rule":"ATTRIBUTE_RULE_TYPE_ENUM_ALL_OF"}`, + }, + { + name: "ANY_OF shorthand", + input: `{"rule":"ANY_OF"}`, + expected: `{"rule":"ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF"}`, + }, + { + name: "HIERARCHY shorthand", + input: `{"rule":"HIERARCHY"}`, + expected: `{"rule":"ATTRIBUTE_RULE_TYPE_ENUM_HIERARCHY"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := NormalizeJSON([]byte(tt.input), allRules) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_ShorthandActiveState(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "ACTIVE shorthand", + input: `{"state":"ACTIVE"}`, + expected: `{"state":"ACTIVE_STATE_ENUM_ACTIVE"}`, + }, + { + name: "INACTIVE shorthand", + input: `{"state":"INACTIVE"}`, + expected: `{"state":"ACTIVE_STATE_ENUM_INACTIVE"}`, + }, + { + name: "ANY shorthand", + input: `{"state":"ANY"}`, + expected: `{"state":"ACTIVE_STATE_ENUM_ANY"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := NormalizeJSON([]byte(tt.input), allRules) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_CaseInsensitive(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "lowercase in", + input: `{"operator":"in"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}`, + }, + { + name: "lowercase and", + input: `{"booleanOperator":"and"}`, + expected: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_AND"}`, + }, + { + name: "mixed case Not_In", + input: `{"operator":"Not_In"}`, + expected: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := NormalizeJSON([]byte(tt.input), allRules) + require.NoError(t, err) + assert.JSONEq(t, tt.expected, string(out)) + }) + } +} + +func TestNormalizeJSON_FullCanonicalNamesPassThrough(t *testing.T) { + tests := []struct { + name string + input string + }{ + { + name: "full operator name", + input: `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}`, + }, + { + name: "full boolean operator name", + input: `{"booleanOperator":"CONDITION_BOOLEAN_TYPE_ENUM_AND"}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := NormalizeJSON([]byte(tt.input), allRules) + require.NoError(t, err) + assert.JSONEq(t, tt.input, string(out)) + }) + } +} + +func TestNormalizeJSON_NumericValuesPassThrough(t *testing.T) { + input := `{"operator":1,"booleanOperator":2}` + out, err := NormalizeJSON([]byte(input), allRules) + require.NoError(t, err) + assert.JSONEq(t, input, string(out)) +} + +func TestNormalizeJSON_UnknownValuesGetPrefixed(t *testing.T) { + // Unknown shorthand values get the prefix prepended; downstream + // protovalidate will reject them. + input := `{"operator":"FOOBAR"}` + expected := `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_FOOBAR"}` + out, err := NormalizeJSON([]byte(input), allRules) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_UnrelatedFieldsUntouched(t *testing.T) { + input := `{"name":"test","description":"IN","operator":"IN"}` + out, err := NormalizeJSON([]byte(input), allRules) + require.NoError(t, err) + + var result map[string]any + require.NoError(t, json.Unmarshal(out, &result)) + + // "description" should NOT be prefixed — only "operator" is a rule field + assert.Equal(t, "IN", result["description"]) + assert.Equal(t, "SUBJECT_MAPPING_OPERATOR_ENUM_IN", result["operator"]) +} + +func TestNormalizeJSON_DeeplyNestedStructure(t *testing.T) { + // Simulates a CreateSubjectConditionSetRequest with nested condition groups + input := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "AND", + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": "IN", + "subjectExternalValues": ["user@example.com"] + }, + { + "subjectExternalSelectorValue": ".groups", + "operator": "NOT_IN", + "subjectExternalValues": ["banned"] + } + ] + }] + }] + } + }` + + expected := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_AND", + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": "SUBJECT_MAPPING_OPERATOR_ENUM_IN", + "subjectExternalValues": ["user@example.com"] + }, + { + "subjectExternalSelectorValue": ".groups", + "operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN", + "subjectExternalValues": ["banned"] + } + ] + }] + }] + } + }` + + out, err := NormalizeJSON([]byte(input), allRules) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_MixedShorthandAndFullNames(t *testing.T) { + input := `{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_OR", + "conditions": [ + {"operator": "IN"}, + {"operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"} + ] + }] + }` + + expected := `{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_OR", + "conditions": [ + {"operator": "SUBJECT_MAPPING_OPERATOR_ENUM_IN"}, + {"operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN"} + ] + }] + }` + + out, err := NormalizeJSON([]byte(input), allRules) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_EmptyBody(t *testing.T) { + out, err := NormalizeJSON([]byte{}, allRules) + require.NoError(t, err) + assert.Empty(t, out) +} + +func TestNormalizeJSON_NoRules(t *testing.T) { + input := `{"operator":"IN"}` + out, err := NormalizeJSON([]byte(input), nil) + require.NoError(t, err) + assert.Equal(t, input, string(out)) +} + +func TestNormalizeJSON_InvalidJSON(t *testing.T) { + input := `not json at all` + out, err := NormalizeJSON([]byte(input), allRules) + require.NoError(t, err) + // Invalid JSON passes through unchanged + assert.Equal(t, input, string(out)) +} diff --git a/service/internal/enumnormalize/middleware.go b/service/internal/enumnormalize/middleware.go new file mode 100644 index 0000000000..f972edba9e --- /dev/null +++ b/service/internal/enumnormalize/middleware.go @@ -0,0 +1,62 @@ +package enumnormalize + +import ( + "bytes" + "io" + "net/http" + "strconv" + "strings" +) + +// NewMiddleware returns HTTP middleware that normalises shorthand enum string +// values in JSON request bodies for the given RPC paths. Requests that do not +// match (wrong content-type, wrong path) are forwarded unchanged with zero +// overhead. +func NewMiddleware(rules []EnumFieldRule, paths []string) func(http.Handler) http.Handler { + pathSet := make(map[string]struct{}, len(paths)) + for _, p := range paths { + pathSet[p] = struct{}{} + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Only rewrite JSON bodies on matching RPC paths. + if !isJSON(r) || !matchesPath(r, pathSet) { + next.ServeHTTP(w, r) + return + } + + body, err := io.ReadAll(r.Body) + if err != nil { + next.ServeHTTP(w, r) + return + } + _ = r.Body.Close() + + normalized, err := NormalizeJSON(body, rules) + if err != nil { + // On normalisation failure, send the original body so + // ConnectRPC can surface its own error. + normalized = body + } + + r.Body = io.NopCloser(bytes.NewReader(normalized)) + r.ContentLength = int64(len(normalized)) + r.Header.Set("Content-Length", strconv.Itoa(len(normalized))) + + next.ServeHTTP(w, r) + }) + } +} + +// isJSON returns true when the request Content-Type indicates a JSON payload +// (application/json or application/connect+json). +func isJSON(r *http.Request) bool { + return strings.Contains(r.Header.Get("Content-Type"), "json") +} + +// matchesPath returns true when the request URL path is in pathSet. +func matchesPath(r *http.Request, pathSet map[string]struct{}) bool { + _, ok := pathSet[r.URL.Path] + return ok +} diff --git a/service/internal/enumnormalize/middleware_test.go b/service/internal/enumnormalize/middleware_test.go new file mode 100644 index 0000000000..bdd155e595 --- /dev/null +++ b/service/internal/enumnormalize/middleware_test.go @@ -0,0 +1,110 @@ +package enumnormalize + +import ( + "io" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +const testPath = "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" + +var testRules = []EnumFieldRule{ + {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, + {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, +} + +// captureHandler records the request body it receives. +type captureHandler struct { + body string +} + +func (h *captureHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { + b, _ := io.ReadAll(r.Body) + h.body = string(b) +} + +func TestMiddleware_NormalizesMatchingJSONRequest(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"booleanOperator":"AND","conditions":[{"operator":"IN"}]}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.Contains(t, capture.body, "CONDITION_BOOLEAN_TYPE_ENUM_AND") + assert.Contains(t, capture.body, "SUBJECT_MAPPING_OPERATOR_ENUM_IN") +} + +func TestMiddleware_ConnectJSONContentType(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"NOT_IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/connect+json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.Contains(t, capture.body, "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN") +} + +func TestMiddleware_NonMatchingPathPassesThrough(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"IN"}` + req := httptest.NewRequest(http.MethodPost, "/policy.attributes.AttributesService/ListAttributes", strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // Should be the original body, not normalized + assert.Equal(t, body, capture.body) +} + +func TestMiddleware_NonJSONContentTypePassesThrough(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/proto") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.Equal(t, body, capture.body) +} + +func TestMiddleware_CanonicalNamesUnchanged(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.JSONEq(t, body, capture.body) +} + +func TestMiddleware_ContentLengthUpdated(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + body := `{"operator":"IN"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // The normalized body is longer than the original + require.Greater(t, len(capture.body), len(body)) +} diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 8d16b784f7..41f1f87949 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -19,9 +19,14 @@ import ( "connectrpc.com/validate" "github.com/go-chi/cors" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" + attrconnect "github.com/opentdf/platform/protocol/go/policy/attributes/attributesconnect" + nsconnect "github.com/opentdf/platform/protocol/go/policy/namespaces/namespacesconnect" + smconnect "github.com/opentdf/platform/protocol/go/policy/subjectmapping/subjectmappingconnect" + unsafeconnect "github.com/opentdf/platform/protocol/go/policy/unsafe/unsafeconnect" "github.com/opentdf/platform/sdk" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/internal/auth" + "github.com/opentdf/platform/service/internal/enumnormalize" "github.com/opentdf/platform/service/internal/security" "github.com/opentdf/platform/service/internal/server/memhttp" "github.com/opentdf/platform/service/logger" @@ -392,6 +397,37 @@ func newHTTPServer(c Config, connectRPC http.Handler, originalGrpcGateway http.H tc *tls.Config ) + // Normalize shorthand enum names (e.g. "IN" → "SUBJECT_MAPPING_OPERATOR_ENUM_IN") + // in JSON request bodies before ConnectRPC deserializes them. Accepts the + // suffix after the enum type prefix, case-insensitive, while full canonical + // names continue to work unchanged. See: opentdf/platform#3338 + connectRPC = enumnormalize.NewMiddleware( + []enumnormalize.EnumFieldRule{ + // Subject Mapping enums + {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, + {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, + // Attribute rule type + {JSONField: "rule", Prefix: "ATTRIBUTE_RULE_TYPE_ENUM_"}, + // Active state filter (list requests) + {JSONField: "state", Prefix: "ACTIVE_STATE_ENUM_"}, + }, + []string{ + // Subject Mapping RPCs + smconnect.SubjectMappingServiceCreateSubjectMappingProcedure, + smconnect.SubjectMappingServiceCreateSubjectConditionSetProcedure, + smconnect.SubjectMappingServiceUpdateSubjectConditionSetProcedure, + // Attribute RPCs (rule + state) + attrconnect.AttributesServiceCreateAttributeProcedure, + attrconnect.AttributesServiceUpdateAttributeProcedure, + attrconnect.AttributesServiceListAttributesProcedure, + attrconnect.AttributesServiceListAttributeValuesProcedure, + // Namespace RPCs (state) + nsconnect.NamespaceServiceListNamespacesProcedure, + // Unsafe RPCs (rule) + unsafeconnect.UnsafeServiceUnsafeUpdateAttributeProcedure, + }, + )(connectRPC) + // Adds deprecation header to any grpcGateway responses. var grpcGateway http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { grpcRW := &grpcGatewayResponseWriter{w: w, code: http.StatusOK} From f421f4d256fe44979a34aa37838b15b265abffd7 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Fri, 24 Apr 2026 14:32:30 -0700 Subject: [PATCH 02/10] fix(core): address review feedback on enum normalization - Use json.Decoder with UseNumber() to preserve numeric precision for large int64 values (avoids float64 conversion) - Add http.MaxBytesReader (1 MB cap) to prevent DoS via oversized request bodies - Pre-build field lookup map once at middleware init instead of per-request - Use exact key matching for JSON field names (protojson always emits camelCase) instead of case-insensitive comparison Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Mary Dickson --- .../internal/enumnormalize/enumnormalize.go | 40 +++++++++++++------ .../enumnormalize/enumnormalize_test.go | 32 +++++++-------- service/internal/enumnormalize/middleware.go | 13 ++++-- 3 files changed, 53 insertions(+), 32 deletions(-) diff --git a/service/internal/enumnormalize/enumnormalize.go b/service/internal/enumnormalize/enumnormalize.go index b2026a96d1..cae966c86b 100644 --- a/service/internal/enumnormalize/enumnormalize.go +++ b/service/internal/enumnormalize/enumnormalize.go @@ -1,6 +1,7 @@ package enumnormalize import ( + "bytes" "encoding/json" "strings" ) @@ -17,22 +18,35 @@ type EnumFieldRule struct { Prefix string } -// NormalizeJSON rewrites shorthand enum string values in body according to -// rules. Values that already carry the full prefix, numeric values, and fields -// not covered by any rule pass through unchanged. -func NormalizeJSON(body []byte, rules []EnumFieldRule) ([]byte, error) { - if len(body) == 0 || len(rules) == 0 { - return body, nil - } +// fieldLookup is a pre-built map from JSON field name to enum prefix, +// constructed once at middleware initialization time. +type fieldLookup map[string]string - // Build a lookup: lowercase field name → prefix - lookup := make(map[string]string, len(rules)) +// buildLookup creates a fieldLookup from a set of rules. Keys are stored +// exactly as declared (protojson always emits camelCase). +func buildLookup(rules []EnumFieldRule) fieldLookup { + m := make(fieldLookup, len(rules)) for _, r := range rules { - lookup[strings.ToLower(r.JSONField)] = r.Prefix + m[r.JSONField] = r.Prefix } + return m +} + +// normalizeJSON rewrites shorthand enum string values in body according to +// the pre-built lookup. Values that already carry the full prefix, numeric +// values, and fields not covered by any rule pass through unchanged. +func normalizeJSON(body []byte, lookup fieldLookup) ([]byte, error) { + if len(body) == 0 || len(lookup) == 0 { + return body, nil + } + + // Use json.Decoder with UseNumber to preserve numeric precision + // (avoids float64 conversion of large int64 values). + decoder := json.NewDecoder(bytes.NewReader(body)) + decoder.UseNumber() var parsed any - if err := json.Unmarshal(body, &parsed); err != nil { + if err := decoder.Decode(&parsed); err != nil { // Not valid JSON — pass through and let ConnectRPC surface the error. return body, nil //nolint:nilerr // intentional: invalid JSON is not our error to report } @@ -44,11 +58,11 @@ func NormalizeJSON(body []byte, rules []EnumFieldRule) ([]byte, error) { // normalizeValue recursively walks a decoded JSON value, normalizing string // enum fields according to the lookup map. -func normalizeValue(v any, lookup map[string]string) { +func normalizeValue(v any, lookup fieldLookup) { switch val := v.(type) { case map[string]any: for key, child := range val { - if prefix, ok := lookup[strings.ToLower(key)]; ok { + if prefix, ok := lookup[key]; ok { if s, isStr := child.(string); isStr { val[key] = applyPrefix(s, prefix) } diff --git a/service/internal/enumnormalize/enumnormalize_test.go b/service/internal/enumnormalize/enumnormalize_test.go index a8bd38f62c..671fb46a89 100644 --- a/service/internal/enumnormalize/enumnormalize_test.go +++ b/service/internal/enumnormalize/enumnormalize_test.go @@ -8,12 +8,12 @@ import ( "github.com/stretchr/testify/require" ) -var allRules = []EnumFieldRule{ +var allLookup = buildLookup([]EnumFieldRule{ {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, {JSONField: "rule", Prefix: "ATTRIBUTE_RULE_TYPE_ENUM_"}, {JSONField: "state", Prefix: "ACTIVE_STATE_ENUM_"}, -} +}) func TestNormalizeJSON_ShorthandOperators(t *testing.T) { tests := []struct { @@ -40,7 +40,7 @@ func TestNormalizeJSON_ShorthandOperators(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out, err := NormalizeJSON([]byte(tt.input), allRules) + out, err := normalizeJSON([]byte(tt.input), allLookup) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(out)) }) @@ -67,7 +67,7 @@ func TestNormalizeJSON_ShorthandBooleanOperators(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out, err := NormalizeJSON([]byte(tt.input), allRules) + out, err := normalizeJSON([]byte(tt.input), allLookup) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(out)) }) @@ -99,7 +99,7 @@ func TestNormalizeJSON_ShorthandAttributeRuleType(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out, err := NormalizeJSON([]byte(tt.input), allRules) + out, err := normalizeJSON([]byte(tt.input), allLookup) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(out)) }) @@ -131,7 +131,7 @@ func TestNormalizeJSON_ShorthandActiveState(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out, err := NormalizeJSON([]byte(tt.input), allRules) + out, err := normalizeJSON([]byte(tt.input), allLookup) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(out)) }) @@ -163,7 +163,7 @@ func TestNormalizeJSON_CaseInsensitive(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out, err := NormalizeJSON([]byte(tt.input), allRules) + out, err := normalizeJSON([]byte(tt.input), allLookup) require.NoError(t, err) assert.JSONEq(t, tt.expected, string(out)) }) @@ -187,7 +187,7 @@ func TestNormalizeJSON_FullCanonicalNamesPassThrough(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out, err := NormalizeJSON([]byte(tt.input), allRules) + out, err := normalizeJSON([]byte(tt.input), allLookup) require.NoError(t, err) assert.JSONEq(t, tt.input, string(out)) }) @@ -196,7 +196,7 @@ func TestNormalizeJSON_FullCanonicalNamesPassThrough(t *testing.T) { func TestNormalizeJSON_NumericValuesPassThrough(t *testing.T) { input := `{"operator":1,"booleanOperator":2}` - out, err := NormalizeJSON([]byte(input), allRules) + out, err := normalizeJSON([]byte(input), allLookup) require.NoError(t, err) assert.JSONEq(t, input, string(out)) } @@ -206,14 +206,14 @@ func TestNormalizeJSON_UnknownValuesGetPrefixed(t *testing.T) { // protovalidate will reject them. input := `{"operator":"FOOBAR"}` expected := `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_FOOBAR"}` - out, err := NormalizeJSON([]byte(input), allRules) + out, err := normalizeJSON([]byte(input), allLookup) require.NoError(t, err) assert.JSONEq(t, expected, string(out)) } func TestNormalizeJSON_UnrelatedFieldsUntouched(t *testing.T) { input := `{"name":"test","description":"IN","operator":"IN"}` - out, err := NormalizeJSON([]byte(input), allRules) + out, err := normalizeJSON([]byte(input), allLookup) require.NoError(t, err) var result map[string]any @@ -270,7 +270,7 @@ func TestNormalizeJSON_DeeplyNestedStructure(t *testing.T) { } }` - out, err := NormalizeJSON([]byte(input), allRules) + out, err := normalizeJSON([]byte(input), allLookup) require.NoError(t, err) assert.JSONEq(t, expected, string(out)) } @@ -296,27 +296,27 @@ func TestNormalizeJSON_MixedShorthandAndFullNames(t *testing.T) { }] }` - out, err := NormalizeJSON([]byte(input), allRules) + out, err := normalizeJSON([]byte(input), allLookup) require.NoError(t, err) assert.JSONEq(t, expected, string(out)) } func TestNormalizeJSON_EmptyBody(t *testing.T) { - out, err := NormalizeJSON([]byte{}, allRules) + out, err := normalizeJSON([]byte{}, allLookup) require.NoError(t, err) assert.Empty(t, out) } func TestNormalizeJSON_NoRules(t *testing.T) { input := `{"operator":"IN"}` - out, err := NormalizeJSON([]byte(input), nil) + out, err := normalizeJSON([]byte(input), nil) require.NoError(t, err) assert.Equal(t, input, string(out)) } func TestNormalizeJSON_InvalidJSON(t *testing.T) { input := `not json at all` - out, err := NormalizeJSON([]byte(input), allRules) + out, err := normalizeJSON([]byte(input), allLookup) require.NoError(t, err) // Invalid JSON passes through unchanged assert.Equal(t, input, string(out)) diff --git a/service/internal/enumnormalize/middleware.go b/service/internal/enumnormalize/middleware.go index f972edba9e..f2340add76 100644 --- a/service/internal/enumnormalize/middleware.go +++ b/service/internal/enumnormalize/middleware.go @@ -8,11 +8,19 @@ import ( "strings" ) +// maxBodySize is the upper bound on request bodies the middleware will read +// into memory for normalization. Policy API request bodies are small (typically +// under 10 KB); this cap prevents abuse while being generous enough for any +// legitimate request. ConnectRPC enforces its own message size limits downstream. +const maxBodySize = 1 << 20 // 1 MB + // NewMiddleware returns HTTP middleware that normalises shorthand enum string // values in JSON request bodies for the given RPC paths. Requests that do not // match (wrong content-type, wrong path) are forwarded unchanged with zero // overhead. func NewMiddleware(rules []EnumFieldRule, paths []string) func(http.Handler) http.Handler { + lookup := buildLookup(rules) + pathSet := make(map[string]struct{}, len(paths)) for _, p := range paths { pathSet[p] = struct{}{} @@ -26,14 +34,13 @@ func NewMiddleware(rules []EnumFieldRule, paths []string) func(http.Handler) htt return } - body, err := io.ReadAll(r.Body) + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodySize)) if err != nil { next.ServeHTTP(w, r) return } - _ = r.Body.Close() - normalized, err := NormalizeJSON(body, rules) + normalized, err := normalizeJSON(body, lookup) if err != nil { // On normalisation failure, send the original body so // ConnectRPC can surface its own error. From ebc91394026ebd9b0e3d156bc08df1ec0ef418d9 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 11:10:51 -0700 Subject: [PATCH 03/10] chore(core): add numeric enum passthrough tests Verify that numeric enum values (e.g., operator: 1, booleanOperator: 3) pass through the normalization middleware unchanged. These are valid protojson and must continue to work alongside shorthand string names. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../enumnormalize/enumnormalize_test.go | 67 ++++++++++++++++++- .../internal/enumnormalize/middleware_test.go | 15 +++++ 2 files changed, 81 insertions(+), 1 deletion(-) diff --git a/service/internal/enumnormalize/enumnormalize_test.go b/service/internal/enumnormalize/enumnormalize_test.go index 671fb46a89..cb7bc5e01f 100644 --- a/service/internal/enumnormalize/enumnormalize_test.go +++ b/service/internal/enumnormalize/enumnormalize_test.go @@ -195,9 +195,74 @@ func TestNormalizeJSON_FullCanonicalNamesPassThrough(t *testing.T) { } func TestNormalizeJSON_NumericValuesPassThrough(t *testing.T) { - input := `{"operator":1,"booleanOperator":2}` + tests := []struct { + name string + input string + }{ + { + name: "operator 1 (IN) and booleanOperator 2 (OR)", + input: `{"operator":1,"booleanOperator":2}`, + }, + { + name: "operator 3 (IN_CONTAINS) and booleanOperator 1 (AND)", + input: `{"operator":3,"booleanOperator":1}`, + }, + { + name: "operator 2 (NOT_IN)", + input: `{"operator":2}`, + }, + { + name: "rule 1 (ALL_OF)", + input: `{"rule":1}`, + }, + { + name: "state 1 (ACTIVE)", + input: `{"state":1}`, + }, + { + name: "numeric zero (UNSPECIFIED) passes through", + input: `{"operator":0}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + out, err := normalizeJSON([]byte(tt.input), allLookup) + require.NoError(t, err) + assert.JSONEq(t, tt.input, string(out)) + }) + } +} + +func TestNormalizeJSON_NumericValuesInNestedStructure(t *testing.T) { + // Simulates the JSON format that was previously used in documentation: + // numeric enum codes instead of string names. + input := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": 1, + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": 3, + "subjectExternalValues": ["@example.com"] + }, + { + "subjectExternalSelectorValue": ".role", + "operator": 1, + "subjectExternalValues": ["admin"] + } + ] + }] + }] + } + }` + out, err := normalizeJSON([]byte(input), allLookup) require.NoError(t, err) + // Numeric values should pass through unchanged — protojson natively + // accepts numeric enum representations. assert.JSONEq(t, input, string(out)) } diff --git a/service/internal/enumnormalize/middleware_test.go b/service/internal/enumnormalize/middleware_test.go index bdd155e595..534f180539 100644 --- a/service/internal/enumnormalize/middleware_test.go +++ b/service/internal/enumnormalize/middleware_test.go @@ -95,6 +95,21 @@ func TestMiddleware_CanonicalNamesUnchanged(t *testing.T) { assert.JSONEq(t, body, capture.body) } +func TestMiddleware_NumericEnumValuesPassThrough(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + // Numeric enum values (e.g., 1 for IN, 3 for IN_CONTAINS) are valid + // protojson and should pass through the middleware unchanged. + body := `{"booleanOperator":1,"conditions":[{"operator":3}]}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(body)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + assert.JSONEq(t, body, capture.body) +} + func TestMiddleware_ContentLengthUpdated(t *testing.T) { capture := &captureHandler{} mw := NewMiddleware(testRules, []string{testPath}) From 9668d0d8c7bf7fa637c45e16c764473c848a441e Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 11:53:23 -0700 Subject: [PATCH 04/10] chore(core): add e2e BDD tests for shorthand enum names Add end-to-end tests that send raw HTTP POST requests with shorthand enum strings (bypassing the SDK) to verify the normalization middleware works through the full ConnectRPC stack. Tests cover: - Shorthand operator/boolean enums (IN, IN_CONTAINS, AND) - Shorthand rule type enum (ANY_OF) - Mixed shorthand and canonical names in the same request Co-Authored-By: Claude Opus 4.6 (1M context) --- tests-bdd/cukes/steps_enum_shorthand.go | 276 +++++++++++++++++++++ tests-bdd/features/shorthand-enums.feature | 18 ++ tests-bdd/platform_test.go | 1 + 3 files changed, 295 insertions(+) create mode 100644 tests-bdd/cukes/steps_enum_shorthand.go create mode 100644 tests-bdd/features/shorthand-enums.feature diff --git a/tests-bdd/cukes/steps_enum_shorthand.go b/tests-bdd/cukes/steps_enum_shorthand.go new file mode 100644 index 0000000000..3b471ba9e8 --- /dev/null +++ b/tests-bdd/cukes/steps_enum_shorthand.go @@ -0,0 +1,276 @@ +package cukes + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "log/slog" + "net/http" + "net/url" + "strings" + + "github.com/cucumber/godog" +) + +type EnumShorthandStepDefinitions struct{} + +// getAccessToken fetches a bearer token from the Keycloak token endpoint using +// the same client credentials the BDD test SDK uses. +func getAccessToken(tokenEndpoint string) (string, error) { + data := url.Values{ + "grant_type": {"client_credentials"}, + "client_id": {"opentdf"}, + "client_secret": {"secret"}, + } + + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // test-only + }, + } + resp, err := client.PostForm(tokenEndpoint, data) + if err != nil { + return "", fmt.Errorf("token request failed: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + return "", fmt.Errorf("token request returned %d: %s", resp.StatusCode, body) + } + + var tokenResp struct { + AccessToken string `json:"access_token"` + } + if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil { + return "", fmt.Errorf("failed to decode token response: %w", err) + } + return tokenResp.AccessToken, nil +} + +// postConnectRPC sends a raw JSON body to a ConnectRPC endpoint and returns the +// HTTP status code and response body. +func postConnectRPC(endpoint, rpcPath, token, jsonBody string) (int, string, error) { + client := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // test-only + }, + } + + req, err := http.NewRequest(http.MethodPost, endpoint+rpcPath, strings.NewReader(jsonBody)) + if err != nil { + return 0, "", err + } + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := client.Do(req) + if err != nil { + return 0, "", err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return resp.StatusCode, "", err + } + return resp.StatusCode, string(body), nil +} + +// iCreateASubjectConditionSetViaHTTPWithShorthandEnums sends a raw HTTP POST with +// shorthand enum strings and verifies the platform accepts it. +func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithShorthandEnums(ctx context.Context) (context.Context, error) { + scenarioContext := GetPlatformScenarioContext(ctx) + scenarioContext.ClearError() + + endpoint := scenarioContext.ScenarioOptions.PlatformEndpoint + tokenEndpoint, err := scenarioContext.SDK.PlatformConfiguration.TokenEndpoint() + if err != nil { + return ctx, fmt.Errorf("failed to get token endpoint: %w", err) + } + + token, err := getAccessToken(tokenEndpoint) + if err != nil { + return ctx, fmt.Errorf("failed to get access token: %w", err) + } + + // Raw JSON with shorthand enum values — this is what the middleware normalizes. + body := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "AND", + "conditions": [{ + "subjectExternalSelectorValue": ".email", + "operator": "IN_CONTAINS", + "subjectExternalValues": ["@example.com"] + }] + }] + }] + } + }` + + rpcPath := "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" + statusCode, respBody, err := postConnectRPC(endpoint, rpcPath, token, body) + if err != nil { + return ctx, fmt.Errorf("HTTP request failed: %w", err) + } + + slog.Debug("shorthand enum e2e response", + slog.Int("status", statusCode), + slog.String("body", respBody)) + + if statusCode != http.StatusOK { + return ctx, fmt.Errorf("expected HTTP 200, got %d: %s", statusCode, respBody) + } + + // Verify the response contains a valid subject condition set ID + var result map[string]any + if err := json.Unmarshal([]byte(respBody), &result); err != nil { + return ctx, fmt.Errorf("failed to parse response: %w", err) + } + scs, ok := result["subjectConditionSet"].(map[string]any) + if !ok || scs["id"] == nil { + return ctx, fmt.Errorf("response missing subjectConditionSet.id: %s", respBody) + } + + scenarioContext.RecordObject("shorthand_scs_id", scs["id"].(string)) + return ctx, nil +} + +// iCreateAnAttributeViaHTTPWithShorthandRule sends a raw HTTP POST to create an +// attribute using a shorthand rule type enum. +func (s *EnumShorthandStepDefinitions) iCreateAnAttributeViaHTTPWithShorthandRule(ctx context.Context) (context.Context, error) { + scenarioContext := GetPlatformScenarioContext(ctx) + scenarioContext.ClearError() + + endpoint := scenarioContext.ScenarioOptions.PlatformEndpoint + tokenEndpoint, err := scenarioContext.SDK.PlatformConfiguration.TokenEndpoint() + if err != nil { + return ctx, fmt.Errorf("failed to get token endpoint: %w", err) + } + + token, err := getAccessToken(tokenEndpoint) + if err != nil { + return ctx, fmt.Errorf("failed to get access token: %w", err) + } + + // Get the namespace ID that was created by the scenario setup + nsID, ok := scenarioContext.GetObject("ns1").(string) + if !ok { + return ctx, fmt.Errorf("namespace ns1 not found in scenario context") + } + + // Raw JSON with shorthand rule type + body := fmt.Sprintf(`{ + "attribute": { + "namespaceId": "%s", + "name": "shorthand_test_attr", + "rule": "ANY_OF", + "values": ["val1", "val2"] + } + }`, nsID) + + rpcPath := "/policy.attributes.AttributesService/CreateAttribute" + statusCode, respBody, err := postConnectRPC(endpoint, rpcPath, token, body) + if err != nil { + return ctx, fmt.Errorf("HTTP request failed: %w", err) + } + + slog.Debug("shorthand rule e2e response", + slog.Int("status", statusCode), + slog.String("body", respBody)) + + if statusCode != http.StatusOK { + return ctx, fmt.Errorf("expected HTTP 200, got %d: %s", statusCode, respBody) + } + + // Verify the response contains a valid attribute with the correct rule + var result map[string]any + if err := json.Unmarshal([]byte(respBody), &result); err != nil { + return ctx, fmt.Errorf("failed to parse response: %w", err) + } + attr, ok := result["attribute"].(map[string]any) + if !ok || attr["id"] == nil { + return ctx, fmt.Errorf("response missing attribute.id: %s", respBody) + } + + // Verify the rule was accepted and stored as the canonical name + rule, _ := attr["rule"].(string) + if rule != "ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF" { + return ctx, fmt.Errorf("expected rule ATTRIBUTE_RULE_TYPE_ENUM_ANY_OF, got %s", rule) + } + + return ctx, nil +} + +// iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats verifies that a request +// mixing shorthand and canonical enum names works correctly. +func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats(ctx context.Context) (context.Context, error) { + scenarioContext := GetPlatformScenarioContext(ctx) + scenarioContext.ClearError() + + endpoint := scenarioContext.ScenarioOptions.PlatformEndpoint + tokenEndpoint, err := scenarioContext.SDK.PlatformConfiguration.TokenEndpoint() + if err != nil { + return ctx, fmt.Errorf("failed to get token endpoint: %w", err) + } + + token, err := getAccessToken(tokenEndpoint) + if err != nil { + return ctx, fmt.Errorf("failed to get access token: %w", err) + } + + // Mix shorthand and canonical names in the same request + body := `{ + "subjectConditionSet": { + "subjectSets": [{ + "conditionGroups": [{ + "booleanOperator": "CONDITION_BOOLEAN_TYPE_ENUM_AND", + "conditions": [ + { + "subjectExternalSelectorValue": ".email", + "operator": "IN", + "subjectExternalValues": ["@test.com"] + }, + { + "subjectExternalSelectorValue": ".role", + "operator": "SUBJECT_MAPPING_OPERATOR_ENUM_NOT_IN", + "subjectExternalValues": ["guest"] + } + ] + }] + }] + } + }` + + rpcPath := "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" + statusCode, respBody, err := postConnectRPC(endpoint, rpcPath, token, body) + if err != nil { + return ctx, fmt.Errorf("HTTP request failed: %w", err) + } + + if statusCode != http.StatusOK { + return ctx, fmt.Errorf("expected HTTP 200, got %d: %s", statusCode, respBody) + } + + var result map[string]any + if err := json.Unmarshal([]byte(respBody), &result); err != nil { + return ctx, fmt.Errorf("failed to parse response: %w", err) + } + scs, ok := result["subjectConditionSet"].(map[string]any) + if !ok || scs["id"] == nil { + return ctx, fmt.Errorf("response missing subjectConditionSet.id: %s", respBody) + } + + return ctx, nil +} + +func RegisterEnumShorthandStepDefinitions(ctx *godog.ScenarioContext) { + steps := &EnumShorthandStepDefinitions{} + ctx.Step(`^I create a subject condition set via HTTP with shorthand enums$`, steps.iCreateASubjectConditionSetViaHTTPWithShorthandEnums) + ctx.Step(`^I create an attribute via HTTP with shorthand rule type$`, steps.iCreateAnAttributeViaHTTPWithShorthandRule) + ctx.Step(`^I create a subject condition set via HTTP with mixed enum formats$`, steps.iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats) +} diff --git a/tests-bdd/features/shorthand-enums.feature b/tests-bdd/features/shorthand-enums.feature new file mode 100644 index 0000000000..274667dec1 --- /dev/null +++ b/tests-bdd/features/shorthand-enums.feature @@ -0,0 +1,18 @@ +@shorthand-enums +Feature: Shorthand Enum Names E2E + Verify that the platform accepts shorthand enum names (e.g., "IN", "AND", + "ANY_OF") in raw HTTP JSON requests. These tests bypass the SDK and send + raw ConnectRPC JSON to prove the normalization middleware works end-to-end. + + Background: + Given an empty local platform + And I submit a request to create a namespace with name "shorthandenums.io" and reference id "ns1" + + Scenario: Create subject condition set with shorthand operator and boolean enums + When I create a subject condition set via HTTP with shorthand enums + + Scenario: Create attribute with shorthand rule type enum + When I create an attribute via HTTP with shorthand rule type + + Scenario: Create subject condition set with mixed shorthand and canonical enum names + When I create a subject condition set via HTTP with mixed enum formats diff --git a/tests-bdd/platform_test.go b/tests-bdd/platform_test.go index f1251c8d8f..6acafef0a1 100644 --- a/tests-bdd/platform_test.go +++ b/tests-bdd/platform_test.go @@ -108,6 +108,7 @@ func runTests() int { cukes.RegisterSmokeStepDefinitions(ctx, platformCukesContext) cukes.RegisterAuthorizationStepDefinitions(ctx) cukes.RegisterSubjectMappingsStepsDefinitions(ctx) + cukes.RegisterEnumShorthandStepDefinitions(ctx) cukes.RegisterRegisteredResourcesStepDefinitions(ctx) cukes.RegisterObligationsStepDefinitions(ctx, platformCukesContext) platformCukesContext.InitializeScenario(ctx) From 2fa0ca80631f5f932176c99f332ea8e41948019c Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 12:14:49 -0700 Subject: [PATCH 05/10] chore(core): add oversized body test for enum normalization middleware Verify that when a request body exceeds the MaxBytesReader limit (1 MB), the middleware skips normalization and forwards the request unchanged. Co-Authored-By: Claude Opus 4.6 (1M context) --- .../internal/enumnormalize/middleware_test.go | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/service/internal/enumnormalize/middleware_test.go b/service/internal/enumnormalize/middleware_test.go index 534f180539..1bccb21b1c 100644 --- a/service/internal/enumnormalize/middleware_test.go +++ b/service/internal/enumnormalize/middleware_test.go @@ -123,3 +123,20 @@ func TestMiddleware_ContentLengthUpdated(t *testing.T) { // The normalized body is longer than the original require.Greater(t, len(capture.body), len(body)) } + +func TestMiddleware_OversizedBodySkipsNormalization(t *testing.T) { + capture := &captureHandler{} + mw := NewMiddleware(testRules, []string{testPath}) + handler := mw(capture) + + // Build a body that exceeds maxBodySize (1 MB). + oversized := `{"operator":"` + strings.Repeat("A", maxBodySize) + `"}` + req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(oversized)) + req.Header.Set("Content-Type", "application/json") + handler.ServeHTTP(httptest.NewRecorder(), req) + + // The middleware should skip normalization on read error and forward the + // request. The downstream handler receives whatever MaxBytesReader yielded + // before the limit — NOT a normalized body. + assert.NotContains(t, capture.body, "SUBJECT_MAPPING_OPERATOR_ENUM_") +} From c70f7478216952673a3ccbc88ad9ce62826ea417 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 13:04:38 -0700 Subject: [PATCH 06/10] fix(core): address review feedback on BDD shorthand enum tests - Remove unnecessary TLS config (endpoints are plain HTTP) - Extract prepareAuthenticatedRequest helper to reduce duplication - Use http.NewRequestWithContext to satisfy noctx linter - Fix errcheck on type assertion and perfsprint on static error string - Revert stray io import in enumnormalize.go Co-Authored-By: Claude Opus 4.6 (1M context) --- tests-bdd/cukes/steps_enum_shorthand.go | 90 ++++++++++++------------- 1 file changed, 43 insertions(+), 47 deletions(-) diff --git a/tests-bdd/cukes/steps_enum_shorthand.go b/tests-bdd/cukes/steps_enum_shorthand.go index 3b471ba9e8..f671c877d4 100644 --- a/tests-bdd/cukes/steps_enum_shorthand.go +++ b/tests-bdd/cukes/steps_enum_shorthand.go @@ -2,8 +2,8 @@ package cukes import ( "context" - "crypto/tls" "encoding/json" + "errors" "fmt" "io" "log/slog" @@ -18,19 +18,21 @@ type EnumShorthandStepDefinitions struct{} // getAccessToken fetches a bearer token from the Keycloak token endpoint using // the same client credentials the BDD test SDK uses. -func getAccessToken(tokenEndpoint string) (string, error) { +func getAccessToken(ctx context.Context, tokenEndpoint string) (string, error) { data := url.Values{ "grant_type": {"client_credentials"}, "client_id": {"opentdf"}, "client_secret": {"secret"}, } - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // test-only - }, + req, err := http.NewRequestWithContext(ctx, http.MethodPost, tokenEndpoint, strings.NewReader(data.Encode())) + if err != nil { + return "", fmt.Errorf("token request creation failed: %w", err) } - resp, err := client.PostForm(tokenEndpoint, data) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded") + + client := &http.Client{} + resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("token request failed: %w", err) } @@ -52,14 +54,10 @@ func getAccessToken(tokenEndpoint string) (string, error) { // postConnectRPC sends a raw JSON body to a ConnectRPC endpoint and returns the // HTTP status code and response body. -func postConnectRPC(endpoint, rpcPath, token, jsonBody string) (int, string, error) { - client := &http.Client{ - Transport: &http.Transport{ - TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, //nolint:gosec // test-only - }, - } +func postConnectRPC(ctx context.Context, endpoint, rpcPath, token, jsonBody string) (int, string, error) { + client := &http.Client{} - req, err := http.NewRequest(http.MethodPost, endpoint+rpcPath, strings.NewReader(jsonBody)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+rpcPath, strings.NewReader(jsonBody)) if err != nil { return 0, "", err } @@ -79,21 +77,33 @@ func postConnectRPC(endpoint, rpcPath, token, jsonBody string) (int, string, err return resp.StatusCode, string(body), nil } -// iCreateASubjectConditionSetViaHTTPWithShorthandEnums sends a raw HTTP POST with -// shorthand enum strings and verifies the platform accepts it. -func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithShorthandEnums(ctx context.Context) (context.Context, error) { +// prepareAuthenticatedRequest extracts the platform endpoint and fetches a +// bearer token for raw HTTP requests. This is the common setup shared by all +// shorthand enum e2e step definitions. +func prepareAuthenticatedRequest(ctx context.Context) (*PlatformScenarioContext, string, string, error) { scenarioContext := GetPlatformScenarioContext(ctx) scenarioContext.ClearError() endpoint := scenarioContext.ScenarioOptions.PlatformEndpoint tokenEndpoint, err := scenarioContext.SDK.PlatformConfiguration.TokenEndpoint() if err != nil { - return ctx, fmt.Errorf("failed to get token endpoint: %w", err) + return nil, "", "", fmt.Errorf("failed to get token endpoint: %w", err) } - token, err := getAccessToken(tokenEndpoint) + token, err := getAccessToken(ctx, tokenEndpoint) if err != nil { - return ctx, fmt.Errorf("failed to get access token: %w", err) + return nil, "", "", fmt.Errorf("failed to get access token: %w", err) + } + + return scenarioContext, endpoint, token, nil +} + +// iCreateASubjectConditionSetViaHTTPWithShorthandEnums sends a raw HTTP POST with +// shorthand enum strings and verifies the platform accepts it. +func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithShorthandEnums(ctx context.Context) (context.Context, error) { + scenarioContext, endpoint, token, err := prepareAuthenticatedRequest(ctx) + if err != nil { + return ctx, err } // Raw JSON with shorthand enum values — this is what the middleware normalizes. @@ -113,7 +123,7 @@ func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithSho }` rpcPath := "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" - statusCode, respBody, err := postConnectRPC(endpoint, rpcPath, token, body) + statusCode, respBody, err := postConnectRPC(ctx, endpoint, rpcPath, token, body) if err != nil { return ctx, fmt.Errorf("HTTP request failed: %w", err) } @@ -136,31 +146,26 @@ func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithSho return ctx, fmt.Errorf("response missing subjectConditionSet.id: %s", respBody) } - scenarioContext.RecordObject("shorthand_scs_id", scs["id"].(string)) + scsID, ok := scs["id"].(string) + if !ok { + return ctx, fmt.Errorf("subjectConditionSet.id is not a string: %s", respBody) + } + scenarioContext.RecordObject("shorthand_scs_id", scsID) return ctx, nil } // iCreateAnAttributeViaHTTPWithShorthandRule sends a raw HTTP POST to create an // attribute using a shorthand rule type enum. func (s *EnumShorthandStepDefinitions) iCreateAnAttributeViaHTTPWithShorthandRule(ctx context.Context) (context.Context, error) { - scenarioContext := GetPlatformScenarioContext(ctx) - scenarioContext.ClearError() - - endpoint := scenarioContext.ScenarioOptions.PlatformEndpoint - tokenEndpoint, err := scenarioContext.SDK.PlatformConfiguration.TokenEndpoint() - if err != nil { - return ctx, fmt.Errorf("failed to get token endpoint: %w", err) - } - - token, err := getAccessToken(tokenEndpoint) + scenarioContext, endpoint, token, err := prepareAuthenticatedRequest(ctx) if err != nil { - return ctx, fmt.Errorf("failed to get access token: %w", err) + return ctx, err } // Get the namespace ID that was created by the scenario setup nsID, ok := scenarioContext.GetObject("ns1").(string) if !ok { - return ctx, fmt.Errorf("namespace ns1 not found in scenario context") + return ctx, errors.New("namespace ns1 not found in scenario context") } // Raw JSON with shorthand rule type @@ -174,7 +179,7 @@ func (s *EnumShorthandStepDefinitions) iCreateAnAttributeViaHTTPWithShorthandRul }`, nsID) rpcPath := "/policy.attributes.AttributesService/CreateAttribute" - statusCode, respBody, err := postConnectRPC(endpoint, rpcPath, token, body) + statusCode, respBody, err := postConnectRPC(ctx, endpoint, rpcPath, token, body) if err != nil { return ctx, fmt.Errorf("HTTP request failed: %w", err) } @@ -209,18 +214,9 @@ func (s *EnumShorthandStepDefinitions) iCreateAnAttributeViaHTTPWithShorthandRul // iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats verifies that a request // mixing shorthand and canonical enum names works correctly. func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithMixedEnumFormats(ctx context.Context) (context.Context, error) { - scenarioContext := GetPlatformScenarioContext(ctx) - scenarioContext.ClearError() - - endpoint := scenarioContext.ScenarioOptions.PlatformEndpoint - tokenEndpoint, err := scenarioContext.SDK.PlatformConfiguration.TokenEndpoint() - if err != nil { - return ctx, fmt.Errorf("failed to get token endpoint: %w", err) - } - - token, err := getAccessToken(tokenEndpoint) + _, endpoint, token, err := prepareAuthenticatedRequest(ctx) if err != nil { - return ctx, fmt.Errorf("failed to get access token: %w", err) + return ctx, err } // Mix shorthand and canonical names in the same request @@ -247,7 +243,7 @@ func (s *EnumShorthandStepDefinitions) iCreateASubjectConditionSetViaHTTPWithMix }` rpcPath := "/policy.subjectmapping.SubjectMappingService/CreateSubjectConditionSet" - statusCode, respBody, err := postConnectRPC(endpoint, rpcPath, token, body) + statusCode, respBody, err := postConnectRPC(ctx, endpoint, rpcPath, token, body) if err != nil { return ctx, fmt.Errorf("HTTP request failed: %w", err) } From 57921d8c1aca9b4b4dcc7b23f8a5221b0c2d3691 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 13:59:43 -0700 Subject: [PATCH 07/10] feat(core): export enumnormalize, add parent-scoped rules and WithHTTPMiddleware Move enumnormalize from internal/ to pkg/ so downstream consumers (e.g. DSP) can import EnumFieldRule and NewMiddleware. Add ParentField to EnumFieldRule for disambiguating shared field names (e.g. DSP's tagging enums all use "type" but with different prefixes under different parent keys). Add WithHTTPMiddleware server option so downstream servers can inject HTTP middleware into the handler chain for their own enum normalization rules. Co-Authored-By: Claude Opus 4.6 (1M context) Signed-off-by: Mary Dickson --- .../internal/enumnormalize/enumnormalize.go | 88 ------------- service/internal/server/server.go | 12 +- service/pkg/enumnormalize/enumnormalize.go | 120 ++++++++++++++++++ .../enumnormalize/enumnormalize_test.go | 87 ++++++++++++- .../enumnormalize/middleware.go | 2 +- .../enumnormalize/middleware_test.go | 0 service/pkg/server/options.go | 21 +++ service/pkg/server/start.go | 1 + 8 files changed, 237 insertions(+), 94 deletions(-) delete mode 100644 service/internal/enumnormalize/enumnormalize.go create mode 100644 service/pkg/enumnormalize/enumnormalize.go rename service/{internal => pkg}/enumnormalize/enumnormalize_test.go (78%) rename service/{internal => pkg}/enumnormalize/middleware.go (98%) rename service/{internal => pkg}/enumnormalize/middleware_test.go (100%) diff --git a/service/internal/enumnormalize/enumnormalize.go b/service/internal/enumnormalize/enumnormalize.go deleted file mode 100644 index cae966c86b..0000000000 --- a/service/internal/enumnormalize/enumnormalize.go +++ /dev/null @@ -1,88 +0,0 @@ -package enumnormalize - -import ( - "bytes" - "encoding/json" - "strings" -) - -// EnumFieldRule maps a JSON field name to the prefix that protobuf requires. -// When the middleware encounters a string value in a matching field that does -// not already carry the prefix, it prepends the prefix so that protojson -// recognises the canonical enum name. -type EnumFieldRule struct { - // JSONField is the protojson camelCase field name (e.g. "operator", "booleanOperator"). - JSONField string - // Prefix is the proto enum type prefix including trailing underscore - // (e.g. "SUBJECT_MAPPING_OPERATOR_ENUM_"). - Prefix string -} - -// fieldLookup is a pre-built map from JSON field name to enum prefix, -// constructed once at middleware initialization time. -type fieldLookup map[string]string - -// buildLookup creates a fieldLookup from a set of rules. Keys are stored -// exactly as declared (protojson always emits camelCase). -func buildLookup(rules []EnumFieldRule) fieldLookup { - m := make(fieldLookup, len(rules)) - for _, r := range rules { - m[r.JSONField] = r.Prefix - } - return m -} - -// normalizeJSON rewrites shorthand enum string values in body according to -// the pre-built lookup. Values that already carry the full prefix, numeric -// values, and fields not covered by any rule pass through unchanged. -func normalizeJSON(body []byte, lookup fieldLookup) ([]byte, error) { - if len(body) == 0 || len(lookup) == 0 { - return body, nil - } - - // Use json.Decoder with UseNumber to preserve numeric precision - // (avoids float64 conversion of large int64 values). - decoder := json.NewDecoder(bytes.NewReader(body)) - decoder.UseNumber() - - var parsed any - if err := decoder.Decode(&parsed); err != nil { - // Not valid JSON — pass through and let ConnectRPC surface the error. - return body, nil //nolint:nilerr // intentional: invalid JSON is not our error to report - } - - normalizeValue(parsed, lookup) - - return json.Marshal(parsed) -} - -// normalizeValue recursively walks a decoded JSON value, normalizing string -// enum fields according to the lookup map. -func normalizeValue(v any, lookup fieldLookup) { - switch val := v.(type) { - case map[string]any: - for key, child := range val { - if prefix, ok := lookup[key]; ok { - if s, isStr := child.(string); isStr { - val[key] = applyPrefix(s, prefix) - } - } - normalizeValue(child, lookup) - } - case []any: - for _, item := range val { - normalizeValue(item, lookup) - } - } -} - -// applyPrefix prepends prefix to value if it is not already present -// (case-insensitive check). The value is upper-cased before comparison and -// before prepending so that "in" and "IN" both resolve correctly. -func applyPrefix(value, prefix string) string { - upper := strings.ToUpper(value) - if strings.HasPrefix(upper, strings.ToUpper(prefix)) { - return upper - } - return prefix + upper -} diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 41f1f87949..7f32fc4f1a 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -26,13 +26,13 @@ import ( "github.com/opentdf/platform/sdk" sdkAudit "github.com/opentdf/platform/sdk/audit" "github.com/opentdf/platform/service/internal/auth" - "github.com/opentdf/platform/service/internal/enumnormalize" "github.com/opentdf/platform/service/internal/security" "github.com/opentdf/platform/service/internal/server/memhttp" "github.com/opentdf/platform/service/logger" "github.com/opentdf/platform/service/logger/audit" ctxAuth "github.com/opentdf/platform/service/pkg/auth" "github.com/opentdf/platform/service/pkg/cache" + "github.com/opentdf/platform/service/pkg/enumnormalize" "github.com/opentdf/platform/service/tracing" "golang.org/x/net/http2" "golang.org/x/net/http2/h2c" @@ -67,8 +67,9 @@ type Config struct { WellKnownConfigRegister func(namespace string, config any) error `mapstructure:"-" json:"-"` // Programmatic interceptors injected at startup (not loaded from config) - ExtraConnectInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` - ExtraIPCInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` + ExtraConnectInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` + ExtraIPCInterceptors []connect.Interceptor `mapstructure:"-" json:"-"` + ExtraHTTPMiddleware []func(http.Handler) http.Handler `mapstructure:"-" json:"-"` // Port to listen on Port int `mapstructure:"port" json:"port" default:"8080"` Host string `mapstructure:"host,omitempty" json:"host"` @@ -428,6 +429,11 @@ func newHTTPServer(c Config, connectRPC http.Handler, originalGrpcGateway http.H }, )(connectRPC) + // Apply extra HTTP middleware injected by downstream consumers (e.g. DSP). + for _, mw := range c.ExtraHTTPMiddleware { + connectRPC = mw(connectRPC) + } + // Adds deprecation header to any grpcGateway responses. var grpcGateway http.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { grpcRW := &grpcGatewayResponseWriter{w: w, code: http.StatusOK} diff --git a/service/pkg/enumnormalize/enumnormalize.go b/service/pkg/enumnormalize/enumnormalize.go new file mode 100644 index 0000000000..267b4b3d7d --- /dev/null +++ b/service/pkg/enumnormalize/enumnormalize.go @@ -0,0 +1,120 @@ +package enumnormalize + +import ( + "bytes" + "encoding/json" + "strings" +) + +// EnumFieldRule maps a JSON field name to the prefix that protobuf requires. +// When the middleware encounters a string value in a matching field that does +// not already carry the prefix, it prepends the prefix so that protojson +// recognises the canonical enum name. +type EnumFieldRule struct { + // JSONField is the protojson camelCase field name (e.g. "operator", "booleanOperator"). + JSONField string + // Prefix is the proto enum type prefix including trailing underscore + // (e.g. "SUBJECT_MAPPING_OPERATOR_ENUM_"). + Prefix string + // ParentField optionally scopes this rule to only match when JSONField + // appears inside an object that is a direct child of a key named + // ParentField (at any depth). This disambiguates cases where multiple + // enum types share the same field name (e.g. "type") but live under + // different parent keys (e.g. "contentExtractors" vs "tagProcessors"). + // When empty, the rule matches JSONField at any position (original behavior). + ParentField string +} + +// ruleLookup stores pre-built lookup tables for fast matching. +type ruleLookup struct { + // global maps field name → prefix for rules with no ParentField. + global map[string]string + // scoped maps parentField → (field name → prefix) for parent-scoped rules. + scoped map[string]map[string]string +} + +// buildRuleLookup creates a ruleLookup from a set of rules. +func buildRuleLookup(rules []EnumFieldRule) ruleLookup { + rl := ruleLookup{ + global: make(map[string]string), + scoped: make(map[string]map[string]string), + } + for _, r := range rules { + if r.ParentField == "" { + rl.global[r.JSONField] = r.Prefix + } else { + if rl.scoped[r.ParentField] == nil { + rl.scoped[r.ParentField] = make(map[string]string) + } + rl.scoped[r.ParentField][r.JSONField] = r.Prefix + } + } + return rl +} + +// normalizeJSON rewrites shorthand enum string values in body according to +// the configured rules. Values that already carry the full prefix, numeric +// values, and fields not covered by any rule pass through unchanged. +func normalizeJSON(body []byte, rl ruleLookup) ([]byte, error) { + if len(body) == 0 || (len(rl.global) == 0 && len(rl.scoped) == 0) { + return body, nil + } + + // Use json.Decoder with UseNumber to preserve numeric precision + // (avoids float64 conversion of large int64 values). + decoder := json.NewDecoder(bytes.NewReader(body)) + decoder.UseNumber() + + var parsed any + if err := decoder.Decode(&parsed); err != nil { + // Not valid JSON — pass through and let ConnectRPC surface the error. + return body, nil //nolint:nilerr // intentional: invalid JSON is not our error to report + } + + normalizeValue(parsed, rl, "") + + return json.Marshal(parsed) +} + +// normalizeValue recursively walks a decoded JSON value, normalizing string +// enum fields according to the lookup rules. parentKey tracks the key under +// which the current value was found, enabling parent-scoped rules. +func normalizeValue(v any, rl ruleLookup, parentKey string) { + switch val := v.(type) { + case map[string]any: + for key, child := range val { + // Check global rules (no parent scope) + if prefix, ok := rl.global[key]; ok { + if s, isStr := child.(string); isStr { + val[key] = applyPrefix(s, prefix) + } + } + // Check parent-scoped rules + if scopedFields, hasParent := rl.scoped[parentKey]; hasParent { + if scopedPrefix, hasField := scopedFields[key]; hasField { + if s, isStr := child.(string); isStr { + val[key] = applyPrefix(s, scopedPrefix) + } + } + } + normalizeValue(child, rl, key) + } + case []any: + // Array elements inherit the parent key so that scoped rules work + // through arrays (e.g. "contentExtractors": [{"type": "..."}]). + for _, item := range val { + normalizeValue(item, rl, parentKey) + } + } +} + +// applyPrefix prepends prefix to value if it is not already present +// (case-insensitive check). The value is upper-cased before comparison and +// before prepending so that "in" and "IN" both resolve correctly. +func applyPrefix(value, prefix string) string { + upper := strings.ToUpper(value) + if strings.HasPrefix(upper, strings.ToUpper(prefix)) { + return upper + } + return prefix + upper +} diff --git a/service/internal/enumnormalize/enumnormalize_test.go b/service/pkg/enumnormalize/enumnormalize_test.go similarity index 78% rename from service/internal/enumnormalize/enumnormalize_test.go rename to service/pkg/enumnormalize/enumnormalize_test.go index cb7bc5e01f..8e3bc1fbb2 100644 --- a/service/internal/enumnormalize/enumnormalize_test.go +++ b/service/pkg/enumnormalize/enumnormalize_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" ) -var allLookup = buildLookup([]EnumFieldRule{ +var allLookup = buildRuleLookup([]EnumFieldRule{ {JSONField: "operator", Prefix: "SUBJECT_MAPPING_OPERATOR_ENUM_"}, {JSONField: "booleanOperator", Prefix: "CONDITION_BOOLEAN_TYPE_ENUM_"}, {JSONField: "rule", Prefix: "ATTRIBUTE_RULE_TYPE_ENUM_"}, @@ -374,7 +374,7 @@ func TestNormalizeJSON_EmptyBody(t *testing.T) { func TestNormalizeJSON_NoRules(t *testing.T) { input := `{"operator":"IN"}` - out, err := normalizeJSON([]byte(input), nil) + out, err := normalizeJSON([]byte(input), ruleLookup{}) require.NoError(t, err) assert.Equal(t, input, string(out)) } @@ -386,3 +386,86 @@ func TestNormalizeJSON_InvalidJSON(t *testing.T) { // Invalid JSON passes through unchanged assert.Equal(t, input, string(out)) } + +// Parent-scoped rule tests + +var scopedLookup = buildRuleLookup([]EnumFieldRule{ + // Different prefixes for the same "type" field, scoped by parent key + {JSONField: "type", Prefix: "CONTENT_EXTRACTOR_TYPE_", ParentField: "contentExtractors"}, + {JSONField: "type", Prefix: "TAG_PROCESSOR_TYPE_", ParentField: "tagProcessors"}, + // A global rule (no parent scope) for a different field + {JSONField: "state", Prefix: "ACTIVE_STATE_ENUM_"}, +}) + +func TestNormalizeJSON_ParentScopedRules(t *testing.T) { + input := `{ + "config": { + "v1": { + "contentExtractors": [{"type": "TIKA_CONTENT_EXTRACTION", "id": "ce1"}], + "tagProcessors": [{"type": "REQUIRED_TAGS", "id": "tp1"}] + } + } + }` + expected := `{ + "config": { + "v1": { + "contentExtractors": [{"type": "CONTENT_EXTRACTOR_TYPE_TIKA_CONTENT_EXTRACTION", "id": "ce1"}], + "tagProcessors": [{"type": "TAG_PROCESSOR_TYPE_REQUIRED_TAGS", "id": "tp1"}] + } + } + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_ParentScopedDoesNotMatchGlobally(t *testing.T) { + // "type" at top level should NOT be rewritten — it only matches under + // "contentExtractors" or "tagProcessors". + input := `{"type": "SOME_VALUE"}` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, input, string(out)) +} + +func TestNormalizeJSON_GlobalAndScopedRulesCoexist(t *testing.T) { + // "state" is a global rule; "type" is parent-scoped. + input := `{ + "state": "ACTIVE", + "contentExtractors": [{"type": "TIKA_CONTENT_EXTRACTION"}] + }` + expected := `{ + "state": "ACTIVE_STATE_ENUM_ACTIVE", + "contentExtractors": [{"type": "CONTENT_EXTRACTOR_TYPE_TIKA_CONTENT_EXTRACTION"}] + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} + +func TestNormalizeJSON_ParentScopedFullCanonicalPassthrough(t *testing.T) { + // Already-prefixed values pass through unchanged + input := `{ + "contentExtractors": [{"type": "CONTENT_EXTRACTOR_TYPE_TIKA_CONTENT_EXTRACTION"}] + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, input, string(out)) +} + +func TestNormalizeJSON_ParentScopedCaseInsensitive(t *testing.T) { + input := `{ + "tagProcessors": [{"type": "required_tags"}] + }` + expected := `{ + "tagProcessors": [{"type": "TAG_PROCESSOR_TYPE_REQUIRED_TAGS"}] + }` + + out, err := normalizeJSON([]byte(input), scopedLookup) + require.NoError(t, err) + assert.JSONEq(t, expected, string(out)) +} diff --git a/service/internal/enumnormalize/middleware.go b/service/pkg/enumnormalize/middleware.go similarity index 98% rename from service/internal/enumnormalize/middleware.go rename to service/pkg/enumnormalize/middleware.go index f2340add76..20fdbdc3cb 100644 --- a/service/internal/enumnormalize/middleware.go +++ b/service/pkg/enumnormalize/middleware.go @@ -19,7 +19,7 @@ const maxBodySize = 1 << 20 // 1 MB // match (wrong content-type, wrong path) are forwarded unchanged with zero // overhead. func NewMiddleware(rules []EnumFieldRule, paths []string) func(http.Handler) http.Handler { - lookup := buildLookup(rules) + lookup := buildRuleLookup(rules) pathSet := make(map[string]struct{}, len(paths)) for _, p := range paths { diff --git a/service/internal/enumnormalize/middleware_test.go b/service/pkg/enumnormalize/middleware_test.go similarity index 100% rename from service/internal/enumnormalize/middleware_test.go rename to service/pkg/enumnormalize/middleware_test.go diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index db3952ceef..9722f5712f 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -2,6 +2,7 @@ package server import ( "context" + "net/http" "connectrpc.com/connect" "github.com/casbin/casbin/v2/persist" @@ -28,6 +29,7 @@ type StartConfig struct { extraConnectInterceptors []connect.Interceptor extraIPCInterceptors []connect.Interceptor + extraHTTPMiddleware []func(http.Handler) http.Handler trustKeyManagerCtxs []trust.NamedKeyManagerCtxFactory @@ -186,6 +188,25 @@ func WithIPCInterceptors(interceptors ...connect.Interceptor) StartOptions { } } +// WithHTTPMiddleware appends HTTP middleware that wraps the ConnectRPC handler. +// Middleware is applied in order, with the last middleware outermost. +// This runs at the HTTP transport layer, before ConnectRPC deserialization, +// making it suitable for request body rewriting (e.g. enum normalization). +// +// Example: +// +// server.Start( +// server.WithHTTPMiddleware( +// enumnormalize.NewMiddleware(rules, paths), +// ), +// ) +func WithHTTPMiddleware(middleware ...func(http.Handler) http.Handler) StartOptions { + return func(c StartConfig) StartConfig { + c.extraHTTPMiddleware = append(c.extraHTTPMiddleware, middleware...) + return c + } +} + // WithTrustKeyManagerFactories option provides factories for creating trust key managers. // Use WithTrustKeyManagerCtxFactories instead. // EXPERIMENTAL diff --git a/service/pkg/server/start.go b/service/pkg/server/start.go index f3513a090d..85ea87a69f 100644 --- a/service/pkg/server/start.go +++ b/service/pkg/server/start.go @@ -154,6 +154,7 @@ func Start(f ...StartOptions) error { // Programmatic Connect/IPC interceptors (not config-driven) cfg.Server.ExtraConnectInterceptors = append(cfg.Server.ExtraConnectInterceptors, startConfig.extraConnectInterceptors...) cfg.Server.ExtraIPCInterceptors = append(cfg.Server.ExtraIPCInterceptors, startConfig.extraIPCInterceptors...) + cfg.Server.ExtraHTTPMiddleware = append(cfg.Server.ExtraHTTPMiddleware, startConfig.extraHTTPMiddleware...) // Set Default Policy if startConfig.builtinPolicyOverride != "" { From 6f9cdc2c0e9ea63bbc06d4cf7403bdd8592df423 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 14:06:40 -0700 Subject: [PATCH 08/10] fix(core): fix CreateAttribute request body in BDD shorthand test Fields are at the top level of CreateAttributeRequest, not nested under an "attribute" key. Co-Authored-By: Claude Opus 4.6 (1M context) --- tests-bdd/cukes/steps_enum_shorthand.go | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/tests-bdd/cukes/steps_enum_shorthand.go b/tests-bdd/cukes/steps_enum_shorthand.go index f671c877d4..86290f9f18 100644 --- a/tests-bdd/cukes/steps_enum_shorthand.go +++ b/tests-bdd/cukes/steps_enum_shorthand.go @@ -168,14 +168,12 @@ func (s *EnumShorthandStepDefinitions) iCreateAnAttributeViaHTTPWithShorthandRul return ctx, errors.New("namespace ns1 not found in scenario context") } - // Raw JSON with shorthand rule type + // Raw JSON with shorthand rule type — fields are at the top level per the proto definition body := fmt.Sprintf(`{ - "attribute": { - "namespaceId": "%s", - "name": "shorthand_test_attr", - "rule": "ANY_OF", - "values": ["val1", "val2"] - } + "namespaceId": "%s", + "name": "shorthand_test_attr", + "rule": "ANY_OF", + "values": ["val1", "val2"] }`, nsID) rpcPath := "/policy.attributes.AttributesService/CreateAttribute" From 5e1ce050194e1b8af8dcd6c1e7329d7ae182dee2 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 14:22:37 -0700 Subject: [PATCH 09/10] fix(core): address CodeRabbit review feedback - Guard against trailing JSON tokens: add EOF check after first decode to prevent silently dropping concatenated JSON values - Make maxBodySize configurable: NewMiddleware now accepts maxBodyBytes parameter (0 = default 1 MB) so callers can align with their own limits - Add HTTP client timeouts to BDD test helpers to prevent CI hangs Co-Authored-By: Claude Opus 4.6 (1M context) --- service/internal/server/server.go | 1 + service/pkg/enumnormalize/enumnormalize.go | 9 +++++++++ .../pkg/enumnormalize/enumnormalize_test.go | 9 +++++++++ service/pkg/enumnormalize/middleware.go | 19 ++++++++++-------- service/pkg/enumnormalize/middleware_test.go | 20 +++++++++---------- service/pkg/server/options.go | 2 +- tests-bdd/cukes/steps_enum_shorthand.go | 7 +++++-- 7 files changed, 46 insertions(+), 21 deletions(-) diff --git a/service/internal/server/server.go b/service/internal/server/server.go index 7f32fc4f1a..fa113c614d 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -427,6 +427,7 @@ func newHTTPServer(c Config, connectRPC http.Handler, originalGrpcGateway http.H // Unsafe RPCs (rule) unsafeconnect.UnsafeServiceUnsafeUpdateAttributeProcedure, }, + 0, // use default max body size )(connectRPC) // Apply extra HTTP middleware injected by downstream consumers (e.g. DSP). diff --git a/service/pkg/enumnormalize/enumnormalize.go b/service/pkg/enumnormalize/enumnormalize.go index 267b4b3d7d..25f133475a 100644 --- a/service/pkg/enumnormalize/enumnormalize.go +++ b/service/pkg/enumnormalize/enumnormalize.go @@ -3,6 +3,7 @@ package enumnormalize import ( "bytes" "encoding/json" + "io" "strings" ) @@ -71,6 +72,14 @@ func normalizeJSON(body []byte, rl ruleLookup) ([]byte, error) { return body, nil //nolint:nilerr // intentional: invalid JSON is not our error to report } + // Ensure the entire body is a single JSON value. If there are trailing + // tokens (e.g. `{"a":1}{"b":2}`), return the original body so ConnectRPC + // can reject the malformed input rather than silently dropping the tail. + var trailing any + if err := decoder.Decode(&trailing); err != io.EOF { + return body, nil + } + normalizeValue(parsed, rl, "") return json.Marshal(parsed) diff --git a/service/pkg/enumnormalize/enumnormalize_test.go b/service/pkg/enumnormalize/enumnormalize_test.go index 8e3bc1fbb2..4d13ca384f 100644 --- a/service/pkg/enumnormalize/enumnormalize_test.go +++ b/service/pkg/enumnormalize/enumnormalize_test.go @@ -379,6 +379,15 @@ func TestNormalizeJSON_NoRules(t *testing.T) { assert.Equal(t, input, string(out)) } +func TestNormalizeJSON_TrailingJSONTokensPassThrough(t *testing.T) { + // Multiple concatenated JSON values should pass through unchanged so + // ConnectRPC can reject the malformed input. + input := `{"operator":"IN"}{"extra":1}` + out, err := normalizeJSON([]byte(input), allLookup) + require.NoError(t, err) + assert.Equal(t, input, string(out)) +} + func TestNormalizeJSON_InvalidJSON(t *testing.T) { input := `not json at all` out, err := normalizeJSON([]byte(input), allLookup) diff --git a/service/pkg/enumnormalize/middleware.go b/service/pkg/enumnormalize/middleware.go index 20fdbdc3cb..68b11edbec 100644 --- a/service/pkg/enumnormalize/middleware.go +++ b/service/pkg/enumnormalize/middleware.go @@ -8,17 +8,20 @@ import ( "strings" ) -// maxBodySize is the upper bound on request bodies the middleware will read -// into memory for normalization. Policy API request bodies are small (typically -// under 10 KB); this cap prevents abuse while being generous enough for any -// legitimate request. ConnectRPC enforces its own message size limits downstream. -const maxBodySize = 1 << 20 // 1 MB +// defaultMaxBodySize is the fallback upper bound on request bodies the +// middleware will read into memory for normalization when no explicit limit is +// provided. Policy API request bodies are small (typically under 10 KB). +const defaultMaxBodySize = 1 << 20 // 1 MB // NewMiddleware returns HTTP middleware that normalises shorthand enum string // values in JSON request bodies for the given RPC paths. Requests that do not // match (wrong content-type, wrong path) are forwarded unchanged with zero -// overhead. -func NewMiddleware(rules []EnumFieldRule, paths []string) func(http.Handler) http.Handler { +// overhead. maxBodyBytes sets the upper bound on request body size; pass 0 to +// use the default (1 MB). +func NewMiddleware(rules []EnumFieldRule, paths []string, maxBodyBytes int64) func(http.Handler) http.Handler { + if maxBodyBytes <= 0 { + maxBodyBytes = defaultMaxBodySize + } lookup := buildRuleLookup(rules) pathSet := make(map[string]struct{}, len(paths)) @@ -34,7 +37,7 @@ func NewMiddleware(rules []EnumFieldRule, paths []string) func(http.Handler) htt return } - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodySize)) + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes)) if err != nil { next.ServeHTTP(w, r) return diff --git a/service/pkg/enumnormalize/middleware_test.go b/service/pkg/enumnormalize/middleware_test.go index 1bccb21b1c..24ced2f85e 100644 --- a/service/pkg/enumnormalize/middleware_test.go +++ b/service/pkg/enumnormalize/middleware_test.go @@ -30,7 +30,7 @@ func (h *captureHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { func TestMiddleware_NormalizesMatchingJSONRequest(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) body := `{"booleanOperator":"AND","conditions":[{"operator":"IN"}]}` @@ -44,7 +44,7 @@ func TestMiddleware_NormalizesMatchingJSONRequest(t *testing.T) { func TestMiddleware_ConnectJSONContentType(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) body := `{"operator":"NOT_IN"}` @@ -57,7 +57,7 @@ func TestMiddleware_ConnectJSONContentType(t *testing.T) { func TestMiddleware_NonMatchingPathPassesThrough(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) body := `{"operator":"IN"}` @@ -71,7 +71,7 @@ func TestMiddleware_NonMatchingPathPassesThrough(t *testing.T) { func TestMiddleware_NonJSONContentTypePassesThrough(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) body := `{"operator":"IN"}` @@ -84,7 +84,7 @@ func TestMiddleware_NonJSONContentTypePassesThrough(t *testing.T) { func TestMiddleware_CanonicalNamesUnchanged(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) body := `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}` @@ -97,7 +97,7 @@ func TestMiddleware_CanonicalNamesUnchanged(t *testing.T) { func TestMiddleware_NumericEnumValuesPassThrough(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) // Numeric enum values (e.g., 1 for IN, 3 for IN_CONTAINS) are valid @@ -112,7 +112,7 @@ func TestMiddleware_NumericEnumValuesPassThrough(t *testing.T) { func TestMiddleware_ContentLengthUpdated(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) body := `{"operator":"IN"}` @@ -126,11 +126,11 @@ func TestMiddleware_ContentLengthUpdated(t *testing.T) { func TestMiddleware_OversizedBodySkipsNormalization(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}) + mw := NewMiddleware(testRules, []string{testPath}, 0) handler := mw(capture) - // Build a body that exceeds maxBodySize (1 MB). - oversized := `{"operator":"` + strings.Repeat("A", maxBodySize) + `"}` + // Build a body that exceeds the default max body size (1 MB). + oversized := `{"operator":"` + strings.Repeat("A", defaultMaxBodySize) + `"}` req := httptest.NewRequest(http.MethodPost, testPath, strings.NewReader(oversized)) req.Header.Set("Content-Type", "application/json") handler.ServeHTTP(httptest.NewRecorder(), req) diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index 9722f5712f..3bd87f5bc5 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -197,7 +197,7 @@ func WithIPCInterceptors(interceptors ...connect.Interceptor) StartOptions { // // server.Start( // server.WithHTTPMiddleware( -// enumnormalize.NewMiddleware(rules, paths), +// enumnormalize.NewMiddleware(rules, paths, 0), // ), // ) func WithHTTPMiddleware(middleware ...func(http.Handler) http.Handler) StartOptions { diff --git a/tests-bdd/cukes/steps_enum_shorthand.go b/tests-bdd/cukes/steps_enum_shorthand.go index 86290f9f18..5858ca2cf1 100644 --- a/tests-bdd/cukes/steps_enum_shorthand.go +++ b/tests-bdd/cukes/steps_enum_shorthand.go @@ -10,10 +10,13 @@ import ( "net/http" "net/url" "strings" + "time" "github.com/cucumber/godog" ) +const bddHTTPTimeout = 15 * time.Second + type EnumShorthandStepDefinitions struct{} // getAccessToken fetches a bearer token from the Keycloak token endpoint using @@ -31,7 +34,7 @@ func getAccessToken(ctx context.Context, tokenEndpoint string) (string, error) { } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") - client := &http.Client{} + client := &http.Client{Timeout: bddHTTPTimeout} resp, err := client.Do(req) if err != nil { return "", fmt.Errorf("token request failed: %w", err) @@ -55,7 +58,7 @@ func getAccessToken(ctx context.Context, tokenEndpoint string) (string, error) { // postConnectRPC sends a raw JSON body to a ConnectRPC endpoint and returns the // HTTP status code and response body. func postConnectRPC(ctx context.Context, endpoint, rpcPath, token, jsonBody string) (int, string, error) { - client := &http.Client{} + client := &http.Client{Timeout: bddHTTPTimeout} req, err := http.NewRequestWithContext(ctx, http.MethodPost, endpoint+rpcPath, strings.NewReader(jsonBody)) if err != nil { From db769f656d569e7c462827d3b515837b9464e347 Mon Sep 17 00:00:00 2001 From: Mary Dickson Date: Mon, 27 Apr 2026 14:26:48 -0700 Subject: [PATCH 10/10] refactor(core): make maxBodyBytes optional in NewMiddleware Use variadic parameter so callers can omit it entirely. Defaults to 1 MB. Callers that need a custom limit can pass it explicitly. Co-Authored-By: Claude Opus 4.6 (1M context) --- service/internal/server/server.go | 1 - service/pkg/enumnormalize/middleware.go | 13 +++++++------ service/pkg/enumnormalize/middleware_test.go | 16 ++++++++-------- service/pkg/server/options.go | 2 +- 4 files changed, 16 insertions(+), 16 deletions(-) diff --git a/service/internal/server/server.go b/service/internal/server/server.go index fa113c614d..7f32fc4f1a 100644 --- a/service/internal/server/server.go +++ b/service/internal/server/server.go @@ -427,7 +427,6 @@ func newHTTPServer(c Config, connectRPC http.Handler, originalGrpcGateway http.H // Unsafe RPCs (rule) unsafeconnect.UnsafeServiceUnsafeUpdateAttributeProcedure, }, - 0, // use default max body size )(connectRPC) // Apply extra HTTP middleware injected by downstream consumers (e.g. DSP). diff --git a/service/pkg/enumnormalize/middleware.go b/service/pkg/enumnormalize/middleware.go index 68b11edbec..a639192c1b 100644 --- a/service/pkg/enumnormalize/middleware.go +++ b/service/pkg/enumnormalize/middleware.go @@ -16,11 +16,12 @@ const defaultMaxBodySize = 1 << 20 // 1 MB // NewMiddleware returns HTTP middleware that normalises shorthand enum string // values in JSON request bodies for the given RPC paths. Requests that do not // match (wrong content-type, wrong path) are forwarded unchanged with zero -// overhead. maxBodyBytes sets the upper bound on request body size; pass 0 to -// use the default (1 MB). -func NewMiddleware(rules []EnumFieldRule, paths []string, maxBodyBytes int64) func(http.Handler) http.Handler { - if maxBodyBytes <= 0 { - maxBodyBytes = defaultMaxBodySize +// overhead. An optional maxBodyBytes sets the upper bound on request body size; +// defaults to 1 MB if omitted or zero. +func NewMiddleware(rules []EnumFieldRule, paths []string, maxBodyBytes ...int64) func(http.Handler) http.Handler { + bodyLimit := int64(defaultMaxBodySize) + if len(maxBodyBytes) > 0 && maxBodyBytes[0] > 0 { + bodyLimit = maxBodyBytes[0] } lookup := buildRuleLookup(rules) @@ -37,7 +38,7 @@ func NewMiddleware(rules []EnumFieldRule, paths []string, maxBodyBytes int64) fu return } - body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, maxBodyBytes)) + body, err := io.ReadAll(http.MaxBytesReader(w, r.Body, bodyLimit)) if err != nil { next.ServeHTTP(w, r) return diff --git a/service/pkg/enumnormalize/middleware_test.go b/service/pkg/enumnormalize/middleware_test.go index 24ced2f85e..5f26fe1b0b 100644 --- a/service/pkg/enumnormalize/middleware_test.go +++ b/service/pkg/enumnormalize/middleware_test.go @@ -30,7 +30,7 @@ func (h *captureHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) { func TestMiddleware_NormalizesMatchingJSONRequest(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) body := `{"booleanOperator":"AND","conditions":[{"operator":"IN"}]}` @@ -44,7 +44,7 @@ func TestMiddleware_NormalizesMatchingJSONRequest(t *testing.T) { func TestMiddleware_ConnectJSONContentType(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) body := `{"operator":"NOT_IN"}` @@ -57,7 +57,7 @@ func TestMiddleware_ConnectJSONContentType(t *testing.T) { func TestMiddleware_NonMatchingPathPassesThrough(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) body := `{"operator":"IN"}` @@ -71,7 +71,7 @@ func TestMiddleware_NonMatchingPathPassesThrough(t *testing.T) { func TestMiddleware_NonJSONContentTypePassesThrough(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) body := `{"operator":"IN"}` @@ -84,7 +84,7 @@ func TestMiddleware_NonJSONContentTypePassesThrough(t *testing.T) { func TestMiddleware_CanonicalNamesUnchanged(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) body := `{"operator":"SUBJECT_MAPPING_OPERATOR_ENUM_IN"}` @@ -97,7 +97,7 @@ func TestMiddleware_CanonicalNamesUnchanged(t *testing.T) { func TestMiddleware_NumericEnumValuesPassThrough(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) // Numeric enum values (e.g., 1 for IN, 3 for IN_CONTAINS) are valid @@ -112,7 +112,7 @@ func TestMiddleware_NumericEnumValuesPassThrough(t *testing.T) { func TestMiddleware_ContentLengthUpdated(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) body := `{"operator":"IN"}` @@ -126,7 +126,7 @@ func TestMiddleware_ContentLengthUpdated(t *testing.T) { func TestMiddleware_OversizedBodySkipsNormalization(t *testing.T) { capture := &captureHandler{} - mw := NewMiddleware(testRules, []string{testPath}, 0) + mw := NewMiddleware(testRules, []string{testPath}) handler := mw(capture) // Build a body that exceeds the default max body size (1 MB). diff --git a/service/pkg/server/options.go b/service/pkg/server/options.go index 3bd87f5bc5..9722f5712f 100644 --- a/service/pkg/server/options.go +++ b/service/pkg/server/options.go @@ -197,7 +197,7 @@ func WithIPCInterceptors(interceptors ...connect.Interceptor) StartOptions { // // server.Start( // server.WithHTTPMiddleware( -// enumnormalize.NewMiddleware(rules, paths, 0), +// enumnormalize.NewMiddleware(rules, paths), // ), // ) func WithHTTPMiddleware(middleware ...func(http.Handler) http.Handler) StartOptions {