Issue JWT on login

This commit is contained in:
Daniel Mason 2021-03-26 12:21:28 +13:00
parent 12f4fb6d89
commit c67be1bd75
9 changed files with 249 additions and 27 deletions

View File

@ -5,3 +5,5 @@ MYSQL_DB=
JWT_SECRET= JWT_SECRET=
JWT_EXPIRY=86400 JWT_EXPIRY=86400
REVERSE_PROXIES=127.0.0.1

View File

@ -3,6 +3,9 @@ package main
import ( import (
"log" "log"
"os" "os"
"strconv"
"strings"
"time"
"git.m2.nz/go-scrobble/internal/goscrobble" "git.m2.nz/go-scrobble/internal/goscrobble"
"github.com/joho/godotenv" "github.com/joho/godotenv"
@ -17,6 +20,21 @@ func main() {
// Store JWT secret // Store JWT secret
goscrobble.JwtToken = []byte(os.Getenv("JWT_SECRET")) goscrobble.JwtToken = []byte(os.Getenv("JWT_SECRET"))
// Store JWT expiry
goscrobble.JwtExpiry = 86400
jwtExpiryStr := os.Getenv("JWT_EXPIRY")
if jwtExpiryStr != "" {
i, err := strconv.ParseFloat(jwtExpiryStr, 64)
if err != nil {
panic("Invalid JWT_EXPIRY value")
}
goscrobble.JwtExpiry = time.Duration(i) * time.Second
}
// Ignore reverse proxies
goscrobble.ReverseProxies = strings.Split(os.Getenv("REVERSE_PROXIES"), ",")
// // Boot up DB connection for life of application // // Boot up DB connection for life of application
goscrobble.InitDb() goscrobble.InitDb()
defer goscrobble.CloseDbConn() defer goscrobble.CloseDbConn()

View File

@ -2,6 +2,7 @@ package goscrobble
import ( import (
"database/sql" "database/sql"
"errors"
"fmt" "fmt"
"log" "log"
"os" "os"
@ -71,3 +72,13 @@ func runMigrations() {
panic(fmt.Errorf("Error running DB Migrations %v", err)) panic(fmt.Errorf("Error running DB Migrations %v", err))
} }
} }
func getDbCount(query string, args ...interface{}) (int, error) {
var result int
err := db.QueryRow(query, args...).Scan(&result)
if err != nil {
return 0, errors.New("Error fetching data")
}
return result, nil
}

View File

@ -1,8 +1,8 @@
package goscrobble package goscrobble
import ( import (
"log"
"net/http" "net/http"
"time"
"github.com/dgrijalva/jwt-go" "github.com/dgrijalva/jwt-go"
) )
@ -10,6 +10,9 @@ import (
// JwtToken - Store token from .env // JwtToken - Store token from .env
var JwtToken []byte var JwtToken []byte
// JwtExpiry - Expiry in seconds
var JwtExpiry time.Duration
// Store custom claims here // Store custom claims here
type Claims struct { type Claims struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
@ -26,7 +29,6 @@ func verifyToken(token string, w http.ResponseWriter) bool {
}) })
if err != nil { if err != nil {
log.Printf("%v", err)
if err == jwt.ErrSignatureInvalid { if err == jwt.ErrSignatureInvalid {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
return false return false

View File

@ -29,6 +29,9 @@ var heavyLimiter = NewIPRateLimiter(0.1, 1)
// Limits to 5 req / sec // Limits to 5 req / sec
var standardLimiter = NewIPRateLimiter(5, 5) var standardLimiter = NewIPRateLimiter(5, 5)
// List of Reverse proxies
var ReverseProxies []string
// HandleRequests - Boot HTTP! // HandleRequests - Boot HTTP!
func HandleRequests() { func HandleRequests() {
// Create a new router // Create a new router
@ -44,7 +47,7 @@ func HandleRequests() {
// No Auth // No Auth
v1.HandleFunc("/register", limitMiddleware(handleRegister, heavyLimiter)).Methods("POST") v1.HandleFunc("/register", limitMiddleware(handleRegister, heavyLimiter)).Methods("POST")
v1.HandleFunc("/login", limitMiddleware(serveEndpoint, standardLimiter)).Methods("POST") v1.HandleFunc("/login", limitMiddleware(handleLogin, standardLimiter)).Methods("POST")
v1.HandleFunc("/logout", serveEndpoint).Methods("POST") v1.HandleFunc("/logout", serveEndpoint).Methods("POST")
// This just prevents it serving frontend stuff over /api // This just prevents it serving frontend stuff over /api
@ -133,7 +136,8 @@ func handleRegister(w http.ResponseWriter, r *http.Request) {
return return
} }
err = createUser(&regReq) ip := getUserIp(r)
err = createUser(&regReq, ip)
if err != nil { if err != nil {
throwBadReq(w, err.Error()) throwBadReq(w, err.Error())
return return
@ -144,16 +148,36 @@ func handleRegister(w http.ResponseWriter, r *http.Request) {
w.Write(msg) w.Write(msg)
} }
// handleLogin - Does as it says!
func handleLogin(w http.ResponseWriter, r *http.Request) {
logReq := LoginRequest{}
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&logReq)
if err != nil {
throwBadReq(w, err.Error())
return
}
ip := getUserIp(r)
data, err := loginUser(&logReq, ip)
if err != nil {
throwBadReq(w, err.Error())
return
}
w.WriteHeader(http.StatusOK)
w.Write(data)
}
// serveEndpoint - API stuffs // serveEndpoint - API stuffs
func serveEndpoint(w http.ResponseWriter, r *http.Request) { func serveEndpoint(w http.ResponseWriter, r *http.Request) {
json, err := decodeJson(r.Body) _, err := decodeJson(r.Body)
if err != nil { if err != nil {
// If we can't decode. Lets tell them nicely. // If we can't decode. Lets tell them nicely.
http.Error(w, "{\"error\":\"Invalid JSON\"}", http.StatusBadRequest) http.Error(w, "{\"error\":\"Invalid JSON\"}", http.StatusBadRequest)
return return
} }
log.Printf("%v", json)
// Lets trick 'em for now ;) ;) // Lets trick 'em for now ;) ;)
fmt.Fprintf(w, "{}") fmt.Fprintf(w, "{}")
} }

View File

@ -1,24 +1,33 @@
package goscrobble package goscrobble
import ( import (
"database/sql"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"log"
"net"
"strings"
"time" "time"
"github.com/dgrijalva/jwt-go"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
) )
const bCryptCost = 16 const bCryptCost = 16
type User struct { type User struct {
UUID string `json:"uuid"` UUID string `json:"uuid"`
CreatedAt time.Time `json:"created_at"` CreatedAt time.Time `json:"created_at"`
Username string `json:"username"` CreatedIp net.IP `json:"created_ip"`
password []byte ModifiedAt time.Time `json:"modified_at"`
Email string `json:"email"` ModifiedIP net.IP `jsos:"modified_ip"`
Verified bool `json:"verified"` Username string `json:"username"`
Active bool `json:"active"` Password []byte `json:"password"`
Admin bool `json:"admin"` Email string `json:"email"`
Verified bool `json:"verified"`
Active bool `json:"active"`
Admin bool `json:"admin"`
} }
// RegisterRequest - Incoming JSON // RegisterRequest - Incoming JSON
@ -28,18 +37,39 @@ type RegisterRequest struct {
Password string `json:"password"` Password string `json:"password"`
} }
// RegisterRequest - Incoming JSON
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
// LoginResponse - JWT issued
type LoginResponse struct {
Token string `json:"token"`
}
// createUser - Called from API // createUser - Called from API
func createUser(req *RegisterRequest) error { func createUser(req *RegisterRequest, ip net.IP) error {
// Check if user already exists.. // Check if user already exists..
if len(req.Password) < 8 { if len(req.Password) < 8 {
return errors.New("Password must be at least 8 characters") return errors.New("Password must be at least 8 characters")
} }
// Check username is set // Check Username is set
if req.Username == "" { if req.Username == "" {
return errors.New("A username is required") return errors.New("A username is required")
} }
// Check max length for Username
if len(req.Username) > 64 {
return errors.New("Username cannot be longer than 64 characters")
}
// Check username doesn't contain @
if strings.Contains(req.Username, "@") {
return errors.New("Username contains invalid characters")
}
// If set an email.. validate it! // If set an email.. validate it!
if req.Email != "" { if req.Email != "" {
if !isEmailValid(req.Email) { if !isEmailValid(req.Email) {
@ -58,12 +88,85 @@ func createUser(req *RegisterRequest) error {
return err return err
} }
return insertUser(req.Username, req.Email, hash) return insertUser(req.Username, req.Email, hash, ip)
}
func loginUser(logReq *LoginRequest, ip net.IP) ([]byte, error) {
var resp []byte
var user User
if logReq.Username == "" {
return resp, errors.New("username must be set")
}
if logReq.Password == "" {
return resp, errors.New("password must be set")
}
if strings.Contains(logReq.Username, "@") {
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE email = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
if err != nil {
if err == sql.ErrNoRows {
return resp, errors.New("Invalid Username or Password")
}
}
} else {
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE username = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
if err == sql.ErrNoRows {
return resp, errors.New("Invalid Username or Password")
}
}
if !isValidPassword(logReq.Password, user) {
return resp, errors.New("Invalid Username or Password")
}
// Issue JWT + Response
token, err := generateJwt(user)
if err != nil {
log.Printf("Error generating JWT: %v", err)
return resp, errors.New("Error logging in")
}
loginResp := LoginResponse{
Token: token,
}
resp, _ = json.Marshal(&loginResp)
return resp, nil
}
func generateJwt(user User) (string, error) {
atClaims := jwt.MapClaims{}
atClaims["sub"] = user.UUID
atClaims["username"] = user.Username
atClaims["email"] = user.Email
atClaims["exp"] = time.Now().Add(JwtExpiry).Unix()
at := jwt.NewWithClaims(jwt.SigningMethodHS512, atClaims)
token, err := at.SignedString(JwtToken)
if err != nil {
return "", err
}
return token, nil
} }
// insertUser - Does the dirtywork! // insertUser - Does the dirtywork!
func insertUser(username string, email string, password []byte) error { func insertUser(username string, email string, password []byte, ip net.IP) error {
_, err := db.Exec("INSERT INTO users (uuid, created_at, username, email, password) VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,?,?)", username, email, password) _, err := db.Exec("INSERT INTO users (uuid, created_at, created_ip, modified_at, modified_ip, username, email, password) "+
"VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,NOW(),?,?,?,?)", ip, ip, username, email, password)
return err
}
func updateUser(uuid string, field string, value string, ip string) error {
_, err := db.Exec("UPDATE users SET ? = ?, modified_at = NOW(), modified_ip = ? WHERE uuid = ?", field, value, uuid, ip)
return err
}
func updateUserDirect(uuid string, field string, value string) error {
_, err := db.Exec("UPDATE users SET ? = ? WHERE uuid = ?", field, value, uuid)
return err return err
} }
@ -75,7 +178,7 @@ func hashPassword(password string) ([]byte, error) {
// isValidPassword - Checks if password is valid // isValidPassword - Checks if password is valid
func isValidPassword(password string, user User) bool { func isValidPassword(password string, user User) bool {
err := bcrypt.CompareHashAndPassword(user.password, []byte(password)) err := bcrypt.CompareHashAndPassword(user.Password, []byte(password))
if err != nil { if err != nil {
return false return false
} }
@ -86,15 +189,19 @@ func isValidPassword(password string, user User) bool {
// userAlreadyExists - Returns bool indicating if a record exists for either username or email // userAlreadyExists - Returns bool indicating if a record exists for either username or email
// Using two look ups to make use of DB indexes. // Using two look ups to make use of DB indexes.
func userAlreadyExists(req *RegisterRequest) bool { func userAlreadyExists(req *RegisterRequest) bool {
var userExists int count, err := getDbCount("SELECT COUNT(*) FROM users WHERE username = ?", req.Username)
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE username = ?", req.Username).Scan(&userExists) if err != nil {
if userExists > 0 { fmt.Printf("Error querying for duplicate users: %v", err)
return true
}
if count > 0 {
return true return true
} }
if req.Email != "" { if req.Email != "" {
// Only run email check if there's an email... // Only run email check if there's an email...
err = db.QueryRow("SELECT COUNT(*) FROM users WHERE email = ?", req.Email).Scan(&userExists) count, err = getDbCount("SELECT COUNT(*) FROM users WHERE email = ?", req.Email)
} }
if err != nil { if err != nil {
@ -102,5 +209,5 @@ func userAlreadyExists(req *RegisterRequest) bool {
return true return true
} }
return userExists > 0 return count > 0
} }

View File

@ -1,8 +1,12 @@
package goscrobble package goscrobble
import ( import (
"encoding/hex"
"encoding/json" "encoding/json"
"io" "io"
"math/big"
"net"
"net/http"
"regexp" "regexp"
) )
@ -24,3 +28,54 @@ func isEmailValid(e string) bool {
} }
return emailRegex.MatchString(e) return emailRegex.MatchString(e)
} }
// contains - Check if string is in list
func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// getUserIp - Returns IP that isn't set in REVERSE_PROXY
func getUserIp(r *http.Request) net.IP {
var ip net.IP
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if contains(ReverseProxies, host) {
host = r.Header.Get("X-FOWARDED-FOR")
}
ip = net.ParseIP(host)
return ip
}
// Inet_Aton converts an IPv4 net.IP object to a 64 bit integer.
func Inet_Aton(ip net.IP) int64 {
ipv4Int := big.NewInt(0)
ipv4Int.SetBytes(ip.To4())
return ipv4Int.Int64()
}
// Inet6_Aton converts an IP Address (IPv4 or IPv6) net.IP object to a hexadecimal
// representaiton. This function is the equivalent of
// inet6_aton({{ ip address }}) in MySQL.
func Inet6_Aton(ip net.IP) string {
ipv4 := false
if ip.To4() != nil {
ipv4 = true
}
ipInt := big.NewInt(0)
if ipv4 {
ipInt.SetBytes(ip.To4())
ipHex := hex.EncodeToString(ipInt.Bytes())
return ipHex
}
ipInt.SetBytes(ip.To16())
ipHex := hex.EncodeToString(ipInt.Bytes())
return ipHex
}

View File

@ -1,5 +1,5 @@
CREATE TABLE IF NOT EXISTS `config` ( CREATE TABLE IF NOT EXISTS `config` (
`key` VARCHAR(255) NOT NULL, `key` VARCHAR(255) NOT NULL,
`value` INT(11) NOT NULL DEFAULT 1, `value` VARCHAR(255) NULL DEFAULT NULL,
PRIMARY KEY(`key`) PRIMARY KEY(`key`)
) DEFAULT CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci; ) DEFAULT CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci;

View File

@ -1,6 +1,9 @@
CREATE TABLE IF NOT EXISTS `users` ( CREATE TABLE IF NOT EXISTS `users` (
`uuid` BINARY(16) PRIMARY KEY, `uuid` BINARY(16) PRIMARY KEY,
`created_at` DATETIME NOT NULL, `created_at` DATETIME NOT NULL,
`created_ip` VARBINARY(16) NULL DEFAULT NULL,
`modified_at` DATETIME NOT NULL,
`modified_ip` VARBINARY(16) NULL DEFAULT NULL,
`username` VARCHAR(64) NOT NULL, `username` VARCHAR(64) NOT NULL,
`password` VARCHAR(60) NOT NULL, `password` VARCHAR(60) NOT NULL,
`email` VARCHAR(255) NULL DEFAULT NULL, `email` VARCHAR(255) NULL DEFAULT NULL,