Adjust db to allow mocking for unit tests

This commit is contained in:
Sirz Benjie 2025-09-16 15:01:00 -05:00
parent 27a1e8f363
commit 40323809f7
No known key found for this signature in database
GPG Key ID: 4A524B4D196C759E
26 changed files with 577 additions and 295 deletions

View File

@ -20,11 +20,15 @@ package account
import (
"crypto/rand"
"fmt"
"github.com/pagefaultgames/rogueserver/db"
)
func ChangePW(uuid []byte, password string) error {
// Interface for database operations needed for changing password.
type ChangePWStore interface {
RemoveSessionsFromUUID(uuid []byte) error
UpdateAccountPassword(uuid []byte, newKey []byte, newSalt []byte) error
}
func ChangePW[T ChangePWStore](store T, uuid []byte, password string) error {
if len(password) < 6 {
return fmt.Errorf("invalid password")
}
@ -35,12 +39,12 @@ func ChangePW(uuid []byte, password string) error {
return fmt.Errorf("failed to generate salt: %s", err)
}
err = db.RemoveSessionsFromUUID(uuid)
err = store.RemoveSessionsFromUUID(uuid)
if err != nil {
return fmt.Errorf("failed to remove sessions: %s", err)
}
err = db.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
err = store.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil {
return fmt.Errorf("failed to add account record: %s", err)
}

View File

@ -35,14 +35,24 @@ var (
DiscordGuildID string
)
func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
type DiscordProvider interface {
HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error)
RetrieveDiscordId(code string) (string, error)
IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error)
}
type discordProvider struct{}
var Discord = &discordProvider{}
func (s *discordProvider) HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
if code == "" {
http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", errors.New("code is empty")
}
discordId, err := RetrieveDiscordId(code)
discordId, err := s.RetrieveDiscordId(code)
if err != nil {
http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", err
@ -51,7 +61,7 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro
return discordId, nil
}
func RetrieveDiscordId(code string) (string, error) {
func (s *discordProvider) RetrieveDiscordId(code string) (string, error) {
v := make(url.Values)
v.Set("client_id", DiscordClientID)
v.Set("client_secret", DiscordClientSecret)
@ -112,7 +122,7 @@ func RetrieveDiscordId(code string) (string, error) {
return user.Id, nil
}
func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
func (s *discordProvider) IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
// fetch all roles from discord
roles, err := DiscordSession.GuildRoles(discordGuildID)
if err != nil {

View File

@ -32,7 +32,16 @@ var (
GoogleCallbackURL string
)
func HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
type GoogleProvider interface {
HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error)
RetrieveGoogleId(code string) (string, error)
}
type googleProvider struct{}
var Google = &googleProvider{}
func (g *googleProvider) HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
if code == "" {
http.Redirect(w, r, GameURL, http.StatusSeeOther)

View File

@ -17,10 +17,6 @@
package account
import (
"github.com/pagefaultgames/rogueserver/db"
)
type InfoResponse struct {
Username string `json:"username"`
DiscordId string `json:"discordId"`
@ -29,9 +25,13 @@ type InfoResponse struct {
HasAdminRole bool `json:"hasAdminRole"`
}
type InfoStore interface {
GetLatestSessionSaveDataSlot(uuid []byte) (int, error)
}
// /account/info - get account info
func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
slot, _ := db.GetLatestSessionSaveDataSlot(uuid)
func Info[T InfoStore](store T, username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
slot, _ := store.GetLatestSessionSaveDataSlot(uuid)
response := InfoResponse{
Username: username,
LastSessionSlot: slot,

View File

@ -24,14 +24,18 @@ import (
"encoding/base64"
"errors"
"fmt"
"github.com/pagefaultgames/rogueserver/db"
)
type LoginResponse GenericAuthResponse
// Interface for database operations needed for login.
type LoginStore interface {
FetchAccountKeySaltFromUsername(username string) (key, salt []byte, err error)
AddAccountSession(username string, token []byte) error
}
// /account/login - log into account
func Login(username, password string) (LoginResponse, error) {
func Login[T LoginStore](store T, username, password string) (LoginResponse, error) {
var response LoginResponse
if !isValidUsername(username) {
@ -42,12 +46,11 @@ func Login(username, password string) (LoginResponse, error) {
return response, fmt.Errorf("invalid password")
}
key, salt, err := db.FetchAccountKeySaltFromUsername(username)
key, salt, err := store.FetchAccountKeySaltFromUsername(username)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return response, fmt.Errorf("account doesn't exist")
}
return response, err
}
@ -55,8 +58,7 @@ func Login(username, password string) (LoginResponse, error) {
return response, fmt.Errorf("password doesn't match")
}
response.Token, err = GenerateTokenForUsername(username)
response.Token, err = GenerateTokenForUsername(store, username)
if err != nil {
return response, fmt.Errorf("failed to generate token: %s", err)
}
@ -64,14 +66,19 @@ func Login(username, password string) (LoginResponse, error) {
return response, nil
}
func GenerateTokenForUsername(username string) (string, error) {
type GenerateTokenForUsernameStore interface {
AddAccountSession(username string, token []byte) error
}
// GenerateTokenForUsername generates a session token and stores it using the provided DBAccountStore.
func GenerateTokenForUsername[T GenerateTokenForUsernameStore](store T, username string) (string, error) {
token := make([]byte, TokenSize)
_, err := rand.Read(token)
if err != nil {
return "", fmt.Errorf("failed to generate token: %s", err)
}
err = db.AddAccountSession(username, token)
err = store.AddAccountSession(username, token)
if err != nil {
return "", fmt.Errorf("failed to add account session")
}

154
api/account/login_test.go Normal file
View File

@ -0,0 +1,154 @@
package account
import (
"database/sql"
"errors"
"testing"
)
func defaultMockStore() *mockDBAccountStore {
return &mockDBAccountStore{
FetchFunc: func(username string) ([]byte, []byte, error) {
return []byte("key"), []byte("salt"), nil
},
AddSessionFunc: func(username string, token []byte) error { return nil },
}
}
type mockDBAccountStore struct {
FetchFunc func(username string) ([]byte, []byte, error)
AddSessionFunc func(username string, token []byte) error
}
func (m *mockDBAccountStore) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
return m.FetchFunc(username)
}
func (m *mockDBAccountStore) AddAccountSession(username string, token []byte) error {
return m.AddSessionFunc(username, token)
}
func TestLogin(t *testing.T) {
t.Run("UsernameMinLength", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "a", "password123")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("UsernameMaxLength", func(t *testing.T) {
uname := "abcdefghijklmnop"
store := defaultMockStore()
_, err := Login(store, uname, "password123")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("UsernameTooLong", func(t *testing.T) {
uname := "abcdefghijklmnopq"
store := defaultMockStore()
_, err := Login(store, uname, "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error for too long username, got: %v", err)
}
})
t.Run("UsernameWithInvalidChars", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "user!@#", "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error for special chars, got: %v", err)
}
})
t.Run("EmptyUsername", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "", "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error for empty username, got: %v", err)
}
})
t.Run("EmptyPassword", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "")
if err == nil || err.Error() != "invalid password" {
t.Errorf("expected invalid password error for empty password, got: %v", err)
}
})
t.Run("MinPasswordLength", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "123456")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("PasswordWithSpecialChars", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "p@$$w0rd!")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("DBUnexpectedError", func(t *testing.T) {
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return nil, nil, errors.New("some db error")
}
_, err := Login(store, "validuser", "password123")
if err == nil || err.Error() != "some db error" {
t.Errorf("expected DB error to propagate, got: %v", err)
}
})
t.Run("InvalidUsername", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "!invaliduser", "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error, got: %v", err)
}
})
t.Run("ShortPassword", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "123")
if err == nil || err.Error() != "invalid password" {
t.Errorf("expected invalid password error, got: %v", err)
}
})
t.Run("AccountDoesNotExist", func(t *testing.T) {
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return nil, nil, sql.ErrNoRows
}
_, err := Login(store, "nonexistent", "password123")
if err == nil || err.Error() != "account doesn't exist" {
t.Errorf("expected account doesn't exist error, got: %v", err)
}
})
t.Run("PasswordMismatch", func(t *testing.T) {
correctSalt := []byte("somesalt")
correctKey := []byte("correctkey")
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return correctKey, correctSalt, nil
}
_, err := Login(store, "validuser", "wrongpassword")
if err == nil || err.Error() != "password doesn't match" {
t.Errorf("expected password doesn't match error, got: %v", err)
}
})
t.Run("Success", func(t *testing.T) {
correctSalt := []byte("somesalt")
password := "goodpassword"
correctKey := deriveArgon2IDKey([]byte(password), correctSalt)
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return correctKey, correctSalt, nil
}
store.AddSessionFunc = func(username string, token []byte) error {
return nil
}
resp, err := Login(store, "validuser", password)
if err != nil {
t.Errorf("expected success, got error: %v", err)
}
if resp.Token == "" {
t.Errorf("expected token to be set on success")
}
})
}

View File

@ -21,13 +21,17 @@ import (
"database/sql"
"errors"
"fmt"
"github.com/pagefaultgames/rogueserver/db"
)
// /account/logout - log out of account
func Logout(token []byte) error {
err := db.RemoveSessionFromToken(token)
// Interface for database operations needed for logout.
type LogoutStore interface {
RemoveSessionFromToken(token []byte) error
}
func Logout[T LogoutStore](store T, token []byte) error {
err := store.RemoveSessionFromToken(token)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("token not found")

View File

@ -20,12 +20,15 @@ package account
import (
"crypto/rand"
"fmt"
"github.com/pagefaultgames/rogueserver/db"
)
// Interface for database operations needed for registration.
type RegisterStore interface {
AddAccountRecord(uuid []byte, username string, passwordHash []byte, salt []byte) error
}
// /account/register - register account
func Register(username, password string) error {
func Register[T RegisterStore](store T, username, password string) error {
if !isValidUsername(username) {
return fmt.Errorf("invalid username")
}
@ -46,7 +49,7 @@ func Register(username, password string) error {
return fmt.Errorf("failed to generate salt: %s", err)
}
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
err = store.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil {
return fmt.Errorf("failed to add account record: %s", err)
}

View File

@ -30,7 +30,7 @@ import (
)
func Init(mux *http.ServeMux) error {
err := scheduleStatRefresh()
err := scheduleStatRefresh(db.Store)
if err != nil {
return err
}
@ -109,7 +109,7 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) {
return nil, nil, err
}
uuid, err := db.FetchUUIDFromToken(token)
uuid, err := db.Store.FetchUUIDFromToken(token)
if err != nil {
return nil, nil, fmt.Errorf("failed to validate token: %s", err)
}

View File

@ -61,7 +61,7 @@ func Init() error {
secret = newSecret
}
seed, err := db.TryAddDailyRun(Seed())
seed, err := db.Store.TryAddDailyRun(Seed())
if err != nil {
log.Print(err)
}
@ -71,7 +71,7 @@ func Init() error {
_, err = scheduler.AddFunc("@daily", func() {
time.Sleep(time.Second)
seed, err = db.TryAddDailyRun(Seed())
seed, err = db.Store.TryAddDailyRun(Seed())
if err != nil {
log.Printf("error while recording new daily: %s", err)
} else {

View File

@ -22,9 +22,14 @@ import (
"github.com/pagefaultgames/rogueserver/defs"
)
// Interface for database operations needed for fetching rankings.
type RankingsStore interface {
FetchRankings(category, page int) ([]defs.DailyRanking, error)
}
// /daily/rankings - fetch daily rankings
func Rankings(category, page int) ([]defs.DailyRanking, error) {
rankings, err := db.FetchRankings(category, page)
func Rankings[T RankingsStore](store T, category, page int) ([]defs.DailyRanking, error) {
rankings, err := db.Store.FetchRankings(category, page)
if err != nil {
return rankings, err
}

View File

@ -17,13 +17,13 @@
package daily
import (
"github.com/pagefaultgames/rogueserver/db"
)
type RankingPageCountStore interface {
FetchRankingPageCount(category int) (int, error)
}
// /daily/rankingpagecount - fetch daily ranking page count
func RankingPageCount(category int) (int, error) {
pageCount, err := db.FetchRankingPageCount(category)
func RankingPageCount[T RankingPageCountStore](store T, category int) (int, error) {
pageCount, err := store.FetchRankingPageCount(category)
if err != nil {
return pageCount, err
}

View File

@ -50,17 +50,17 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
return
}
username, err := db.FetchUsernameFromUUID(uuid)
username, err := db.Store.FetchUsernameFromUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
discordId, err := db.FetchDiscordIdByUsername(username)
discordId, err := db.Store.FetchDiscordIdByUsername(username)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpError(w, r, err, http.StatusInternalServerError)
return
}
googleId, err := db.FetchGoogleIdByUsername(username)
googleId, err := db.Store.FetchGoogleIdByUsername(username)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -68,10 +68,10 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
var hasAdminRole bool
if discordId != "" {
hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
hasAdminRole, _ = account.Discord.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
}
response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole)
response, err := account.Info(db.Store, username, discordId, googleId, uuid, hasAdminRole)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -81,7 +81,7 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
}
func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
err := account.Register(r.PostFormValue("username"), r.PostFormValue("password"))
err := account.Register(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -91,7 +91,7 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
}
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
response, err := account.Login(r.PostFormValue("username"), r.PostFormValue("password"))
response, err := account.Login(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -107,20 +107,20 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) {
return
}
err = account.ChangePW(uuid, r.PostFormValue("password"))
err = account.ChangePW(db.Store, uuid, r.PostFormValue("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
username, err := db.FetchUsernameFromUUID(uuid)
username, err := db.Store.FetchUsernameFromUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
// create a new session with these credentials
response, err := account.Login(username, r.Form.Get("password"))
response, err := account.Login(db.Store, username, r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -136,7 +136,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
return
}
err = account.Logout(token)
err = account.Logout(db.Store, token)
if err != nil {
// also possible for InternalServerError but that's unlikely unless the server blew up
httpError(w, r, err, http.StatusUnauthorized)
@ -183,7 +183,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return
}
err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
@ -191,7 +191,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("action") {
case "get":
save, err := savedata.GetSession(uuid, slot)
save, err := savedata.GetSession(db.Store, uuid, slot)
if err != nil {
if errors.Is(err, savedata.ErrSaveNotExist) {
http.Error(w, err.Error(), http.StatusNotFound)
@ -211,7 +211,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return
}
existingSave, err := savedata.GetSession(uuid, slot)
existingSave, err := savedata.GetSession(db.Store, uuid, slot)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
@ -224,7 +224,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
}
}
err = savedata.UpdateSession(uuid, slot, session)
err = savedata.UpdateSession(db.Store, uuid, slot, session)
if err != nil {
httpError(w, r, fmt.Errorf("failed to put session data: %s", err), http.StatusInternalServerError)
return
@ -239,13 +239,13 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return
}
seed, err := db.GetDailyRunSeed()
seed, err := db.Store.GetDailyRunSeed()
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
resp, err := savedata.Clear(uuid, slot, seed, session)
resp, err := savedata.Clear(db.Store, uuid, slot, seed, session)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -253,7 +253,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, resp)
case "newclear":
resp, err := savedata.NewClear(uuid, slot)
resp, err := savedata.NewClear(db.Store, uuid, slot)
if err != nil {
httpError(w, r, fmt.Errorf("failed to read new clear: %s", err), http.StatusInternalServerError)
return
@ -261,7 +261,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, resp)
case "delete":
err := savedata.DeleteSession(uuid, slot)
err := savedata.DeleteSession(db.Store, uuid, slot)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -301,7 +301,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return
}
active, err := db.IsActiveSession(uuid, data.ClientSessionId)
active, err := db.Store.IsActiveSession(uuid, data.ClientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
@ -312,7 +312,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return
}
storedTrainerId, storedSecretId, err := db.FetchTrainerIds(uuid)
storedTrainerId, storedSecretId, err := db.Store.FetchTrainerIds(uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -324,14 +324,14 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return
}
} else {
err = db.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid)
err = db.Store.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
}
oldSystem, err := savedata.GetSystem(uuid)
oldSystem, err := savedata.GetSystem(db.Store, uuid)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
@ -356,7 +356,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
}
}
existingSave, err := savedata.GetSession(uuid, data.SessionSlotId)
existingSave, err := savedata.GetSession(db.Store, uuid, data.SessionSlotId)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
@ -369,13 +369,13 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
}
}
err = savedata.Update(uuid, data.SessionSlotId, data.Session)
err = savedata.Update(db.Store, uuid, data.SessionSlotId, data.Session)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = savedata.Update(uuid, 0, data.System)
err = savedata.Update(db.Store, uuid, 0, data.System)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -401,7 +401,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
return
}
active, err := db.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
active, err := db.Store.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
@ -410,14 +410,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("action") {
case "get":
if !active {
err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}
}
save, err := savedata.GetSystem(uuid)
save, err := savedata.GetSystem(db.Store, uuid)
if err != nil {
if errors.Is(err, savedata.ErrSaveNotExist) {
http.Error(w, err.Error(), http.StatusNotFound)
@ -442,7 +442,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
return
}
oldSystem, err := savedata.GetSystem(uuid)
oldSystem, err := savedata.GetSystem(db.Store, uuid)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
@ -467,7 +467,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
}
}
err = savedata.UpdateSystem(uuid, system)
err = savedata.UpdateSystem(db.Store, uuid, system)
if err != nil {
httpError(w, r, fmt.Errorf("failed to put system data: %s", err), http.StatusInternalServerError)
return
@ -481,13 +481,13 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
// not valid, send server state
if !active {
err := db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
err := db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}
storedSaveData, err := db.ReadSystemSaveData(uuid)
storedSaveData, err := db.Store.ReadSystemSaveData(uuid)
if err != nil {
httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError)
return
@ -498,7 +498,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, response)
case "delete":
err := savedata.DeleteSystem(uuid)
err := savedata.DeleteSystem(db.Store, uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -511,9 +511,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
}
}
// Interface providing database operations needed for getting daily seed.
type HandleDailySeedStore interface {
GetDailyRunSeed() (string, error)
}
// daily
func handleDailySeed(w http.ResponseWriter, r *http.Request) {
seed, err := db.GetDailyRunSeed()
seed, err := db.Store.GetDailyRunSeed()
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -522,6 +527,11 @@ func handleDailySeed(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, seed)
}
// Interface for database operations needed for getting daily rankings.
type HandleDailyRankingsStore interface {
daily.RankingsStore
}
func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
var err error
@ -543,7 +553,7 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
}
}
rankings, err := daily.Rankings(category, page)
rankings, err := daily.Rankings(db.Store, category, page)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -552,6 +562,10 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, rankings)
}
type HandleDailyRankingsPageCountStore interface {
daily.RankingPageCountStore
}
func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
var category int
if r.URL.Query().Has("category") {
@ -563,7 +577,7 @@ func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
}
}
count, err := daily.RankingPageCount(category)
count, err := daily.RankingPageCount(db.Store, category)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
}
@ -579,9 +593,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
var err error
switch provider {
case "discord":
externalAuthId, err = account.HandleDiscordCallback(w, r)
externalAuthId, err = account.Discord.HandleDiscordCallback(w, r)
case "google":
externalAuthId, err = account.HandleGoogleCallback(w, r)
externalAuthId, err = account.Google.HandleGoogleCallback(w, r)
default:
http.Error(w, "invalid provider", http.StatusBadRequest)
return
@ -600,7 +614,7 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
return
}
userName, err := db.FetchUsernameBySessionToken(stateByte)
userName, err := db.Store.FetchUsernameBySessionToken(stateByte)
if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return
@ -608,9 +622,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
switch provider {
case "discord":
err = db.AddDiscordIdByUsername(externalAuthId, userName)
err = db.Store.AddDiscordIdByUsername(externalAuthId, userName)
case "google":
err = db.AddGoogleIdByUsername(externalAuthId, userName)
err = db.Store.AddGoogleIdByUsername(externalAuthId, userName)
}
if err != nil {
@ -622,16 +636,16 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
var userName string
switch provider {
case "discord":
userName, err = db.FetchUsernameByDiscordId(externalAuthId)
userName, err = db.Store.FetchUsernameByDiscordId(externalAuthId)
case "google":
userName, err = db.FetchUsernameByGoogleId(externalAuthId)
userName, err = db.Store.FetchUsernameByGoogleId(externalAuthId)
}
if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return
}
sessionToken, err := account.GenerateTokenForUsername(userName)
sessionToken, err := account.GenerateTokenForUsername(db.Store, userName)
if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return
@ -651,6 +665,11 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
}
type HandleProviderLogoutStore interface {
RemoveDiscordIdByUUID(uuid []byte) error
RemoveGoogleIdByUUID(uuid []byte) error
}
func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
@ -660,9 +679,9 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("provider") {
case "discord":
err = db.RemoveDiscordIdByUUID(uuid)
err = db.Store.RemoveDiscordIdByUUID(uuid)
case "google":
err = db.RemoveGoogleIdByUUID(uuid)
err = db.Store.RemoveGoogleIdByUUID(uuid)
default:
http.Error(w, "invalid provider", http.StatusBadRequest)
return
@ -681,13 +700,13 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
return
}
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@ -698,19 +717,19 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username)
_, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
userUuid, err := db.FetchUUIDFromUsername(username)
userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = db.AddDiscordIdByUUID(discordId, userUuid)
err = db.Store.AddDiscordIdByUUID(discordId, userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -728,13 +747,13 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
return
}
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@ -748,26 +767,26 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
log.Printf("Username given, removing discordId")
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username)
_, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
userUuid, err := db.FetchUUIDFromUsername(username)
userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = db.RemoveDiscordIdByUUID(userUuid)
err = db.Store.RemoveDiscordIdByUUID(userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
case discordId != "":
log.Printf("DiscordID given, removing discordId")
err = db.RemoveDiscordIdByDiscordId(discordId)
err = db.Store.RemoveDiscordIdByDiscordId(discordId)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -786,13 +805,13 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
return
}
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@ -803,19 +822,19 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username)
_, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
userUuid, err := db.FetchUUIDFromUsername(username)
userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = db.AddGoogleIdByUUID(googleId, userUuid)
err = db.Store.AddGoogleIdByUUID(googleId, userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -833,13 +852,13 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
return
}
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@ -853,26 +872,26 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
log.Printf("Username given, removing googleId")
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username)
_, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
userUuid, err := db.FetchUUIDFromUsername(username)
userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
err = db.RemoveGoogleIdByUUID(userUuid)
err = db.Store.RemoveGoogleIdByUUID(userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
case googleId != "":
log.Printf("DiscordID given, removing googleId")
err = db.RemoveGoogleIdByDiscordId(googleId)
err = db.Store.RemoveGoogleIdByDiscordId(googleId)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@ -891,13 +910,13 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
return
}
userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@ -907,14 +926,14 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username)
_, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
// this does a single call that does a query for multiple columns from our database and makes an object out of it, which is returned to us
adminSearchResult, err := db.FetchAdminDetailsByUsername(username)
adminSearchResult, err := db.Store.FetchAdminDetailsByUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return

View File

@ -21,7 +21,6 @@ import (
"fmt"
"log"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
@ -30,10 +29,20 @@ type ClearResponse struct {
Error string `json:"error"`
}
// Interface for database operations needed for `Clear`
// Helps with testing and reduces coupling.
type ClearStore interface {
UpdateAccountLastActivity(uuid []byte) error
TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error)
DeleteSessionSaveData(uuid []byte, slot int) error
AddOrUpdateAccountDailyRun(uuid []byte, score int, waveCompleted int) error
SetAccountBanned(uuid []byte, banned bool) error
}
// /savedata/clear - mark session save data as cleared and delete
func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) {
func Clear[T ClearStore](store T, uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) {
var response ClearResponse
err := db.UpdateAccountLastActivity(uuid)
err := store.UpdateAccountLastActivity(uuid)
if err != nil {
log.Print("failed to update account last activity")
}
@ -51,23 +60,23 @@ func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (Clear
}
if save.Score >= 20000 {
db.SetAccountBanned(uuid, true)
store.SetAccountBanned(uuid, true)
}
err = db.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
err = store.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
if err != nil {
log.Printf("failed to add or update daily run record: %s", err)
}
}
if sessionCompleted {
response.Success, err = db.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
response.Success, err = store.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
if err != nil {
log.Printf("failed to mark seed as completed: %s", err)
}
}
err = db.DeleteSessionSaveData(uuid, slot)
err = store.DeleteSessionSaveData(uuid, slot)
if err != nil {
log.Printf("failed to delete session save data: %s", err)
}

View File

@ -21,13 +21,19 @@ import (
"fmt"
"log"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
// Interface for database operations needed for delete.
// This is to allow easier testing and less coupling.
type DeleteStore interface {
UpdateAccountLastActivity(uuid []byte) error
DeleteSessionSaveData(uuid []byte, slot int) error
}
// /savedata/delete - delete save data
func Delete(uuid []byte, datatype, slot int) error {
err := db.UpdateAccountLastActivity(uuid)
func Delete[T DeleteStore](store T, uuid []byte, datatype, slot int) error {
err := store.UpdateAccountLastActivity(uuid)
if err != nil {
log.Print("failed to update account last activity")
}
@ -39,7 +45,7 @@ func Delete(uuid []byte, datatype, slot int) error {
break
}
err = db.DeleteSessionSaveData(uuid, slot)
err = store.DeleteSessionSaveData(uuid, slot)
default:
err = fmt.Errorf("invalid data type")
}

View File

@ -20,22 +20,26 @@ package savedata
import (
"fmt"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
type NewClearStore interface {
ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
ReadSeedCompleted(uuid []byte, seed string) (bool, error)
}
// /savedata/newclear - return whether a session is a new clear for its seed
func NewClear(uuid []byte, slot int) (bool, error) {
func NewClear[T NewClearStore](store T, uuid []byte, slot int) (bool, error) {
if slot < 0 || slot >= defs.SessionSlotCount {
return false, fmt.Errorf("slot id %d out of range", slot)
}
session, err := db.ReadSessionSaveData(uuid, slot)
session, err := store.ReadSessionSaveData(uuid, slot)
if err != nil {
return false, err
}
completed, err := db.ReadSeedCompleted(uuid, session.Seed)
completed, err := store.ReadSeedCompleted(uuid, session.Seed)
if err != nil {
return false, fmt.Errorf("failed to read seed completed: %s", err)
}

View File

@ -21,12 +21,15 @@ import (
"database/sql"
"errors"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) {
session, err := db.ReadSessionSaveData(uuid, slot)
type GetSessionStore interface {
ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
}
func GetSession[T GetSessionStore](store T, uuid []byte, slot int) (defs.SessionSaveData, error) {
session, err := store.ReadSessionSaveData(uuid, slot)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = ErrSaveNotExist
@ -38,8 +41,13 @@ func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) {
return session, nil
}
func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error {
err := db.StoreSessionSaveData(uuid, data, slot)
type UpdateSessionStore interface {
StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error
DeleteSessionSaveData(uuid []byte, slot int) error
}
func UpdateSession[T UpdateSessionStore](store T, uuid []byte, slot int, data defs.SessionSaveData) error {
err := store.StoreSessionSaveData(uuid, data, slot)
if err != nil {
return err
}
@ -47,8 +55,12 @@ func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error {
return nil
}
func DeleteSession(uuid []byte, slot int) error {
err := db.DeleteSessionSaveData(uuid, slot)
type DeleteSessionStore interface {
DeleteSessionSaveData(uuid []byte, slot int) error
}
func DeleteSession[T DeleteSessionStore](store T, uuid []byte, slot int) error {
err := store.DeleteSessionSaveData(uuid, slot)
if err != nil {
return err
}

View File

@ -24,24 +24,28 @@ import (
"os"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
var ErrSaveNotExist = errors.New("save does not exist")
func GetSystem(uuid []byte) (defs.SystemSaveData, error) {
type GetSystemStore interface {
GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error)
ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error)
}
func GetSystem[T GetSystemStore](store T, uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
var err error
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
system, err = db.GetSystemSaveFromS3(uuid)
system, err = store.GetSystemSaveFromS3(uuid)
var nokey *types.NoSuchKey
if errors.As(err, &nokey) {
err = ErrSaveNotExist
}
} else { // use database
system, err = db.ReadSystemSaveData(uuid)
system, err = store.ReadSystemSaveData(uuid)
if errors.Is(err, sql.ErrNoRows) {
err = ErrSaveNotExist
}
@ -53,20 +57,27 @@ func GetSystem(uuid []byte) (defs.SystemSaveData, error) {
return system, nil
}
func UpdateSystem(uuid []byte, data defs.SystemSaveData) error {
// Interface for database operations needed for updating system data.
type UpdateSystemStore interface {
UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error
StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error
StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error
}
func UpdateSystem[T UpdateSystemStore](store T, uuid []byte, data defs.SystemSaveData) error {
if data.TrainerId == 0 && data.SecretId == 0 {
return fmt.Errorf("invalid system data")
}
err := db.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts)
err := store.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts)
if err != nil {
return fmt.Errorf("failed to update account stats: %s", err)
}
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
err = db.StoreSystemSaveDataS3(uuid, data)
err = store.StoreSystemSaveDataS3(uuid, data)
} else {
err = db.StoreSystemSaveData(uuid, data)
err = store.StoreSystemSaveData(uuid, data)
}
if err != nil {
return err
@ -75,8 +86,12 @@ func UpdateSystem(uuid []byte, data defs.SystemSaveData) error {
return nil
}
func DeleteSystem(uuid []byte) error {
err := db.DeleteSystemSaveData(uuid)
type DeleteSystemStore interface {
DeleteSystemSaveData(uuid []byte) error
}
func DeleteSystem[T DeleteSystemStore](store T, uuid []byte) error {
err := store.DeleteSystemSaveData(uuid)
if err != nil {
return err
}

View File

@ -21,13 +21,18 @@ import (
"fmt"
"log"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
type UpdateStore interface {
UpdateAccountLastActivity(uuid []byte) error
UpdateSystemStore
UpdateSessionStore
}
// /savedata/update - update save data
func Update(uuid []byte, slot int, save any) error {
err := db.UpdateAccountLastActivity(uuid)
func Update[T UpdateStore](store T, uuid []byte, slot int, save any) error {
err := store.UpdateAccountLastActivity(uuid)
if err != nil {
log.Print("failed to update account last activity")
}
@ -38,13 +43,13 @@ func Update(uuid []byte, slot int, save any) error {
return fmt.Errorf("invalid system data")
}
return UpdateSystem(uuid, save)
return UpdateSystem(store, uuid, save)
case defs.SessionSaveData: // Session
if slot < 0 || slot >= defs.SessionSlotCount {
return fmt.Errorf("slot id %d out of range", slot)
}
return UpdateSession(uuid, slot, save)
return UpdateSession(store, uuid, slot, save)
default:
return fmt.Errorf("invalid data type")
}

View File

@ -21,7 +21,6 @@ import (
"log"
"time"
"github.com/pagefaultgames/rogueserver/db"
"github.com/robfig/cron/v3"
)
@ -32,9 +31,9 @@ var (
classicSessionCount int
)
func scheduleStatRefresh() error {
func scheduleStatRefresh[T updateStatsStore](store T) error {
_, err := scheduler.AddFunc("@every 30s", func() {
err := updateStats()
err := updateStats(store)
if err != nil {
log.Printf("failed to update stats: %s", err)
}
@ -47,19 +46,25 @@ func scheduleStatRefresh() error {
return nil
}
func updateStats() error {
type updateStatsStore interface {
FetchPlayerCount() (int, error)
FetchBattleCount() (int, error)
FetchClassicSessionCount() (int, error)
}
func updateStats[T updateStatsStore](store T) error {
var err error
playerCount, err = db.FetchPlayerCount()
playerCount, err = store.FetchPlayerCount()
if err != nil {
return err
}
battleCount, err = db.FetchBattleCount()
battleCount, err = store.FetchBattleCount()
if err != nil {
return err
}
classicSessionCount, err = db.FetchClassicSessionCount()
classicSessionCount, err = store.FetchClassicSessionCount()
if err != nil {
return err
}

View File

@ -1,18 +1,18 @@
/*
Copyright (C) 2024 - 2025 Pagefault Games
Copyright (C) 2024 - 2025 Pagefault Games
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
package db
@ -27,30 +27,30 @@ import (
"github.com/pagefaultgames/rogueserver/defs"
)
func AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
if err != nil {
return err
}
return nil
}
func AddAccountSession(username string, token []byte) error {
func (s *store) AddAccountSession(username string, token []byte) error {
_, err := handle.Exec("INSERT INTO sessions (uuid, token, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username)
if err != nil {
return err
}
_, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username)
if err != nil {
return err
}
return nil
}
func AddDiscordIdByUsername(discordId string, username string) error {
// (removed, now a method on store)
func (s *store) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
var key, salt []byte
err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
if err != nil {
return nil, nil, err
}
return key, salt, nil
}
func (s *store) AddDiscordIdByUsername(discordId string, username string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE username = ?", discordId, username)
if err != nil {
return err
@ -59,7 +59,7 @@ func AddDiscordIdByUsername(discordId string, username string) error {
return nil
}
func AddGoogleIdByUsername(googleId string, username string) error {
func (s *store) AddGoogleIdByUsername(googleId string, username string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE username = ?", googleId, username)
if err != nil {
return err
@ -68,7 +68,7 @@ func AddGoogleIdByUsername(googleId string, username string) error {
return nil
}
func AddGoogleIdByUUID(googleId string, uuid []byte) error {
func (s *store) AddGoogleIdByUUID(googleId string, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE uuid = ?", googleId, uuid)
if err != nil {
return err
@ -77,7 +77,7 @@ func AddGoogleIdByUUID(googleId string, uuid []byte) error {
return nil
}
func AddDiscordIdByUUID(discordId string, uuid []byte) error {
func (s *store) AddDiscordIdByUUID(discordId string, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE uuid = ?", discordId, uuid)
if err != nil {
return err
@ -86,7 +86,7 @@ func AddDiscordIdByUUID(discordId string, uuid []byte) error {
return nil
}
func FetchUsernameByDiscordId(discordId string) (string, error) {
func (s *store) FetchUsernameByDiscordId(discordId string) (string, error) {
var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE discordId = ?", discordId).Scan(&username)
if err != nil {
@ -96,7 +96,7 @@ func FetchUsernameByDiscordId(discordId string) (string, error) {
return username, nil
}
func FetchUsernameByGoogleId(googleId string) (string, error) {
func (s *store) FetchUsernameByGoogleId(googleId string) (string, error) {
var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE googleId = ?", googleId).Scan(&username)
if err != nil {
@ -106,7 +106,7 @@ func FetchUsernameByGoogleId(googleId string) (string, error) {
return username, nil
}
func FetchDiscordIdByUsername(username string) (string, error) {
func (s *store) FetchDiscordIdByUsername(username string) (string, error) {
var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE username = ?", username).Scan(&discordId)
if err != nil {
@ -120,7 +120,7 @@ func FetchDiscordIdByUsername(username string) (string, error) {
return discordId.String, nil
}
func FetchGoogleIdByUsername(username string) (string, error) {
func (s *store) FetchGoogleIdByUsername(username string) (string, error) {
var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE username = ?", username).Scan(&googleId)
if err != nil {
@ -134,7 +134,7 @@ func FetchGoogleIdByUsername(username string) (string, error) {
return googleId.String, nil
}
func FetchDiscordIdByUUID(uuid []byte) (string, error) {
func (s *store) FetchDiscordIdByUUID(uuid []byte) (string, error) {
var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId)
if err != nil {
@ -148,7 +148,7 @@ func FetchDiscordIdByUUID(uuid []byte) (string, error) {
return discordId.String, nil
}
func FetchGoogleIdByUUID(uuid []byte) (string, error) {
func (s *store) FetchGoogleIdByUUID(uuid []byte) (string, error) {
var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId)
if err != nil {
@ -162,7 +162,7 @@ func FetchGoogleIdByUUID(uuid []byte) (string, error) {
return googleId.String, nil
}
func FetchUsernameBySessionToken(token []byte) (string, error) {
func (s *store) FetchUsernameBySessionToken(token []byte) (string, error) {
var username string
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username)
if err != nil {
@ -172,7 +172,7 @@ func FetchUsernameBySessionToken(token []byte) (string, error) {
return username, nil
}
func CheckUsernameExists(username string) (string, error) {
func (s *store) CheckUsernameExists(username string) (string, error) {
var dbUsername sql.NullString
err := handle.QueryRow("SELECT username FROM accounts WHERE username = ?", username).Scan(&dbUsername)
if err != nil {
@ -185,7 +185,7 @@ func CheckUsernameExists(username string) (string, error) {
return dbUsername.String, nil
}
func FetchLastLoggedInDateByUsername(username string) (string, error) {
func (s *store) FetchLastLoggedInDateByUsername(username string) (string, error) {
var lastLoggedIn sql.NullString
err := handle.QueryRow("SELECT lastLoggedIn FROM accounts WHERE username = ?", username).Scan(&lastLoggedIn)
if err != nil {
@ -206,7 +206,7 @@ type AdminSearchResponse struct {
Registered string `json:"registered"`
}
func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) {
func (s *store) FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) {
var username, discordId, googleId, lastActivity, registered sql.NullString
var adminResponse AdminSearchResponse
@ -226,7 +226,7 @@ func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error)
return adminResponse, nil
}
func UpdateAccountPassword(uuid, key, salt []byte) error {
func (s *store) UpdateAccountPassword(uuid, key, salt []byte) error {
_, err := handle.Exec("UPDATE accounts SET hash = ?, salt = ? WHERE uuid = ?", key, salt, uuid)
if err != nil {
return err
@ -235,7 +235,7 @@ func UpdateAccountPassword(uuid, key, salt []byte) error {
return nil
}
func UpdateAccountLastActivity(uuid []byte) error {
func (s *store) UpdateAccountLastActivity(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET lastActivity = UTC_TIMESTAMP() WHERE uuid = ?", uuid)
if err != nil {
return err
@ -244,7 +244,7 @@ func UpdateAccountLastActivity(uuid []byte) error {
return nil
}
func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
func (s *store) UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
var columns = []string{"playTime", "battles", "classicSessionsPlayed", "sessionsWon", "highestEndlessWave", "highestLevel", "pokemonSeen", "pokemonDefeated", "pokemonCaught", "pokemonHatched", "eggsPulled", "regularVouchers", "plusVouchers", "premiumVouchers", "goldenVouchers"}
var statCols []string
@ -321,7 +321,7 @@ func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[str
return nil
}
func SetAccountBanned(uuid []byte, banned bool) error {
func (s *store) SetAccountBanned(uuid []byte, banned bool) error {
_, err := handle.Exec("UPDATE accounts SET banned = ? WHERE uuid = ?", banned, uuid)
if err != nil {
return err
@ -330,17 +330,16 @@ func SetAccountBanned(uuid []byte, banned bool) error {
return nil
}
func FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
var key, salt []byte
err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
func (s *store) AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
if err != nil {
return nil, nil, err
return err
}
return key, salt, nil
return nil
}
func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
func (s *store) FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
err = handle.QueryRow("SELECT trainerId, secretId FROM accounts WHERE uuid = ?", uuid).Scan(&trainerId, &secretId)
if err != nil {
return 0, 0, err
@ -349,7 +348,7 @@ func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
return trainerId, secretId, nil
}
func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
func (s *store) UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET trainerId = ?, secretId = ? WHERE uuid = ?", trainerId, secretId, uuid)
if err != nil {
return err
@ -358,12 +357,12 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
return nil
}
func IsActiveSession(uuid []byte, sessionId string) (bool, error) {
func (s *store) IsActiveSession(uuid []byte, sessionId string) (bool, error) {
var id string
err := handle.QueryRow("SELECT clientSessionId FROM activeClientSessions WHERE uuid = ?", uuid).Scan(&id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = UpdateActiveSession(uuid, sessionId)
err = s.UpdateActiveSession(uuid, sessionId)
if err != nil {
return false, err
}
@ -377,7 +376,7 @@ func IsActiveSession(uuid []byte, sessionId string) (bool, error) {
return id == "" || id == sessionId, nil
}
func UpdateActiveSession(uuid []byte, clientSessionId string) error {
func (s *store) UpdateActiveSession(uuid []byte, clientSessionId string) error {
_, err := handle.Exec("INSERT INTO activeClientSessions (uuid, clientSessionId) VALUES (?, ?) ON DUPLICATE KEY UPDATE clientSessionId = ?", uuid, clientSessionId, clientSessionId)
if err != nil {
return err
@ -386,7 +385,7 @@ func UpdateActiveSession(uuid []byte, clientSessionId string) error {
return nil
}
func FetchUUIDFromToken(token []byte) ([]byte, error) {
func (s *store) FetchUUIDFromToken(token []byte) ([]byte, error) {
var uuid []byte
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid)
if err != nil {
@ -396,7 +395,7 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) {
return uuid, nil
}
func RemoveSessionFromToken(token []byte) error {
func (s *store) RemoveSessionFromToken(token []byte) error {
_, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token)
if err != nil {
return err
@ -405,7 +404,7 @@ func RemoveSessionFromToken(token []byte) error {
return nil
}
func RemoveSessionsFromUUID(uuid []byte) error {
func (s *store) RemoveSessionsFromUUID(uuid []byte) error {
_, err := handle.Exec("DELETE FROM sessions WHERE uuid = ?", uuid)
if err != nil {
return err
@ -414,7 +413,7 @@ func RemoveSessionsFromUUID(uuid []byte) error {
return nil
}
func FetchUsernameFromUUID(uuid []byte) (string, error) {
func (s *store) FetchUsernameFromUUID(uuid []byte) (string, error) {
var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE uuid = ?", uuid).Scan(&username)
if err != nil {
@ -424,7 +423,7 @@ func FetchUsernameFromUUID(uuid []byte) (string, error) {
return username, nil
}
func FetchUUIDFromUsername(username string) ([]byte, error) {
func (s *store) FetchUUIDFromUsername(username string) ([]byte, error) {
var uuid []byte
err := handle.QueryRow("SELECT uuid FROM accounts WHERE username = ?", username).Scan(&uuid)
if err != nil {
@ -434,7 +433,7 @@ func FetchUUIDFromUsername(username string) ([]byte, error) {
return uuid, nil
}
func RemoveDiscordIdByUUID(uuid []byte) error {
func (s *store) RemoveDiscordIdByUUID(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE uuid = ?", uuid)
if err != nil {
return err
@ -443,7 +442,7 @@ func RemoveDiscordIdByUUID(uuid []byte) error {
return nil
}
func RemoveGoogleIdByUUID(uuid []byte) error {
func (s *store) RemoveGoogleIdByUUID(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE uuid = ?", uuid)
if err != nil {
return err
@ -452,7 +451,7 @@ func RemoveGoogleIdByUUID(uuid []byte) error {
return nil
}
func RemoveGoogleIdByUsername(username string) error {
func (s *store) RemoveGoogleIdByUsername(username string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE username = ?", username)
if err != nil {
return err
@ -461,7 +460,7 @@ func RemoveGoogleIdByUsername(username string) error {
return nil
}
func RemoveDiscordIdByUsername(username string) error {
func (s *store) RemoveDiscordIdByUsername(username string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE username = ?", username)
if err != nil {
return err
@ -470,7 +469,7 @@ func RemoveDiscordIdByUsername(username string) error {
return nil
}
func RemoveDiscordIdByDiscordId(discordId string) error {
func (s *store) RemoveDiscordIdByDiscordId(discordId string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE discordId = ?", discordId)
if err != nil {
return err
@ -479,7 +478,7 @@ func RemoveDiscordIdByDiscordId(discordId string) error {
return nil
}
func RemoveGoogleIdByDiscordId(discordId string) error {
func (s *store) RemoveGoogleIdByDiscordId(discordId string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE discordId = ?", discordId)
if err != nil {
return err

View File

@ -23,7 +23,7 @@ import (
"github.com/pagefaultgames/rogueserver/defs"
)
func TryAddDailyRun(seed string) (string, error) {
func (s *store) TryAddDailyRun(seed string) (string, error) {
var actualSeed string
err := handle.QueryRow("INSERT INTO dailyRuns (seed, date) VALUES (?, UTC_DATE()) ON DUPLICATE KEY UPDATE date = date RETURNING seed", seed).Scan(&actualSeed)
if err != nil {
@ -33,7 +33,7 @@ func TryAddDailyRun(seed string) (string, error) {
return actualSeed, nil
}
func GetDailyRunSeed() (string, error) {
func (s *store) GetDailyRunSeed() (string, error) {
var seed string
err := handle.QueryRow("SELECT seed FROM dailyRuns WHERE date = UTC_DATE()").Scan(&seed)
if err != nil {
@ -43,7 +43,7 @@ func GetDailyRunSeed() (string, error) {
return seed, nil
}
func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
func (s *store) AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
_, err := handle.Exec("INSERT INTO accountDailyRuns (uuid, date, score, wave, timestamp) VALUES (?, UTC_DATE(), ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE score = GREATEST(score, ?), wave = GREATEST(wave, ?), timestamp = IF(score < ?, UTC_TIMESTAMP(), timestamp)", uuid, score, wave, score, wave, score)
if err != nil {
return err
@ -52,7 +52,7 @@ func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
return nil
}
func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
func (s *store) FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
var rankings []defs.DailyRanking
offset := (page - 1) * 10
@ -85,7 +85,7 @@ func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
return rankings, nil
}
func FetchRankingPageCount(category int) (int, error) {
func (s *store) FetchRankingPageCount(category int) (int, error) {
var query string
switch category {
case 0:

View File

@ -31,6 +31,12 @@ import (
var handle *sql.DB
var s3client *s3.Client
// internal type used to implement the Store interface
type store struct{}
// Store is the global instance for DB access.
var Store = &store{}
func Init(username, password, protocol, address, database string) error {
var err error

View File

@ -17,7 +17,7 @@
package db
func FetchPlayerCount() (int, error) {
func (s *store) FetchPlayerCount() (int, error) {
var playerCount int
err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount)
if err != nil {
@ -27,7 +27,7 @@ func FetchPlayerCount() (int, error) {
return playerCount, nil
}
func FetchBattleCount() (int, error) {
func (s *store) FetchBattleCount() (int, error) {
var battleCount int
err := handle.QueryRow("SELECT COALESCE(SUM(s.battles), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&battleCount)
if err != nil {
@ -37,7 +37,7 @@ func FetchBattleCount() (int, error) {
return battleCount, nil
}
func FetchClassicSessionCount() (int, error) {
func (s *store) FetchClassicSessionCount() (int, error) {
var classicSessionCount int
err := handle.QueryRow("SELECT COALESCE(SUM(s.classicSessionsPlayed), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&classicSessionCount)
if err != nil {

61
db/s3.go Normal file
View File

@ -0,0 +1,61 @@
package db
import (
"bytes"
"context"
"encoding/json"
"os"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/pagefaultgames/rogueserver/defs"
)
func (s *store) GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
username, err := Store.FetchUsernameFromUUID(uuid)
if err != nil {
return system, err
}
resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
})
if err != nil {
return system, err
}
err = json.NewDecoder(resp.Body).Decode(&system)
if err != nil {
return system, err
}
return system, nil
}
func (s *store) StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error {
username, err := s.FetchUsernameFromUUID(uuid)
if err != nil {
return err
}
buf := new(bytes.Buffer)
err = json.NewEncoder(buf).Encode(data)
if err != nil {
return err
}
_, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
Body: buf,
})
if err != nil {
return err
}
return nil
}

View File

@ -19,19 +19,13 @@ package db
import (
"bytes"
"context"
"encoding/gob"
"encoding/json"
"os"
"github.com/klauspost/compress/zstd"
"github.com/pagefaultgames/rogueserver/defs"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
)
func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
func (s *store) TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
var count int
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
if err != nil {
@ -48,7 +42,7 @@ func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
return true, nil
}
func ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
func (s *store) ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
var count int
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
if err != nil {
@ -58,7 +52,7 @@ func ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
return count > 0, nil
}
func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
func (s *store) ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
var data []byte
@ -82,7 +76,7 @@ func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
return system, nil
}
func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
func (s *store) StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
buf := new(bytes.Buffer)
zw, err := zstd.NewWriter(buf)
@ -108,32 +102,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
return nil
}
func StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error {
username, err := FetchUsernameFromUUID(uuid)
if err != nil {
return err
}
buf := new(bytes.Buffer)
err = json.NewEncoder(buf).Encode(data)
if err != nil {
return err
}
_, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
Body: buf,
})
if err != nil {
return err
}
return nil
}
func DeleteSystemSaveData(uuid []byte) error {
func (s *store) DeleteSystemSaveData(uuid []byte) error {
_, err := handle.Exec("DELETE FROM systemSaveData WHERE uuid = ?", uuid)
if err != nil {
return err
@ -142,7 +111,7 @@ func DeleteSystemSaveData(uuid []byte) error {
return nil
}
func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
func (s *store) ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
var session defs.SessionSaveData
var data []byte
@ -166,7 +135,7 @@ func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
return session, nil
}
func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
func (s *store) GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
var slot int
err := handle.QueryRow("SELECT slot FROM sessionSaveData WHERE uuid = ? ORDER BY timestamp DESC, slot ASC LIMIT 1", uuid).Scan(&slot)
if err != nil {
@ -176,7 +145,7 @@ func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
return slot, nil
}
func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error {
func (s *store) StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error {
buf := new(bytes.Buffer)
zw, err := zstd.NewWriter(buf)
@ -202,7 +171,7 @@ func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) erro
return nil
}
func DeleteSessionSaveData(uuid []byte, slot int) error {
func (s *store) DeleteSessionSaveData(uuid []byte, slot int) error {
_, err := handle.Exec("DELETE FROM sessionSaveData WHERE uuid = ? AND slot = ?", uuid, slot)
if err != nil {
return err
@ -210,27 +179,3 @@ func DeleteSessionSaveData(uuid []byte, slot int) error {
return nil
}
func GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
username, err := FetchUsernameFromUUID(uuid)
if err != nil {
return system, err
}
resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
})
if err != nil {
return system, err
}
err = json.NewDecoder(resp.Body).Decode(&system)
if err != nil {
return system, err
}
return system, nil
}