From 436fce875901474beb2d48d7ee5676579974c37d Mon Sep 17 00:00:00 2001 From: Up Date: Tue, 14 May 2024 12:54:06 +0200 Subject: [PATCH] add client session ID tokens --- api/common.go | 18 +++-- api/endpoints.go | 185 +++++++++++++++++++++++++++++++++++++++++++---- db/account.go | 24 +++--- db/db.go | 8 ++ 4 files changed, 203 insertions(+), 32 deletions(-) diff --git a/api/common.go b/api/common.go index b1f11dd..ec87644 100644 --- a/api/common.go +++ b/api/common.go @@ -21,12 +21,11 @@ import ( "encoding/base64" "encoding/json" "fmt" - "log" - "net/http" - "github.com/pagefaultgames/rogueserver/api/account" "github.com/pagefaultgames/rogueserver/api/daily" "github.com/pagefaultgames/rogueserver/db" + "log" + "net/http" ) func Init(mux *http.ServeMux) error { @@ -49,14 +48,17 @@ func Init(mux *http.ServeMux) error { mux.HandleFunc("GET /game/classicsessioncount", handleGameClassicSessionCount) // savedata - mux.HandleFunc("GET /savedata/get", handleGetSaveData) - mux.HandleFunc("POST /savedata/update", handleSaveData) - mux.HandleFunc("GET /savedata/delete", handleSaveData) // TODO use deleteSystemSave - mux.HandleFunc("POST /savedata/clear", handleSaveData) // TODO use clearSessionData - mux.HandleFunc("GET /savedata/newclear", handleNewClear) + mux.HandleFunc("GET /savedata/get", legacyHandleGetSaveData) + mux.HandleFunc("POST /savedata/update", legacyHandleSaveData) + mux.HandleFunc("GET /savedata/delete", legacyHandleSaveData) // TODO use deleteSystemSave + mux.HandleFunc("POST /savedata/clear", legacyHandleSaveData) // TODO use clearSessionData + mux.HandleFunc("GET /savedata/newclear", legacyHandleNewClear) // new session mux.HandleFunc("POST /savedata/updateall", handleUpdateAll) + mux.HandleFunc("POST /savedata/verify", handleSessionVerify) + mux.HandleFunc("GET /savedata/system", handleGetSystemData) + mux.HandleFunc("GET /savedata/session", handleGetSessionData) // daily mux.HandleFunc("GET /daily/seed", handleDailySeed) diff --git a/api/endpoints.go b/api/endpoints.go index be24e33..676570e 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -146,7 +146,53 @@ func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) { _, _ = w.Write([]byte(strconv.Itoa(classicSessionCount))) } -func handleGetSaveData(w http.ResponseWriter, r *http.Request) { +func handleGetSessionData(w http.ResponseWriter, r *http.Request) { + token, uuid, err := tokenAndUuidFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + + var slot int + if r.URL.Query().Has("slot") { + slot, err = strconv.Atoi(r.URL.Query().Get("slot")) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + } + + var clientSessionId string + if r.URL.Query().Has("clientSessionId") { + clientSessionId = r.URL.Query().Get("clientSessionId") + } else { + httpError(w, r, fmt.Errorf("missing clientSessionId"), http.StatusBadRequest) + } + + err = db.UpdateActiveSession(token, clientSessionId) + if err != nil { + httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) + return + } + + var save any + save, err = savedata.Get(uuid, 1, slot) + if errors.Is(err, sql.ErrNoRows) { + http.Error(w, err.Error(), http.StatusNotFound) + return + } + + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } + + jsonResponse(w, r, save) +} + +const legacyClientSessionId = "LEGACY_CLIENT" + +func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) { token, uuid, err := tokenAndUuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) @@ -173,7 +219,7 @@ func handleGetSaveData(w http.ResponseWriter, r *http.Request) { var save any if datatype == 0 { - err = db.UpdateActiveSession(uuid, token) + err = db.UpdateActiveSession(token, legacyClientSessionId) // we dont have a client id if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return @@ -222,7 +268,7 @@ func clearSessionData(w http.ResponseWriter, r *http.Request) { save = session var active bool - active, err = db.IsActiveSession(token) + active, err = db.IsActiveSession(token, legacyClientSessionId) //TODO unfinished, read token from query if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -309,7 +355,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) { } var active bool - active, err = db.IsActiveSession(token) + active, err = db.IsActiveSession(token, legacyClientSessionId) //TODO unfinished, read token from query if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusInternalServerError) return @@ -363,7 +409,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func handleSaveData(w http.ResponseWriter, r *http.Request) { +func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) { token, uuid, err := tokenAndUuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) @@ -388,6 +434,14 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { } } + var clientSessionId string + if r.URL.Query().Has("clientSessionId") { + clientSessionId = r.URL.Query().Get("clientSessionId") + } + if clientSessionId == "" { + clientSessionId = legacyClientSessionId + } + var save any // /savedata/get and /savedata/delete specify datatype, but don't expect data in body if r.URL.Path != "/savedata/get" && r.URL.Path != "/savedata/delete" { @@ -416,14 +470,14 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { var active bool if r.URL.Path == "/savedata/get" { if datatype == 0 { - err = db.UpdateActiveSession(uuid, token) + err = db.UpdateActiveSession(token, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return } } } else { - active, err = db.IsActiveSession(token) + active, err = db.IsActiveSession(token, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -517,9 +571,10 @@ func handleSaveData(w http.ResponseWriter, r *http.Request) { } type CombinedSaveData struct { - System defs.SystemSaveData `json:"system"` - Session defs.SessionSaveData `json:"session"` - SessionSlotId int `json:"sessionSlotId"` + System defs.SystemSaveData `json:"system"` + Session defs.SessionSaveData `json:"session"` + SessionSlotId int `json:"sessionSlotId"` + ClientSessionId string `json:"clientSessionId"` } // TODO wrap this in a transaction @@ -531,6 +586,14 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { return } + var clientSessionId string + if r.URL.Query().Has("clientSessionId") { + clientSessionId = r.URL.Query().Get("clientSessionId") + } + if clientSessionId == "" { + clientSessionId = legacyClientSessionId + } + var data CombinedSaveData err = json.NewDecoder(r.Body).Decode(&data) if err != nil { @@ -539,7 +602,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { } var active bool - active, err = db.IsActiveSession(token) + active, err = db.IsActiveSession(token, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -584,7 +647,104 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -func handleNewClear(w http.ResponseWriter, r *http.Request) { +type SessionVerifyResponse struct { + Valid bool `json:"valid"` + SessionData *defs.SessionSaveData `json:"sessionData"` +} + +type SessionVerifyRequest struct { + ClientSessionId string `json:"clientSessionId"` + Slot int `json:"slot"` +} + +func handleSessionVerify(w http.ResponseWriter, r *http.Request) { + var token []byte + token, err := tokenFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + + var input SessionVerifyRequest + err = json.NewDecoder(r.Body).Decode(&input) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + var active bool + active, err = db.IsActiveSession(token, input.ClientSessionId) + if err != nil { + httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) + return + } + + response := SessionVerifyResponse{ + Valid: active, + } + + // not valid, send server state + if !active { + err = db.UpdateActiveSession(token, input.ClientSessionId) + if err != nil { + httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) + return + } + + var uuid []byte + uuid, err = db.FetchUUIDFromToken(token) + if err != nil { + httpError(w, r, fmt.Errorf("failed to fetch UUID from token: %s", err), http.StatusInternalServerError) + } + var storedSaveData defs.SessionSaveData + storedSaveData, err = db.ReadSessionSaveData(uuid, input.Slot) + if err != nil { + httpError(w, r, fmt.Errorf("failed to read session save data: %s", err), http.StatusInternalServerError) + return + } + + response.SessionData = &storedSaveData + } + + jsonResponse(w, r, response) +} + +func handleGetSystemData(w http.ResponseWriter, r *http.Request) { + token, uuid, err := tokenAndUuidFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } + + var clientSessionId string + if r.URL.Query().Has("clientSessionId") { + clientSessionId = r.URL.Query().Get("clientSessionId") + } else { + httpError(w, r, fmt.Errorf("missing clientSessionId"), http.StatusBadRequest) + } + + err = db.UpdateActiveSession(token, clientSessionId) + if err != nil { + httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) + return + } + + var save any //TODO this is always system save data + save, err = savedata.Get(uuid, 0, 0) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + http.Error(w, err.Error(), http.StatusNotFound) + } else { + httpError(w, r, err, http.StatusInternalServerError) + } + + return + } + + jsonResponse(w, r, save) +} + +func legacyHandleNewClear(w http.ResponseWriter, r *http.Request) { uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) @@ -610,7 +770,6 @@ func handleNewClear(w http.ResponseWriter, r *http.Request) { } // daily - func handleDailySeed(w http.ResponseWriter, r *http.Request) { seed, err := db.GetDailyRunSeed() if err != nil { diff --git a/db/account.go b/db/account.go index a3dddd2..88b54db 100644 --- a/db/account.go +++ b/db/account.go @@ -40,11 +40,6 @@ func AddAccountSession(username string, token []byte) error { return err } - _, err = handle.Exec("UPDATE sessions s JOIN accounts a ON a.uuid = s.uuid SET s.active = 1 WHERE a.username = ? AND a.lastLoggedIn IS NULL", username) - if err != nil { - return err - } - _, err = handle.Exec("UPDATE accounts SET lastLoggedIn = UTC_TIMESTAMP() WHERE username = ?", username) if err != nil { return err @@ -213,18 +208,25 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error { return nil } -func IsActiveSession(token []byte) (bool, error) { - var active int - err := handle.QueryRow("SELECT `active` FROM sessions WHERE token = ?", token).Scan(&active) +func IsActiveSession(token []byte, clientSessionId string) (bool, error) { + var storedId string + err := handle.QueryRow("SELECT `clientSessionId` FROM sessions WHERE token = ?", token).Scan(&storedId) if err != nil { return false, err } + if storedId == "" { + err = UpdateActiveSession(token, clientSessionId) + if err != nil { + return false, err + } + return true, nil + } - return active == 1, nil + return storedId == clientSessionId, nil } -func UpdateActiveSession(uuid []byte, token []byte) error { - _, err := handle.Exec("UPDATE sessions SET `active` = CASE WHEN token = ? THEN 1 ELSE 0 END WHERE uuid = ?", token, uuid) +func UpdateActiveSession(token []byte, clientSessionId string) error { + _, err := handle.Exec("UPDATE sessions SET clientSessionId = ? WHERE token = ?", clientSessionId, token) if err != nil { return err } diff --git a/db/db.go b/db/db.go index 874a48c..d4ddc88 100644 --- a/db/db.go +++ b/db/db.go @@ -145,6 +145,8 @@ func Init(username, password, protocol, address, database string) error { func setupDb(tx *sql.Tx) error { queries := []string{ + // MIGRATION 000 + `CREATE TABLE IF NOT EXISTS accounts (uuid BINARY(16) NOT NULL PRIMARY KEY, username VARCHAR(16) UNIQUE NOT NULL, hash BINARY(32) NOT NULL, salt BINARY(16) NOT NULL, registered TIMESTAMP NOT NULL, lastLoggedIn TIMESTAMP DEFAULT NULL, lastActivity TIMESTAMP DEFAULT NULL, banned TINYINT(1) NOT NULL DEFAULT 0, trainerId SMALLINT(5) UNSIGNED DEFAULT 0, secretId SMALLINT(5) UNSIGNED DEFAULT 0)`, `CREATE INDEX IF NOT EXISTS accountsByActivity ON accounts (lastActivity)`, @@ -168,6 +170,12 @@ func setupDb(tx *sql.Tx) error { `CREATE TABLE IF NOT EXISTS systemSaveData (uuid BINARY(16) PRIMARY KEY, data LONGBLOB, timestamp TIMESTAMP)`, `CREATE TABLE IF NOT EXISTS sessionSaveData (uuid BINARY(16), slot TINYINT, data LONGBLOB, timestamp TIMESTAMP, PRIMARY KEY (uuid, slot))`, + + // ---------------------------------- + // MIGRATION 001 + + `ALTER TABLE sessions DROP COLUMN IF EXISTS active`, + `ALTER TABLE sessions ADD COLUMN IF NOT EXISTS clientSessionId VARCHAR(32)`, } for _, q := range queries {