- Fixed looking up invalid profiles
- Added valid error handling to bad request && rate limiting
- Add Sendgrid library (Will add SMTP later)
- Complete password reset process
This commit is contained in:
Daniel Mason 2021-04-02 01:56:08 +13:00
parent e570314ac2
commit fd615102a8
Signed by: idanoo
GPG key ID: 387387CDBC02F132
28 changed files with 871 additions and 261 deletions

View file

@ -1,4 +0,0 @@
package goscrobble
func getImageLastFM(src string) {
}

View file

@ -10,6 +10,9 @@ import (
// ParseJellyfinInput - Transform API data into a common struct
func ParseJellyfinInput(userUUID string, data map[string]interface{}, ip net.IP, tx *sql.Tx) error {
// Debugging
fmt.Printf("%+v", data)
if data["ItemType"] != "Audio" {
return errors.New("Media type not audio")
}

View file

@ -0,0 +1,66 @@
package goscrobble
import (
"database/sql"
"fmt"
"net"
)
// ParseMultiScrobblerInput - Transform API data
func ParseMultiScrobblerInput(userUUID string, data map[string]interface{}, ip net.IP, tx *sql.Tx) error {
// Debugging
fmt.Printf("%+v", data)
// if data["ItemType"] != "Audio" {
// return errors.New("Media type not audio")
// }
// // Safety Checks
// if data["Artist"] == nil {
// return errors.New("Missing artist data")
// }
// if data["Album"] == nil {
// return errors.New("Missing album data")
// }
// if data["Name"] == nil {
// return errors.New("Missing track data")
// }
// // Insert artist if not exist
// artist, err := insertArtist(fmt.Sprintf("%s", data["Artist"]), fmt.Sprintf("%s", data["Provider_musicbrainzartist"]), tx)
// if err != nil {
// log.Printf("%+v", err)
// return errors.New("Failed to map artist")
// }
// // Insert album if not exist
// artists := []string{artist.Uuid}
// album, err := insertAlbum(fmt.Sprintf("%s", data["Album"]), fmt.Sprintf("%s", data["Provider_musicbrainzalbum"]), artists, tx)
// if err != nil {
// log.Printf("%+v", err)
// return errors.New("Failed to map album")
// }
// // Insert album if not exist
// track, err := insertTrack(fmt.Sprintf("%s", data["Name"]), fmt.Sprintf("%s", data["Provider_musicbrainztrack"]), album.Uuid, artists, tx)
// if err != nil {
// log.Printf("%+v", err)
// return errors.New("Failed to map track")
// }
// // Insert album if not exist
// err = insertScrobble(userUUID, track.Uuid, "jellyfin", ip, tx)
// if err != nil {
// log.Printf("%+v", err)
// return errors.New("Failed to map track")
// }
// _ = album
// _ = artist
// _ = track
// Insert track if not exist
return nil
}

View file

@ -2,38 +2,21 @@ package goscrobble
import (
"encoding/json"
"errors"
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/gorilla/mux"
"github.com/rs/cors"
)
// spaHandler - Handles Single Page Applications (React)
type spaHandler struct {
staticPath string
indexPath string
}
type jsonResponse struct {
Err string `json:"error,omitempty"`
Msg string `json:"message,omitempty"`
Err string `json:"error,omitempty"`
Msg string `json:"message,omitempty"`
Valid bool `json:"valid,omitempty"`
}
// Limits to 1 req / 4 sec
var heavyLimiter = NewIPRateLimiter(0.25, 2)
// Limits to 5 req / sec
var standardLimiter = NewIPRateLimiter(5, 5)
// Limits to 10 req / sec
var lightLimiter = NewIPRateLimiter(10, 10)
// List of Reverse proxies
var ReverseProxies []string
@ -50,6 +33,7 @@ func HandleRequests(port string) {
// Static Token for /ingress
v1.HandleFunc("/ingress/jellyfin", tokenMiddleware(handleIngress)).Methods("POST")
v1.HandleFunc("/ingress/multiscrobbler", tokenMiddleware(handleIngress)).Methods("POST")
// JWT Auth - PWN PROFILE ONLY.
v1.HandleFunc("/user", jwtMiddleware(fetchUser)).Methods("GET")
@ -66,6 +50,8 @@ func HandleRequests(port string) {
v1.HandleFunc("/register", limitMiddleware(handleRegister, heavyLimiter)).Methods("POST")
v1.HandleFunc("/login", limitMiddleware(handleLogin, standardLimiter)).Methods("POST")
v1.HandleFunc("/sendreset", limitMiddleware(handleSendReset, heavyLimiter)).Methods("POST")
v1.HandleFunc("/resetpassword", limitMiddleware(handleResetPassword, heavyLimiter)).Methods("POST")
// This just prevents it serving frontend stuff over /api
r.PathPrefix("/api")
@ -87,161 +73,7 @@ func HandleRequests(port string) {
log.Fatal(http.ListenAndServe(":"+port, handler))
}
// MIDDLEWARE RESPONSES
// throwUnauthorized - Throws a 403
func throwUnauthorized(w http.ResponseWriter, m string) {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
err := errors.New(string(js))
http.Error(w, err.Error(), http.StatusUnauthorized)
}
// throwUnauthorized - Throws a 403
func throwBadReq(w http.ResponseWriter, m string) {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
err := errors.New(string(js))
http.Error(w, err.Error(), http.StatusBadRequest)
}
// throwOkError - Throws a 403
func throwOkError(w http.ResponseWriter, m string) {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusOK)
w.Write(js)
}
// throwOkMessage - Throws a happy 200
func throwOkMessage(w http.ResponseWriter, m string) {
jr := jsonResponse{
Msg: m,
}
js, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusOK)
w.Write(js)
}
// throwOkMessage - Throws a happy 200
func throwInvalidJson(w http.ResponseWriter) {
jr := jsonResponse{
Err: "Invalid JSON",
}
js, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusBadRequest)
w.Write(js)
}
// generateJsonMessage - Generates a message:str response
func generateJsonMessage(m string) []byte {
jr := jsonResponse{
Msg: m,
}
js, _ := json.Marshal(&jr)
return js
}
// generateJsonError - Generates a err:str response
func generateJsonError(m string) []byte {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
return js
}
// MIDDLEWARE ACTIONS
// tokenMiddleware - Validates token to a user
func tokenMiddleware(next func(http.ResponseWriter, *http.Request, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
if authToken == "" {
throwUnauthorized(w, "A token is required")
return
}
userUuid, err := getUserUuidForToken(authToken)
if err != nil {
throwUnauthorized(w, err.Error())
return
}
next(w, r, userUuid)
}
}
// jwtMiddleware - Validates middleware to a user
func jwtMiddleware(next func(http.ResponseWriter, *http.Request, string, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
claims, err := verifyJWTToken(authToken)
if err != nil {
throwUnauthorized(w, "Invalid JWT Token")
return
}
var reqUuid string
for k, v := range mux.Vars(r) {
if k == "uuid" {
reqUuid = v
}
}
next(w, r, claims.Subject, reqUuid)
}
}
// adminMiddleware - Validates user is admin
func adminMiddleware(next func(http.ResponseWriter, *http.Request, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
claims, err := verifyJWTToken(authToken)
if err != nil {
throwUnauthorized(w, "Invalid JWT Token")
return
}
user, err := getUser(claims.Subject)
if err != nil {
throwUnauthorized(w, err.Error())
return
}
if !user.Admin {
throwUnauthorized(w, "User is not admin")
return
}
next(w, r, claims.Subject)
}
}
// limitMiddleware - Rate limits important stuff
func limitMiddleware(next http.HandlerFunc, limiter *IPRateLimiter) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter := limiter.GetLimiter(r.RemoteAddr)
if !limiter.Allow() {
msg := generateJsonMessage("Too many requests")
w.WriteHeader(http.StatusTooManyRequests)
w.Write(msg)
return
}
next(w, r)
})
}
// API ENDPOINT HANDLING
// handleRegister - Does as it says!
func handleRegister(w http.ResponseWriter, r *http.Request) {
regReq := RegisterRequest{}
@ -296,6 +128,86 @@ func handleStats(w http.ResponseWriter, r *http.Request) {
w.Write(js)
}
// handleSendReset - Does as it says!
func handleSendReset(w http.ResponseWriter, r *http.Request) {
req := RegisterRequest{}
decoder := json.NewDecoder(r.Body)
err := decoder.Decode(&req)
if err != nil {
throwBadReq(w, err.Error())
return
}
if req.Email == "" {
throwOkError(w, "Invalid Email")
return
}
_ = getUserIp(r)
user, err := getUserByEmail(req.Email)
if err != nil {
throwOkError(w, err.Error())
return
}
ip := getUserIp(r)
err = user.sendResetEmail(ip)
if err != nil {
throwOkError(w, err.Error())
return
}
throwOkMessage(w, "Password reset email sent")
}
// handleSendReset - Does as it says!
func handleResetPassword(w http.ResponseWriter, r *http.Request) {
bodyJson, err := decodeJson(r.Body)
if err != nil {
throwInvalidJson(w)
return
}
if bodyJson["password"] == nil {
// validating
valid, err := checkResetToken(fmt.Sprintf("%s", bodyJson["token"]))
if err != nil {
throwOkError(w, err.Error())
return
}
jr := jsonResponse{
Valid: valid,
}
msg, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusOK)
w.Write(msg)
return
} else {
// resetting
token := fmt.Sprintf("%s", bodyJson["token"])
pw := fmt.Sprintf("%s", bodyJson["password"])
if len(pw) < 8 {
throwOkError(w, "Password must be at least 8 characters")
return
}
ip := getUserIp(r)
user, err := getUserByResetToken(token)
if err != nil {
throwOkError(w, err.Error())
return
}
err = user.updatePassword(pw, ip)
if err != nil {
throwOkError(w, err.Error())
return
}
throwOkMessage(w, "Password updated successfully!")
return
}
}
// serveEndpoint - API stuffs
func handleIngress(w http.ResponseWriter, r *http.Request, userUuid string) {
bodyJson, err := decodeJson(r.Body)
@ -304,32 +216,39 @@ func handleIngress(w http.ResponseWriter, r *http.Request, userUuid string) {
return
}
ip := getUserIp(r)
tx, _ := db.Begin()
ingressType := strings.Replace(r.URL.Path, "/api/v1/ingress/", "", 1)
switch ingressType {
case "jellyfin":
tx, _ := db.Begin()
ip := getUserIp(r)
err := ParseJellyfinInput(userUuid, bodyJson, ip, tx)
if err != nil {
// log.Printf("Error inserting track: %+v", err)
tx.Rollback()
throwOkError(w, err.Error())
return
}
err = tx.Commit()
case "multiscrobbler":
err := ParseMultiScrobblerInput(userUuid, bodyJson, ip, tx)
if err != nil {
tx.Rollback()
throwOkError(w, err.Error())
return
}
default:
tx.Rollback()
throwBadReq(w, "Unknown ingress type")
}
throwOkMessage(w, "success")
err = tx.Commit()
if err != nil {
throwOkError(w, err.Error())
return
}
throwBadReq(w, "Unknown ingress type")
throwOkMessage(w, "success")
return
}
// fetchUser - Return personal userprofile
@ -432,35 +351,3 @@ func fetchProfile(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
w.Write(json)
}
// FRONTEND HANDLING
// ServerHTTP - Frontend server
func (h spaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Get the absolute path to prevent directory traversal
path, err := filepath.Abs(r.URL.Path)
if err != nil {
// If we failed to get the absolute path respond with a 400 bad request and return
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// prepend the path with the path to the static directory
path = filepath.Join(h.staticPath, path)
// check whether a file exists at the given path
_, err = os.Stat(path)
if os.IsNotExist(err) {
// file does not exist, serve index.html
http.ServeFile(w, r, filepath.Join(h.staticPath, h.indexPath))
return
} else if err != nil {
// if we got an error (that wasn't that the file doesn't exist) stating the
// file, return a 500 internal server error and stop
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// otherwise, use http.FileServer to serve the static dir
http.FileServer(http.Dir(h.staticPath)).ServeHTTP(w, r)
}

View file

@ -0,0 +1,104 @@
package goscrobble
import (
"encoding/json"
"net/http"
"strings"
"github.com/gorilla/mux"
)
// Limits to 1 req / 4 sec
var heavyLimiter = NewIPRateLimiter(0.25, 2)
// Limits to 5 req / sec
var standardLimiter = NewIPRateLimiter(5, 5)
// Limits to 10 req / sec
var lightLimiter = NewIPRateLimiter(10, 10)
// tokenMiddleware - Validates token to a user
func tokenMiddleware(next func(http.ResponseWriter, *http.Request, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
if authToken == "" {
throwUnauthorized(w, "A token is required")
return
}
userUuid, err := getUserUuidForToken(authToken)
if err != nil {
throwUnauthorized(w, err.Error())
return
}
next(w, r, userUuid)
}
}
// jwtMiddleware - Validates middleware to a user
func jwtMiddleware(next func(http.ResponseWriter, *http.Request, string, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
claims, err := verifyJWTToken(authToken)
if err != nil {
throwUnauthorized(w, "Invalid JWT Token")
return
}
var reqUuid string
for k, v := range mux.Vars(r) {
if k == "uuid" {
reqUuid = v
}
}
next(w, r, claims.Subject, reqUuid)
}
}
// adminMiddleware - Validates user is admin
func adminMiddleware(next func(http.ResponseWriter, *http.Request, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
claims, err := verifyJWTToken(authToken)
if err != nil {
throwUnauthorized(w, "Invalid JWT Token")
return
}
user, err := getUser(claims.Subject)
if err != nil {
throwUnauthorized(w, err.Error())
return
}
if !user.Admin {
throwUnauthorized(w, "User is not admin")
return
}
next(w, r, claims.Subject)
}
}
// limitMiddleware - Rate limits important stuff
func limitMiddleware(next http.HandlerFunc, limiter *IPRateLimiter) http.HandlerFunc {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
limiter := limiter.GetLimiter(r.RemoteAddr)
if !limiter.Allow() {
jr := jsonResponse{
Msg: "Too many requests",
}
msg, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusTooManyRequests)
w.Write(msg)
return
}
next(w, r)
})
}

View file

@ -0,0 +1,58 @@
package goscrobble
import (
"encoding/json"
"errors"
"net/http"
)
// MIDDLEWARE RESPONSES
// throwUnauthorized - Throws a 403
func throwUnauthorized(w http.ResponseWriter, m string) {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
err := errors.New(string(js))
http.Error(w, err.Error(), http.StatusUnauthorized)
}
// throwUnauthorized - Throws a 403
func throwBadReq(w http.ResponseWriter, m string) {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
err := errors.New(string(js))
http.Error(w, err.Error(), http.StatusBadRequest)
}
// throwOkError - Throws a 403
func throwOkError(w http.ResponseWriter, m string) {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusOK)
w.Write(js)
}
// throwOkMessage - Throws a happy 200
func throwOkMessage(w http.ResponseWriter, m string) {
jr := jsonResponse{
Msg: m,
}
js, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusOK)
w.Write(js)
}
// throwOkMessage - Throws a happy 200
func throwInvalidJson(w http.ResponseWriter) {
jr := jsonResponse{
Err: "Invalid JSON",
}
js, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusBadRequest)
w.Write(js)
}

View file

@ -0,0 +1,43 @@
package goscrobble
import (
"net/http"
"os"
"path/filepath"
)
// spaHandler - Handles Single Page Applications (React)
type spaHandler struct {
staticPath string
indexPath string
}
// ServerHTTP - Frontend React server
func (h spaHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// Get the absolute path to prevent directory traversal
path, err := filepath.Abs(r.URL.Path)
if err != nil {
// If we failed to get the absolute path respond with a 400 bad request and return
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
// prepend the path with the path to the static directory
path = filepath.Join(h.staticPath, path)
// check whether a file exists at the given path
_, err = os.Stat(path)
if os.IsNotExist(err) {
// file does not exist, serve index.html
http.ServeFile(w, r, filepath.Join(h.staticPath, h.indexPath))
return
} else if err != nil {
// if we got an error (that wasn't that the file doesn't exist) stating the
// file, return a 500 internal server error and stop
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// otherwise, use http.FileServer to serve the static dir
http.FileServer(http.Dir(h.staticPath)).ServeHTTP(w, r)
}

View file

@ -0,0 +1,18 @@
package goscrobble
import (
"os"
"github.com/sendgrid/sendgrid-go"
"github.com/sendgrid/sendgrid-go/helpers/mail"
)
func sendEmail(destName string, destEmail string, subject string, content string) error {
from := mail.NewEmail(os.Getenv("MAIL_FROM_NAME"), os.Getenv("MAIL_FROM_ADDRESS"))
to := mail.NewEmail(destName, destEmail)
message := mail.NewSingleEmail(from, subject, to, content, "")
client := sendgrid.NewSendClient(os.Getenv("SENDGRID_API_KEY"))
_, err := client.Send(message)
return err
}

View file

@ -0,0 +1,14 @@
package goscrobble
import (
"fmt"
"time"
)
func ClearTokenTimer() {
go func() {
for now := range time.Tick(time.Second) {
fmt.Println(now)
}
}()
}

View file

@ -7,6 +7,7 @@ import (
"fmt"
"log"
"net"
"os"
"strings"
"time"
@ -153,13 +154,13 @@ func insertUser(username string, email string, password []byte, ip net.IP) error
}
func updateUser(uuid string, field string, value string, ip net.IP) error {
_, err := db.Exec("UPDATE users SET ? = ?, modified_at = NOW(), modified_ip = ? WHERE uuid = ?", field, value, uuid, ip)
_, err := db.Exec("UPDATE users SET `"+field+"` = ?, modified_at = NOW(), modified_ip = ? WHERE uuid = ?", value, uuid, ip)
return err
}
func updateUserDirect(uuid string, field string, value string) error {
_, err := db.Exec("UPDATE users SET ? = ? WHERE uuid = ?", field, value, uuid)
_, err := db.Exec("UPDATE users SET `"+field+"` = ? WHERE uuid = ?", value, uuid)
return err
}
@ -228,3 +229,90 @@ func getUserByUsername(username string) (User, error) {
return user, nil
}
func getUserByEmail(email string) (User, error) {
var user User
err := db.QueryRow("SELECT BIN_TO_UUID(`uuid`, true), `created_at`, `created_ip`, `modified_at`, `modified_ip`, `username`, `email`, `password`, `verified`, `admin` FROM `users` WHERE `email` = ? AND `active` = 1",
email).Scan(&user.UUID, &user.CreatedAt, &user.CreatedIp, &user.ModifiedAt, &user.ModifiedIP, &user.Username, &user.Email, &user.Password, &user.Verified, &user.Admin)
if err == sql.ErrNoRows {
return user, errors.New("Invalid Email")
}
return user, nil
}
func getUserByResetToken(token string) (User, error) {
var user User
err := db.QueryRow("SELECT BIN_TO_UUID(`users`.`uuid`, true), `created_at`, `created_ip`, `modified_at`, `modified_ip`, `username`, `email`, `password`, `verified`, `admin` FROM `users` "+
"JOIN `resettoken` ON `resettoken`.`user` = `users`.`uuid` WHERE `resettoken`.`token` = ? AND `active` = 1",
token).Scan(&user.UUID, &user.CreatedAt, &user.CreatedIp, &user.ModifiedAt, &user.ModifiedIP, &user.Username, &user.Email, &user.Password, &user.Verified, &user.Admin)
fmt.Println(err)
if err == sql.ErrNoRows {
return user, errors.New("Invalid Token")
}
return user, nil
}
func (user *User) sendResetEmail(ip net.IP) error {
token := generateToken(16)
// 24 hours
exp := time.Now().AddDate(0, 0, 1)
err := user.saveResetToken(token, exp)
if err != nil {
return err
}
content := fmt.Sprintf(
"Someone at %s has request a password reset for %s. Click the following link to reset your password: %s/reset/%s",
ip, user.Username, os.Getenv("GOSCROBBLE_DOMAIN"), token)
return sendEmail(user.Username, user.Email, "GoScrobble - Password Reset", content)
}
func (user *User) saveResetToken(token string, expiry time.Time) error {
_, _ = db.Exec("DELETE FROM `resettoken` WHERE `user` = UUID_TO_BIN(?, true)", user.UUID)
_, err := db.Exec("INSERT INTO `resettoken` (`user`, `token`, `expiry`) "+
"VALUES (UUID_TO_BIN(?, true),?, ?)", user.UUID, token, expiry)
return err
}
func clearOldResetTokens() {
_, _ = db.Exec("DELETE FROM `resettoken` WHERE `expiry` < NOW()")
}
func clearResetToken(token string) error {
_, err := db.Exec("DELETE FROM `resettoken` WHERE `token` = ?", token)
return err
}
// checkResetToken - If a token exists check it
func checkResetToken(token string) (bool, error) {
count, err := getDbCount("SELECT COUNT(*) FROM `resettoken` WHERE `token` = ? ", token)
if err != nil {
return false, err
}
return count > 0, nil
}
func (user *User) updatePassword(newPassword string, ip net.IP) error {
hash, err := hashPassword(newPassword)
if err != nil {
return errors.New("Bad password")
}
_, err = db.Exec("UPDATE `users` SET `password` = ? WHERE `uuid` = UUID_TO_BIN(?, true)", hash, user.UUID)
if err != nil {
return errors.New("Failed to update password")
}
return nil
}