Fix: rotate access token and add test
This commit is contained in:
@@ -44,7 +44,7 @@ func (self *Handlers) PostLogin(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := self.db.GetSession(ctx, input.Token)
|
session, err := self.db.GetSessionByLoginToken(ctx, input.Token)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return middlewares.HTTPError{
|
return middlewares.HTTPError{
|
||||||
StatusCode: http.StatusUnauthorized,
|
StatusCode: http.StatusUnauthorized,
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func NewBunDatabase(db *bun.DB) *BunDatabase {
|
|||||||
return &BunDatabase{db: db}
|
return &BunDatabase{db: db}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (self *BunDatabase) GetSession(
|
func (self *BunDatabase) GetSessionByLoginToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
loginToken string,
|
loginToken string,
|
||||||
) (models.Session, error) {
|
) (models.Session, error) {
|
||||||
@@ -35,6 +35,23 @@ func (self *BunDatabase) GetSession(
|
|||||||
return ret, nil
|
return ret, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (self *BunDatabase) GetSessionByUserId(
|
||||||
|
ctx context.Context,
|
||||||
|
userId string,
|
||||||
|
) (models.Session, error) {
|
||||||
|
ret := models.Session{
|
||||||
|
UserId: userId,
|
||||||
|
}
|
||||||
|
err := self.db.NewSelect().
|
||||||
|
Model(&ret).
|
||||||
|
Where("user_id = ?", userId).
|
||||||
|
Scan(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return models.Session{}, err
|
||||||
|
}
|
||||||
|
return ret, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (self *BunDatabase) UpdateRefreshToken(
|
func (self *BunDatabase) UpdateRefreshToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userId string,
|
userId string,
|
||||||
@@ -88,6 +105,7 @@ func (self *BunDatabase) UpsertLoginToken(
|
|||||||
session := models.Session{
|
session := models.Session{
|
||||||
UserId: userId,
|
UserId: userId,
|
||||||
LoginToken: token,
|
LoginToken: token,
|
||||||
|
IsValid: true,
|
||||||
}
|
}
|
||||||
_, err = self.db.NewInsert().
|
_, err = self.db.NewInsert().
|
||||||
Model(&session).
|
Model(&session).
|
||||||
|
|||||||
@@ -7,11 +7,16 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Database interface {
|
type Database interface {
|
||||||
GetSession(
|
GetSessionByLoginToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
loginToken string,
|
loginToken string,
|
||||||
) (models.Session, error)
|
) (models.Session, error)
|
||||||
|
|
||||||
|
GetSessionByUserId(
|
||||||
|
ctx context.Context,
|
||||||
|
userId string,
|
||||||
|
) (models.Session, error)
|
||||||
|
|
||||||
UpdateRefreshToken(
|
UpdateRefreshToken(
|
||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
userId string,
|
userId string,
|
||||||
|
|||||||
@@ -30,7 +30,8 @@ func refreshAccessToken(
|
|||||||
return "", types.ContextNotExistError
|
return "", types.ContextNotExistError
|
||||||
}
|
}
|
||||||
|
|
||||||
session, err := db.GetSession(ctx, refreshTokenClaim.UserId)
|
session, err := db.GetSessionByUserId(ctx,
|
||||||
|
refreshTokenClaim.UserId)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
tracing.Logger.Ctx(ctx).
|
tracing.Logger.Ctx(ctx).
|
||||||
Warn("session not exist", zap.Error(err))
|
Warn("session not exist", zap.Error(err))
|
||||||
@@ -68,7 +69,7 @@ func (self *Handlers) CheckAccessToken(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return HTTPError{
|
return HTTPError{
|
||||||
StatusCode: http.StatusUnauthorized,
|
StatusCode: http.StatusUnauthorized,
|
||||||
Message: "access token refresh failed",
|
Message: "failed to refresh access token",
|
||||||
OriginError: err,
|
OriginError: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -86,7 +87,7 @@ func (self *Handlers) CheckAccessToken(
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return HTTPError{
|
return HTTPError{
|
||||||
StatusCode: http.StatusUnauthorized,
|
StatusCode: http.StatusUnauthorized,
|
||||||
Message: "access token refresh failed",
|
Message: "failed to refresh access token",
|
||||||
OriginError: err,
|
OriginError: err,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -67,6 +67,7 @@ func Test_01_Login(t *testing.T) {
|
|||||||
if len(cookie.Value) == 0 {
|
if len(cookie.Value) == 0 {
|
||||||
t.Fatal("empty refresh token")
|
t.Fatal("empty refresh token")
|
||||||
}
|
}
|
||||||
|
client.SetCookie(cookie)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
15
tests/02_getAliases_test.go
Normal file
15
tests/02_getAliases_test.go
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Test_02_GetImages(t *testing.T) {
|
||||||
|
resp, err := client.R().
|
||||||
|
Get("http://localhost:8080/api/aliases")
|
||||||
|
if err != nil || resp.StatusCode() != http.StatusOK {
|
||||||
|
t.Logf("%+v", resp)
|
||||||
|
t.Fatal("failed to fetch aliases")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +0,0 @@
|
|||||||
package main
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func Test_02_GetImages(t *testing.T) {
|
|
||||||
}
|
|
||||||
Reference in New Issue
Block a user