Issue JWT on login

This commit is contained in:
Daniel Mason 2021-03-26 12:21:28 +13:00
parent 12f4fb6d89
commit c67be1bd75
9 changed files with 249 additions and 27 deletions

View File

@ -5,3 +5,5 @@ MYSQL_DB=
JWT_SECRET=
JWT_EXPIRY=86400
REVERSE_PROXIES=127.0.0.1

View File

@ -3,6 +3,9 @@ package main
import (
"log"
"os"
"strconv"
"strings"
"time"
"git.m2.nz/go-scrobble/internal/goscrobble"
"github.com/joho/godotenv"
@ -17,6 +20,21 @@ func main() {
// Store JWT secret
goscrobble.JwtToken = []byte(os.Getenv("JWT_SECRET"))
// Store JWT expiry
goscrobble.JwtExpiry = 86400
jwtExpiryStr := os.Getenv("JWT_EXPIRY")
if jwtExpiryStr != "" {
i, err := strconv.ParseFloat(jwtExpiryStr, 64)
if err != nil {
panic("Invalid JWT_EXPIRY value")
}
goscrobble.JwtExpiry = time.Duration(i) * time.Second
}
// Ignore reverse proxies
goscrobble.ReverseProxies = strings.Split(os.Getenv("REVERSE_PROXIES"), ",")
// // Boot up DB connection for life of application
goscrobble.InitDb()
defer goscrobble.CloseDbConn()

View File

@ -2,6 +2,7 @@ package goscrobble
import (
"database/sql"
"errors"
"fmt"
"log"
"os"
@ -71,3 +72,13 @@ func runMigrations() {
panic(fmt.Errorf("Error running DB Migrations %v", err))
}
}
func getDbCount(query string, args ...interface{}) (int, error) {
var result int
err := db.QueryRow(query, args...).Scan(&result)
if err != nil {
return 0, errors.New("Error fetching data")
}
return result, nil
}

View File

@ -1,8 +1,8 @@
package goscrobble
import (
"log"
"net/http"
"time"
"github.com/dgrijalva/jwt-go"
)
@ -10,6 +10,9 @@ import (
// JwtToken - Store token from .env
var JwtToken []byte
// JwtExpiry - Expiry in seconds
var JwtExpiry time.Duration
// Store custom claims here
type Claims struct {
UUID string `json:"uuid"`
@ -26,7 +29,6 @@ func verifyToken(token string, w http.ResponseWriter) bool {
})
if err != nil {
log.Printf("%v", err)
if err == jwt.ErrSignatureInvalid {
w.WriteHeader(http.StatusUnauthorized)
return false

View File

@ -29,6 +29,9 @@ var heavyLimiter = NewIPRateLimiter(0.1, 1)
// Limits to 5 req / sec
var standardLimiter = NewIPRateLimiter(5, 5)
// List of Reverse proxies
var ReverseProxies []string
// HandleRequests - Boot HTTP!
func HandleRequests() {
// Create a new router
@ -44,7 +47,7 @@ func HandleRequests() {
// No Auth
v1.HandleFunc("/register", limitMiddleware(handleRegister, heavyLimiter)).Methods("POST")
v1.HandleFunc("/login", limitMiddleware(serveEndpoint, standardLimiter)).Methods("POST")
v1.HandleFunc("/login", limitMiddleware(handleLogin, standardLimiter)).Methods("POST")
v1.HandleFunc("/logout", serveEndpoint).Methods("POST")
// This just prevents it serving frontend stuff over /api
@ -133,7 +136,8 @@ func handleRegister(w http.ResponseWriter, r *http.Request) {
return
}
err = createUser(&regReq)
ip := getUserIp(r)
err = createUser(&regReq, ip)
if err != nil {
throwBadReq(w, err.Error())
return
@ -144,16 +148,36 @@ func handleRegister(w http.ResponseWriter, r *http.Request) {
w.Write(msg)
}
// handleLogin - Does as it says!
func handleLogin(w http.ResponseWriter, r *http.Request) {
logReq := LoginRequest{}
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&logReq)
if err != nil {
throwBadReq(w, err.Error())
return
}
ip := getUserIp(r)
data, err := loginUser(&logReq, ip)
if err != nil {
throwBadReq(w, err.Error())
return
}
w.WriteHeader(http.StatusOK)
w.Write(data)
}
// serveEndpoint - API stuffs
func serveEndpoint(w http.ResponseWriter, r *http.Request) {
json, err := decodeJson(r.Body)
_, err := decodeJson(r.Body)
if err != nil {
// If we can't decode. Lets tell them nicely.
http.Error(w, "{\"error\":\"Invalid JSON\"}", http.StatusBadRequest)
return
}
log.Printf("%v", json)
// Lets trick 'em for now ;) ;)
fmt.Fprintf(w, "{}")
}

View File

@ -1,10 +1,16 @@
package goscrobble
import (
"database/sql"
"encoding/json"
"errors"
"fmt"
"log"
"net"
"strings"
"time"
"github.com/dgrijalva/jwt-go"
"golang.org/x/crypto/bcrypt"
)
@ -13,8 +19,11 @@ const bCryptCost = 16
type User struct {
UUID string `json:"uuid"`
CreatedAt time.Time `json:"created_at"`
CreatedIp net.IP `json:"created_ip"`
ModifiedAt time.Time `json:"modified_at"`
ModifiedIP net.IP `jsos:"modified_ip"`
Username string `json:"username"`
password []byte
Password []byte `json:"password"`
Email string `json:"email"`
Verified bool `json:"verified"`
Active bool `json:"active"`
@ -28,18 +37,39 @@ type RegisterRequest struct {
Password string `json:"password"`
}
// RegisterRequest - Incoming JSON
type LoginRequest struct {
Username string `json:"username"`
Password string `json:"password"`
}
// LoginResponse - JWT issued
type LoginResponse struct {
Token string `json:"token"`
}
// createUser - Called from API
func createUser(req *RegisterRequest) error {
func createUser(req *RegisterRequest, ip net.IP) error {
// Check if user already exists..
if len(req.Password) < 8 {
return errors.New("Password must be at least 8 characters")
}
// Check username is set
// Check Username is set
if req.Username == "" {
return errors.New("A username is required")
}
// Check max length for Username
if len(req.Username) > 64 {
return errors.New("Username cannot be longer than 64 characters")
}
// Check username doesn't contain @
if strings.Contains(req.Username, "@") {
return errors.New("Username contains invalid characters")
}
// If set an email.. validate it!
if req.Email != "" {
if !isEmailValid(req.Email) {
@ -58,12 +88,85 @@ func createUser(req *RegisterRequest) error {
return err
}
return insertUser(req.Username, req.Email, hash)
return insertUser(req.Username, req.Email, hash, ip)
}
func loginUser(logReq *LoginRequest, ip net.IP) ([]byte, error) {
var resp []byte
var user User
if logReq.Username == "" {
return resp, errors.New("username must be set")
}
if logReq.Password == "" {
return resp, errors.New("password must be set")
}
if strings.Contains(logReq.Username, "@") {
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE email = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
if err != nil {
if err == sql.ErrNoRows {
return resp, errors.New("Invalid Username or Password")
}
}
} else {
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE username = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
if err == sql.ErrNoRows {
return resp, errors.New("Invalid Username or Password")
}
}
if !isValidPassword(logReq.Password, user) {
return resp, errors.New("Invalid Username or Password")
}
// Issue JWT + Response
token, err := generateJwt(user)
if err != nil {
log.Printf("Error generating JWT: %v", err)
return resp, errors.New("Error logging in")
}
loginResp := LoginResponse{
Token: token,
}
resp, _ = json.Marshal(&loginResp)
return resp, nil
}
func generateJwt(user User) (string, error) {
atClaims := jwt.MapClaims{}
atClaims["sub"] = user.UUID
atClaims["username"] = user.Username
atClaims["email"] = user.Email
atClaims["exp"] = time.Now().Add(JwtExpiry).Unix()
at := jwt.NewWithClaims(jwt.SigningMethodHS512, atClaims)
token, err := at.SignedString(JwtToken)
if err != nil {
return "", err
}
return token, nil
}
// insertUser - Does the dirtywork!
func insertUser(username string, email string, password []byte) error {
_, err := db.Exec("INSERT INTO users (uuid, created_at, username, email, password) VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,?,?)", username, email, password)
func insertUser(username string, email string, password []byte, ip net.IP) error {
_, err := db.Exec("INSERT INTO users (uuid, created_at, created_ip, modified_at, modified_ip, username, email, password) "+
"VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,NOW(),?,?,?,?)", ip, ip, username, email, password)
return err
}
func updateUser(uuid string, field string, value string, ip string) error {
_, err := db.Exec("UPDATE users SET ? = ?, modified_at = NOW(), modified_ip = ? WHERE uuid = ?", field, value, uuid, ip)
return err
}
func updateUserDirect(uuid string, field string, value string) error {
_, err := db.Exec("UPDATE users SET ? = ? WHERE uuid = ?", field, value, uuid)
return err
}
@ -75,7 +178,7 @@ func hashPassword(password string) ([]byte, error) {
// isValidPassword - Checks if password is valid
func isValidPassword(password string, user User) bool {
err := bcrypt.CompareHashAndPassword(user.password, []byte(password))
err := bcrypt.CompareHashAndPassword(user.Password, []byte(password))
if err != nil {
return false
}
@ -86,15 +189,19 @@ func isValidPassword(password string, user User) bool {
// userAlreadyExists - Returns bool indicating if a record exists for either username or email
// Using two look ups to make use of DB indexes.
func userAlreadyExists(req *RegisterRequest) bool {
var userExists int
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE username = ?", req.Username).Scan(&userExists)
if userExists > 0 {
count, err := getDbCount("SELECT COUNT(*) FROM users WHERE username = ?", req.Username)
if err != nil {
fmt.Printf("Error querying for duplicate users: %v", err)
return true
}
if count > 0 {
return true
}
if req.Email != "" {
// Only run email check if there's an email...
err = db.QueryRow("SELECT COUNT(*) FROM users WHERE email = ?", req.Email).Scan(&userExists)
count, err = getDbCount("SELECT COUNT(*) FROM users WHERE email = ?", req.Email)
}
if err != nil {
@ -102,5 +209,5 @@ func userAlreadyExists(req *RegisterRequest) bool {
return true
}
return userExists > 0
return count > 0
}

View File

@ -1,8 +1,12 @@
package goscrobble
import (
"encoding/hex"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"regexp"
)
@ -24,3 +28,54 @@ func isEmailValid(e string) bool {
}
return emailRegex.MatchString(e)
}
// contains - Check if string is in list
func contains(s []string, e string) bool {
for _, a := range s {
if a == e {
return true
}
}
return false
}
// getUserIp - Returns IP that isn't set in REVERSE_PROXY
func getUserIp(r *http.Request) net.IP {
var ip net.IP
host, _, _ := net.SplitHostPort(r.RemoteAddr)
if contains(ReverseProxies, host) {
host = r.Header.Get("X-FOWARDED-FOR")
}
ip = net.ParseIP(host)
return ip
}
// Inet_Aton converts an IPv4 net.IP object to a 64 bit integer.
func Inet_Aton(ip net.IP) int64 {
ipv4Int := big.NewInt(0)
ipv4Int.SetBytes(ip.To4())
return ipv4Int.Int64()
}
// Inet6_Aton converts an IP Address (IPv4 or IPv6) net.IP object to a hexadecimal
// representaiton. This function is the equivalent of
// inet6_aton({{ ip address }}) in MySQL.
func Inet6_Aton(ip net.IP) string {
ipv4 := false
if ip.To4() != nil {
ipv4 = true
}
ipInt := big.NewInt(0)
if ipv4 {
ipInt.SetBytes(ip.To4())
ipHex := hex.EncodeToString(ipInt.Bytes())
return ipHex
}
ipInt.SetBytes(ip.To16())
ipHex := hex.EncodeToString(ipInt.Bytes())
return ipHex
}

View File

@ -1,5 +1,5 @@
CREATE TABLE IF NOT EXISTS `config` (
`key` VARCHAR(255) NOT NULL,
`value` INT(11) NOT NULL DEFAULT 1,
`value` VARCHAR(255) NULL DEFAULT NULL,
PRIMARY KEY(`key`)
) DEFAULT CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci;

View File

@ -1,6 +1,9 @@
CREATE TABLE IF NOT EXISTS `users` (
`uuid` BINARY(16) PRIMARY KEY,
`created_at` DATETIME NOT NULL,
`created_ip` VARBINARY(16) NULL DEFAULT NULL,
`modified_at` DATETIME NOT NULL,
`modified_ip` VARBINARY(16) NULL DEFAULT NULL,
`username` VARCHAR(64) NOT NULL,
`password` VARCHAR(60) NOT NULL,
`email` VARCHAR(255) NULL DEFAULT NULL,