payment-poc/main.go

469 lines
14 KiB
Go

package main
import (
"embed"
"errors"
"fmt"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/joho/godotenv"
"github.com/stripe/stripe-go/v72"
"html/template"
"log"
"net/http"
"os"
"payment-poc/database"
"payment-poc/migration"
"payment-poc/providers/mock"
stripe2 "payment-poc/providers/stripe"
"payment-poc/providers/viva"
wspay2 "payment-poc/providers/wspay"
"payment-poc/state"
"strconv"
"strings"
"time"
)
//go:embed db/dev/*.sql
var devMigrations embed.FS
type PaymentProvider interface {
CreatePaymentUrl(entry database.PaymentEntry) (updatedEntry database.PaymentEntry, url string, err error)
CompleteTransaction(entry database.PaymentEntry, amount int64) (database.PaymentEntry, error)
CancelTransaction(entry database.PaymentEntry) (database.PaymentEntry, error)
UpdatePayment(entry database.PaymentEntry) (updatedEntry database.PaymentEntry, err error)
}
func init() {
godotenv.Load()
log.SetPrefix("")
log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile)
}
func main() {
client, err := connectToDb()
if err != nil {
log.Fatalf("couldn't connect to db: %v", err)
}
if err := migration.InitializeMigrations(client, devMigrations); err != nil {
log.Fatalf("couldn't execute migrations: %v", err)
}
g := gin.Default()
if !hasProfile("no-auth") {
g.Use(gin.BasicAuth(getAccounts()))
}
g.SetFuncMap(template.FuncMap{
"formatCurrency": formatCurrency,
"formatCurrencyPtr": formatCurrencyPtr,
"decimalCurrency": decimalCurrency,
"formatState": formatState,
"omitempty": omitempty,
})
g.NoRoute(func(c *gin.Context) {
c.JSON(http.StatusNotFound, gin.H{"message": "no action on given url", "created": time.Now()})
})
g.NoMethod(func(c *gin.Context) {
c.JSON(http.StatusMethodNotAllowed, gin.H{"message": "no action on given method", "created": time.Now()})
})
backendUrl := envMustExist("BACKEND_URL")
paymentGateways := map[state.PaymentGateway]PaymentProvider{}
entryProvider := &database.PaymentEntryProvider{DB: client}
g.LoadHTMLGlob("./templates/*.gohtml")
if hasProfile(string(state.GatewayMock)) {
mockService := mock.Service{
BackendUrl: backendUrl,
}
mockHandlers(g.Group("mock"), entryProvider, &mockService)
paymentGateways[state.GatewayMock] = &mockService
log.Printf("Registered provider: %s", state.GatewayMock)
}
if hasProfile(string(state.GatewayWsPay)) {
wspayService := wspay2.Service{
ShopId: envMustExist("WSPAY_SHOP_ID"),
ShopSecret: envMustExist("WSPAY_SHOP_SECRET"),
BackendUrl: backendUrl,
}
wsPayHandlers(g.Group("wspay"), entryProvider, &wspayService)
paymentGateways[state.GatewayWsPay] = &wspayService
log.Printf("Registered provider: %s", state.GatewayWsPay)
}
if hasProfile(string(state.GatewayStripe)) {
stripeService := stripe2.Service{
ApiKey: envMustExist("STRIPE_KEY"),
BackendUrl: backendUrl,
}
stripeHandlers(g.Group("stripe"), entryProvider, &stripeService)
paymentGateways[state.GatewayStripe] = &stripeService
stripe.Key = envMustExist("STRIPE_KEY")
log.Printf("Registered provider: %s", state.GatewayStripe)
}
if hasProfile(string(state.GatewayVivaWallet)) {
vivaService := viva.Service{
ClientId: envMustExist("VIVA_WALLET_CLIENT_ID"),
ClientSecret: envMustExist("VIVA_WALLET_CLIENT_SECRET"),
SourceCode: envMustExist("VIVA_WALLET_SOURCE_CODE"),
MerchantId: envMustExist("VIVA_WALLET_MERCHANT_ID"),
ApiKey: envMustExist("VIVA_WALLET_API_KEY"),
}
vivaHandlers(g.Group("viva"), entryProvider, &vivaService)
paymentGateways[state.GatewayVivaWallet] = &vivaService
log.Printf("Registered provider: %s", state.GatewayVivaWallet)
}
g.GET("/", func(c *gin.Context) {
entries, _ := entryProvider.FetchAll()
c.HTML(200, "index.gohtml", gin.H{"Entries": entries})
})
g.GET("/methods", func(c *gin.Context) {
amount, err := strconv.ParseFloat(c.Query("amount"), 64)
if err != nil {
amount = 10.00
}
c.HTML(200, "methods.gohtml", gin.H{"Amount": amount, "Gateways": mapGateways(paymentGateways)})
})
g.GET("/methods/:gateway", func(c *gin.Context) {
gateway, err := fetchGateway(c.Param("gateway"))
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
if paymentGateway, contains := paymentGateways[gateway]; contains {
amount, err := fetchAmount(c.Query("amount"))
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
entry, err := entryProvider.CreateEntry(database.PaymentEntry{
Gateway: gateway,
State: state.StatePreinitialized,
TotalAmount: amount,
})
log.Printf("[%s:%s] creating payment with gateway '%s' for '%f'", entry.Id.String(), entry.State, gateway, float64(amount)/100.0)
if entry, url, err := paymentGateway.CreatePaymentUrl(entry); err == nil {
log.Printf("[%s:%s] created redirect url", entry.Id, entry.State)
entryProvider.UpdateEntry(entry)
c.Redirect(http.StatusSeeOther, url)
} else {
c.AbortWithError(http.StatusBadRequest, err)
return
}
} else {
c.AbortWithError(http.StatusBadRequest, errors.New("unsupported payment gateway: "+string(gateway)))
return
}
})
g.GET("/entries/:id", func(c *gin.Context) {
id := uuid.MustParse(c.Param("id"))
entry, err := entryProvider.FetchById(id)
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
c.HTML(200, "info.gohtml", gin.H{"Entry": entry})
})
g.POST("/entries/:id/complete", func(c *gin.Context) {
id := uuid.MustParse(c.Param("id"))
entry, err := entryProvider.FetchById(id)
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
if paymentGateway, ok := paymentGateways[entry.Gateway]; ok {
amount, err := fetchAmount(c.PostForm("amount"))
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
log.Printf("[%s:%s] completing payment with amount %f", id.String(), entry.State, float64(amount)/100.0)
entry, err = paymentGateway.CompleteTransaction(entry, amount)
if err == nil {
entryProvider.UpdateEntry(entry)
log.Printf("[%s:%s] completed payment", id.String(), entry.State)
c.Redirect(http.StatusSeeOther, "/entries/"+id.String())
} else {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
} else {
if err != nil {
c.AbortWithError(http.StatusInternalServerError, errors.New("payment gateway not supported: "+string(entry.Gateway)))
return
}
}
})
g.POST("/entries/:id/cancel", func(c *gin.Context) {
id := uuid.MustParse(c.Param("id"))
entry, err := entryProvider.FetchById(id)
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
if paymentGateway, ok := paymentGateways[entry.Gateway]; ok {
log.Printf("[%s:%s] canceling payment", id.String(), entry.State)
entry, err = paymentGateway.CancelTransaction(entry)
if err == nil {
entryProvider.UpdateEntry(entry)
log.Printf("[%s:%s] canceled payment", id.String(), entry.State)
c.Redirect(http.StatusSeeOther, "/entries/"+id.String())
} else {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
} else {
if err != nil {
c.AbortWithError(http.StatusInternalServerError, errors.New("payment gateway not supported: "+string(entry.Gateway)))
return
}
}
})
g.POST("/entries/:id/refresh", func(c *gin.Context) {
id := uuid.MustParse(c.Param("id"))
entry, err := entryProvider.FetchById(id)
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
if paymentGateway, ok := paymentGateways[entry.Gateway]; ok {
log.Printf("[%s:%s] fetching payment info", entry.Id.String(), entry.State)
entry, err = paymentGateway.UpdatePayment(entry)
if err == nil {
entryProvider.UpdateEntry(entry)
log.Printf("[%s:%s] fetched payment info", entry.Id.String(), entry.State)
}
c.Redirect(http.StatusSeeOther, "/entries/"+id.String())
} else {
if err != nil {
c.AbortWithError(http.StatusInternalServerError, errors.New("payment gateway not supported: "+string(entry.Gateway)))
return
}
}
})
log.Fatal(http.ListenAndServe(":5281", g))
}
func mockHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, mockService *mock.Service) {
g.GET("/gateway/:id", func(c *gin.Context) {
id := uuid.MustParse(c.Param("id"))
entry, err := provider.FetchById(id)
if err != nil {
c.AbortWithError(http.StatusBadRequest, err)
return
}
c.HTML(http.StatusOK, "mock_gateway.gohtml", gin.H{"Entry": entry})
})
g.GET("success", func(c *gin.Context) {
url, err := mockService.HandleResponse(c, provider, state.StateAccepted)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
g.GET("error", func(c *gin.Context) {
url, err := mockService.HandleResponse(c, provider, state.StateError)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
}
func mapGateways(gateways map[state.PaymentGateway]PaymentProvider) map[string]string {
providerMap := map[string]string{}
for key := range gateways {
providerMap[string(key)] = mapGatewayName(key)
}
return providerMap
}
func mapGatewayName(key state.PaymentGateway) string {
switch key {
case state.GatewayStripe:
return "Stripe"
case state.GatewayVivaWallet:
return "Viva wallet"
case state.GatewayWsPay:
return "WsPay"
case state.GatewayMock:
return "mock"
}
return ""
}
func fetchGateway(gateway string) (state.PaymentGateway, error) {
switch gateway {
case string(state.GatewayWsPay):
return state.GatewayWsPay, nil
case string(state.GatewayStripe):
return state.GatewayStripe, nil
case string(state.GatewayVivaWallet):
return state.GatewayVivaWallet, nil
case string(state.GatewayMock):
return state.GatewayMock, nil
}
return "", errors.New("unknown gateway: " + gateway)
}
func getAccounts() gin.Accounts {
auth := strings.Split(envMustExist("AUTH"), ":")
return gin.Accounts{auth[0]: auth[1]}
}
func fetchAmount(amount string) (int64, error) {
if amount, err := strconv.ParseFloat(amount, 64); err == nil {
return int64(amount * 100), nil
} else {
return 0, err
}
}
func vivaHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, vivaService *viva.Service) {
g.GET("success", func(c *gin.Context) {
url, err := vivaService.HandleResponse(c, provider, state.StateAccepted)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
g.GET("error", func(c *gin.Context) {
url, err := vivaService.HandleResponse(c, provider, state.StateError)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
}
func stripeHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, stripeService *stripe2.Service) {
g.GET("success", func(c *gin.Context) {
url, err := stripeService.HandleResponse(c, provider, state.StateAccepted)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
g.GET("error", func(c *gin.Context) {
url, err := stripeService.HandleResponse(c, provider, state.StateError)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
}
func wsPayHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, wspayService *wspay2.Service) {
g.GET("/initialize/:id", func(c *gin.Context) {
entry, err := provider.FetchById(uuid.MustParse(c.Param("id")))
if err != nil {
c.AbortWithError(http.StatusNotFound, err)
return
}
if entry.State != state.StatePreinitialized {
c.AbortWithError(http.StatusBadRequest, err)
return
}
form := wspayService.InitializePayment(entry)
c.HTML(200, "wspay.gohtml", gin.H{"Action": wspay2.AuthorisationForm, "Form": form})
})
g.GET("success", func(c *gin.Context) {
url, err := wspayService.HandleSuccessResponse(c, provider)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
g.GET("error", func(c *gin.Context) {
url, err := wspayService.HandleErrorResponse(c, provider, state.StateError)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
g.GET("cancel", func(c *gin.Context) {
url, err := wspayService.HandleErrorResponse(c, provider, state.StateCanceled)
if err != nil {
c.AbortWithError(http.StatusInternalServerError, err)
return
}
c.Redirect(http.StatusSeeOther, url)
})
}
func hasProfile(profile string) bool {
profiles := strings.Split(os.Getenv("PROFILE"), ",")
for _, p := range profiles {
if profile == strings.TrimSpace(p) {
return true
}
}
return false
}
func formatState(stt state.PaymentState) string {
switch stt {
case state.StateCanceled:
return "Otkazana"
case state.StateVoided:
return "Poništena"
case state.StateAccepted:
return "Predautorizirana"
case state.StateError:
return "Greška"
case state.StatePreinitialized:
return "Predinicijalizirana"
case state.StateInitialized:
return "Inicijalizirana"
case state.StateCanceledInitialization:
return "Otkazana tijekom izrade"
case state.StateCompleted:
return "Autorizirana"
}
return "nepoznato stanje '" + string(stt) + "'"
}
func formatCurrency(current int64) string {
return fmt.Sprintf("%d,%02d", current/100, current%100)
}
func formatCurrencyPtr(current *int64) string {
if current != nil {
return fmt.Sprintf("%d,%02d", (*current)/100, (*current)%100)
} else {
return "-"
}
}
func decimalCurrency(current int64) string {
return fmt.Sprintf("%d,%02d", current/100, current%100)
}
func omitempty(value string) string {
if value == "" {
return "-"
}
return value
}