From c67be1bd752f4b976be277b8478941aca7f4cded Mon Sep 17 00:00:00 2001 From: Daniel Mason Date: Fri, 26 Mar 2021 12:21:28 +1300 Subject: [PATCH] Issue JWT on login --- .env.example | 4 +- cmd/go-scrobble/main.go | 18 +++++ internal/goscrobble/db.go | 11 +++ internal/goscrobble/jwt.go | 6 +- internal/goscrobble/server.go | 32 +++++++- internal/goscrobble/user.go | 145 +++++++++++++++++++++++++++++----- internal/goscrobble/utils.go | 55 +++++++++++++ migrations/1_init.up.sql | 2 +- migrations/2_users.up.sql | 3 + 9 files changed, 249 insertions(+), 27 deletions(-) diff --git a/.env.example b/.env.example index 70cda35c..ceaba4b8 100644 --- a/.env.example +++ b/.env.example @@ -4,4 +4,6 @@ MYSQL_PASS= MYSQL_DB= JWT_SECRET= -JWT_EXPIRY=86400 \ No newline at end of file +JWT_EXPIRY=86400 + +REVERSE_PROXIES=127.0.0.1 \ No newline at end of file diff --git a/cmd/go-scrobble/main.go b/cmd/go-scrobble/main.go index 906fcaba..4662ef56 100644 --- a/cmd/go-scrobble/main.go +++ b/cmd/go-scrobble/main.go @@ -3,6 +3,9 @@ package main import ( "log" "os" + "strconv" + "strings" + "time" "git.m2.nz/go-scrobble/internal/goscrobble" "github.com/joho/godotenv" @@ -17,6 +20,21 @@ func main() { // Store JWT secret goscrobble.JwtToken = []byte(os.Getenv("JWT_SECRET")) + // Store JWT expiry + goscrobble.JwtExpiry = 86400 + jwtExpiryStr := os.Getenv("JWT_EXPIRY") + if jwtExpiryStr != "" { + i, err := strconv.ParseFloat(jwtExpiryStr, 64) + if err != nil { + panic("Invalid JWT_EXPIRY value") + } + + goscrobble.JwtExpiry = time.Duration(i) * time.Second + } + + // Ignore reverse proxies + goscrobble.ReverseProxies = strings.Split(os.Getenv("REVERSE_PROXIES"), ",") + // // Boot up DB connection for life of application goscrobble.InitDb() defer goscrobble.CloseDbConn() diff --git a/internal/goscrobble/db.go b/internal/goscrobble/db.go index 8071afec..b30fdb1c 100644 --- a/internal/goscrobble/db.go +++ b/internal/goscrobble/db.go @@ -2,6 +2,7 @@ package goscrobble import ( "database/sql" + "errors" "fmt" "log" "os" @@ -71,3 +72,13 @@ func runMigrations() { panic(fmt.Errorf("Error running DB Migrations %v", err)) } } + +func getDbCount(query string, args ...interface{}) (int, error) { + var result int + err := db.QueryRow(query, args...).Scan(&result) + if err != nil { + return 0, errors.New("Error fetching data") + } + + return result, nil +} diff --git a/internal/goscrobble/jwt.go b/internal/goscrobble/jwt.go index 1f58158f..7b0948fb 100644 --- a/internal/goscrobble/jwt.go +++ b/internal/goscrobble/jwt.go @@ -1,8 +1,8 @@ package goscrobble import ( - "log" "net/http" + "time" "github.com/dgrijalva/jwt-go" ) @@ -10,6 +10,9 @@ import ( // JwtToken - Store token from .env var JwtToken []byte +// JwtExpiry - Expiry in seconds +var JwtExpiry time.Duration + // Store custom claims here type Claims struct { UUID string `json:"uuid"` @@ -26,7 +29,6 @@ func verifyToken(token string, w http.ResponseWriter) bool { }) if err != nil { - log.Printf("%v", err) if err == jwt.ErrSignatureInvalid { w.WriteHeader(http.StatusUnauthorized) return false diff --git a/internal/goscrobble/server.go b/internal/goscrobble/server.go index c00cfe9c..7cb74080 100644 --- a/internal/goscrobble/server.go +++ b/internal/goscrobble/server.go @@ -29,6 +29,9 @@ var heavyLimiter = NewIPRateLimiter(0.1, 1) // Limits to 5 req / sec var standardLimiter = NewIPRateLimiter(5, 5) +// List of Reverse proxies +var ReverseProxies []string + // HandleRequests - Boot HTTP! func HandleRequests() { // Create a new router @@ -44,7 +47,7 @@ func HandleRequests() { // No Auth v1.HandleFunc("/register", limitMiddleware(handleRegister, heavyLimiter)).Methods("POST") - v1.HandleFunc("/login", limitMiddleware(serveEndpoint, standardLimiter)).Methods("POST") + v1.HandleFunc("/login", limitMiddleware(handleLogin, standardLimiter)).Methods("POST") v1.HandleFunc("/logout", serveEndpoint).Methods("POST") // This just prevents it serving frontend stuff over /api @@ -133,7 +136,8 @@ func handleRegister(w http.ResponseWriter, r *http.Request) { return } - err = createUser(®Req) + ip := getUserIp(r) + err = createUser(®Req, ip) if err != nil { throwBadReq(w, err.Error()) return @@ -144,16 +148,36 @@ func handleRegister(w http.ResponseWriter, r *http.Request) { w.Write(msg) } +// handleLogin - Does as it says! +func handleLogin(w http.ResponseWriter, r *http.Request) { + logReq := LoginRequest{} + decoder := json.NewDecoder(r.Body) + err := decoder.Decode(&logReq) + if err != nil { + throwBadReq(w, err.Error()) + return + } + + ip := getUserIp(r) + data, err := loginUser(&logReq, ip) + if err != nil { + throwBadReq(w, err.Error()) + return + } + + w.WriteHeader(http.StatusOK) + w.Write(data) +} + // serveEndpoint - API stuffs func serveEndpoint(w http.ResponseWriter, r *http.Request) { - json, err := decodeJson(r.Body) + _, err := decodeJson(r.Body) if err != nil { // If we can't decode. Lets tell them nicely. http.Error(w, "{\"error\":\"Invalid JSON\"}", http.StatusBadRequest) return } - log.Printf("%v", json) // Lets trick 'em for now ;) ;) fmt.Fprintf(w, "{}") } diff --git a/internal/goscrobble/user.go b/internal/goscrobble/user.go index cefaf838..8e7feac6 100644 --- a/internal/goscrobble/user.go +++ b/internal/goscrobble/user.go @@ -1,24 +1,33 @@ package goscrobble import ( + "database/sql" + "encoding/json" "errors" "fmt" + "log" + "net" + "strings" "time" + "github.com/dgrijalva/jwt-go" "golang.org/x/crypto/bcrypt" ) const bCryptCost = 16 type User struct { - UUID string `json:"uuid"` - CreatedAt time.Time `json:"created_at"` - Username string `json:"username"` - password []byte - Email string `json:"email"` - Verified bool `json:"verified"` - Active bool `json:"active"` - Admin bool `json:"admin"` + UUID string `json:"uuid"` + CreatedAt time.Time `json:"created_at"` + CreatedIp net.IP `json:"created_ip"` + ModifiedAt time.Time `json:"modified_at"` + ModifiedIP net.IP `jsos:"modified_ip"` + Username string `json:"username"` + Password []byte `json:"password"` + Email string `json:"email"` + Verified bool `json:"verified"` + Active bool `json:"active"` + Admin bool `json:"admin"` } // RegisterRequest - Incoming JSON @@ -28,18 +37,39 @@ type RegisterRequest struct { Password string `json:"password"` } +// RegisterRequest - Incoming JSON +type LoginRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} + +// LoginResponse - JWT issued +type LoginResponse struct { + Token string `json:"token"` +} + // createUser - Called from API -func createUser(req *RegisterRequest) error { +func createUser(req *RegisterRequest, ip net.IP) error { // Check if user already exists.. if len(req.Password) < 8 { return errors.New("Password must be at least 8 characters") } - // Check username is set + // Check Username is set if req.Username == "" { return errors.New("A username is required") } + // Check max length for Username + if len(req.Username) > 64 { + return errors.New("Username cannot be longer than 64 characters") + } + + // Check username doesn't contain @ + if strings.Contains(req.Username, "@") { + return errors.New("Username contains invalid characters") + } + // If set an email.. validate it! if req.Email != "" { if !isEmailValid(req.Email) { @@ -58,12 +88,85 @@ func createUser(req *RegisterRequest) error { return err } - return insertUser(req.Username, req.Email, hash) + return insertUser(req.Username, req.Email, hash, ip) +} + +func loginUser(logReq *LoginRequest, ip net.IP) ([]byte, error) { + var resp []byte + var user User + + if logReq.Username == "" { + return resp, errors.New("username must be set") + } + + if logReq.Password == "" { + return resp, errors.New("password must be set") + } + + if strings.Contains(logReq.Username, "@") { + err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE email = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password) + if err != nil { + if err == sql.ErrNoRows { + return resp, errors.New("Invalid Username or Password") + } + } + } else { + err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE username = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password) + if err == sql.ErrNoRows { + return resp, errors.New("Invalid Username or Password") + } + } + + if !isValidPassword(logReq.Password, user) { + return resp, errors.New("Invalid Username or Password") + } + + // Issue JWT + Response + token, err := generateJwt(user) + if err != nil { + log.Printf("Error generating JWT: %v", err) + return resp, errors.New("Error logging in") + } + + loginResp := LoginResponse{ + Token: token, + } + + resp, _ = json.Marshal(&loginResp) + return resp, nil +} + +func generateJwt(user User) (string, error) { + atClaims := jwt.MapClaims{} + atClaims["sub"] = user.UUID + atClaims["username"] = user.Username + atClaims["email"] = user.Email + atClaims["exp"] = time.Now().Add(JwtExpiry).Unix() + at := jwt.NewWithClaims(jwt.SigningMethodHS512, atClaims) + token, err := at.SignedString(JwtToken) + if err != nil { + return "", err + } + + return token, nil } // insertUser - Does the dirtywork! -func insertUser(username string, email string, password []byte) error { - _, err := db.Exec("INSERT INTO users (uuid, created_at, username, email, password) VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,?,?)", username, email, password) +func insertUser(username string, email string, password []byte, ip net.IP) error { + _, err := db.Exec("INSERT INTO users (uuid, created_at, created_ip, modified_at, modified_ip, username, email, password) "+ + "VALUES (UUID_TO_BIN(UUID(), true),NOW(),?,NOW(),?,?,?,?)", ip, ip, username, email, password) + + return err +} + +func updateUser(uuid string, field string, value string, ip string) error { + _, err := db.Exec("UPDATE users SET ? = ?, modified_at = NOW(), modified_ip = ? WHERE uuid = ?", field, value, uuid, ip) + + return err +} + +func updateUserDirect(uuid string, field string, value string) error { + _, err := db.Exec("UPDATE users SET ? = ? WHERE uuid = ?", field, value, uuid) return err } @@ -75,7 +178,7 @@ func hashPassword(password string) ([]byte, error) { // isValidPassword - Checks if password is valid func isValidPassword(password string, user User) bool { - err := bcrypt.CompareHashAndPassword(user.password, []byte(password)) + err := bcrypt.CompareHashAndPassword(user.Password, []byte(password)) if err != nil { return false } @@ -86,15 +189,19 @@ func isValidPassword(password string, user User) bool { // userAlreadyExists - Returns bool indicating if a record exists for either username or email // Using two look ups to make use of DB indexes. func userAlreadyExists(req *RegisterRequest) bool { - var userExists int - err := db.QueryRow("SELECT COUNT(*) FROM users WHERE username = ?", req.Username).Scan(&userExists) - if userExists > 0 { + count, err := getDbCount("SELECT COUNT(*) FROM users WHERE username = ?", req.Username) + if err != nil { + fmt.Printf("Error querying for duplicate users: %v", err) + return true + } + + if count > 0 { return true } if req.Email != "" { // Only run email check if there's an email... - err = db.QueryRow("SELECT COUNT(*) FROM users WHERE email = ?", req.Email).Scan(&userExists) + count, err = getDbCount("SELECT COUNT(*) FROM users WHERE email = ?", req.Email) } if err != nil { @@ -102,5 +209,5 @@ func userAlreadyExists(req *RegisterRequest) bool { return true } - return userExists > 0 + return count > 0 } diff --git a/internal/goscrobble/utils.go b/internal/goscrobble/utils.go index 60570c5e..a5dc3d39 100644 --- a/internal/goscrobble/utils.go +++ b/internal/goscrobble/utils.go @@ -1,8 +1,12 @@ package goscrobble import ( + "encoding/hex" "encoding/json" "io" + "math/big" + "net" + "net/http" "regexp" ) @@ -24,3 +28,54 @@ func isEmailValid(e string) bool { } return emailRegex.MatchString(e) } + +// contains - Check if string is in list +func contains(s []string, e string) bool { + for _, a := range s { + if a == e { + return true + } + } + return false +} + +// getUserIp - Returns IP that isn't set in REVERSE_PROXY +func getUserIp(r *http.Request) net.IP { + var ip net.IP + host, _, _ := net.SplitHostPort(r.RemoteAddr) + if contains(ReverseProxies, host) { + host = r.Header.Get("X-FOWARDED-FOR") + } + + ip = net.ParseIP(host) + + return ip +} + +// Inet_Aton converts an IPv4 net.IP object to a 64 bit integer. +func Inet_Aton(ip net.IP) int64 { + ipv4Int := big.NewInt(0) + ipv4Int.SetBytes(ip.To4()) + return ipv4Int.Int64() +} + +// Inet6_Aton converts an IP Address (IPv4 or IPv6) net.IP object to a hexadecimal +// representaiton. This function is the equivalent of +// inet6_aton({{ ip address }}) in MySQL. +func Inet6_Aton(ip net.IP) string { + ipv4 := false + if ip.To4() != nil { + ipv4 = true + } + + ipInt := big.NewInt(0) + if ipv4 { + ipInt.SetBytes(ip.To4()) + ipHex := hex.EncodeToString(ipInt.Bytes()) + return ipHex + } + + ipInt.SetBytes(ip.To16()) + ipHex := hex.EncodeToString(ipInt.Bytes()) + return ipHex +} diff --git a/migrations/1_init.up.sql b/migrations/1_init.up.sql index ad916fec..d2d1efdd 100644 --- a/migrations/1_init.up.sql +++ b/migrations/1_init.up.sql @@ -1,5 +1,5 @@ CREATE TABLE IF NOT EXISTS `config` ( `key` VARCHAR(255) NOT NULL, - `value` INT(11) NOT NULL DEFAULT 1, + `value` VARCHAR(255) NULL DEFAULT NULL, PRIMARY KEY(`key`) ) DEFAULT CHARACTER SET = utf8mb4 COLLATE = utf8mb4_0900_ai_ci; \ No newline at end of file diff --git a/migrations/2_users.up.sql b/migrations/2_users.up.sql index cdc764a6..7203af66 100644 --- a/migrations/2_users.up.sql +++ b/migrations/2_users.up.sql @@ -1,6 +1,9 @@ CREATE TABLE IF NOT EXISTS `users` ( `uuid` BINARY(16) PRIMARY KEY, `created_at` DATETIME NOT NULL, + `created_ip` VARBINARY(16) NULL DEFAULT NULL, + `modified_at` DATETIME NOT NULL, + `modified_ip` VARBINARY(16) NULL DEFAULT NULL, `username` VARCHAR(64) NOT NULL, `password` VARCHAR(60) NOT NULL, `email` VARCHAR(255) NULL DEFAULT NULL,