mirror of
https://github.com/dnote/dnote
synced 2026-03-16 15:35:52 +01:00
474 lines
12 KiB
Go
474 lines
12 KiB
Go
/* Copyright (C) 2019 Monomax Software Pty Ltd
|
|
*
|
|
* This file is part of Dnote.
|
|
*
|
|
* Dnote is free software: you can redistribute it and/or modify
|
|
* it under the terms of the GNU Affero General Public License as published by
|
|
* the Free Software Foundation, either version 3 of the License, or
|
|
* (at your option) any later version.
|
|
*
|
|
* Dnote is distributed in the hope that it will be useful,
|
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
* GNU Affero General Public License for more details.
|
|
*
|
|
* You should have received a copy of the GNU Affero General Public License
|
|
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
|
|
*/
|
|
|
|
package handlers
|
|
|
|
import (
|
|
"encoding/json"
|
|
"fmt"
|
|
"io/ioutil"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
|
|
"github.com/dnote/dnote/pkg/server/api/helpers"
|
|
"github.com/dnote/dnote/pkg/server/api/operations"
|
|
"github.com/dnote/dnote/pkg/server/database"
|
|
"github.com/jinzhu/gorm"
|
|
"github.com/pkg/errors"
|
|
"github.com/stripe/stripe-go"
|
|
"github.com/stripe/stripe-go/card"
|
|
"github.com/stripe/stripe-go/customer"
|
|
"github.com/stripe/stripe-go/paymentsource"
|
|
"github.com/stripe/stripe-go/source"
|
|
"github.com/stripe/stripe-go/sub"
|
|
"github.com/stripe/stripe-go/webhook"
|
|
)
|
|
|
|
var proPlanID = "plan_EpgsEvY27pajfo"
|
|
|
|
func getOrCreateStripeCustomer(tx *gorm.DB, user database.User) (*stripe.Customer, error) {
|
|
if user.StripeCustomerID != "" {
|
|
c, err := customer.Get(user.StripeCustomerID, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "getting customer")
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
var account database.Account
|
|
if err := tx.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
|
|
return nil, errors.Wrap(err, "finding account")
|
|
}
|
|
|
|
customerParams := &stripe.CustomerParams{
|
|
Email: &account.Email.String,
|
|
}
|
|
c, err := customer.New(customerParams)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "creating customer")
|
|
}
|
|
|
|
user.StripeCustomerID = c.ID
|
|
if err := tx.Save(&user).Error; err != nil {
|
|
return nil, errors.Wrap(err, "updating user")
|
|
}
|
|
|
|
return c, nil
|
|
}
|
|
|
|
func addCustomerSource(customerID, sourceID string) (*stripe.PaymentSource, error) {
|
|
params := &stripe.CustomerSourceParams{
|
|
Customer: stripe.String(customerID),
|
|
Source: &stripe.SourceParams{
|
|
Token: stripe.String(sourceID),
|
|
},
|
|
}
|
|
|
|
src, err := paymentsource.New(params)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "creating source for customer")
|
|
}
|
|
|
|
return src, nil
|
|
}
|
|
|
|
func createCustomerSubscription(customerID, planID string) (*stripe.Subscription, error) {
|
|
subParams := &stripe.SubscriptionParams{
|
|
Customer: stripe.String(customerID),
|
|
Items: []*stripe.SubscriptionItemsParams{
|
|
{
|
|
Plan: stripe.String(planID),
|
|
},
|
|
},
|
|
}
|
|
|
|
s, err := sub.New(subParams)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "creating subscription for customer")
|
|
}
|
|
|
|
return s, nil
|
|
}
|
|
|
|
type createSubPayload struct {
|
|
Source stripe.Source `json:"source"`
|
|
Country string `json:"country"`
|
|
}
|
|
|
|
// createSub creates a subscription for a the current user
|
|
func (a *App) createSub(w http.ResponseWriter, r *http.Request) {
|
|
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
|
if !ok {
|
|
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
var payload createSubPayload
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
handleError(w, "decoding params", err, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
db := database.DBConn
|
|
tx := db.Begin()
|
|
|
|
if err := tx.Model(&user).
|
|
Update(map[string]interface{}{
|
|
"cloud": true,
|
|
"billing_country": payload.Country,
|
|
}).Error; err != nil {
|
|
tx.Rollback()
|
|
handleError(w, "updating user", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
customer, err := getOrCreateStripeCustomer(tx, user)
|
|
if err != nil {
|
|
tx.Rollback()
|
|
handleError(w, "getting customer", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if _, err = addCustomerSource(customer.ID, payload.Source.ID); err != nil {
|
|
tx.Rollback()
|
|
handleError(w, "attaching source", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if _, err := createCustomerSubscription(customer.ID, proPlanID); err != nil {
|
|
tx.Rollback()
|
|
handleError(w, "creating subscription", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if err := tx.Commit().Error; err != nil {
|
|
handleError(w, "committing a subscription transaction", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
type updateSubPayload struct {
|
|
StripeSubcriptionID string `json:"stripe_subscription_id"`
|
|
Op string `json:"op"`
|
|
Body *interface{} `json:"body"`
|
|
}
|
|
|
|
var (
|
|
updateSubOpCancel = "cancel"
|
|
updateSubOpReactivate = "reactivate"
|
|
)
|
|
|
|
var validUpdateSubOp = []string{
|
|
updateSubOpCancel,
|
|
updateSubOpReactivate,
|
|
}
|
|
|
|
func validateUpdateSubPayload(p updateSubPayload) error {
|
|
var isOpValid bool
|
|
|
|
for _, op := range validUpdateSubOp {
|
|
if p.Op == op {
|
|
isOpValid = true
|
|
break
|
|
}
|
|
}
|
|
|
|
if !isOpValid {
|
|
return errors.Errorf("Invalid operation %s", p.Op)
|
|
}
|
|
|
|
if p.StripeSubcriptionID == "" {
|
|
return errors.New("stripe_subscription_id is required")
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (a *App) updateSub(w http.ResponseWriter, r *http.Request) {
|
|
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
|
if !ok {
|
|
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if user.StripeCustomerID == "" {
|
|
http.Error(w, "Customer does not exist", http.StatusForbidden)
|
|
return
|
|
}
|
|
|
|
var payload updateSubPayload
|
|
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
|
|
handleError(w, "decoding params", err, http.StatusBadRequest)
|
|
return
|
|
}
|
|
if err := validateUpdateSubPayload(payload); err != nil {
|
|
handleError(w, "invalid payload", err, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
var err error
|
|
if payload.Op == updateSubOpCancel {
|
|
err = operations.CancelSub(payload.StripeSubcriptionID, user)
|
|
} else if payload.Op == updateSubOpReactivate {
|
|
err = operations.ReactivateSub(payload.StripeSubcriptionID, user)
|
|
}
|
|
|
|
if err != nil {
|
|
var statusCode int
|
|
if err == operations.ErrSubscriptionActive {
|
|
statusCode = http.StatusBadRequest
|
|
} else {
|
|
statusCode = http.StatusInternalServerError
|
|
}
|
|
|
|
handleError(w, fmt.Sprintf("during operation %s", payload.Op), err, statusCode)
|
|
return
|
|
}
|
|
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|
|
|
|
// GetSubResponseItem represents a subscription item in the response for get subscription
|
|
type GetSubResponseItem struct {
|
|
PlanID string `json:"plan_id"`
|
|
ProductID string `json:"product_id"`
|
|
}
|
|
|
|
// GetSubResponse is a response for getSub
|
|
type GetSubResponse struct {
|
|
SubscriptionID string `json:"id"`
|
|
Items []GetSubResponseItem `json:"items"`
|
|
CurrentPeriodStart int64 `json:"current_period_start"`
|
|
CurrentPeriodEnd int64 `json:"current_period_end"`
|
|
Status stripe.SubscriptionStatus `json:"status"`
|
|
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
|
|
}
|
|
|
|
func respondWithEmptySub(w http.ResponseWriter) {
|
|
emptyGetSubResponse := GetSubResponse{
|
|
Items: []GetSubResponseItem{},
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(emptyGetSubResponse); err != nil {
|
|
handleError(w, "encoding response", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *App) getSub(w http.ResponseWriter, r *http.Request) {
|
|
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
|
if !ok {
|
|
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if user.StripeCustomerID == "" {
|
|
respondWithEmptySub(w)
|
|
return
|
|
}
|
|
|
|
listParams := &stripe.SubscriptionListParams{}
|
|
listParams.Filters.AddFilter("customer", "", user.StripeCustomerID)
|
|
listParams.Filters.AddFilter("status", "", "active")
|
|
i := sub.List(listParams)
|
|
|
|
if !i.Next() {
|
|
if err := i.Err(); err != nil {
|
|
handleError(w, "fetching subscription", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
// If no active subscription exists, respond with an empty subscription
|
|
respondWithEmptySub(w)
|
|
return
|
|
}
|
|
|
|
s := i.Subscription()
|
|
|
|
resp := GetSubResponse{
|
|
SubscriptionID: s.ID,
|
|
CurrentPeriodStart: s.CurrentPeriodStart,
|
|
CurrentPeriodEnd: s.CurrentPeriodEnd,
|
|
Status: s.Status,
|
|
CancelAtPeriodEnd: s.CancelAtPeriodEnd,
|
|
}
|
|
|
|
for _, item := range s.Items.Data {
|
|
i := GetSubResponseItem{
|
|
PlanID: item.Plan.ID,
|
|
ProductID: item.Plan.Product.ID,
|
|
}
|
|
resp.Items = append(resp.Items, i)
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
handleError(w, "encoding response", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// GetStripeSourceResponse is a response for getStripeToken
|
|
type GetStripeSourceResponse struct {
|
|
Brand string `json:"brand"`
|
|
Last4 string `json:"last4"`
|
|
ExpMonth uint8 `json:"exp_month"`
|
|
ExpYear uint16 `json:"exp_year"`
|
|
}
|
|
|
|
func respondWithEmptyStripeToken(w http.ResponseWriter) {
|
|
var resp GetStripeSourceResponse
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
handleError(w, "encoding response", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
// getStripeCard retrieves card information from stripe and returns a stripe.Card
|
|
// It handles legacy 'card' resource which have 'card_' prefixes, as well as the
|
|
// more up-to-date 'source' resources which have 'src_' prefixes.
|
|
func getStripeCard(stripeCustomerID, sourceID string) (*stripe.Card, error) {
|
|
if strings.HasPrefix(sourceID, "card_") {
|
|
params := &stripe.CardParams{
|
|
Customer: stripe.String(stripeCustomerID),
|
|
}
|
|
cd, err := card.Get(sourceID, params)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "fetching card")
|
|
}
|
|
|
|
return cd, nil
|
|
} else if strings.HasPrefix(sourceID, "src_") {
|
|
src, err := source.Get(sourceID, nil)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "fetching source")
|
|
}
|
|
|
|
brand, ok := src.TypeData["brand"].(string)
|
|
if !ok {
|
|
return nil, errors.New("casting brand")
|
|
}
|
|
last4, ok := src.TypeData["last4"].(string)
|
|
if !ok {
|
|
return nil, errors.New("casting last4")
|
|
}
|
|
expMonth, ok := src.TypeData["exp_month"].(float64)
|
|
if !ok {
|
|
return nil, errors.New("casting exp_month")
|
|
}
|
|
expYear, ok := src.TypeData["exp_year"].(float64)
|
|
if !ok {
|
|
return nil, errors.New("casting exp_year")
|
|
}
|
|
|
|
cd := &stripe.Card{
|
|
Brand: stripe.CardBrand(brand),
|
|
Last4: last4,
|
|
ExpMonth: uint8(expMonth),
|
|
ExpYear: uint16(expYear),
|
|
}
|
|
|
|
return cd, nil
|
|
}
|
|
|
|
return nil, errors.Errorf("malformed sourceID %s", sourceID)
|
|
}
|
|
|
|
func (a *App) getStripeSource(w http.ResponseWriter, r *http.Request) {
|
|
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
|
|
if !ok {
|
|
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
|
|
return
|
|
}
|
|
if user.StripeCustomerID == "" {
|
|
respondWithEmptyStripeToken(w)
|
|
return
|
|
}
|
|
|
|
c, err := customer.Get(user.StripeCustomerID, nil)
|
|
if err != nil {
|
|
handleError(w, "fetching stripe customer", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
if c.DefaultSource == nil {
|
|
respondWithEmptyStripeToken(w)
|
|
return
|
|
}
|
|
|
|
cd, err := getStripeCard(user.StripeCustomerID, c.DefaultSource.ID)
|
|
if err != nil {
|
|
handleError(w, "fetching stripe source", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
|
|
resp := GetStripeSourceResponse{
|
|
Brand: string(cd.Brand),
|
|
Last4: cd.Last4,
|
|
ExpMonth: cd.ExpMonth,
|
|
ExpYear: cd.ExpYear,
|
|
}
|
|
|
|
w.Header().Set("Content-Type", "application/json")
|
|
if err := json.NewEncoder(w).Encode(resp); err != nil {
|
|
handleError(w, "encoding response", err, http.StatusInternalServerError)
|
|
return
|
|
}
|
|
}
|
|
|
|
func (a *App) stripeWebhook(w http.ResponseWriter, req *http.Request) {
|
|
body, err := ioutil.ReadAll(req.Body)
|
|
if err != nil {
|
|
handleError(w, "reading body", err, http.StatusServiceUnavailable)
|
|
return
|
|
}
|
|
|
|
webhookSecret := os.Getenv("StripeWebhookSecret")
|
|
event, err := webhook.ConstructEvent(body, req.Header.Get("Stripe-Signature"), webhookSecret)
|
|
if err != nil {
|
|
handleError(w, "verifying stripe webhook signature", err, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
switch event.Type {
|
|
case "customer.subscription.deleted":
|
|
{
|
|
var subscription stripe.Subscription
|
|
if json.Unmarshal(event.Data.Raw, &subscription); err != nil {
|
|
handleError(w, "unmarshaling payload", err, http.StatusBadRequest)
|
|
return
|
|
}
|
|
|
|
operations.MarkUnsubscribed(subscription.Customer.ID)
|
|
}
|
|
default:
|
|
{
|
|
msg := fmt.Sprintf("Unsupported webhook event type %s", event.Type)
|
|
handleError(w, msg, err, http.StatusBadRequest)
|
|
return
|
|
}
|
|
}
|
|
|
|
// Return a response to acknowledge receipt of the event
|
|
w.WriteHeader(http.StatusOK)
|
|
}
|