269 lines
4.8 KiB
Go
269 lines
4.8 KiB
Go
package auth
|
|
|
|
import (
|
|
"addrss/pkg/config"
|
|
"addrss/pkg/repo"
|
|
"crypto/hmac"
|
|
"crypto/sha256"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
type header struct {
|
|
Alg string `json:"alg"`
|
|
Typ string `json:"typ"`
|
|
}
|
|
|
|
type baseClaims struct {
|
|
Aud string `json:"aud"`
|
|
Iss string `json:"iss"`
|
|
Exp int64 `json:"exp"`
|
|
Sub int64 `json:"sub"`
|
|
}
|
|
|
|
type Payload interface {
|
|
getBaseClaims() baseClaims
|
|
*AccessClaims | *IdClaims | *RefreshClaims
|
|
}
|
|
|
|
type AccessClaims struct {
|
|
baseClaims
|
|
Scopes string `json:"scope"`
|
|
}
|
|
|
|
func (ac *AccessClaims) getBaseClaims() baseClaims {
|
|
return ac.baseClaims
|
|
}
|
|
|
|
type IdClaims struct {
|
|
baseClaims
|
|
FirstName string `json:"firstName"`
|
|
LastName string `json:"lastName"`
|
|
Email string `json:"email"`
|
|
PhoneNumber string `json:"phone_number"`
|
|
}
|
|
|
|
func (ic *IdClaims) getBaseClaims() baseClaims {
|
|
return ic.baseClaims
|
|
}
|
|
|
|
type RefreshClaims struct {
|
|
baseClaims
|
|
Jti string `json:"jti"`
|
|
}
|
|
|
|
func (rc *RefreshClaims) getBaseClaims() baseClaims {
|
|
return rc.baseClaims
|
|
}
|
|
|
|
type Tokens struct {
|
|
AccessToken string `json:"accessToken,omitempty"`
|
|
IdToken string `json:"idToken,omitempty"`
|
|
RefreshToken string `json:"refreshToken,omitempty"`
|
|
}
|
|
|
|
func AcquireTokens(user repo.User) (Tokens, error) {
|
|
at, err := getAccessToken(user)
|
|
if err != nil {
|
|
return Tokens{}, err
|
|
}
|
|
|
|
it, err := getIdToken(user)
|
|
if err != nil {
|
|
return Tokens{}, err
|
|
}
|
|
|
|
rt, err := getRefreshToken(user)
|
|
if err != nil {
|
|
return Tokens{}, err
|
|
}
|
|
|
|
t := Tokens{
|
|
AccessToken: at,
|
|
IdToken: it,
|
|
RefreshToken: rt,
|
|
}
|
|
|
|
return t, nil
|
|
}
|
|
|
|
func ValidateJwtToken[T Payload](token string, claims T) error {
|
|
segments := strings.Split(token, ".")
|
|
if len(segments) != 3 {
|
|
return fmt.Errorf("invalid segment count")
|
|
}
|
|
|
|
k, err := config.GetString("jwtKey")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, []byte(k))
|
|
mac.Write([]byte(segments[0] + "." + segments[1]))
|
|
if strings.ReplaceAll(base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "") != segments[2] {
|
|
return fmt.Errorf("invalid sigature")
|
|
}
|
|
|
|
for len(segments[1])%4 != 0 {
|
|
segments[1] += "="
|
|
}
|
|
|
|
bytes, _ := base64.URLEncoding.DecodeString(segments[1])
|
|
if err := json.Unmarshal(bytes, &claims); err != nil {
|
|
return fmt.Errorf("failed to unmarshal token payload")
|
|
}
|
|
|
|
c := claims.getBaseClaims()
|
|
aud, err := config.GetString("jwtAudience")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
iss, err := config.GetString("jwtIssuer")
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
if c.Aud == aud && c.Iss == iss && c.Exp >= time.Now().Unix() {
|
|
return nil
|
|
}
|
|
|
|
return fmt.Errorf("invalid payload parameters")
|
|
}
|
|
|
|
func getGuestToken() (string, error) {
|
|
bc, err := newBaseClaims(0)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
ap := AccessClaims{
|
|
bc,
|
|
"",
|
|
}
|
|
return encodeAndSignJwt(ap)
|
|
}
|
|
|
|
func getAccessToken(user repo.User) (string, error) {
|
|
bc, err := newBaseClaims(user.Id)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
ac := AccessClaims{
|
|
bc,
|
|
"",
|
|
}
|
|
|
|
return encodeAndSignJwt(ac)
|
|
}
|
|
|
|
func getIdToken(user repo.User) (string, error) {
|
|
bc, err := newBaseClaims(user.Id)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
ic := IdClaims{
|
|
bc,
|
|
user.FirstName,
|
|
user.LastName,
|
|
user.EmailAddress,
|
|
user.PhoneNumber,
|
|
}
|
|
|
|
return encodeAndSignJwt(ic)
|
|
}
|
|
|
|
func getRefreshToken(user repo.User) (string, error) {
|
|
session := repo.UserSession{
|
|
UserId: user.Id,
|
|
TokenId: uuid.New().String(),
|
|
Expiration: time.Now().AddDate(0, 1, 0),
|
|
}
|
|
|
|
if err := repo.AddUserSession(session); err != nil {
|
|
return "", err
|
|
}
|
|
|
|
bc, err := newBaseClaims(user.Id)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
rc := RefreshClaims{
|
|
bc,
|
|
session.TokenId,
|
|
}
|
|
rc.Exp = session.Expiration.Unix()
|
|
|
|
return encodeAndSignJwt(rc)
|
|
}
|
|
|
|
func encodeAndSignJwt(payload any) (string, error) {
|
|
h, err := json.Marshal(newHeader())
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
p, err := json.Marshal(payload)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
token := base64.URLEncoding.EncodeToString(h) + "." + base64.URLEncoding.EncodeToString(p)
|
|
token = strings.ReplaceAll(token, "=", "")
|
|
|
|
key, err := config.GetString("jwtKey")
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
mac := hmac.New(sha256.New, []byte(key))
|
|
_, err = mac.Write([]byte(token))
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
signed := strings.ReplaceAll(token+"."+base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "")
|
|
return signed, nil
|
|
}
|
|
|
|
func newHeader() header {
|
|
return header{
|
|
Alg: "HS256",
|
|
Typ: "JWT",
|
|
}
|
|
}
|
|
|
|
func newBaseClaims(userId int64) (baseClaims, error) {
|
|
aud, err := config.GetString("jwtAudience")
|
|
if err != nil {
|
|
return baseClaims{}, err
|
|
}
|
|
|
|
iss, err := config.GetString("jwtIssuer")
|
|
if err != nil {
|
|
return baseClaims{}, err
|
|
}
|
|
|
|
lt, err := config.GetInt64("jwtLifetime")
|
|
if err != nil {
|
|
return baseClaims{}, err
|
|
}
|
|
|
|
bc := baseClaims{
|
|
Aud: aud,
|
|
Iss: iss,
|
|
Exp: time.Now().Add(time.Duration(lt * int64(time.Second))).Unix(),
|
|
Sub: userId,
|
|
}
|
|
|
|
return bc, nil
|
|
}
|