From fb1c47b32155f0a40224404dfa53d3fd5eed827a Mon Sep 17 00:00:00 2001 From: Yi-Ting Shih Date: Sun, 7 Dec 2025 11:27:54 +0800 Subject: [PATCH] Feat: add session --- cmds/serve.go | 32 +++++---- go.mod | 1 + go.sum | 2 + implements/bunDatabase.go | 65 ++++++++++++++++++ interfaces/database.go | 15 ++++ middlewares/checkAccessToken.go | 113 +++++++++++++++++++++++++++++++ middlewares/checkRefreshToken.go | 74 ++++++++++++++++++++ middlewares/handlers.go | 11 +++ models/session.go | 68 +++++++++++++++++++ types/definition.go | 6 ++ types/errors.go | 10 +++ utils/initDB.go | 15 ++++ 12 files changed, 400 insertions(+), 12 deletions(-) create mode 100644 middlewares/checkAccessToken.go create mode 100644 middlewares/checkRefreshToken.go create mode 100644 middlewares/handlers.go create mode 100644 models/session.go create mode 100644 types/definition.go create mode 100644 types/errors.go diff --git a/cmds/serve.go b/cmds/serve.go index 81d8274..eb9490b 100644 --- a/cmds/serve.go +++ b/cmds/serve.go @@ -1,15 +1,21 @@ package cmds import ( + "database/sql" "log" "net/http" "gitea.konchin.com/go2025/backend/handlers/api" + "gitea.konchin.com/go2025/backend/implements" "gitea.konchin.com/go2025/backend/middlewares" "gitea.konchin.com/go2025/backend/tracing" "github.com/spf13/cobra" "github.com/spf13/viper" httpSwagger "github.com/swaggo/http-swagger" + "github.com/uptrace/bun" + "github.com/uptrace/bun/dialect/pgdialect" + "github.com/uptrace/bun/driver/pgdriver" + "github.com/uptrace/bun/extra/bunotel" "github.com/uptrace/bunrouter" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -35,13 +41,13 @@ var serveCmd = &cobra.Command{ defer tracing.DeferUptrace(ctx) } - /* - // Initialize DB instance - sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN( - viper.GetString("pg-connection-string")))) - bunDB := bun.NewDB(sqldb, pgdialect.New()) - bunDB.AddQueryHook(bunotel.NewQueryHook(bunotel.WithDBName("backend"))) + // Initialize DB instance + sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN( + viper.GetString("pg-connection-string")))) + bunDB := bun.NewDB(sqldb, pgdialect.New()) + bunDB.AddQueryHook(bunotel.NewQueryHook(bunotel.WithDBName("backend"))) + /* // Initialize MinIO instance mc, err := minio.New(viper.GetString("minio-host"), &minio.Options{ Creds: credentials.NewStaticV4( @@ -64,15 +70,15 @@ var serveCmd = &cobra.Command{ zap.Error(err)) panic(err) } - - // Initialize custom interfaces - db := implements.NewBunDatabase(bunDB) - s3 := implements.NewMinIOObjectStorage(mc) - */ + // Initialize custom interfaces + db := implements.NewBunDatabase(bunDB) + // s3 := implements.NewMinIOObjectStorage(mc) + // Initialize handlers apis := api.NewHandlers() + midHandlers := middlewares.NewHandlers(db) // Initialize backend router router := bunrouter.New() @@ -82,7 +88,9 @@ var serveCmd = &cobra.Command{ Use(middlewares.AccessLog). Use(middlewares.CORSHandler) - apiGroup := backend.NewGroup("/api") + apiGroup := backend.NewGroup("/api"). + Use(midHandlers.CheckRefreshToken). + Use(midHandlers.CheckAccessToken) apiGroup.GET("/images", apis.GetImages) if viper.GetBool("swagger") { diff --git a/go.mod b/go.mod index 8cb1346..76146f7 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module gitea.konchin.com/go2025/backend go 1.25.4 require ( + github.com/golang-jwt/jwt/v5 v5.3.0 github.com/minio/minio-go/v7 v7.0.97 github.com/spf13/cobra v1.10.2 github.com/spf13/viper v1.21.0 diff --git a/go.sum b/go.sum index 4c2f18d..a8ce5bc 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,8 @@ github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyr github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= +github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo= +github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= diff --git a/implements/bunDatabase.go b/implements/bunDatabase.go index ffc7c25..c1c5bbb 100644 --- a/implements/bunDatabase.go +++ b/implements/bunDatabase.go @@ -1,4 +1,69 @@ package implements +import ( + "context" + + "gitea.konchin.com/go2025/backend/models" + "gitea.konchin.com/go2025/backend/tracing" + "github.com/uptrace/bun" + "go.uber.org/zap" +) + type BunDatabase struct { + db *bun.DB +} + +func NewBunDatabase(db *bun.DB) *BunDatabase { + return &BunDatabase{db: db} +} + +func (self *BunDatabase) GetSession( + ctx context.Context, + userId string, +) (models.Session, error) { + ret := models.Session{ + UserId: userId, + } + err := self.db.NewSelect(). + Model(&ret). + WherePK(). + Scan(ctx) + if err != nil { + return models.Session{}, err + } + return ret, nil +} + +func (self *BunDatabase) UpdateRefreshToken( + ctx context.Context, + userId string, +) (models.Session, error) { + ret := models.Session{ + UserId: userId, + } + err := self.db.NewSelect(). + Model(&ret). + WherePK(). + Scan(ctx) + if err != nil { + return models.Session{}, err + } + + if err := ret.RotateRefreshToken(); err != nil { + tracing.Logger.Ctx(ctx). + Error("failed to rotate refresh token", + zap.Error(err)) + return models.Session{}, err + } + + err = self.db.NewUpdate(). + Model((*models.Session)(nil)). + Set("refresh_token = ?", ret.RefreshToken). + Where("user_id = ?", ret.UserId). + Returning("*"). + Scan(ctx, &ret) + if err != nil { + return models.Session{}, err + } + return ret, nil } diff --git a/interfaces/database.go b/interfaces/database.go index 1d3fe42..c211a28 100644 --- a/interfaces/database.go +++ b/interfaces/database.go @@ -1,4 +1,19 @@ package interfaces +import ( + "context" + + "gitea.konchin.com/go2025/backend/models" +) + type Database interface { + GetSession( + ctx context.Context, + userId string, + ) (models.Session, error) + + UpdateRefreshToken( + ctx context.Context, + userId string, + ) (models.Session, error) } diff --git a/middlewares/checkAccessToken.go b/middlewares/checkAccessToken.go new file mode 100644 index 0000000..52b264d --- /dev/null +++ b/middlewares/checkAccessToken.go @@ -0,0 +1,113 @@ +package middlewares + +import ( + "context" + "net/http" + + "gitea.konchin.com/go2025/backend/interfaces" + "gitea.konchin.com/go2025/backend/models" + "gitea.konchin.com/go2025/backend/tracing" + "gitea.konchin.com/go2025/backend/types" + "github.com/golang-jwt/jwt/v5" + "github.com/spf13/viper" + "github.com/uptrace/bunrouter" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/trace" + "go.uber.org/zap" +) + +func refreshAccessToken( + ctx context.Context, + db interfaces.Database, + w http.ResponseWriter, + req bunrouter.Request, +) (string, error) { + refreshTokenClaim, ok := req.Context(). + Value(types.RefreshToken("")).(models.RefreshTokenClaim) + if !ok { + tracing.Logger.Ctx(ctx). + Warn("refresh token not exist") + return "", types.ContextNotExistError + } + + session, err := db.GetSession(ctx, refreshTokenClaim.UserId) + if err != nil { + tracing.Logger.Ctx(ctx). + Warn("session not exist", zap.Error(err)) + return "", err + } + + ret, err := session.ToAccessToken() + if err != nil { + tracing.Logger.Ctx(ctx). + Warn("access token generate failed", zap.Error(err)) + return "", err + } + + http.SetCookie(w, &http.Cookie{ + Name: "access_token", + Value: ret, + Path: "/", + Secure: false, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) + return ret, nil +} + +func (self *Handlers) CheckAccessToken( + next bunrouter.HandlerFunc, +) bunrouter.HandlerFunc { + return func(w http.ResponseWriter, req bunrouter.Request) error { + ctx := req.Context() + + var accessTokenString string + accessTokenCookie, err := req.Cookie("access_token") + if err != nil { + accessTokenString, err = refreshAccessToken(ctx, self.db, w, req) + if err != nil { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "access token refresh failed", + OriginError: err, + } + } + } else { + accessTokenString = accessTokenCookie.Value + } + + var claim models.AccessTokenClaim + token, err := jwt.ParseWithClaims(accessTokenString, &claim, + func(*jwt.Token) (interface{}, error) { + return []byte(viper.GetString("ACCESS_TOKEN_SECRET")), nil + }) + if err != nil || !token.Valid { + accessTokenString, err = refreshAccessToken(ctx, self.db, w, req) + if err != nil { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "access token refresh failed", + OriginError: err, + } + } + token, err := jwt.ParseWithClaims(accessTokenString, &claim, + func(*jwt.Token) (interface{}, error) { + return []byte(viper.GetString("ACCESS_TOKEN_SECRET")), nil + }) + if err != nil || !token.Valid { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "access token jwt cannot parse or invalid", + OriginError: err, + } + } + } + + span := trace.SpanFromContext(ctx) + span.SetAttributes( + attribute.String("owner.UserId", claim.UserId)) + + ctx = context.WithValue(ctx, types.AccessToken(""), claim) + return next(w, req.WithContext(ctx)) + } +} diff --git a/middlewares/checkRefreshToken.go b/middlewares/checkRefreshToken.go new file mode 100644 index 0000000..b848748 --- /dev/null +++ b/middlewares/checkRefreshToken.go @@ -0,0 +1,74 @@ +package middlewares + +import ( + "context" + "net/http" + "time" + + "gitea.konchin.com/go2025/backend/models" + "gitea.konchin.com/go2025/backend/types" + "github.com/golang-jwt/jwt/v5" + "github.com/spf13/viper" + "github.com/uptrace/bunrouter" +) + +func (self *Handlers) CheckRefreshToken( + next bunrouter.HandlerFunc, +) bunrouter.HandlerFunc { + return func(w http.ResponseWriter, req bunrouter.Request) error { + ctx := req.Context() + + refreshTokenCookie, err := req.Cookie("refresh_token") + if err != nil { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "user did not login", + OriginError: err, + } + } + + var claim models.RefreshTokenClaim + token, err := jwt.ParseWithClaims(refreshTokenCookie.Value, &claim, + func(*jwt.Token) (interface{}, error) { + return []byte(viper.GetString("REFRESH_TOKEN_SECRET")), nil + }) + if err != nil { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "refresh token jwt cannot parse", + OriginError: err, + } + } + if !token.Valid { + return HTTPError{ + StatusCode: http.StatusUnauthorized, + Message: "refresh token jwt invalid", + } + } + + // check time and refresh + timeLeft := claim.ExpiresAt.Time.Sub(time.Now()) / time.Second + if int64(timeLeft) < viper.GetInt64("REFRESH_TOKEN_TIMEOUT")/2 { + session, err := self.db.UpdateRefreshToken(ctx, claim.UserId) + if err != nil { + return HTTPError{ + StatusCode: http.StatusInternalServerError, + Message: "upsert session failed", + OriginError: err, + } + } + + http.SetCookie(w, &http.Cookie{ + Name: "refresh_token", + Value: session.RefreshToken, + Path: "/", + Secure: false, + HttpOnly: true, + SameSite: http.SameSiteLaxMode, + }) + } + + ctx = context.WithValue(ctx, types.RefreshToken(""), claim) + return next(w, req.WithContext(ctx)) + } +} diff --git a/middlewares/handlers.go b/middlewares/handlers.go new file mode 100644 index 0000000..052e1c9 --- /dev/null +++ b/middlewares/handlers.go @@ -0,0 +1,11 @@ +package middlewares + +import "gitea.konchin.com/go2025/backend/interfaces" + +type Handlers struct { + db interfaces.Database +} + +func NewHandlers(db interfaces.Database) *Handlers { + return &Handlers{db: db} +} diff --git a/models/session.go b/models/session.go new file mode 100644 index 0000000..4077bbd --- /dev/null +++ b/models/session.go @@ -0,0 +1,68 @@ +package models + +import ( + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/spf13/viper" + "github.com/uptrace/bun" +) + +type Session struct { + bun.BaseModel `bun:"table:session"` + + UserId string `bun:"user_id,pk"` + LoginToken string `bun:"login_token,unique"` + RefreshToken string `bun:"refresh_token,unique"` + + IsValid bool `bun:"is_valid"` +} + +type RefreshTokenClaimFields struct { + UserId string `json:"user_id"` +} + +type RefreshTokenClaim struct { + jwt.RegisteredClaims + + RefreshTokenClaimFields `json:",inline"` +} + +type AccessTokenClaimFields struct { + UserId string `json:"user_id"` +} + +type AccessTokenClaim struct { + jwt.RegisteredClaims + + AccessTokenClaimFields `json:",inline"` +} + +func (self *Session) RotateRefreshToken() error { + refreshToken, err := + jwt.NewWithClaims(jwt.SigningMethodHS256, RefreshTokenClaim{ + RefreshTokenClaimFields: RefreshTokenClaimFields{ + UserId: self.UserId, + }, + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: &jwt.NumericDate{time.Now()}, + ExpiresAt: &jwt.NumericDate{time.Now().Add(time.Duration( + viper.GetInt64("REFRESH_TOKEN_TIMEOUT")) * time.Second)}, + }}).SignedString([]byte(viper.GetString("REFRESH_TOKEN_SECRET"))) + if err != nil { + self.RefreshToken = refreshToken + } + return err +} + +func (self *Session) ToAccessToken() (string, error) { + return jwt.NewWithClaims(jwt.SigningMethodHS256, AccessTokenClaim{ + AccessTokenClaimFields: AccessTokenClaimFields{ + UserId: self.UserId, + }, + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: &jwt.NumericDate{time.Now()}, + ExpiresAt: &jwt.NumericDate{time.Now().Add(time.Duration( + viper.GetInt64("ACCESS_TOKEN_TIMEOUT")) * time.Second)}, + }}).SignedString([]byte(viper.GetString("ACCESS_TOKEN_SECRET"))) +} diff --git a/types/definition.go b/types/definition.go new file mode 100644 index 0000000..07f2926 --- /dev/null +++ b/types/definition.go @@ -0,0 +1,6 @@ +package types + +type ( + AccessToken string + RefreshToken string +) diff --git a/types/errors.go b/types/errors.go new file mode 100644 index 0000000..14114b1 --- /dev/null +++ b/types/errors.go @@ -0,0 +1,10 @@ +package types + +import "fmt" + +var ( + ContextNotExistError = fmt.Errorf("context not exist") + + WrongFormatError = fmt.Errorf("wrong format") + HTTPRequestFailedError = fmt.Errorf("http request failed") +) diff --git a/utils/initDB.go b/utils/initDB.go index d4b585b..b1c3203 100644 --- a/utils/initDB.go +++ b/utils/initDB.go @@ -1 +1,16 @@ package utils + +import ( + "context" + + "gitea.konchin.com/go2025/backend/models" + "github.com/uptrace/bun" +) + +func initDB(ctx context.Context, db *bun.DB) error { + return db.ResetModel(ctx, + (*models.Alias)(nil), + (*models.Image)(nil), + (*models.Session)(nil), + ) +}