diff --git a/.golangci.yml b/.golangci.yml
index 58c100b..df75f09 100644
--- a/.golangci.yml
+++ b/.golangci.yml
@@ -1,7 +1,8 @@
+version: "2"
run:
timeout: 10m
severity:
- default-severity: error
+ default: error
rules:
- linters:
- unused
diff --git a/api/account/changepw.go b/api/account/changepw.go
index e9726ec..edb36da 100644
--- a/api/account/changepw.go
+++ b/api/account/changepw.go
@@ -20,11 +20,15 @@ package account
import (
"crypto/rand"
"fmt"
-
- "github.com/pagefaultgames/rogueserver/db"
)
-func ChangePW(uuid []byte, password string) error {
+// Interface for database operations needed for changing password.
+type ChangePWStore interface {
+ RemoveSessionsFromUUID(uuid []byte) error
+ UpdateAccountPassword(uuid []byte, newKey []byte, newSalt []byte) error
+}
+
+func ChangePW[T ChangePWStore](store T, uuid []byte, password string) error {
if len(password) < 6 {
return fmt.Errorf("invalid password")
}
@@ -35,12 +39,12 @@ func ChangePW(uuid []byte, password string) error {
return fmt.Errorf("failed to generate salt: %s", err)
}
- err = db.RemoveSessionsFromUUID(uuid)
+ err = store.RemoveSessionsFromUUID(uuid)
if err != nil {
return fmt.Errorf("failed to remove sessions: %s", err)
}
- err = db.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
+ err = store.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil {
return fmt.Errorf("failed to add account record: %s", err)
}
diff --git a/api/account/discord.go b/api/account/discord.go
index 0be4087..f959eb8 100644
--- a/api/account/discord.go
+++ b/api/account/discord.go
@@ -35,14 +35,24 @@ var (
DiscordGuildID string
)
-func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
+type DiscordProvider interface {
+ HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error)
+ RetrieveDiscordId(code string) (string, error)
+ IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error)
+}
+
+type discordProvider struct{}
+
+var Discord = &discordProvider{}
+
+func (s *discordProvider) HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
if code == "" {
http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", errors.New("code is empty")
}
- discordId, err := RetrieveDiscordId(code)
+ discordId, err := s.RetrieveDiscordId(code)
if err != nil {
http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", err
@@ -51,7 +61,7 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro
return discordId, nil
}
-func RetrieveDiscordId(code string) (string, error) {
+func (s *discordProvider) RetrieveDiscordId(code string) (string, error) {
v := make(url.Values)
v.Set("client_id", DiscordClientID)
v.Set("client_secret", DiscordClientSecret)
@@ -112,7 +122,7 @@ func RetrieveDiscordId(code string) (string, error) {
return user.Id, nil
}
-func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
+func (s *discordProvider) IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
// fetch all roles from discord
roles, err := DiscordSession.GuildRoles(discordGuildID)
if err != nil {
diff --git a/api/account/google.go b/api/account/google.go
index 30c5ae5..28ee6b4 100644
--- a/api/account/google.go
+++ b/api/account/google.go
@@ -32,7 +32,16 @@ var (
GoogleCallbackURL string
)
-func HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
+type GoogleProvider interface {
+ HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error)
+ RetrieveGoogleId(code string) (string, error)
+}
+
+type googleProvider struct{}
+
+var Google = &googleProvider{}
+
+func (g *googleProvider) HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code")
if code == "" {
http.Redirect(w, r, GameURL, http.StatusSeeOther)
diff --git a/api/account/info.go b/api/account/info.go
index 32cf9ff..67d96f6 100644
--- a/api/account/info.go
+++ b/api/account/info.go
@@ -17,10 +17,6 @@
package account
-import (
- "github.com/pagefaultgames/rogueserver/db"
-)
-
type InfoResponse struct {
Username string `json:"username"`
DiscordId string `json:"discordId"`
@@ -29,9 +25,13 @@ type InfoResponse struct {
HasAdminRole bool `json:"hasAdminRole"`
}
+type InfoStore interface {
+ GetLatestSessionSaveDataSlot(uuid []byte) (int, error)
+}
+
// /account/info - get account info
-func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
- slot, _ := db.GetLatestSessionSaveDataSlot(uuid)
+func Info[T InfoStore](store T, username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
+ slot, _ := store.GetLatestSessionSaveDataSlot(uuid)
response := InfoResponse{
Username: username,
LastSessionSlot: slot,
diff --git a/api/account/login.go b/api/account/login.go
index a283444..41a373f 100644
--- a/api/account/login.go
+++ b/api/account/login.go
@@ -24,14 +24,18 @@ import (
"encoding/base64"
"errors"
"fmt"
-
- "github.com/pagefaultgames/rogueserver/db"
)
type LoginResponse GenericAuthResponse
+// Interface for database operations needed for login.
+type LoginStore interface {
+ FetchAccountKeySaltFromUsername(username string) (key, salt []byte, err error)
+ AddAccountSession(username string, token []byte) error
+}
+
// /account/login - log into account
-func Login(username, password string) (LoginResponse, error) {
+func Login[T LoginStore](store T, username, password string) (LoginResponse, error) {
var response LoginResponse
if !isValidUsername(username) {
@@ -42,12 +46,11 @@ func Login(username, password string) (LoginResponse, error) {
return response, fmt.Errorf("invalid password")
}
- key, salt, err := db.FetchAccountKeySaltFromUsername(username)
+ key, salt, err := store.FetchAccountKeySaltFromUsername(username)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return response, fmt.Errorf("account doesn't exist")
}
-
return response, err
}
@@ -55,8 +58,7 @@ func Login(username, password string) (LoginResponse, error) {
return response, fmt.Errorf("password doesn't match")
}
- response.Token, err = GenerateTokenForUsername(username)
-
+ response.Token, err = GenerateTokenForUsername(store, username)
if err != nil {
return response, fmt.Errorf("failed to generate token: %s", err)
}
@@ -64,14 +66,19 @@ func Login(username, password string) (LoginResponse, error) {
return response, nil
}
-func GenerateTokenForUsername(username string) (string, error) {
+type GenerateTokenForUsernameStore interface {
+ AddAccountSession(username string, token []byte) error
+}
+
+// GenerateTokenForUsername generates a session token and stores it using the provided DBAccountStore.
+func GenerateTokenForUsername[T GenerateTokenForUsernameStore](store T, username string) (string, error) {
token := make([]byte, TokenSize)
_, err := rand.Read(token)
if err != nil {
return "", fmt.Errorf("failed to generate token: %s", err)
}
- err = db.AddAccountSession(username, token)
+ err = store.AddAccountSession(username, token)
if err != nil {
return "", fmt.Errorf("failed to add account session")
}
diff --git a/api/account/login_test.go b/api/account/login_test.go
new file mode 100644
index 0000000..c883065
--- /dev/null
+++ b/api/account/login_test.go
@@ -0,0 +1,154 @@
+package account
+
+import (
+ "database/sql"
+ "errors"
+ "testing"
+)
+
+func defaultMockStore() *mockDBAccountStore {
+ return &mockDBAccountStore{
+ FetchFunc: func(username string) ([]byte, []byte, error) {
+ return []byte("key"), []byte("salt"), nil
+ },
+ AddSessionFunc: func(username string, token []byte) error { return nil },
+ }
+}
+
+type mockDBAccountStore struct {
+ FetchFunc func(username string) ([]byte, []byte, error)
+ AddSessionFunc func(username string, token []byte) error
+}
+
+func (m *mockDBAccountStore) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
+ return m.FetchFunc(username)
+}
+func (m *mockDBAccountStore) AddAccountSession(username string, token []byte) error {
+ return m.AddSessionFunc(username, token)
+}
+
+func TestLogin(t *testing.T) {
+ t.Run("UsernameMinLength", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "a", "password123")
+ if err == nil {
+ t.Errorf("expected error due to password mismatch or DB, got nil")
+ }
+ })
+ t.Run("UsernameMaxLength", func(t *testing.T) {
+ uname := "abcdefghijklmnop"
+ store := defaultMockStore()
+ _, err := Login(store, uname, "password123")
+ if err == nil {
+ t.Errorf("expected error due to password mismatch or DB, got nil")
+ }
+ })
+ t.Run("UsernameTooLong", func(t *testing.T) {
+ uname := "abcdefghijklmnopq"
+ store := defaultMockStore()
+ _, err := Login(store, uname, "password123")
+ if err == nil || err.Error() != "invalid username" {
+ t.Errorf("expected invalid username error for too long username, got: %v", err)
+ }
+ })
+ t.Run("UsernameWithInvalidChars", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "user!@#", "password123")
+ if err == nil || err.Error() != "invalid username" {
+ t.Errorf("expected invalid username error for special chars, got: %v", err)
+ }
+ })
+ t.Run("EmptyUsername", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "", "password123")
+ if err == nil || err.Error() != "invalid username" {
+ t.Errorf("expected invalid username error for empty username, got: %v", err)
+ }
+ })
+ t.Run("EmptyPassword", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "validuser", "")
+ if err == nil || err.Error() != "invalid password" {
+ t.Errorf("expected invalid password error for empty password, got: %v", err)
+ }
+ })
+ t.Run("MinPasswordLength", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "validuser", "123456")
+ if err == nil {
+ t.Errorf("expected error due to password mismatch or DB, got nil")
+ }
+ })
+ t.Run("PasswordWithSpecialChars", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "validuser", "p@$$w0rd!")
+ if err == nil {
+ t.Errorf("expected error due to password mismatch or DB, got nil")
+ }
+ })
+ t.Run("DBUnexpectedError", func(t *testing.T) {
+ store := defaultMockStore()
+ store.FetchFunc = func(username string) ([]byte, []byte, error) {
+ return nil, nil, errors.New("some db error")
+ }
+ _, err := Login(store, "validuser", "password123")
+ if err == nil || err.Error() != "some db error" {
+ t.Errorf("expected DB error to propagate, got: %v", err)
+ }
+ })
+ t.Run("InvalidUsername", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "!invaliduser", "password123")
+ if err == nil || err.Error() != "invalid username" {
+ t.Errorf("expected invalid username error, got: %v", err)
+ }
+ })
+ t.Run("ShortPassword", func(t *testing.T) {
+ store := defaultMockStore()
+ _, err := Login(store, "validuser", "123")
+ if err == nil || err.Error() != "invalid password" {
+ t.Errorf("expected invalid password error, got: %v", err)
+ }
+ })
+ t.Run("AccountDoesNotExist", func(t *testing.T) {
+ store := defaultMockStore()
+ store.FetchFunc = func(username string) ([]byte, []byte, error) {
+ return nil, nil, sql.ErrNoRows
+ }
+ _, err := Login(store, "nonexistent", "password123")
+ if err == nil || err.Error() != "account doesn't exist" {
+ t.Errorf("expected account doesn't exist error, got: %v", err)
+ }
+ })
+ t.Run("PasswordMismatch", func(t *testing.T) {
+ correctSalt := []byte("somesalt")
+ correctKey := []byte("correctkey")
+ store := defaultMockStore()
+ store.FetchFunc = func(username string) ([]byte, []byte, error) {
+ return correctKey, correctSalt, nil
+ }
+ _, err := Login(store, "validuser", "wrongpassword")
+ if err == nil || err.Error() != "password doesn't match" {
+ t.Errorf("expected password doesn't match error, got: %v", err)
+ }
+ })
+ t.Run("Success", func(t *testing.T) {
+ correctSalt := []byte("somesalt")
+ password := "goodpassword"
+ correctKey := deriveArgon2IDKey([]byte(password), correctSalt)
+ store := defaultMockStore()
+ store.FetchFunc = func(username string) ([]byte, []byte, error) {
+ return correctKey, correctSalt, nil
+ }
+ store.AddSessionFunc = func(username string, token []byte) error {
+ return nil
+ }
+ resp, err := Login(store, "validuser", password)
+ if err != nil {
+ t.Errorf("expected success, got error: %v", err)
+ }
+ if resp.Token == "" {
+ t.Errorf("expected token to be set on success")
+ }
+ })
+}
diff --git a/api/account/logout.go b/api/account/logout.go
index ae3b26f..385f0e7 100644
--- a/api/account/logout.go
+++ b/api/account/logout.go
@@ -21,13 +21,17 @@ import (
"database/sql"
"errors"
"fmt"
-
- "github.com/pagefaultgames/rogueserver/db"
)
// /account/logout - log out of account
-func Logout(token []byte) error {
- err := db.RemoveSessionFromToken(token)
+
+// Interface for database operations needed for logout.
+type LogoutStore interface {
+ RemoveSessionFromToken(token []byte) error
+}
+
+func Logout[T LogoutStore](store T, token []byte) error {
+ err := store.RemoveSessionFromToken(token)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("token not found")
diff --git a/api/account/register.go b/api/account/register.go
index 06639e5..f7fc2c2 100644
--- a/api/account/register.go
+++ b/api/account/register.go
@@ -20,12 +20,15 @@ package account
import (
"crypto/rand"
"fmt"
-
- "github.com/pagefaultgames/rogueserver/db"
)
+// Interface for database operations needed for registration.
+type RegisterStore interface {
+ AddAccountRecord(uuid []byte, username string, passwordHash []byte, salt []byte) error
+}
+
// /account/register - register account
-func Register(username, password string) error {
+func Register[T RegisterStore](store T, username, password string) error {
if !isValidUsername(username) {
return fmt.Errorf("invalid username")
}
@@ -46,7 +49,7 @@ func Register(username, password string) error {
return fmt.Errorf("failed to generate salt: %s", err)
}
- err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
+ err = store.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil {
return fmt.Errorf("failed to add account record: %s", err)
}
diff --git a/api/common.go b/api/common.go
index f7f4a15..a7324ad 100644
--- a/api/common.go
+++ b/api/common.go
@@ -30,7 +30,7 @@ import (
)
func Init(mux *http.ServeMux) error {
- err := scheduleStatRefresh()
+ err := scheduleStatRefresh(db.Store)
if err != nil {
return err
}
@@ -109,7 +109,7 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) {
return nil, nil, err
}
- uuid, err := db.FetchUUIDFromToken(token)
+ uuid, err := db.Store.FetchUUIDFromToken(token)
if err != nil {
return nil, nil, fmt.Errorf("failed to validate token: %s", err)
}
diff --git a/api/daily/common.go b/api/daily/common.go
index d31d78e..15a662b 100644
--- a/api/daily/common.go
+++ b/api/daily/common.go
@@ -61,7 +61,7 @@ func Init() error {
secret = newSecret
}
- seed, err := db.TryAddDailyRun(Seed())
+ seed, err := db.Store.TryAddDailyRun(Seed())
if err != nil {
log.Print(err)
}
@@ -71,7 +71,7 @@ func Init() error {
_, err = scheduler.AddFunc("@daily", func() {
time.Sleep(time.Second)
- seed, err = db.TryAddDailyRun(Seed())
+ seed, err = db.Store.TryAddDailyRun(Seed())
if err != nil {
log.Printf("error while recording new daily: %s", err)
} else {
diff --git a/api/daily/rankings.go b/api/daily/rankings.go
index 175a486..1acc1e0 100644
--- a/api/daily/rankings.go
+++ b/api/daily/rankings.go
@@ -22,9 +22,14 @@ import (
"github.com/pagefaultgames/rogueserver/defs"
)
+// Interface for database operations needed for fetching rankings.
+type RankingsStore interface {
+ FetchRankings(category, page int) ([]defs.DailyRanking, error)
+}
+
// /daily/rankings - fetch daily rankings
-func Rankings(category, page int) ([]defs.DailyRanking, error) {
- rankings, err := db.FetchRankings(category, page)
+func Rankings[T RankingsStore](store T, category, page int) ([]defs.DailyRanking, error) {
+ rankings, err := db.Store.FetchRankings(category, page)
if err != nil {
return rankings, err
}
diff --git a/api/daily/rankingspagecount.go b/api/daily/rankingspagecount.go
index 5950c77..76ddabc 100644
--- a/api/daily/rankingspagecount.go
+++ b/api/daily/rankingspagecount.go
@@ -17,13 +17,13 @@
package daily
-import (
- "github.com/pagefaultgames/rogueserver/db"
-)
+type RankingPageCountStore interface {
+ FetchRankingPageCount(category int) (int, error)
+}
// /daily/rankingpagecount - fetch daily ranking page count
-func RankingPageCount(category int) (int, error) {
- pageCount, err := db.FetchRankingPageCount(category)
+func RankingPageCount[T RankingPageCountStore](store T, category int) (int, error) {
+ pageCount, err := store.FetchRankingPageCount(category)
if err != nil {
return pageCount, err
}
diff --git a/api/endpoints.go b/api/endpoints.go
index 8db9f1a..5630bd8 100644
--- a/api/endpoints.go
+++ b/api/endpoints.go
@@ -50,17 +50,17 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
return
}
- username, err := db.FetchUsernameFromUUID(uuid)
+ username, err := db.Store.FetchUsernameFromUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- discordId, err := db.FetchDiscordIdByUsername(username)
+ discordId, err := db.Store.FetchDiscordIdByUsername(username)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- googleId, err := db.FetchGoogleIdByUsername(username)
+ googleId, err := db.Store.FetchGoogleIdByUsername(username)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -68,10 +68,10 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
var hasAdminRole bool
if discordId != "" {
- hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
+ hasAdminRole, _ = account.Discord.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
}
- response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole)
+ response, err := account.Info(db.Store, username, discordId, googleId, uuid, hasAdminRole)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -81,7 +81,7 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
}
func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
- err := account.Register(r.PostFormValue("username"), r.PostFormValue("password"))
+ err := account.Register(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -91,7 +91,7 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
}
func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
- response, err := account.Login(r.PostFormValue("username"), r.PostFormValue("password"))
+ response, err := account.Login(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -107,20 +107,20 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) {
return
}
- err = account.ChangePW(uuid, r.PostFormValue("password"))
+ err = account.ChangePW(db.Store, uuid, r.PostFormValue("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- username, err := db.FetchUsernameFromUUID(uuid)
+ username, err := db.Store.FetchUsernameFromUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
// create a new session with these credentials
- response, err := account.Login(username, r.Form.Get("password"))
+ response, err := account.Login(db.Store, username, r.Form.Get("password"))
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -136,7 +136,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
return
}
- err = account.Logout(token)
+ err = account.Logout(db.Store, token)
if err != nil {
// also possible for InternalServerError but that's unlikely unless the server blew up
httpError(w, r, err, http.StatusUnauthorized)
@@ -183,7 +183,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return
}
- err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
+ err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
@@ -191,7 +191,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("action") {
case "get":
- save, err := savedata.GetSession(uuid, slot)
+ save, err := savedata.GetSession(db.Store, uuid, slot)
if err != nil {
if errors.Is(err, savedata.ErrSaveNotExist) {
http.Error(w, err.Error(), http.StatusNotFound)
@@ -211,7 +211,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return
}
- existingSave, err := savedata.GetSession(uuid, slot)
+ existingSave, err := savedata.GetSession(db.Store, uuid, slot)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
@@ -224,7 +224,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
}
}
- err = savedata.UpdateSession(uuid, slot, session)
+ err = savedata.UpdateSession(db.Store, uuid, slot, session)
if err != nil {
httpError(w, r, fmt.Errorf("failed to put session data: %s", err), http.StatusInternalServerError)
return
@@ -239,13 +239,13 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return
}
- seed, err := db.GetDailyRunSeed()
+ seed, err := db.Store.GetDailyRunSeed()
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- resp, err := savedata.Clear(uuid, slot, seed, session)
+ resp, err := savedata.Clear(db.Store, uuid, slot, seed, session)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -253,7 +253,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, resp)
case "newclear":
- resp, err := savedata.NewClear(uuid, slot)
+ resp, err := savedata.NewClear(db.Store, uuid, slot)
if err != nil {
httpError(w, r, fmt.Errorf("failed to read new clear: %s", err), http.StatusInternalServerError)
return
@@ -261,7 +261,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, resp)
case "delete":
- err := savedata.DeleteSession(uuid, slot)
+ err := savedata.DeleteSession(db.Store, uuid, slot)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -301,7 +301,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return
}
- active, err := db.IsActiveSession(uuid, data.ClientSessionId)
+ active, err := db.Store.IsActiveSession(uuid, data.ClientSessionId)
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
@@ -312,7 +312,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return
}
- storedTrainerId, storedSecretId, err := db.FetchTrainerIds(uuid)
+ storedTrainerId, storedSecretId, err := db.Store.FetchTrainerIds(uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -324,14 +324,14 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return
}
} else {
- err = db.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid)
+ err = db.Store.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
}
- oldSystem, err := savedata.GetSystem(uuid)
+ oldSystem, err := savedata.GetSystem(db.Store, uuid)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
@@ -356,7 +356,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
}
}
- existingSave, err := savedata.GetSession(uuid, data.SessionSlotId)
+ existingSave, err := savedata.GetSession(db.Store, uuid, data.SessionSlotId)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
@@ -369,13 +369,13 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
}
}
- err = savedata.Update(uuid, data.SessionSlotId, data.Session)
+ err = savedata.Update(db.Store, uuid, data.SessionSlotId, data.Session)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- err = savedata.Update(uuid, 0, data.System)
+ err = savedata.Update(db.Store, uuid, 0, data.System)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -401,7 +401,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
return
}
- active, err := db.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
+ active, err := db.Store.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return
@@ -410,14 +410,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("action") {
case "get":
if !active {
- err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
+ err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}
}
- save, err := savedata.GetSystem(uuid)
+ save, err := savedata.GetSystem(db.Store, uuid)
if err != nil {
if errors.Is(err, savedata.ErrSaveNotExist) {
http.Error(w, err.Error(), http.StatusNotFound)
@@ -442,7 +442,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
return
}
- oldSystem, err := savedata.GetSystem(uuid)
+ oldSystem, err := savedata.GetSystem(db.Store, uuid)
if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
@@ -467,7 +467,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
}
}
- err = savedata.UpdateSystem(uuid, system)
+ err = savedata.UpdateSystem(db.Store, uuid, system)
if err != nil {
httpError(w, r, fmt.Errorf("failed to put system data: %s", err), http.StatusInternalServerError)
return
@@ -481,13 +481,13 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
// not valid, send server state
if !active {
- err := db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
+ err := db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return
}
- storedSaveData, err := db.ReadSystemSaveData(uuid)
+ storedSaveData, err := db.Store.ReadSystemSaveData(uuid)
if err != nil {
httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError)
return
@@ -498,7 +498,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, response)
case "delete":
- err := savedata.DeleteSystem(uuid)
+ err := savedata.DeleteSystem(db.Store, uuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -511,9 +511,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
}
}
+// Interface providing database operations needed for getting daily seed.
+type HandleDailySeedStore interface {
+ GetDailyRunSeed() (string, error)
+}
+
// daily
func handleDailySeed(w http.ResponseWriter, r *http.Request) {
- seed, err := db.GetDailyRunSeed()
+ seed, err := db.Store.GetDailyRunSeed()
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -522,6 +527,11 @@ func handleDailySeed(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, seed)
}
+// Interface for database operations needed for getting daily rankings.
+type HandleDailyRankingsStore interface {
+ daily.RankingsStore
+}
+
func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
var err error
@@ -543,7 +553,7 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
}
}
- rankings, err := daily.Rankings(category, page)
+ rankings, err := daily.Rankings(db.Store, category, page)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -552,6 +562,10 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, rankings)
}
+type HandleDailyRankingsPageCountStore interface {
+ daily.RankingPageCountStore
+}
+
func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
var category int
if r.URL.Query().Has("category") {
@@ -563,7 +577,7 @@ func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
}
}
- count, err := daily.RankingPageCount(category)
+ count, err := daily.RankingPageCount(db.Store, category)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
}
@@ -579,9 +593,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
var err error
switch provider {
case "discord":
- externalAuthId, err = account.HandleDiscordCallback(w, r)
+ externalAuthId, err = account.Discord.HandleDiscordCallback(w, r)
case "google":
- externalAuthId, err = account.HandleGoogleCallback(w, r)
+ externalAuthId, err = account.Google.HandleGoogleCallback(w, r)
default:
http.Error(w, "invalid provider", http.StatusBadRequest)
return
@@ -600,7 +614,7 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
return
}
- userName, err := db.FetchUsernameBySessionToken(stateByte)
+ userName, err := db.Store.FetchUsernameBySessionToken(stateByte)
if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return
@@ -608,9 +622,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
switch provider {
case "discord":
- err = db.AddDiscordIdByUsername(externalAuthId, userName)
+ err = db.Store.AddDiscordIdByUsername(externalAuthId, userName)
case "google":
- err = db.AddGoogleIdByUsername(externalAuthId, userName)
+ err = db.Store.AddGoogleIdByUsername(externalAuthId, userName)
}
if err != nil {
@@ -622,16 +636,16 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
var userName string
switch provider {
case "discord":
- userName, err = db.FetchUsernameByDiscordId(externalAuthId)
+ userName, err = db.Store.FetchUsernameByDiscordId(externalAuthId)
case "google":
- userName, err = db.FetchUsernameByGoogleId(externalAuthId)
+ userName, err = db.Store.FetchUsernameByGoogleId(externalAuthId)
}
if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return
}
- sessionToken, err := account.GenerateTokenForUsername(userName)
+ sessionToken, err := account.GenerateTokenForUsername(db.Store, userName)
if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return
@@ -651,6 +665,11 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
}
+type HandleProviderLogoutStore interface {
+ RemoveDiscordIdByUUID(uuid []byte) error
+ RemoveGoogleIdByUUID(uuid []byte) error
+}
+
func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r)
if err != nil {
@@ -660,9 +679,9 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("provider") {
case "discord":
- err = db.RemoveDiscordIdByUUID(uuid)
+ err = db.Store.RemoveDiscordIdByUUID(uuid)
case "google":
- err = db.RemoveGoogleIdByUUID(uuid)
+ err = db.Store.RemoveGoogleIdByUUID(uuid)
default:
http.Error(w, "invalid provider", http.StatusBadRequest)
return
@@ -681,13 +700,13 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
return
}
- userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
+ userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
- hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
+ hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@@ -698,19 +717,19 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
- _, err = db.CheckUsernameExists(username)
+ _, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
- userUuid, err := db.FetchUUIDFromUsername(username)
+ userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- err = db.AddDiscordIdByUUID(discordId, userUuid)
+ err = db.Store.AddDiscordIdByUUID(discordId, userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -728,13 +747,13 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
return
}
- userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
+ userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
- hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
+ hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@@ -748,26 +767,26 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
log.Printf("Username given, removing discordId")
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
- _, err = db.CheckUsernameExists(username)
+ _, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
- userUuid, err := db.FetchUUIDFromUsername(username)
+ userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- err = db.RemoveDiscordIdByUUID(userUuid)
+ err = db.Store.RemoveDiscordIdByUUID(userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
case discordId != "":
log.Printf("DiscordID given, removing discordId")
- err = db.RemoveDiscordIdByDiscordId(discordId)
+ err = db.Store.RemoveDiscordIdByDiscordId(discordId)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -786,13 +805,13 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
return
}
- userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
+ userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
- hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
+ hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@@ -803,19 +822,19 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
- _, err = db.CheckUsernameExists(username)
+ _, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
- userUuid, err := db.FetchUUIDFromUsername(username)
+ userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- err = db.AddGoogleIdByUUID(googleId, userUuid)
+ err = db.Store.AddGoogleIdByUUID(googleId, userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -833,13 +852,13 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
return
}
- userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
+ userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
- hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
+ hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@@ -853,26 +872,26 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
log.Printf("Username given, removing googleId")
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
- _, err = db.CheckUsernameExists(username)
+ _, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
- userUuid, err := db.FetchUUIDFromUsername(username)
+ userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
- err = db.RemoveGoogleIdByUUID(userUuid)
+ err = db.Store.RemoveGoogleIdByUUID(userUuid)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
}
case googleId != "":
log.Printf("DiscordID given, removing googleId")
- err = db.RemoveGoogleIdByDiscordId(googleId)
+ err = db.Store.RemoveGoogleIdByDiscordId(googleId)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
@@ -891,13 +910,13 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
return
}
- userDiscordId, err := db.FetchDiscordIdByUUID(uuid)
+ userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil {
httpError(w, r, err, http.StatusUnauthorized)
return
}
- hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
+ hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return
@@ -907,14 +926,14 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
- _, err = db.CheckUsernameExists(username)
+ _, err = db.Store.CheckUsernameExists(username)
if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return
}
// this does a single call that does a query for multiple columns from our database and makes an object out of it, which is returned to us
- adminSearchResult, err := db.FetchAdminDetailsByUsername(username)
+ adminSearchResult, err := db.Store.FetchAdminDetailsByUsername(username)
if err != nil {
httpError(w, r, err, http.StatusInternalServerError)
return
diff --git a/api/savedata/clear.go b/api/savedata/clear.go
index 35b9a59..ad50fcc 100644
--- a/api/savedata/clear.go
+++ b/api/savedata/clear.go
@@ -21,7 +21,6 @@ import (
"fmt"
"log"
- "github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
@@ -30,10 +29,20 @@ type ClearResponse struct {
Error string `json:"error"`
}
+// Interface for database operations needed for `Clear`
+// Helps with testing and reduces coupling.
+type ClearStore interface {
+ UpdateAccountLastActivity(uuid []byte) error
+ TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error)
+ DeleteSessionSaveData(uuid []byte, slot int) error
+ AddOrUpdateAccountDailyRun(uuid []byte, score int, waveCompleted int) error
+ SetAccountBanned(uuid []byte, banned bool) error
+}
+
// /savedata/clear - mark session save data as cleared and delete
-func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) {
+func Clear[T ClearStore](store T, uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) {
var response ClearResponse
- err := db.UpdateAccountLastActivity(uuid)
+ err := store.UpdateAccountLastActivity(uuid)
if err != nil {
log.Print("failed to update account last activity")
}
@@ -51,23 +60,23 @@ func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (Clear
}
if save.Score >= 20000 {
- db.SetAccountBanned(uuid, true)
+ store.SetAccountBanned(uuid, true)
}
- err = db.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
+ err = store.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
if err != nil {
log.Printf("failed to add or update daily run record: %s", err)
}
}
if sessionCompleted {
- response.Success, err = db.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
+ response.Success, err = store.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
if err != nil {
log.Printf("failed to mark seed as completed: %s", err)
}
}
- err = db.DeleteSessionSaveData(uuid, slot)
+ err = store.DeleteSessionSaveData(uuid, slot)
if err != nil {
log.Printf("failed to delete session save data: %s", err)
}
diff --git a/api/savedata/delete.go b/api/savedata/delete.go
index cb59c3e..fd57902 100644
--- a/api/savedata/delete.go
+++ b/api/savedata/delete.go
@@ -21,13 +21,19 @@ import (
"fmt"
"log"
- "github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
+// Interface for database operations needed for delete.
+// This is to allow easier testing and less coupling.
+type DeleteStore interface {
+ UpdateAccountLastActivity(uuid []byte) error
+ DeleteSessionSaveData(uuid []byte, slot int) error
+}
+
// /savedata/delete - delete save data
-func Delete(uuid []byte, datatype, slot int) error {
- err := db.UpdateAccountLastActivity(uuid)
+func Delete[T DeleteStore](store T, uuid []byte, datatype, slot int) error {
+ err := store.UpdateAccountLastActivity(uuid)
if err != nil {
log.Print("failed to update account last activity")
}
@@ -39,7 +45,7 @@ func Delete(uuid []byte, datatype, slot int) error {
break
}
- err = db.DeleteSessionSaveData(uuid, slot)
+ err = store.DeleteSessionSaveData(uuid, slot)
default:
err = fmt.Errorf("invalid data type")
}
diff --git a/api/savedata/newclear.go b/api/savedata/newclear.go
index 1420027..30c914d 100644
--- a/api/savedata/newclear.go
+++ b/api/savedata/newclear.go
@@ -20,22 +20,26 @@ package savedata
import (
"fmt"
- "github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
+type NewClearStore interface {
+ ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
+ ReadSeedCompleted(uuid []byte, seed string) (bool, error)
+}
+
// /savedata/newclear - return whether a session is a new clear for its seed
-func NewClear(uuid []byte, slot int) (bool, error) {
+func NewClear[T NewClearStore](store T, uuid []byte, slot int) (bool, error) {
if slot < 0 || slot >= defs.SessionSlotCount {
return false, fmt.Errorf("slot id %d out of range", slot)
}
- session, err := db.ReadSessionSaveData(uuid, slot)
+ session, err := store.ReadSessionSaveData(uuid, slot)
if err != nil {
return false, err
}
- completed, err := db.ReadSeedCompleted(uuid, session.Seed)
+ completed, err := store.ReadSeedCompleted(uuid, session.Seed)
if err != nil {
return false, fmt.Errorf("failed to read seed completed: %s", err)
}
diff --git a/api/savedata/session.go b/api/savedata/session.go
index 76bf71f..fc2916c 100644
--- a/api/savedata/session.go
+++ b/api/savedata/session.go
@@ -21,12 +21,15 @@ import (
"database/sql"
"errors"
- "github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
-func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) {
- session, err := db.ReadSessionSaveData(uuid, slot)
+type GetSessionStore interface {
+ ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
+}
+
+func GetSession[T GetSessionStore](store T, uuid []byte, slot int) (defs.SessionSaveData, error) {
+ session, err := store.ReadSessionSaveData(uuid, slot)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
err = ErrSaveNotExist
@@ -38,8 +41,13 @@ func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) {
return session, nil
}
-func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error {
- err := db.StoreSessionSaveData(uuid, data, slot)
+type UpdateSessionStore interface {
+ StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error
+ DeleteSessionSaveData(uuid []byte, slot int) error
+}
+
+func UpdateSession[T UpdateSessionStore](store T, uuid []byte, slot int, data defs.SessionSaveData) error {
+ err := store.StoreSessionSaveData(uuid, data, slot)
if err != nil {
return err
}
@@ -47,8 +55,12 @@ func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error {
return nil
}
-func DeleteSession(uuid []byte, slot int) error {
- err := db.DeleteSessionSaveData(uuid, slot)
+type DeleteSessionStore interface {
+ DeleteSessionSaveData(uuid []byte, slot int) error
+}
+
+func DeleteSession[T DeleteSessionStore](store T, uuid []byte, slot int) error {
+ err := store.DeleteSessionSaveData(uuid, slot)
if err != nil {
return err
}
diff --git a/api/savedata/system.go b/api/savedata/system.go
index 25d6b58..7d42bf7 100644
--- a/api/savedata/system.go
+++ b/api/savedata/system.go
@@ -24,24 +24,28 @@ import (
"os"
"github.com/aws/aws-sdk-go-v2/service/s3/types"
- "github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
var ErrSaveNotExist = errors.New("save does not exist")
-func GetSystem(uuid []byte) (defs.SystemSaveData, error) {
+type GetSystemStore interface {
+ GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error)
+ ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error)
+}
+
+func GetSystem[T GetSystemStore](store T, uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
var err error
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
- system, err = db.GetSystemSaveFromS3(uuid)
+ system, err = store.GetSystemSaveFromS3(uuid)
var nokey *types.NoSuchKey
if errors.As(err, &nokey) {
err = ErrSaveNotExist
}
} else { // use database
- system, err = db.ReadSystemSaveData(uuid)
+ system, err = store.ReadSystemSaveData(uuid)
if errors.Is(err, sql.ErrNoRows) {
err = ErrSaveNotExist
}
@@ -53,20 +57,27 @@ func GetSystem(uuid []byte) (defs.SystemSaveData, error) {
return system, nil
}
-func UpdateSystem(uuid []byte, data defs.SystemSaveData) error {
+// Interface for database operations needed for updating system data.
+type UpdateSystemStore interface {
+ UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error
+ StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error
+ StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error
+}
+
+func UpdateSystem[T UpdateSystemStore](store T, uuid []byte, data defs.SystemSaveData) error {
if data.TrainerId == 0 && data.SecretId == 0 {
return fmt.Errorf("invalid system data")
}
- err := db.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts)
+ err := store.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts)
if err != nil {
return fmt.Errorf("failed to update account stats: %s", err)
}
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
- err = db.StoreSystemSaveDataS3(uuid, data)
+ err = store.StoreSystemSaveDataS3(uuid, data)
} else {
- err = db.StoreSystemSaveData(uuid, data)
+ err = store.StoreSystemSaveData(uuid, data)
}
if err != nil {
return err
@@ -75,8 +86,12 @@ func UpdateSystem(uuid []byte, data defs.SystemSaveData) error {
return nil
}
-func DeleteSystem(uuid []byte) error {
- err := db.DeleteSystemSaveData(uuid)
+type DeleteSystemStore interface {
+ DeleteSystemSaveData(uuid []byte) error
+}
+
+func DeleteSystem[T DeleteSystemStore](store T, uuid []byte) error {
+ err := store.DeleteSystemSaveData(uuid)
if err != nil {
return err
}
diff --git a/api/savedata/update.go b/api/savedata/update.go
index 7bbe91d..52609bb 100644
--- a/api/savedata/update.go
+++ b/api/savedata/update.go
@@ -21,13 +21,18 @@ import (
"fmt"
"log"
- "github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs"
)
+type UpdateStore interface {
+ UpdateAccountLastActivity(uuid []byte) error
+ UpdateSystemStore
+ UpdateSessionStore
+}
+
// /savedata/update - update save data
-func Update(uuid []byte, slot int, save any) error {
- err := db.UpdateAccountLastActivity(uuid)
+func Update[T UpdateStore](store T, uuid []byte, slot int, save any) error {
+ err := store.UpdateAccountLastActivity(uuid)
if err != nil {
log.Print("failed to update account last activity")
}
@@ -38,13 +43,13 @@ func Update(uuid []byte, slot int, save any) error {
return fmt.Errorf("invalid system data")
}
- return UpdateSystem(uuid, save)
+ return UpdateSystem(store, uuid, save)
case defs.SessionSaveData: // Session
if slot < 0 || slot >= defs.SessionSlotCount {
return fmt.Errorf("slot id %d out of range", slot)
}
- return UpdateSession(uuid, slot, save)
+ return UpdateSession(store, uuid, slot, save)
default:
return fmt.Errorf("invalid data type")
}
diff --git a/api/stats.go b/api/stats.go
index c2ab176..a4b5933 100644
--- a/api/stats.go
+++ b/api/stats.go
@@ -21,7 +21,6 @@ import (
"log"
"time"
- "github.com/pagefaultgames/rogueserver/db"
"github.com/robfig/cron/v3"
)
@@ -32,9 +31,9 @@ var (
classicSessionCount int
)
-func scheduleStatRefresh() error {
+func scheduleStatRefresh[T updateStatsStore](store T) error {
_, err := scheduler.AddFunc("@every 30s", func() {
- err := updateStats()
+ err := updateStats(store)
if err != nil {
log.Printf("failed to update stats: %s", err)
}
@@ -47,19 +46,25 @@ func scheduleStatRefresh() error {
return nil
}
-func updateStats() error {
+type updateStatsStore interface {
+ FetchPlayerCount() (int, error)
+ FetchBattleCount() (int, error)
+ FetchClassicSessionCount() (int, error)
+}
+
+func updateStats[T updateStatsStore](store T) error {
var err error
- playerCount, err = db.FetchPlayerCount()
+ playerCount, err = store.FetchPlayerCount()
if err != nil {
return err
}
- battleCount, err = db.FetchBattleCount()
+ battleCount, err = store.FetchBattleCount()
if err != nil {
return err
}
- classicSessionCount, err = db.FetchClassicSessionCount()
+ classicSessionCount, err = store.FetchClassicSessionCount()
if err != nil {
return err
}
diff --git a/db/account.go b/db/account.go
index fbcce0e..6d2ec64 100644
--- a/db/account.go
+++ b/db/account.go
@@ -1,18 +1,18 @@
/*
- Copyright (C) 2024 - 2025 Pagefault Games
+ Copyright (C) 2024 - 2025 Pagefault Games
- This program is free software: you can redistribute it and/or modify
- it under the terms of the GNU Affero General Public License as published by
- the Free Software Foundation, either version 3 of the License, or
- (at your option) any later version.
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU Affero General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
- This program is distributed in the hope that it will be useful,
- but WITHOUT ANY WARRANTY; without even the implied warranty of
- MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
- GNU Affero General Public License for more details.
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU Affero General Public License for more details.
- You should have received a copy of the GNU Affero General Public License
- along with this program. If not, see .
+ You should have received a copy of the GNU Affero General Public License
+ along with this program. If not, see .
*/
package db
@@ -27,30 +27,30 @@ import (
"github.com/pagefaultgames/rogueserver/defs"
)
-func AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
- _, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func AddAccountSession(username string, token []byte) error {
+func (s *store) AddAccountSession(username string, token []byte) error {
_, err := handle.Exec("INSERT INTO sessions (uuid, token, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username)
if err != nil {
return err
}
-
_, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username)
if err != nil {
return err
}
-
return nil
}
-func AddDiscordIdByUsername(discordId string, username string) error {
+// (removed, now a method on store)
+func (s *store) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
+ var key, salt []byte
+ err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
+ if err != nil {
+ return nil, nil, err
+ }
+
+ return key, salt, nil
+}
+
+func (s *store) AddDiscordIdByUsername(discordId string, username string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE username = ?", discordId, username)
if err != nil {
return err
@@ -59,7 +59,7 @@ func AddDiscordIdByUsername(discordId string, username string) error {
return nil
}
-func AddGoogleIdByUsername(googleId string, username string) error {
+func (s *store) AddGoogleIdByUsername(googleId string, username string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE username = ?", googleId, username)
if err != nil {
return err
@@ -68,7 +68,7 @@ func AddGoogleIdByUsername(googleId string, username string) error {
return nil
}
-func AddGoogleIdByUUID(googleId string, uuid []byte) error {
+func (s *store) AddGoogleIdByUUID(googleId string, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE uuid = ?", googleId, uuid)
if err != nil {
return err
@@ -77,7 +77,7 @@ func AddGoogleIdByUUID(googleId string, uuid []byte) error {
return nil
}
-func AddDiscordIdByUUID(discordId string, uuid []byte) error {
+func (s *store) AddDiscordIdByUUID(discordId string, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE uuid = ?", discordId, uuid)
if err != nil {
return err
@@ -86,7 +86,7 @@ func AddDiscordIdByUUID(discordId string, uuid []byte) error {
return nil
}
-func FetchUsernameByDiscordId(discordId string) (string, error) {
+func (s *store) FetchUsernameByDiscordId(discordId string) (string, error) {
var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE discordId = ?", discordId).Scan(&username)
if err != nil {
@@ -96,7 +96,7 @@ func FetchUsernameByDiscordId(discordId string) (string, error) {
return username, nil
}
-func FetchUsernameByGoogleId(googleId string) (string, error) {
+func (s *store) FetchUsernameByGoogleId(googleId string) (string, error) {
var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE googleId = ?", googleId).Scan(&username)
if err != nil {
@@ -106,7 +106,7 @@ func FetchUsernameByGoogleId(googleId string) (string, error) {
return username, nil
}
-func FetchDiscordIdByUsername(username string) (string, error) {
+func (s *store) FetchDiscordIdByUsername(username string) (string, error) {
var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE username = ?", username).Scan(&discordId)
if err != nil {
@@ -120,7 +120,7 @@ func FetchDiscordIdByUsername(username string) (string, error) {
return discordId.String, nil
}
-func FetchGoogleIdByUsername(username string) (string, error) {
+func (s *store) FetchGoogleIdByUsername(username string) (string, error) {
var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE username = ?", username).Scan(&googleId)
if err != nil {
@@ -134,7 +134,7 @@ func FetchGoogleIdByUsername(username string) (string, error) {
return googleId.String, nil
}
-func FetchDiscordIdByUUID(uuid []byte) (string, error) {
+func (s *store) FetchDiscordIdByUUID(uuid []byte) (string, error) {
var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId)
if err != nil {
@@ -148,7 +148,7 @@ func FetchDiscordIdByUUID(uuid []byte) (string, error) {
return discordId.String, nil
}
-func FetchGoogleIdByUUID(uuid []byte) (string, error) {
+func (s *store) FetchGoogleIdByUUID(uuid []byte) (string, error) {
var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId)
if err != nil {
@@ -162,7 +162,7 @@ func FetchGoogleIdByUUID(uuid []byte) (string, error) {
return googleId.String, nil
}
-func FetchUsernameBySessionToken(token []byte) (string, error) {
+func (s *store) FetchUsernameBySessionToken(token []byte) (string, error) {
var username string
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username)
if err != nil {
@@ -172,7 +172,7 @@ func FetchUsernameBySessionToken(token []byte) (string, error) {
return username, nil
}
-func CheckUsernameExists(username string) (string, error) {
+func (s *store) CheckUsernameExists(username string) (string, error) {
var dbUsername sql.NullString
err := handle.QueryRow("SELECT username FROM accounts WHERE username = ?", username).Scan(&dbUsername)
if err != nil {
@@ -185,7 +185,7 @@ func CheckUsernameExists(username string) (string, error) {
return dbUsername.String, nil
}
-func FetchLastLoggedInDateByUsername(username string) (string, error) {
+func (s *store) FetchLastLoggedInDateByUsername(username string) (string, error) {
var lastLoggedIn sql.NullString
err := handle.QueryRow("SELECT lastLoggedIn FROM accounts WHERE username = ?", username).Scan(&lastLoggedIn)
if err != nil {
@@ -206,7 +206,7 @@ type AdminSearchResponse struct {
Registered string `json:"registered"`
}
-func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) {
+func (s *store) FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) {
var username, discordId, googleId, lastActivity, registered sql.NullString
var adminResponse AdminSearchResponse
@@ -226,7 +226,7 @@ func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error)
return adminResponse, nil
}
-func UpdateAccountPassword(uuid, key, salt []byte) error {
+func (s *store) UpdateAccountPassword(uuid, key, salt []byte) error {
_, err := handle.Exec("UPDATE accounts SET hash = ?, salt = ? WHERE uuid = ?", key, salt, uuid)
if err != nil {
return err
@@ -235,7 +235,7 @@ func UpdateAccountPassword(uuid, key, salt []byte) error {
return nil
}
-func UpdateAccountLastActivity(uuid []byte) error {
+func (s *store) UpdateAccountLastActivity(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET lastActivity = UTC_TIMESTAMP() WHERE uuid = ?", uuid)
if err != nil {
return err
@@ -244,7 +244,7 @@ func UpdateAccountLastActivity(uuid []byte) error {
return nil
}
-func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
+func (s *store) UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
var columns = []string{"playTime", "battles", "classicSessionsPlayed", "sessionsWon", "highestEndlessWave", "highestLevel", "pokemonSeen", "pokemonDefeated", "pokemonCaught", "pokemonHatched", "eggsPulled", "regularVouchers", "plusVouchers", "premiumVouchers", "goldenVouchers"}
var statCols []string
@@ -321,7 +321,7 @@ func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[str
return nil
}
-func SetAccountBanned(uuid []byte, banned bool) error {
+func (s *store) SetAccountBanned(uuid []byte, banned bool) error {
_, err := handle.Exec("UPDATE accounts SET banned = ? WHERE uuid = ?", banned, uuid)
if err != nil {
return err
@@ -330,17 +330,16 @@ func SetAccountBanned(uuid []byte, banned bool) error {
return nil
}
-func FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
- var key, salt []byte
- err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
+func (s *store) AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
+ _, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
if err != nil {
- return nil, nil, err
+ return err
}
- return key, salt, nil
+ return nil
}
-func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
+func (s *store) FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
err = handle.QueryRow("SELECT trainerId, secretId FROM accounts WHERE uuid = ?", uuid).Scan(&trainerId, &secretId)
if err != nil {
return 0, 0, err
@@ -349,7 +348,7 @@ func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
return trainerId, secretId, nil
}
-func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
+func (s *store) UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET trainerId = ?, secretId = ? WHERE uuid = ?", trainerId, secretId, uuid)
if err != nil {
return err
@@ -358,12 +357,12 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
return nil
}
-func IsActiveSession(uuid []byte, sessionId string) (bool, error) {
+func (s *store) IsActiveSession(uuid []byte, sessionId string) (bool, error) {
var id string
err := handle.QueryRow("SELECT clientSessionId FROM activeClientSessions WHERE uuid = ?", uuid).Scan(&id)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
- err = UpdateActiveSession(uuid, sessionId)
+ err = s.UpdateActiveSession(uuid, sessionId)
if err != nil {
return false, err
}
@@ -377,7 +376,7 @@ func IsActiveSession(uuid []byte, sessionId string) (bool, error) {
return id == "" || id == sessionId, nil
}
-func UpdateActiveSession(uuid []byte, clientSessionId string) error {
+func (s *store) UpdateActiveSession(uuid []byte, clientSessionId string) error {
_, err := handle.Exec("INSERT INTO activeClientSessions (uuid, clientSessionId) VALUES (?, ?) ON DUPLICATE KEY UPDATE clientSessionId = ?", uuid, clientSessionId, clientSessionId)
if err != nil {
return err
@@ -386,7 +385,7 @@ func UpdateActiveSession(uuid []byte, clientSessionId string) error {
return nil
}
-func FetchUUIDFromToken(token []byte) ([]byte, error) {
+func (s *store) FetchUUIDFromToken(token []byte) ([]byte, error) {
var uuid []byte
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid)
if err != nil {
@@ -396,7 +395,7 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) {
return uuid, nil
}
-func RemoveSessionFromToken(token []byte) error {
+func (s *store) RemoveSessionFromToken(token []byte) error {
_, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token)
if err != nil {
return err
@@ -405,7 +404,7 @@ func RemoveSessionFromToken(token []byte) error {
return nil
}
-func RemoveSessionsFromUUID(uuid []byte) error {
+func (s *store) RemoveSessionsFromUUID(uuid []byte) error {
_, err := handle.Exec("DELETE FROM sessions WHERE uuid = ?", uuid)
if err != nil {
return err
@@ -414,7 +413,7 @@ func RemoveSessionsFromUUID(uuid []byte) error {
return nil
}
-func FetchUsernameFromUUID(uuid []byte) (string, error) {
+func (s *store) FetchUsernameFromUUID(uuid []byte) (string, error) {
var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE uuid = ?", uuid).Scan(&username)
if err != nil {
@@ -424,7 +423,7 @@ func FetchUsernameFromUUID(uuid []byte) (string, error) {
return username, nil
}
-func FetchUUIDFromUsername(username string) ([]byte, error) {
+func (s *store) FetchUUIDFromUsername(username string) ([]byte, error) {
var uuid []byte
err := handle.QueryRow("SELECT uuid FROM accounts WHERE username = ?", username).Scan(&uuid)
if err != nil {
@@ -434,7 +433,7 @@ func FetchUUIDFromUsername(username string) ([]byte, error) {
return uuid, nil
}
-func RemoveDiscordIdByUUID(uuid []byte) error {
+func (s *store) RemoveDiscordIdByUUID(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE uuid = ?", uuid)
if err != nil {
return err
@@ -443,7 +442,7 @@ func RemoveDiscordIdByUUID(uuid []byte) error {
return nil
}
-func RemoveGoogleIdByUUID(uuid []byte) error {
+func (s *store) RemoveGoogleIdByUUID(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE uuid = ?", uuid)
if err != nil {
return err
@@ -452,7 +451,7 @@ func RemoveGoogleIdByUUID(uuid []byte) error {
return nil
}
-func RemoveGoogleIdByUsername(username string) error {
+func (s *store) RemoveGoogleIdByUsername(username string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE username = ?", username)
if err != nil {
return err
@@ -461,7 +460,7 @@ func RemoveGoogleIdByUsername(username string) error {
return nil
}
-func RemoveDiscordIdByUsername(username string) error {
+func (s *store) RemoveDiscordIdByUsername(username string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE username = ?", username)
if err != nil {
return err
@@ -470,7 +469,7 @@ func RemoveDiscordIdByUsername(username string) error {
return nil
}
-func RemoveDiscordIdByDiscordId(discordId string) error {
+func (s *store) RemoveDiscordIdByDiscordId(discordId string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE discordId = ?", discordId)
if err != nil {
return err
@@ -479,7 +478,7 @@ func RemoveDiscordIdByDiscordId(discordId string) error {
return nil
}
-func RemoveGoogleIdByDiscordId(discordId string) error {
+func (s *store) RemoveGoogleIdByDiscordId(discordId string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE discordId = ?", discordId)
if err != nil {
return err
diff --git a/db/daily.go b/db/daily.go
index 59c8f92..155c3d6 100644
--- a/db/daily.go
+++ b/db/daily.go
@@ -23,7 +23,7 @@ import (
"github.com/pagefaultgames/rogueserver/defs"
)
-func TryAddDailyRun(seed string) (string, error) {
+func (s *store) TryAddDailyRun(seed string) (string, error) {
var actualSeed string
err := handle.QueryRow("INSERT INTO dailyRuns (seed, date) VALUES (?, UTC_DATE()) ON DUPLICATE KEY UPDATE date = date RETURNING seed", seed).Scan(&actualSeed)
if err != nil {
@@ -33,7 +33,7 @@ func TryAddDailyRun(seed string) (string, error) {
return actualSeed, nil
}
-func GetDailyRunSeed() (string, error) {
+func (s *store) GetDailyRunSeed() (string, error) {
var seed string
err := handle.QueryRow("SELECT seed FROM dailyRuns WHERE date = UTC_DATE()").Scan(&seed)
if err != nil {
@@ -43,7 +43,7 @@ func GetDailyRunSeed() (string, error) {
return seed, nil
}
-func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
+func (s *store) AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
_, err := handle.Exec("INSERT INTO accountDailyRuns (uuid, date, score, wave, timestamp) VALUES (?, UTC_DATE(), ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE score = GREATEST(score, ?), wave = GREATEST(wave, ?), timestamp = IF(score < ?, UTC_TIMESTAMP(), timestamp)", uuid, score, wave, score, wave, score)
if err != nil {
return err
@@ -52,7 +52,7 @@ func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
return nil
}
-func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
+func (s *store) FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
var rankings []defs.DailyRanking
offset := (page - 1) * 10
@@ -85,7 +85,7 @@ func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
return rankings, nil
}
-func FetchRankingPageCount(category int) (int, error) {
+func (s *store) FetchRankingPageCount(category int) (int, error) {
var query string
switch category {
case 0:
diff --git a/db/db.go b/db/db.go
index 06adb03..25b0039 100644
--- a/db/db.go
+++ b/db/db.go
@@ -31,6 +31,12 @@ import (
var handle *sql.DB
var s3client *s3.Client
+// internal type used to implement the Store interface
+type store struct{}
+
+// Store is the global instance for DB access.
+var Store = &store{}
+
func Init(username, password, protocol, address, database string) error {
var err error
diff --git a/db/game.go b/db/game.go
index ec862b6..7e722cc 100644
--- a/db/game.go
+++ b/db/game.go
@@ -17,7 +17,7 @@
package db
-func FetchPlayerCount() (int, error) {
+func (s *store) FetchPlayerCount() (int, error) {
var playerCount int
err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount)
if err != nil {
@@ -27,7 +27,7 @@ func FetchPlayerCount() (int, error) {
return playerCount, nil
}
-func FetchBattleCount() (int, error) {
+func (s *store) FetchBattleCount() (int, error) {
var battleCount int
err := handle.QueryRow("SELECT COALESCE(SUM(s.battles), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&battleCount)
if err != nil {
@@ -37,7 +37,7 @@ func FetchBattleCount() (int, error) {
return battleCount, nil
}
-func FetchClassicSessionCount() (int, error) {
+func (s *store) FetchClassicSessionCount() (int, error) {
var classicSessionCount int
err := handle.QueryRow("SELECT COALESCE(SUM(s.classicSessionsPlayed), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&classicSessionCount)
if err != nil {
diff --git a/db/s3.go b/db/s3.go
new file mode 100644
index 0000000..3ac41af
--- /dev/null
+++ b/db/s3.go
@@ -0,0 +1,61 @@
+package db
+
+import (
+ "bytes"
+ "context"
+ "encoding/json"
+ "os"
+
+ "github.com/aws/aws-sdk-go-v2/aws"
+ "github.com/aws/aws-sdk-go-v2/service/s3"
+ "github.com/pagefaultgames/rogueserver/defs"
+)
+
+func (s *store) GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
+ var system defs.SystemSaveData
+
+ username, err := Store.FetchUsernameFromUUID(uuid)
+ if err != nil {
+ return system, err
+ }
+
+ resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
+ Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
+ Key: aws.String(username),
+ })
+ if err != nil {
+ return system, err
+ }
+
+ err = json.NewDecoder(resp.Body).Decode(&system)
+ if err != nil {
+ return system, err
+ }
+
+ return system, nil
+}
+
+func (s *store) StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error {
+ username, err := s.FetchUsernameFromUUID(uuid)
+ if err != nil {
+ return err
+ }
+
+ buf := new(bytes.Buffer)
+
+ err = json.NewEncoder(buf).Encode(data)
+ if err != nil {
+ return err
+ }
+
+ _, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
+ Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
+ Key: aws.String(username),
+ Body: buf,
+ })
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
diff --git a/db/savedata.go b/db/savedata.go
index ce86720..9226d5e 100644
--- a/db/savedata.go
+++ b/db/savedata.go
@@ -19,19 +19,13 @@ package db
import (
"bytes"
- "context"
"encoding/gob"
- "encoding/json"
- "os"
"github.com/klauspost/compress/zstd"
"github.com/pagefaultgames/rogueserver/defs"
-
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/service/s3"
)
-func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
+func (s *store) TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
var count int
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
if err != nil {
@@ -48,7 +42,7 @@ func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
return true, nil
}
-func ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
+func (s *store) ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
var count int
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
if err != nil {
@@ -58,7 +52,7 @@ func ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
return count > 0, nil
}
-func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
+func (s *store) ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
var data []byte
@@ -82,7 +76,7 @@ func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
return system, nil
}
-func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
+func (s *store) StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
buf := new(bytes.Buffer)
zw, err := zstd.NewWriter(buf)
@@ -108,32 +102,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
return nil
}
-func StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error {
- username, err := FetchUsernameFromUUID(uuid)
- if err != nil {
- return err
- }
-
- buf := new(bytes.Buffer)
-
- err = json.NewEncoder(buf).Encode(data)
- if err != nil {
- return err
- }
-
- _, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
- Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
- Key: aws.String(username),
- Body: buf,
- })
- if err != nil {
- return err
- }
-
- return nil
-}
-
-func DeleteSystemSaveData(uuid []byte) error {
+func (s *store) DeleteSystemSaveData(uuid []byte) error {
_, err := handle.Exec("DELETE FROM systemSaveData WHERE uuid = ?", uuid)
if err != nil {
return err
@@ -142,7 +111,7 @@ func DeleteSystemSaveData(uuid []byte) error {
return nil
}
-func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
+func (s *store) ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
var session defs.SessionSaveData
var data []byte
@@ -166,7 +135,7 @@ func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
return session, nil
}
-func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
+func (s *store) GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
var slot int
err := handle.QueryRow("SELECT slot FROM sessionSaveData WHERE uuid = ? ORDER BY timestamp DESC, slot ASC LIMIT 1", uuid).Scan(&slot)
if err != nil {
@@ -176,7 +145,7 @@ func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
return slot, nil
}
-func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error {
+func (s *store) StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error {
buf := new(bytes.Buffer)
zw, err := zstd.NewWriter(buf)
@@ -202,7 +171,7 @@ func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) erro
return nil
}
-func DeleteSessionSaveData(uuid []byte, slot int) error {
+func (s *store) DeleteSessionSaveData(uuid []byte, slot int) error {
_, err := handle.Exec("DELETE FROM sessionSaveData WHERE uuid = ? AND slot = ?", uuid, slot)
if err != nil {
return err
@@ -210,27 +179,3 @@ func DeleteSessionSaveData(uuid []byte, slot int) error {
return nil
}
-
-func GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
- var system defs.SystemSaveData
-
- username, err := FetchUsernameFromUUID(uuid)
- if err != nil {
- return system, err
- }
-
- resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
- Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
- Key: aws.String(username),
- })
- if err != nil {
- return system, err
- }
-
- err = json.NewDecoder(resp.Body).Decode(&system)
- if err != nil {
- return system, err
- }
-
- return system, nil
-}