diff --git a/.gitignore b/.gitignore index f668971..09245dc 100644 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,3 @@ pokerogue-server* userdata/* secret.key -www/ \ No newline at end of file diff --git a/api/common.go b/api/common.go index dec5d2d..afe494c 100644 --- a/api/common.go +++ b/api/common.go @@ -10,31 +10,9 @@ import ( "github.com/pagefaultgames/pokerogue-server/db" ) -func Init(mux *http.ServeMux) { +func Init() { scheduleStatRefresh() daily.Init() - - // account - mux.HandleFunc("GET /api/account/info", handleAccountInfo) - mux.HandleFunc("POST /api/account/register", handleAccountRegister) - mux.HandleFunc("POST /api/account/login", handleAccountLogin) - mux.HandleFunc("GET /api/account/logout", handleAccountLogout) - - // game - mux.HandleFunc("GET /api/game/playercount", handleGamePlayerCount) - mux.HandleFunc("GET /api/game/titlestats", handleGameTitleStats) - mux.HandleFunc("GET /api/game/classicsessioncount", handleGameClassicSessionCount) - - // savedata - mux.HandleFunc("GET /api/savedata/get", handleSaveData) - mux.HandleFunc("POST /api/savedata/update", handleSaveData) - mux.HandleFunc("GET /api/savedata/delete", handleSaveData) - mux.HandleFunc("POST /api/savedata/clear", handleSaveData) - - // daily - mux.HandleFunc("GET /api/daily/seed", handleDailySeed) - mux.HandleFunc("GET /api/daily/rankings", handleDailyRankings) - mux.HandleFunc("GET /api/daily/rankingpagecount", handleDailyRankingPageCount) } func getUsernameFromRequest(r *http.Request) (string, error) { diff --git a/api/endpoints.go b/api/endpoints.go index 8f0d97e..396aab6 100644 --- a/api/endpoints.go +++ b/api/endpoints.go @@ -7,6 +7,7 @@ import ( "log" "net/http" "strconv" + "sync" "github.com/pagefaultgames/pokerogue-server/api/account" "github.com/pagefaultgames/pokerogue-server/api/daily" @@ -14,252 +15,262 @@ import ( "github.com/pagefaultgames/pokerogue-server/defs" ) +type Server struct { + Debug bool + Exit *sync.RWMutex +} + /* The caller of endpoint handler functions are responsible for extracting the necessary data from the request. Handler functions are responsible for checking the validity of this data and returning a result or error. Handlers should not return serialized JSON, instead return the struct itself. */ -func handleAccountInfo(w http.ResponseWriter, r *http.Request) { - username, err := getUsernameFromRequest(r) - if err != nil { - httpError(w, r, err, http.StatusBadRequest) - return - } - - uuid, err := getUUIDFromRequest(r) // lazy - if err != nil { - httpError(w, r, err, http.StatusBadRequest) - return - } +func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { + // kind of misusing the RWMutex but it doesn't matter + s.Exit.RLock() + defer s.Exit.RUnlock() - response, err := account.Info(username, uuid) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } + if s.Debug { + w.Header().Set("Access-Control-Allow-Headers", "*") + w.Header().Set("Access-Control-Allow-Methods", "*") + w.Header().Set("Access-Control-Allow-Origin", "*") - err = json.NewEncoder(w).Encode(response) - if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) - return - } -} - -func handleAccountRegister(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { - httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) - return - } - - err = account.Register(r.Form.Get("username"), r.Form.Get("password")) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } } - w.WriteHeader(http.StatusOK) -} - -func handleAccountLogin(w http.ResponseWriter, r *http.Request) { - err := r.ParseForm() - if err != nil { - httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) - return - } + switch r.URL.Path { + // /account + case "/account/info": + username, err := getUsernameFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } - response, err := account.Login(r.Form.Get("username"), r.Form.Get("password")) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } + uuid, err := getUUIDFromRequest(r) // lazy + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } - err = json.NewEncoder(w).Encode(response) - if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) - return - } -} + response, err := account.Info(username, uuid) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } -func handleAccountLogout(w http.ResponseWriter, r *http.Request) { - token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) - if err != nil { - httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest) - return - } + err = json.NewEncoder(w).Encode(response) + if err != nil { + httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + return + } + case "/account/register": + err := r.ParseForm() + if err != nil { + httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) + return + } - err = account.Logout(token) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } + err = account.Register(r.Form.Get("username"), r.Form.Get("password")) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } - w.WriteHeader(http.StatusOK) -} + w.WriteHeader(http.StatusOK) + case "/account/login": + err := r.ParseForm() + if err != nil { + httpError(w, r, fmt.Errorf("failed to parse request form: %s", err), http.StatusBadRequest) + return + } -func handleGamePlayerCount(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(strconv.Itoa(playerCount))) -} + response, err := account.Login(r.Form.Get("username"), r.Form.Get("password")) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } -func handleGameTitleStats(w http.ResponseWriter, r *http.Request) { - err := json.NewEncoder(w).Encode(defs.TitleStats{ - PlayerCount: playerCount, - BattleCount: battleCount, - }) - if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) - return - } -} + err = json.NewEncoder(w).Encode(response) + if err != nil { + httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + return + } + case "/account/logout": + token, err := base64.StdEncoding.DecodeString(r.Header.Get("Authorization")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode token: %s", err), http.StatusBadRequest) + return + } -func handleGameClassicSessionCount(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(strconv.Itoa(classicSessionCount))) -} + err = account.Logout(token) + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) + return + } -func handleSaveData(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) - if err != nil { - httpError(w, r, err, http.StatusBadRequest) - return - } + w.WriteHeader(http.StatusOK) - datatype := -1 - if r.URL.Query().Has("datatype") { - datatype, err = strconv.Atoi(r.URL.Query().Get("datatype")) + // /game + case "/game/playercount": + w.Write([]byte(strconv.Itoa(playerCount))) + case "/game/titlestats": + err := json.NewEncoder(w).Encode(defs.TitleStats{ + PlayerCount: playerCount, + BattleCount: battleCount, + }) if err != nil { - httpError(w, r, err, http.StatusBadRequest) + httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) return } - } + case "/game/classicsessioncount": + w.Write([]byte(strconv.Itoa(classicSessionCount))) - var slot int - if r.URL.Query().Has("slot") { - slot, err = strconv.Atoi(r.URL.Query().Get("slot")) + // /savedata + case "/savedata/get", "/savedata/update", "/savedata/delete", "/savedata/clear": + uuid, err := getUUIDFromRequest(r) if err != nil { httpError(w, r, err, http.StatusBadRequest) return } - } - var save any - // /savedata/get and /savedata/delete specify datatype, but don't expect data in body - if r.URL.Path != "/api/savedata/get" && r.URL.Path != "/api/savedata/delete" { - if datatype == 0 { - var system defs.SystemSaveData - err = json.NewDecoder(r.Body).Decode(&system) + datatype := -1 + if r.URL.Query().Has("datatype") { + datatype, err = strconv.Atoi(r.URL.Query().Get("datatype")) if err != nil { - httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + httpError(w, r, err, http.StatusBadRequest) return } + } - save = system - // /savedata/clear doesn't specify datatype, it is assumed to be 1 (session) - } else if datatype == 1 || r.URL.Path == "/api/savedata/clear" { - var session defs.SessionSaveData - err = json.NewDecoder(r.Body).Decode(&session) + var slot int + if r.URL.Query().Has("slot") { + slot, err = strconv.Atoi(r.URL.Query().Get("slot")) if err != nil { - httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + httpError(w, r, err, http.StatusBadRequest) return } + } - save = session + var save any + // /savedata/get and /savedata/delete specify datatype, but don't expect data in body + if r.URL.Path != "/savedata/get" && r.URL.Path != "/savedata/delete" { + if datatype == 0 { + var system defs.SystemSaveData + err = json.NewDecoder(r.Body).Decode(&system) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + save = system + // /savedata/clear doesn't specify datatype, it is assumed to be 1 (session) + } else if datatype == 1 || r.URL.Path == "/savedata/clear" { + var session defs.SessionSaveData + err = json.NewDecoder(r.Body).Decode(&session) + if err != nil { + httpError(w, r, fmt.Errorf("failed to decode request body: %s", err), http.StatusBadRequest) + return + } + + save = session + } } - } - switch r.URL.Path { - case "/api/savedata/get": - save, err = savedata.Get(uuid, datatype, slot) - case "/api/savedata/update": - err = savedata.Update(uuid, slot, save) - case "/api/savedata/delete": - err = savedata.Delete(uuid, datatype, slot) - case "/api/savedata/clear": - s, ok := save.(defs.SessionSaveData) - if !ok { - httpError(w, r, fmt.Errorf("save data is not type SessionSaveData"), http.StatusBadRequest) + switch r.URL.Path { + case "/savedata/get": + save, err = savedata.Get(uuid, datatype, slot) + case "/savedata/update": + err = savedata.Update(uuid, slot, save) + case "/savedata/delete": + err = savedata.Delete(uuid, datatype, slot) + case "/savedata/clear": + s, ok := save.(defs.SessionSaveData) + if !ok { + httpError(w, r, fmt.Errorf("save data is not type SessionSaveData"), http.StatusBadRequest) + return + } + + // doesn't return a save, but it works + save, err = savedata.Clear(uuid, slot, daily.Seed(), s) + } + if err != nil { + httpError(w, r, err, http.StatusInternalServerError) return } - // doesn't return a save, but it works - save, err = savedata.Clear(uuid, slot, daily.Seed(), s) - } - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } + if save == nil || r.URL.Path == "/savedata/update" { + w.WriteHeader(http.StatusOK) + return + } - if save == nil || r.URL.Path == "/api/savedata/update" { - w.WriteHeader(http.StatusOK) - return - } + err = json.NewEncoder(w).Encode(save) + if err != nil { + httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) + return + } - err = json.NewEncoder(w).Encode(save) - if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) - return - } -} + // /daily + case "/daily/seed": + w.Write([]byte(daily.Seed())) + case "/daily/rankings": + uuid, err := getUUIDFromRequest(r) + if err != nil { + httpError(w, r, err, http.StatusBadRequest) + return + } -func handleDailySeed(w http.ResponseWriter, r *http.Request) { - w.Write([]byte(daily.Seed())) -} + var category int + if r.URL.Query().Has("category") { + category, err = strconv.Atoi(r.URL.Query().Get("category")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) + return + } + } -func handleDailyRankings(w http.ResponseWriter, r *http.Request) { - uuid, err := getUUIDFromRequest(r) - if err != nil { - httpError(w, r, err, http.StatusBadRequest) - return - } + page := 1 + if r.URL.Query().Has("page") { + page, err = strconv.Atoi(r.URL.Query().Get("page")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest) + return + } + } - var category int - if r.URL.Query().Has("category") { - category, err = strconv.Atoi(r.URL.Query().Get("category")) + rankings, err := daily.Rankings(uuid, category, page) if err != nil { - httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) + httpError(w, r, err, http.StatusInternalServerError) return } - } - page := 1 - if r.URL.Query().Has("page") { - page, err = strconv.Atoi(r.URL.Query().Get("page")) + err = json.NewEncoder(w).Encode(rankings) if err != nil { - httpError(w, r, fmt.Errorf("failed to convert page: %s", err), http.StatusBadRequest) + httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) return } - } - - rankings, err := daily.Rankings(uuid, category, page) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) - return - } - - err = json.NewEncoder(w).Encode(rankings) - if err != nil { - httpError(w, r, fmt.Errorf("failed to encode response json: %s", err), http.StatusInternalServerError) - return - } -} + case "/daily/rankingpagecount": + var category int + if r.URL.Query().Has("category") { + var err error + category, err = strconv.Atoi(r.URL.Query().Get("category")) + if err != nil { + httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) + return + } + } -func handleDailyRankingPageCount(w http.ResponseWriter, r *http.Request) { - var category int - if r.URL.Query().Has("category") { - var err error - category, err = strconv.Atoi(r.URL.Query().Get("category")) + count, err := daily.RankingPageCount(category) if err != nil { - httpError(w, r, fmt.Errorf("failed to convert category: %s", err), http.StatusBadRequest) - return + httpError(w, r, err, http.StatusInternalServerError) } - } - count, err := daily.RankingPageCount(category) - if err != nil { - httpError(w, r, err, http.StatusInternalServerError) + w.Write([]byte(strconv.Itoa(count))) } - - w.Write([]byte(strconv.Itoa(count))) } func httpError(w http.ResponseWriter, r *http.Request, err error, code int) { diff --git a/pokerogue-server.go b/pokerogue-server.go index a147f63..acd901c 100644 --- a/pokerogue-server.go +++ b/pokerogue-server.go @@ -4,7 +4,12 @@ import ( "encoding/gob" "flag" "log" + "net" "net/http" + "os" + "os/signal" + "sync" + "syscall" "github.com/pagefaultgames/pokerogue-server/api" "github.com/pagefaultgames/pokerogue-server/db" @@ -12,10 +17,10 @@ import ( func main() { // flag stuff - addr := flag.String("addr", "0.0.0.0:80", "network address for api to listen on") - wwwpath := flag.String("wwwpath", "www", "path to static content to serve") - tlscert := flag.String("tlscert", "", "path to tls certificate to use for https") - tlskey := flag.String("tlskey", "", "path to tls private key to use for https") + debug := flag.Bool("debug", false, "debug mode") + + proto := flag.String("proto", "tcp", "protocol for api to use (tcp, unix)") + addr := flag.String("addr", "0.0.0.0", "network address for api to listen on") dbuser := flag.String("dbuser", "pokerogue", "database username") dbpass := flag.String("dbpass", "", "database password") @@ -25,6 +30,7 @@ func main() { flag.Parse() + // register gob types gob.Register([]interface{}{}) gob.Register(map[string]interface{}{}) @@ -35,19 +41,55 @@ func main() { log.Fatalf("failed to initialize database: %s", err) } - // start web server - mux := http.NewServeMux() + // create listener + listener, err := createListener(*proto, *addr) + if err != nil { + log.Fatalf("failed to create net listener: %s", err) + } - api.Init(mux) + // create exit handler + var exit sync.RWMutex + createExitHandler(&exit) - mux.Handle("/", http.FileServer(http.Dir(*wwwpath))) - - if *tlscert != "" && *tlskey != "" { - err = http.ListenAndServeTLS(*addr, *tlscert, *tlskey, mux) - } else { - err = http.ListenAndServe(*addr, mux) - } + // init api + api.Init() + + // start web server + err = http.Serve(listener, &api.Server{Debug: *debug, Exit: &exit}) if err != nil { log.Fatalf("failed to create http server or server errored: %s", err) } } + +func createListener(proto, addr string) (net.Listener, error) { + if proto == "unix" { + os.Remove(addr) + } + + listener, err := net.Listen(proto, addr) + if err != nil { + return nil, err + } + + if proto == "unix" { + os.Chmod(addr, 0777) + } + + return listener, nil +} + +func createExitHandler(mtx *sync.RWMutex) { + s := make(chan os.Signal, 1) + signal.Notify(s, syscall.SIGINT, syscall.SIGTERM) + + go func() { + // wait for exit signal of some kind + <-s + + // block new requests and wait for existing ones to finish + mtx.Lock() + + // bail + os.Exit(0) + }() +}