package main import ( "embed" "errors" "fmt" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/joho/godotenv" "github.com/stripe/stripe-go/v72" "html/template" "log" "net/http" "os" "payment-poc/database" "payment-poc/migration" "payment-poc/providers/mock" stripe2 "payment-poc/providers/stripe" "payment-poc/providers/viva" wspay2 "payment-poc/providers/wspay" "payment-poc/state" "strconv" "strings" "time" ) //go:embed db/dev/*.sql var devMigrations embed.FS type PaymentProvider interface { CreatePaymentUrl(entry database.PaymentEntry) (updatedEntry database.PaymentEntry, url string, err error) CompleteTransaction(entry database.PaymentEntry, amount int64) (database.PaymentEntry, error) CancelTransaction(entry database.PaymentEntry) (database.PaymentEntry, error) UpdatePayment(entry database.PaymentEntry) (updatedEntry database.PaymentEntry, err error) } func init() { godotenv.Load() log.SetPrefix("") log.SetFlags(log.Ldate | log.Ltime | log.Lshortfile) } func main() { client, err := connectToDb() if err != nil { log.Fatalf("couldn't connect to db: %v", err) } if err := migration.InitializeMigrations(client, devMigrations); err != nil { log.Fatalf("couldn't execute migrations: %v", err) } g := gin.Default() if !hasProfile("no-auth") { g.Use(gin.BasicAuth(getAccounts())) } g.SetFuncMap(template.FuncMap{ "formatCurrency": formatCurrency, "formatCurrencyPtr": formatCurrencyPtr, "decimalCurrency": decimalCurrency, "formatState": formatState, "omitempty": omitempty, }) g.NoRoute(func(c *gin.Context) { c.JSON(http.StatusNotFound, gin.H{"message": "no action on given url", "created": time.Now()}) }) g.NoMethod(func(c *gin.Context) { c.JSON(http.StatusMethodNotAllowed, gin.H{"message": "no action on given method", "created": time.Now()}) }) backendUrl := envMustExist("BACKEND_URL") paymentGateways := map[state.PaymentGateway]PaymentProvider{} entryProvider := &database.PaymentEntryProvider{DB: client} g.LoadHTMLGlob("./templates/*.gohtml") if hasProfile(string(state.GatewayMock)) { mockService := mock.Service{ BackendUrl: backendUrl, } mockHandlers(g.Group("mock"), entryProvider, &mockService) paymentGateways[state.GatewayMock] = &mockService log.Printf("Registered provider: %s", state.GatewayMock) } if hasProfile(string(state.GatewayWsPay)) { wspayService := wspay2.Service{ ShopId: envMustExist("WSPAY_SHOP_ID"), ShopSecret: envMustExist("WSPAY_SHOP_SECRET"), BackendUrl: backendUrl, } wsPayHandlers(g.Group("wspay"), entryProvider, &wspayService) paymentGateways[state.GatewayWsPay] = &wspayService log.Printf("Registered provider: %s", state.GatewayWsPay) } if hasProfile(string(state.GatewayStripe)) { stripeService := stripe2.Service{ ApiKey: envMustExist("STRIPE_KEY"), BackendUrl: backendUrl, } stripeHandlers(g.Group("stripe"), entryProvider, &stripeService) paymentGateways[state.GatewayStripe] = &stripeService stripe.Key = envMustExist("STRIPE_KEY") log.Printf("Registered provider: %s", state.GatewayStripe) } if hasProfile(string(state.GatewayVivaWallet)) { vivaService := viva.Service{ ClientId: envMustExist("VIVA_WALLET_CLIENT_ID"), ClientSecret: envMustExist("VIVA_WALLET_CLIENT_SECRET"), SourceCode: envMustExist("VIVA_WALLET_SOURCE_CODE"), MerchantId: envMustExist("VIVA_WALLET_MERCHANT_ID"), ApiKey: envMustExist("VIVA_WALLET_API_KEY"), } vivaHandlers(g.Group("viva"), entryProvider, &vivaService) paymentGateways[state.GatewayVivaWallet] = &vivaService log.Printf("Registered provider: %s", state.GatewayVivaWallet) } g.GET("/", func(c *gin.Context) { entries, _ := entryProvider.FetchAll() c.HTML(200, "index.gohtml", gin.H{"Entries": entries}) }) g.GET("/methods", func(c *gin.Context) { amount, err := strconv.ParseFloat(c.Query("amount"), 64) if err != nil { amount = 10.00 } c.HTML(200, "methods.gohtml", gin.H{"Amount": amount, "Gateways": mapGateways(paymentGateways)}) }) g.GET("/methods/:gateway", func(c *gin.Context) { gateway, err := fetchGateway(c.Param("gateway")) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } if paymentGateway, contains := paymentGateways[gateway]; contains { amount, err := fetchAmount(c.Query("amount")) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } entry, err := entryProvider.CreateEntry(database.PaymentEntry{ Gateway: gateway, State: state.StatePreinitialized, TotalAmount: amount, }) log.Printf("[%s:%s] creating payment with gateway '%s' for '%f'", entry.Id.String(), entry.State, gateway, float64(amount)/100.0) if entry, url, err := paymentGateway.CreatePaymentUrl(entry); err == nil { log.Printf("[%s:%s] created redirect url", entry.Id, entry.State) entryProvider.UpdateEntry(entry) c.Redirect(http.StatusSeeOther, url) } else { c.AbortWithError(http.StatusBadRequest, err) return } } else { c.AbortWithError(http.StatusBadRequest, errors.New("unsupported payment gateway: "+string(gateway))) return } }) g.GET("/entries/:id", func(c *gin.Context) { id := uuid.MustParse(c.Param("id")) entry, err := entryProvider.FetchById(id) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } c.HTML(200, "info.gohtml", gin.H{"Entry": entry}) }) g.POST("/entries/:id/complete", func(c *gin.Context) { id := uuid.MustParse(c.Param("id")) entry, err := entryProvider.FetchById(id) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } if paymentGateway, ok := paymentGateways[entry.Gateway]; ok { amount, err := fetchAmount(c.PostForm("amount")) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } log.Printf("[%s:%s] completing payment with amount %f", id.String(), entry.State, float64(amount)/100.0) entry, err = paymentGateway.CompleteTransaction(entry, amount) if err == nil { entryProvider.UpdateEntry(entry) log.Printf("[%s:%s] completed payment", id.String(), entry.State) c.Redirect(http.StatusSeeOther, "/entries/"+id.String()) } else { c.AbortWithError(http.StatusInternalServerError, err) return } } else { if err != nil { c.AbortWithError(http.StatusInternalServerError, errors.New("payment gateway not supported: "+string(entry.Gateway))) return } } }) g.POST("/entries/:id/cancel", func(c *gin.Context) { id := uuid.MustParse(c.Param("id")) entry, err := entryProvider.FetchById(id) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } if paymentGateway, ok := paymentGateways[entry.Gateway]; ok { log.Printf("[%s:%s] canceling payment", id.String(), entry.State) entry, err = paymentGateway.CancelTransaction(entry) if err == nil { entryProvider.UpdateEntry(entry) log.Printf("[%s:%s] canceled payment", id.String(), entry.State) c.Redirect(http.StatusSeeOther, "/entries/"+id.String()) } else { c.AbortWithError(http.StatusInternalServerError, err) return } } else { if err != nil { c.AbortWithError(http.StatusInternalServerError, errors.New("payment gateway not supported: "+string(entry.Gateway))) return } } }) g.POST("/entries/:id/refresh", func(c *gin.Context) { id := uuid.MustParse(c.Param("id")) entry, err := entryProvider.FetchById(id) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } if paymentGateway, ok := paymentGateways[entry.Gateway]; ok { log.Printf("[%s:%s] fetching payment info", entry.Id.String(), entry.State) entry, err = paymentGateway.UpdatePayment(entry) if err == nil { entryProvider.UpdateEntry(entry) log.Printf("[%s:%s] fetched payment info", entry.Id.String(), entry.State) } c.Redirect(http.StatusSeeOther, "/entries/"+id.String()) } else { if err != nil { c.AbortWithError(http.StatusInternalServerError, errors.New("payment gateway not supported: "+string(entry.Gateway))) return } } }) log.Fatal(http.ListenAndServe(":5281", g)) } func mockHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, mockService *mock.Service) { g.GET("/gateway/:id", func(c *gin.Context) { id := uuid.MustParse(c.Param("id")) entry, err := provider.FetchById(id) if err != nil { c.AbortWithError(http.StatusBadRequest, err) return } c.HTML(http.StatusOK, "mock_gateway.gohtml", gin.H{"Entry": entry}) }) g.GET("success", func(c *gin.Context) { url, err := mockService.HandleResponse(c, provider, state.StateAccepted) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) g.GET("error", func(c *gin.Context) { url, err := mockService.HandleResponse(c, provider, state.StateError) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) } func mapGateways(gateways map[state.PaymentGateway]PaymentProvider) map[string]string { providerMap := map[string]string{} for key := range gateways { providerMap[string(key)] = mapGatewayName(key) } return providerMap } func mapGatewayName(key state.PaymentGateway) string { switch key { case state.GatewayStripe: return "Stripe" case state.GatewayVivaWallet: return "Viva wallet" case state.GatewayWsPay: return "WsPay" case state.GatewayMock: return "mock" } return "" } func fetchGateway(gateway string) (state.PaymentGateway, error) { switch gateway { case string(state.GatewayWsPay): return state.GatewayWsPay, nil case string(state.GatewayStripe): return state.GatewayStripe, nil case string(state.GatewayVivaWallet): return state.GatewayVivaWallet, nil case string(state.GatewayMock): return state.GatewayMock, nil } return "", errors.New("unknown gateway: " + gateway) } func getAccounts() gin.Accounts { auth := strings.Split(envMustExist("AUTH"), ":") return gin.Accounts{auth[0]: auth[1]} } func fetchAmount(amount string) (int64, error) { if amount, err := strconv.ParseFloat(amount, 64); err == nil { return int64(amount * 100), nil } else { return 0, err } } func vivaHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, vivaService *viva.Service) { g.GET("success", func(c *gin.Context) { url, err := vivaService.HandleResponse(c, provider, state.StateAccepted) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) g.GET("error", func(c *gin.Context) { url, err := vivaService.HandleResponse(c, provider, state.StateError) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) } func stripeHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, stripeService *stripe2.Service) { g.GET("success", func(c *gin.Context) { url, err := stripeService.HandleResponse(c, provider, state.StateAccepted) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) g.GET("error", func(c *gin.Context) { url, err := stripeService.HandleResponse(c, provider, state.StateError) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) } func wsPayHandlers(g *gin.RouterGroup, provider *database.PaymentEntryProvider, wspayService *wspay2.Service) { g.GET("success", func(c *gin.Context) { url, err := wspayService.HandleSuccessResponse(c, provider) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) g.GET("error", func(c *gin.Context) { url, err := wspayService.HandleErrorResponse(c, provider, state.StateError) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) g.GET("cancel", func(c *gin.Context) { url, err := wspayService.HandleErrorResponse(c, provider, state.StateCanceled) if err != nil { c.AbortWithError(http.StatusInternalServerError, err) return } c.Redirect(http.StatusSeeOther, url) }) } func hasProfile(profile string) bool { profiles := strings.Split(os.Getenv("PROFILE"), ",") for _, p := range profiles { if profile == strings.TrimSpace(p) { return true } } return false } func formatState(stt state.PaymentState) string { switch stt { case state.StateCanceled: return "Otkazana" case state.StateVoided: return "Poništena" case state.StateAccepted: return "Predautorizirana" case state.StateError: return "Greška" case state.StatePreinitialized: return "Predinicijalizirana" case state.StateInitialized: return "Inicijalizirana" case state.StateCanceledInitialization: return "Otkazana tijekom izrade" case state.StateCompleted: return "Autorizirana" } return "nepoznato stanje '" + string(stt) + "'" } func formatCurrency(current int64) string { return fmt.Sprintf("%d,%02d", current/100, current%100) } func formatCurrencyPtr(current *int64) string { if current != nil { return fmt.Sprintf("%d,%02d", (*current)/100, (*current)%100) } else { return "-" } } func decimalCurrency(current int64) string { return fmt.Sprintf("%d,%02d", current/100, current%100) } func omitempty(value string) string { if value == "" { return "-" } return value }