dnote/pkg/server/api/handlers/subscription.go
Sung Won Cho 23a511dbe0
Improve package structure (#207)
* Improve package structure

* Set up travis
2019-06-25 19:20:19 +10:00

474 lines
12 KiB
Go

/* Copyright (C) 2019 Monomax Software Pty Ltd
*
* This file is part of Dnote.
*
* Dnote is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* Dnote is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with Dnote. If not, see <https://www.gnu.org/licenses/>.
*/
package handlers
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"github.com/dnote/dnote/pkg/server/api/helpers"
"github.com/dnote/dnote/pkg/server/api/operations"
"github.com/dnote/dnote/pkg/server/database"
"github.com/jinzhu/gorm"
"github.com/pkg/errors"
"github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/card"
"github.com/stripe/stripe-go/customer"
"github.com/stripe/stripe-go/paymentsource"
"github.com/stripe/stripe-go/source"
"github.com/stripe/stripe-go/sub"
"github.com/stripe/stripe-go/webhook"
)
var proPlanID = "plan_EpgsEvY27pajfo"
func getOrCreateStripeCustomer(tx *gorm.DB, user database.User) (*stripe.Customer, error) {
if user.StripeCustomerID != "" {
c, err := customer.Get(user.StripeCustomerID, nil)
if err != nil {
return nil, errors.Wrap(err, "getting customer")
}
return c, nil
}
var account database.Account
if err := tx.Where("user_id = ?", user.ID).First(&account).Error; err != nil {
return nil, errors.Wrap(err, "finding account")
}
customerParams := &stripe.CustomerParams{
Email: &account.Email.String,
}
c, err := customer.New(customerParams)
if err != nil {
return nil, errors.Wrap(err, "creating customer")
}
user.StripeCustomerID = c.ID
if err := tx.Save(&user).Error; err != nil {
return nil, errors.Wrap(err, "updating user")
}
return c, nil
}
func addCustomerSource(customerID, sourceID string) (*stripe.PaymentSource, error) {
params := &stripe.CustomerSourceParams{
Customer: stripe.String(customerID),
Source: &stripe.SourceParams{
Token: stripe.String(sourceID),
},
}
src, err := paymentsource.New(params)
if err != nil {
return nil, errors.Wrap(err, "creating source for customer")
}
return src, nil
}
func createCustomerSubscription(customerID, planID string) (*stripe.Subscription, error) {
subParams := &stripe.SubscriptionParams{
Customer: stripe.String(customerID),
Items: []*stripe.SubscriptionItemsParams{
{
Plan: stripe.String(planID),
},
},
}
s, err := sub.New(subParams)
if err != nil {
return nil, errors.Wrap(err, "creating subscription for customer")
}
return s, nil
}
type createSubPayload struct {
Source stripe.Source `json:"source"`
Country string `json:"country"`
}
// createSub creates a subscription for a the current user
func (a *App) createSub(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
return
}
var payload createSubPayload
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
handleError(w, "decoding params", err, http.StatusBadRequest)
return
}
db := database.DBConn
tx := db.Begin()
if err := tx.Model(&user).
Update(map[string]interface{}{
"cloud": true,
"billing_country": payload.Country,
}).Error; err != nil {
tx.Rollback()
handleError(w, "updating user", err, http.StatusInternalServerError)
return
}
customer, err := getOrCreateStripeCustomer(tx, user)
if err != nil {
tx.Rollback()
handleError(w, "getting customer", err, http.StatusInternalServerError)
return
}
if _, err = addCustomerSource(customer.ID, payload.Source.ID); err != nil {
tx.Rollback()
handleError(w, "attaching source", err, http.StatusInternalServerError)
return
}
if _, err := createCustomerSubscription(customer.ID, proPlanID); err != nil {
tx.Rollback()
handleError(w, "creating subscription", err, http.StatusInternalServerError)
return
}
if err := tx.Commit().Error; err != nil {
handleError(w, "committing a subscription transaction", err, http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
type updateSubPayload struct {
StripeSubcriptionID string `json:"stripe_subscription_id"`
Op string `json:"op"`
Body *interface{} `json:"body"`
}
var (
updateSubOpCancel = "cancel"
updateSubOpReactivate = "reactivate"
)
var validUpdateSubOp = []string{
updateSubOpCancel,
updateSubOpReactivate,
}
func validateUpdateSubPayload(p updateSubPayload) error {
var isOpValid bool
for _, op := range validUpdateSubOp {
if p.Op == op {
isOpValid = true
break
}
}
if !isOpValid {
return errors.Errorf("Invalid operation %s", p.Op)
}
if p.StripeSubcriptionID == "" {
return errors.New("stripe_subscription_id is required")
}
return nil
}
func (a *App) updateSub(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
return
}
if user.StripeCustomerID == "" {
http.Error(w, "Customer does not exist", http.StatusForbidden)
return
}
var payload updateSubPayload
if err := json.NewDecoder(r.Body).Decode(&payload); err != nil {
handleError(w, "decoding params", err, http.StatusBadRequest)
return
}
if err := validateUpdateSubPayload(payload); err != nil {
handleError(w, "invalid payload", err, http.StatusBadRequest)
return
}
var err error
if payload.Op == updateSubOpCancel {
err = operations.CancelSub(payload.StripeSubcriptionID, user)
} else if payload.Op == updateSubOpReactivate {
err = operations.ReactivateSub(payload.StripeSubcriptionID, user)
}
if err != nil {
var statusCode int
if err == operations.ErrSubscriptionActive {
statusCode = http.StatusBadRequest
} else {
statusCode = http.StatusInternalServerError
}
handleError(w, fmt.Sprintf("during operation %s", payload.Op), err, statusCode)
return
}
w.WriteHeader(http.StatusOK)
}
// GetSubResponseItem represents a subscription item in the response for get subscription
type GetSubResponseItem struct {
PlanID string `json:"plan_id"`
ProductID string `json:"product_id"`
}
// GetSubResponse is a response for getSub
type GetSubResponse struct {
SubscriptionID string `json:"id"`
Items []GetSubResponseItem `json:"items"`
CurrentPeriodStart int64 `json:"current_period_start"`
CurrentPeriodEnd int64 `json:"current_period_end"`
Status stripe.SubscriptionStatus `json:"status"`
CancelAtPeriodEnd bool `json:"cancel_at_period_end"`
}
func respondWithEmptySub(w http.ResponseWriter) {
emptyGetSubResponse := GetSubResponse{
Items: []GetSubResponseItem{},
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(emptyGetSubResponse); err != nil {
handleError(w, "encoding response", err, http.StatusInternalServerError)
return
}
}
func (a *App) getSub(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
return
}
if user.StripeCustomerID == "" {
respondWithEmptySub(w)
return
}
listParams := &stripe.SubscriptionListParams{}
listParams.Filters.AddFilter("customer", "", user.StripeCustomerID)
listParams.Filters.AddFilter("status", "", "active")
i := sub.List(listParams)
if !i.Next() {
if err := i.Err(); err != nil {
handleError(w, "fetching subscription", err, http.StatusInternalServerError)
return
}
// If no active subscription exists, respond with an empty subscription
respondWithEmptySub(w)
return
}
s := i.Subscription()
resp := GetSubResponse{
SubscriptionID: s.ID,
CurrentPeriodStart: s.CurrentPeriodStart,
CurrentPeriodEnd: s.CurrentPeriodEnd,
Status: s.Status,
CancelAtPeriodEnd: s.CancelAtPeriodEnd,
}
for _, item := range s.Items.Data {
i := GetSubResponseItem{
PlanID: item.Plan.ID,
ProductID: item.Plan.Product.ID,
}
resp.Items = append(resp.Items, i)
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
handleError(w, "encoding response", err, http.StatusInternalServerError)
return
}
}
// GetStripeSourceResponse is a response for getStripeToken
type GetStripeSourceResponse struct {
Brand string `json:"brand"`
Last4 string `json:"last4"`
ExpMonth uint8 `json:"exp_month"`
ExpYear uint16 `json:"exp_year"`
}
func respondWithEmptyStripeToken(w http.ResponseWriter) {
var resp GetStripeSourceResponse
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
handleError(w, "encoding response", err, http.StatusInternalServerError)
return
}
}
// getStripeCard retrieves card information from stripe and returns a stripe.Card
// It handles legacy 'card' resource which have 'card_' prefixes, as well as the
// more up-to-date 'source' resources which have 'src_' prefixes.
func getStripeCard(stripeCustomerID, sourceID string) (*stripe.Card, error) {
if strings.HasPrefix(sourceID, "card_") {
params := &stripe.CardParams{
Customer: stripe.String(stripeCustomerID),
}
cd, err := card.Get(sourceID, params)
if err != nil {
return nil, errors.Wrap(err, "fetching card")
}
return cd, nil
} else if strings.HasPrefix(sourceID, "src_") {
src, err := source.Get(sourceID, nil)
if err != nil {
return nil, errors.Wrap(err, "fetching source")
}
brand, ok := src.TypeData["brand"].(string)
if !ok {
return nil, errors.New("casting brand")
}
last4, ok := src.TypeData["last4"].(string)
if !ok {
return nil, errors.New("casting last4")
}
expMonth, ok := src.TypeData["exp_month"].(float64)
if !ok {
return nil, errors.New("casting exp_month")
}
expYear, ok := src.TypeData["exp_year"].(float64)
if !ok {
return nil, errors.New("casting exp_year")
}
cd := &stripe.Card{
Brand: stripe.CardBrand(brand),
Last4: last4,
ExpMonth: uint8(expMonth),
ExpYear: uint16(expYear),
}
return cd, nil
}
return nil, errors.Errorf("malformed sourceID %s", sourceID)
}
func (a *App) getStripeSource(w http.ResponseWriter, r *http.Request) {
user, ok := r.Context().Value(helpers.KeyUser).(database.User)
if !ok {
http.Error(w, "No authenticated user found", http.StatusInternalServerError)
return
}
if user.StripeCustomerID == "" {
respondWithEmptyStripeToken(w)
return
}
c, err := customer.Get(user.StripeCustomerID, nil)
if err != nil {
handleError(w, "fetching stripe customer", err, http.StatusInternalServerError)
return
}
if c.DefaultSource == nil {
respondWithEmptyStripeToken(w)
return
}
cd, err := getStripeCard(user.StripeCustomerID, c.DefaultSource.ID)
if err != nil {
handleError(w, "fetching stripe source", err, http.StatusInternalServerError)
return
}
resp := GetStripeSourceResponse{
Brand: string(cd.Brand),
Last4: cd.Last4,
ExpMonth: cd.ExpMonth,
ExpYear: cd.ExpYear,
}
w.Header().Set("Content-Type", "application/json")
if err := json.NewEncoder(w).Encode(resp); err != nil {
handleError(w, "encoding response", err, http.StatusInternalServerError)
return
}
}
func (a *App) stripeWebhook(w http.ResponseWriter, req *http.Request) {
body, err := ioutil.ReadAll(req.Body)
if err != nil {
handleError(w, "reading body", err, http.StatusServiceUnavailable)
return
}
webhookSecret := os.Getenv("StripeWebhookSecret")
event, err := webhook.ConstructEvent(body, req.Header.Get("Stripe-Signature"), webhookSecret)
if err != nil {
handleError(w, "verifying stripe webhook signature", err, http.StatusBadRequest)
return
}
switch event.Type {
case "customer.subscription.deleted":
{
var subscription stripe.Subscription
if json.Unmarshal(event.Data.Raw, &subscription); err != nil {
handleError(w, "unmarshaling payload", err, http.StatusBadRequest)
return
}
operations.MarkUnsubscribed(subscription.Customer.ID)
}
default:
{
msg := fmt.Sprintf("Unsupported webhook event type %s", event.Type)
handleError(w, msg, err, http.StatusBadRequest)
return
}
}
// Return a response to acknowledge receipt of the event
w.WriteHeader(http.StatusOK)
}