Initial commit
This commit is contained in:
268
pkg/auth/tokens.go
Normal file
268
pkg/auth/tokens.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"addrss/pkg/config"
|
||||
"addrss/pkg/repo"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type header struct {
|
||||
Alg string `json:"alg"`
|
||||
Typ string `json:"typ"`
|
||||
}
|
||||
|
||||
type baseClaims struct {
|
||||
Aud string `json:"aud"`
|
||||
Iss string `json:"iss"`
|
||||
Exp int64 `json:"exp"`
|
||||
Sub int64 `json:"sub"`
|
||||
}
|
||||
|
||||
type Payload interface {
|
||||
getBaseClaims() baseClaims
|
||||
*AccessClaims | *IdClaims | *RefreshClaims
|
||||
}
|
||||
|
||||
type AccessClaims struct {
|
||||
baseClaims
|
||||
Scopes string `json:"scope"`
|
||||
}
|
||||
|
||||
func (ac *AccessClaims) getBaseClaims() baseClaims {
|
||||
return ac.baseClaims
|
||||
}
|
||||
|
||||
type IdClaims struct {
|
||||
baseClaims
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
Email string `json:"email"`
|
||||
PhoneNumber string `json:"phone_number"`
|
||||
}
|
||||
|
||||
func (ic *IdClaims) getBaseClaims() baseClaims {
|
||||
return ic.baseClaims
|
||||
}
|
||||
|
||||
type RefreshClaims struct {
|
||||
baseClaims
|
||||
Jti string `json:"jti"`
|
||||
}
|
||||
|
||||
func (rc *RefreshClaims) getBaseClaims() baseClaims {
|
||||
return rc.baseClaims
|
||||
}
|
||||
|
||||
type Tokens struct {
|
||||
AccessToken string `json:"accessToken,omitempty"`
|
||||
IdToken string `json:"idToken,omitempty"`
|
||||
RefreshToken string `json:"refreshToken,omitempty"`
|
||||
}
|
||||
|
||||
func AcquireTokens(user repo.User) (Tokens, error) {
|
||||
at, err := getAccessToken(user)
|
||||
if err != nil {
|
||||
return Tokens{}, err
|
||||
}
|
||||
|
||||
it, err := getIdToken(user)
|
||||
if err != nil {
|
||||
return Tokens{}, err
|
||||
}
|
||||
|
||||
rt, err := getRefreshToken(user)
|
||||
if err != nil {
|
||||
return Tokens{}, err
|
||||
}
|
||||
|
||||
t := Tokens{
|
||||
AccessToken: at,
|
||||
IdToken: it,
|
||||
RefreshToken: rt,
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func ValidateJwtToken[T Payload](token string, claims T) error {
|
||||
segments := strings.Split(token, ".")
|
||||
if len(segments) != 3 {
|
||||
return fmt.Errorf("invalid segment count")
|
||||
}
|
||||
|
||||
k, err := config.GetString("jwtKey")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(k))
|
||||
mac.Write([]byte(segments[0] + "." + segments[1]))
|
||||
if strings.ReplaceAll(base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "") != segments[2] {
|
||||
return fmt.Errorf("invalid sigature")
|
||||
}
|
||||
|
||||
for len(segments[1])%4 != 0 {
|
||||
segments[1] += "="
|
||||
}
|
||||
|
||||
bytes, _ := base64.URLEncoding.DecodeString(segments[1])
|
||||
if err := json.Unmarshal(bytes, &claims); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal token payload")
|
||||
}
|
||||
|
||||
c := claims.getBaseClaims()
|
||||
aud, err := config.GetString("jwtAudience")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iss, err := config.GetString("jwtIssuer")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Aud == aud && c.Iss == iss && c.Exp >= time.Now().Unix() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid payload parameters")
|
||||
}
|
||||
|
||||
func getGuestToken() (string, error) {
|
||||
bc, err := newBaseClaims(0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ap := AccessClaims{
|
||||
bc,
|
||||
"",
|
||||
}
|
||||
return encodeAndSignJwt(ap)
|
||||
}
|
||||
|
||||
func getAccessToken(user repo.User) (string, error) {
|
||||
bc, err := newBaseClaims(user.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ac := AccessClaims{
|
||||
bc,
|
||||
"",
|
||||
}
|
||||
|
||||
return encodeAndSignJwt(ac)
|
||||
}
|
||||
|
||||
func getIdToken(user repo.User) (string, error) {
|
||||
bc, err := newBaseClaims(user.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ic := IdClaims{
|
||||
bc,
|
||||
user.FirstName,
|
||||
user.LastName,
|
||||
user.EmailAddress,
|
||||
user.PhoneNumber,
|
||||
}
|
||||
|
||||
return encodeAndSignJwt(ic)
|
||||
}
|
||||
|
||||
func getRefreshToken(user repo.User) (string, error) {
|
||||
session := repo.UserSession{
|
||||
UserId: user.Id,
|
||||
TokenId: uuid.New().String(),
|
||||
Expiration: time.Now().AddDate(0, 1, 0),
|
||||
}
|
||||
|
||||
if err := repo.AddUserSession(session); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bc, err := newBaseClaims(user.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rc := RefreshClaims{
|
||||
bc,
|
||||
session.TokenId,
|
||||
}
|
||||
rc.Exp = session.Expiration.Unix()
|
||||
|
||||
return encodeAndSignJwt(rc)
|
||||
}
|
||||
|
||||
func encodeAndSignJwt(payload any) (string, error) {
|
||||
h, err := json.Marshal(newHeader())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
p, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token := base64.URLEncoding.EncodeToString(h) + "." + base64.URLEncoding.EncodeToString(p)
|
||||
token = strings.ReplaceAll(token, "=", "")
|
||||
|
||||
key, err := config.GetString("jwtKey")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
_, err = mac.Write([]byte(token))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signed := strings.ReplaceAll(token+"."+base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "")
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func newHeader() header {
|
||||
return header{
|
||||
Alg: "HS256",
|
||||
Typ: "JWT",
|
||||
}
|
||||
}
|
||||
|
||||
func newBaseClaims(userId int64) (baseClaims, error) {
|
||||
aud, err := config.GetString("jwtAudience")
|
||||
if err != nil {
|
||||
return baseClaims{}, err
|
||||
}
|
||||
|
||||
iss, err := config.GetString("jwtIssuer")
|
||||
if err != nil {
|
||||
return baseClaims{}, err
|
||||
}
|
||||
|
||||
lt, err := config.GetInt64("jwtLifetime")
|
||||
if err != nil {
|
||||
return baseClaims{}, err
|
||||
}
|
||||
|
||||
bc := baseClaims{
|
||||
Aud: aud,
|
||||
Iss: iss,
|
||||
Exp: time.Now().Add(time.Duration(lt * int64(time.Second))).Unix(),
|
||||
Sub: userId,
|
||||
}
|
||||
|
||||
return bc, nil
|
||||
}
|
||||
Reference in New Issue
Block a user