Various changes

pull/1/head
maru 9 months ago
parent 253e462536
commit d12a008259
No known key found for this signature in database
GPG Key ID: 37689350E9CD0F0D

@ -8,12 +8,12 @@ import (
"github.com/Flashfyre/pokerogue-server/db" "github.com/Flashfyre/pokerogue-server/db"
) )
func GetUsernameFromRequest(request *http.Request) (string, error) { func getUsernameFromRequest(r *http.Request) (string, error) {
if request.Header.Get("Authorization") == "" { if r.Header.Get("Authorization") == "" {
return "", fmt.Errorf("missing token") return "", fmt.Errorf("missing token")
} }
token, err := base64.StdEncoding.DecodeString(request.Header.Get("Authorization")) token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil { if err != nil {
return "", fmt.Errorf("failed to decode token: %s", err) return "", fmt.Errorf("failed to decode token: %s", err)
} }
@ -30,12 +30,12 @@ func GetUsernameFromRequest(request *http.Request) (string, error) {
return username, nil return username, nil
} }
func GetUuidFromRequest(request *http.Request) ([]byte, error) { func getUuidFromRequest(r *http.Request) ([]byte, error) {
if request.Header.Get("Authorization") == "" { if r.Header.Get("Authorization") == "" {
return nil, fmt.Errorf("missing token") return nil, fmt.Errorf("missing token")
} }
token, err := base64.StdEncoding.DecodeString(request.Header.Get("Authorization")) token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to decode token: %s", err) return nil, fmt.Errorf("failed to decode token: %s", err)
} }

@ -26,21 +26,20 @@ const (
var isValidUsername = regexp.MustCompile(`^\w{1,16}$`).MatchString var isValidUsername = regexp.MustCompile(`^\w{1,16}$`).MatchString
// /account/info - get account info
type AccountInfoResponse struct { type AccountInfoResponse struct {
Username string `json:"username"` Username string `json:"username"`
LastSessionSlot int `json:"lastSessionSlot"` LastSessionSlot int `json:"lastSessionSlot"`
} }
func (s *Server) HandleAccountInfo(w http.ResponseWriter, r *http.Request) { // /account/info - get account info
username, err := GetUsernameFromRequest(r) func (s *Server) handleAccountInfo(w http.ResponseWriter, r *http.Request) {
username, err := getUsernameFromRequest(r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
} }
uuid, err := GetUuidFromRequest(r) // lazy uuid, err := getUuidFromRequest(r) // lazy
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -74,11 +73,11 @@ func (s *Server) HandleAccountInfo(w http.ResponseWriter, r *http.Request) {
w.Write(response) w.Write(response)
} }
// /account/register - register account
type AccountRegisterRequest GenericAuthRequest type AccountRegisterRequest GenericAuthRequest
func (s *Server) HandleAccountRegister(w http.ResponseWriter, r *http.Request) { // /account/register - register account
func (s *Server) handleAccountRegister(w http.ResponseWriter, r *http.Request) {
var request AccountRegisterRequest var request AccountRegisterRequest
err := json.NewDecoder(r.Body).Decode(&request) err := json.NewDecoder(r.Body).Decode(&request)
if err != nil { if err != nil {
@ -121,12 +120,11 @@ func (s *Server) HandleAccountRegister(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK) w.WriteHeader(http.StatusOK)
} }
// /account/login - log into account
type AccountLoginRequest GenericAuthRequest type AccountLoginRequest GenericAuthRequest
type AccountLoginResponse GenericAuthResponse type AccountLoginResponse GenericAuthResponse
func (s *Server) HandleAccountLogin(w http.ResponseWriter, r *http.Request) { // /account/login - log into account
func (s *Server) handleAccountLogin(w http.ResponseWriter, r *http.Request) {
var request AccountLoginRequest var request AccountLoginRequest
err := json.NewDecoder(r.Body).Decode(&request) err := json.NewDecoder(r.Body).Decode(&request)
if err != nil { if err != nil {
@ -184,8 +182,7 @@ func (s *Server) HandleAccountLogin(w http.ResponseWriter, r *http.Request) {
} }
// /account/logout - log out of account // /account/logout - log out of account
func (s *Server) handleAccountLogout(w http.ResponseWriter, r *http.Request) {
func (s *Server) HandleAccountLogout(w http.ResponseWriter, r *http.Request) {
token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization"))
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("failed to decode token: %s", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("failed to decode token: %s", err), http.StatusBadRequest)

@ -53,16 +53,17 @@ func InitDailyRun() {
dailyRunSecret = secret dailyRunSecret = secret
dailyRunSeed = base64.StdEncoding.EncodeToString(DeriveDailyRunSeed(time.Now().UTC())) dailyRunSeed = base64.StdEncoding.EncodeToString(deriveDailyRunSeed(time.Now().UTC()))
err = db.TryAddDailyRun(dailyRunSeed) err = db.TryAddDailyRun(dailyRunSeed)
if err != nil { if err != nil {
log.Print(err.Error()) log.Print(err)
} }
log.Printf("Daily Run Seed: %s", dailyRunSeed) log.Printf("Daily Run Seed: %s", dailyRunSeed)
} }
func DeriveDailyRunSeed(seedTime time.Time) []byte { func deriveDailyRunSeed(seedTime time.Time) []byte {
day := make([]byte, 8) day := make([]byte, 8)
binary.BigEndian.PutUint64(day, uint64(seedTime.Unix()/secondsPerDay)) binary.BigEndian.PutUint64(day, uint64(seedTime.Unix()/secondsPerDay))
@ -72,20 +73,13 @@ func DeriveDailyRunSeed(seedTime time.Time) []byte {
} }
// /daily/seed - fetch daily run seed // /daily/seed - fetch daily run seed
func (s *Server) handleSeed(w http.ResponseWriter) {
func (s *Server) HandleSeed(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(dailyRunSeed)) w.Write([]byte(dailyRunSeed))
} }
// /daily/rankings - fetch daily rankings // /daily/rankings - fetch daily rankings
func (s *Server) handleRankings(w http.ResponseWriter, r *http.Request) {
func (s *Server) HandleRankings(w http.ResponseWriter, r *http.Request) { uuid, err := getUuidFromRequest(r)
var err error
var category int
var page int
var uuid []byte
uuid, err = GetUuidFromRequest(r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -96,24 +90,22 @@ func (s *Server) HandleRankings(w http.ResponseWriter, r *http.Request) {
log.Print("failed to update account last activity") log.Print("failed to update account last activity")
} }
var category int
if r.URL.Query().Has("category") { if r.URL.Query().Has("category") {
category, err = strconv.Atoi(r.URL.Query().Get("category")) category, err = strconv.Atoi(r.URL.Query().Get("category"))
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest)
return return
} }
} else {
category = 0
} }
page := 1
if r.URL.Query().Has("page") { if r.URL.Query().Has("page") {
page, err = strconv.Atoi(r.URL.Query().Get("page")) page, err = strconv.Atoi(r.URL.Query().Get("page"))
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("failed to convert page: %s", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("failed to convert page: %s", err), http.StatusBadRequest)
return return
} }
} else {
page = 1
} }
rankings, err := db.FetchRankings(category, page) rankings, err := db.FetchRankings(category, page)
@ -131,8 +123,7 @@ func (s *Server) HandleRankings(w http.ResponseWriter, r *http.Request) {
} }
// /daily/rankingpagecount - fetch daily ranking page count // /daily/rankingpagecount - fetch daily ranking page count
func (s *Server) handleRankingPageCount(w http.ResponseWriter, r *http.Request) {
func (s *Server) HandleRankingPageCount(w http.ResponseWriter, r *http.Request) {
var err error var err error
var category int var category int
@ -142,8 +133,6 @@ func (s *Server) HandleRankingPageCount(w http.ResponseWriter, r *http.Request)
http.Error(w, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest) http.Error(w, fmt.Sprintf("failed to convert category: %s", err), http.StatusBadRequest)
return return
} }
} else {
category = 0
} }
pageCount, err := db.FetchRankingPageCount(category) pageCount, err := db.FetchRankingPageCount(category)

@ -13,25 +13,24 @@ import (
var ( var (
playerCountScheduler = gocron.NewScheduler(time.UTC) playerCountScheduler = gocron.NewScheduler(time.UTC)
playerCount = 0 playerCount int
) )
func SchedulePlayerCountRefresh() { func SchedulePlayerCountRefresh() {
playerCountScheduler.Every(10).Second().Do(UpdatePlayerCount) playerCountScheduler.Every(10).Second().Do(updatePlayerCount)
playerCountScheduler.StartAsync() playerCountScheduler.StartAsync()
} }
func UpdatePlayerCount() { func updatePlayerCount() {
var err error var err error
playerCount, err = db.FetchPlayerCount() playerCount, err = db.FetchPlayerCount()
if err != nil { if err != nil {
log.Print(err.Error()) log.Print(err)
} }
} }
// /game/playercount - get player count // /game/playercount - get player count
func (s *Server) handlePlayerCountGet(w http.ResponseWriter) {
func (s *Server) HandlePlayerCountGet(w http.ResponseWriter, r *http.Request) {
response, err := json.Marshal(playerCount) response, err := json.Marshal(playerCount)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("failed to marshal response json: %s", err), http.StatusInternalServerError)

@ -1,7 +1,6 @@
package api package api
import ( import (
"encoding/gob"
"net/http" "net/http"
) )
@ -21,37 +20,34 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
} }
gob.Register([]interface{}{})
gob.Register(map[string]interface{}{})
switch r.URL.Path { switch r.URL.Path {
case "/account/info": case "/account/info":
s.HandleAccountInfo(w, r) s.handleAccountInfo(w, r)
case "/account/register": case "/account/register":
s.HandleAccountRegister(w, r) s.handleAccountRegister(w, r)
case "/account/login": case "/account/login":
s.HandleAccountLogin(w, r) s.handleAccountLogin(w, r)
case "/account/logout": case "/account/logout":
s.HandleAccountLogout(w, r) s.handleAccountLogout(w, r)
case "/game/playercount": case "/game/playercount":
s.HandlePlayerCountGet(w, r) s.handlePlayerCountGet(w)
case "/savedata/get": case "/savedata/get":
s.HandleSavedataGet(w, r) s.handleSavedataGet(w, r)
case "/savedata/update": case "/savedata/update":
s.HandleSavedataUpdate(w, r) s.handleSavedataUpdate(w, r)
case "/savedata/delete": case "/savedata/delete":
s.HandleSavedataDelete(w, r) s.handleSavedataDelete(w, r)
case "/savedata/clear": case "/savedata/clear":
s.HandleSavedataClear(w, r) s.handleSavedataClear(w, r)
case "/daily/seed": case "/daily/seed":
s.HandleSeed(w, r) s.handleSeed(w)
case "/daily/rankings": case "/daily/rankings":
s.HandleRankings(w, r) s.handleRankings(w, r)
case "/daily/rankingpagecount": case "/daily/rankingpagecount":
s.HandleRankingPageCount(w, r) s.handleRankingPageCount(w, r)
} }
} }

@ -12,7 +12,10 @@ import (
"github.com/klauspost/compress/zstd" "github.com/klauspost/compress/zstd"
) )
func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) { func readSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
gob.Register([]interface{}{})
gob.Register(map[string]interface{}{})
var system defs.SystemSaveData var system defs.SystemSaveData
save, err := os.ReadFile("userdata/" + hex.EncodeToString(uuid) + "/system.pzs") save, err := os.ReadFile("userdata/" + hex.EncodeToString(uuid) + "/system.pzs")
@ -40,7 +43,10 @@ func ReadSystemSaveData(uuid []byte) (defs.SystemSaveData, error) {
return system, nil return system, nil
} }
func ReadSessionSaveData(uuid []byte, slotId int) (defs.SessionSaveData, error) { func readSessionSaveData(uuid []byte, slotId int) (defs.SessionSaveData, error) {
gob.Register([]interface{}{})
gob.Register(map[string]interface{}{})
var session defs.SessionSaveData var session defs.SessionSaveData
fileName := "session" fileName := "session"
@ -73,12 +79,13 @@ func ReadSessionSaveData(uuid []byte, slotId int) (defs.SessionSaveData, error)
return session, nil return session, nil
} }
func ValidateSessionCompleted(session defs.SessionSaveData) bool { func validateSessionCompleted(session defs.SessionSaveData) bool {
switch session.GameMode { switch session.GameMode {
case 0: case 0:
return session.BattleType == 2 && session.WaveIndex == 200 return session.BattleType == 2 && session.WaveIndex == 200
case 3: case 3:
return session.BattleType == 2 && session.WaveIndex == 50 return session.BattleType == 2 && session.WaveIndex == 50
} }
return false return false
} }

@ -19,9 +19,8 @@ import (
const sessionSlotCount = 3 const sessionSlotCount = 3
// /savedata/get - get save data // /savedata/get - get save data
func (s *Server) handleSavedataGet(w http.ResponseWriter, r *http.Request) {
func (s *Server) HandleSavedataGet(w http.ResponseWriter, r *http.Request) { uuid, err := getUuidFromRequest(r)
uuid, err := GetUuidFromRequest(r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -29,7 +28,7 @@ func (s *Server) HandleSavedataGet(w http.ResponseWriter, r *http.Request) {
switch r.URL.Query().Get("datatype") { switch r.URL.Query().Get("datatype") {
case "0": // System case "0": // System
system, err := ReadSystemSaveData(uuid) system, err := readSystemSaveData(uuid)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -54,7 +53,7 @@ func (s *Server) HandleSavedataGet(w http.ResponseWriter, r *http.Request) {
return return
} }
session, err := ReadSessionSaveData(uuid, slotId) session, err := readSessionSaveData(uuid, slotId)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError) http.Error(w, err.Error(), http.StatusInternalServerError)
return return
@ -74,9 +73,8 @@ func (s *Server) HandleSavedataGet(w http.ResponseWriter, r *http.Request) {
} }
// /savedata/update - update save data // /savedata/update - update save data
func (s *Server) handleSavedataUpdate(w http.ResponseWriter, r *http.Request) {
func (s *Server) HandleSavedataUpdate(w http.ResponseWriter, r *http.Request) { uuid, err := getUuidFromRequest(r)
uuid, err := GetUuidFromRequest(r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -188,9 +186,8 @@ func (s *Server) HandleSavedataUpdate(w http.ResponseWriter, r *http.Request) {
} }
// /savedata/delete - delete save data // /savedata/delete - delete save data
func (s *Server) handleSavedataDelete(w http.ResponseWriter, r *http.Request) {
func (s *Server) HandleSavedataDelete(w http.ResponseWriter, r *http.Request) { uuid, err := getUuidFromRequest(r)
uuid, err := GetUuidFromRequest(r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -245,9 +242,8 @@ type SavedataClearResponse struct {
} }
// /savedata/clear - mark session save data as cleared and delete // /savedata/clear - mark session save data as cleared and delete
func (s *Server) handleSavedataClear(w http.ResponseWriter, r *http.Request) {
func (s *Server) HandleSavedataClear(w http.ResponseWriter, r *http.Request) { uuid, err := getUuidFromRequest(r)
uuid, err := GetUuidFromRequest(r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest) http.Error(w, err.Error(), http.StatusBadRequest)
return return
@ -276,7 +272,7 @@ func (s *Server) HandleSavedataClear(w http.ResponseWriter, r *http.Request) {
return return
} }
sessionCompleted := ValidateSessionCompleted(session) sessionCompleted := validateSessionCompleted(session)
newCompletion := false newCompletion := false
if session.GameMode == 3 && session.Seed == dailyRunSeed { if session.GameMode == 3 && session.Seed == dailyRunSeed {
@ -286,14 +282,14 @@ func (s *Server) HandleSavedataClear(w http.ResponseWriter, r *http.Request) {
} }
err = db.AddOrUpdateAccountDailyRun(uuid, session.Score, waveCompleted) err = db.AddOrUpdateAccountDailyRun(uuid, session.Score, waveCompleted)
if err != nil { if err != nil {
log.Printf("failed to add or update daily run record: %s", err.Error()) log.Printf("failed to add or update daily run record: %s", err)
} }
} }
if sessionCompleted { if sessionCompleted {
newCompletion, err = db.TryAddSeedCompletion(uuid, session.Seed, int(session.GameMode)) newCompletion, err = db.TryAddSeedCompletion(uuid, session.Seed, int(session.GameMode))
if err != nil { if err != nil {
log.Printf("failed to mark seed as completed: %s", err.Error()) log.Printf("failed to mark seed as completed: %s", err)
} }
} }

@ -45,7 +45,7 @@ func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
defer results.Close() defer results.Close()
for results.Next() { for results.Next() {
ranking := defs.DailyRanking{} var ranking defs.DailyRanking
err = results.Scan(&ranking.Rank, &ranking.Username, &ranking.Score, &ranking.Wave) err = results.Scan(&ranking.Rank, &ranking.Username, &ranking.Score, &ranking.Wave)
if err != nil { if err != nil {
return rankings, err return rankings, err
@ -58,8 +58,6 @@ func FetchRankings(category int, page int) ([]defs.DailyRanking, error) {
} }
func FetchRankingPageCount(category int) (int, error) { func FetchRankingPageCount(category int) (int, error) {
var recordCount int
var query string var query string
switch category { switch category {
case 0: case 0:
@ -68,6 +66,7 @@ func FetchRankingPageCount(category int) (int, error) {
query = "SELECT COUNT(DISTINCT a.username) FROM accountDailyRuns adr JOIN dailyRuns dr ON dr.date = adr.date JOIN accounts a ON adr.uuid = a.uuid WHERE dr.date >= DATE_SUB(DATE(UTC_TIMESTAMP()), INTERVAL DAYOFWEEK(UTC_TIMESTAMP()) - 1 DAY)" query = "SELECT COUNT(DISTINCT a.username) FROM accountDailyRuns adr JOIN dailyRuns dr ON dr.date = adr.date JOIN accounts a ON adr.uuid = a.uuid WHERE dr.date >= DATE_SUB(DATE(UTC_TIMESTAMP()), INTERVAL DAYOFWEEK(UTC_TIMESTAMP()) - 1 DAY)"
} }
var recordCount int
err := handle.QueryRow(query).Scan(&recordCount) err := handle.QueryRow(query).Scan(&recordCount)
if err != nil { if err != nil {
return 0, err return 0, err

@ -2,7 +2,6 @@ package db
func FetchPlayerCount() (int, error) { func FetchPlayerCount() (int, error) {
var playerCount int var playerCount int
err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount) err := handle.QueryRow("SELECT COUNT(*) FROM accounts WHERE lastActivity > DATE_SUB(UTC_TIMESTAMP(), INTERVAL 5 MINUTE)").Scan(&playerCount)
if err != nil { if err != nil {
return 0, err return 0, err

Loading…
Cancel
Save