Initial commit

This commit is contained in:
2025-09-06 21:35:45 -04:00
commit b02525e28a
32 changed files with 1478 additions and 0 deletions

76
pkg/auth/auth.go Normal file
View File

@@ -0,0 +1,76 @@
package auth
import (
"addrss/pkg/repo"
"fmt"
"golang.org/x/crypto/bcrypt"
)
type UserLogin struct {
EmailAddress string `json:"emailAddress"`
Password string `json:"password"`
}
func AuthenticateGuest() (Tokens, error) {
gt, err := getGuestToken()
if err != nil {
return Tokens{}, err
}
return Tokens{AccessToken: gt}, nil
}
func AuthenticateUserLogin(userLogin UserLogin) (Tokens, error) {
user, err := repo.GetUserByEmail(userLogin.EmailAddress)
if err != nil {
return Tokens{}, &ErrorUnauthorized{err}
}
if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(userLogin.Password)); err != nil {
return Tokens{}, &ErrorUnauthorized{err}
}
tokens, err := AcquireTokens(user)
if err != nil {
return Tokens{}, &ErrorForbidden{err}
}
return tokens, nil
}
func AuthenticateUserRefresh(refreshToken string) (Tokens, error) {
claims := RefreshClaims{}
if err := ValidateJwtToken(refreshToken, &claims); err != nil {
return Tokens{}, &ErrorUnauthorized{err}
}
us, err := repo.GetUserSessionById(claims.Sub)
if err != nil {
return Tokens{}, &ErrorUnauthorized{err}
}
if us.TokenId != claims.Jti {
_ = repo.DeleteUserSession(claims.Sub)
return Tokens{}, &ErrorUnauthorized{fmt.Errorf("token id mismatch")}
}
user, err := repo.GetUserById(claims.Sub)
if err != nil {
return Tokens{}, &ErrorUnauthorized{err}
}
tokens, err := AcquireTokens(user)
if err != nil {
return Tokens{}, &ErrorForbidden{err}
}
return tokens, nil
}
func DestroySession(userId int64) error {
if err := repo.DeleteUserSession(userId); err != nil {
return err
}
return nil
}

17
pkg/auth/errors.go Normal file
View File

@@ -0,0 +1,17 @@
package auth
type authError struct {
InnerError error
}
type ErrorUnauthorized authError
func (eu *ErrorUnauthorized) Error() string {
return "Invalid username or password"
}
type ErrorForbidden authError
func (ef *ErrorForbidden) Error() string {
return "Access is denied"
}

74
pkg/auth/password.go Normal file
View File

@@ -0,0 +1,74 @@
package auth
import (
"addrss/pkg/repo"
"fmt"
"math/rand"
"time"
"golang.org/x/crypto/bcrypt"
)
type NewPassword struct {
Password string `json:"-"`
Hash string `json:"-"`
}
type PasswordChange struct {
UserId int64 `json:"userId"`
OldPassword string `json:"oldPassword"`
NewPassword string `json:"newPassword"`
ConfirmPassword string `json:"confirmPassword"`
}
func ChangePassword(pc PasswordChange) error {
u, err := repo.GetUserById(pc.UserId)
if err != nil {
return fmt.Errorf("could not get user for user id %d: %v", pc.UserId, err)
}
if err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(pc.OldPassword)); err != nil {
return fmt.Errorf("incorrect password for user id %d", pc.UserId)
}
h, err := bcrypt.GenerateFromPassword([]byte(pc.NewPassword), bcrypt.DefaultCost)
if err != nil {
return fmt.Errorf("bcrypt error: %v", err)
}
u.Password = string(h)
if err = repo.UpdateUserPassword(u); err != nil {
return fmt.Errorf("failed to update password for user id %d: %v", u.Id, err)
}
return nil
}
func GetRandomPassword(length int) (NewPassword, error) {
p := GetRandomString(length)
h, err := bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
if err != nil {
return NewPassword{}, err
}
np := NewPassword{
Password: p,
Hash: string(h),
}
return np, nil
}
// GetRandomString Keep as a separate function in case this becomes useful for some other purpose
func GetRandomString(length int) string {
const chars = "!@#$%?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
b := make([]byte, length)
for i := range b {
b[i] = chars[rnd.Intn(len(chars))]
}
return string(b)
}

268
pkg/auth/tokens.go Normal file
View File

@@ -0,0 +1,268 @@
package auth
import (
"addrss/pkg/config"
"addrss/pkg/repo"
"crypto/hmac"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/google/uuid"
)
type header struct {
Alg string `json:"alg"`
Typ string `json:"typ"`
}
type baseClaims struct {
Aud string `json:"aud"`
Iss string `json:"iss"`
Exp int64 `json:"exp"`
Sub int64 `json:"sub"`
}
type Payload interface {
getBaseClaims() baseClaims
*AccessClaims | *IdClaims | *RefreshClaims
}
type AccessClaims struct {
baseClaims
Scopes string `json:"scope"`
}
func (ac *AccessClaims) getBaseClaims() baseClaims {
return ac.baseClaims
}
type IdClaims struct {
baseClaims
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
Email string `json:"email"`
PhoneNumber string `json:"phone_number"`
}
func (ic *IdClaims) getBaseClaims() baseClaims {
return ic.baseClaims
}
type RefreshClaims struct {
baseClaims
Jti string `json:"jti"`
}
func (rc *RefreshClaims) getBaseClaims() baseClaims {
return rc.baseClaims
}
type Tokens struct {
AccessToken string `json:"accessToken,omitempty"`
IdToken string `json:"idToken,omitempty"`
RefreshToken string `json:"refreshToken,omitempty"`
}
func AcquireTokens(user repo.User) (Tokens, error) {
at, err := getAccessToken(user)
if err != nil {
return Tokens{}, err
}
it, err := getIdToken(user)
if err != nil {
return Tokens{}, err
}
rt, err := getRefreshToken(user)
if err != nil {
return Tokens{}, err
}
t := Tokens{
AccessToken: at,
IdToken: it,
RefreshToken: rt,
}
return t, nil
}
func ValidateJwtToken[T Payload](token string, claims T) error {
segments := strings.Split(token, ".")
if len(segments) != 3 {
return fmt.Errorf("invalid segment count")
}
k, err := config.GetString("jwtKey")
if err != nil {
return err
}
mac := hmac.New(sha256.New, []byte(k))
mac.Write([]byte(segments[0] + "." + segments[1]))
if strings.ReplaceAll(base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "") != segments[2] {
return fmt.Errorf("invalid sigature")
}
for len(segments[1])%4 != 0 {
segments[1] += "="
}
bytes, _ := base64.URLEncoding.DecodeString(segments[1])
if err := json.Unmarshal(bytes, &claims); err != nil {
return fmt.Errorf("failed to unmarshal token payload")
}
c := claims.getBaseClaims()
aud, err := config.GetString("jwtAudience")
if err != nil {
return err
}
iss, err := config.GetString("jwtIssuer")
if err != nil {
return err
}
if c.Aud == aud && c.Iss == iss && c.Exp >= time.Now().Unix() {
return nil
}
return fmt.Errorf("invalid payload parameters")
}
func getGuestToken() (string, error) {
bc, err := newBaseClaims(0)
if err != nil {
return "", err
}
ap := AccessClaims{
bc,
"",
}
return encodeAndSignJwt(ap)
}
func getAccessToken(user repo.User) (string, error) {
bc, err := newBaseClaims(user.Id)
if err != nil {
return "", err
}
ac := AccessClaims{
bc,
"",
}
return encodeAndSignJwt(ac)
}
func getIdToken(user repo.User) (string, error) {
bc, err := newBaseClaims(user.Id)
if err != nil {
return "", err
}
ic := IdClaims{
bc,
user.FirstName,
user.LastName,
user.EmailAddress,
user.PhoneNumber,
}
return encodeAndSignJwt(ic)
}
func getRefreshToken(user repo.User) (string, error) {
session := repo.UserSession{
UserId: user.Id,
TokenId: uuid.New().String(),
Expiration: time.Now().AddDate(0, 1, 0),
}
if err := repo.AddUserSession(session); err != nil {
return "", err
}
bc, err := newBaseClaims(user.Id)
if err != nil {
return "", err
}
rc := RefreshClaims{
bc,
session.TokenId,
}
rc.Exp = session.Expiration.Unix()
return encodeAndSignJwt(rc)
}
func encodeAndSignJwt(payload any) (string, error) {
h, err := json.Marshal(newHeader())
if err != nil {
return "", err
}
p, err := json.Marshal(payload)
if err != nil {
return "", err
}
token := base64.URLEncoding.EncodeToString(h) + "." + base64.URLEncoding.EncodeToString(p)
token = strings.ReplaceAll(token, "=", "")
key, err := config.GetString("jwtKey")
if err != nil {
return "", err
}
mac := hmac.New(sha256.New, []byte(key))
_, err = mac.Write([]byte(token))
if err != nil {
return "", err
}
signed := strings.ReplaceAll(token+"."+base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "")
return signed, nil
}
func newHeader() header {
return header{
Alg: "HS256",
Typ: "JWT",
}
}
func newBaseClaims(userId int64) (baseClaims, error) {
aud, err := config.GetString("jwtAudience")
if err != nil {
return baseClaims{}, err
}
iss, err := config.GetString("jwtIssuer")
if err != nil {
return baseClaims{}, err
}
lt, err := config.GetInt64("jwtLifetime")
if err != nil {
return baseClaims{}, err
}
bc := baseClaims{
Aud: aud,
Iss: iss,
Exp: time.Now().Add(time.Duration(lt * int64(time.Second))).Unix(),
Sub: userId,
}
return bc, nil
}