From 40323809f779b804db34f3edc02e421530b18f1f Mon Sep 17 00:00:00 2001 From: Sirz Benjie <142067137+SirzBenjie@users.noreply.github.com> Date: Tue, 16 Sep 2025 15:01:00 -0500 Subject: [PATCH] Adjust db to allow mocking for unit tests --- api/account/changepw.go | 14 ++- api/account/discord.go | 18 +++- api/account/google.go | 11 ++- api/account/info.go | 12 +-- api/account/login.go | 25 +++-- api/account/login_test.go | 154 ++++++++++++++++++++++++++++++ api/account/logout.go | 12 ++- api/account/register.go | 11 ++- api/common.go | 4 +- api/daily/common.go | 4 +- api/daily/rankings.go | 9 +- api/daily/rankingspagecount.go | 10 +- api/endpoints.go | 165 ++++++++++++++++++--------------- api/savedata/clear.go | 23 +++-- api/savedata/delete.go | 14 ++- api/savedata/newclear.go | 12 ++- api/savedata/session.go | 26 ++++-- api/savedata/system.go | 35 +++++-- api/savedata/update.go | 15 ++- api/stats.go | 19 ++-- db/account.go | 123 ++++++++++++------------ db/daily.go | 10 +- db/db.go | 6 ++ db/game.go | 6 +- db/s3.go | 61 ++++++++++++ db/savedata.go | 73 ++------------- 26 files changed, 577 insertions(+), 295 deletions(-) create mode 100644 api/account/login_test.go create mode 100644 db/s3.go diff --git a/api/account/changepw.go b/api/account/changepw.go index e9726ec..edb36da 100644 --- a/api/account/changepw.go +++ b/api/account/changepw.go @@ -20,11 +20,15 @@ package account import ( "crypto/rand" "fmt" - - "github.com/pagefaultgames/rogueserver/db" ) -func ChangePW(uuid []byte, password string) error { +// Interface for database operations needed for changing password. +type ChangePWStore interface { + RemoveSessionsFromUUID(uuid []byte) error + UpdateAccountPassword(uuid []byte, newKey []byte, newSalt []byte) error +} + +func ChangePW[T ChangePWStore](store T, uuid []byte, password string) error { if len(password) < 6 { return fmt.Errorf("invalid password") } @@ -35,12 +39,12 @@ func ChangePW(uuid []byte, password string) error { return fmt.Errorf("failed to generate salt: %s", err) } - err = db.RemoveSessionsFromUUID(uuid) + err = store.RemoveSessionsFromUUID(uuid) if err != nil { return fmt.Errorf("failed to remove sessions: %s", err) } - err = db.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt) + err = store.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt) if err != nil { return fmt.Errorf("failed to add account record: %s", err) } diff --git a/api/account/discord.go b/api/account/discord.go index 0be4087..f959eb8 100644 --- a/api/account/discord.go +++ b/api/account/discord.go @@ -35,14 +35,24 @@ var ( DiscordGuildID string ) -func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) { +type DiscordProvider interface { + HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) + RetrieveDiscordId(code string) (string, error) + IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) +} + +type discordProvider struct{} + +var Discord = &discordProvider{} + +func (s *discordProvider) HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) { code := r.URL.Query().Get("code") if code == "" { http.Redirect(w, r, GameURL, http.StatusSeeOther) return "", errors.New("code is empty") } - discordId, err := RetrieveDiscordId(code) + discordId, err := s.RetrieveDiscordId(code) if err != nil { http.Redirect(w, r, GameURL, http.StatusSeeOther) return "", err @@ -51,7 +61,7 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro return discordId, nil } -func RetrieveDiscordId(code string) (string, error) { +func (s *discordProvider) RetrieveDiscordId(code string) (string, error) { v := make(url.Values) v.Set("client_id", DiscordClientID) v.Set("client_secret", DiscordClientSecret) @@ -112,7 +122,7 @@ func RetrieveDiscordId(code string) (string, error) { return user.Id, nil } -func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) { +func (s *discordProvider) IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) { // fetch all roles from discord roles, err := DiscordSession.GuildRoles(discordGuildID) if err != nil { diff --git a/api/account/google.go b/api/account/google.go index 30c5ae5..28ee6b4 100644 --- a/api/account/google.go +++ b/api/account/google.go @@ -32,7 +32,16 @@ var ( GoogleCallbackURL string ) -func HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) { +type GoogleProvider interface { + HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) + RetrieveGoogleId(code string) (string, error) +} + +type googleProvider struct{} + +var Google = &googleProvider{} + +func (g *googleProvider) HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) { code := r.URL.Query().Get("code") if code == "" { http.Redirect(w, r, GameURL, http.StatusSeeOther) diff --git a/api/account/info.go b/api/account/info.go index 32cf9ff..67d96f6 100644 --- a/api/account/info.go +++ b/api/account/info.go @@ -17,10 +17,6 @@ package account -import ( - "github.com/pagefaultgames/rogueserver/db" -) - type InfoResponse struct { Username string `json:"username"` DiscordId string `json:"discordId"` @@ -29,9 +25,13 @@ type InfoResponse struct { HasAdminRole bool `json:"hasAdminRole"` } +type InfoStore interface { + GetLatestSessionSaveDataSlot(uuid []byte) (int, error) +} + // /account/info - get account info -func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) { - slot, _ := db.GetLatestSessionSaveDataSlot(uuid) +func Info[T InfoStore](store T, username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) { + slot, _ := store.GetLatestSessionSaveDataSlot(uuid) response := InfoResponse{ Username: username, LastSessionSlot: slot, diff --git a/api/account/login.go b/api/account/login.go index a283444..41a373f 100644 --- a/api/account/login.go +++ b/api/account/login.go @@ -24,14 +24,18 @@ import ( "encoding/base64" "errors" "fmt" - - "github.com/pagefaultgames/rogueserver/db" ) type LoginResponse GenericAuthResponse +// Interface for database operations needed for login. +type LoginStore interface { + FetchAccountKeySaltFromUsername(username string) (key, salt []byte, err error) + AddAccountSession(username string, token []byte) error +} + // /account/login - log into account -func Login(username, password string) (LoginResponse, error) { +func Login[T LoginStore](store T, username, password string) (LoginResponse, error) { var response LoginResponse if !isValidUsername(username) { @@ -42,12 +46,11 @@ func Login(username, password string) (LoginResponse, error) { return response, fmt.Errorf("invalid password") } - key, salt, err := db.FetchAccountKeySaltFromUsername(username) + key, salt, err := store.FetchAccountKeySaltFromUsername(username) if err != nil { if errors.Is(err, sql.ErrNoRows) { return response, fmt.Errorf("account doesn't exist") } - return response, err } @@ -55,8 +58,7 @@ func Login(username, password string) (LoginResponse, error) { return response, fmt.Errorf("password doesn't match") } - response.Token, err = GenerateTokenForUsername(username) - + response.Token, err = GenerateTokenForUsername(store, username) if err != nil { return response, fmt.Errorf("failed to generate token: %s", err) } @@ -64,14 +66,19 @@ func Login(username, password string) (LoginResponse, error) { return response, nil } -func GenerateTokenForUsername(username string) (string, error) { +type GenerateTokenForUsernameStore interface { + AddAccountSession(username string, token []byte) error +} + +// GenerateTokenForUsername generates a session token and stores it using the provided DBAccountStore. +func GenerateTokenForUsername[T GenerateTokenForUsernameStore](store T, username string) (string, error) { token := make([]byte, TokenSize) _, err := rand.Read(token) if err != nil { return "", fmt.Errorf("failed to generate token: %s", err) } - err = db.AddAccountSession(username, token) + err = store.AddAccountSession(username, token) if err != nil { return "", fmt.Errorf("failed to add account session") } diff --git a/api/account/login_test.go b/api/account/login_test.go new file mode 100644 index 0000000..c883065 --- /dev/null +++ b/api/account/login_test.go @@ -0,0 +1,154 @@ +package account + +import ( + "database/sql" + "errors" + "testing" +) + +func defaultMockStore() *mockDBAccountStore { + return &mockDBAccountStore{ + FetchFunc: func(username string) ([]byte, []byte, error) { + return []byte("key"), []byte("salt"), nil + }, + AddSessionFunc: func(username string, token []byte) error { return nil }, + } +} + +type mockDBAccountStore struct { + FetchFunc func(username string) ([]byte, []byte, error) + AddSessionFunc func(username string, token []byte) error +} + +func (m *mockDBAccountStore) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) { + return m.FetchFunc(username) +} +func (m *mockDBAccountStore) AddAccountSession(username string, token []byte) error { + return m.AddSessionFunc(username, token) +} + +func TestLogin(t *testing.T) { + t.Run("UsernameMinLength", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "a", "password123") + if err == nil { + t.Errorf("expected error due to password mismatch or DB, got nil") + } + }) + t.Run("UsernameMaxLength", func(t *testing.T) { + uname := "abcdefghijklmnop" + store := defaultMockStore() + _, err := Login(store, uname, "password123") + if err == nil { + t.Errorf("expected error due to password mismatch or DB, got nil") + } + }) + t.Run("UsernameTooLong", func(t *testing.T) { + uname := "abcdefghijklmnopq" + store := defaultMockStore() + _, err := Login(store, uname, "password123") + if err == nil || err.Error() != "invalid username" { + t.Errorf("expected invalid username error for too long username, got: %v", err) + } + }) + t.Run("UsernameWithInvalidChars", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "user!@#", "password123") + if err == nil || err.Error() != "invalid username" { + t.Errorf("expected invalid username error for special chars, got: %v", err) + } + }) + t.Run("EmptyUsername", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "", "password123") + if err == nil || err.Error() != "invalid username" { + t.Errorf("expected invalid username error for empty username, got: %v", err) + } + }) + t.Run("EmptyPassword", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "validuser", "") + if err == nil || err.Error() != "invalid password" { + t.Errorf("expected invalid password error for empty password, got: %v", err) + } + }) + t.Run("MinPasswordLength", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "validuser", "123456") + if err == nil { + t.Errorf("expected error due to password mismatch or DB, got nil") + } + }) + t.Run("PasswordWithSpecialChars", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "validuser", "p@$$w0rd!") + if err == nil { + t.Errorf("expected error due to password mismatch or DB, got nil") + } + }) + t.Run("DBUnexpectedError", func(t *testing.T) { + store := defaultMockStore() + store.FetchFunc = func(username string) ([]byte, []byte, error) { + return nil, nil, errors.New("some db error") + } + _, err := Login(store, "validuser", "password123") + if err == nil || err.Error() != "some db error" { + t.Errorf("expected DB error to propagate, got: %v", err) + } + }) + t.Run("InvalidUsername", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "!invaliduser", "password123") + if err == nil || err.Error() != "invalid username" { + t.Errorf("expected invalid username error, got: %v", err) + } + }) + t.Run("ShortPassword", func(t *testing.T) { + store := defaultMockStore() + _, err := Login(store, "validuser", "123") + if err == nil || err.Error() != "invalid password" { + t.Errorf("expected invalid password error, got: %v", err) + } + }) + t.Run("AccountDoesNotExist", func(t *testing.T) { + store := defaultMockStore() + store.FetchFunc = func(username string) ([]byte, []byte, error) { + return nil, nil, sql.ErrNoRows + } + _, err := Login(store, "nonexistent", "password123") + if err == nil || err.Error() != "account doesn't exist" { + t.Errorf("expected account doesn't exist error, got: %v", err) + } + }) + t.Run("PasswordMismatch", func(t *testing.T) { + correctSalt := []byte("somesalt") + correctKey := []byte("correctkey") + store := defaultMockStore() + store.FetchFunc = func(username string) ([]byte, []byte, error) { + return correctKey, correctSalt, nil + } + _, err := Login(store, "validuser", "wrongpassword") + if err == nil || err.Error() != "password doesn't match" { + t.Errorf("expected password doesn't match error, got: %v", err) + } + }) + t.Run("Success", func(t *testing.T) { + correctSalt := []byte("somesalt") + password := "goodpassword" + correctKey := deriveArgon2IDKey([]byte(password), correctSalt) + store := defaultMockStore() + store.FetchFunc = func(username string) ([]byte, []byte, error) { + return correctKey, correctSalt, nil + } + store.AddSessionFunc = func(username string, token []byte) error { + return nil + } + resp, err := Login(store, "validuser", password) + if err != nil { + t.Errorf("expected success, got error: %v", err) + } + if resp.Token == "" { + t.Errorf("expected token to be set on success") + } + }) +} diff --git a/api/account/logout.go b/api/account/logout.go index ae3b26f..385f0e7 100644 --- a/api/account/logout.go +++ b/api/account/logout.go @@ -21,13 +21,17 @@ import ( "database/sql" "errors" "fmt" - - "github.com/pagefaultgames/rogueserver/db" ) // /account/logout - log out of account -func Logout(token []byte) error { - err := db.RemoveSessionFromToken(token) + +// Interface for database operations needed for logout. +type LogoutStore interface { + RemoveSessionFromToken(token []byte) error +} + +func Logout[T LogoutStore](store T, token []byte) error { + err := store.RemoveSessionFromToken(token) if err != nil { if errors.Is(err, sql.ErrNoRows) { return fmt.Errorf("token not found") diff --git a/api/account/register.go b/api/account/register.go index 06639e5..f7fc2c2 100644 --- a/api/account/register.go +++ b/api/account/register.go @@ -20,12 +20,15 @@ package account import ( "crypto/rand" "fmt" - - "github.com/pagefaultgames/rogueserver/db" ) +// Interface for database operations needed for registration. +type RegisterStore interface { + AddAccountRecord(uuid []byte, username string, passwordHash []byte, salt []byte) error +} + // /account/register - register account -func Register(username, password string) error { +func Register[T RegisterStore](store T, username, password string) error { if !isValidUsername(username) { return fmt.Errorf("invalid username") } @@ -46,7 +49,7 @@ func Register(username, password string) error { return fmt.Errorf("failed to generate salt: %s", err) } - err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt) + err = store.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt) if err != nil { return fmt.Errorf("failed to add account record: %s", err) } diff --git a/api/common.go b/api/common.go index f7f4a15..a7324ad 100644 --- a/api/common.go +++ b/api/common.go @@ -30,7 +30,7 @@ import ( ) func Init(mux *http.ServeMux) error { - err := scheduleStatRefresh() + err := scheduleStatRefresh(db.Store) if err != nil { return err } @@ -109,7 +109,7 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) { return nil, nil, err } - uuid, err := db.FetchUUIDFromToken(token) + uuid, err := db.Store.FetchUUIDFromToken(token) if err != nil { return nil, nil, fmt.Errorf("failed to validate token: %s", err) } diff --git a/api/daily/common.go b/api/daily/common.go index d31d78e..15a662b 100644 --- a/api/daily/common.go +++ b/api/daily/common.go @@ -61,7 +61,7 @@ func Init() error { secret = newSecret } - seed, err := db.TryAddDailyRun(Seed()) + seed, err := db.Store.TryAddDailyRun(Seed()) if err != nil { log.Print(err) } @@ -71,7 +71,7 @@ func Init() error { _, err = scheduler.AddFunc("@daily", func() { time.Sleep(time.Second) - seed, err = db.TryAddDailyRun(Seed()) + seed, err = db.Store.TryAddDailyRun(Seed()) if err != nil { log.Printf("error while recording new daily: %s", err) } else { diff --git a/api/daily/rankings.go b/api/daily/rankings.go index 175a486..1acc1e0 100644 --- a/api/daily/rankings.go +++ b/api/daily/rankings.go @@ -22,9 +22,14 @@ import ( "github.com/pagefaultgames/rogueserver/defs" ) +// Interface for database operations needed for fetching rankings. +type RankingsStore interface { + FetchRankings(category, page int) ([]defs.DailyRanking, error) +} + // /daily/rankings - fetch daily rankings -func Rankings(category, page int) ([]defs.DailyRanking, error) { - rankings, err := db.FetchRankings(category, page) +func Rankings[T RankingsStore](store T, category, page int) ([]defs.DailyRanking, error) { + rankings, err := db.Store.FetchRankings(category, page) if err != nil { return rankings, err } diff --git a/api/daily/rankingspagecount.go b/api/daily/rankingspagecount.go index 5950c77..76ddabc 100644 --- a/api/daily/rankingspagecount.go +++ b/api/daily/rankingspagecount.go @@ -17,13 +17,13 @@ package daily -import ( - "github.com/pagefaultgames/rogueserver/db" -) +type RankingPageCountStore interface { + FetchRankingPageCount(category int) (int, error) +} // /daily/rankingpagecount - fetch daily ranking page count -func RankingPageCount(category int) (int, error) { - pageCount, err := db.FetchRankingPageCount(category) +func RankingPageCount[T RankingPageCountStore](store T, category int) (int, error) { + pageCount, err := store.FetchRankingPageCount(category) if err != nil { return pageCount, err } diff --git a/api/endpoints.go b/api/endpoints.go index 8db9f1a..5630bd8 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -50,17 +50,17 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) { return } - username, err := db.FetchUsernameFromUUID(uuid) + username, err := db.Store.FetchUsernameFromUUID(uuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - discordId, err := db.FetchDiscordIdByUsername(username) + discordId, err := db.Store.FetchDiscordIdByUsername(username) if err != nil && !errors.Is(err, sql.ErrNoRows) { httpError(w, r, err, http.StatusInternalServerError) return } - googleId, err := db.FetchGoogleIdByUsername(username) + googleId, err := db.Store.FetchGoogleIdByUsername(username) if err != nil && !errors.Is(err, sql.ErrNoRows) { httpError(w, r, err, http.StatusInternalServerError) return @@ -68,10 +68,10 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) { var hasAdminRole bool if discordId != "" { - hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID) + hasAdminRole, _ = account.Discord.IsUserDiscordAdmin(discordId, account.DiscordGuildID) } - response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole) + response, err := account.Info(db.Store, username, discordId, googleId, uuid, hasAdminRole) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -81,7 +81,7 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) { } func handleAccountRegister(w http.ResponseWriter, r *http.Request) { - err := account.Register(r.PostFormValue("username"), r.PostFormValue("password")) + err := account.Register(db.Store, r.PostFormValue("username"), r.PostFormValue("password")) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -91,7 +91,7 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) { } func handleAccountLogin(w http.ResponseWriter, r *http.Request) { - response, err := account.Login(r.PostFormValue("username"), r.PostFormValue("password")) + response, err := account.Login(db.Store, r.PostFormValue("username"), r.PostFormValue("password")) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -107,20 +107,20 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) { return } - err = account.ChangePW(uuid, r.PostFormValue("password")) + err = account.ChangePW(db.Store, uuid, r.PostFormValue("password")) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - username, err := db.FetchUsernameFromUUID(uuid) + username, err := db.Store.FetchUsernameFromUUID(uuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } // create a new session with these credentials - response, err := account.Login(username, r.Form.Get("password")) + response, err := account.Login(db.Store, username, r.Form.Get("password")) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -136,7 +136,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) { return } - err = account.Logout(token) + err = account.Logout(db.Store, token) if err != nil { // also possible for InternalServerError but that's unlikely unless the server blew up httpError(w, r, err, http.StatusUnauthorized) @@ -183,7 +183,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) { return } - err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) + err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return @@ -191,7 +191,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) { switch r.PathValue("action") { case "get": - save, err := savedata.GetSession(uuid, slot) + save, err := savedata.GetSession(db.Store, uuid, slot) if err != nil { if errors.Is(err, savedata.ErrSaveNotExist) { http.Error(w, err.Error(), http.StatusNotFound) @@ -211,7 +211,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) { return } - existingSave, err := savedata.GetSession(uuid, slot) + existingSave, err := savedata.GetSession(db.Store, uuid, slot) if err != nil { if !errors.Is(err, savedata.ErrSaveNotExist) { httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError) @@ -224,7 +224,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) { } } - err = savedata.UpdateSession(uuid, slot, session) + err = savedata.UpdateSession(db.Store, uuid, slot, session) if err != nil { httpError(w, r, fmt.Errorf("failed to put session data: %s", err), http.StatusInternalServerError) return @@ -239,13 +239,13 @@ func handleSession(w http.ResponseWriter, r *http.Request) { return } - seed, err := db.GetDailyRunSeed() + seed, err := db.Store.GetDailyRunSeed() if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - resp, err := savedata.Clear(uuid, slot, seed, session) + resp, err := savedata.Clear(db.Store, uuid, slot, seed, session) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -253,7 +253,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) { writeJSON(w, r, resp) case "newclear": - resp, err := savedata.NewClear(uuid, slot) + resp, err := savedata.NewClear(db.Store, uuid, slot) if err != nil { httpError(w, r, fmt.Errorf("failed to read new clear: %s", err), http.StatusInternalServerError) return @@ -261,7 +261,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) { writeJSON(w, r, resp) case "delete": - err := savedata.DeleteSession(uuid, slot) + err := savedata.DeleteSession(db.Store, uuid, slot) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -301,7 +301,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { return } - active, err := db.IsActiveSession(uuid, data.ClientSessionId) + active, err := db.Store.IsActiveSession(uuid, data.ClientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -312,7 +312,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { return } - storedTrainerId, storedSecretId, err := db.FetchTrainerIds(uuid) + storedTrainerId, storedSecretId, err := db.Store.FetchTrainerIds(uuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -324,14 +324,14 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { return } } else { - err = db.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid) + err = db.Store.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } } - oldSystem, err := savedata.GetSystem(uuid) + oldSystem, err := savedata.GetSystem(db.Store, uuid) if err != nil { if !errors.Is(err, savedata.ErrSaveNotExist) { httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError) @@ -356,7 +356,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { } } - existingSave, err := savedata.GetSession(uuid, data.SessionSlotId) + existingSave, err := savedata.GetSession(db.Store, uuid, data.SessionSlotId) if err != nil { if !errors.Is(err, savedata.ErrSaveNotExist) { httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError) @@ -369,13 +369,13 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { } } - err = savedata.Update(uuid, data.SessionSlotId, data.Session) + err = savedata.Update(db.Store, uuid, data.SessionSlotId, data.Session) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - err = savedata.Update(uuid, 0, data.System) + err = savedata.Update(db.Store, uuid, 0, data.System) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -401,7 +401,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) { return } - active, err := db.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId")) + active, err := db.Store.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId")) if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -410,14 +410,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) { switch r.PathValue("action") { case "get": if !active { - err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) + err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return } } - save, err := savedata.GetSystem(uuid) + save, err := savedata.GetSystem(db.Store, uuid) if err != nil { if errors.Is(err, savedata.ErrSaveNotExist) { http.Error(w, err.Error(), http.StatusNotFound) @@ -442,7 +442,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) { return } - oldSystem, err := savedata.GetSystem(uuid) + oldSystem, err := savedata.GetSystem(db.Store, uuid) if err != nil { if !errors.Is(err, savedata.ErrSaveNotExist) { httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError) @@ -467,7 +467,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) { } } - err = savedata.UpdateSystem(uuid, system) + err = savedata.UpdateSystem(db.Store, uuid, system) if err != nil { httpError(w, r, fmt.Errorf("failed to put system data: %s", err), http.StatusInternalServerError) return @@ -481,13 +481,13 @@ func handleSystem(w http.ResponseWriter, r *http.Request) { // not valid, send server state if !active { - err := db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) + err := db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return } - storedSaveData, err := db.ReadSystemSaveData(uuid) + storedSaveData, err := db.Store.ReadSystemSaveData(uuid) if err != nil { httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError) return @@ -498,7 +498,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) { writeJSON(w, r, response) case "delete": - err := savedata.DeleteSystem(uuid) + err := savedata.DeleteSystem(db.Store, uuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -511,9 +511,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) { } } +// Interface providing database operations needed for getting daily seed. +type HandleDailySeedStore interface { + GetDailyRunSeed() (string, error) +} + // daily func handleDailySeed(w http.ResponseWriter, r *http.Request) { - seed, err := db.GetDailyRunSeed() + seed, err := db.Store.GetDailyRunSeed() if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -522,6 +527,11 @@ func handleDailySeed(w http.ResponseWriter, r *http.Request) { fmt.Fprint(w, seed) } +// Interface for database operations needed for getting daily rankings. +type HandleDailyRankingsStore interface { + daily.RankingsStore +} + func handleDailyRankings(w http.ResponseWriter, r *http.Request) { var err error @@ -543,7 +553,7 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) { } } - rankings, err := daily.Rankings(category, page) + rankings, err := daily.Rankings(db.Store, category, page) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -552,6 +562,10 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) { writeJSON(w, r, rankings) } +type HandleDailyRankingsPageCountStore interface { + daily.RankingPageCountStore +} + func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) { var category int if r.URL.Query().Has("category") { @@ -563,7 +577,7 @@ func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) { } } - count, err := daily.RankingPageCount(category) + count, err := daily.RankingPageCount(db.Store, category) if err != nil { httpError(w, r, err, http.StatusInternalServerError) } @@ -579,9 +593,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) { var err error switch provider { case "discord": - externalAuthId, err = account.HandleDiscordCallback(w, r) + externalAuthId, err = account.Discord.HandleDiscordCallback(w, r) case "google": - externalAuthId, err = account.HandleGoogleCallback(w, r) + externalAuthId, err = account.Google.HandleGoogleCallback(w, r) default: http.Error(w, "invalid provider", http.StatusBadRequest) return @@ -600,7 +614,7 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) { return } - userName, err := db.FetchUsernameBySessionToken(stateByte) + userName, err := db.Store.FetchUsernameBySessionToken(stateByte) if err != nil { http.Redirect(w, r, account.GameURL, http.StatusSeeOther) return @@ -608,9 +622,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) { switch provider { case "discord": - err = db.AddDiscordIdByUsername(externalAuthId, userName) + err = db.Store.AddDiscordIdByUsername(externalAuthId, userName) case "google": - err = db.AddGoogleIdByUsername(externalAuthId, userName) + err = db.Store.AddGoogleIdByUsername(externalAuthId, userName) } if err != nil { @@ -622,16 +636,16 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) { var userName string switch provider { case "discord": - userName, err = db.FetchUsernameByDiscordId(externalAuthId) + userName, err = db.Store.FetchUsernameByDiscordId(externalAuthId) case "google": - userName, err = db.FetchUsernameByGoogleId(externalAuthId) + userName, err = db.Store.FetchUsernameByGoogleId(externalAuthId) } if err != nil { http.Redirect(w, r, account.GameURL, http.StatusSeeOther) return } - sessionToken, err := account.GenerateTokenForUsername(userName) + sessionToken, err := account.GenerateTokenForUsername(db.Store, userName) if err != nil { http.Redirect(w, r, account.GameURL, http.StatusSeeOther) return @@ -651,6 +665,11 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, account.GameURL, http.StatusSeeOther) } +type HandleProviderLogoutStore interface { + RemoveDiscordIdByUUID(uuid []byte) error + RemoveGoogleIdByUUID(uuid []byte) error +} + func handleProviderLogout(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { @@ -660,9 +679,9 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) { switch r.PathValue("provider") { case "discord": - err = db.RemoveDiscordIdByUUID(uuid) + err = db.Store.RemoveDiscordIdByUUID(uuid) case "google": - err = db.RemoveGoogleIdByUUID(uuid) + err = db.Store.RemoveGoogleIdByUUID(uuid) default: http.Error(w, "invalid provider", http.StatusBadRequest) return @@ -681,13 +700,13 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) { return } - userDiscordId, err := db.FetchDiscordIdByUUID(uuid) + userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid) if err != nil { httpError(w, r, err, http.StatusUnauthorized) return } - hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) + hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) if !hasRole || err != nil { httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) return @@ -698,19 +717,19 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) { // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server - _, err = db.CheckUsernameExists(username) + _, err = db.Store.CheckUsernameExists(username) if err != nil { httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) return } - userUuid, err := db.FetchUUIDFromUsername(username) + userUuid, err := db.Store.FetchUUIDFromUsername(username) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - err = db.AddDiscordIdByUUID(discordId, userUuid) + err = db.Store.AddDiscordIdByUUID(discordId, userUuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -728,13 +747,13 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) { return } - userDiscordId, err := db.FetchDiscordIdByUUID(uuid) + userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid) if err != nil { httpError(w, r, err, http.StatusUnauthorized) return } - hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) + hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) if !hasRole || err != nil { httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) return @@ -748,26 +767,26 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) { log.Printf("Username given, removing discordId") // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server - _, err = db.CheckUsernameExists(username) + _, err = db.Store.CheckUsernameExists(username) if err != nil { httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) return } - userUuid, err := db.FetchUUIDFromUsername(username) + userUuid, err := db.Store.FetchUUIDFromUsername(username) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - err = db.RemoveDiscordIdByUUID(userUuid) + err = db.Store.RemoveDiscordIdByUUID(userUuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } case discordId != "": log.Printf("DiscordID given, removing discordId") - err = db.RemoveDiscordIdByDiscordId(discordId) + err = db.Store.RemoveDiscordIdByDiscordId(discordId) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -786,13 +805,13 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) { return } - userDiscordId, err := db.FetchDiscordIdByUUID(uuid) + userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid) if err != nil { httpError(w, r, err, http.StatusUnauthorized) return } - hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) + hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) if !hasRole || err != nil { httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) return @@ -803,19 +822,19 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) { // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server - _, err = db.CheckUsernameExists(username) + _, err = db.Store.CheckUsernameExists(username) if err != nil { httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) return } - userUuid, err := db.FetchUUIDFromUsername(username) + userUuid, err := db.Store.FetchUUIDFromUsername(username) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - err = db.AddGoogleIdByUUID(googleId, userUuid) + err = db.Store.AddGoogleIdByUUID(googleId, userUuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -833,13 +852,13 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) { return } - userDiscordId, err := db.FetchDiscordIdByUUID(uuid) + userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid) if err != nil { httpError(w, r, err, http.StatusUnauthorized) return } - hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) + hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) if !hasRole || err != nil { httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) return @@ -853,26 +872,26 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) { log.Printf("Username given, removing googleId") // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server - _, err = db.CheckUsernameExists(username) + _, err = db.Store.CheckUsernameExists(username) if err != nil { httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) return } - userUuid, err := db.FetchUUIDFromUsername(username) + userUuid, err := db.Store.FetchUUIDFromUsername(username) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } - err = db.RemoveGoogleIdByUUID(userUuid) + err = db.Store.RemoveGoogleIdByUUID(userUuid) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return } case googleId != "": log.Printf("DiscordID given, removing googleId") - err = db.RemoveGoogleIdByDiscordId(googleId) + err = db.Store.RemoveGoogleIdByDiscordId(googleId) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return @@ -891,13 +910,13 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) { return } - userDiscordId, err := db.FetchDiscordIdByUUID(uuid) + userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid) if err != nil { httpError(w, r, err, http.StatusUnauthorized) return } - hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) + hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) if !hasRole || err != nil { httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) return @@ -907,14 +926,14 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) { // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server - _, err = db.CheckUsernameExists(username) + _, err = db.Store.CheckUsernameExists(username) if err != nil { httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) return } // this does a single call that does a query for multiple columns from our database and makes an object out of it, which is returned to us - adminSearchResult, err := db.FetchAdminDetailsByUsername(username) + adminSearchResult, err := db.Store.FetchAdminDetailsByUsername(username) if err != nil { httpError(w, r, err, http.StatusInternalServerError) return diff --git a/api/savedata/clear.go b/api/savedata/clear.go index 35b9a59..ad50fcc 100644 --- a/api/savedata/clear.go +++ b/api/savedata/clear.go @@ -21,7 +21,6 @@ import ( "fmt" "log" - "github.com/pagefaultgames/rogueserver/db" "github.com/pagefaultgames/rogueserver/defs" ) @@ -30,10 +29,20 @@ type ClearResponse struct { Error string `json:"error"` } +// Interface for database operations needed for `Clear` +// Helps with testing and reduces coupling. +type ClearStore interface { + UpdateAccountLastActivity(uuid []byte) error + TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) + DeleteSessionSaveData(uuid []byte, slot int) error + AddOrUpdateAccountDailyRun(uuid []byte, score int, waveCompleted int) error + SetAccountBanned(uuid []byte, banned bool) error +} + // /savedata/clear - mark session save data as cleared and delete -func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) { +func Clear[T ClearStore](store T, uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) { var response ClearResponse - err := db.UpdateAccountLastActivity(uuid) + err := store.UpdateAccountLastActivity(uuid) if err != nil { log.Print("failed to update account last activity") } @@ -51,23 +60,23 @@ func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (Clear } if save.Score >= 20000 { - db.SetAccountBanned(uuid, true) + store.SetAccountBanned(uuid, true) } - err = db.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted) + err = store.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted) if err != nil { log.Printf("failed to add or update daily run record: %s", err) } } if sessionCompleted { - response.Success, err = db.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode)) + response.Success, err = store.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode)) if err != nil { log.Printf("failed to mark seed as completed: %s", err) } } - err = db.DeleteSessionSaveData(uuid, slot) + err = store.DeleteSessionSaveData(uuid, slot) if err != nil { log.Printf("failed to delete session save data: %s", err) } diff --git a/api/savedata/delete.go b/api/savedata/delete.go index cb59c3e..fd57902 100644 --- a/api/savedata/delete.go +++ b/api/savedata/delete.go @@ -21,13 +21,19 @@ import ( "fmt" "log" - "github.com/pagefaultgames/rogueserver/db" "github.com/pagefaultgames/rogueserver/defs" ) +// Interface for database operations needed for delete. +// This is to allow easier testing and less coupling. +type DeleteStore interface { + UpdateAccountLastActivity(uuid []byte) error + DeleteSessionSaveData(uuid []byte, slot int) error +} + // /savedata/delete - delete save data -func Delete(uuid []byte, datatype, slot int) error { - err := db.UpdateAccountLastActivity(uuid) +func Delete[T DeleteStore](store T, uuid []byte, datatype, slot int) error { + err := store.UpdateAccountLastActivity(uuid) if err != nil { log.Print("failed to update account last activity") } @@ -39,7 +45,7 @@ func Delete(uuid []byte, datatype, slot int) error { break } - err = db.DeleteSessionSaveData(uuid, slot) + err = store.DeleteSessionSaveData(uuid, slot) default: err = fmt.Errorf("invalid data type") } diff --git a/api/savedata/newclear.go b/api/savedata/newclear.go index 1420027..30c914d 100644 --- a/api/savedata/newclear.go +++ b/api/savedata/newclear.go @@ -20,22 +20,26 @@ package savedata import ( "fmt" - "github.com/pagefaultgames/rogueserver/db" "github.com/pagefaultgames/rogueserver/defs" ) +type NewClearStore interface { + ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) + ReadSeedCompleted(uuid []byte, seed string) (bool, error) +} + // /savedata/newclear - return whether a session is a new clear for its seed -func NewClear(uuid []byte, slot int) (bool, error) { +func NewClear[T NewClearStore](store T, uuid []byte, slot int) (bool, error) { if slot < 0 || slot >= defs.SessionSlotCount { return false, fmt.Errorf("slot id %d out of range", slot) } - session, err := db.ReadSessionSaveData(uuid, slot) + session, err := store.ReadSessionSaveData(uuid, slot) if err != nil { return false, err } - completed, err := db.ReadSeedCompleted(uuid, session.Seed) + completed, err := store.ReadSeedCompleted(uuid, session.Seed) if err != nil { return false, fmt.Errorf("failed to read seed completed: %s", err) } diff --git a/api/savedata/session.go b/api/savedata/session.go index 76bf71f..fc2916c 100644 --- a/api/savedata/session.go +++ b/api/savedata/session.go @@ -21,12 +21,15 @@ import ( "database/sql" "errors" - "github.com/pagefaultgames/rogueserver/db" "github.com/pagefaultgames/rogueserver/defs" ) -func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) { - session, err := db.ReadSessionSaveData(uuid, slot) +type GetSessionStore interface { + ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) +} + +func GetSession[T GetSessionStore](store T, uuid []byte, slot int) (defs.SessionSaveData, error) { + session, err := store.ReadSessionSaveData(uuid, slot) if err != nil { if errors.Is(err, sql.ErrNoRows) { err = ErrSaveNotExist @@ -38,8 +41,13 @@ func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) { return session, nil } -func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error { - err := db.StoreSessionSaveData(uuid, data, slot) +type UpdateSessionStore interface { + StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error + DeleteSessionSaveData(uuid []byte, slot int) error +} + +func UpdateSession[T UpdateSessionStore](store T, uuid []byte, slot int, data defs.SessionSaveData) error { + err := store.StoreSessionSaveData(uuid, data, slot) if err != nil { return err } @@ -47,8 +55,12 @@ func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error { return nil } -func DeleteSession(uuid []byte, slot int) error { - err := db.DeleteSessionSaveData(uuid, slot) +type DeleteSessionStore interface { + DeleteSessionSaveData(uuid []byte, slot int) error +} + +func DeleteSession[T DeleteSessionStore](store T, uuid []byte, slot int) error { + err := store.DeleteSessionSaveData(uuid, slot) if err != nil { return err } diff --git a/api/savedata/system.go b/api/savedata/system.go index 25d6b58..7d42bf7 100644 --- a/api/savedata/system.go +++ b/api/savedata/system.go @@ -24,24 +24,28 @@ import ( "os" "github.com/aws/aws-sdk-go-v2/service/s3/types" - "github.com/pagefaultgames/rogueserver/db" "github.com/pagefaultgames/rogueserver/defs" ) var ErrSaveNotExist = errors.New("save does not exist") -func GetSystem(uuid []byte) (defs.SystemSaveData, error) { +type GetSystemStore interface { + GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) + ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) +} + +func GetSystem[T GetSystemStore](store T, uuid []byte) (defs.SystemSaveData, error) { var system defs.SystemSaveData var err error if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3 - system, err = db.GetSystemSaveFromS3(uuid) + system, err = store.GetSystemSaveFromS3(uuid) var nokey *types.NoSuchKey if errors.As(err, &nokey) { err = ErrSaveNotExist } } else { // use database - system, err = db.ReadSystemSaveData(uuid) + system, err = store.ReadSystemSaveData(uuid) if errors.Is(err, sql.ErrNoRows) { err = ErrSaveNotExist } @@ -53,20 +57,27 @@ func GetSystem(uuid []byte) (defs.SystemSaveData, error) { return system, nil } -func UpdateSystem(uuid []byte, data defs.SystemSaveData) error { +// Interface for database operations needed for updating system data. +type UpdateSystemStore interface { + UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error + StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error + StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error +} + +func UpdateSystem[T UpdateSystemStore](store T, uuid []byte, data defs.SystemSaveData) error { if data.TrainerId == 0 && data.SecretId == 0 { return fmt.Errorf("invalid system data") } - err := db.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts) + err := store.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts) if err != nil { return fmt.Errorf("failed to update account stats: %s", err) } if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3 - err = db.StoreSystemSaveDataS3(uuid, data) + err = store.StoreSystemSaveDataS3(uuid, data) } else { - err = db.StoreSystemSaveData(uuid, data) + err = store.StoreSystemSaveData(uuid, data) } if err != nil { return err @@ -75,8 +86,12 @@ func UpdateSystem(uuid []byte, data defs.SystemSaveData) error { return nil } -func DeleteSystem(uuid []byte) error { - err := db.DeleteSystemSaveData(uuid) +type DeleteSystemStore interface { + DeleteSystemSaveData(uuid []byte) error +} + +func DeleteSystem[T DeleteSystemStore](store T, uuid []byte) error { + err := store.DeleteSystemSaveData(uuid) if err != nil { return err } diff --git a/api/savedata/update.go b/api/savedata/update.go index 7bbe91d..52609bb 100644 --- a/api/savedata/update.go +++ b/api/savedata/update.go @@ -21,13 +21,18 @@ import ( "fmt" "log" - "github.com/pagefaultgames/rogueserver/db" "github.com/pagefaultgames/rogueserver/defs" ) +type UpdateStore interface { + UpdateAccountLastActivity(uuid []byte) error + UpdateSystemStore + UpdateSessionStore +} + // /savedata/update - update save data -func Update(uuid []byte, slot int, save any) error { - err := db.UpdateAccountLastActivity(uuid) +func Update[T UpdateStore](store T, uuid []byte, slot int, save any) error { + err := store.UpdateAccountLastActivity(uuid) if err != nil { log.Print("failed to update account last activity") } @@ -38,13 +43,13 @@ func Update(uuid []byte, slot int, save any) error { return fmt.Errorf("invalid system data") } - return UpdateSystem(uuid, save) + return UpdateSystem(store, uuid, save) case defs.SessionSaveData: // Session if slot < 0 || slot >= defs.SessionSlotCount { return fmt.Errorf("slot id %d out of range", slot) } - return UpdateSession(uuid, slot, save) + return UpdateSession(store, uuid, slot, save) default: return fmt.Errorf("invalid data type") } diff --git a/api/stats.go b/api/stats.go index c2ab176..a4b5933 100644 --- a/api/stats.go +++ b/api/stats.go @@ -21,7 +21,6 @@ import ( "log" "time" - "github.com/pagefaultgames/rogueserver/db" "github.com/robfig/cron/v3" ) @@ -32,9 +31,9 @@ var ( classicSessionCount int ) -func scheduleStatRefresh() error { +func scheduleStatRefresh[T updateStatsStore](store T) error { _, err := scheduler.AddFunc("@every 30s", func() { - err := updateStats() + err := updateStats(store) if err != nil { log.Printf("failed to update stats: %s", err) } @@ -47,19 +46,25 @@ func scheduleStatRefresh() error { return nil } -func updateStats() error { +type updateStatsStore interface { + FetchPlayerCount() (int, error) + FetchBattleCount() (int, error) + FetchClassicSessionCount() (int, error) +} + +func updateStats[T updateStatsStore](store T) error { var err error - playerCount, err = db.FetchPlayerCount() + playerCount, err = store.FetchPlayerCount() if err != nil { return err } - battleCount, err = db.FetchBattleCount() + battleCount, err = store.FetchBattleCount() if err != nil { return err } - classicSessionCount, err = db.FetchClassicSessionCount() + classicSessionCount, err = store.FetchClassicSessionCount() if err != nil { return err } diff --git a/db/account.go b/db/account.go index fbcce0e..6d2ec64 100644 --- a/db/account.go +++ b/db/account.go @@ -1,18 +1,18 @@ /* - Copyright (C) 2024 - 2025 Pagefault Games + Copyright (C) 2024 - 2025 Pagefault Games - This program is free software: you can redistribute it and/or modify - it under the terms of the GNU Affero General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. - This program is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU Affero General Public License for more details. + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. - You should have received a copy of the GNU Affero General Public License - along with this program. If not, see . + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . */ package db @@ -27,30 +27,30 @@ import ( "github.com/pagefaultgames/rogueserver/defs" ) -func AddAccountRecord(uuid []byte, username string, key, salt []byte) error { - _, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt) - if err != nil { - return err - } - - return nil -} - -func AddAccountSession(username string, token []byte) error { +func (s *store) AddAccountSession(username string, token []byte) error { _, err := handle.Exec("INSERT INTO sessions (uuid, token, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username) if err != nil { return err } - _, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username) if err != nil { return err } - return nil } -func AddDiscordIdByUsername(discordId string, username string) error { +// (removed, now a method on store) +func (s *store) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) { + var key, salt []byte + err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt) + if err != nil { + return nil, nil, err + } + + return key, salt, nil +} + +func (s *store) AddDiscordIdByUsername(discordId string, username string) error { _, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE username = ?", discordId, username) if err != nil { return err @@ -59,7 +59,7 @@ func AddDiscordIdByUsername(discordId string, username string) error { return nil } -func AddGoogleIdByUsername(googleId string, username string) error { +func (s *store) AddGoogleIdByUsername(googleId string, username string) error { _, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE username = ?", googleId, username) if err != nil { return err @@ -68,7 +68,7 @@ func AddGoogleIdByUsername(googleId string, username string) error { return nil } -func AddGoogleIdByUUID(googleId string, uuid []byte) error { +func (s *store) AddGoogleIdByUUID(googleId string, uuid []byte) error { _, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE uuid = ?", googleId, uuid) if err != nil { return err @@ -77,7 +77,7 @@ func AddGoogleIdByUUID(googleId string, uuid []byte) error { return nil } -func AddDiscordIdByUUID(discordId string, uuid []byte) error { +func (s *store) AddDiscordIdByUUID(discordId string, uuid []byte) error { _, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE uuid = ?", discordId, uuid) if err != nil { return err @@ -86,7 +86,7 @@ func AddDiscordIdByUUID(discordId string, uuid []byte) error { return nil } -func FetchUsernameByDiscordId(discordId string) (string, error) { +func (s *store) FetchUsernameByDiscordId(discordId string) (string, error) { var username string err := handle.QueryRow("SELECT username FROM accounts WHERE discordId = ?", discordId).Scan(&username) if err != nil { @@ -96,7 +96,7 @@ func FetchUsernameByDiscordId(discordId string) (string, error) { return username, nil } -func FetchUsernameByGoogleId(googleId string) (string, error) { +func (s *store) FetchUsernameByGoogleId(googleId string) (string, error) { var username string err := handle.QueryRow("SELECT username FROM accounts WHERE googleId = ?", googleId).Scan(&username) if err != nil { @@ -106,7 +106,7 @@ func FetchUsernameByGoogleId(googleId string) (string, error) { return username, nil } -func FetchDiscordIdByUsername(username string) (string, error) { +func (s *store) FetchDiscordIdByUsername(username string) (string, error) { var discordId sql.NullString err := handle.QueryRow("SELECT discordId FROM accounts WHERE username = ?", username).Scan(&discordId) if err != nil { @@ -120,7 +120,7 @@ func FetchDiscordIdByUsername(username string) (string, error) { return discordId.String, nil } -func FetchGoogleIdByUsername(username string) (string, error) { +func (s *store) FetchGoogleIdByUsername(username string) (string, error) { var googleId sql.NullString err := handle.QueryRow("SELECT googleId FROM accounts WHERE username = ?", username).Scan(&googleId) if err != nil { @@ -134,7 +134,7 @@ func FetchGoogleIdByUsername(username string) (string, error) { return googleId.String, nil } -func FetchDiscordIdByUUID(uuid []byte) (string, error) { +func (s *store) FetchDiscordIdByUUID(uuid []byte) (string, error) { var discordId sql.NullString err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId) if err != nil { @@ -148,7 +148,7 @@ func FetchDiscordIdByUUID(uuid []byte) (string, error) { return discordId.String, nil } -func FetchGoogleIdByUUID(uuid []byte) (string, error) { +func (s *store) FetchGoogleIdByUUID(uuid []byte) (string, error) { var googleId sql.NullString err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId) if err != nil { @@ -162,7 +162,7 @@ func FetchGoogleIdByUUID(uuid []byte) (string, error) { return googleId.String, nil } -func FetchUsernameBySessionToken(token []byte) (string, error) { +func (s *store) FetchUsernameBySessionToken(token []byte) (string, error) { var username string err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username) if err != nil { @@ -172,7 +172,7 @@ func FetchUsernameBySessionToken(token []byte) (string, error) { return username, nil } -func CheckUsernameExists(username string) (string, error) { +func (s *store) CheckUsernameExists(username string) (string, error) { var dbUsername sql.NullString err := handle.QueryRow("SELECT username FROM accounts WHERE username = ?", username).Scan(&dbUsername) if err != nil { @@ -185,7 +185,7 @@ func CheckUsernameExists(username string) (string, error) { return dbUsername.String, nil } -func FetchLastLoggedInDateByUsername(username string) (string, error) { +func (s *store) FetchLastLoggedInDateByUsername(username string) (string, error) { var lastLoggedIn sql.NullString err := handle.QueryRow("SELECT lastLoggedIn FROM accounts WHERE username = ?", username).Scan(&lastLoggedIn) if err != nil { @@ -206,7 +206,7 @@ type AdminSearchResponse struct { Registered string `json:"registered"` } -func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) { +func (s *store) FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) { var username, discordId, googleId, lastActivity, registered sql.NullString var adminResponse AdminSearchResponse @@ -226,7 +226,7 @@ func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) return adminResponse, nil } -func UpdateAccountPassword(uuid, key, salt []byte) error { +func (s *store) UpdateAccountPassword(uuid, key, salt []byte) error { _, err := handle.Exec("UPDATE accounts SET hash = ?, salt = ? WHERE uuid = ?", key, salt, uuid) if err != nil { return err @@ -235,7 +235,7 @@ func UpdateAccountPassword(uuid, key, salt []byte) error { return nil } -func UpdateAccountLastActivity(uuid []byte) error { +func (s *store) UpdateAccountLastActivity(uuid []byte) error { _, err := handle.Exec("UPDATE accounts SET lastActivity = UTC_TIMESTAMP() WHERE uuid = ?", uuid) if err != nil { return err @@ -244,7 +244,7 @@ func UpdateAccountLastActivity(uuid []byte) error { return nil } -func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error { +func (s *store) UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error { var columns = []string{"playTime", "battles", "classicSessionsPlayed", "sessionsWon", "highestEndlessWave", "highestLevel", "pokemonSeen", "pokemonDefeated", "pokemonCaught", "pokemonHatched", "eggsPulled", "regularVouchers", "plusVouchers", "premiumVouchers", "goldenVouchers"} var statCols []string @@ -321,7 +321,7 @@ func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[str return nil } -func SetAccountBanned(uuid []byte, banned bool) error { +func (s *store) SetAccountBanned(uuid []byte, banned bool) error { _, err := handle.Exec("UPDATE accounts SET banned = ? WHERE uuid = ?", banned, uuid) if err != nil { return err @@ -330,17 +330,16 @@ func SetAccountBanned(uuid []byte, banned bool) error { return nil } -func FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) { - var key, salt []byte - err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt) +func (s *store) AddAccountRecord(uuid []byte, username string, key, salt []byte) error { + _, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt) if err != nil { - return nil, nil, err + return err } - return key, salt, nil + return nil } -func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) { +func (s *store) FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) { err = handle.QueryRow("SELECT trainerId, secretId FROM accounts WHERE uuid = ?", uuid).Scan(&trainerId, &secretId) if err != nil { return 0, 0, err @@ -349,7 +348,7 @@ func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) { return trainerId, secretId, nil } -func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error { +func (s *store) UpdateTrainerIds(trainerId, secretId int, uuid []byte) error { _, err := handle.Exec("UPDATE accounts SET trainerId = ?, secretId = ? WHERE uuid = ?", trainerId, secretId, uuid) if err != nil { return err @@ -358,12 +357,12 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error { return nil } -func IsActiveSession(uuid []byte, sessionId string) (bool, error) { +func (s *store) IsActiveSession(uuid []byte, sessionId string) (bool, error) { var id string err := handle.QueryRow("SELECT clientSessionId FROM activeClientSessions WHERE uuid = ?", uuid).Scan(&id) if err != nil { if errors.Is(err, sql.ErrNoRows) { - err = UpdateActiveSession(uuid, sessionId) + err = s.UpdateActiveSession(uuid, sessionId) if err != nil { return false, err } @@ -377,7 +376,7 @@ func IsActiveSession(uuid []byte, sessionId string) (bool, error) { return id == "" || id == sessionId, nil } -func UpdateActiveSession(uuid []byte, clientSessionId string) error { +func (s *store) UpdateActiveSession(uuid []byte, clientSessionId string) error { _, err := handle.Exec("INSERT INTO activeClientSessions (uuid, clientSessionId) VALUES (?, ?) ON DUPLICATE KEY UPDATE clientSessionId = ?", uuid, clientSessionId, clientSessionId) if err != nil { return err @@ -386,7 +385,7 @@ func UpdateActiveSession(uuid []byte, clientSessionId string) error { return nil } -func FetchUUIDFromToken(token []byte) ([]byte, error) { +func (s *store) FetchUUIDFromToken(token []byte) ([]byte, error) { var uuid []byte err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid) if err != nil { @@ -396,7 +395,7 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) { return uuid, nil } -func RemoveSessionFromToken(token []byte) error { +func (s *store) RemoveSessionFromToken(token []byte) error { _, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token) if err != nil { return err @@ -405,7 +404,7 @@ func RemoveSessionFromToken(token []byte) error { return nil } -func RemoveSessionsFromUUID(uuid []byte) error { +func (s *store) RemoveSessionsFromUUID(uuid []byte) error { _, err := handle.Exec("DELETE FROM sessions WHERE uuid = ?", uuid) if err != nil { return err @@ -414,7 +413,7 @@ func RemoveSessionsFromUUID(uuid []byte) error { return nil } -func FetchUsernameFromUUID(uuid []byte) (string, error) { +func (s *store) FetchUsernameFromUUID(uuid []byte) (string, error) { var username string err := handle.QueryRow("SELECT username FROM accounts WHERE uuid = ?", uuid).Scan(&username) if err != nil { @@ -424,7 +423,7 @@ func FetchUsernameFromUUID(uuid []byte) (string, error) { return username, nil } -func FetchUUIDFromUsername(username string) ([]byte, error) { +func (s *store) FetchUUIDFromUsername(username string) ([]byte, error) { var uuid []byte err := handle.QueryRow("SELECT uuid FROM accounts WHERE username = ?", username).Scan(&uuid) if err != nil { @@ -434,7 +433,7 @@ func FetchUUIDFromUsername(username string) ([]byte, error) { return uuid, nil } -func RemoveDiscordIdByUUID(uuid []byte) error { +func (s *store) RemoveDiscordIdByUUID(uuid []byte) error { _, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE uuid = ?", uuid) if err != nil { return err @@ -443,7 +442,7 @@ func RemoveDiscordIdByUUID(uuid []byte) error { return nil } -func RemoveGoogleIdByUUID(uuid []byte) error { +func (s *store) RemoveGoogleIdByUUID(uuid []byte) error { _, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE uuid = ?", uuid) if err != nil { return err @@ -452,7 +451,7 @@ func RemoveGoogleIdByUUID(uuid []byte) error { return nil } -func RemoveGoogleIdByUsername(username string) error { +func (s *store) RemoveGoogleIdByUsername(username string) error { _, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE username = ?", username) if err != nil { return err @@ -461,7 +460,7 @@ func RemoveGoogleIdByUsername(username string) error { return nil } -func RemoveDiscordIdByUsername(username string) error { +func (s *store) RemoveDiscordIdByUsername(username string) error { _, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE username = ?", username) if err != nil { return err @@ -470,7 +469,7 @@ func RemoveDiscordIdByUsername(username string) error { return nil } -func RemoveDiscordIdByDiscordId(discordId string) error { +func (s *store) RemoveDiscordIdByDiscordId(discordId string) error { _, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE discordId = ?", discordId) if err != nil { return err @@ -479,7 +478,7 @@ func RemoveDiscordIdByDiscordId(discordId string) error { return nil } -func RemoveGoogleIdByDiscordId(discordId string) error { +func (s *store) RemoveGoogleIdByDiscordId(discordId string) error { _, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE discordId = ?", discordId) if err != nil { return err diff --git a/db/daily.go b/db/daily.go index 59c8f92..155c3d6 100644 --- a/db/daily.go +++ b/db/daily.go @@ -23,7 +23,7 @@ import ( "github.com/pagefaultgames/rogueserver/defs" ) -func TryAddDailyRun(seed string) (string, error) { +func (s *store) TryAddDailyRun(seed string) (string, error) { var actualSeed string err := handle.QueryRow("INSERT INTO dailyRuns (seed, date) VALUES (?, UTC_DATE()) ON DUPLICATE KEY UPDATE date = date RETURNING seed", seed).Scan(&actualSeed) if err != nil { @@ -33,7 +33,7 @@ func TryAddDailyRun(seed string) (string, error) { return actualSeed, nil } -func GetDailyRunSeed() (string, error) { +func (s *store) GetDailyRunSeed() (string, error) { var seed string err := handle.QueryRow("SELECT seed FROM dailyRuns WHERE date = UTC_DATE()").Scan(&seed) if err != nil { @@ -43,7 +43,7 @@ func GetDailyRunSeed() (string, error) { return seed, nil } -func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error { +func (s *store) AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error { _, err := handle.Exec("INSERT INTO accountDailyRuns (uuid, date, score, wave, timestamp) VALUES (?, UTC_DATE(), ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE score = GREATEST(score, ?), wave = GREATEST(wave, ?), timestamp = IF(score < ?, UTC_TIMESTAMP(), timestamp)", uuid, score, wave, score, wave, score) if err != nil { return err @@ -52,7 +52,7 @@ func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error { return nil } -func FetchRankings(category int, page int) ([]defs.DailyRanking, error) { +func (s *store) FetchRankings(category int, page int) ([]defs.DailyRanking, error) { var rankings []defs.DailyRanking offset := (page - 1) * 10 @@ -85,7 +85,7 @@ func FetchRankings(category int, page int) ([]defs.DailyRanking, error) { return rankings, nil } -func FetchRankingPageCount(category int) (int, error) { +func (s *store) FetchRankingPageCount(category int) (int, error) { var query string switch category { case 0: diff --git a/db/db.go b/db/db.go index 06adb03..25b0039 100644 --- a/db/db.go +++ b/db/db.go @@ -31,6 +31,12 @@ import ( var handle *sql.DB var s3client *s3.Client +// internal type used to implement the Store interface +type store struct{} + +// Store is the global instance for DB access. +var Store = &store{} + func Init(username, password, protocol, address, database string) error { var err error diff --git a/db/game.go b/db/game.go index ec862b6..7e722cc 100644 --- a/db/game.go +++ b/db/game.go @@ -17,7 +17,7 @@ package db -func FetchPlayerCount() (int, error) { +func (s *store) FetchPlayerCount() (int, error) { var playerCount int err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount) if err != nil { @@ -27,7 +27,7 @@ func FetchPlayerCount() (int, error) { return playerCount, nil } -func FetchBattleCount() (int, error) { +func (s *store) FetchBattleCount() (int, error) { var battleCount int err := handle.QueryRow("SELECT COALESCE(SUM(s.battles), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&battleCount) if err != nil { @@ -37,7 +37,7 @@ func FetchBattleCount() (int, error) { return battleCount, nil } -func FetchClassicSessionCount() (int, error) { +func (s *store) FetchClassicSessionCount() (int, error) { var classicSessionCount int err := handle.QueryRow("SELECT COALESCE(SUM(s.classicSessionsPlayed), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&classicSessionCount) if err != nil { diff --git a/db/s3.go b/db/s3.go new file mode 100644 index 0000000..3ac41af --- /dev/null +++ b/db/s3.go @@ -0,0 +1,61 @@ +package db + +import ( + "bytes" + "context" + "encoding/json" + "os" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/pagefaultgames/rogueserver/defs" +) + +func (s *store) GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) { + var system defs.SystemSaveData + + username, err := Store.FetchUsernameFromUUID(uuid) + if err != nil { + return system, err + } + + resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{ + Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")), + Key: aws.String(username), + }) + if err != nil { + return system, err + } + + err = json.NewDecoder(resp.Body).Decode(&system) + if err != nil { + return system, err + } + + return system, nil +} + +func (s *store) StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error { + username, err := s.FetchUsernameFromUUID(uuid) + if err != nil { + return err + } + + buf := new(bytes.Buffer) + + err = json.NewEncoder(buf).Encode(data) + if err != nil { + return err + } + + _, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{ + Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")), + Key: aws.String(username), + Body: buf, + }) + if err != nil { + return err + } + + return nil +} diff --git a/db/savedata.go b/db/savedata.go index ce86720..9226d5e 100644 --- a/db/savedata.go +++ b/db/savedata.go @@ -19,19 +19,13 @@ package db import ( "bytes" - "context" "encoding/gob" - "encoding/json" - "os" "github.com/klauspost/compress/zstd" "github.com/pagefaultgames/rogueserver/defs" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/service/s3" ) -func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) { +func (s *store) TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) { var count int err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count) if err != nil { @@ -48,7 +42,7 @@ func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) { return true, nil } -func ReadSeedCompleted(uuid []byte, seed string) (bool, error) { +func (s *store) ReadSeedCompleted(uuid []byte, seed string) (bool, error) { var count int err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count) if err != nil { @@ -58,7 +52,7 @@ func ReadSeedCompleted(uuid []byte, seed string) (bool, error) { return count > 0, nil } -func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) { +func (s *store) ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) { var system defs.SystemSaveData var data []byte @@ -82,7 +76,7 @@ func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) { return system, nil } -func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { +func (s *store) StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { buf := new(bytes.Buffer) zw, err := zstd.NewWriter(buf) @@ -108,32 +102,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { return nil } -func StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error { - username, err := FetchUsernameFromUUID(uuid) - if err != nil { - return err - } - - buf := new(bytes.Buffer) - - err = json.NewEncoder(buf).Encode(data) - if err != nil { - return err - } - - _, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{ - Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")), - Key: aws.String(username), - Body: buf, - }) - if err != nil { - return err - } - - return nil -} - -func DeleteSystemSaveData(uuid []byte) error { +func (s *store) DeleteSystemSaveData(uuid []byte) error { _, err := handle.Exec("DELETE FROM systemSaveData WHERE uuid = ?", uuid) if err != nil { return err @@ -142,7 +111,7 @@ func DeleteSystemSaveData(uuid []byte) error { return nil } -func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) { +func (s *store) ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) { var session defs.SessionSaveData var data []byte @@ -166,7 +135,7 @@ func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) { return session, nil } -func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) { +func (s *store) GetLatestSessionSaveDataSlot(uuid []byte) (int, error) { var slot int err := handle.QueryRow("SELECT slot FROM sessionSaveData WHERE uuid = ? ORDER BY timestamp DESC, slot ASC LIMIT 1", uuid).Scan(&slot) if err != nil { @@ -176,7 +145,7 @@ func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) { return slot, nil } -func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error { +func (s *store) StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error { buf := new(bytes.Buffer) zw, err := zstd.NewWriter(buf) @@ -202,7 +171,7 @@ func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) erro return nil } -func DeleteSessionSaveData(uuid []byte, slot int) error { +func (s *store) DeleteSessionSaveData(uuid []byte, slot int) error { _, err := handle.Exec("DELETE FROM sessionSaveData WHERE uuid = ? AND slot = ?", uuid, slot) if err != nil { return err @@ -210,27 +179,3 @@ func DeleteSessionSaveData(uuid []byte, slot int) error { return nil } - -func GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) { - var system defs.SystemSaveData - - username, err := FetchUsernameFromUUID(uuid) - if err != nil { - return system, err - } - - resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{ - Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")), - Key: aws.String(username), - }) - if err != nil { - return system, err - } - - err = json.NewDecoder(resp.Body).Decode(&system) - if err != nil { - return system, err - } - - return system, nil -}