- Login flow working..

- Jellyfin scrobble working
- Returns scrobbles via API for authed users /api/v1/user/{uuid}/scrobble
- Add redis handler + funcs
- Move middleware to pass in uuid as needed
This commit is contained in:
Daniel Mason 2021-03-29 20:56:34 +13:00
parent c83c086cdd
commit 5fd9d41069
Signed by: idanoo
GPG key ID: 387387CDBC02F132
54 changed files with 1093 additions and 386 deletions

View file

@ -29,7 +29,7 @@ func InitDb() {
dbTz = "&loc=" + strings.Replace(timeZone, "/", fmt.Sprintf("%%2F"), 1)
}
dbConn, err := sql.Open("mysql", dbUser+":"+dbPass+"@tcp("+dbHost+")/"+dbName+"?multiStatements=true"+dbTz)
dbConn, err := sql.Open("mysql", dbUser+":"+dbPass+"@tcp("+dbHost+")/"+dbName+"?multiStatements=true&parseTime=true"+dbTz)
if err != nil {
panic(err)
}

View file

@ -1,7 +1,6 @@
package goscrobble
import (
"net/http"
"time"
"github.com/dgrijalva/jwt-go"
@ -13,34 +12,51 @@ var JwtToken []byte
// JwtExpiry - Expiry in seconds
var JwtExpiry time.Duration
// Store custom claims here
type Claims struct {
UUID string `json:"uuid"`
type CustomClaims struct {
Username string `json:"username"`
Email string `json:"email"`
jwt.StandardClaims
}
// verifyToken - Verifies the JWT is valid
func verifyToken(token string, w http.ResponseWriter) bool {
// Initialize a new instance of `Claims`
claims := &Claims{}
func generateJWTToken(user User) (string, error) {
atClaims := jwt.MapClaims{}
atClaims["sub"] = user.UUID
atClaims["username"] = user.Username
atClaims["email"] = user.Email
atClaims["iat"] = time.Now().Unix()
atClaims["exp"] = time.Now().Add(JwtExpiry).Unix()
at := jwt.NewWithClaims(jwt.SigningMethodHS512, atClaims)
token, err := at.SignedString(JwtToken)
if err != nil {
return "", err
}
tkn, err := jwt.ParseWithClaims(token, claims, func(JwtToken *jwt.Token) (interface{}, error) {
return token, nil
}
// verifyToken - Verifies the JWT is valid
func verifyJWTToken(token string) (CustomClaims, error) {
// Initialize a new instance of `Claims`
claims := CustomClaims{}
_, err := jwt.ParseWithClaims(token, &claims, func(token *jwt.Token) (interface{}, error) {
return JwtToken, nil
})
// Verify Signature
if err != nil {
if err == jwt.ErrSignatureInvalid {
w.WriteHeader(http.StatusUnauthorized)
return false
}
w.WriteHeader(http.StatusBadRequest)
return false
}
if !tkn.Valid {
w.WriteHeader(http.StatusUnauthorized)
return false
return claims, err
}
return true
// Verify expiry
err = claims.Valid()
if err != nil {
return claims, err
}
return claims, err
}
func getClaims(token *jwt.Token) CustomClaims {
claims, _ := token.Claims.(CustomClaims)
return claims
}

View file

@ -0,0 +1,75 @@
package goscrobble
import (
"context"
"fmt"
"log"
"os"
"strconv"
"time"
"github.com/go-redis/redis/v8"
)
var redisDb *redis.Client
var redisPrefix string
var ctx = context.Background()
// InitRedis - Boot redis connection!
func InitRedis() {
redisHost := os.Getenv("REDIS_HOST")
redisPort := os.Getenv("REDIS_PORT")
redisDatabase := os.Getenv("REDIS_DB")
redisAuth := os.Getenv("REDIS_AUTH")
redisPrefix = os.Getenv("REDIS_PREFIX")
redisDbNum := 0
if redisDatabase != "" {
redisDbNum, _ = strconv.Atoi(redisDatabase)
}
// Create new connection
redisDb = redis.NewClient(&redis.Options{
Addr: redisHost + ":" + redisPort,
Password: redisAuth,
DB: redisDbNum,
})
// Lets just check it's active..
err := redisDb.Set(ctx, "testSetKey", "value", 0).Err()
if err != nil {
panic(err)
}
redisDb.Del(ctx, "testSetKey")
fmt.Println("Redis connected")
}
func CloseRedisConn() {
redisDb.Close()
}
// setRedis - Uses default 24 hour TTL
func setRedisVal(key string, val string) error {
ttl := time.Hour * time.Duration(24)
return setRedisValTtl(key, val, ttl)
}
// setRedisTtl - Allows custom TTL
func setRedisValTtl(key string, val string, ttl time.Duration) error {
return redisDb.Set(ctx, redisPrefix+key, val, 0).Err()
}
// getRedisVal - Returns value if exists
func getRedisVal(key string) string {
val, err := redisDb.Get(ctx, redisPrefix+key).Result()
if err != nil {
if err == redis.Nil {
return ""
}
log.Printf("Failed to fetch redis key (%+v) Error: %+v", key, err)
}
return val
}

View file

@ -16,6 +16,25 @@ type Scrobble struct {
Track string `json:"track"`
}
type ScrobbleRequest struct {
Meta ScrobbleRequestMeta `json:"meta"`
Items []ScrobbleRequestItem `json:"items"`
}
type ScrobbleRequestMeta struct {
Count int `json:"count"`
Total int `json:"total"`
Page int `json:"page"`
}
type ScrobbleRequestItem struct {
UUID string `json:"uuid"`
Timestamp time.Time `json:"time"`
Artist string `json:"artist"`
Album string `json:"album"`
Track string `json:"track"`
}
// insertScrobble - This will return if it exists or create it based on MBID > Name
func insertScrobble(user string, track string, ip net.IP, tx *sql.Tx) error {
err := insertNewScrobble(user, track, ip, tx)
@ -27,19 +46,66 @@ func insertScrobble(user string, track string, ip net.IP, tx *sql.Tx) error {
return nil
}
func fetchScrobble(col string, val string, tx *sql.Tx) Scrobble {
var scrobble Scrobble
err := tx.QueryRow(
"SELECT BIN_TO_UUID(`uuid`, true), `created_at`, `created_ip`, `user`, `track` FROM `scrobbles` WHERE `"+col+"` = ?",
val).Scan(&scrobble.Uuid, &scrobble.CreatedAt, &scrobble.CreatedIp, &scrobble.User, &scrobble.Track)
func fetchScrobblesForUser(userUuid string, page int) (ScrobbleRequest, error) {
scrobbleReq := ScrobbleRequest{}
var count int
// Yeah this isn't great. But for now.. it works! Cache later
total, err := getDbCount(
"SELECT COUNT(*) FROM `scrobbles` "+
"JOIN tracks ON scrobbles.track = tracks.uuid "+
"JOIN track_artist ON track_artist.track = tracks.uuid "+
"JOIN track_album ON track_album.track = tracks.uuid "+
"JOIN artists ON track_artist.artist = artists.uuid "+
"JOIN albums ON track_album.album = albums.uuid "+
"JOIN users ON scrobbles.user = users.uuid "+
"WHERE user = UUID_TO_BIN(?, true)",
userUuid)
if err != nil {
if err != sql.ErrNoRows {
log.Printf("Error fetching scrobbles: %+v", err)
}
log.Printf("Failed to fetch scrobble count: %+v", err)
return scrobbleReq, errors.New("Failed to fetch scrobbles")
}
return scrobble
rows, err := db.Query(
"SELECT BIN_TO_UUID(`scrobbles`.`uuid`, true), `scrobbles`.`created_at`, `artists`.`name`, `albums`.`name`,`tracks`.`name` FROM `scrobbles` "+
"JOIN tracks ON scrobbles.track = tracks.uuid "+
"JOIN track_artist ON track_artist.track = tracks.uuid "+
"JOIN track_album ON track_album.track = tracks.uuid "+
"JOIN artists ON track_artist.artist = artists.uuid "+
"JOIN albums ON track_album.album = albums.uuid "+
"JOIN users ON scrobbles.user = users.uuid "+
"WHERE user = UUID_TO_BIN(?, true) "+
"ORDER BY scrobbles.created_at DESC LIMIT 500",
userUuid)
if err != nil {
log.Printf("Failed to fetch scrobbles: %+v", err)
return scrobbleReq, errors.New("Failed to fetch scrobbles")
}
defer rows.Close()
for rows.Next() {
item := ScrobbleRequestItem{}
err := rows.Scan(&item.UUID, &item.Timestamp, &item.Artist, &item.Album, &item.Track)
if err != nil {
log.Printf("Failed to fetch scrobbles: %+v", err)
return scrobbleReq, errors.New("Failed to fetch scrobbles")
}
count++
scrobbleReq.Items = append(scrobbleReq.Items, item)
}
err = rows.Err()
if err != nil {
log.Printf("Failed to fetch scrobbles: %+v", err)
return scrobbleReq, errors.New("Failed to fetch scrobbles")
}
scrobbleReq.Meta.Count = count
scrobbleReq.Meta.Total = total
scrobbleReq.Meta.Page = page
return scrobbleReq, nil
}
func insertNewScrobble(user string, track string, ip net.IP, tx *sql.Tx) error {

View file

@ -46,16 +46,14 @@ func HandleRequests(port string) {
v1 := r.PathPrefix("/api/v1").Subrouter()
// Static Token for /ingress
v1.HandleFunc("/ingress/jellyfin", tokenMiddleware(handleIngress))
v1.HandleFunc("/ingress/jellyfin", tokenMiddleware(handleIngress)).Methods("POST")
// JWT Auth
// v1.HandleFunc("/profile/{id}", jwtMiddleware(handleIngress))
v1.HandleFunc("/user/{id}/scrobbles", jwtMiddleware(fetchScrobbleResponse)).Methods("GET")
// No Auth
v1.HandleFunc("/register", limitMiddleware(handleRegister, heavyLimiter)).Methods("POST")
v1.HandleFunc("/login", limitMiddleware(handleLogin, standardLimiter)).Methods("POST")
// For now just trash JWT in frontend until we have full state management "Good enough"
// v1.HandleFunc("/logout", handleIngress).Methods("POST")
// This just prevents it serving frontend stuff over /api
r.PathPrefix("/api")
@ -65,9 +63,10 @@ func HandleRequests(port string) {
r.PathPrefix("/").Handler(spa)
c := cors.New(cors.Options{
// Grrrr CORS
// Grrrr CORS. To clean up at a later date
AllowedOrigins: []string{"*"},
AllowCredentials: true,
AllowedHeaders: []string{"*"},
})
handler := c.Handler(r)
@ -97,14 +96,24 @@ func throwBadReq(w http.ResponseWriter, m string) {
http.Error(w, err.Error(), http.StatusBadRequest)
}
// throwOkError - Throws a 403
func throwOkError(w http.ResponseWriter, m string) {
jr := jsonResponse{
Err: m,
}
js, _ := json.Marshal(&jr)
w.WriteHeader(http.StatusOK)
w.Write(js)
}
// throwOkMessage - Throws a happy 200
func throwOkMessage(w http.ResponseWriter, m string) {
jr := jsonResponse{
Msg: m,
}
js, _ := json.Marshal(&jr)
err := errors.New(string(js))
http.Error(w, err.Error(), http.StatusOK)
w.WriteHeader(http.StatusOK)
w.Write(js)
}
// generateJsonMessage - Generates a message:str response
@ -126,7 +135,7 @@ func generateJsonError(m string) []byte {
}
// tokenMiddleware - Validates token to a user
func tokenMiddleware(next http.HandlerFunc) http.HandlerFunc {
func tokenMiddleware(next func(http.ResponseWriter, *http.Request, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
@ -140,17 +149,29 @@ func tokenMiddleware(next http.HandlerFunc) http.HandlerFunc {
return
}
// Lets tack this on the request for now..
r.Header.Set("UserUUID", userUuid)
next(w, r)
next(w, r, userUuid)
}
}
// jwtMiddleware - Validates middleware to a user
func jwtMiddleware(next http.HandlerFunc) http.HandlerFunc {
func jwtMiddleware(next func(http.ResponseWriter, *http.Request, string, string)) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
throwUnauthorized(w, "Invalid JWT Token")
next(w, r)
fullToken := r.Header.Get("Authorization")
authToken := strings.Replace(fullToken, "Bearer ", "", 1)
claims, err := verifyJWTToken(authToken)
if err != nil {
throwUnauthorized(w, "Invalid JWT Token")
return
}
var v string
for k, v := range mux.Vars(r) {
if k == "id" {
log.Printf("key=%v, value=%v", k, v)
}
}
next(w, r, claims.Subject, v)
}
}
@ -188,9 +209,7 @@ func handleRegister(w http.ResponseWriter, r *http.Request) {
return
}
msg := generateJsonMessage("User created succesfully. You may now login")
w.WriteHeader(http.StatusCreated)
w.Write(msg)
throwOkMessage(w, "User created succesfully. You may now login")
}
// handleLogin - Does as it says!
@ -206,7 +225,7 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
ip := getUserIp(r)
data, err := loginUser(&logReq, ip)
if err != nil {
throwOkMessage(w, err.Error())
throwOkError(w, err.Error())
return
}
@ -215,26 +234,29 @@ func handleLogin(w http.ResponseWriter, r *http.Request) {
}
// serveEndpoint - API stuffs
func handleIngress(w http.ResponseWriter, r *http.Request) {
func handleIngress(w http.ResponseWriter, r *http.Request, userUuid string) {
bodyJson, err := decodeJson(r.Body)
if err != nil {
// If we can't decode. Lets tell them nicely.
http.Error(w, "{\"error\":\"Invalid JSON\"}", http.StatusBadRequest)
return
}
ingressType := strings.Replace(r.URL.Path, "/api/v1/ingress/", "", 1)
switch ingressType {
case "jellyfin":
tx, _ := db.Begin()
ip := getUserIp(r)
err := ParseJellyfinInput(r.Header.Get("UserUUID"), bodyJson, ip, tx)
err := ParseJellyfinInput(userUuid, bodyJson, ip, tx)
if err != nil {
log.Printf("Error inserting track: %+v", err)
tx.Rollback()
throwBadReq(w, err.Error())
return
}
err = tx.Commit()
if err != nil {
throwBadReq(w, err.Error())
@ -248,6 +270,20 @@ func handleIngress(w http.ResponseWriter, r *http.Request) {
throwBadReq(w, "Unknown ingress type")
}
// fetchScrobbles - Return an array of scrobbles
func fetchScrobbleResponse(w http.ResponseWriter, r *http.Request, jwtUser string, reqUser string) {
resp, err := fetchScrobblesForUser(reqUser, 1)
if err != nil {
throwBadReq(w, "Failed to fetch scrobbles")
return
}
// Fetch last 500 scrobbles
json, _ := json.Marshal(&resp)
w.WriteHeader(http.StatusOK)
w.Write(json)
}
// FRONTEND HANDLING
// ServerHTTP - Frontend server

View file

@ -17,9 +17,16 @@ func generateToken(n int) string {
func getUserForToken(token string) (string, error) {
var uuid string
err := db.QueryRow("SELECT BIN_TO_UUID(`uuid`, true) FROM `users` WHERE `token` = ? AND `active` = 1", token).Scan(&uuid)
if err != nil {
return "", errors.New("Invalid Token")
cachedKey := getRedisVal("user_token:" + token)
if cachedKey == "" {
err := db.QueryRow("SELECT BIN_TO_UUID(`uuid`, true) FROM `users` WHERE `token` = ? AND `active` = 1", token).Scan(&uuid)
if err != nil {
return "", errors.New("Invalid Token")
}
setRedisVal("user_token:"+token, uuid)
} else {
uuid = cachedKey
}
return uuid, nil
}

View file

@ -10,7 +10,6 @@ import (
"strings"
"time"
"github.com/dgrijalva/jwt-go"
"golang.org/x/crypto/bcrypt"
)
@ -104,14 +103,16 @@ func loginUser(logReq *LoginRequest, ip net.IP) ([]byte, error) {
}
if strings.Contains(logReq.Username, "@") {
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE email = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
err := db.QueryRow("SELECT BIN_TO_UUID(`uuid`, true), `username`, `email`, `password` FROM `users` WHERE `email` = ? AND `active` = 1",
logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
if err != nil {
if err == sql.ErrNoRows {
return resp, errors.New("Invalid Username or Password")
}
}
} else {
err := db.QueryRow("SELECT BIN_TO_UUID(uuid), username, email, password FROM users WHERE username = ? AND active = 1", logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
err := db.QueryRow("SELECT BIN_TO_UUID(`uuid`, true), `username`, `email`, `password` FROM `users` WHERE `username` = ? AND `active` = 1",
logReq.Username).Scan(&user.UUID, &user.Username, &user.Email, &user.Password)
if err == sql.ErrNoRows {
return resp, errors.New("Invalid Username or Password")
}
@ -122,7 +123,7 @@ func loginUser(logReq *LoginRequest, ip net.IP) ([]byte, error) {
}
// Issue JWT + Response
token, err := generateJwt(user)
token, err := generateJWTToken(user)
if err != nil {
log.Printf("Error generating JWT: %v", err)
return resp, errors.New("Error logging in")
@ -136,21 +137,6 @@ func loginUser(logReq *LoginRequest, ip net.IP) ([]byte, error) {
return resp, nil
}
func generateJwt(user User) (string, error) {
atClaims := jwt.MapClaims{}
atClaims["sub"] = user.UUID
atClaims["username"] = user.Username
atClaims["email"] = user.Email
atClaims["exp"] = time.Now().Add(JwtExpiry).Unix()
at := jwt.NewWithClaims(jwt.SigningMethodHS512, atClaims)
token, err := at.SignedString(JwtToken)
if err != nil {
return "", err
}
return token, nil
}
// insertUser - Does the dirtywork!
func insertUser(username string, email string, password []byte, ip net.IP) error {
token := generateToken(32)

View file

@ -3,6 +3,7 @@ package goscrobble
import (
"encoding/hex"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
@ -86,3 +87,7 @@ func Inet6_Aton(ip net.IP) string {
ipHex := hex.EncodeToString(ipInt.Bytes())
return ipHex
}
func calcPageOffsetString(page int, offset int) string {
return fmt.Sprintf("%d", page*offset)
}