mirror of
https://github.com/pagefaultgames/rogueserver.git
synced 2025-09-19 05:33:48 +08:00
Adjust db to allow mocking for unit tests
This commit is contained in:
parent
27a1e8f363
commit
40323809f7
@ -20,11 +20,15 @@ package account
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
|
||||
func ChangePW(uuid []byte, password string) error {
|
||||
// Interface for database operations needed for changing password.
|
||||
type ChangePWStore interface {
|
||||
RemoveSessionsFromUUID(uuid []byte) error
|
||||
UpdateAccountPassword(uuid []byte, newKey []byte, newSalt []byte) error
|
||||
}
|
||||
|
||||
func ChangePW[T ChangePWStore](store T, uuid []byte, password string) error {
|
||||
if len(password) < 6 {
|
||||
return fmt.Errorf("invalid password")
|
||||
}
|
||||
@ -35,12 +39,12 @@ func ChangePW(uuid []byte, password string) error {
|
||||
return fmt.Errorf("failed to generate salt: %s", err)
|
||||
}
|
||||
|
||||
err = db.RemoveSessionsFromUUID(uuid)
|
||||
err = store.RemoveSessionsFromUUID(uuid)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to remove sessions: %s", err)
|
||||
}
|
||||
|
||||
err = db.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
|
||||
err = store.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add account record: %s", err)
|
||||
}
|
||||
|
@ -35,14 +35,24 @@ var (
|
||||
DiscordGuildID string
|
||||
)
|
||||
|
||||
func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
type DiscordProvider interface {
|
||||
HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error)
|
||||
RetrieveDiscordId(code string) (string, error)
|
||||
IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error)
|
||||
}
|
||||
|
||||
type discordProvider struct{}
|
||||
|
||||
var Discord = &discordProvider{}
|
||||
|
||||
func (s *discordProvider) HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Redirect(w, r, GameURL, http.StatusSeeOther)
|
||||
return "", errors.New("code is empty")
|
||||
}
|
||||
|
||||
discordId, err := RetrieveDiscordId(code)
|
||||
discordId, err := s.RetrieveDiscordId(code)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, GameURL, http.StatusSeeOther)
|
||||
return "", err
|
||||
@ -51,7 +61,7 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro
|
||||
return discordId, nil
|
||||
}
|
||||
|
||||
func RetrieveDiscordId(code string) (string, error) {
|
||||
func (s *discordProvider) RetrieveDiscordId(code string) (string, error) {
|
||||
v := make(url.Values)
|
||||
v.Set("client_id", DiscordClientID)
|
||||
v.Set("client_secret", DiscordClientSecret)
|
||||
@ -112,7 +122,7 @@ func RetrieveDiscordId(code string) (string, error) {
|
||||
return user.Id, nil
|
||||
}
|
||||
|
||||
func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
|
||||
func (s *discordProvider) IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
|
||||
// fetch all roles from discord
|
||||
roles, err := DiscordSession.GuildRoles(discordGuildID)
|
||||
if err != nil {
|
||||
|
@ -32,7 +32,16 @@ var (
|
||||
GoogleCallbackURL string
|
||||
)
|
||||
|
||||
func HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
type GoogleProvider interface {
|
||||
HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error)
|
||||
RetrieveGoogleId(code string) (string, error)
|
||||
}
|
||||
|
||||
type googleProvider struct{}
|
||||
|
||||
var Google = &googleProvider{}
|
||||
|
||||
func (g *googleProvider) HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
http.Redirect(w, r, GameURL, http.StatusSeeOther)
|
||||
|
@ -17,10 +17,6 @@
|
||||
|
||||
package account
|
||||
|
||||
import (
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
|
||||
type InfoResponse struct {
|
||||
Username string `json:"username"`
|
||||
DiscordId string `json:"discordId"`
|
||||
@ -29,9 +25,13 @@ type InfoResponse struct {
|
||||
HasAdminRole bool `json:"hasAdminRole"`
|
||||
}
|
||||
|
||||
type InfoStore interface {
|
||||
GetLatestSessionSaveDataSlot(uuid []byte) (int, error)
|
||||
}
|
||||
|
||||
// /account/info - get account info
|
||||
func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
|
||||
slot, _ := db.GetLatestSessionSaveDataSlot(uuid)
|
||||
func Info[T InfoStore](store T, username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
|
||||
slot, _ := store.GetLatestSessionSaveDataSlot(uuid)
|
||||
response := InfoResponse{
|
||||
Username: username,
|
||||
LastSessionSlot: slot,
|
||||
|
@ -24,14 +24,18 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
|
||||
type LoginResponse GenericAuthResponse
|
||||
|
||||
// Interface for database operations needed for login.
|
||||
type LoginStore interface {
|
||||
FetchAccountKeySaltFromUsername(username string) (key, salt []byte, err error)
|
||||
AddAccountSession(username string, token []byte) error
|
||||
}
|
||||
|
||||
// /account/login - log into account
|
||||
func Login(username, password string) (LoginResponse, error) {
|
||||
func Login[T LoginStore](store T, username, password string) (LoginResponse, error) {
|
||||
var response LoginResponse
|
||||
|
||||
if !isValidUsername(username) {
|
||||
@ -42,12 +46,11 @@ func Login(username, password string) (LoginResponse, error) {
|
||||
return response, fmt.Errorf("invalid password")
|
||||
}
|
||||
|
||||
key, salt, err := db.FetchAccountKeySaltFromUsername(username)
|
||||
key, salt, err := store.FetchAccountKeySaltFromUsername(username)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return response, fmt.Errorf("account doesn't exist")
|
||||
}
|
||||
|
||||
return response, err
|
||||
}
|
||||
|
||||
@ -55,8 +58,7 @@ func Login(username, password string) (LoginResponse, error) {
|
||||
return response, fmt.Errorf("password doesn't match")
|
||||
}
|
||||
|
||||
response.Token, err = GenerateTokenForUsername(username)
|
||||
|
||||
response.Token, err = GenerateTokenForUsername(store, username)
|
||||
if err != nil {
|
||||
return response, fmt.Errorf("failed to generate token: %s", err)
|
||||
}
|
||||
@ -64,14 +66,19 @@ func Login(username, password string) (LoginResponse, error) {
|
||||
return response, nil
|
||||
}
|
||||
|
||||
func GenerateTokenForUsername(username string) (string, error) {
|
||||
type GenerateTokenForUsernameStore interface {
|
||||
AddAccountSession(username string, token []byte) error
|
||||
}
|
||||
|
||||
// GenerateTokenForUsername generates a session token and stores it using the provided DBAccountStore.
|
||||
func GenerateTokenForUsername[T GenerateTokenForUsernameStore](store T, username string) (string, error) {
|
||||
token := make([]byte, TokenSize)
|
||||
_, err := rand.Read(token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate token: %s", err)
|
||||
}
|
||||
|
||||
err = db.AddAccountSession(username, token)
|
||||
err = store.AddAccountSession(username, token)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to add account session")
|
||||
}
|
||||
|
154
api/account/login_test.go
Normal file
154
api/account/login_test.go
Normal file
@ -0,0 +1,154 @@
|
||||
package account
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func defaultMockStore() *mockDBAccountStore {
|
||||
return &mockDBAccountStore{
|
||||
FetchFunc: func(username string) ([]byte, []byte, error) {
|
||||
return []byte("key"), []byte("salt"), nil
|
||||
},
|
||||
AddSessionFunc: func(username string, token []byte) error { return nil },
|
||||
}
|
||||
}
|
||||
|
||||
type mockDBAccountStore struct {
|
||||
FetchFunc func(username string) ([]byte, []byte, error)
|
||||
AddSessionFunc func(username string, token []byte) error
|
||||
}
|
||||
|
||||
func (m *mockDBAccountStore) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
|
||||
return m.FetchFunc(username)
|
||||
}
|
||||
func (m *mockDBAccountStore) AddAccountSession(username string, token []byte) error {
|
||||
return m.AddSessionFunc(username, token)
|
||||
}
|
||||
|
||||
func TestLogin(t *testing.T) {
|
||||
t.Run("UsernameMinLength", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "a", "password123")
|
||||
if err == nil {
|
||||
t.Errorf("expected error due to password mismatch or DB, got nil")
|
||||
}
|
||||
})
|
||||
t.Run("UsernameMaxLength", func(t *testing.T) {
|
||||
uname := "abcdefghijklmnop"
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, uname, "password123")
|
||||
if err == nil {
|
||||
t.Errorf("expected error due to password mismatch or DB, got nil")
|
||||
}
|
||||
})
|
||||
t.Run("UsernameTooLong", func(t *testing.T) {
|
||||
uname := "abcdefghijklmnopq"
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, uname, "password123")
|
||||
if err == nil || err.Error() != "invalid username" {
|
||||
t.Errorf("expected invalid username error for too long username, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("UsernameWithInvalidChars", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "user!@#", "password123")
|
||||
if err == nil || err.Error() != "invalid username" {
|
||||
t.Errorf("expected invalid username error for special chars, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("EmptyUsername", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "", "password123")
|
||||
if err == nil || err.Error() != "invalid username" {
|
||||
t.Errorf("expected invalid username error for empty username, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("EmptyPassword", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "validuser", "")
|
||||
if err == nil || err.Error() != "invalid password" {
|
||||
t.Errorf("expected invalid password error for empty password, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("MinPasswordLength", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "validuser", "123456")
|
||||
if err == nil {
|
||||
t.Errorf("expected error due to password mismatch or DB, got nil")
|
||||
}
|
||||
})
|
||||
t.Run("PasswordWithSpecialChars", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "validuser", "p@$$w0rd!")
|
||||
if err == nil {
|
||||
t.Errorf("expected error due to password mismatch or DB, got nil")
|
||||
}
|
||||
})
|
||||
t.Run("DBUnexpectedError", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
store.FetchFunc = func(username string) ([]byte, []byte, error) {
|
||||
return nil, nil, errors.New("some db error")
|
||||
}
|
||||
_, err := Login(store, "validuser", "password123")
|
||||
if err == nil || err.Error() != "some db error" {
|
||||
t.Errorf("expected DB error to propagate, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("InvalidUsername", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "!invaliduser", "password123")
|
||||
if err == nil || err.Error() != "invalid username" {
|
||||
t.Errorf("expected invalid username error, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("ShortPassword", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
_, err := Login(store, "validuser", "123")
|
||||
if err == nil || err.Error() != "invalid password" {
|
||||
t.Errorf("expected invalid password error, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("AccountDoesNotExist", func(t *testing.T) {
|
||||
store := defaultMockStore()
|
||||
store.FetchFunc = func(username string) ([]byte, []byte, error) {
|
||||
return nil, nil, sql.ErrNoRows
|
||||
}
|
||||
_, err := Login(store, "nonexistent", "password123")
|
||||
if err == nil || err.Error() != "account doesn't exist" {
|
||||
t.Errorf("expected account doesn't exist error, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("PasswordMismatch", func(t *testing.T) {
|
||||
correctSalt := []byte("somesalt")
|
||||
correctKey := []byte("correctkey")
|
||||
store := defaultMockStore()
|
||||
store.FetchFunc = func(username string) ([]byte, []byte, error) {
|
||||
return correctKey, correctSalt, nil
|
||||
}
|
||||
_, err := Login(store, "validuser", "wrongpassword")
|
||||
if err == nil || err.Error() != "password doesn't match" {
|
||||
t.Errorf("expected password doesn't match error, got: %v", err)
|
||||
}
|
||||
})
|
||||
t.Run("Success", func(t *testing.T) {
|
||||
correctSalt := []byte("somesalt")
|
||||
password := "goodpassword"
|
||||
correctKey := deriveArgon2IDKey([]byte(password), correctSalt)
|
||||
store := defaultMockStore()
|
||||
store.FetchFunc = func(username string) ([]byte, []byte, error) {
|
||||
return correctKey, correctSalt, nil
|
||||
}
|
||||
store.AddSessionFunc = func(username string, token []byte) error {
|
||||
return nil
|
||||
}
|
||||
resp, err := Login(store, "validuser", password)
|
||||
if err != nil {
|
||||
t.Errorf("expected success, got error: %v", err)
|
||||
}
|
||||
if resp.Token == "" {
|
||||
t.Errorf("expected token to be set on success")
|
||||
}
|
||||
})
|
||||
}
|
@ -21,13 +21,17 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
|
||||
// /account/logout - log out of account
|
||||
func Logout(token []byte) error {
|
||||
err := db.RemoveSessionFromToken(token)
|
||||
|
||||
// Interface for database operations needed for logout.
|
||||
type LogoutStore interface {
|
||||
RemoveSessionFromToken(token []byte) error
|
||||
}
|
||||
|
||||
func Logout[T LogoutStore](store T, token []byte) error {
|
||||
err := store.RemoveSessionFromToken(token)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return fmt.Errorf("token not found")
|
||||
|
@ -20,12 +20,15 @@ package account
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
|
||||
// Interface for database operations needed for registration.
|
||||
type RegisterStore interface {
|
||||
AddAccountRecord(uuid []byte, username string, passwordHash []byte, salt []byte) error
|
||||
}
|
||||
|
||||
// /account/register - register account
|
||||
func Register(username, password string) error {
|
||||
func Register[T RegisterStore](store T, username, password string) error {
|
||||
if !isValidUsername(username) {
|
||||
return fmt.Errorf("invalid username")
|
||||
}
|
||||
@ -46,7 +49,7 @@ func Register(username, password string) error {
|
||||
return fmt.Errorf("failed to generate salt: %s", err)
|
||||
}
|
||||
|
||||
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
|
||||
err = store.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add account record: %s", err)
|
||||
}
|
||||
|
@ -30,7 +30,7 @@ import (
|
||||
)
|
||||
|
||||
func Init(mux *http.ServeMux) error {
|
||||
err := scheduleStatRefresh()
|
||||
err := scheduleStatRefresh(db.Store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -109,7 +109,7 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
uuid, err := db.FetchUUIDFromToken(token)
|
||||
uuid, err := db.Store.FetchUUIDFromToken(token)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("failed to validate token: %s", err)
|
||||
}
|
||||
|
@ -61,7 +61,7 @@ func Init() error {
|
||||
secret = newSecret
|
||||
}
|
||||
|
||||
seed, err := db.TryAddDailyRun(Seed())
|
||||
seed, err := db.Store.TryAddDailyRun(Seed())
|
||||
if err != nil {
|
||||
log.Print(err)
|
||||
}
|
||||
@ -71,7 +71,7 @@ func Init() error {
|
||||
_, err = scheduler.AddFunc("@daily", func() {
|
||||
time.Sleep(time.Second)
|
||||
|
||||
seed, err = db.TryAddDailyRun(Seed())
|
||||
seed, err = db.Store.TryAddDailyRun(Seed())
|
||||
if err != nil {
|
||||
log.Printf("error while recording new daily: %s", err)
|
||||
} else {
|
||||
|
@ -22,9 +22,14 @@ import (
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
// Interface for database operations needed for fetching rankings.
|
||||
type RankingsStore interface {
|
||||
FetchRankings(category, page int) ([]defs.DailyRanking, error)
|
||||
}
|
||||
|
||||
// /daily/rankings - fetch daily rankings
|
||||
func Rankings(category, page int) ([]defs.DailyRanking, error) {
|
||||
rankings, err := db.FetchRankings(category, page)
|
||||
func Rankings[T RankingsStore](store T, category, page int) ([]defs.DailyRanking, error) {
|
||||
rankings, err := db.Store.FetchRankings(category, page)
|
||||
if err != nil {
|
||||
return rankings, err
|
||||
}
|
||||
|
@ -17,13 +17,13 @@
|
||||
|
||||
package daily
|
||||
|
||||
import (
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
)
|
||||
type RankingPageCountStore interface {
|
||||
FetchRankingPageCount(category int) (int, error)
|
||||
}
|
||||
|
||||
// /daily/rankingpagecount - fetch daily ranking page count
|
||||
func RankingPageCount(category int) (int, error) {
|
||||
pageCount, err := db.FetchRankingPageCount(category)
|
||||
func RankingPageCount[T RankingPageCountStore](store T, category int) (int, error) {
|
||||
pageCount, err := store.FetchRankingPageCount(category)
|
||||
if err != nil {
|
||||
return pageCount, err
|
||||
}
|
||||
|
165
api/endpoints.go
165
api/endpoints.go
@ -50,17 +50,17 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
username, err := db.FetchUsernameFromUUID(uuid)
|
||||
username, err := db.Store.FetchUsernameFromUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
discordId, err := db.FetchDiscordIdByUsername(username)
|
||||
discordId, err := db.Store.FetchDiscordIdByUsername(username)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
googleId, err := db.FetchGoogleIdByUsername(username)
|
||||
googleId, err := db.Store.FetchGoogleIdByUsername(username)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -68,10 +68,10 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
var hasAdminRole bool
|
||||
if discordId != "" {
|
||||
hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
|
||||
hasAdminRole, _ = account.Discord.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
|
||||
}
|
||||
|
||||
response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole)
|
||||
response, err := account.Info(db.Store, username, discordId, googleId, uuid, hasAdminRole)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -81,7 +81,7 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
|
||||
err := account.Register(r.PostFormValue("username"), r.PostFormValue("password"))
|
||||
err := account.Register(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -91,7 +91,7 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
|
||||
response, err := account.Login(r.PostFormValue("username"), r.PostFormValue("password"))
|
||||
response, err := account.Login(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -107,20 +107,20 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = account.ChangePW(uuid, r.PostFormValue("password"))
|
||||
err = account.ChangePW(db.Store, uuid, r.PostFormValue("password"))
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
username, err := db.FetchUsernameFromUUID(uuid)
|
||||
username, err := db.Store.FetchUsernameFromUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// create a new session with these credentials
|
||||
response, err := account.Login(username, r.Form.Get("password"))
|
||||
response, err := account.Login(db.Store, username, r.Form.Get("password"))
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -136,7 +136,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = account.Logout(token)
|
||||
err = account.Logout(db.Store, token)
|
||||
if err != nil {
|
||||
// also possible for InternalServerError but that's unlikely unless the server blew up
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
@ -183,7 +183,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
@ -191,7 +191,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
switch r.PathValue("action") {
|
||||
case "get":
|
||||
save, err := savedata.GetSession(uuid, slot)
|
||||
save, err := savedata.GetSession(db.Store, uuid, slot)
|
||||
if err != nil {
|
||||
if errors.Is(err, savedata.ErrSaveNotExist) {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
@ -211,7 +211,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
existingSave, err := savedata.GetSession(uuid, slot)
|
||||
existingSave, err := savedata.GetSession(db.Store, uuid, slot)
|
||||
if err != nil {
|
||||
if !errors.Is(err, savedata.ErrSaveNotExist) {
|
||||
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
|
||||
@ -224,7 +224,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
err = savedata.UpdateSession(uuid, slot, session)
|
||||
err = savedata.UpdateSession(db.Store, uuid, slot, session)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to put session data: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
@ -239,13 +239,13 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
seed, err := db.GetDailyRunSeed()
|
||||
seed, err := db.Store.GetDailyRunSeed()
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
resp, err := savedata.Clear(uuid, slot, seed, session)
|
||||
resp, err := savedata.Clear(db.Store, uuid, slot, seed, session)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -253,7 +253,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
writeJSON(w, r, resp)
|
||||
case "newclear":
|
||||
resp, err := savedata.NewClear(uuid, slot)
|
||||
resp, err := savedata.NewClear(db.Store, uuid, slot)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to read new clear: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
@ -261,7 +261,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
writeJSON(w, r, resp)
|
||||
case "delete":
|
||||
err := savedata.DeleteSession(uuid, slot)
|
||||
err := savedata.DeleteSession(db.Store, uuid, slot)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -301,7 +301,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
active, err := db.IsActiveSession(uuid, data.ClientSessionId)
|
||||
active, err := db.Store.IsActiveSession(uuid, data.ClientSessionId)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
@ -312,7 +312,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
storedTrainerId, storedSecretId, err := db.FetchTrainerIds(uuid)
|
||||
storedTrainerId, storedSecretId, err := db.Store.FetchTrainerIds(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -324,14 +324,14 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
err = db.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid)
|
||||
err = db.Store.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
oldSystem, err := savedata.GetSystem(uuid)
|
||||
oldSystem, err := savedata.GetSystem(db.Store, uuid)
|
||||
if err != nil {
|
||||
if !errors.Is(err, savedata.ErrSaveNotExist) {
|
||||
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
|
||||
@ -356,7 +356,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
existingSave, err := savedata.GetSession(uuid, data.SessionSlotId)
|
||||
existingSave, err := savedata.GetSession(db.Store, uuid, data.SessionSlotId)
|
||||
if err != nil {
|
||||
if !errors.Is(err, savedata.ErrSaveNotExist) {
|
||||
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
|
||||
@ -369,13 +369,13 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
err = savedata.Update(uuid, data.SessionSlotId, data.Session)
|
||||
err = savedata.Update(db.Store, uuid, data.SessionSlotId, data.Session)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = savedata.Update(uuid, 0, data.System)
|
||||
err = savedata.Update(db.Store, uuid, 0, data.System)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -401,7 +401,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
active, err := db.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
active, err := db.Store.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
@ -410,14 +410,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.PathValue("action") {
|
||||
case "get":
|
||||
if !active {
|
||||
err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
save, err := savedata.GetSystem(uuid)
|
||||
save, err := savedata.GetSystem(db.Store, uuid)
|
||||
if err != nil {
|
||||
if errors.Is(err, savedata.ErrSaveNotExist) {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
@ -442,7 +442,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
oldSystem, err := savedata.GetSystem(uuid)
|
||||
oldSystem, err := savedata.GetSystem(db.Store, uuid)
|
||||
if err != nil {
|
||||
if !errors.Is(err, savedata.ErrSaveNotExist) {
|
||||
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
|
||||
@ -467,7 +467,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
err = savedata.UpdateSystem(uuid, system)
|
||||
err = savedata.UpdateSystem(db.Store, uuid, system)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to put system data: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
@ -481,13 +481,13 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// not valid, send server state
|
||||
if !active {
|
||||
err := db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
err := db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
storedSaveData, err := db.ReadSystemSaveData(uuid)
|
||||
storedSaveData, err := db.Store.ReadSystemSaveData(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError)
|
||||
return
|
||||
@ -498,7 +498,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
writeJSON(w, r, response)
|
||||
case "delete":
|
||||
err := savedata.DeleteSystem(uuid)
|
||||
err := savedata.DeleteSystem(db.Store, uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -511,9 +511,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
// Interface providing database operations needed for getting daily seed.
|
||||
type HandleDailySeedStore interface {
|
||||
GetDailyRunSeed() (string, error)
|
||||
}
|
||||
|
||||
// daily
|
||||
func handleDailySeed(w http.ResponseWriter, r *http.Request) {
|
||||
seed, err := db.GetDailyRunSeed()
|
||||
seed, err := db.Store.GetDailyRunSeed()
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -522,6 +527,11 @@ func handleDailySeed(w http.ResponseWriter, r *http.Request) {
|
||||
fmt.Fprint(w, seed)
|
||||
}
|
||||
|
||||
// Interface for database operations needed for getting daily rankings.
|
||||
type HandleDailyRankingsStore interface {
|
||||
daily.RankingsStore
|
||||
}
|
||||
|
||||
func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
|
||||
@ -543,7 +553,7 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
rankings, err := daily.Rankings(category, page)
|
||||
rankings, err := daily.Rankings(db.Store, category, page)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -552,6 +562,10 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
|
||||
writeJSON(w, r, rankings)
|
||||
}
|
||||
|
||||
type HandleDailyRankingsPageCountStore interface {
|
||||
daily.RankingPageCountStore
|
||||
}
|
||||
|
||||
func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
|
||||
var category int
|
||||
if r.URL.Query().Has("category") {
|
||||
@ -563,7 +577,7 @@ func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
}
|
||||
|
||||
count, err := daily.RankingPageCount(category)
|
||||
count, err := daily.RankingPageCount(db.Store, category)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
}
|
||||
@ -579,9 +593,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
|
||||
var err error
|
||||
switch provider {
|
||||
case "discord":
|
||||
externalAuthId, err = account.HandleDiscordCallback(w, r)
|
||||
externalAuthId, err = account.Discord.HandleDiscordCallback(w, r)
|
||||
case "google":
|
||||
externalAuthId, err = account.HandleGoogleCallback(w, r)
|
||||
externalAuthId, err = account.Google.HandleGoogleCallback(w, r)
|
||||
default:
|
||||
http.Error(w, "invalid provider", http.StatusBadRequest)
|
||||
return
|
||||
@ -600,7 +614,7 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userName, err := db.FetchUsernameBySessionToken(stateByte)
|
||||
userName, err := db.Store.FetchUsernameBySessionToken(stateByte)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
|
||||
return
|
||||
@ -608,9 +622,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
switch provider {
|
||||
case "discord":
|
||||
err = db.AddDiscordIdByUsername(externalAuthId, userName)
|
||||
err = db.Store.AddDiscordIdByUsername(externalAuthId, userName)
|
||||
case "google":
|
||||
err = db.AddGoogleIdByUsername(externalAuthId, userName)
|
||||
err = db.Store.AddGoogleIdByUsername(externalAuthId, userName)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -622,16 +636,16 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
|
||||
var userName string
|
||||
switch provider {
|
||||
case "discord":
|
||||
userName, err = db.FetchUsernameByDiscordId(externalAuthId)
|
||||
userName, err = db.Store.FetchUsernameByDiscordId(externalAuthId)
|
||||
case "google":
|
||||
userName, err = db.FetchUsernameByGoogleId(externalAuthId)
|
||||
userName, err = db.Store.FetchUsernameByGoogleId(externalAuthId)
|
||||
}
|
||||
if err != nil {
|
||||
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
|
||||
return
|
||||
}
|
||||
|
||||
sessionToken, err := account.GenerateTokenForUsername(userName)
|
||||
sessionToken, err := account.GenerateTokenForUsername(db.Store, userName)
|
||||
if err != nil {
|
||||
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
|
||||
return
|
||||
@ -651,6 +665,11 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
|
||||
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
|
||||
}
|
||||
|
||||
type HandleProviderLogoutStore interface {
|
||||
RemoveDiscordIdByUUID(uuid []byte) error
|
||||
RemoveGoogleIdByUUID(uuid []byte) error
|
||||
}
|
||||
|
||||
func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
|
||||
uuid, err := uuidFromRequest(r)
|
||||
if err != nil {
|
||||
@ -660,9 +679,9 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
switch r.PathValue("provider") {
|
||||
case "discord":
|
||||
err = db.RemoveDiscordIdByUUID(uuid)
|
||||
err = db.Store.RemoveDiscordIdByUUID(uuid)
|
||||
case "google":
|
||||
err = db.RemoveGoogleIdByUUID(uuid)
|
||||
err = db.Store.RemoveGoogleIdByUUID(uuid)
|
||||
default:
|
||||
http.Error(w, "invalid provider", http.StatusBadRequest)
|
||||
return
|
||||
@ -681,13 +700,13 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
|
||||
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
if !hasRole || err != nil {
|
||||
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
|
||||
return
|
||||
@ -698,19 +717,19 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
|
||||
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
|
||||
_, err = db.CheckUsernameExists(username)
|
||||
_, err = db.Store.CheckUsernameExists(username)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
userUuid, err := db.FetchUUIDFromUsername(username)
|
||||
userUuid, err := db.Store.FetchUUIDFromUsername(username)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.AddDiscordIdByUUID(discordId, userUuid)
|
||||
err = db.Store.AddDiscordIdByUUID(discordId, userUuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -728,13 +747,13 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
|
||||
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
if !hasRole || err != nil {
|
||||
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
|
||||
return
|
||||
@ -748,26 +767,26 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Username given, removing discordId")
|
||||
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
|
||||
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
|
||||
_, err = db.CheckUsernameExists(username)
|
||||
_, err = db.Store.CheckUsernameExists(username)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
userUuid, err := db.FetchUUIDFromUsername(username)
|
||||
userUuid, err := db.Store.FetchUUIDFromUsername(username)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.RemoveDiscordIdByUUID(userUuid)
|
||||
err = db.Store.RemoveDiscordIdByUUID(userUuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case discordId != "":
|
||||
log.Printf("DiscordID given, removing discordId")
|
||||
err = db.RemoveDiscordIdByDiscordId(discordId)
|
||||
err = db.Store.RemoveDiscordIdByDiscordId(discordId)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -786,13 +805,13 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
|
||||
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
if !hasRole || err != nil {
|
||||
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
|
||||
return
|
||||
@ -803,19 +822,19 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
|
||||
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
|
||||
_, err = db.CheckUsernameExists(username)
|
||||
_, err = db.Store.CheckUsernameExists(username)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
userUuid, err := db.FetchUUIDFromUsername(username)
|
||||
userUuid, err := db.Store.FetchUUIDFromUsername(username)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.AddGoogleIdByUUID(googleId, userUuid)
|
||||
err = db.Store.AddGoogleIdByUUID(googleId, userUuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -833,13 +852,13 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
|
||||
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
if !hasRole || err != nil {
|
||||
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
|
||||
return
|
||||
@ -853,26 +872,26 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
|
||||
log.Printf("Username given, removing googleId")
|
||||
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
|
||||
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
|
||||
_, err = db.CheckUsernameExists(username)
|
||||
_, err = db.Store.CheckUsernameExists(username)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
userUuid, err := db.FetchUUIDFromUsername(username)
|
||||
userUuid, err := db.Store.FetchUUIDFromUsername(username)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.RemoveGoogleIdByUUID(userUuid)
|
||||
err = db.Store.RemoveGoogleIdByUUID(userUuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
case googleId != "":
|
||||
log.Printf("DiscordID given, removing googleId")
|
||||
err = db.RemoveGoogleIdByDiscordId(googleId)
|
||||
err = db.Store.RemoveGoogleIdByDiscordId(googleId)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
@ -891,13 +910,13 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
|
||||
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
|
||||
if !hasRole || err != nil {
|
||||
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
|
||||
return
|
||||
@ -907,14 +926,14 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
|
||||
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
|
||||
_, err = db.CheckUsernameExists(username)
|
||||
_, err = db.Store.CheckUsernameExists(username)
|
||||
if err != nil {
|
||||
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
// this does a single call that does a query for multiple columns from our database and makes an object out of it, which is returned to us
|
||||
adminSearchResult, err := db.FetchAdminDetailsByUsername(username)
|
||||
adminSearchResult, err := db.Store.FetchAdminDetailsByUsername(username)
|
||||
if err != nil {
|
||||
httpError(w, r, err, http.StatusInternalServerError)
|
||||
return
|
||||
|
@ -21,7 +21,6 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
@ -30,10 +29,20 @@ type ClearResponse struct {
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
// Interface for database operations needed for `Clear`
|
||||
// Helps with testing and reduces coupling.
|
||||
type ClearStore interface {
|
||||
UpdateAccountLastActivity(uuid []byte) error
|
||||
TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error)
|
||||
DeleteSessionSaveData(uuid []byte, slot int) error
|
||||
AddOrUpdateAccountDailyRun(uuid []byte, score int, waveCompleted int) error
|
||||
SetAccountBanned(uuid []byte, banned bool) error
|
||||
}
|
||||
|
||||
// /savedata/clear - mark session save data as cleared and delete
|
||||
func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) {
|
||||
func Clear[T ClearStore](store T, uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) {
|
||||
var response ClearResponse
|
||||
err := db.UpdateAccountLastActivity(uuid)
|
||||
err := store.UpdateAccountLastActivity(uuid)
|
||||
if err != nil {
|
||||
log.Print("failed to update account last activity")
|
||||
}
|
||||
@ -51,23 +60,23 @@ func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (Clear
|
||||
}
|
||||
|
||||
if save.Score >= 20000 {
|
||||
db.SetAccountBanned(uuid, true)
|
||||
store.SetAccountBanned(uuid, true)
|
||||
}
|
||||
|
||||
err = db.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
|
||||
err = store.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
|
||||
if err != nil {
|
||||
log.Printf("failed to add or update daily run record: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
if sessionCompleted {
|
||||
response.Success, err = db.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
|
||||
response.Success, err = store.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
|
||||
if err != nil {
|
||||
log.Printf("failed to mark seed as completed: %s", err)
|
||||
}
|
||||
}
|
||||
|
||||
err = db.DeleteSessionSaveData(uuid, slot)
|
||||
err = store.DeleteSessionSaveData(uuid, slot)
|
||||
if err != nil {
|
||||
log.Printf("failed to delete session save data: %s", err)
|
||||
}
|
||||
|
@ -21,13 +21,19 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
// Interface for database operations needed for delete.
|
||||
// This is to allow easier testing and less coupling.
|
||||
type DeleteStore interface {
|
||||
UpdateAccountLastActivity(uuid []byte) error
|
||||
DeleteSessionSaveData(uuid []byte, slot int) error
|
||||
}
|
||||
|
||||
// /savedata/delete - delete save data
|
||||
func Delete(uuid []byte, datatype, slot int) error {
|
||||
err := db.UpdateAccountLastActivity(uuid)
|
||||
func Delete[T DeleteStore](store T, uuid []byte, datatype, slot int) error {
|
||||
err := store.UpdateAccountLastActivity(uuid)
|
||||
if err != nil {
|
||||
log.Print("failed to update account last activity")
|
||||
}
|
||||
@ -39,7 +45,7 @@ func Delete(uuid []byte, datatype, slot int) error {
|
||||
break
|
||||
}
|
||||
|
||||
err = db.DeleteSessionSaveData(uuid, slot)
|
||||
err = store.DeleteSessionSaveData(uuid, slot)
|
||||
default:
|
||||
err = fmt.Errorf("invalid data type")
|
||||
}
|
||||
|
@ -20,22 +20,26 @@ package savedata
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
type NewClearStore interface {
|
||||
ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
|
||||
ReadSeedCompleted(uuid []byte, seed string) (bool, error)
|
||||
}
|
||||
|
||||
// /savedata/newclear - return whether a session is a new clear for its seed
|
||||
func NewClear(uuid []byte, slot int) (bool, error) {
|
||||
func NewClear[T NewClearStore](store T, uuid []byte, slot int) (bool, error) {
|
||||
if slot < 0 || slot >= defs.SessionSlotCount {
|
||||
return false, fmt.Errorf("slot id %d out of range", slot)
|
||||
}
|
||||
|
||||
session, err := db.ReadSessionSaveData(uuid, slot)
|
||||
session, err := store.ReadSessionSaveData(uuid, slot)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
completed, err := db.ReadSeedCompleted(uuid, session.Seed)
|
||||
completed, err := store.ReadSeedCompleted(uuid, session.Seed)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to read seed completed: %s", err)
|
||||
}
|
||||
|
@ -21,12 +21,15 @@ import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) {
|
||||
session, err := db.ReadSessionSaveData(uuid, slot)
|
||||
type GetSessionStore interface {
|
||||
ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
|
||||
}
|
||||
|
||||
func GetSession[T GetSessionStore](store T, uuid []byte, slot int) (defs.SessionSaveData, error) {
|
||||
session, err := store.ReadSessionSaveData(uuid, slot)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = ErrSaveNotExist
|
||||
@ -38,8 +41,13 @@ func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error {
|
||||
err := db.StoreSessionSaveData(uuid, data, slot)
|
||||
type UpdateSessionStore interface {
|
||||
StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error
|
||||
DeleteSessionSaveData(uuid []byte, slot int) error
|
||||
}
|
||||
|
||||
func UpdateSession[T UpdateSessionStore](store T, uuid []byte, slot int, data defs.SessionSaveData) error {
|
||||
err := store.StoreSessionSaveData(uuid, data, slot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -47,8 +55,12 @@ func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteSession(uuid []byte, slot int) error {
|
||||
err := db.DeleteSessionSaveData(uuid, slot)
|
||||
type DeleteSessionStore interface {
|
||||
DeleteSessionSaveData(uuid []byte, slot int) error
|
||||
}
|
||||
|
||||
func DeleteSession[T DeleteSessionStore](store T, uuid []byte, slot int) error {
|
||||
err := store.DeleteSessionSaveData(uuid, slot)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -24,24 +24,28 @@ import (
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3/types"
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
var ErrSaveNotExist = errors.New("save does not exist")
|
||||
|
||||
func GetSystem(uuid []byte) (defs.SystemSaveData, error) {
|
||||
type GetSystemStore interface {
|
||||
GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error)
|
||||
ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error)
|
||||
}
|
||||
|
||||
func GetSystem[T GetSystemStore](store T, uuid []byte) (defs.SystemSaveData, error) {
|
||||
var system defs.SystemSaveData
|
||||
var err error
|
||||
|
||||
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
|
||||
system, err = db.GetSystemSaveFromS3(uuid)
|
||||
system, err = store.GetSystemSaveFromS3(uuid)
|
||||
var nokey *types.NoSuchKey
|
||||
if errors.As(err, &nokey) {
|
||||
err = ErrSaveNotExist
|
||||
}
|
||||
} else { // use database
|
||||
system, err = db.ReadSystemSaveData(uuid)
|
||||
system, err = store.ReadSystemSaveData(uuid)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = ErrSaveNotExist
|
||||
}
|
||||
@ -53,20 +57,27 @@ func GetSystem(uuid []byte) (defs.SystemSaveData, error) {
|
||||
return system, nil
|
||||
}
|
||||
|
||||
func UpdateSystem(uuid []byte, data defs.SystemSaveData) error {
|
||||
// Interface for database operations needed for updating system data.
|
||||
type UpdateSystemStore interface {
|
||||
UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error
|
||||
StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error
|
||||
StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error
|
||||
}
|
||||
|
||||
func UpdateSystem[T UpdateSystemStore](store T, uuid []byte, data defs.SystemSaveData) error {
|
||||
if data.TrainerId == 0 && data.SecretId == 0 {
|
||||
return fmt.Errorf("invalid system data")
|
||||
}
|
||||
|
||||
err := db.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts)
|
||||
err := store.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update account stats: %s", err)
|
||||
}
|
||||
|
||||
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
|
||||
err = db.StoreSystemSaveDataS3(uuid, data)
|
||||
err = store.StoreSystemSaveDataS3(uuid, data)
|
||||
} else {
|
||||
err = db.StoreSystemSaveData(uuid, data)
|
||||
err = store.StoreSystemSaveData(uuid, data)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
@ -75,8 +86,12 @@ func UpdateSystem(uuid []byte, data defs.SystemSaveData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteSystem(uuid []byte) error {
|
||||
err := db.DeleteSystemSaveData(uuid)
|
||||
type DeleteSystemStore interface {
|
||||
DeleteSystemSaveData(uuid []byte) error
|
||||
}
|
||||
|
||||
func DeleteSystem[T DeleteSystemStore](store T, uuid []byte) error {
|
||||
err := store.DeleteSystemSaveData(uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -21,13 +21,18 @@ import (
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
type UpdateStore interface {
|
||||
UpdateAccountLastActivity(uuid []byte) error
|
||||
UpdateSystemStore
|
||||
UpdateSessionStore
|
||||
}
|
||||
|
||||
// /savedata/update - update save data
|
||||
func Update(uuid []byte, slot int, save any) error {
|
||||
err := db.UpdateAccountLastActivity(uuid)
|
||||
func Update[T UpdateStore](store T, uuid []byte, slot int, save any) error {
|
||||
err := store.UpdateAccountLastActivity(uuid)
|
||||
if err != nil {
|
||||
log.Print("failed to update account last activity")
|
||||
}
|
||||
@ -38,13 +43,13 @@ func Update(uuid []byte, slot int, save any) error {
|
||||
return fmt.Errorf("invalid system data")
|
||||
}
|
||||
|
||||
return UpdateSystem(uuid, save)
|
||||
return UpdateSystem(store, uuid, save)
|
||||
case defs.SessionSaveData: // Session
|
||||
if slot < 0 || slot >= defs.SessionSlotCount {
|
||||
return fmt.Errorf("slot id %d out of range", slot)
|
||||
}
|
||||
|
||||
return UpdateSession(uuid, slot, save)
|
||||
return UpdateSession(store, uuid, slot, save)
|
||||
default:
|
||||
return fmt.Errorf("invalid data type")
|
||||
}
|
||||
|
19
api/stats.go
19
api/stats.go
@ -21,7 +21,6 @@ import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"github.com/pagefaultgames/rogueserver/db"
|
||||
"github.com/robfig/cron/v3"
|
||||
)
|
||||
|
||||
@ -32,9 +31,9 @@ var (
|
||||
classicSessionCount int
|
||||
)
|
||||
|
||||
func scheduleStatRefresh() error {
|
||||
func scheduleStatRefresh[T updateStatsStore](store T) error {
|
||||
_, err := scheduler.AddFunc("@every 30s", func() {
|
||||
err := updateStats()
|
||||
err := updateStats(store)
|
||||
if err != nil {
|
||||
log.Printf("failed to update stats: %s", err)
|
||||
}
|
||||
@ -47,19 +46,25 @@ func scheduleStatRefresh() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateStats() error {
|
||||
type updateStatsStore interface {
|
||||
FetchPlayerCount() (int, error)
|
||||
FetchBattleCount() (int, error)
|
||||
FetchClassicSessionCount() (int, error)
|
||||
}
|
||||
|
||||
func updateStats[T updateStatsStore](store T) error {
|
||||
var err error
|
||||
playerCount, err = db.FetchPlayerCount()
|
||||
playerCount, err = store.FetchPlayerCount()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
battleCount, err = db.FetchBattleCount()
|
||||
battleCount, err = store.FetchBattleCount()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
classicSessionCount, err = db.FetchClassicSessionCount()
|
||||
classicSessionCount, err = store.FetchClassicSessionCount()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
123
db/account.go
123
db/account.go
@ -1,18 +1,18 @@
|
||||
/*
|
||||
Copyright (C) 2024 - 2025 Pagefault Games
|
||||
Copyright (C) 2024 - 2025 Pagefault Games
|
||||
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
This program is free software: you can redistribute it and/or modify
|
||||
it under the terms of the GNU Affero General Public License as published by
|
||||
the Free Software Foundation, either version 3 of the License, or
|
||||
(at your option) any later version.
|
||||
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
This program is distributed in the hope that it will be useful,
|
||||
but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
GNU Affero General Public License for more details.
|
||||
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
You should have received a copy of the GNU Affero General Public License
|
||||
along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
*/
|
||||
|
||||
package db
|
||||
@ -27,30 +27,30 @@ import (
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
func AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
|
||||
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddAccountSession(username string, token []byte) error {
|
||||
func (s *store) AddAccountSession(username string, token []byte) error {
|
||||
_, err := handle.Exec("INSERT INTO sessions (uuid, token, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddDiscordIdByUsername(discordId string, username string) error {
|
||||
// (removed, now a method on store)
|
||||
func (s *store) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
|
||||
var key, salt []byte
|
||||
err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return key, salt, nil
|
||||
}
|
||||
|
||||
func (s *store) AddDiscordIdByUsername(discordId string, username string) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE username = ?", discordId, username)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -59,7 +59,7 @@ func AddDiscordIdByUsername(discordId string, username string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddGoogleIdByUsername(googleId string, username string) error {
|
||||
func (s *store) AddGoogleIdByUsername(googleId string, username string) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE username = ?", googleId, username)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -68,7 +68,7 @@ func AddGoogleIdByUsername(googleId string, username string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddGoogleIdByUUID(googleId string, uuid []byte) error {
|
||||
func (s *store) AddGoogleIdByUUID(googleId string, uuid []byte) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE uuid = ?", googleId, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -77,7 +77,7 @@ func AddGoogleIdByUUID(googleId string, uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func AddDiscordIdByUUID(discordId string, uuid []byte) error {
|
||||
func (s *store) AddDiscordIdByUUID(discordId string, uuid []byte) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE uuid = ?", discordId, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -86,7 +86,7 @@ func AddDiscordIdByUUID(discordId string, uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchUsernameByDiscordId(discordId string) (string, error) {
|
||||
func (s *store) FetchUsernameByDiscordId(discordId string) (string, error) {
|
||||
var username string
|
||||
err := handle.QueryRow("SELECT username FROM accounts WHERE discordId = ?", discordId).Scan(&username)
|
||||
if err != nil {
|
||||
@ -96,7 +96,7 @@ func FetchUsernameByDiscordId(discordId string) (string, error) {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func FetchUsernameByGoogleId(googleId string) (string, error) {
|
||||
func (s *store) FetchUsernameByGoogleId(googleId string) (string, error) {
|
||||
var username string
|
||||
err := handle.QueryRow("SELECT username FROM accounts WHERE googleId = ?", googleId).Scan(&username)
|
||||
if err != nil {
|
||||
@ -106,7 +106,7 @@ func FetchUsernameByGoogleId(googleId string) (string, error) {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func FetchDiscordIdByUsername(username string) (string, error) {
|
||||
func (s *store) FetchDiscordIdByUsername(username string) (string, error) {
|
||||
var discordId sql.NullString
|
||||
err := handle.QueryRow("SELECT discordId FROM accounts WHERE username = ?", username).Scan(&discordId)
|
||||
if err != nil {
|
||||
@ -120,7 +120,7 @@ func FetchDiscordIdByUsername(username string) (string, error) {
|
||||
return discordId.String, nil
|
||||
}
|
||||
|
||||
func FetchGoogleIdByUsername(username string) (string, error) {
|
||||
func (s *store) FetchGoogleIdByUsername(username string) (string, error) {
|
||||
var googleId sql.NullString
|
||||
err := handle.QueryRow("SELECT googleId FROM accounts WHERE username = ?", username).Scan(&googleId)
|
||||
if err != nil {
|
||||
@ -134,7 +134,7 @@ func FetchGoogleIdByUsername(username string) (string, error) {
|
||||
return googleId.String, nil
|
||||
}
|
||||
|
||||
func FetchDiscordIdByUUID(uuid []byte) (string, error) {
|
||||
func (s *store) FetchDiscordIdByUUID(uuid []byte) (string, error) {
|
||||
var discordId sql.NullString
|
||||
err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId)
|
||||
if err != nil {
|
||||
@ -148,7 +148,7 @@ func FetchDiscordIdByUUID(uuid []byte) (string, error) {
|
||||
return discordId.String, nil
|
||||
}
|
||||
|
||||
func FetchGoogleIdByUUID(uuid []byte) (string, error) {
|
||||
func (s *store) FetchGoogleIdByUUID(uuid []byte) (string, error) {
|
||||
var googleId sql.NullString
|
||||
err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId)
|
||||
if err != nil {
|
||||
@ -162,7 +162,7 @@ func FetchGoogleIdByUUID(uuid []byte) (string, error) {
|
||||
return googleId.String, nil
|
||||
}
|
||||
|
||||
func FetchUsernameBySessionToken(token []byte) (string, error) {
|
||||
func (s *store) FetchUsernameBySessionToken(token []byte) (string, error) {
|
||||
var username string
|
||||
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username)
|
||||
if err != nil {
|
||||
@ -172,7 +172,7 @@ func FetchUsernameBySessionToken(token []byte) (string, error) {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func CheckUsernameExists(username string) (string, error) {
|
||||
func (s *store) CheckUsernameExists(username string) (string, error) {
|
||||
var dbUsername sql.NullString
|
||||
err := handle.QueryRow("SELECT username FROM accounts WHERE username = ?", username).Scan(&dbUsername)
|
||||
if err != nil {
|
||||
@ -185,7 +185,7 @@ func CheckUsernameExists(username string) (string, error) {
|
||||
return dbUsername.String, nil
|
||||
}
|
||||
|
||||
func FetchLastLoggedInDateByUsername(username string) (string, error) {
|
||||
func (s *store) FetchLastLoggedInDateByUsername(username string) (string, error) {
|
||||
var lastLoggedIn sql.NullString
|
||||
err := handle.QueryRow("SELECT lastLoggedIn FROM accounts WHERE username = ?", username).Scan(&lastLoggedIn)
|
||||
if err != nil {
|
||||
@ -206,7 +206,7 @@ type AdminSearchResponse struct {
|
||||
Registered string `json:"registered"`
|
||||
}
|
||||
|
||||
func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) {
|
||||
func (s *store) FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) {
|
||||
var username, discordId, googleId, lastActivity, registered sql.NullString
|
||||
var adminResponse AdminSearchResponse
|
||||
|
||||
@ -226,7 +226,7 @@ func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error)
|
||||
return adminResponse, nil
|
||||
}
|
||||
|
||||
func UpdateAccountPassword(uuid, key, salt []byte) error {
|
||||
func (s *store) UpdateAccountPassword(uuid, key, salt []byte) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET hash = ?, salt = ? WHERE uuid = ?", key, salt, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -235,7 +235,7 @@ func UpdateAccountPassword(uuid, key, salt []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateAccountLastActivity(uuid []byte) error {
|
||||
func (s *store) UpdateAccountLastActivity(uuid []byte) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET lastActivity = UTC_TIMESTAMP() WHERE uuid = ?", uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -244,7 +244,7 @@ func UpdateAccountLastActivity(uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
|
||||
func (s *store) UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
|
||||
var columns = []string{"playTime", "battles", "classicSessionsPlayed", "sessionsWon", "highestEndlessWave", "highestLevel", "pokemonSeen", "pokemonDefeated", "pokemonCaught", "pokemonHatched", "eggsPulled", "regularVouchers", "plusVouchers", "premiumVouchers", "goldenVouchers"}
|
||||
|
||||
var statCols []string
|
||||
@ -321,7 +321,7 @@ func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[str
|
||||
return nil
|
||||
}
|
||||
|
||||
func SetAccountBanned(uuid []byte, banned bool) error {
|
||||
func (s *store) SetAccountBanned(uuid []byte, banned bool) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET banned = ? WHERE uuid = ?", banned, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -330,17 +330,16 @@ func SetAccountBanned(uuid []byte, banned bool) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
|
||||
var key, salt []byte
|
||||
err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
|
||||
func (s *store) AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
|
||||
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return key, salt, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
|
||||
func (s *store) FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
|
||||
err = handle.QueryRow("SELECT trainerId, secretId FROM accounts WHERE uuid = ?", uuid).Scan(&trainerId, &secretId)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
@ -349,7 +348,7 @@ func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
|
||||
return trainerId, secretId, nil
|
||||
}
|
||||
|
||||
func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
|
||||
func (s *store) UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET trainerId = ?, secretId = ? WHERE uuid = ?", trainerId, secretId, uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -358,12 +357,12 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func IsActiveSession(uuid []byte, sessionId string) (bool, error) {
|
||||
func (s *store) IsActiveSession(uuid []byte, sessionId string) (bool, error) {
|
||||
var id string
|
||||
err := handle.QueryRow("SELECT clientSessionId FROM activeClientSessions WHERE uuid = ?", uuid).Scan(&id)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = UpdateActiveSession(uuid, sessionId)
|
||||
err = s.UpdateActiveSession(uuid, sessionId)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@ -377,7 +376,7 @@ func IsActiveSession(uuid []byte, sessionId string) (bool, error) {
|
||||
return id == "" || id == sessionId, nil
|
||||
}
|
||||
|
||||
func UpdateActiveSession(uuid []byte, clientSessionId string) error {
|
||||
func (s *store) UpdateActiveSession(uuid []byte, clientSessionId string) error {
|
||||
_, err := handle.Exec("INSERT INTO activeClientSessions (uuid, clientSessionId) VALUES (?, ?) ON DUPLICATE KEY UPDATE clientSessionId = ?", uuid, clientSessionId, clientSessionId)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -386,7 +385,7 @@ func UpdateActiveSession(uuid []byte, clientSessionId string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchUUIDFromToken(token []byte) ([]byte, error) {
|
||||
func (s *store) FetchUUIDFromToken(token []byte) ([]byte, error) {
|
||||
var uuid []byte
|
||||
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid)
|
||||
if err != nil {
|
||||
@ -396,7 +395,7 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) {
|
||||
return uuid, nil
|
||||
}
|
||||
|
||||
func RemoveSessionFromToken(token []byte) error {
|
||||
func (s *store) RemoveSessionFromToken(token []byte) error {
|
||||
_, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -405,7 +404,7 @@ func RemoveSessionFromToken(token []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveSessionsFromUUID(uuid []byte) error {
|
||||
func (s *store) RemoveSessionsFromUUID(uuid []byte) error {
|
||||
_, err := handle.Exec("DELETE FROM sessions WHERE uuid = ?", uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -414,7 +413,7 @@ func RemoveSessionsFromUUID(uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchUsernameFromUUID(uuid []byte) (string, error) {
|
||||
func (s *store) FetchUsernameFromUUID(uuid []byte) (string, error) {
|
||||
var username string
|
||||
err := handle.QueryRow("SELECT username FROM accounts WHERE uuid = ?", uuid).Scan(&username)
|
||||
if err != nil {
|
||||
@ -424,7 +423,7 @@ func FetchUsernameFromUUID(uuid []byte) (string, error) {
|
||||
return username, nil
|
||||
}
|
||||
|
||||
func FetchUUIDFromUsername(username string) ([]byte, error) {
|
||||
func (s *store) FetchUUIDFromUsername(username string) ([]byte, error) {
|
||||
var uuid []byte
|
||||
err := handle.QueryRow("SELECT uuid FROM accounts WHERE username = ?", username).Scan(&uuid)
|
||||
if err != nil {
|
||||
@ -434,7 +433,7 @@ func FetchUUIDFromUsername(username string) ([]byte, error) {
|
||||
return uuid, nil
|
||||
}
|
||||
|
||||
func RemoveDiscordIdByUUID(uuid []byte) error {
|
||||
func (s *store) RemoveDiscordIdByUUID(uuid []byte) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE uuid = ?", uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -443,7 +442,7 @@ func RemoveDiscordIdByUUID(uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveGoogleIdByUUID(uuid []byte) error {
|
||||
func (s *store) RemoveGoogleIdByUUID(uuid []byte) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE uuid = ?", uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -452,7 +451,7 @@ func RemoveGoogleIdByUUID(uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveGoogleIdByUsername(username string) error {
|
||||
func (s *store) RemoveGoogleIdByUsername(username string) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE username = ?", username)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -461,7 +460,7 @@ func RemoveGoogleIdByUsername(username string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveDiscordIdByUsername(username string) error {
|
||||
func (s *store) RemoveDiscordIdByUsername(username string) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE username = ?", username)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -470,7 +469,7 @@ func RemoveDiscordIdByUsername(username string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveDiscordIdByDiscordId(discordId string) error {
|
||||
func (s *store) RemoveDiscordIdByDiscordId(discordId string) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE discordId = ?", discordId)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -479,7 +478,7 @@ func RemoveDiscordIdByDiscordId(discordId string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func RemoveGoogleIdByDiscordId(discordId string) error {
|
||||
func (s *store) RemoveGoogleIdByDiscordId(discordId string) error {
|
||||
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE discordId = ?", discordId)
|
||||
if err != nil {
|
||||
return err
|
||||
|
10
db/daily.go
10
db/daily.go
@ -23,7 +23,7 @@ import (
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
func TryAddDailyRun(seed string) (string, error) {
|
||||
func (s *store) TryAddDailyRun(seed string) (string, error) {
|
||||
var actualSeed string
|
||||
err := handle.QueryRow("INSERT INTO dailyRuns (seed, date) VALUES (?, UTC_DATE()) ON DUPLICATE KEY UPDATE date = date RETURNING seed", seed).Scan(&actualSeed)
|
||||
if err != nil {
|
||||
@ -33,7 +33,7 @@ func TryAddDailyRun(seed string) (string, error) {
|
||||
return actualSeed, nil
|
||||
}
|
||||
|
||||
func GetDailyRunSeed() (string, error) {
|
||||
func (s *store) GetDailyRunSeed() (string, error) {
|
||||
var seed string
|
||||
err := handle.QueryRow("SELECT seed FROM dailyRuns WHERE date = UTC_DATE()").Scan(&seed)
|
||||
if err != nil {
|
||||
@ -43,7 +43,7 @@ func GetDailyRunSeed() (string, error) {
|
||||
return seed, nil
|
||||
}
|
||||
|
||||
func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
|
||||
func (s *store) AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
|
||||
_, err := handle.Exec("INSERT INTO accountDailyRuns (uuid, date, score, wave, timestamp) VALUES (?, UTC_DATE(), ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE score = GREATEST(score, ?), wave = GREATEST(wave, ?), timestamp = IF(score < ?, UTC_TIMESTAMP(), timestamp)", uuid, score, wave, score, wave, score)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -52,7 +52,7 @@ func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
|
||||
func (s *store) FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
|
||||
var rankings []defs.DailyRanking
|
||||
|
||||
offset := (page - 1) * 10
|
||||
@ -85,7 +85,7 @@ func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
|
||||
return rankings, nil
|
||||
}
|
||||
|
||||
func FetchRankingPageCount(category int) (int, error) {
|
||||
func (s *store) FetchRankingPageCount(category int) (int, error) {
|
||||
var query string
|
||||
switch category {
|
||||
case 0:
|
||||
|
6
db/db.go
6
db/db.go
@ -31,6 +31,12 @@ import (
|
||||
var handle *sql.DB
|
||||
var s3client *s3.Client
|
||||
|
||||
// internal type used to implement the Store interface
|
||||
type store struct{}
|
||||
|
||||
// Store is the global instance for DB access.
|
||||
var Store = &store{}
|
||||
|
||||
func Init(username, password, protocol, address, database string) error {
|
||||
var err error
|
||||
|
||||
|
@ -17,7 +17,7 @@
|
||||
|
||||
package db
|
||||
|
||||
func FetchPlayerCount() (int, error) {
|
||||
func (s *store) FetchPlayerCount() (int, error) {
|
||||
var playerCount int
|
||||
err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount)
|
||||
if err != nil {
|
||||
@ -27,7 +27,7 @@ func FetchPlayerCount() (int, error) {
|
||||
return playerCount, nil
|
||||
}
|
||||
|
||||
func FetchBattleCount() (int, error) {
|
||||
func (s *store) FetchBattleCount() (int, error) {
|
||||
var battleCount int
|
||||
err := handle.QueryRow("SELECT COALESCE(SUM(s.battles), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&battleCount)
|
||||
if err != nil {
|
||||
@ -37,7 +37,7 @@ func FetchBattleCount() (int, error) {
|
||||
return battleCount, nil
|
||||
}
|
||||
|
||||
func FetchClassicSessionCount() (int, error) {
|
||||
func (s *store) FetchClassicSessionCount() (int, error) {
|
||||
var classicSessionCount int
|
||||
err := handle.QueryRow("SELECT COALESCE(SUM(s.classicSessionsPlayed), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&classicSessionCount)
|
||||
if err != nil {
|
||||
|
61
db/s3.go
Normal file
61
db/s3.go
Normal file
@ -0,0 +1,61 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"os"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
)
|
||||
|
||||
func (s *store) GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
|
||||
var system defs.SystemSaveData
|
||||
|
||||
username, err := Store.FetchUsernameFromUUID(uuid)
|
||||
if err != nil {
|
||||
return system, err
|
||||
}
|
||||
|
||||
resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
|
||||
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
|
||||
Key: aws.String(username),
|
||||
})
|
||||
if err != nil {
|
||||
return system, err
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&system)
|
||||
if err != nil {
|
||||
return system, err
|
||||
}
|
||||
|
||||
return system, nil
|
||||
}
|
||||
|
||||
func (s *store) StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error {
|
||||
username, err := s.FetchUsernameFromUUID(uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
err = json.NewEncoder(buf).Encode(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
|
||||
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
|
||||
Key: aws.String(username),
|
||||
Body: buf,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@ -19,19 +19,13 @@ package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/gob"
|
||||
"encoding/json"
|
||||
"os"
|
||||
|
||||
"github.com/klauspost/compress/zstd"
|
||||
"github.com/pagefaultgames/rogueserver/defs"
|
||||
|
||||
"github.com/aws/aws-sdk-go-v2/aws"
|
||||
"github.com/aws/aws-sdk-go-v2/service/s3"
|
||||
)
|
||||
|
||||
func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
|
||||
func (s *store) TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
|
||||
var count int
|
||||
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
|
||||
if err != nil {
|
||||
@ -48,7 +42,7 @@ func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
|
||||
func (s *store) ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
|
||||
var count int
|
||||
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
|
||||
if err != nil {
|
||||
@ -58,7 +52,7 @@ func ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
|
||||
func (s *store) ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
|
||||
var system defs.SystemSaveData
|
||||
|
||||
var data []byte
|
||||
@ -82,7 +76,7 @@ func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
|
||||
return system, nil
|
||||
}
|
||||
|
||||
func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
|
||||
func (s *store) StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
zw, err := zstd.NewWriter(buf)
|
||||
@ -108,32 +102,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error {
|
||||
username, err := FetchUsernameFromUUID(uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
err = json.NewEncoder(buf).Encode(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
|
||||
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
|
||||
Key: aws.String(username),
|
||||
Body: buf,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteSystemSaveData(uuid []byte) error {
|
||||
func (s *store) DeleteSystemSaveData(uuid []byte) error {
|
||||
_, err := handle.Exec("DELETE FROM systemSaveData WHERE uuid = ?", uuid)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -142,7 +111,7 @@ func DeleteSystemSaveData(uuid []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
|
||||
func (s *store) ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
|
||||
var session defs.SessionSaveData
|
||||
|
||||
var data []byte
|
||||
@ -166,7 +135,7 @@ func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
|
||||
func (s *store) GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
|
||||
var slot int
|
||||
err := handle.QueryRow("SELECT slot FROM sessionSaveData WHERE uuid = ? ORDER BY timestamp DESC, slot ASC LIMIT 1", uuid).Scan(&slot)
|
||||
if err != nil {
|
||||
@ -176,7 +145,7 @@ func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
|
||||
return slot, nil
|
||||
}
|
||||
|
||||
func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error {
|
||||
func (s *store) StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error {
|
||||
buf := new(bytes.Buffer)
|
||||
|
||||
zw, err := zstd.NewWriter(buf)
|
||||
@ -202,7 +171,7 @@ func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
func DeleteSessionSaveData(uuid []byte, slot int) error {
|
||||
func (s *store) DeleteSessionSaveData(uuid []byte, slot int) error {
|
||||
_, err := handle.Exec("DELETE FROM sessionSaveData WHERE uuid = ? AND slot = ?", uuid, slot)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -210,27 +179,3 @@ func DeleteSessionSaveData(uuid []byte, slot int) error {
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
|
||||
var system defs.SystemSaveData
|
||||
|
||||
username, err := FetchUsernameFromUUID(uuid)
|
||||
if err != nil {
|
||||
return system, err
|
||||
}
|
||||
|
||||
resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
|
||||
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
|
||||
Key: aws.String(username),
|
||||
})
|
||||
if err != nil {
|
||||
return system, err
|
||||
}
|
||||
|
||||
err = json.NewDecoder(resp.Body).Decode(&system)
|
||||
if err != nil {
|
||||
return system, err
|
||||
}
|
||||
|
||||
return system, nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user