diff --git a/api/endpoints.go b/api/endpoints.go index 676570e..62a958d 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -147,7 +147,7 @@ func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) { } func handleGetSessionData(w http.ResponseWriter, r *http.Request) { - token, uuid, err := tokenAndUuidFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -169,7 +169,7 @@ func handleGetSessionData(w http.ResponseWriter, r *http.Request) { httpError(w, r, fmt.Errorf("missing clientSessionId"), http.StatusBadRequest) } - err = db.UpdateActiveSession(token, clientSessionId) + err = db.UpdateActiveSession(uuid, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return @@ -193,7 +193,7 @@ func handleGetSessionData(w http.ResponseWriter, r *http.Request) { const legacyClientSessionId = "LEGACY_CLIENT" func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) { - token, uuid, err := tokenAndUuidFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -219,7 +219,7 @@ func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) { var save any if datatype == 0 { - err = db.UpdateActiveSession(token, legacyClientSessionId) // we dont have a client id + err = db.UpdateActiveSession(uuid, legacyClientSessionId) // we dont have a client id if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return @@ -242,7 +242,7 @@ func legacyHandleGetSaveData(w http.ResponseWriter, r *http.Request) { // FIXME UNFINISHED!!! func clearSessionData(w http.ResponseWriter, r *http.Request) { - token, uuid, err := tokenAndUuidFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -268,7 +268,7 @@ func clearSessionData(w http.ResponseWriter, r *http.Request) { save = session var active bool - active, err = db.IsActiveSession(token, legacyClientSessionId) //TODO unfinished, read token from query + active, err = db.IsActiveSession(uuid, legacyClientSessionId) //TODO unfinished, read token from query if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -330,7 +330,7 @@ func clearSessionData(w http.ResponseWriter, r *http.Request) { // FIXME UNFINISHED!!! func deleteSystemSave(w http.ResponseWriter, r *http.Request) { - token, uuid, err := tokenAndUuidFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -355,7 +355,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) { } var active bool - active, err = db.IsActiveSession(token, legacyClientSessionId) //TODO unfinished, read token from query + active, err = db.IsActiveSession(uuid, legacyClientSessionId) //TODO unfinished, read token from query if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusInternalServerError) return @@ -410,7 +410,7 @@ func deleteSystemSave(w http.ResponseWriter, r *http.Request) { } func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) { - token, uuid, err := tokenAndUuidFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -470,14 +470,14 @@ func legacyHandleSaveData(w http.ResponseWriter, r *http.Request) { var active bool if r.URL.Path == "/savedata/get" { if datatype == 0 { - err = db.UpdateActiveSession(token, clientSessionId) + err = db.UpdateActiveSession(uuid, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return } } } else { - active, err = db.IsActiveSession(token, clientSessionId) + active, err = db.IsActiveSession(uuid, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -579,8 +579,7 @@ type CombinedSaveData struct { // TODO wrap this in a transaction func handleUpdateAll(w http.ResponseWriter, r *http.Request) { - var token []byte - token, uuid, err := tokenAndUuidFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -602,7 +601,7 @@ func handleUpdateAll(w http.ResponseWriter, r *http.Request) { } var active bool - active, err = db.IsActiveSession(token, clientSessionId) + active, err = db.IsActiveSession(uuid, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -658,8 +657,7 @@ type SessionVerifyRequest struct { } func handleSessionVerify(w http.ResponseWriter, r *http.Request) { - var token []byte - token, err := tokenFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -673,7 +671,7 @@ func handleSessionVerify(w http.ResponseWriter, r *http.Request) { } var active bool - active, err = db.IsActiveSession(token, input.ClientSessionId) + active, err = db.IsActiveSession(uuid, input.ClientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to check active session: %s", err), http.StatusBadRequest) return @@ -685,17 +683,12 @@ func handleSessionVerify(w http.ResponseWriter, r *http.Request) { // not valid, send server state if !active { - err = db.UpdateActiveSession(token, input.ClientSessionId) + err = db.UpdateActiveSession(uuid, input.ClientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return } - var uuid []byte - uuid, err = db.FetchUUIDFromToken(token) - if err != nil { - httpError(w, r, fmt.Errorf("failed to fetch UUID from token: %s", err), http.StatusInternalServerError) - } var storedSaveData defs.SessionSaveData storedSaveData, err = db.ReadSessionSaveData(uuid, input.Slot) if err != nil { @@ -710,7 +703,7 @@ func handleSessionVerify(w http.ResponseWriter, r *http.Request) { } func handleGetSystemData(w http.ResponseWriter, r *http.Request) { - token, uuid, err := tokenAndUuidFromRequest(r) + uuid, err := uuidFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return @@ -723,7 +716,7 @@ func handleGetSystemData(w http.ResponseWriter, r *http.Request) { httpError(w, r, fmt.Errorf("missing clientSessionId"), http.StatusBadRequest) } - err = db.UpdateActiveSession(token, clientSessionId) + err = db.UpdateActiveSession(uuid, clientSessionId) if err != nil { httpError(w, r, fmt.Errorf("failed to update active session: %s", err), http.StatusBadRequest) return diff --git a/db/account.go b/db/account.go index 88b54db..e000122 100644 --- a/db/account.go +++ b/db/account.go @@ -18,6 +18,8 @@ package db import ( + "database/sql" + "errors" "fmt" "slices" @@ -208,14 +210,17 @@ func UpdateTrainerIds(trainerId, secretId int, uuid []byte) error { return nil } -func IsActiveSession(token []byte, clientSessionId string) (bool, error) { +func IsActiveSession(uuid []byte, clientSessionId string) (bool, error) { var storedId string - err := handle.QueryRow("SELECT `clientSessionId` FROM sessions WHERE token = ?", token).Scan(&storedId) + err := handle.QueryRow("SELECT clientSessionId FROM activeClientSessions WHERE sessions.uuid = ?", uuid).Scan(&storedId) if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, nil + } return false, err } if storedId == "" { - err = UpdateActiveSession(token, clientSessionId) + err = UpdateActiveSession(uuid, clientSessionId) if err != nil { return false, err } @@ -225,8 +230,8 @@ func IsActiveSession(token []byte, clientSessionId string) (bool, error) { return storedId == clientSessionId, nil } -func UpdateActiveSession(token []byte, clientSessionId string) error { - _, err := handle.Exec("UPDATE sessions SET clientSessionId = ? WHERE token = ?", clientSessionId, token) +func UpdateActiveSession(uuid []byte, clientSessionId string) error { + _, err := handle.Exec("REPLACE INTO activeClientSessions VALUES (?, ?)", uuid, clientSessionId) if err != nil { return err } diff --git a/db/db.go b/db/db.go index d4ddc88..7aa881c 100644 --- a/db/db.go +++ b/db/db.go @@ -175,7 +175,7 @@ func setupDb(tx *sql.Tx) error { // MIGRATION 001 `ALTER TABLE sessions DROP COLUMN IF EXISTS active`, - `ALTER TABLE sessions ADD COLUMN IF NOT EXISTS clientSessionId VARCHAR(32)`, + `CREATE TABLE IF NOT EXISTS activeClientSessions (uuid BINARY(16) NOT NULL PRIMARY KEY, clientSessionId VARCHAR(32) NOT NULL)`, } for _, q := range queries {