Initial commit
This commit is contained in:
76
pkg/auth/auth.go
Normal file
76
pkg/auth/auth.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"addrss/pkg/repo"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type UserLogin struct {
|
||||
EmailAddress string `json:"emailAddress"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
func AuthenticateGuest() (Tokens, error) {
|
||||
gt, err := getGuestToken()
|
||||
if err != nil {
|
||||
return Tokens{}, err
|
||||
}
|
||||
return Tokens{AccessToken: gt}, nil
|
||||
}
|
||||
|
||||
func AuthenticateUserLogin(userLogin UserLogin) (Tokens, error) {
|
||||
user, err := repo.GetUserByEmail(userLogin.EmailAddress)
|
||||
if err != nil {
|
||||
return Tokens{}, &ErrorUnauthorized{err}
|
||||
}
|
||||
|
||||
if err = bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(userLogin.Password)); err != nil {
|
||||
return Tokens{}, &ErrorUnauthorized{err}
|
||||
}
|
||||
|
||||
tokens, err := AcquireTokens(user)
|
||||
if err != nil {
|
||||
return Tokens{}, &ErrorForbidden{err}
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func AuthenticateUserRefresh(refreshToken string) (Tokens, error) {
|
||||
claims := RefreshClaims{}
|
||||
if err := ValidateJwtToken(refreshToken, &claims); err != nil {
|
||||
return Tokens{}, &ErrorUnauthorized{err}
|
||||
}
|
||||
|
||||
us, err := repo.GetUserSessionById(claims.Sub)
|
||||
if err != nil {
|
||||
return Tokens{}, &ErrorUnauthorized{err}
|
||||
}
|
||||
|
||||
if us.TokenId != claims.Jti {
|
||||
_ = repo.DeleteUserSession(claims.Sub)
|
||||
return Tokens{}, &ErrorUnauthorized{fmt.Errorf("token id mismatch")}
|
||||
}
|
||||
|
||||
user, err := repo.GetUserById(claims.Sub)
|
||||
if err != nil {
|
||||
return Tokens{}, &ErrorUnauthorized{err}
|
||||
}
|
||||
|
||||
tokens, err := AcquireTokens(user)
|
||||
if err != nil {
|
||||
return Tokens{}, &ErrorForbidden{err}
|
||||
}
|
||||
|
||||
return tokens, nil
|
||||
}
|
||||
|
||||
func DestroySession(userId int64) error {
|
||||
if err := repo.DeleteUserSession(userId); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
17
pkg/auth/errors.go
Normal file
17
pkg/auth/errors.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package auth
|
||||
|
||||
type authError struct {
|
||||
InnerError error
|
||||
}
|
||||
|
||||
type ErrorUnauthorized authError
|
||||
|
||||
func (eu *ErrorUnauthorized) Error() string {
|
||||
return "Invalid username or password"
|
||||
}
|
||||
|
||||
type ErrorForbidden authError
|
||||
|
||||
func (ef *ErrorForbidden) Error() string {
|
||||
return "Access is denied"
|
||||
}
|
||||
74
pkg/auth/password.go
Normal file
74
pkg/auth/password.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"addrss/pkg/repo"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type NewPassword struct {
|
||||
Password string `json:"-"`
|
||||
Hash string `json:"-"`
|
||||
}
|
||||
|
||||
type PasswordChange struct {
|
||||
UserId int64 `json:"userId"`
|
||||
OldPassword string `json:"oldPassword"`
|
||||
NewPassword string `json:"newPassword"`
|
||||
ConfirmPassword string `json:"confirmPassword"`
|
||||
}
|
||||
|
||||
func ChangePassword(pc PasswordChange) error {
|
||||
u, err := repo.GetUserById(pc.UserId)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not get user for user id %d: %v", pc.UserId, err)
|
||||
}
|
||||
|
||||
if err = bcrypt.CompareHashAndPassword([]byte(u.Password), []byte(pc.OldPassword)); err != nil {
|
||||
return fmt.Errorf("incorrect password for user id %d", pc.UserId)
|
||||
}
|
||||
|
||||
h, err := bcrypt.GenerateFromPassword([]byte(pc.NewPassword), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return fmt.Errorf("bcrypt error: %v", err)
|
||||
}
|
||||
|
||||
u.Password = string(h)
|
||||
if err = repo.UpdateUserPassword(u); err != nil {
|
||||
return fmt.Errorf("failed to update password for user id %d: %v", u.Id, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GetRandomPassword(length int) (NewPassword, error) {
|
||||
p := GetRandomString(length)
|
||||
|
||||
h, err := bcrypt.GenerateFromPassword([]byte(p), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return NewPassword{}, err
|
||||
}
|
||||
|
||||
np := NewPassword{
|
||||
Password: p,
|
||||
Hash: string(h),
|
||||
}
|
||||
|
||||
return np, nil
|
||||
}
|
||||
|
||||
// GetRandomString Keep as a separate function in case this becomes useful for some other purpose
|
||||
func GetRandomString(length int) string {
|
||||
const chars = "!@#$%?abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
rnd := rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
|
||||
b := make([]byte, length)
|
||||
for i := range b {
|
||||
b[i] = chars[rnd.Intn(len(chars))]
|
||||
}
|
||||
|
||||
return string(b)
|
||||
}
|
||||
268
pkg/auth/tokens.go
Normal file
268
pkg/auth/tokens.go
Normal file
@@ -0,0 +1,268 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"addrss/pkg/config"
|
||||
"addrss/pkg/repo"
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
type header struct {
|
||||
Alg string `json:"alg"`
|
||||
Typ string `json:"typ"`
|
||||
}
|
||||
|
||||
type baseClaims struct {
|
||||
Aud string `json:"aud"`
|
||||
Iss string `json:"iss"`
|
||||
Exp int64 `json:"exp"`
|
||||
Sub int64 `json:"sub"`
|
||||
}
|
||||
|
||||
type Payload interface {
|
||||
getBaseClaims() baseClaims
|
||||
*AccessClaims | *IdClaims | *RefreshClaims
|
||||
}
|
||||
|
||||
type AccessClaims struct {
|
||||
baseClaims
|
||||
Scopes string `json:"scope"`
|
||||
}
|
||||
|
||||
func (ac *AccessClaims) getBaseClaims() baseClaims {
|
||||
return ac.baseClaims
|
||||
}
|
||||
|
||||
type IdClaims struct {
|
||||
baseClaims
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
Email string `json:"email"`
|
||||
PhoneNumber string `json:"phone_number"`
|
||||
}
|
||||
|
||||
func (ic *IdClaims) getBaseClaims() baseClaims {
|
||||
return ic.baseClaims
|
||||
}
|
||||
|
||||
type RefreshClaims struct {
|
||||
baseClaims
|
||||
Jti string `json:"jti"`
|
||||
}
|
||||
|
||||
func (rc *RefreshClaims) getBaseClaims() baseClaims {
|
||||
return rc.baseClaims
|
||||
}
|
||||
|
||||
type Tokens struct {
|
||||
AccessToken string `json:"accessToken,omitempty"`
|
||||
IdToken string `json:"idToken,omitempty"`
|
||||
RefreshToken string `json:"refreshToken,omitempty"`
|
||||
}
|
||||
|
||||
func AcquireTokens(user repo.User) (Tokens, error) {
|
||||
at, err := getAccessToken(user)
|
||||
if err != nil {
|
||||
return Tokens{}, err
|
||||
}
|
||||
|
||||
it, err := getIdToken(user)
|
||||
if err != nil {
|
||||
return Tokens{}, err
|
||||
}
|
||||
|
||||
rt, err := getRefreshToken(user)
|
||||
if err != nil {
|
||||
return Tokens{}, err
|
||||
}
|
||||
|
||||
t := Tokens{
|
||||
AccessToken: at,
|
||||
IdToken: it,
|
||||
RefreshToken: rt,
|
||||
}
|
||||
|
||||
return t, nil
|
||||
}
|
||||
|
||||
func ValidateJwtToken[T Payload](token string, claims T) error {
|
||||
segments := strings.Split(token, ".")
|
||||
if len(segments) != 3 {
|
||||
return fmt.Errorf("invalid segment count")
|
||||
}
|
||||
|
||||
k, err := config.GetString("jwtKey")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(k))
|
||||
mac.Write([]byte(segments[0] + "." + segments[1]))
|
||||
if strings.ReplaceAll(base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "") != segments[2] {
|
||||
return fmt.Errorf("invalid sigature")
|
||||
}
|
||||
|
||||
for len(segments[1])%4 != 0 {
|
||||
segments[1] += "="
|
||||
}
|
||||
|
||||
bytes, _ := base64.URLEncoding.DecodeString(segments[1])
|
||||
if err := json.Unmarshal(bytes, &claims); err != nil {
|
||||
return fmt.Errorf("failed to unmarshal token payload")
|
||||
}
|
||||
|
||||
c := claims.getBaseClaims()
|
||||
aud, err := config.GetString("jwtAudience")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
iss, err := config.GetString("jwtIssuer")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if c.Aud == aud && c.Iss == iss && c.Exp >= time.Now().Unix() {
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("invalid payload parameters")
|
||||
}
|
||||
|
||||
func getGuestToken() (string, error) {
|
||||
bc, err := newBaseClaims(0)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ap := AccessClaims{
|
||||
bc,
|
||||
"",
|
||||
}
|
||||
return encodeAndSignJwt(ap)
|
||||
}
|
||||
|
||||
func getAccessToken(user repo.User) (string, error) {
|
||||
bc, err := newBaseClaims(user.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ac := AccessClaims{
|
||||
bc,
|
||||
"",
|
||||
}
|
||||
|
||||
return encodeAndSignJwt(ac)
|
||||
}
|
||||
|
||||
func getIdToken(user repo.User) (string, error) {
|
||||
bc, err := newBaseClaims(user.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
ic := IdClaims{
|
||||
bc,
|
||||
user.FirstName,
|
||||
user.LastName,
|
||||
user.EmailAddress,
|
||||
user.PhoneNumber,
|
||||
}
|
||||
|
||||
return encodeAndSignJwt(ic)
|
||||
}
|
||||
|
||||
func getRefreshToken(user repo.User) (string, error) {
|
||||
session := repo.UserSession{
|
||||
UserId: user.Id,
|
||||
TokenId: uuid.New().String(),
|
||||
Expiration: time.Now().AddDate(0, 1, 0),
|
||||
}
|
||||
|
||||
if err := repo.AddUserSession(session); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
bc, err := newBaseClaims(user.Id)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
rc := RefreshClaims{
|
||||
bc,
|
||||
session.TokenId,
|
||||
}
|
||||
rc.Exp = session.Expiration.Unix()
|
||||
|
||||
return encodeAndSignJwt(rc)
|
||||
}
|
||||
|
||||
func encodeAndSignJwt(payload any) (string, error) {
|
||||
h, err := json.Marshal(newHeader())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
p, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
token := base64.URLEncoding.EncodeToString(h) + "." + base64.URLEncoding.EncodeToString(p)
|
||||
token = strings.ReplaceAll(token, "=", "")
|
||||
|
||||
key, err := config.GetString("jwtKey")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
mac := hmac.New(sha256.New, []byte(key))
|
||||
_, err = mac.Write([]byte(token))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
signed := strings.ReplaceAll(token+"."+base64.URLEncoding.EncodeToString(mac.Sum(nil)), "=", "")
|
||||
return signed, nil
|
||||
}
|
||||
|
||||
func newHeader() header {
|
||||
return header{
|
||||
Alg: "HS256",
|
||||
Typ: "JWT",
|
||||
}
|
||||
}
|
||||
|
||||
func newBaseClaims(userId int64) (baseClaims, error) {
|
||||
aud, err := config.GetString("jwtAudience")
|
||||
if err != nil {
|
||||
return baseClaims{}, err
|
||||
}
|
||||
|
||||
iss, err := config.GetString("jwtIssuer")
|
||||
if err != nil {
|
||||
return baseClaims{}, err
|
||||
}
|
||||
|
||||
lt, err := config.GetInt64("jwtLifetime")
|
||||
if err != nil {
|
||||
return baseClaims{}, err
|
||||
}
|
||||
|
||||
bc := baseClaims{
|
||||
Aud: aud,
|
||||
Iss: iss,
|
||||
Exp: time.Now().Add(time.Duration(lt * int64(time.Second))).Unix(),
|
||||
Sub: userId,
|
||||
}
|
||||
|
||||
return bc, nil
|
||||
}
|
||||
Reference in New Issue
Block a user