Skip to content
This repository was archived by the owner on Jan 2, 2024. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 5 additions & 7 deletions api/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,10 @@ func (h *Handler) serveSubscribe(w http.ResponseWriter, r *http.Request) error {
return err
}
phases = append(phases, control.Phase{
Trial: p.Trial,
Effective: p.Effective,
Features: fs,
AutomaticTax: sr.Tax.Automatic,
Trial: p.Trial,
Effective: p.Effective,
Features: fs,
Tax: sr.Tax,
})
}
}
Expand Down Expand Up @@ -342,9 +342,7 @@ func (h *Handler) servePhase(w http.ResponseWriter, r *http.Request) error {
Plans: p.Plans,
Fragments: p.Fragments(),
Trial: p.Trial,
Tax: apitypes.Taxation{
Automatic: p.AutomaticTax,
},
Tax: p.Tax,
})
}
}
Expand Down
3 changes: 2 additions & 1 deletion api/api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"tier.run/refs"
"tier.run/stripe"
"tier.run/stripe/stroke"
"tier.run/types/tax"
"tier.run/types/they"
)

Expand Down Expand Up @@ -319,7 +320,7 @@ func TestScheduleAutomaticTax(t *testing.T) {
}
})
_, err := tc.Schedule(ctx, "org:test", &tier.ScheduleParams{
Tax: tier.Taxation{Automatic: true},
Tax: tax.Applied{Automatically: true},
Phases: []apitypes.Phase{
{
Features: []string{"plan:test@0"},
Expand Down
17 changes: 7 additions & 10 deletions api/apitypes/apitypes.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (

"tier.run/refs"
"tier.run/types/payment"
"tier.run/types/tax"
)

type Error struct {
Expand All @@ -26,18 +27,14 @@ type Phase struct {
Features []string `json:"features,omitempty"`
}

type Taxation struct {
Automatic bool `json:"automatic,omitempty"`
}

type PhaseResponse struct {
Effective time.Time `json:"effective,omitempty"`
End time.Time `json:"end,omitempty"`
Features []refs.FeaturePlan `json:"features,omitempty"`
Plans []refs.Plan `json:"plans,omitempty"`
Fragments []refs.FeaturePlan `json:"fragments,omitempty"`
Trial bool `json:"trial,omitempty"`
Tax Taxation `json:"tax,omitempty"`
Tax tax.Applied `json:"tax"`
}

func (pr PhaseResponse) MarshalJSON() ([]byte, error) {
Expand Down Expand Up @@ -91,11 +88,11 @@ type CheckoutRequest struct {
}

type ScheduleRequest struct {
Org string `json:"org"`
PaymentMethodID string `json:"payment_method_id"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
Tax Taxation `json:"tax"`
Org string `json:"org"`
PaymentMethodID string `json:"payment_method_id"`
Info *OrgInfo `json:"info"`
Phases []Phase `json:"phases"`
Tax tax.Applied `json:"tax"`
}

// ScheduleResponse is the expected response from a schedule request. It is
Expand Down
48 changes: 42 additions & 6 deletions api/apitypes/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"

"tier.run/refs"
"tier.run/types/tax"
"tier.run/values"
)

Expand Down Expand Up @@ -53,12 +54,47 @@ type Divide struct {
}

type Feature struct {
Title string `json:"title,omitempty"`
Base float64 `json:"base,omitempty"`
Mode string `json:"mode,omitempty"`
Aggregate string `json:"aggregate,omitempty"`
Tiers []Tier `json:"tiers,omitempty"`
Divide *Divide `json:"divide,omitempty"`
Title string `json:"title,omitempty"`
Base float64 `json:"base,omitempty"`
Mode string `json:"mode,omitempty"`
Aggregate string `json:"aggregate,omitempty"`
Tiers []Tier `json:"tiers,omitempty"`
Divide Divide `json:"divide"`
Tax tax.Settings `json:"tax"`
}

func (v Feature) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Title string `json:"title,omitempty"`
Base float64 `json:"base,omitempty"`
Mode string `json:"mode,omitempty"`
Aggregate string `json:"aggregate,omitempty"`
Tiers []Tier `json:"tiers,omitempty"`
Divide *Divide `json:"divide,omitempty"`
Tax *tax.Settings `json:"tax,omitempty"`
}{
Title: v.Title,
Base: v.Base,
Mode: v.Mode,
Aggregate: v.Aggregate,
Tiers: v.Tiers,
Divide: zeroAsNil(v.Divide),
Tax: zeroAsNil(v.Tax),
})
}

// zeroAsNil returns a pointer to v if v is not the zero value for its type.
// If v implements IsZero, it is used to determine if v is the zero value.
func zeroAsNil[T comparable](v T) *T {
z, ok := any(v).(interface{ IsZero() bool })
if ok && z.IsZero() {
return nil
}
var zero T
if v == zero {
return nil
}
return &v
}

type Plan struct {
Expand Down
10 changes: 6 additions & 4 deletions api/materialize/views.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ func FromPricingHuJSON(data []byte) (fs []control.Feature, err error) {
for feature, f := range p.Features {
fn := feature.WithPlan(plan)

divide := values.Coalesce(f.Divide, &apitypes.Divide{})
ff := control.Feature{
FeaturePlan: fn,

Expand All @@ -55,8 +54,10 @@ func FromPricingHuJSON(data []byte) (fs []control.Feature, err error) {
Mode: values.Coalesce(f.Mode, "graduated"),
Aggregate: values.Coalesce(f.Aggregate, "sum"),

TransformDenominator: divide.By,
TransformRoundUp: divide.Rounding == "up",
TransformDenominator: f.Divide.By,
TransformRoundUp: f.Divide.Rounding == "up",

Tax: f.Tax,
}

if len(f.Tiers) > 0 {
Expand Down Expand Up @@ -110,13 +111,14 @@ func ToPricingJSON(fs []control.Feature) ([]byte, error) {
Mode: values.ZeroIf(f.Mode, "graduated"),
Aggregate: values.ZeroIf(f.Aggregate, "sum"),
Tiers: tiers,
Tax: f.Tax,
}
if f.TransformDenominator != 0 {
var round string
if f.TransformRoundUp {
round = "up"
}
af.Divide = &apitypes.Divide{
af.Divide = apitypes.Divide{
By: f.TransformDenominator,
Rounding: round,
}
Expand Down
27 changes: 27 additions & 0 deletions api/materialize/views_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"tier.run/client/tier"
"tier.run/control"
"tier.run/refs"
"tier.run/types/tax"
)

func TestPricingHuJSON(t *testing.T) {
Expand Down Expand Up @@ -41,6 +42,13 @@ func TestPricingHuJSON(t *testing.T) {
}
},
},
"plan:tax@1": {
"features": {
"feature:tax:not:included": {
"tax": {"included": true},
},
},
},
}
}`)

Expand All @@ -63,6 +71,16 @@ func TestPricingHuJSON(t *testing.T) {
Aggregate: "sum", // defaults
Base: 100,
},
{
PlanTitle: "plan:tax@1",
Title: "feature:tax:not:included@plan:tax@1",
FeaturePlan: refs.MustParseFeaturePlan("feature:tax:not:included@plan:tax@1"),
Currency: "usd",
Interval: "@monthly",
Mode: "graduated", // defaults
Aggregate: "sum",
Tax: tax.Settings{Included: true},
},
{
PlanTitle: "Just an example plan to show off features",
Title: "feature:volume@plan:example@1",
Expand Down Expand Up @@ -123,6 +141,14 @@ func TestPricingHuJSON(t *testing.T) {
"divide": {"by": 100, "rounding": "up"},
}
}
},
"plan:tax@1": {
"title": "plan:tax@1",
"features": {
"feature:tax:not:included": {
"tax": {"included": true}
}
}
}
}
}`)
Expand All @@ -134,6 +160,7 @@ func diffJSON(t *testing.T, got, want []byte) {
t.Helper()

format := func(b []byte) string {
t.Helper()
b, err := hujson.Standardize(b)
if err != nil {
t.Fatal(err)
Expand Down
5 changes: 2 additions & 3 deletions client/tier/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"tier.run/api/apitypes"
"tier.run/fetch"
"tier.run/refs"
"tier.run/types/tax"
)

// ClockHeader is the header used to pass the clock ID to the tier sidecar.
Expand Down Expand Up @@ -305,14 +306,12 @@ type CheckoutParams struct {
RequireBillingAddress bool
}

type Taxation = apitypes.Taxation

type ScheduleParams struct {
Info *OrgInfo
Phases []Phase
PaymentMethodID string

Tax Taxation
Tax tax.Applied
}

func (c *Client) Schedule(ctx context.Context, org string, p *ScheduleParams) (*apitypes.ScheduleResponse, error) {
Expand Down
9 changes: 5 additions & 4 deletions cmd/tier/tier.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (
"tier.run/control"
"tier.run/profile"
"tier.run/stripe"
"tier.run/types/tax"
"tier.run/version"
)

Expand Down Expand Up @@ -279,7 +280,7 @@ func runTier(cmd string, args []string) (err error) {
cancelURL := fs.String("cancel_url", "", "sets the cancel URL for use with -checkout")
requireBillingAddress := fs.Bool("require_billing_address", false, "require billing address for use with --checkout")
paymentMethod := fs.String("paymentmethod", "", "sets the Stripe payment method for the subscription (e.g. pm_123). It is ignored with --checkout")
tax := fs.String("tax", "", "sets the Stripe tax rate for the subscription ('auto' is currently the only supported value)")
taxtype := fs.String("tax", "", "sets the Stripe tax rate for the subscription ('auto' is currently the only supported value)")
if err := fs.Parse(args); err != nil {
return err
}
Expand All @@ -293,8 +294,8 @@ func runTier(cmd string, args []string) (err error) {
fmt.Fprintln(stderr, "tier: the -cancel flag must be used without arguments")
return errUsage
}
if *tax != "" && *tax != "auto" {
fmt.Fprintf(stderr, "tier: invalid tax rate %q\n", *tax)
if *taxtype != "" && *taxtype != "auto" {
fmt.Fprintf(stderr, "tier: invalid tax rate %q\n", *taxtype)
return errUsage
}

Expand Down Expand Up @@ -324,7 +325,7 @@ func runTier(cmd string, args []string) (err error) {
Email: *email,
},
PaymentMethodID: *paymentMethod,
Tax: tier.Taxation{Automatic: *tax == "auto"},
Tax: tax.Applied{Automatically: *taxtype == "auto"},
}
switch {
case *trial > 0:
Expand Down
13 changes: 11 additions & 2 deletions control/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"golang.org/x/sync/errgroup"
"tier.run/refs"
"tier.run/stripe"
"tier.run/types/tax"
"tier.run/values"
)

Expand Down Expand Up @@ -106,6 +107,8 @@ type Feature struct {

TransformDenominator int // the denominator for transforming usage
TransformRoundUp bool // whether to round up transformed usage; otherwise round down

Tax tax.Settings
}

// TODO(bmizerany): remove FQN and replace with simply adding the version to
Expand Down Expand Up @@ -330,9 +333,13 @@ func (c *Client) pushFeature(ctx context.Context, f Feature) (providerID string,
data.Set("metadata", "tier.limit", limit)
}

if f.Tax.Included {
data.Set("tax_behavior", "inclusive")
} else {
data.Set("tax_behavior", "exclusive")
}

// TODO(bmizerany): data.Set("active", ?)
// TODO(bmizerany): data.Set("tax_behavior", "?")
// TODO(bmizerany): data.Set("transform_quantity", "?")
// TODO(bmizerany): data.Set("currency_options", "?")

var v struct {
Expand Down Expand Up @@ -374,6 +381,7 @@ type stripePrice struct {
DivideBy int `json:"divide_by"`
Round string `json:"round"`
} `json:"transform_quantity"`
TaxBehavior string `json:"tax_behavior"`
}

func stripePriceToFeature(p stripePrice) Feature {
Expand All @@ -388,6 +396,7 @@ func stripePriceToFeature(p stripePrice) Feature {
Aggregate: aggregateFromStripe[p.Recurring.AggregateUsage],
TransformDenominator: p.TransformQuantity.DivideBy,
TransformRoundUp: p.TransformQuantity.Round == "up",
Tax: tax.Settings{Included: p.TaxBehavior == "inclusive"},
}

if len(p.Tiers) == 0 && p.Recurring.UsageType == "metered" {
Expand Down
9 changes: 9 additions & 0 deletions control/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"kr.dev/diff"
"tier.run/refs"
"tier.run/stripe/stroke"
"tier.run/types/tax"
)

func newTestClient(t *testing.T) *Client {
Expand Down Expand Up @@ -80,6 +81,14 @@ func TestRoundTrip(t *testing.T) {
{Upto: 1, Price: 100, Base: 0},
},
},
{
FeaturePlan: refs.MustParseFeaturePlan("feature:tax@0"),
Interval: "@daily",
Currency: "eur",
Title: "Test2",
Base: 1000,
Tax: tax.Settings{Included: true},
},
}

if !slices.IsSortedFunc(want, func(a, b Feature) bool {
Expand Down
Loading