From cada4d25fa6cc06000f1ec2cbab8c90f4ad44c7e Mon Sep 17 00:00:00 2001 From: Yi-Ting Shih Date: Sun, 7 Dec 2025 19:14:06 +0800 Subject: [PATCH] Feat: add preshared key check --- cmds/serve.go | 5 ++++- middlewares/checkPresharedKey.go | 32 ++++++++++++++++++++++++++++++++ tests/01_login_test.go | 13 +++++++++++++ 3 files changed, 49 insertions(+), 1 deletion(-) create mode 100644 middlewares/checkPresharedKey.go diff --git a/cmds/serve.go b/cmds/serve.go index aec0aa4..04e2f0c 100644 --- a/cmds/serve.go +++ b/cmds/serve.go @@ -106,7 +106,8 @@ var serveCmd = &cobra.Command{ authGroup := backend.NewGroup("/auth") authGroup.POST("/login", auths.PostLogin) - authGroup.POST("/gen-login-url", auths.PostGenLoginUrl) + authGroup.POST("/gen-login-url", + midHandlers.CheckPresharedKey(auths.PostGenLoginUrl)) if viper.GetBool("swagger") { backend.GET("/swagger/*any", @@ -127,6 +128,8 @@ func init() { String("external-url", "http://localhost:8080", "External url for login") serveCmd.Flags(). String("cors-origin", "", "CORS origin") + serveCmd.Flags(). + String("preshared-key", "poop", "Preshared key for Discord Bot") serveCmd.Flags(). Int64("access-token-timeout", 300, "Timeout of Access Token JWT") diff --git a/middlewares/checkPresharedKey.go b/middlewares/checkPresharedKey.go new file mode 100644 index 0000000..7d26854 --- /dev/null +++ b/middlewares/checkPresharedKey.go @@ -0,0 +1,32 @@ +package middlewares + +import ( + "net/http" + "strings" + + "github.com/spf13/viper" + "github.com/uptrace/bunrouter" +) + +func (self *Handlers) CheckPresharedKey( + next bunrouter.HandlerFunc, +) bunrouter.HandlerFunc { + return func(w http.ResponseWriter, req bunrouter.Request) error { + authHeader := strings.Split(req.Header.Get("Authorization"), " ") + if len(authHeader) != 2 || authHeader[0] != "Bearer" { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "missing preshared key", + } + } + + if authHeader[1] != viper.GetString("preshared-key") { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "preshared key mismatched", + } + } + + return next(w, req) + } +} diff --git a/tests/01_login_test.go b/tests/01_login_test.go index 50d756f..4240c9a 100644 --- a/tests/01_login_test.go +++ b/tests/01_login_test.go @@ -21,9 +21,22 @@ type loginPayload struct { func Test_01_Login(t *testing.T) { client = resty.New() + t.Run("check preshared key failed", func(t *testing.T) { + resp, err := client.R(). + SetBody(`{"userId": "testuser1"}`). + Post("http://localhost:8080/auth/gen-login-url") + if err != nil { + t.Fatal("request failed") + } + if resp.StatusCode() != http.StatusUnauthorized { + t.Fatal("preshared key check should failed") + } + }) + var payload genLoginUrlPayload resp, err := client.R(). SetBody(`{"userId": "testuser1"}`). + SetAuthToken("poop"). SetResult(&payload). Post("http://localhost:8080/auth/gen-login-url")