Refactor to use dependency injection, allowing for unit testing (#67)

This commit is contained in:
Sirz Benjie 2025-09-16 22:04:45 -05:00 committed by GitHub
commit 9c9b5243cf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 579 additions and 296 deletions

View File

@ -1,7 +1,8 @@
version: "2"
run: run:
timeout: 10m timeout: 10m
severity: severity:
default-severity: error default: error
rules: rules:
- linters: - linters:
- unused - unused

View File

@ -20,11 +20,15 @@ package account
import ( import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"github.com/pagefaultgames/rogueserver/db"
) )
func ChangePW(uuid []byte, password string) error { // Interface for database operations needed for changing password.
type ChangePWStore interface {
RemoveSessionsFromUUID(uuid []byte) error
UpdateAccountPassword(uuid []byte, newKey []byte, newSalt []byte) error
}
func ChangePW[T ChangePWStore](store T, uuid []byte, password string) error {
if len(password) < 6 { if len(password) < 6 {
return fmt.Errorf("invalid password") return fmt.Errorf("invalid password")
} }
@ -35,12 +39,12 @@ func ChangePW(uuid []byte, password string) error {
return fmt.Errorf("failed to generate salt: %s", err) return fmt.Errorf("failed to generate salt: %s", err)
} }
err = db.RemoveSessionsFromUUID(uuid) err = store.RemoveSessionsFromUUID(uuid)
if err != nil { if err != nil {
return fmt.Errorf("failed to remove sessions: %s", err) return fmt.Errorf("failed to remove sessions: %s", err)
} }
err = db.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt) err = store.UpdateAccountPassword(uuid, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil { if err != nil {
return fmt.Errorf("failed to add account record: %s", err) return fmt.Errorf("failed to add account record: %s", err)
} }

View File

@ -35,14 +35,24 @@ var (
DiscordGuildID string DiscordGuildID string
) )
func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) { type DiscordProvider interface {
HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error)
RetrieveDiscordId(code string) (string, error)
IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error)
}
type discordProvider struct{}
var Discord = &discordProvider{}
func (s *discordProvider) HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" { if code == "" {
http.Redirect(w, r, GameURL, http.StatusSeeOther) http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", errors.New("code is empty") return "", errors.New("code is empty")
} }
discordId, err := RetrieveDiscordId(code) discordId, err := s.RetrieveDiscordId(code)
if err != nil { if err != nil {
http.Redirect(w, r, GameURL, http.StatusSeeOther) http.Redirect(w, r, GameURL, http.StatusSeeOther)
return "", err return "", err
@ -51,7 +61,7 @@ func HandleDiscordCallback(w http.ResponseWriter, r *http.Request) (string, erro
return discordId, nil return discordId, nil
} }
func RetrieveDiscordId(code string) (string, error) { func (s *discordProvider) RetrieveDiscordId(code string) (string, error) {
v := make(url.Values) v := make(url.Values)
v.Set("client_id", DiscordClientID) v.Set("client_id", DiscordClientID)
v.Set("client_secret", DiscordClientSecret) v.Set("client_secret", DiscordClientSecret)
@ -112,7 +122,7 @@ func RetrieveDiscordId(code string) (string, error) {
return user.Id, nil return user.Id, nil
} }
func IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) { func (s *discordProvider) IsUserDiscordAdmin(discordId string, discordGuildID string) (bool, error) {
// fetch all roles from discord // fetch all roles from discord
roles, err := DiscordSession.GuildRoles(discordGuildID) roles, err := DiscordSession.GuildRoles(discordGuildID)
if err != nil { if err != nil {

View File

@ -32,7 +32,16 @@ var (
GoogleCallbackURL string GoogleCallbackURL string
) )
func HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) { type GoogleProvider interface {
HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error)
RetrieveGoogleId(code string) (string, error)
}
type googleProvider struct{}
var Google = &googleProvider{}
func (g *googleProvider) HandleGoogleCallback(w http.ResponseWriter, r *http.Request) (string, error) {
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" { if code == "" {
http.Redirect(w, r, GameURL, http.StatusSeeOther) http.Redirect(w, r, GameURL, http.StatusSeeOther)

View File

@ -17,10 +17,6 @@
package account package account
import (
"github.com/pagefaultgames/rogueserver/db"
)
type InfoResponse struct { type InfoResponse struct {
Username string `json:"username"` Username string `json:"username"`
DiscordId string `json:"discordId"` DiscordId string `json:"discordId"`
@ -29,9 +25,13 @@ type InfoResponse struct {
HasAdminRole bool `json:"hasAdminRole"` HasAdminRole bool `json:"hasAdminRole"`
} }
type InfoStore interface {
GetLatestSessionSaveDataSlot(uuid []byte) (int, error)
}
// /account/info - get account info // /account/info - get account info
func Info(username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) { func Info[T InfoStore](store T, username string, discordId string, googleId string, uuid []byte, hasAdminRole bool) (InfoResponse, error) {
slot, _ := db.GetLatestSessionSaveDataSlot(uuid) slot, _ := store.GetLatestSessionSaveDataSlot(uuid)
response := InfoResponse{ response := InfoResponse{
Username: username, Username: username,
LastSessionSlot: slot, LastSessionSlot: slot,

View File

@ -24,14 +24,18 @@ import (
"encoding/base64" "encoding/base64"
"errors" "errors"
"fmt" "fmt"
"github.com/pagefaultgames/rogueserver/db"
) )
type LoginResponse GenericAuthResponse type LoginResponse GenericAuthResponse
// Interface for database operations needed for login.
type LoginStore interface {
FetchAccountKeySaltFromUsername(username string) (key, salt []byte, err error)
AddAccountSession(username string, token []byte) error
}
// /account/login - log into account // /account/login - log into account
func Login(username, password string) (LoginResponse, error) { func Login[T LoginStore](store T, username, password string) (LoginResponse, error) {
var response LoginResponse var response LoginResponse
if !isValidUsername(username) { if !isValidUsername(username) {
@ -42,12 +46,11 @@ func Login(username, password string) (LoginResponse, error) {
return response, fmt.Errorf("invalid password") return response, fmt.Errorf("invalid password")
} }
key, salt, err := db.FetchAccountKeySaltFromUsername(username) key, salt, err := store.FetchAccountKeySaltFromUsername(username)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return response, fmt.Errorf("account doesn't exist") return response, fmt.Errorf("account doesn't exist")
} }
return response, err return response, err
} }
@ -55,8 +58,7 @@ func Login(username, password string) (LoginResponse, error) {
return response, fmt.Errorf("password doesn't match") return response, fmt.Errorf("password doesn't match")
} }
response.Token, err = GenerateTokenForUsername(username) response.Token, err = GenerateTokenForUsername(store, username)
if err != nil { if err != nil {
return response, fmt.Errorf("failed to generate token: %s", err) return response, fmt.Errorf("failed to generate token: %s", err)
} }
@ -64,14 +66,19 @@ func Login(username, password string) (LoginResponse, error) {
return response, nil return response, nil
} }
func GenerateTokenForUsername(username string) (string, error) { type GenerateTokenForUsernameStore interface {
AddAccountSession(username string, token []byte) error
}
// GenerateTokenForUsername generates a session token and stores it using the provided DBAccountStore.
func GenerateTokenForUsername[T GenerateTokenForUsernameStore](store T, username string) (string, error) {
token := make([]byte, TokenSize) token := make([]byte, TokenSize)
_, err := rand.Read(token) _, err := rand.Read(token)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to generate token: %s", err) return "", fmt.Errorf("failed to generate token: %s", err)
} }
err = db.AddAccountSession(username, token) err = store.AddAccountSession(username, token)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to add account session") return "", fmt.Errorf("failed to add account session")
} }

154
api/account/login_test.go Normal file
View File

@ -0,0 +1,154 @@
package account
import (
"database/sql"
"errors"
"testing"
)
func defaultMockStore() *mockDBAccountStore {
return &mockDBAccountStore{
FetchFunc: func(username string) ([]byte, []byte, error) {
return []byte("key"), []byte("salt"), nil
},
AddSessionFunc: func(username string, token []byte) error { return nil },
}
}
type mockDBAccountStore struct {
FetchFunc func(username string) ([]byte, []byte, error)
AddSessionFunc func(username string, token []byte) error
}
func (m *mockDBAccountStore) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
return m.FetchFunc(username)
}
func (m *mockDBAccountStore) AddAccountSession(username string, token []byte) error {
return m.AddSessionFunc(username, token)
}
func TestLogin(t *testing.T) {
t.Run("UsernameMinLength", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "a", "password123")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("UsernameMaxLength", func(t *testing.T) {
uname := "abcdefghijklmnop"
store := defaultMockStore()
_, err := Login(store, uname, "password123")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("UsernameTooLong", func(t *testing.T) {
uname := "abcdefghijklmnopq"
store := defaultMockStore()
_, err := Login(store, uname, "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error for too long username, got: %v", err)
}
})
t.Run("UsernameWithInvalidChars", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "user!@#", "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error for special chars, got: %v", err)
}
})
t.Run("EmptyUsername", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "", "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error for empty username, got: %v", err)
}
})
t.Run("EmptyPassword", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "")
if err == nil || err.Error() != "invalid password" {
t.Errorf("expected invalid password error for empty password, got: %v", err)
}
})
t.Run("MinPasswordLength", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "123456")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("PasswordWithSpecialChars", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "p@$$w0rd!")
if err == nil {
t.Errorf("expected error due to password mismatch or DB, got nil")
}
})
t.Run("DBUnexpectedError", func(t *testing.T) {
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return nil, nil, errors.New("some db error")
}
_, err := Login(store, "validuser", "password123")
if err == nil || err.Error() != "some db error" {
t.Errorf("expected DB error to propagate, got: %v", err)
}
})
t.Run("InvalidUsername", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "!invaliduser", "password123")
if err == nil || err.Error() != "invalid username" {
t.Errorf("expected invalid username error, got: %v", err)
}
})
t.Run("ShortPassword", func(t *testing.T) {
store := defaultMockStore()
_, err := Login(store, "validuser", "123")
if err == nil || err.Error() != "invalid password" {
t.Errorf("expected invalid password error, got: %v", err)
}
})
t.Run("AccountDoesNotExist", func(t *testing.T) {
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return nil, nil, sql.ErrNoRows
}
_, err := Login(store, "nonexistent", "password123")
if err == nil || err.Error() != "account doesn't exist" {
t.Errorf("expected account doesn't exist error, got: %v", err)
}
})
t.Run("PasswordMismatch", func(t *testing.T) {
correctSalt := []byte("somesalt")
correctKey := []byte("correctkey")
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return correctKey, correctSalt, nil
}
_, err := Login(store, "validuser", "wrongpassword")
if err == nil || err.Error() != "password doesn't match" {
t.Errorf("expected password doesn't match error, got: %v", err)
}
})
t.Run("Success", func(t *testing.T) {
correctSalt := []byte("somesalt")
password := "goodpassword"
correctKey := deriveArgon2IDKey([]byte(password), correctSalt)
store := defaultMockStore()
store.FetchFunc = func(username string) ([]byte, []byte, error) {
return correctKey, correctSalt, nil
}
store.AddSessionFunc = func(username string, token []byte) error {
return nil
}
resp, err := Login(store, "validuser", password)
if err != nil {
t.Errorf("expected success, got error: %v", err)
}
if resp.Token == "" {
t.Errorf("expected token to be set on success")
}
})
}

View File

@ -21,13 +21,17 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"github.com/pagefaultgames/rogueserver/db"
) )
// /account/logout - log out of account // /account/logout - log out of account
func Logout(token []byte) error {
err := db.RemoveSessionFromToken(token) // Interface for database operations needed for logout.
type LogoutStore interface {
RemoveSessionFromToken(token []byte) error
}
func Logout[T LogoutStore](store T, token []byte) error {
err := store.RemoveSessionFromToken(token)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return fmt.Errorf("token not found") return fmt.Errorf("token not found")

View File

@ -20,12 +20,15 @@ package account
import ( import (
"crypto/rand" "crypto/rand"
"fmt" "fmt"
"github.com/pagefaultgames/rogueserver/db"
) )
// Interface for database operations needed for registration.
type RegisterStore interface {
AddAccountRecord(uuid []byte, username string, passwordHash []byte, salt []byte) error
}
// /account/register - register account // /account/register - register account
func Register(username, password string) error { func Register[T RegisterStore](store T, username, password string) error {
if !isValidUsername(username) { if !isValidUsername(username) {
return fmt.Errorf("invalid username") return fmt.Errorf("invalid username")
} }
@ -46,7 +49,7 @@ func Register(username, password string) error {
return fmt.Errorf("failed to generate salt: %s", err) return fmt.Errorf("failed to generate salt: %s", err)
} }
err = db.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt) err = store.AddAccountRecord(uuid, username, deriveArgon2IDKey([]byte(password), salt), salt)
if err != nil { if err != nil {
return fmt.Errorf("failed to add account record: %s", err) return fmt.Errorf("failed to add account record: %s", err)
} }

View File

@ -30,7 +30,7 @@ import (
) )
func Init(mux *http.ServeMux) error { func Init(mux *http.ServeMux) error {
err := scheduleStatRefresh() err := scheduleStatRefresh(db.Store)
if err != nil { if err != nil {
return err return err
} }
@ -109,7 +109,7 @@ func tokenAndUuidFromRequest(r *http.Request) ([]byte, []byte, error) {
return nil, nil, err return nil, nil, err
} }
uuid, err := db.FetchUUIDFromToken(token) uuid, err := db.Store.FetchUUIDFromToken(token)
if err != nil { if err != nil {
return nil, nil, fmt.Errorf("failed to validate token: %s", err) return nil, nil, fmt.Errorf("failed to validate token: %s", err)
} }

View File

@ -61,7 +61,7 @@ func Init() error {
secret = newSecret secret = newSecret
} }
seed, err := db.TryAddDailyRun(Seed()) seed, err := db.Store.TryAddDailyRun(Seed())
if err != nil { if err != nil {
log.Print(err) log.Print(err)
} }
@ -71,7 +71,7 @@ func Init() error {
_, err = scheduler.AddFunc("@daily", func() { _, err = scheduler.AddFunc("@daily", func() {
time.Sleep(time.Second) time.Sleep(time.Second)
seed, err = db.TryAddDailyRun(Seed()) seed, err = db.Store.TryAddDailyRun(Seed())
if err != nil { if err != nil {
log.Printf("error while recording new daily: %s", err) log.Printf("error while recording new daily: %s", err)
} else { } else {

View File

@ -22,9 +22,14 @@ import (
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
// Interface for database operations needed for fetching rankings.
type RankingsStore interface {
FetchRankings(category, page int) ([]defs.DailyRanking, error)
}
// /daily/rankings - fetch daily rankings // /daily/rankings - fetch daily rankings
func Rankings(category, page int) ([]defs.DailyRanking, error) { func Rankings[T RankingsStore](store T, category, page int) ([]defs.DailyRanking, error) {
rankings, err := db.FetchRankings(category, page) rankings, err := db.Store.FetchRankings(category, page)
if err != nil { if err != nil {
return rankings, err return rankings, err
} }

View File

@ -17,13 +17,13 @@
package daily package daily
import ( type RankingPageCountStore interface {
"github.com/pagefaultgames/rogueserver/db" FetchRankingPageCount(category int) (int, error)
) }
// /daily/rankingpagecount - fetch daily ranking page count // /daily/rankingpagecount - fetch daily ranking page count
func RankingPageCount(category int) (int, error) { func RankingPageCount[T RankingPageCountStore](store T, category int) (int, error) {
pageCount, err := db.FetchRankingPageCount(category) pageCount, err := store.FetchRankingPageCount(category)
if err != nil { if err != nil {
return pageCount, err return pageCount, err
} }

View File

@ -50,17 +50,17 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
return return
} }
username, err := db.FetchUsernameFromUUID(uuid) username, err := db.Store.FetchUsernameFromUUID(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
discordId, err := db.FetchDiscordIdByUsername(username) discordId, err := db.Store.FetchDiscordIdByUsername(username)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
googleId, err := db.FetchGoogleIdByUsername(username) googleId, err := db.Store.FetchGoogleIdByUsername(username)
if err != nil && !errors.Is(err, sql.ErrNoRows) { if err != nil && !errors.Is(err, sql.ErrNoRows) {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -68,10 +68,10 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
var hasAdminRole bool var hasAdminRole bool
if discordId != "" { if discordId != "" {
hasAdminRole, _ = account.IsUserDiscordAdmin(discordId, account.DiscordGuildID) hasAdminRole, _ = account.Discord.IsUserDiscordAdmin(discordId, account.DiscordGuildID)
} }
response, err := account.Info(username, discordId, googleId, uuid, hasAdminRole) response, err := account.Info(db.Store, username, discordId, googleId, uuid, hasAdminRole)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -81,7 +81,7 @@ func handleAccountInfo(w http.ResponseWriter, r *http.Request) {
} }
func handleAccountRegister(w http.ResponseWriter, r *http.Request) { func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
err := account.Register(r.PostFormValue("username"), r.PostFormValue("password")) err := account.Register(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -91,7 +91,7 @@ func handleAccountRegister(w http.ResponseWriter, r *http.Request) {
} }
func handleAccountLogin(w http.ResponseWriter, r *http.Request) { func handleAccountLogin(w http.ResponseWriter, r *http.Request) {
response, err := account.Login(r.PostFormValue("username"), r.PostFormValue("password")) response, err := account.Login(db.Store, r.PostFormValue("username"), r.PostFormValue("password"))
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -107,20 +107,20 @@ func handleAccountChangePW(w http.ResponseWriter, r *http.Request) {
return return
} }
err = account.ChangePW(uuid, r.PostFormValue("password")) err = account.ChangePW(db.Store, uuid, r.PostFormValue("password"))
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
username, err := db.FetchUsernameFromUUID(uuid) username, err := db.Store.FetchUsernameFromUUID(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
// create a new session with these credentials // create a new session with these credentials
response, err := account.Login(username, r.Form.Get("password")) response, err := account.Login(db.Store, username, r.Form.Get("password"))
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -136,7 +136,7 @@ func handleAccountLogout(w http.ResponseWriter, r *http.Request) {
return return
} }
err = account.Logout(token) err = account.Logout(db.Store, token)
if err != nil { if err != nil {
// also possible for InternalServerError but that's unlikely unless the server blew up // also possible for InternalServerError but that's unlikely unless the server blew up
httpError(w, r, err, http.StatusUnauthorized) httpError(w, r, err, http.StatusUnauthorized)
@ -183,7 +183,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return return
} }
err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return return
@ -191,7 +191,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("action") { switch r.PathValue("action") {
case "get": case "get":
save, err := savedata.GetSession(uuid, slot) save, err := savedata.GetSession(db.Store, uuid, slot)
if err != nil { if err != nil {
if errors.Is(err, savedata.ErrSaveNotExist) { if errors.Is(err, savedata.ErrSaveNotExist) {
http.Error(w, err.Error(), http.StatusNotFound) http.Error(w, err.Error(), http.StatusNotFound)
@ -211,7 +211,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return return
} }
existingSave, err := savedata.GetSession(uuid, slot) existingSave, err := savedata.GetSession(db.Store, uuid, slot)
if err != nil { if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) { if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
@ -224,7 +224,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
} }
} }
err = savedata.UpdateSession(uuid, slot, session) err = savedata.UpdateSession(db.Store, uuid, slot, session)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to put session data: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to put session data: %s", err), http.StatusInternalServerError)
return return
@ -239,13 +239,13 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
return return
} }
seed, err := db.GetDailyRunSeed() seed, err := db.Store.GetDailyRunSeed()
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
resp, err := savedata.Clear(uuid, slot, seed, session) resp, err := savedata.Clear(db.Store, uuid, slot, seed, session)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -253,7 +253,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, resp) writeJSON(w, r, resp)
case "newclear": case "newclear":
resp, err := savedata.NewClear(uuid, slot) resp, err := savedata.NewClear(db.Store, uuid, slot)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to read new clear: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to read new clear: %s", err), http.StatusInternalServerError)
return return
@ -261,7 +261,7 @@ func handleSession(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, resp) writeJSON(w, r, resp)
case "delete": case "delete":
err := savedata.DeleteSession(uuid, slot) err := savedata.DeleteSession(db.Store, uuid, slot)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -301,7 +301,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return return
} }
active, err := db.IsActiveSession(uuid, data.ClientSessionId) active, err := db.Store.IsActiveSession(uuid, data.ClientSessionId)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return return
@ -312,7 +312,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return return
} }
storedTrainerId, storedSecretId, err := db.FetchTrainerIds(uuid) storedTrainerId, storedSecretId, err := db.Store.FetchTrainerIds(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -324,14 +324,14 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
return return
} }
} else { } else {
err = db.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid) err = db.Store.UpdateTrainerIds(data.System.TrainerId, data.System.SecretId, uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
} }
oldSystem, err := savedata.GetSystem(uuid) oldSystem, err := savedata.GetSystem(db.Store, uuid)
if err != nil { if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) { if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
@ -356,7 +356,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
} }
} }
existingSave, err := savedata.GetSession(uuid, data.SessionSlotId) existingSave, err := savedata.GetSession(db.Store, uuid, data.SessionSlotId)
if err != nil { if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) { if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to retrieve session save data: %s", err), http.StatusInternalServerError)
@ -369,13 +369,13 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) {
} }
} }
err = savedata.Update(uuid, data.SessionSlotId, data.Session) err = savedata.Update(db.Store, uuid, data.SessionSlotId, data.Session)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
err = savedata.Update(uuid, 0, data.System) err = savedata.Update(db.Store, uuid, 0, data.System)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -401,7 +401,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
return return
} }
active, err := db.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId")) active, err := db.Store.IsActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest)
return return
@ -410,14 +410,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("action") { switch r.PathValue("action") {
case "get": case "get":
if !active { if !active {
err = db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) err = db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return return
} }
} }
save, err := savedata.GetSystem(uuid) save, err := savedata.GetSystem(db.Store, uuid)
if err != nil { if err != nil {
if errors.Is(err, savedata.ErrSaveNotExist) { if errors.Is(err, savedata.ErrSaveNotExist) {
http.Error(w, err.Error(), http.StatusNotFound) http.Error(w, err.Error(), http.StatusNotFound)
@ -442,7 +442,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
return return
} }
oldSystem, err := savedata.GetSystem(uuid) oldSystem, err := savedata.GetSystem(db.Store, uuid)
if err != nil { if err != nil {
if !errors.Is(err, savedata.ErrSaveNotExist) { if !errors.Is(err, savedata.ErrSaveNotExist) {
httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to retrieve playtime: %s", err), http.StatusInternalServerError)
@ -467,7 +467,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
} }
} }
err = savedata.UpdateSystem(uuid, system) err = savedata.UpdateSystem(db.Store, uuid, system)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to put system data: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to put system data: %s", err), http.StatusInternalServerError)
return return
@ -481,13 +481,13 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
// not valid, send server state // not valid, send server state
if !active { if !active {
err := db.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId")) err := db.Store.UpdateActiveSession(uuid, r.URL.Query().Get("clientSessionId"))
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest)
return return
} }
storedSaveData, err := db.ReadSystemSaveData(uuid) storedSaveData, err := db.Store.ReadSystemSaveData(uuid)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError) httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError)
return return
@ -498,7 +498,7 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, response) writeJSON(w, r, response)
case "delete": case "delete":
err := savedata.DeleteSystem(uuid) err := savedata.DeleteSystem(db.Store, uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -511,9 +511,14 @@ func handleSystem(w http.ResponseWriter, r *http.Request) {
} }
} }
// Interface providing database operations needed for getting daily seed.
type HandleDailySeedStore interface {
GetDailyRunSeed() (string, error)
}
// daily // daily
func handleDailySeed(w http.ResponseWriter, r *http.Request) { func handleDailySeed(w http.ResponseWriter, r *http.Request) {
seed, err := db.GetDailyRunSeed() seed, err := db.Store.GetDailyRunSeed()
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -522,6 +527,11 @@ func handleDailySeed(w http.ResponseWriter, r *http.Request) {
fmt.Fprint(w, seed) fmt.Fprint(w, seed)
} }
// Interface for database operations needed for getting daily rankings.
type HandleDailyRankingsStore interface {
daily.RankingsStore
}
func handleDailyRankings(w http.ResponseWriter, r *http.Request) { func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
var err error var err error
@ -543,7 +553,7 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
} }
} }
rankings, err := daily.Rankings(category, page) rankings, err := daily.Rankings(db.Store, category, page)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -552,6 +562,10 @@ func handleDailyRankings(w http.ResponseWriter, r *http.Request) {
writeJSON(w, r, rankings) writeJSON(w, r, rankings)
} }
type HandleDailyRankingsPageCountStore interface {
daily.RankingPageCountStore
}
func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) { func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
var category int var category int
if r.URL.Query().Has("category") { if r.URL.Query().Has("category") {
@ -563,7 +577,7 @@ func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) {
} }
} }
count, err := daily.RankingPageCount(category) count, err := daily.RankingPageCount(db.Store, category)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
} }
@ -579,9 +593,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
var err error var err error
switch provider { switch provider {
case "discord": case "discord":
externalAuthId, err = account.HandleDiscordCallback(w, r) externalAuthId, err = account.Discord.HandleDiscordCallback(w, r)
case "google": case "google":
externalAuthId, err = account.HandleGoogleCallback(w, r) externalAuthId, err = account.Google.HandleGoogleCallback(w, r)
default: default:
http.Error(w, "invalid provider", http.StatusBadRequest) http.Error(w, "invalid provider", http.StatusBadRequest)
return return
@ -600,7 +614,7 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
return return
} }
userName, err := db.FetchUsernameBySessionToken(stateByte) userName, err := db.Store.FetchUsernameBySessionToken(stateByte)
if err != nil { if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther) http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return return
@ -608,9 +622,9 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
switch provider { switch provider {
case "discord": case "discord":
err = db.AddDiscordIdByUsername(externalAuthId, userName) err = db.Store.AddDiscordIdByUsername(externalAuthId, userName)
case "google": case "google":
err = db.AddGoogleIdByUsername(externalAuthId, userName) err = db.Store.AddGoogleIdByUsername(externalAuthId, userName)
} }
if err != nil { if err != nil {
@ -622,16 +636,16 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
var userName string var userName string
switch provider { switch provider {
case "discord": case "discord":
userName, err = db.FetchUsernameByDiscordId(externalAuthId) userName, err = db.Store.FetchUsernameByDiscordId(externalAuthId)
case "google": case "google":
userName, err = db.FetchUsernameByGoogleId(externalAuthId) userName, err = db.Store.FetchUsernameByGoogleId(externalAuthId)
} }
if err != nil { if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther) http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return return
} }
sessionToken, err := account.GenerateTokenForUsername(userName) sessionToken, err := account.GenerateTokenForUsername(db.Store, userName)
if err != nil { if err != nil {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther) http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
return return
@ -651,6 +665,11 @@ func handleProviderCallback(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, account.GameURL, http.StatusSeeOther) http.Redirect(w, r, account.GameURL, http.StatusSeeOther)
} }
type HandleProviderLogoutStore interface {
RemoveDiscordIdByUUID(uuid []byte) error
RemoveGoogleIdByUUID(uuid []byte) error
}
func handleProviderLogout(w http.ResponseWriter, r *http.Request) { func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
uuid, err := uuidFromRequest(r) uuid, err := uuidFromRequest(r)
if err != nil { if err != nil {
@ -660,9 +679,9 @@ func handleProviderLogout(w http.ResponseWriter, r *http.Request) {
switch r.PathValue("provider") { switch r.PathValue("provider") {
case "discord": case "discord":
err = db.RemoveDiscordIdByUUID(uuid) err = db.Store.RemoveDiscordIdByUUID(uuid)
case "google": case "google":
err = db.RemoveGoogleIdByUUID(uuid) err = db.Store.RemoveGoogleIdByUUID(uuid)
default: default:
http.Error(w, "invalid provider", http.StatusBadRequest) http.Error(w, "invalid provider", http.StatusBadRequest)
return return
@ -681,13 +700,13 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
return return
} }
userDiscordId, err := db.FetchDiscordIdByUUID(uuid) userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusUnauthorized) httpError(w, r, err, http.StatusUnauthorized)
return return
} }
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil { if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return return
@ -698,19 +717,19 @@ func handleAdminDiscordLink(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username) _, err = db.Store.CheckUsernameExists(username)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return return
} }
userUuid, err := db.FetchUUIDFromUsername(username) userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
err = db.AddDiscordIdByUUID(discordId, userUuid) err = db.Store.AddDiscordIdByUUID(discordId, userUuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -728,13 +747,13 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
return return
} }
userDiscordId, err := db.FetchDiscordIdByUUID(uuid) userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusUnauthorized) httpError(w, r, err, http.StatusUnauthorized)
return return
} }
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil { if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return return
@ -748,26 +767,26 @@ func handleAdminDiscordUnlink(w http.ResponseWriter, r *http.Request) {
log.Printf("Username given, removing discordId") log.Printf("Username given, removing discordId")
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username) _, err = db.Store.CheckUsernameExists(username)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return return
} }
userUuid, err := db.FetchUUIDFromUsername(username) userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
err = db.RemoveDiscordIdByUUID(userUuid) err = db.Store.RemoveDiscordIdByUUID(userUuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
case discordId != "": case discordId != "":
log.Printf("DiscordID given, removing discordId") log.Printf("DiscordID given, removing discordId")
err = db.RemoveDiscordIdByDiscordId(discordId) err = db.Store.RemoveDiscordIdByDiscordId(discordId)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -786,13 +805,13 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
return return
} }
userDiscordId, err := db.FetchDiscordIdByUUID(uuid) userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusUnauthorized) httpError(w, r, err, http.StatusUnauthorized)
return return
} }
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil { if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return return
@ -803,19 +822,19 @@ func handleAdminGoogleLink(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username) _, err = db.Store.CheckUsernameExists(username)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return return
} }
userUuid, err := db.FetchUUIDFromUsername(username) userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
err = db.AddGoogleIdByUUID(googleId, userUuid) err = db.Store.AddGoogleIdByUUID(googleId, userUuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -833,13 +852,13 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
return return
} }
userDiscordId, err := db.FetchDiscordIdByUUID(uuid) userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusUnauthorized) httpError(w, r, err, http.StatusUnauthorized)
return return
} }
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil { if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return return
@ -853,26 +872,26 @@ func handleAdminGoogleUnlink(w http.ResponseWriter, r *http.Request) {
log.Printf("Username given, removing googleId") log.Printf("Username given, removing googleId")
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username) _, err = db.Store.CheckUsernameExists(username)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return return
} }
userUuid, err := db.FetchUUIDFromUsername(username) userUuid, err := db.Store.FetchUUIDFromUsername(username)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
err = db.RemoveGoogleIdByUUID(userUuid) err = db.Store.RemoveGoogleIdByUUID(userUuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
} }
case googleId != "": case googleId != "":
log.Printf("DiscordID given, removing googleId") log.Printf("DiscordID given, removing googleId")
err = db.RemoveGoogleIdByDiscordId(googleId) err = db.Store.RemoveGoogleIdByDiscordId(googleId)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return
@ -891,13 +910,13 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
return return
} }
userDiscordId, err := db.FetchDiscordIdByUUID(uuid) userDiscordId, err := db.Store.FetchDiscordIdByUUID(uuid)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusUnauthorized) httpError(w, r, err, http.StatusUnauthorized)
return return
} }
hasRole, err := account.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID) hasRole, err := account.Discord.IsUserDiscordAdmin(userDiscordId, account.DiscordGuildID)
if !hasRole || err != nil { if !hasRole || err != nil {
httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden) httpError(w, r, fmt.Errorf("user does not have the required role"), http.StatusForbidden)
return return
@ -907,14 +926,14 @@ func handleAdminSearch(w http.ResponseWriter, r *http.Request) {
// this does a quick call to make sure the username exists on the server before allowing the rest of the code to run // this does a quick call to make sure the username exists on the server before allowing the rest of the code to run
// this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server // this calls error value 404 (StatusNotFound) if there's no data; this means the username does not exist in the server
_, err = db.CheckUsernameExists(username) _, err = db.Store.CheckUsernameExists(username)
if err != nil { if err != nil {
httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound) httpError(w, r, fmt.Errorf("username does not exist on the server"), http.StatusNotFound)
return return
} }
// this does a single call that does a query for multiple columns from our database and makes an object out of it, which is returned to us // this does a single call that does a query for multiple columns from our database and makes an object out of it, which is returned to us
adminSearchResult, err := db.FetchAdminDetailsByUsername(username) adminSearchResult, err := db.Store.FetchAdminDetailsByUsername(username)
if err != nil { if err != nil {
httpError(w, r, err, http.StatusInternalServerError) httpError(w, r, err, http.StatusInternalServerError)
return return

View File

@ -21,7 +21,6 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
@ -30,10 +29,20 @@ type ClearResponse struct {
Error string `json:"error"` Error string `json:"error"`
} }
// Interface for database operations needed for `Clear`
// Helps with testing and reduces coupling.
type ClearStore interface {
UpdateAccountLastActivity(uuid []byte) error
TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error)
DeleteSessionSaveData(uuid []byte, slot int) error
AddOrUpdateAccountDailyRun(uuid []byte, score int, waveCompleted int) error
SetAccountBanned(uuid []byte, banned bool) error
}
// /savedata/clear - mark session save data as cleared and delete // /savedata/clear - mark session save data as cleared and delete
func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) { func Clear[T ClearStore](store T, uuid []byte, slot int, seed string, save defs.SessionSaveData) (ClearResponse, error) {
var response ClearResponse var response ClearResponse
err := db.UpdateAccountLastActivity(uuid) err := store.UpdateAccountLastActivity(uuid)
if err != nil { if err != nil {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
@ -51,23 +60,23 @@ func Clear(uuid []byte, slot int, seed string, save defs.SessionSaveData) (Clear
} }
if save.Score >= 20000 { if save.Score >= 20000 {
db.SetAccountBanned(uuid, true) store.SetAccountBanned(uuid, true)
} }
err = db.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted) err = store.AddOrUpdateAccountDailyRun(uuid, save.Score, waveCompleted)
if err != nil { if err != nil {
log.Printf("failed to add or update daily run record: %s", err) log.Printf("failed to add or update daily run record: %s", err)
} }
} }
if sessionCompleted { if sessionCompleted {
response.Success, err = db.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode)) response.Success, err = store.TryAddSeedCompletion(uuid, save.Seed, int(save.GameMode))
if err != nil { if err != nil {
log.Printf("failed to mark seed as completed: %s", err) log.Printf("failed to mark seed as completed: %s", err)
} }
} }
err = db.DeleteSessionSaveData(uuid, slot) err = store.DeleteSessionSaveData(uuid, slot)
if err != nil { if err != nil {
log.Printf("failed to delete session save data: %s", err) log.Printf("failed to delete session save data: %s", err)
} }

View File

@ -21,13 +21,19 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
// Interface for database operations needed for delete.
// This is to allow easier testing and less coupling.
type DeleteStore interface {
UpdateAccountLastActivity(uuid []byte) error
DeleteSessionSaveData(uuid []byte, slot int) error
}
// /savedata/delete - delete save data // /savedata/delete - delete save data
func Delete(uuid []byte, datatype, slot int) error { func Delete[T DeleteStore](store T, uuid []byte, datatype, slot int) error {
err := db.UpdateAccountLastActivity(uuid) err := store.UpdateAccountLastActivity(uuid)
if err != nil { if err != nil {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
@ -39,7 +45,7 @@ func Delete(uuid []byte, datatype, slot int) error {
break break
} }
err = db.DeleteSessionSaveData(uuid, slot) err = store.DeleteSessionSaveData(uuid, slot)
default: default:
err = fmt.Errorf("invalid data type") err = fmt.Errorf("invalid data type")
} }

View File

@ -20,22 +20,26 @@ package savedata
import ( import (
"fmt" "fmt"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
type NewClearStore interface {
ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
ReadSeedCompleted(uuid []byte, seed string) (bool, error)
}
// /savedata/newclear - return whether a session is a new clear for its seed // /savedata/newclear - return whether a session is a new clear for its seed
func NewClear(uuid []byte, slot int) (bool, error) { func NewClear[T NewClearStore](store T, uuid []byte, slot int) (bool, error) {
if slot < 0 || slot >= defs.SessionSlotCount { if slot < 0 || slot >= defs.SessionSlotCount {
return false, fmt.Errorf("slot id %d out of range", slot) return false, fmt.Errorf("slot id %d out of range", slot)
} }
session, err := db.ReadSessionSaveData(uuid, slot) session, err := store.ReadSessionSaveData(uuid, slot)
if err != nil { if err != nil {
return false, err return false, err
} }
completed, err := db.ReadSeedCompleted(uuid, session.Seed) completed, err := store.ReadSeedCompleted(uuid, session.Seed)
if err != nil { if err != nil {
return false, fmt.Errorf("failed to read seed completed: %s", err) return false, fmt.Errorf("failed to read seed completed: %s", err)
} }

View File

@ -21,12 +21,15 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) { type GetSessionStore interface {
session, err := db.ReadSessionSaveData(uuid, slot) ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error)
}
func GetSession[T GetSessionStore](store T, uuid []byte, slot int) (defs.SessionSaveData, error) {
session, err := store.ReadSessionSaveData(uuid, slot)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
err = ErrSaveNotExist err = ErrSaveNotExist
@ -38,8 +41,13 @@ func GetSession(uuid []byte, slot int) (defs.SessionSaveData, error) {
return session, nil return session, nil
} }
func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error { type UpdateSessionStore interface {
err := db.StoreSessionSaveData(uuid, data, slot) StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error
DeleteSessionSaveData(uuid []byte, slot int) error
}
func UpdateSession[T UpdateSessionStore](store T, uuid []byte, slot int, data defs.SessionSaveData) error {
err := store.StoreSessionSaveData(uuid, data, slot)
if err != nil { if err != nil {
return err return err
} }
@ -47,8 +55,12 @@ func UpdateSession(uuid []byte, slot int, data defs.SessionSaveData) error {
return nil return nil
} }
func DeleteSession(uuid []byte, slot int) error { type DeleteSessionStore interface {
err := db.DeleteSessionSaveData(uuid, slot) DeleteSessionSaveData(uuid []byte, slot int) error
}
func DeleteSession[T DeleteSessionStore](store T, uuid []byte, slot int) error {
err := store.DeleteSessionSaveData(uuid, slot)
if err != nil { if err != nil {
return err return err
} }

View File

@ -24,24 +24,28 @@ import (
"os" "os"
"github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
var ErrSaveNotExist = errors.New("save does not exist") var ErrSaveNotExist = errors.New("save does not exist")
func GetSystem(uuid []byte) (defs.SystemSaveData, error) { type GetSystemStore interface {
GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error)
ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error)
}
func GetSystem[T GetSystemStore](store T, uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData var system defs.SystemSaveData
var err error var err error
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3 if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
system, err = db.GetSystemSaveFromS3(uuid) system, err = store.GetSystemSaveFromS3(uuid)
var nokey *types.NoSuchKey var nokey *types.NoSuchKey
if errors.As(err, &nokey) { if errors.As(err, &nokey) {
err = ErrSaveNotExist err = ErrSaveNotExist
} }
} else { // use database } else { // use database
system, err = db.ReadSystemSaveData(uuid) system, err = store.ReadSystemSaveData(uuid)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
err = ErrSaveNotExist err = ErrSaveNotExist
} }
@ -53,20 +57,27 @@ func GetSystem(uuid []byte) (defs.SystemSaveData, error) {
return system, nil return system, nil
} }
func UpdateSystem(uuid []byte, data defs.SystemSaveData) error { // Interface for database operations needed for updating system data.
type UpdateSystemStore interface {
UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error
StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error
StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error
}
func UpdateSystem[T UpdateSystemStore](store T, uuid []byte, data defs.SystemSaveData) error {
if data.TrainerId == 0 && data.SecretId == 0 { if data.TrainerId == 0 && data.SecretId == 0 {
return fmt.Errorf("invalid system data") return fmt.Errorf("invalid system data")
} }
err := db.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts) err := store.UpdateAccountStats(uuid, data.GameStats, data.VoucherCounts)
if err != nil { if err != nil {
return fmt.Errorf("failed to update account stats: %s", err) return fmt.Errorf("failed to update account stats: %s", err)
} }
if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3 if os.Getenv("S3_SYSTEM_BUCKET_NAME") != "" { // use S3
err = db.StoreSystemSaveDataS3(uuid, data) err = store.StoreSystemSaveDataS3(uuid, data)
} else { } else {
err = db.StoreSystemSaveData(uuid, data) err = store.StoreSystemSaveData(uuid, data)
} }
if err != nil { if err != nil {
return err return err
@ -75,8 +86,12 @@ func UpdateSystem(uuid []byte, data defs.SystemSaveData) error {
return nil return nil
} }
func DeleteSystem(uuid []byte) error { type DeleteSystemStore interface {
err := db.DeleteSystemSaveData(uuid) DeleteSystemSaveData(uuid []byte) error
}
func DeleteSystem[T DeleteSystemStore](store T, uuid []byte) error {
err := store.DeleteSystemSaveData(uuid)
if err != nil { if err != nil {
return err return err
} }

View File

@ -21,13 +21,18 @@ import (
"fmt" "fmt"
"log" "log"
"github.com/pagefaultgames/rogueserver/db"
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
type UpdateStore interface {
UpdateAccountLastActivity(uuid []byte) error
UpdateSystemStore
UpdateSessionStore
}
// /savedata/update - update save data // /savedata/update - update save data
func Update(uuid []byte, slot int, save any) error { func Update[T UpdateStore](store T, uuid []byte, slot int, save any) error {
err := db.UpdateAccountLastActivity(uuid) err := store.UpdateAccountLastActivity(uuid)
if err != nil { if err != nil {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
@ -38,13 +43,13 @@ func Update(uuid []byte, slot int, save any) error {
return fmt.Errorf("invalid system data") return fmt.Errorf("invalid system data")
} }
return UpdateSystem(uuid, save) return UpdateSystem(store, uuid, save)
case defs.SessionSaveData: // Session case defs.SessionSaveData: // Session
if slot < 0 || slot >= defs.SessionSlotCount { if slot < 0 || slot >= defs.SessionSlotCount {
return fmt.Errorf("slot id %d out of range", slot) return fmt.Errorf("slot id %d out of range", slot)
} }
return UpdateSession(uuid, slot, save) return UpdateSession(store, uuid, slot, save)
default: default:
return fmt.Errorf("invalid data type") return fmt.Errorf("invalid data type")
} }

View File

@ -21,7 +21,6 @@ import (
"log" "log"
"time" "time"
"github.com/pagefaultgames/rogueserver/db"
"github.com/robfig/cron/v3" "github.com/robfig/cron/v3"
) )
@ -32,9 +31,9 @@ var (
classicSessionCount int classicSessionCount int
) )
func scheduleStatRefresh() error { func scheduleStatRefresh[T updateStatsStore](store T) error {
_, err := scheduler.AddFunc("@every 30s", func() { _, err := scheduler.AddFunc("@every 30s", func() {
err := updateStats() err := updateStats(store)
if err != nil { if err != nil {
log.Printf("failed to update stats: %s", err) log.Printf("failed to update stats: %s", err)
} }
@ -47,19 +46,25 @@ func scheduleStatRefresh() error {
return nil return nil
} }
func updateStats() error { type updateStatsStore interface {
FetchPlayerCount() (int, error)
FetchBattleCount() (int, error)
FetchClassicSessionCount() (int, error)
}
func updateStats[T updateStatsStore](store T) error {
var err error var err error
playerCount, err = db.FetchPlayerCount() playerCount, err = store.FetchPlayerCount()
if err != nil { if err != nil {
return err return err
} }
battleCount, err = db.FetchBattleCount() battleCount, err = store.FetchBattleCount()
if err != nil { if err != nil {
return err return err
} }
classicSessionCount, err = db.FetchClassicSessionCount() classicSessionCount, err = store.FetchClassicSessionCount()
if err != nil { if err != nil {
return err return err
} }

View File

@ -1,18 +1,18 @@
/* /*
Copyright (C) 2024 - 2025 Pagefault Games Copyright (C) 2024 - 2025 Pagefault Games
This program is free software: you can redistribute it and/or modify This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published by it under the terms of the GNU Affero General Public License as published by
the Free Software Foundation, either version 3 of the License, or the Free Software Foundation, either version 3 of the License, or
(at your option) any later version. (at your option) any later version.
This program is distributed in the hope that it will be useful, This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details. GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <http://www.gnu.org/licenses/>. along with this program. If not, see <http://www.gnu.org/licenses/>.
*/ */
package db package db
@ -27,30 +27,30 @@ import (
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
func AddAccountRecord(uuid []byte, username string, key, salt []byte) error { func (s *store) AddAccountSession(username string, token []byte) error {
_, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
if err != nil {
return err
}
return nil
}
func AddAccountSession(username string, token []byte) error {
_, err := handle.Exec("INSERT INTO sessions (uuid, token, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username) _, err := handle.Exec("INSERT INTO sessions (uuid, token, expire) SELECT a.uuid, ?, DATE_ADD(UTC_TIMESTAMP(), INTERVAL 1 WEEK) FROM accounts a WHERE a.username = ?", token, username)
if err != nil { if err != nil {
return err return err
} }
_, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username) _, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username)
if err != nil { if err != nil {
return err return err
} }
return nil return nil
} }
func AddDiscordIdByUsername(discordId string, username string) error { // (removed, now a method on store)
func (s *store) FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) {
var key, salt []byte
err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
if err != nil {
return nil, nil, err
}
return key, salt, nil
}
func (s *store) AddDiscordIdByUsername(discordId string, username string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE username = ?", discordId, username) _, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE username = ?", discordId, username)
if err != nil { if err != nil {
return err return err
@ -59,7 +59,7 @@ func AddDiscordIdByUsername(discordId string, username string) error {
return nil return nil
} }
func AddGoogleIdByUsername(googleId string, username string) error { func (s *store) AddGoogleIdByUsername(googleId string, username string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE username = ?", googleId, username) _, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE username = ?", googleId, username)
if err != nil { if err != nil {
return err return err
@ -68,7 +68,7 @@ func AddGoogleIdByUsername(googleId string, username string) error {
return nil return nil
} }
func AddGoogleIdByUUID(googleId string, uuid []byte) error { func (s *store) AddGoogleIdByUUID(googleId string, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE uuid = ?", googleId, uuid) _, err := handle.Exec("UPDATE accounts SET googleId = ? WHERE uuid = ?", googleId, uuid)
if err != nil { if err != nil {
return err return err
@ -77,7 +77,7 @@ func AddGoogleIdByUUID(googleId string, uuid []byte) error {
return nil return nil
} }
func AddDiscordIdByUUID(discordId string, uuid []byte) error { func (s *store) AddDiscordIdByUUID(discordId string, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE uuid = ?", discordId, uuid) _, err := handle.Exec("UPDATE accounts SET discordId = ? WHERE uuid = ?", discordId, uuid)
if err != nil { if err != nil {
return err return err
@ -86,7 +86,7 @@ func AddDiscordIdByUUID(discordId string, uuid []byte) error {
return nil return nil
} }
func FetchUsernameByDiscordId(discordId string) (string, error) { func (s *store) FetchUsernameByDiscordId(discordId string) (string, error) {
var username string var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE discordId = ?", discordId).Scan(&username) err := handle.QueryRow("SELECT username FROM accounts WHERE discordId = ?", discordId).Scan(&username)
if err != nil { if err != nil {
@ -96,7 +96,7 @@ func FetchUsernameByDiscordId(discordId string) (string, error) {
return username, nil return username, nil
} }
func FetchUsernameByGoogleId(googleId string) (string, error) { func (s *store) FetchUsernameByGoogleId(googleId string) (string, error) {
var username string var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE googleId = ?", googleId).Scan(&username) err := handle.QueryRow("SELECT username FROM accounts WHERE googleId = ?", googleId).Scan(&username)
if err != nil { if err != nil {
@ -106,7 +106,7 @@ func FetchUsernameByGoogleId(googleId string) (string, error) {
return username, nil return username, nil
} }
func FetchDiscordIdByUsername(username string) (string, error) { func (s *store) FetchDiscordIdByUsername(username string) (string, error) {
var discordId sql.NullString var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE username = ?", username).Scan(&discordId) err := handle.QueryRow("SELECT discordId FROM accounts WHERE username = ?", username).Scan(&discordId)
if err != nil { if err != nil {
@ -120,7 +120,7 @@ func FetchDiscordIdByUsername(username string) (string, error) {
return discordId.String, nil return discordId.String, nil
} }
func FetchGoogleIdByUsername(username string) (string, error) { func (s *store) FetchGoogleIdByUsername(username string) (string, error) {
var googleId sql.NullString var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE username = ?", username).Scan(&googleId) err := handle.QueryRow("SELECT googleId FROM accounts WHERE username = ?", username).Scan(&googleId)
if err != nil { if err != nil {
@ -134,7 +134,7 @@ func FetchGoogleIdByUsername(username string) (string, error) {
return googleId.String, nil return googleId.String, nil
} }
func FetchDiscordIdByUUID(uuid []byte) (string, error) { func (s *store) FetchDiscordIdByUUID(uuid []byte) (string, error) {
var discordId sql.NullString var discordId sql.NullString
err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId) err := handle.QueryRow("SELECT discordId FROM accounts WHERE uuid = ?", uuid).Scan(&discordId)
if err != nil { if err != nil {
@ -148,7 +148,7 @@ func FetchDiscordIdByUUID(uuid []byte) (string, error) {
return discordId.String, nil return discordId.String, nil
} }
func FetchGoogleIdByUUID(uuid []byte) (string, error) { func (s *store) FetchGoogleIdByUUID(uuid []byte) (string, error) {
var googleId sql.NullString var googleId sql.NullString
err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId) err := handle.QueryRow("SELECT googleId FROM accounts WHERE uuid = ?", uuid).Scan(&googleId)
if err != nil { if err != nil {
@ -162,7 +162,7 @@ func FetchGoogleIdByUUID(uuid []byte) (string, error) {
return googleId.String, nil return googleId.String, nil
} }
func FetchUsernameBySessionToken(token []byte) (string, error) { func (s *store) FetchUsernameBySessionToken(token []byte) (string, error) {
var username string var username string
err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username) err := handle.QueryRow("SELECT a.username FROM accounts a JOIN sessions s ON a.uuid = s.uuid WHERE s.token = ?", token).Scan(&username)
if err != nil { if err != nil {
@ -172,7 +172,7 @@ func FetchUsernameBySessionToken(token []byte) (string, error) {
return username, nil return username, nil
} }
func CheckUsernameExists(username string) (string, error) { func (s *store) CheckUsernameExists(username string) (string, error) {
var dbUsername sql.NullString var dbUsername sql.NullString
err := handle.QueryRow("SELECT username FROM accounts WHERE username = ?", username).Scan(&dbUsername) err := handle.QueryRow("SELECT username FROM accounts WHERE username = ?", username).Scan(&dbUsername)
if err != nil { if err != nil {
@ -185,7 +185,7 @@ func CheckUsernameExists(username string) (string, error) {
return dbUsername.String, nil return dbUsername.String, nil
} }
func FetchLastLoggedInDateByUsername(username string) (string, error) { func (s *store) FetchLastLoggedInDateByUsername(username string) (string, error) {
var lastLoggedIn sql.NullString var lastLoggedIn sql.NullString
err := handle.QueryRow("SELECT lastLoggedIn FROM accounts WHERE username = ?", username).Scan(&lastLoggedIn) err := handle.QueryRow("SELECT lastLoggedIn FROM accounts WHERE username = ?", username).Scan(&lastLoggedIn)
if err != nil { if err != nil {
@ -206,7 +206,7 @@ type AdminSearchResponse struct {
Registered string `json:"registered"` Registered string `json:"registered"`
} }
func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) { func (s *store) FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error) {
var username, discordId, googleId, lastActivity, registered sql.NullString var username, discordId, googleId, lastActivity, registered sql.NullString
var adminResponse AdminSearchResponse var adminResponse AdminSearchResponse
@ -226,7 +226,7 @@ func FetchAdminDetailsByUsername(dbUsername string) (AdminSearchResponse, error)
return adminResponse, nil return adminResponse, nil
} }
func UpdateAccountPassword(uuid, key, salt []byte) error { func (s *store) UpdateAccountPassword(uuid, key, salt []byte) error {
_, err := handle.Exec("UPDATE accounts SET hash = ?, salt = ? WHERE uuid = ?", key, salt, uuid) _, err := handle.Exec("UPDATE accounts SET hash = ?, salt = ? WHERE uuid = ?", key, salt, uuid)
if err != nil { if err != nil {
return err return err
@ -235,7 +235,7 @@ func UpdateAccountPassword(uuid, key, salt []byte) error {
return nil return nil
} }
func UpdateAccountLastActivity(uuid []byte) error { func (s *store) UpdateAccountLastActivity(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET lastActivity = UTC_TIMESTAMP() WHERE uuid = ?", uuid) _, err := handle.Exec("UPDATE accounts SET lastActivity = UTC_TIMESTAMP() WHERE uuid = ?", uuid)
if err != nil { if err != nil {
return err return err
@ -244,7 +244,7 @@ func UpdateAccountLastActivity(uuid []byte) error {
return nil return nil
} }
func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error { func (s *store) UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[string]int) error {
var columns = []string{"playTime", "battles", "classicSessionsPlayed", "sessionsWon", "highestEndlessWave", "highestLevel", "pokemonSeen", "pokemonDefeated", "pokemonCaught", "pokemonHatched", "eggsPulled", "regularVouchers", "plusVouchers", "premiumVouchers", "goldenVouchers"} var columns = []string{"playTime", "battles", "classicSessionsPlayed", "sessionsWon", "highestEndlessWave", "highestLevel", "pokemonSeen", "pokemonDefeated", "pokemonCaught", "pokemonHatched", "eggsPulled", "regularVouchers", "plusVouchers", "premiumVouchers", "goldenVouchers"}
var statCols []string var statCols []string
@ -321,7 +321,7 @@ func UpdateAccountStats(uuid []byte, stats defs.GameStats, voucherCounts map[str
return nil return nil
} }
func SetAccountBanned(uuid []byte, banned bool) error { func (s *store) SetAccountBanned(uuid []byte, banned bool) error {
_, err := handle.Exec("UPDATE accounts SET banned = ? WHERE uuid = ?", banned, uuid) _, err := handle.Exec("UPDATE accounts SET banned = ? WHERE uuid = ?", banned, uuid)
if err != nil { if err != nil {
return err return err
@ -330,17 +330,16 @@ func SetAccountBanned(uuid []byte, banned bool) error {
return nil return nil
} }
func FetchAccountKeySaltFromUsername(username string) ([]byte, []byte, error) { func (s *store) AddAccountRecord(uuid []byte, username string, key, salt []byte) error {
var key, salt []byte _, err := handle.Exec("INSERT INTO accounts (uuid, username, hash, salt, registered) VALUES (?, ?, ?, ?, UTC_TIMESTAMP())", uuid, username, key, salt)
err := handle.QueryRow("SELECT hash, salt FROM accounts WHERE username = ?", username).Scan(&key, &salt)
if err != nil { if err != nil {
return nil, nil, err return err
} }
return key, salt, nil return nil
} }
func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) { func (s *store) FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
err = handle.QueryRow("SELECT trainerId, secretId FROM accounts WHERE uuid = ?", uuid).Scan(&trainerId, &secretId) err = handle.QueryRow("SELECT trainerId, secretId FROM accounts WHERE uuid = ?", uuid).Scan(&trainerId, &secretId)
if err != nil { if err != nil {
return 0, 0, err return 0, 0, err
@ -349,7 +348,7 @@ func FetchTrainerIds(uuid []byte) (trainerId, secretId int, err error) {
return trainerId, secretId, nil return trainerId, secretId, nil
} }
func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error { func (s *store) UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET trainerId = ?, secretId = ? WHERE uuid = ?", trainerId, secretId, uuid) _, err := handle.Exec("UPDATE accounts SET trainerId = ?, secretId = ? WHERE uuid = ?", trainerId, secretId, uuid)
if err != nil { if err != nil {
return err return err
@ -358,12 +357,12 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error {
return nil return nil
} }
func IsActiveSession(uuid []byte, sessionId string) (bool, error) { func (s *store) IsActiveSession(uuid []byte, sessionId string) (bool, error) {
var id string var id string
err := handle.QueryRow("SELECT clientSessionId FROM activeClientSessions WHERE uuid = ?", uuid).Scan(&id) err := handle.QueryRow("SELECT clientSessionId FROM activeClientSessions WHERE uuid = ?", uuid).Scan(&id)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
err = UpdateActiveSession(uuid, sessionId) err = s.UpdateActiveSession(uuid, sessionId)
if err != nil { if err != nil {
return false, err return false, err
} }
@ -377,7 +376,7 @@ func IsActiveSession(uuid []byte, sessionId string) (bool, error) {
return id == "" || id == sessionId, nil return id == "" || id == sessionId, nil
} }
func UpdateActiveSession(uuid []byte, clientSessionId string) error { func (s *store) UpdateActiveSession(uuid []byte, clientSessionId string) error {
_, err := handle.Exec("INSERT INTO activeClientSessions (uuid, clientSessionId) VALUES (?, ?) ON DUPLICATE KEY UPDATE clientSessionId = ?", uuid, clientSessionId, clientSessionId) _, err := handle.Exec("INSERT INTO activeClientSessions (uuid, clientSessionId) VALUES (?, ?) ON DUPLICATE KEY UPDATE clientSessionId = ?", uuid, clientSessionId, clientSessionId)
if err != nil { if err != nil {
return err return err
@ -386,7 +385,7 @@ func UpdateActiveSession(uuid []byte, clientSessionId string) error {
return nil return nil
} }
func FetchUUIDFromToken(token []byte) ([]byte, error) { func (s *store) FetchUUIDFromToken(token []byte) ([]byte, error) {
var uuid []byte var uuid []byte
err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid) err := handle.QueryRow("SELECT uuid FROM sessions WHERE token = ?", token).Scan(&uuid)
if err != nil { if err != nil {
@ -396,7 +395,7 @@ func FetchUUIDFromToken(token []byte) ([]byte, error) {
return uuid, nil return uuid, nil
} }
func RemoveSessionFromToken(token []byte) error { func (s *store) RemoveSessionFromToken(token []byte) error {
_, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token) _, err := handle.Exec("DELETE FROM sessions WHERE token = ?", token)
if err != nil { if err != nil {
return err return err
@ -405,7 +404,7 @@ func RemoveSessionFromToken(token []byte) error {
return nil return nil
} }
func RemoveSessionsFromUUID(uuid []byte) error { func (s *store) RemoveSessionsFromUUID(uuid []byte) error {
_, err := handle.Exec("DELETE FROM sessions WHERE uuid = ?", uuid) _, err := handle.Exec("DELETE FROM sessions WHERE uuid = ?", uuid)
if err != nil { if err != nil {
return err return err
@ -414,7 +413,7 @@ func RemoveSessionsFromUUID(uuid []byte) error {
return nil return nil
} }
func FetchUsernameFromUUID(uuid []byte) (string, error) { func (s *store) FetchUsernameFromUUID(uuid []byte) (string, error) {
var username string var username string
err := handle.QueryRow("SELECT username FROM accounts WHERE uuid = ?", uuid).Scan(&username) err := handle.QueryRow("SELECT username FROM accounts WHERE uuid = ?", uuid).Scan(&username)
if err != nil { if err != nil {
@ -424,7 +423,7 @@ func FetchUsernameFromUUID(uuid []byte) (string, error) {
return username, nil return username, nil
} }
func FetchUUIDFromUsername(username string) ([]byte, error) { func (s *store) FetchUUIDFromUsername(username string) ([]byte, error) {
var uuid []byte var uuid []byte
err := handle.QueryRow("SELECT uuid FROM accounts WHERE username = ?", username).Scan(&uuid) err := handle.QueryRow("SELECT uuid FROM accounts WHERE username = ?", username).Scan(&uuid)
if err != nil { if err != nil {
@ -434,7 +433,7 @@ func FetchUUIDFromUsername(username string) ([]byte, error) {
return uuid, nil return uuid, nil
} }
func RemoveDiscordIdByUUID(uuid []byte) error { func (s *store) RemoveDiscordIdByUUID(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE uuid = ?", uuid) _, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE uuid = ?", uuid)
if err != nil { if err != nil {
return err return err
@ -443,7 +442,7 @@ func RemoveDiscordIdByUUID(uuid []byte) error {
return nil return nil
} }
func RemoveGoogleIdByUUID(uuid []byte) error { func (s *store) RemoveGoogleIdByUUID(uuid []byte) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE uuid = ?", uuid) _, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE uuid = ?", uuid)
if err != nil { if err != nil {
return err return err
@ -452,7 +451,7 @@ func RemoveGoogleIdByUUID(uuid []byte) error {
return nil return nil
} }
func RemoveGoogleIdByUsername(username string) error { func (s *store) RemoveGoogleIdByUsername(username string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE username = ?", username) _, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE username = ?", username)
if err != nil { if err != nil {
return err return err
@ -461,7 +460,7 @@ func RemoveGoogleIdByUsername(username string) error {
return nil return nil
} }
func RemoveDiscordIdByUsername(username string) error { func (s *store) RemoveDiscordIdByUsername(username string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE username = ?", username) _, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE username = ?", username)
if err != nil { if err != nil {
return err return err
@ -470,7 +469,7 @@ func RemoveDiscordIdByUsername(username string) error {
return nil return nil
} }
func RemoveDiscordIdByDiscordId(discordId string) error { func (s *store) RemoveDiscordIdByDiscordId(discordId string) error {
_, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE discordId = ?", discordId) _, err := handle.Exec("UPDATE accounts SET discordId = NULL WHERE discordId = ?", discordId)
if err != nil { if err != nil {
return err return err
@ -479,7 +478,7 @@ func RemoveDiscordIdByDiscordId(discordId string) error {
return nil return nil
} }
func RemoveGoogleIdByDiscordId(discordId string) error { func (s *store) RemoveGoogleIdByDiscordId(discordId string) error {
_, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE discordId = ?", discordId) _, err := handle.Exec("UPDATE accounts SET googleId = NULL WHERE discordId = ?", discordId)
if err != nil { if err != nil {
return err return err

View File

@ -23,7 +23,7 @@ import (
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
) )
func TryAddDailyRun(seed string) (string, error) { func (s *store) TryAddDailyRun(seed string) (string, error) {
var actualSeed string var actualSeed string
err := handle.QueryRow("INSERT INTO dailyRuns (seed, date) VALUES (?, UTC_DATE()) ON DUPLICATE KEY UPDATE date = date RETURNING seed", seed).Scan(&actualSeed) err := handle.QueryRow("INSERT INTO dailyRuns (seed, date) VALUES (?, UTC_DATE()) ON DUPLICATE KEY UPDATE date = date RETURNING seed", seed).Scan(&actualSeed)
if err != nil { if err != nil {
@ -33,7 +33,7 @@ func TryAddDailyRun(seed string) (string, error) {
return actualSeed, nil return actualSeed, nil
} }
func GetDailyRunSeed() (string, error) { func (s *store) GetDailyRunSeed() (string, error) {
var seed string var seed string
err := handle.QueryRow("SELECT seed FROM dailyRuns WHERE date = UTC_DATE()").Scan(&seed) err := handle.QueryRow("SELECT seed FROM dailyRuns WHERE date = UTC_DATE()").Scan(&seed)
if err != nil { if err != nil {
@ -43,7 +43,7 @@ func GetDailyRunSeed() (string, error) {
return seed, nil return seed, nil
} }
func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error { func (s *store) AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
_, err := handle.Exec("INSERT INTO accountDailyRuns (uuid, date, score, wave, timestamp) VALUES (?, UTC_DATE(), ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE score = GREATEST(score, ?), wave = GREATEST(wave, ?), timestamp = IF(score < ?, UTC_TIMESTAMP(), timestamp)", uuid, score, wave, score, wave, score) _, err := handle.Exec("INSERT INTO accountDailyRuns (uuid, date, score, wave, timestamp) VALUES (?, UTC_DATE(), ?, ?, UTC_TIMESTAMP()) ON DUPLICATE KEY UPDATE score = GREATEST(score, ?), wave = GREATEST(wave, ?), timestamp = IF(score < ?, UTC_TIMESTAMP(), timestamp)", uuid, score, wave, score, wave, score)
if err != nil { if err != nil {
return err return err
@ -52,7 +52,7 @@ func AddOrUpdateAccountDailyRun(uuid []byte, score int, wave int) error {
return nil return nil
} }
func FetchRankings(category int, page int) ([]defs.DailyRanking, error) { func (s *store) FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
var rankings []defs.DailyRanking var rankings []defs.DailyRanking
offset := (page - 1) * 10 offset := (page - 1) * 10
@ -85,7 +85,7 @@ func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
return rankings, nil return rankings, nil
} }
func FetchRankingPageCount(category int) (int, error) { func (s *store) FetchRankingPageCount(category int) (int, error) {
var query string var query string
switch category { switch category {
case 0: case 0:

View File

@ -31,6 +31,12 @@ import (
var handle *sql.DB var handle *sql.DB
var s3client *s3.Client var s3client *s3.Client
// internal type used to implement the Store interface
type store struct{}
// Store is the global instance for DB access.
var Store = &store{}
func Init(username, password, protocol, address, database string) error { func Init(username, password, protocol, address, database string) error {
var err error var err error

View File

@ -17,7 +17,7 @@
package db package db
func FetchPlayerCount() (int, error) { func (s *store) FetchPlayerCount() (int, error) {
var playerCount int var playerCount int
err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount) err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount)
if err != nil { if err != nil {
@ -27,7 +27,7 @@ func FetchPlayerCount() (int, error) {
return playerCount, nil return playerCount, nil
} }
func FetchBattleCount() (int, error) { func (s *store) FetchBattleCount() (int, error) {
var battleCount int var battleCount int
err := handle.QueryRow("SELECT COALESCE(SUM(s.battles), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&battleCount) err := handle.QueryRow("SELECT COALESCE(SUM(s.battles), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&battleCount)
if err != nil { if err != nil {
@ -37,7 +37,7 @@ func FetchBattleCount() (int, error) {
return battleCount, nil return battleCount, nil
} }
func FetchClassicSessionCount() (int, error) { func (s *store) FetchClassicSessionCount() (int, error) {
var classicSessionCount int var classicSessionCount int
err := handle.QueryRow("SELECT COALESCE(SUM(s.classicSessionsPlayed), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&classicSessionCount) err := handle.QueryRow("SELECT COALESCE(SUM(s.classicSessionsPlayed), 0) FROM accountStats s JOIN accounts a ON a.uuid = s.uuid WHERE a.banned = 0").Scan(&classicSessionCount)
if err != nil { if err != nil {

61
db/s3.go Normal file
View File

@ -0,0 +1,61 @@
package db
import (
"bytes"
"context"
"encoding/json"
"os"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
"github.com/pagefaultgames/rogueserver/defs"
)
func (s *store) GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
username, err := Store.FetchUsernameFromUUID(uuid)
if err != nil {
return system, err
}
resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
})
if err != nil {
return system, err
}
err = json.NewDecoder(resp.Body).Decode(&system)
if err != nil {
return system, err
}
return system, nil
}
func (s *store) StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error {
username, err := s.FetchUsernameFromUUID(uuid)
if err != nil {
return err
}
buf := new(bytes.Buffer)
err = json.NewEncoder(buf).Encode(data)
if err != nil {
return err
}
_, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
Body: buf,
})
if err != nil {
return err
}
return nil
}

View File

@ -19,19 +19,13 @@ package db
import ( import (
"bytes" "bytes"
"context"
"encoding/gob" "encoding/gob"
"encoding/json"
"os"
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
"github.com/pagefaultgames/rogueserver/defs" "github.com/pagefaultgames/rogueserver/defs"
"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/service/s3"
) )
func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) { func (s *store) TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
var count int var count int
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count) err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
if err != nil { if err != nil {
@ -48,7 +42,7 @@ func TryAddSeedCompletion(uuid []byte, seed string, mode int) (bool, error) {
return true, nil return true, nil
} }
func ReadSeedCompleted(uuid []byte, seed string) (bool, error) { func (s *store) ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
var count int var count int
err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count) err := handle.QueryRow("SELECT COUNT(*) FROM dailyRunCompletions WHERE uuid = ? AND seed = ?", uuid, seed).Scan(&count)
if err != nil { if err != nil {
@ -58,7 +52,7 @@ func ReadSeedCompleted(uuid []byte, seed string) (bool, error) {
return count > 0, nil return count > 0, nil
} }
func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) { func (s *store) ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData var system defs.SystemSaveData
var data []byte var data []byte
@ -82,7 +76,7 @@ func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
return system, nil return system, nil
} }
func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error { func (s *store) StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
zw, err := zstd.NewWriter(buf) zw, err := zstd.NewWriter(buf)
@ -108,32 +102,7 @@ func StoreSystemSaveData(uuid []byte, data defs.SystemSaveData) error {
return nil return nil
} }
func StoreSystemSaveDataS3(uuid []byte, data defs.SystemSaveData) error { func (s *store) DeleteSystemSaveData(uuid []byte) error {
username, err := FetchUsernameFromUUID(uuid)
if err != nil {
return err
}
buf := new(bytes.Buffer)
err = json.NewEncoder(buf).Encode(data)
if err != nil {
return err
}
_, err = s3client.PutObject(context.Background(), &s3.PutObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
Body: buf,
})
if err != nil {
return err
}
return nil
}
func DeleteSystemSaveData(uuid []byte) error {
_, err := handle.Exec("DELETE FROM systemSaveData WHERE uuid = ?", uuid) _, err := handle.Exec("DELETE FROM systemSaveData WHERE uuid = ?", uuid)
if err != nil { if err != nil {
return err return err
@ -142,7 +111,7 @@ func DeleteSystemSaveData(uuid []byte) error {
return nil return nil
} }
func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) { func (s *store) ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
var session defs.SessionSaveData var session defs.SessionSaveData
var data []byte var data []byte
@ -166,7 +135,7 @@ func ReadSessionSaveData(uuid []byte, slot int) (defs.SessionSaveData, error) {
return session, nil return session, nil
} }
func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) { func (s *store) GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
var slot int var slot int
err := handle.QueryRow("SELECT slot FROM sessionSaveData WHERE uuid = ? ORDER BY timestamp DESC, slot ASC LIMIT 1", uuid).Scan(&slot) err := handle.QueryRow("SELECT slot FROM sessionSaveData WHERE uuid = ? ORDER BY timestamp DESC, slot ASC LIMIT 1", uuid).Scan(&slot)
if err != nil { if err != nil {
@ -176,7 +145,7 @@ func GetLatestSessionSaveDataSlot(uuid []byte) (int, error) {
return slot, nil return slot, nil
} }
func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error { func (s *store) StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) error {
buf := new(bytes.Buffer) buf := new(bytes.Buffer)
zw, err := zstd.NewWriter(buf) zw, err := zstd.NewWriter(buf)
@ -202,7 +171,7 @@ func StoreSessionSaveData(uuid []byte, data defs.SessionSaveData, slot int) erro
return nil return nil
} }
func DeleteSessionSaveData(uuid []byte, slot int) error { func (s *store) DeleteSessionSaveData(uuid []byte, slot int) error {
_, err := handle.Exec("DELETE FROM sessionSaveData WHERE uuid = ? AND slot = ?", uuid, slot) _, err := handle.Exec("DELETE FROM sessionSaveData WHERE uuid = ? AND slot = ?", uuid, slot)
if err != nil { if err != nil {
return err return err
@ -210,27 +179,3 @@ func DeleteSessionSaveData(uuid []byte, slot int) error {
return nil return nil
} }
func GetSystemSaveFromS3(uuid []byte) (defs.SystemSaveData, error) {
var system defs.SystemSaveData
username, err := FetchUsernameFromUUID(uuid)
if err != nil {
return system, err
}
resp, err := s3client.GetObject(context.TODO(), &s3.GetObjectInput{
Bucket: aws.String(os.Getenv("S3_SYSTEM_BUCKET_NAME")),
Key: aws.String(username),
})
if err != nil {
return system, err
}
err = json.NewDecoder(resp.Body).Decode(&system)
if err != nil {
return system, err
}
return system, nil
}