diff --git a/openmeter/customer/adapter/customer.go b/openmeter/customer/adapter/customer.go index e0631d099f..adf349978f 100644 --- a/openmeter/customer/adapter/customer.go +++ b/openmeter/customer/adapter/customer.go @@ -15,6 +15,7 @@ import ( appcustominvoicingcustomerdb "github.com/openmeterio/openmeter/openmeter/ent/db/appcustominvoicingcustomer" appstripecustomerdb "github.com/openmeterio/openmeter/openmeter/ent/db/appstripecustomer" billingcustomeroverridedb "github.com/openmeterio/openmeter/openmeter/ent/db/billingcustomeroverride" + billingprofiledb "github.com/openmeterio/openmeter/openmeter/ent/db/billingprofile" customerdb "github.com/openmeterio/openmeter/openmeter/ent/db/customer" customersubjectsdb "github.com/openmeterio/openmeter/openmeter/ent/db/customersubjects" plandb "github.com/openmeterio/openmeter/openmeter/ent/db/plan" @@ -78,14 +79,19 @@ func (a *adapter) ListCustomers(ctx context.Context, input customer.ListCustomer } if input.BillingProfileID != nil { - if p := filter.SelectPredicate[predicate.BillingCustomerOverride]( - filter.Filter(*input.BillingProfileID), - billingcustomeroverridedb.FieldBillingProfileID, - ); p != nil { - query = query.Where(customerdb.HasBillingCustomerOverrideWith( - *p, - billingcustomeroverridedb.DeletedAtIsNil(), - )) + defaultProfileID, err := repo.db.BillingProfile.Query(). + Where( + billingprofiledb.Namespace(input.Namespace), + billingprofiledb.Default(true), + billingprofiledb.DeletedAtIsNil(), + ). + FirstID(ctx) + if err != nil { + return pagination.Result[customer.Customer]{}, fmt.Errorf("resolving default billing profile id: %w", err) + } + + if p := buildBillingProfileIDPredicate(*input.BillingProfileID, defaultProfileID); p != nil { + query = query.Where(*p) } } @@ -809,3 +815,56 @@ func activeSubscriptionFilter(at time.Time) []predicate.Subscription { subscriptiondb.CreatedAtLTE(at), } } + +// buildBillingProfileIDPredicate builds a customer predicate that filters on +// the customer's *effective* billing profile id — i.e. +// COALESCE(override.billing_profile_id, namespace_default_profile.id). +// +// The filter is routed through filter.Select on both +// billing_customer_override.billing_profile_id (for customers with an explicit +// live override) and billing_profile.id (via an EXISTS subquery, for customers +// who resolve to the namespace default). +func buildBillingProfileIDPredicate(f filter.FilterULID, defaultProfileID string) *predicate.Customer { + overrideSelector := f.Select(billingcustomeroverridedb.FieldBillingProfileID) + if overrideSelector == nil { + return nil + } + + preds := []predicate.Customer{ + customerdb.HasBillingCustomerOverrideWith( + billingcustomeroverridedb.DeletedAtIsNil(), + predicate.BillingCustomerOverride(overrideSelector), + ), + } + + if defaultSelector := f.Select(billingprofiledb.FieldID); defaultSelector != nil { + // Resolves to the default profile: no live override OR a live override + // with NULL profile_id. + resolvesToDefault := customerdb.Or( + customerdb.Not(customerdb.HasBillingCustomerOverrideWith( + billingcustomeroverridedb.DeletedAtIsNil(), + )), + customerdb.HasBillingCustomerOverrideWith( + billingcustomeroverridedb.DeletedAtIsNil(), + billingcustomeroverridedb.BillingProfileIDIsNil(), + ), + ) + + // EXISTS subquery that pins the namespace default profile row by id + // and applies the user's filter to that row's id. This lets eq, ne, + // in (and the And/Or wrappers) all flow through filter.Select. + defaultMatchesFilter := predicate.Customer(func(s *sql.Selector) { + bp := sql.Table(billingprofiledb.Table) + sub := sql.Select(bp.C(billingprofiledb.FieldID)). + From(bp). + Where(sql.EQ(bp.C(billingprofiledb.FieldID), defaultProfileID)) + predicate.BillingProfile(defaultSelector)(sub) + s.Where(sql.Exists(sub)) + }) + + preds = append(preds, customerdb.And(resolvesToDefault, defaultMatchesFilter)) + } + + p := customerdb.Or(preds...) + return &p +} diff --git a/test/customer/customer.go b/test/customer/customer.go index f78cbbe5f6..503acb88a5 100644 --- a/test/customer/customer.go +++ b/test/customer/customer.go @@ -11,6 +11,7 @@ import ( "github.com/samber/lo" "github.com/stretchr/testify/require" + "github.com/openmeterio/openmeter/openmeter/billing" "github.com/openmeterio/openmeter/openmeter/customer" "github.com/openmeterio/openmeter/openmeter/entitlement" "github.com/openmeterio/openmeter/openmeter/productcatalog" @@ -672,6 +673,148 @@ func (s *CustomerHandlerTestSuite) TestList(ctx context.Context, t *testing.T) { require.Equal(t, createCustomer1.ID, list.Items[1].ID, "Customer 1 must be second in order") } +// TestListBillingProfileFilter tests that the billing_profile_id filter operates +// on the customer's effective billing profile — i.e. customers without an +// explicit override are matched when the filtered id is the namespace default. +func (s *CustomerHandlerTestSuite) TestListBillingProfileFilter(ctx context.Context, t *testing.T) { + s.setupNamespace(t) + + customerService := s.Env.Customer() + billingService := s.Env.Billing() + + sandboxApp := s.installSandboxApp(t, s.namespace) + defaultProfile := s.createDefaultProfile(t, sandboxApp, s.namespace) + + pinnedInput := minimalCreateProfileInputTemplate(sandboxApp.GetID()) + pinnedInput.Namespace = s.namespace + pinnedInput.Default = false + pinnedInput.Name = "Pinned Profile" + pinnedProfile, err := billingService.CreateProfile(ctx, pinnedInput) + require.NoError(t, err, "creating pinned profile must not fail") + + // noOverride: relies on the namespace default profile. + noOverride, err := customerService.CreateCustomer(ctx, customer.CreateCustomerInput{ + Namespace: s.namespace, + CustomerMutate: customer.CustomerMutate{ + Key: lo.ToPtr("no-override"), + Name: "No Override", + }, + }) + require.NoError(t, err) + + // overrideNullProfile: has an override row with billing_profile_id IS NULL, + // which also resolves to the namespace default profile. + overrideNullProfile, err := customerService.CreateCustomer(ctx, customer.CreateCustomerInput{ + Namespace: s.namespace, + CustomerMutate: customer.CustomerMutate{ + Key: lo.ToPtr("override-null-profile"), + Name: "Override Null Profile", + }, + }) + require.NoError(t, err) + _, err = billingService.UpsertCustomerOverride(ctx, billing.UpsertCustomerOverrideInput{ + Namespace: s.namespace, + CustomerID: overrideNullProfile.ID, + Collection: billing.CollectionOverrideConfig{ + Interval: lo.ToPtr(datetime.MustParseDuration(t, "PT1H")), + }, + }) + require.NoError(t, err, "upserting customer override without profile id must not fail") + + // overrideDefault: has an override pointing explicitly at the default profile. + overrideDefault, err := customerService.CreateCustomer(ctx, customer.CreateCustomerInput{ + Namespace: s.namespace, + CustomerMutate: customer.CustomerMutate{ + Key: lo.ToPtr("override-default"), + Name: "Override Default", + }, + }) + require.NoError(t, err) + _, err = billingService.UpsertCustomerOverride(ctx, billing.UpsertCustomerOverrideInput{ + Namespace: s.namespace, + CustomerID: overrideDefault.ID, + ProfileID: defaultProfile.ID, + }) + require.NoError(t, err, "upserting customer override pinned to default must not fail") + + // overridePinned: has an override pointing at the non-default pinned profile. + overridePinned, err := customerService.CreateCustomer(ctx, customer.CreateCustomerInput{ + Namespace: s.namespace, + CustomerMutate: customer.CustomerMutate{ + Key: lo.ToPtr("override-pinned"), + Name: "Override Pinned", + }, + }) + require.NoError(t, err) + _, err = billingService.UpsertCustomerOverride(ctx, billing.UpsertCustomerOverrideInput{ + Namespace: s.namespace, + CustomerID: overridePinned.ID, + ProfileID: pinnedProfile.ID, + }) + require.NoError(t, err, "upserting customer override pinned to non-default must not fail") + + // overrideSoftDeleted: had an override pinned to the non-default profile, + // then deleted. The soft-deleted row must not affect the effective profile, + // so the customer falls back to the namespace default. + overrideSoftDeleted, err := customerService.CreateCustomer(ctx, customer.CreateCustomerInput{ + Namespace: s.namespace, + CustomerMutate: customer.CustomerMutate{ + Key: lo.ToPtr("override-soft-deleted"), + Name: "Override Soft Deleted", + }, + }) + require.NoError(t, err) + _, err = billingService.UpsertCustomerOverride(ctx, billing.UpsertCustomerOverrideInput{ + Namespace: s.namespace, + CustomerID: overrideSoftDeleted.ID, + ProfileID: pinnedProfile.ID, + }) + require.NoError(t, err) + require.NoError(t, billingService.DeleteCustomerOverride(ctx, billing.DeleteCustomerOverrideInput{ + Customer: customer.CustomerID{Namespace: s.namespace, ID: overrideSoftDeleted.ID}, + }), "deleting customer override must not fail") + + page := pagination.Page{PageNumber: 1, PageSize: 50} + idsOf := func(items []customer.Customer) []string { + ids := make([]string, 0, len(items)) + for _, c := range items { + ids = append(ids, c.ID) + } + return ids + } + + // eq default covers the bug: customers with no override, with override + // pinned to default, with override.billing_profile_id IS NULL, and with a + // soft-deleted override all resolve to the default profile and must match. + t.Run("eq default", func(t *testing.T) { + list, err := customerService.ListCustomers(ctx, customer.ListCustomersInput{ + Namespace: s.namespace, + Page: page, + BillingProfileID: &filter.FilterULID{FilterString: filter.FilterString{Eq: &defaultProfile.ID}}, + }) + require.NoError(t, err) + require.ElementsMatch(t, []string{ + noOverride.ID, + overrideNullProfile.ID, + overrideDefault.ID, + overrideSoftDeleted.ID, + }, idsOf(list.Items)) + }) + + // eq pinned guards against over-matching from the default branch — only + // customers with an explicit live override pointing at the pinned profile + // should match. + t.Run("eq pinned", func(t *testing.T) { + list, err := customerService.ListCustomers(ctx, customer.ListCustomersInput{ + Namespace: s.namespace, + Page: page, + BillingProfileID: &filter.FilterULID{FilterString: filter.FilterString{Eq: &pinnedProfile.ID}}, + }) + require.NoError(t, err) + require.ElementsMatch(t, []string{overridePinned.ID}, idsOf(list.Items)) + }) +} + // TestListCustomerUsageAttributions tests the listing of customer usage attributions func (s *CustomerHandlerTestSuite) TestListCustomerUsageAttributions(ctx context.Context, t *testing.T) { s.setupNamespace(t) diff --git a/test/customer/customer_test.go b/test/customer/customer_test.go index e19e1c9583..913a607039 100644 --- a/test/customer/customer_test.go +++ b/test/customer/customer_test.go @@ -35,6 +35,10 @@ func TestCustomer(t *testing.T) { testSuite.TestList(ctx, t) }) + t.Run("TestListBillingProfileFilter", func(t *testing.T) { + testSuite.TestListBillingProfileFilter(ctx, t) + }) + t.Run("TestListCustomerUsageAttributions", func(t *testing.T) { testSuite.TestListCustomerUsageAttributions(ctx, t) })