package auth import ( "addrss/pkg/config" "addrss/pkg/repo" "crypto/hmac" "crypto/sha256" "encoding/base64" "encoding/json" "fmt" "strings" "time" "github.com/google/uuid" ) type header struct { Alg string `json:"alg"` Typ string `json:"typ"` } type baseClaims struct { Aud string `json:"aud"` Iss string `json:"iss"` Exp int64 `json:"exp"` Sub int64 `json:"sub"` } type Payload interface { getBaseClaims() baseClaims *AccessClaims | *IdClaims | *RefreshClaims } type AccessClaims struct { baseClaims Scopes string `json:"scope"` } func (ac *AccessClaims) getBaseClaims() baseClaims { return ac.baseClaims } type IdClaims struct { baseClaims FirstName string `json:"firstName"` LastName string `json:"lastName"` Email string `json:"email"` PhoneNumber string `json:"phone_number"` } func (ic *IdClaims) getBaseClaims() baseClaims { return ic.baseClaims } type RefreshClaims struct { baseClaims Jti string `json:"jti"` } func (rc *RefreshClaims) getBaseClaims() baseClaims { return rc.baseClaims } type Tokens struct { AccessToken string `json:"accessToken,omitempty"` IdToken string `json:"idToken,omitempty"` RefreshToken string `json:"refreshToken,omitempty"` } func AcquireTokens(user repo.User) (Tokens, error) { at, err := getAccessToken(user) if err != nil { return Tokens{}, err } it, err := getIdToken(user) if err != nil { return Tokens{}, err } rt, err := getRefreshToken(user) if err != nil { return Tokens{}, err } t := Tokens{ AccessToken: at, IdToken: it, RefreshToken: rt, } return t, nil } func ValidateJwtToken[T Payload](token string, claims T) error { segments := strings.Split(token, ".") if len(segments) != 3 { return fmt.Errorf("invalid segment count") } k, err := config.GetString("jwtKey") if err != nil { return err } mac := hmac.New(sha256.New, []byte(k)) mac.Write([]byte(segments[0] + "." + segments[1])) if strings.ReplaceAll(base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "") != segments[2] { return fmt.Errorf("invalid sigature") } for len(segments[1])%4 != 0 { segments[1] += "=" } bytes, _ := base64.URLEncoding.DecodeString(segments[1]) if err := json.Unmarshal(bytes, &claims); err != nil { return fmt.Errorf("failed to unmarshal token payload") } c := claims.getBaseClaims() aud, err := config.GetString("jwtAudience") if err != nil { return err } iss, err := config.GetString("jwtIssuer") if err != nil { return err } if c.Aud == aud && c.Iss == iss && c.Exp >= time.Now().Unix() { return nil } return fmt.Errorf("invalid payload parameters") } func getGuestToken() (string, error) { bc, err := newBaseClaims(0) if err != nil { return "", err } ap := AccessClaims{ bc, "", } return encodeAndSignJwt(ap) } func getAccessToken(user repo.User) (string, error) { bc, err := newBaseClaims(user.Id) if err != nil { return "", err } ac := AccessClaims{ bc, "", } return encodeAndSignJwt(ac) } func getIdToken(user repo.User) (string, error) { bc, err := newBaseClaims(user.Id) if err != nil { return "", err } ic := IdClaims{ bc, user.FirstName, user.LastName, user.EmailAddress, user.PhoneNumber, } return encodeAndSignJwt(ic) } func getRefreshToken(user repo.User) (string, error) { session := repo.UserSession{ UserId: user.Id, TokenId: uuid.New().String(), Expiration: time.Now().AddDate(0, 1, 0), } if err := repo.AddUserSession(session); err != nil { return "", err } bc, err := newBaseClaims(user.Id) if err != nil { return "", err } rc := RefreshClaims{ bc, session.TokenId, } rc.Exp = session.Expiration.Unix() return encodeAndSignJwt(rc) } func encodeAndSignJwt(payload any) (string, error) { h, err := json.Marshal(newHeader()) if err != nil { return "", err } p, err := json.Marshal(payload) if err != nil { return "", err } token := base64.URLEncoding.EncodeToString(h) + "." + base64.URLEncoding.EncodeToString(p) token = strings.ReplaceAll(token, "=", "") key, err := config.GetString("jwtKey") if err != nil { return "", err } mac := hmac.New(sha256.New, []byte(key)) _, err = mac.Write([]byte(token)) if err != nil { return "", err } signed := strings.ReplaceAll(token+"."+base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "") return signed, nil } func newHeader() header { return header{ Alg: "HS256", Typ: "JWT", } } func newBaseClaims(userId int64) (baseClaims, error) { aud, err := config.GetString("jwtAudience") if err != nil { return baseClaims{}, err } iss, err := config.GetString("jwtIssuer") if err != nil { return baseClaims{}, err } lt, err := config.GetInt64("jwtLifetime") if err != nil { return baseClaims{}, err } bc := baseClaims{ Aud: aud, Iss: iss, Exp: time.Now().Add(time.Duration(lt * int64(time.Second))).Unix(), Sub: userId, } return bc, nil }