mirror of
https://github.com/idanoo/GoScrobble.git
synced 2024-11-21 16:11:56 +00:00
Issue JWT on login
This commit is contained in:
parent
12f4fb6d89
commit
c67be1bd75
@ -4,4 +4,6 @@ MYSQL_PASS=
|
||||
MYSQL_DB=
|
||||
|
||||
JWT_SECRET=
|
||||
JWT_EXPIRY=86400
|
||||
JWT_EXPIRY=86400
|
||||
|
||||
REVERSE_PROXIES=127.0.0.1
|
@ -3,6 +3,9 @@ package main
|
||||
import (
|
||||
"log"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"git.m2.nz/go-scrobble/internal/goscrobble"
|
||||
"github.com/joho/godotenv"
|
||||
@ -17,6 +20,21 @@ func main() {
|
||||
// Store JWT secret
|
||||
goscrobble.JwtToken = []byte(os.Getenv("JWT_SECRET"))
|
||||
|
||||
// Store JWT expiry
|
||||
goscrobble.JwtExpiry = 86400
|
||||
jwtExpiryStr := os.Getenv("JWT_EXPIRY")
|
||||
if jwtExpiryStr != "" {
|
||||
i, err := strconv.ParseFloat(jwtExpiryStr, 64)
|
||||
if err != nil {
|
||||
panic("Invalid JWT_EXPIRY value")
|
||||
}
|
||||
|
||||
goscrobble.JwtExpiry = time.Duration(i) * time.Second
|
||||
}
|
||||
|
||||
// Ignore reverse proxies
|
||||
goscrobble.ReverseProxies = strings.Split(os.Getenv("REVERSE_PROXIES"), ",")
|
||||
|
||||
// // Boot up DB connection for life of application
|
||||
goscrobble.InitDb()
|
||||
defer goscrobble.CloseDbConn()
|
||||
|
@ -2,6 +2,7 @@ package goscrobble
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
@ -71,3 +72,13 @@ func runMigrations() {
|
||||
panic(fmt.Errorf("Error running DB Migrations %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func getDbCount(query string, args ...interface{}) (int, error) {
|
||||
var result int
|
||||
err := db.QueryRow(query, args...).Scan(&result)
|
||||
if err != nil {
|
||||
return 0, errors.New("Error fetching data")
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
package goscrobble
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
@ -10,6 +10,9 @@ import (
|
||||
// JwtToken - Store token from .env
|
||||
var JwtToken []byte
|
||||
|
||||
// JwtExpiry - Expiry in seconds
|
||||
var JwtExpiry time.Duration
|
||||
|
||||
// Store custom claims here
|
||||
type Claims struct {
|
||||
UUID string `json:"uuid"`
|
||||
@ -26,7 +29,6 @@ func verifyToken(token string, w http.ResponseWriter) bool {
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Printf("%v", err)
|
||||
if err == jwt.ErrSignatureInvalid {
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
return false
|
||||
|
@ -29,6 +29,9 @@ var heavyLimiter = NewIPRateLimiter(0.1, 1)
|
||||
// Limits to 5 req / sec
|
||||
var standardLimiter = NewIPRateLimiter(5, 5)
|
||||
|
||||
// List of Reverse proxies
|
||||
var ReverseProxies []string
|
||||
|
||||
// HandleRequests - Boot HTTP!
|
||||
func HandleRequests() {
|
||||
// Create a new router
|
||||
@ -44,7 +47,7 @@ func HandleRequests() {
|
||||
|
||||
// No Auth
|
||||
v1.HandleFunc("/register", limitMiddleware(handleRegister, heavyLimiter)).Methods("POST")
|
||||
v1.HandleFunc("/login", limitMiddleware(serveEndpoint, standardLimiter)).Methods("POST")
|
||||
v1.HandleFunc("/login", limitMiddleware(handleLogin, standardLimiter)).Methods("POST")
|
||||
v1.HandleFunc("/logout", serveEndpoint).Methods("POST")
|
||||
|
||||
// This just prevents it serving frontend stuff over /api
|
||||
@ -133,7 +136,8 @@ func handleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
err = createUser(®Req)
|
||||
ip := getUserIp(r)
|
||||
err = createUser(®Req, ip)
|
||||
if err != nil {
|
||||
throwBadReq(w, err.Error())
|
||||
return
|
||||
@ -144,16 +148,36 @@ func handleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
w.Write(msg)
|
||||
}
|
||||
|
||||
// handleLogin - Does as it says!
|
||||
func handleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
logReq := LoginRequest{}
|
||||
decoder := json.NewDecoder(r.Body)
|
||||
err := decoder.Decode(&logReq)
|
||||
if err != nil {
|
||||
throwBadReq(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ip := getUserIp(r)
|
||||
data, err := loginUser(&logReq, ip)
|
||||
if err != nil {
|
||||
throwBadReq(w, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
w.WriteHeader(http.StatusOK)
|
||||
w.Write(data)
|
||||
}
|
||||
|
||||
// serveEndpoint - API stuffs
|
||||
func serveEndpoint(w http.ResponseWriter, r *http.Request) {
|
||||
json, err := decodeJson(r.Body)
|
||||
_, err := decodeJson(r.Body)
|
||||
if err != nil {
|
||||
// If we can't decode. Lets tell them nicely.
|
||||
http.Error(w, "{\"error\":\"Invalid JSON\"}", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log.Printf("%v", json)
|
||||
// Lets trick 'em for now ;) ;)
|
||||
fmt.Fprintf(w, "{}")
|
||||
}
|
||||
|
@ -1,24 +1,33 @@
|
||||
package goscrobble
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"log"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dgrijalva/jwt-go"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
const bCryptCost = 16
|
||||
|
||||
type User struct {
|
||||
UUID string `json:"uuid"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
Username string `json:"username"`
|
||||
password []byte
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified"`
|
||||
Active bool `json:"active"`
|
||||
Admin bool `json:"admin"`
|
||||
UUID string `json:"uuid"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
CreatedIp net.IP `json:"created_ip"`
|
||||
ModifiedAt time.Time `json:"modified_at"`
|
||||
ModifiedIP net.IP `jsos:"modified_ip"`
|
||||
Username string `json:"username"`
|
||||
Password []byte `json:"password"`
|
||||
Email string `json:"email"`
|
||||
Verified bool `json:"verified"`
|
||||
Active bool `json:"active"`
|
||||
Admin bool `json:"admin"`
|
||||
}
|
||||
|
||||
// RegisterRequest - Incoming JSON
|
||||
@ -28,18 +37,39 @@ type RegisterRequest struct {
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// RegisterRequest - Incoming JSON
|
||||
type LoginRequest struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
||||
|
||||
// LoginResponse - JWT issued
|
||||
type LoginResponse struct {
|
||||
Token string `json:"token"`
|
||||
}
|
||||
|
||||
// createUser - Called from API
|
||||
func createUser(req *RegisterRequest) error {
|
||||
func createUser(req *RegisterRequest, ip net.IP) error {
|
||||
// Check if user already exists..
|
||||
if len(req.Password) < 8 {
|
||||
return errors.New("Password must be at least 8 characters")
|
||||
}
|
||||
|
||||
// Check username is set
|
||||
// Check Username is set
|
||||
if req.Username == "" {
|
||||
return errors.New("A username is required")
|
||||
}
|
||||
|
||||
// Check max length for Username
|
||||
if len(req.Username) > 64 {
|
||||
return errors.New("Username cannot be longer than 64 characters")
|
||||
}
|
||||
|
||||
// Check username doesn't contain @
|
||||
if strings.Contains(req.Username, "@") {
|
||||
return errors.New("Username contains invalid characters")
|
||||
}
|
||||
|
||||
// If set an email.. validate it!
|
||||
if req.Email != "" {
|
||||
if !isEmailValid(req.Email) {
|
||||
@ -58,12 +88,85 @@ func createUser(req *RegisterRequest) error {
|
||||
return err
|
||||
}
|
||||
|
||||
return insertUser(req.Username, req.Email, hash)
|
||||
return insertUser(req.Username, req.Email, hash, ip)
|
||||
}
|
||||
|
||||
func loginUser(logReq *LoginRequest, ip net.IP) ([]byte, error) {
|
||||
var resp []byte
|
||||
var user User
|
||||
|
||||
if logReq.Username == "" {
|
||||
return resp, errors.New("username must be set")
|
||||
}
|
||||
|
||||
if logReq.Password == "" {
|
||||
return resp, errors.New("password must be set")
|
||||
}
|
||||
|
||||
if strings.Contains(logReq.Username, "@") {
|
||||
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE email = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return resp, errors.New("Invalid Username or Password")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE username = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
|
||||
if err == sql.ErrNoRows {
|
||||
return resp, errors.New("Invalid Username or Password")
|
||||
}
|
||||
}
|
||||
|
||||
if !isValidPassword(logReq.Password, user) {
|
||||
return resp, errors.New("Invalid Username or Password")
|
||||
}
|
||||
|
||||
// Issue JWT + Response
|
||||
token, err := generateJwt(user)
|
||||
if err != nil {
|
||||
log.Printf("Error generating JWT: %v", err)
|
||||
return resp, errors.New("Error logging in")
|
||||
}
|
||||
|
||||
loginResp := LoginResponse{
|
||||
Token: token,
|
||||
}
|
||||
|
||||
resp, _ = json.Marshal(&loginResp)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func generateJwt(user User) (string, error) {
|
||||
atClaims := jwt.MapClaims{}
|
||||
atClaims["sub"] = user.UUID
|
||||
atClaims["username"] = user.Username
|
||||
atClaims["email"] = user.Email
|
||||
atClaims["exp"] = time.Now().Add(JwtExpiry).Unix()
|
||||
at := jwt.NewWithClaims(jwt.SigningMethodHS512, atClaims)
|
||||
token, err := at.SignedString(JwtToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// insertUser - Does the dirtywork!
|
||||
func insertUser(username string, email string, password []byte) error {
|
||||
_, err := db.Exec("INSERT INTO users (uuid, created_at, username, email, password) VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,?,?)", username, email, password)
|
||||
func insertUser(username string, email string, password []byte, ip net.IP) error {
|
||||
_, err := db.Exec("INSERT INTO users (uuid, created_at, created_ip, modified_at, modified_ip, username, email, password) "+
|
||||
"VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,NOW(),?,?,?,?)", ip, ip, username, email, password)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func updateUser(uuid string, field string, value string, ip string) error {
|
||||
_, err := db.Exec("UPDATE users SET ? = ?, modified_at = NOW(), modified_ip = ? WHERE uuid = ?", field, value, uuid, ip)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func updateUserDirect(uuid string, field string, value string) error {
|
||||
_, err := db.Exec("UPDATE users SET ? = ? WHERE uuid = ?", field, value, uuid)
|
||||
|
||||
return err
|
||||
}
|
||||
@ -75,7 +178,7 @@ func hashPassword(password string) ([]byte, error) {
|
||||
|
||||
// isValidPassword - Checks if password is valid
|
||||
func isValidPassword(password string, user User) bool {
|
||||
err := bcrypt.CompareHashAndPassword(user.password, []byte(password))
|
||||
err := bcrypt.CompareHashAndPassword(user.Password, []byte(password))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
@ -86,15 +189,19 @@ func isValidPassword(password string, user User) bool {
|
||||
// userAlreadyExists - Returns bool indicating if a record exists for either username or email
|
||||
// Using two look ups to make use of DB indexes.
|
||||
func userAlreadyExists(req *RegisterRequest) bool {
|
||||
var userExists int
|
||||
err := db.QueryRow("SELECT COUNT(*) FROM users WHERE username = ?", req.Username).Scan(&userExists)
|
||||
if userExists > 0 {
|
||||
count, err := getDbCount("SELECT COUNT(*) FROM users WHERE username = ?", req.Username)
|
||||
if err != nil {
|
||||
fmt.Printf("Error querying for duplicate users: %v", err)
|
||||
return true
|
||||
}
|
||||
|
||||
if count > 0 {
|
||||
return true
|
||||
}
|
||||
|
||||
if req.Email != "" {
|
||||
// Only run email check if there's an email...
|
||||
err = db.QueryRow("SELECT COUNT(*) FROM users WHERE email = ?", req.Email).Scan(&userExists)
|
||||
count, err = getDbCount("SELECT COUNT(*) FROM users WHERE email = ?", req.Email)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
@ -102,5 +209,5 @@ func userAlreadyExists(req *RegisterRequest) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
return userExists > 0
|
||||
return count > 0
|
||||
}
|
||||
|
@ -1,8 +1,12 @@
|
||||
package goscrobble
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"regexp"
|
||||
)
|
||||
|
||||
@ -24,3 +28,54 @@ func isEmailValid(e string) bool {
|
||||
}
|
||||
return emailRegex.MatchString(e)
|
||||
}
|
||||
|
||||
// contains - Check if string is in list
|
||||
func contains(s []string, e string) bool {
|
||||
for _, a := range s {
|
||||
if a == e {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getUserIp - Returns IP that isn't set in REVERSE_PROXY
|
||||
func getUserIp(r *http.Request) net.IP {
|
||||
var ip net.IP
|
||||
host, _, _ := net.SplitHostPort(r.RemoteAddr)
|
||||
if contains(ReverseProxies, host) {
|
||||
host = r.Header.Get("X-FOWARDED-FOR")
|
||||
}
|
||||
|
||||
ip = net.ParseIP(host)
|
||||
|
||||
return ip
|
||||
}
|
||||
|
||||
// Inet_Aton converts an IPv4 net.IP object to a 64 bit integer.
|
||||
func Inet_Aton(ip net.IP) int64 {
|
||||
ipv4Int := big.NewInt(0)
|
||||
ipv4Int.SetBytes(ip.To4())
|
||||
return ipv4Int.Int64()
|
||||
}
|
||||
|
||||
// Inet6_Aton converts an IP Address (IPv4 or IPv6) net.IP object to a hexadecimal
|
||||
// representaiton. This function is the equivalent of
|
||||
// inet6_aton({{ ip address }}) in MySQL.
|
||||
func Inet6_Aton(ip net.IP) string {
|
||||
ipv4 := false
|
||||
if ip.To4() != nil {
|
||||
ipv4 = true
|
||||
}
|
||||
|
||||
ipInt := big.NewInt(0)
|
||||
if ipv4 {
|
||||
ipInt.SetBytes(ip.To4())
|
||||
ipHex := hex.EncodeToString(ipInt.Bytes())
|
||||
return ipHex
|
||||
}
|
||||
|
||||
ipInt.SetBytes(ip.To16())
|
||||
ipHex := hex.EncodeToString(ipInt.Bytes())
|
||||
return ipHex
|
||||
}
|
||||
|
@ -1,5 +1,5 @@
|
||||
CREATE TABLE IF NOT EXISTS `config` (
|
||||
`key` VARCHAR(255) NOT NULL,
|
||||
`value` INT(11) NOT NULL DEFAULT 1,
|
||||
`value` VARCHAR(255) NULL DEFAULT NULL,
|
||||
PRIMARY KEY(`key`)
|
||||
) DEFAULT CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci;
|
@ -1,6 +1,9 @@
|
||||
CREATE TABLE IF NOT EXISTS `users` (
|
||||
`uuid` BINARY(16) PRIMARY KEY,
|
||||
`created_at` DATETIME NOT NULL,
|
||||
`created_ip` VARBINARY(16) NULL DEFAULT NULL,
|
||||
`modified_at` DATETIME NOT NULL,
|
||||
`modified_ip` VARBINARY(16) NULL DEFAULT NULL,
|
||||
`username` VARCHAR(64) NOT NULL,
|
||||
`password` VARCHAR(60) NOT NULL,
|
||||
`email` VARCHAR(255) NULL DEFAULT NULL,
|
||||
|
Loading…
Reference in New Issue
Block a user