Files
addrss.io/pkg/auth/tokens.go
2025-09-06 21:35:45 -04:00

269 lines
4.8 KiB
Go

package auth
import (
"addrss/pkg/config"
"addrss/pkg/repo"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
)
type header struct {
Alg string `json:"alg"`
Typ string `json:"typ"`
}
type baseClaims struct {
Aud string `json:"aud"`
Iss string `json:"iss"`
Exp int64 `json:"exp"`
Sub int64 `json:"sub"`
}
type Payload interface {
getBaseClaims() baseClaims
*AccessClaims | *IdClaims | *RefreshClaims
}
type AccessClaims struct {
baseClaims
Scopes string `json:"scope"`
}
func (ac *AccessClaims) getBaseClaims() baseClaims {
return ac.baseClaims
}
type IdClaims struct {
baseClaims
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
Email string `json:"email"`
PhoneNumber string `json:"phone_number"`
}
func (ic *IdClaims) getBaseClaims() baseClaims {
return ic.baseClaims
}
type RefreshClaims struct {
baseClaims
Jti string `json:"jti"`
}
func (rc *RefreshClaims) getBaseClaims() baseClaims {
return rc.baseClaims
}
type Tokens struct {
AccessToken string `json:"accessToken,omitempty"`
IdToken string `json:"idToken,omitempty"`
RefreshToken string `json:"refreshToken,omitempty"`
}
func AcquireTokens(user repo.User) (Tokens, error) {
at, err := getAccessToken(user)
if err != nil {
return Tokens{}, err
}
it, err := getIdToken(user)
if err != nil {
return Tokens{}, err
}
rt, err := getRefreshToken(user)
if err != nil {
return Tokens{}, err
}
t := Tokens{
AccessToken: at,
IdToken: it,
RefreshToken: rt,
}
return t, nil
}
func ValidateJwtToken[T Payload](token string, claims T) error {
segments := strings.Split(token, ".")
if len(segments) != 3 {
return fmt.Errorf("invalid segment count")
}
k, err := config.GetString("jwtKey")
if err != nil {
return err
}
mac := hmac.New(sha256.New, []byte(k))
mac.Write([]byte(segments[0] + "." + segments[1]))
if strings.ReplaceAll(base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "") != segments[2] {
return fmt.Errorf("invalid sigature")
}
for len(segments[1])%4 != 0 {
segments[1] += "="
}
bytes, _ := base64.URLEncoding.DecodeString(segments[1])
if err := json.Unmarshal(bytes, &claims); err != nil {
return fmt.Errorf("failed to unmarshal token payload")
}
c := claims.getBaseClaims()
aud, err := config.GetString("jwtAudience")
if err != nil {
return err
}
iss, err := config.GetString("jwtIssuer")
if err != nil {
return err
}
if c.Aud == aud && c.Iss == iss && c.Exp >= time.Now().Unix() {
return nil
}
return fmt.Errorf("invalid payload parameters")
}
func getGuestToken() (string, error) {
bc, err := newBaseClaims(0)
if err != nil {
return "", err
}
ap := AccessClaims{
bc,
"",
}
return encodeAndSignJwt(ap)
}
func getAccessToken(user repo.User) (string, error) {
bc, err := newBaseClaims(user.Id)
if err != nil {
return "", err
}
ac := AccessClaims{
bc,
"",
}
return encodeAndSignJwt(ac)
}
func getIdToken(user repo.User) (string, error) {
bc, err := newBaseClaims(user.Id)
if err != nil {
return "", err
}
ic := IdClaims{
bc,
user.FirstName,
user.LastName,
user.EmailAddress,
user.PhoneNumber,
}
return encodeAndSignJwt(ic)
}
func getRefreshToken(user repo.User) (string, error) {
session := repo.UserSession{
UserId: user.Id,
TokenId: uuid.New().String(),
Expiration: time.Now().AddDate(0, 1, 0),
}
if err := repo.AddUserSession(session); err != nil {
return "", err
}
bc, err := newBaseClaims(user.Id)
if err != nil {
return "", err
}
rc := RefreshClaims{
bc,
session.TokenId,
}
rc.Exp = session.Expiration.Unix()
return encodeAndSignJwt(rc)
}
func encodeAndSignJwt(payload any) (string, error) {
h, err := json.Marshal(newHeader())
if err != nil {
return "", err
}
p, err := json.Marshal(payload)
if err != nil {
return "", err
}
token := base64.URLEncoding.EncodeToString(h) + "." + base64.URLEncoding.EncodeToString(p)
token = strings.ReplaceAll(token, "=", "")
key, err := config.GetString("jwtKey")
if err != nil {
return "", err
}
mac := hmac.New(sha256.New, []byte(key))
_, err = mac.Write([]byte(token))
if err != nil {
return "", err
}
signed := strings.ReplaceAll(token+"."+base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "")
return signed, nil
}
func newHeader() header {
return header{
Alg: "HS256",
Typ: "JWT",
}
}
func newBaseClaims(userId int64) (baseClaims, error) {
aud, err := config.GetString("jwtAudience")
if err != nil {
return baseClaims{}, err
}
iss, err := config.GetString("jwtIssuer")
if err != nil {
return baseClaims{}, err
}
lt, err := config.GetInt64("jwtLifetime")
if err != nil {
return baseClaims{}, err
}
bc := baseClaims{
Aud: aud,
Iss: iss,
Exp: time.Now().Add(time.Duration(lt * int64(time.Second))).Unix(),
Sub: userId,
}
return bc, nil
}