Feat: add session

This commit is contained in:
2025-12-07 11:27:54 +08:00
parent 9c30bc009d
commit fb1c47b321
12 changed files with 400 additions and 12 deletions

View File

@@ -1,15 +1,21 @@
package cmds package cmds
import ( import (
"database/sql"
"log" "log"
"net/http" "net/http"
"gitea.konchin.com/go2025/backend/handlers/api" "gitea.konchin.com/go2025/backend/handlers/api"
"gitea.konchin.com/go2025/backend/implements"
"gitea.konchin.com/go2025/backend/middlewares" "gitea.konchin.com/go2025/backend/middlewares"
"gitea.konchin.com/go2025/backend/tracing" "gitea.konchin.com/go2025/backend/tracing"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/spf13/viper" "github.com/spf13/viper"
httpSwagger "github.com/swaggo/http-swagger" httpSwagger "github.com/swaggo/http-swagger"
"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
"github.com/uptrace/bun/extra/bunotel"
"github.com/uptrace/bunrouter" "github.com/uptrace/bunrouter"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
@@ -35,13 +41,13 @@ var serveCmd = &cobra.Command{
defer tracing.DeferUptrace(ctx) defer tracing.DeferUptrace(ctx)
} }
/* // Initialize DB instance
// Initialize DB instance sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN(
sqldb := sql.OpenDB(pgdriver.NewConnector(pgdriver.WithDSN( viper.GetString("pg-connection-string"))))
viper.GetString("pg-connection-string")))) bunDB := bun.NewDB(sqldb, pgdialect.New())
bunDB := bun.NewDB(sqldb, pgdialect.New()) bunDB.AddQueryHook(bunotel.NewQueryHook(bunotel.WithDBName("backend")))
bunDB.AddQueryHook(bunotel.NewQueryHook(bunotel.WithDBName("backend")))
/*
// Initialize MinIO instance // Initialize MinIO instance
mc, err := minio.New(viper.GetString("minio-host"), &minio.Options{ mc, err := minio.New(viper.GetString("minio-host"), &minio.Options{
Creds: credentials.NewStaticV4( Creds: credentials.NewStaticV4(
@@ -64,15 +70,15 @@ var serveCmd = &cobra.Command{
zap.Error(err)) zap.Error(err))
panic(err) panic(err)
} }
// Initialize custom interfaces
db := implements.NewBunDatabase(bunDB)
s3 := implements.NewMinIOObjectStorage(mc)
*/ */
// Initialize custom interfaces
db := implements.NewBunDatabase(bunDB)
// s3 := implements.NewMinIOObjectStorage(mc)
// Initialize handlers // Initialize handlers
apis := api.NewHandlers() apis := api.NewHandlers()
midHandlers := middlewares.NewHandlers(db)
// Initialize backend router // Initialize backend router
router := bunrouter.New() router := bunrouter.New()
@@ -82,7 +88,9 @@ var serveCmd = &cobra.Command{
Use(middlewares.AccessLog). Use(middlewares.AccessLog).
Use(middlewares.CORSHandler) Use(middlewares.CORSHandler)
apiGroup := backend.NewGroup("/api") apiGroup := backend.NewGroup("/api").
Use(midHandlers.CheckRefreshToken).
Use(midHandlers.CheckAccessToken)
apiGroup.GET("/images", apis.GetImages) apiGroup.GET("/images", apis.GetImages)
if viper.GetBool("swagger") { if viper.GetBool("swagger") {

1
go.mod
View File

@@ -3,6 +3,7 @@ module gitea.konchin.com/go2025/backend
go 1.25.4 go 1.25.4
require ( require (
github.com/golang-jwt/jwt/v5 v5.3.0
github.com/minio/minio-go/v7 v7.0.97 github.com/minio/minio-go/v7 v7.0.97
github.com/spf13/cobra v1.10.2 github.com/spf13/cobra v1.10.2
github.com/spf13/viper v1.21.0 github.com/spf13/viper v1.21.0

2
go.sum
View File

@@ -34,6 +34,8 @@ github.com/go-openapi/swag v0.19.15 h1:D2NRCBzS9/pEY3gP9Nl8aDqGUcPFrwG2p+CNFrLyr
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ= github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs= github.com/go-viper/mapstructure/v2 v2.4.0 h1:EBsztssimR/CONLSZZ04E8qAkxNYq4Qp9LvH92wZUgs=
github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/go-viper/mapstructure/v2 v2.4.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM=
github.com/golang-jwt/jwt/v5 v5.3.0 h1:pv4AsKCKKZuqlgs5sUmn4x8UlGa0kEVt/puTpKx9vvo=
github.com/golang-jwt/jwt/v5 v5.3.0/go.mod h1:fxCRLWMO43lRc8nhHWY6LGqRcf+1gQWArsqaEUEa5bE=
github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek= github.com/golang/protobuf v1.5.4 h1:i7eJL8qZTpSEXOPTxNKhASYpMn+8e5Q6AdndVa1dWek=
github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps= github.com/golang/protobuf v1.5.4/go.mod h1:lnTiLA8Wa4RWRcIUkrtSVa5nRhsEGBg48fD6rSs7xps=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8= github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=

View File

@@ -1,4 +1,69 @@
package implements package implements
import (
"context"
"gitea.konchin.com/go2025/backend/models"
"gitea.konchin.com/go2025/backend/tracing"
"github.com/uptrace/bun"
"go.uber.org/zap"
)
type BunDatabase struct { type BunDatabase struct {
db *bun.DB
}
func NewBunDatabase(db *bun.DB) *BunDatabase {
return &BunDatabase{db: db}
}
func (self *BunDatabase) GetSession(
ctx context.Context,
userId string,
) (models.Session, error) {
ret := models.Session{
UserId: userId,
}
err := self.db.NewSelect().
Model(&ret).
WherePK().
Scan(ctx)
if err != nil {
return models.Session{}, err
}
return ret, nil
}
func (self *BunDatabase) UpdateRefreshToken(
ctx context.Context,
userId string,
) (models.Session, error) {
ret := models.Session{
UserId: userId,
}
err := self.db.NewSelect().
Model(&ret).
WherePK().
Scan(ctx)
if err != nil {
return models.Session{}, err
}
if err := ret.RotateRefreshToken(); err != nil {
tracing.Logger.Ctx(ctx).
Error("failed to rotate refresh token",
zap.Error(err))
return models.Session{}, err
}
err = self.db.NewUpdate().
Model((*models.Session)(nil)).
Set("refresh_token = ?", ret.RefreshToken).
Where("user_id = ?", ret.UserId).
Returning("*").
Scan(ctx, &ret)
if err != nil {
return models.Session{}, err
}
return ret, nil
} }

View File

@@ -1,4 +1,19 @@
package interfaces package interfaces
import (
"context"
"gitea.konchin.com/go2025/backend/models"
)
type Database interface { type Database interface {
GetSession(
ctx context.Context,
userId string,
) (models.Session, error)
UpdateRefreshToken(
ctx context.Context,
userId string,
) (models.Session, error)
} }

View File

@@ -0,0 +1,113 @@
package middlewares
import (
"context"
"net/http"
"gitea.konchin.com/go2025/backend/interfaces"
"gitea.konchin.com/go2025/backend/models"
"gitea.konchin.com/go2025/backend/tracing"
"gitea.konchin.com/go2025/backend/types"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/viper"
"github.com/uptrace/bunrouter"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"go.uber.org/zap"
)
func refreshAccessToken(
ctx context.Context,
db interfaces.Database,
w http.ResponseWriter,
req bunrouter.Request,
) (string, error) {
refreshTokenClaim, ok := req.Context().
Value(types.RefreshToken("")).(models.RefreshTokenClaim)
if !ok {
tracing.Logger.Ctx(ctx).
Warn("refresh token not exist")
return "", types.ContextNotExistError
}
session, err := db.GetSession(ctx, refreshTokenClaim.UserId)
if err != nil {
tracing.Logger.Ctx(ctx).
Warn("session not exist", zap.Error(err))
return "", err
}
ret, err := session.ToAccessToken()
if err != nil {
tracing.Logger.Ctx(ctx).
Warn("access token generate failed", zap.Error(err))
return "", err
}
http.SetCookie(w, &http.Cookie{
Name: "access_token",
Value: ret,
Path: "/",
Secure: false,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
return ret, nil
}
func (self *Handlers) CheckAccessToken(
next bunrouter.HandlerFunc,
) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
ctx := req.Context()
var accessTokenString string
accessTokenCookie, err := req.Cookie("access_token")
if err != nil {
accessTokenString, err = refreshAccessToken(ctx, self.db, w, req)
if err != nil {
return HTTPError{
StatusCode: http.StatusUnauthorized,
Message: "access token refresh failed",
OriginError: err,
}
}
} else {
accessTokenString = accessTokenCookie.Value
}
var claim models.AccessTokenClaim
token, err := jwt.ParseWithClaims(accessTokenString, &claim,
func(*jwt.Token) (interface{}, error) {
return []byte(viper.GetString("ACCESS_TOKEN_SECRET")), nil
})
if err != nil || !token.Valid {
accessTokenString, err = refreshAccessToken(ctx, self.db, w, req)
if err != nil {
return HTTPError{
StatusCode: http.StatusUnauthorized,
Message: "access token refresh failed",
OriginError: err,
}
}
token, err := jwt.ParseWithClaims(accessTokenString, &claim,
func(*jwt.Token) (interface{}, error) {
return []byte(viper.GetString("ACCESS_TOKEN_SECRET")), nil
})
if err != nil || !token.Valid {
return HTTPError{
StatusCode: http.StatusUnauthorized,
Message: "access token jwt cannot parse or invalid",
OriginError: err,
}
}
}
span := trace.SpanFromContext(ctx)
span.SetAttributes(
attribute.String("owner.UserId", claim.UserId))
ctx = context.WithValue(ctx, types.AccessToken(""), claim)
return next(w, req.WithContext(ctx))
}
}

View File

@@ -0,0 +1,74 @@
package middlewares
import (
"context"
"net/http"
"time"
"gitea.konchin.com/go2025/backend/models"
"gitea.konchin.com/go2025/backend/types"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/viper"
"github.com/uptrace/bunrouter"
)
func (self *Handlers) CheckRefreshToken(
next bunrouter.HandlerFunc,
) bunrouter.HandlerFunc {
return func(w http.ResponseWriter, req bunrouter.Request) error {
ctx := req.Context()
refreshTokenCookie, err := req.Cookie("refresh_token")
if err != nil {
return HTTPError{
StatusCode: http.StatusUnauthorized,
Message: "user did not login",
OriginError: err,
}
}
var claim models.RefreshTokenClaim
token, err := jwt.ParseWithClaims(refreshTokenCookie.Value, &claim,
func(*jwt.Token) (interface{}, error) {
return []byte(viper.GetString("REFRESH_TOKEN_SECRET")), nil
})
if err != nil {
return HTTPError{
StatusCode: http.StatusUnauthorized,
Message: "refresh token jwt cannot parse",
OriginError: err,
}
}
if !token.Valid {
return HTTPError{
StatusCode: http.StatusUnauthorized,
Message: "refresh token jwt invalid",
}
}
// check time and refresh
timeLeft := claim.ExpiresAt.Time.Sub(time.Now()) / time.Second
if int64(timeLeft) < viper.GetInt64("REFRESH_TOKEN_TIMEOUT")/2 {
session, err := self.db.UpdateRefreshToken(ctx, claim.UserId)
if err != nil {
return HTTPError{
StatusCode: http.StatusInternalServerError,
Message: "upsert session failed",
OriginError: err,
}
}
http.SetCookie(w, &http.Cookie{
Name: "refresh_token",
Value: session.RefreshToken,
Path: "/",
Secure: false,
HttpOnly: true,
SameSite: http.SameSiteLaxMode,
})
}
ctx = context.WithValue(ctx, types.RefreshToken(""), claim)
return next(w, req.WithContext(ctx))
}
}

11
middlewares/handlers.go Normal file
View File

@@ -0,0 +1,11 @@
package middlewares
import "gitea.konchin.com/go2025/backend/interfaces"
type Handlers struct {
db interfaces.Database
}
func NewHandlers(db interfaces.Database) *Handlers {
return &Handlers{db: db}
}

68
models/session.go Normal file
View File

@@ -0,0 +1,68 @@
package models
import (
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/spf13/viper"
"github.com/uptrace/bun"
)
type Session struct {
bun.BaseModel `bun:"table:session"`
UserId string `bun:"user_id,pk"`
LoginToken string `bun:"login_token,unique"`
RefreshToken string `bun:"refresh_token,unique"`
IsValid bool `bun:"is_valid"`
}
type RefreshTokenClaimFields struct {
UserId string `json:"user_id"`
}
type RefreshTokenClaim struct {
jwt.RegisteredClaims
RefreshTokenClaimFields `json:",inline"`
}
type AccessTokenClaimFields struct {
UserId string `json:"user_id"`
}
type AccessTokenClaim struct {
jwt.RegisteredClaims
AccessTokenClaimFields `json:",inline"`
}
func (self *Session) RotateRefreshToken() error {
refreshToken, err :=
jwt.NewWithClaims(jwt.SigningMethodHS256, RefreshTokenClaim{
RefreshTokenClaimFields: RefreshTokenClaimFields{
UserId: self.UserId,
},
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: &jwt.NumericDate{time.Now()},
ExpiresAt: &jwt.NumericDate{time.Now().Add(time.Duration(
viper.GetInt64("REFRESH_TOKEN_TIMEOUT")) * time.Second)},
}}).SignedString([]byte(viper.GetString("REFRESH_TOKEN_SECRET")))
if err != nil {
self.RefreshToken = refreshToken
}
return err
}
func (self *Session) ToAccessToken() (string, error) {
return jwt.NewWithClaims(jwt.SigningMethodHS256, AccessTokenClaim{
AccessTokenClaimFields: AccessTokenClaimFields{
UserId: self.UserId,
},
RegisteredClaims: jwt.RegisteredClaims{
IssuedAt: &jwt.NumericDate{time.Now()},
ExpiresAt: &jwt.NumericDate{time.Now().Add(time.Duration(
viper.GetInt64("ACCESS_TOKEN_TIMEOUT")) * time.Second)},
}}).SignedString([]byte(viper.GetString("ACCESS_TOKEN_SECRET")))
}

6
types/definition.go Normal file
View File

@@ -0,0 +1,6 @@
package types
type (
AccessToken string
RefreshToken string
)

10
types/errors.go Normal file
View File

@@ -0,0 +1,10 @@
package types
import "fmt"
var (
ContextNotExistError = fmt.Errorf("context not exist")
WrongFormatError = fmt.Errorf("wrong format")
HTTPRequestFailedError = fmt.Errorf("http request failed")
)

View File

@@ -1 +1,16 @@
package utils package utils
import (
"context"
"gitea.konchin.com/go2025/backend/models"
"github.com/uptrace/bun"
)
func initDB(ctx context.Context, db *bun.DB) error {
return db.ResetModel(ctx,
(*models.Alias)(nil),
(*models.Image)(nil),
(*models.Session)(nil),
)
}