Fix: rotate access token and add test

This commit is contained in:
2025-12-07 22:06:32 +08:00
parent 8d3cd0260e
commit 1ce2174bdc
7 changed files with 46 additions and 12 deletions

View File

@@ -44,7 +44,7 @@ func (self *Handlers) PostLogin(
} }
} }
session, err := self.db.GetSession(ctx, input.Token) session, err := self.db.GetSessionByLoginToken(ctx, input.Token)
if err != nil { if err != nil {
return middlewares.HTTPError{ return middlewares.HTTPError{
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,

View File

@@ -18,7 +18,7 @@ func NewBunDatabase(db *bun.DB) *BunDatabase {
return &BunDatabase{db: db} return &BunDatabase{db: db}
} }
func (self *BunDatabase) GetSession( func (self *BunDatabase) GetSessionByLoginToken(
ctx context.Context, ctx context.Context,
loginToken string, loginToken string,
) (models.Session, error) { ) (models.Session, error) {
@@ -35,6 +35,23 @@ func (self *BunDatabase) GetSession(
return ret, nil return ret, nil
} }
func (self *BunDatabase) GetSessionByUserId(
ctx context.Context,
userId string,
) (models.Session, error) {
ret := models.Session{
UserId: userId,
}
err := self.db.NewSelect().
Model(&ret).
Where("user_id = ?", userId).
Scan(ctx)
if err != nil {
return models.Session{}, err
}
return ret, nil
}
func (self *BunDatabase) UpdateRefreshToken( func (self *BunDatabase) UpdateRefreshToken(
ctx context.Context, ctx context.Context,
userId string, userId string,
@@ -88,6 +105,7 @@ func (self *BunDatabase) UpsertLoginToken(
session := models.Session{ session := models.Session{
UserId: userId, UserId: userId,
LoginToken: token, LoginToken: token,
IsValid: true,
} }
_, err = self.db.NewInsert(). _, err = self.db.NewInsert().
Model(&session). Model(&session).

View File

@@ -7,11 +7,16 @@ import (
) )
type Database interface { type Database interface {
GetSession( GetSessionByLoginToken(
ctx context.Context, ctx context.Context,
loginToken string, loginToken string,
) (models.Session, error) ) (models.Session, error)
GetSessionByUserId(
ctx context.Context,
userId string,
) (models.Session, error)
UpdateRefreshToken( UpdateRefreshToken(
ctx context.Context, ctx context.Context,
userId string, userId string,

View File

@@ -30,7 +30,8 @@ func refreshAccessToken(
return "", types.ContextNotExistError return "", types.ContextNotExistError
} }
session, err := db.GetSession(ctx, refreshTokenClaim.UserId) session, err := db.GetSessionByUserId(ctx,
refreshTokenClaim.UserId)
if err != nil { if err != nil {
tracing.Logger.Ctx(ctx). tracing.Logger.Ctx(ctx).
Warn("session not exist", zap.Error(err)) Warn("session not exist", zap.Error(err))
@@ -68,7 +69,7 @@ func (self *Handlers) CheckAccessToken(
if err != nil { if err != nil {
return HTTPError{ return HTTPError{
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
Message: "access token refresh failed", Message: "failed to refresh access token",
OriginError: err, OriginError: err,
} }
} }
@@ -86,7 +87,7 @@ func (self *Handlers) CheckAccessToken(
if err != nil { if err != nil {
return HTTPError{ return HTTPError{
StatusCode: http.StatusUnauthorized, StatusCode: http.StatusUnauthorized,
Message: "access token refresh failed", Message: "failed to refresh access token",
OriginError: err, OriginError: err,
} }
} }

View File

@@ -67,6 +67,7 @@ func Test_01_Login(t *testing.T) {
if len(cookie.Value) == 0 { if len(cookie.Value) == 0 {
t.Fatal("empty refresh token") t.Fatal("empty refresh token")
} }
client.SetCookie(cookie)
return return
} }
} }

View File

@@ -0,0 +1,15 @@
package main
import (
"net/http"
"testing"
)
func Test_02_GetImages(t *testing.T) {
resp, err := client.R().
Get("http://localhost:8080/api/aliases")
if err != nil || resp.StatusCode() != http.StatusOK {
t.Logf("%+v", resp)
t.Fatal("failed to fetch aliases")
}
}

View File

@@ -1,6 +0,0 @@
package main
import "testing"
func Test_02_GetImages(t *testing.T) {
}