Feature: Auth (#4)

* feat(api): add auth

* feat(web): add auth and refactor

* refactor(web): baseurl

* feat: add autobrrctl cli for user creation

* build: move static assets

* refactor(web): auth guard and routing

* refactor: rename var

* fix: remove subrouter

* build: update default config
This commit is contained in:
Ludvig Lundgren 2021-08-14 14:19:21 +02:00 committed by GitHub
parent 2e8d0950c1
commit 40b855bf39
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
56 changed files with 1208 additions and 257 deletions

51
internal/auth/service.go Normal file
View file

@ -0,0 +1,51 @@
package auth
import (
"errors"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/user"
"github.com/autobrr/autobrr/pkg/argon2id"
)
type Service interface {
Login(username, password string) (*domain.User, error)
}
type service struct {
userSvc user.Service
}
func NewService(userSvc user.Service) Service {
return &service{
userSvc: userSvc,
}
}
func (s *service) Login(username, password string) (*domain.User, error) {
if username == "" || password == "" {
return nil, errors.New("bad credentials")
}
// find user
u, err := s.userSvc.FindByUsername(username)
if err != nil {
return nil, err
}
if u == nil {
return nil, errors.New("bad credentials")
}
// compare password from request and the saved password
match, err := argon2id.ComparePasswordAndHash(password, u.Password)
if err != nil {
return nil, errors.New("error checking credentials")
}
if !match {
return nil, errors.New("bad credentials")
}
return u, nil
}

View file

@ -7,30 +7,21 @@ import (
"path"
"path/filepath"
"github.com/autobrr/autobrr/internal/domain"
"github.com/spf13/viper"
)
type Cfg struct {
Host string `toml:"host"`
Port int `toml:"port"`
LogLevel string `toml:"logLevel"`
LogPath string `toml:"logPath"`
BaseURL string `toml:"baseUrl"`
}
var Config domain.Config
var Config Cfg
func Defaults() Cfg {
hostname, err := os.Hostname()
if err != nil {
hostname = "localhost"
}
return Cfg{
Host: hostname,
Port: 8989,
LogLevel: "DEBUG",
LogPath: "",
BaseURL: "/",
func Defaults() domain.Config {
return domain.Config{
Host: "localhost",
Port: 8989,
LogLevel: "DEBUG",
LogPath: "",
BaseURL: "/",
SessionSecret: "secret-session-key",
}
}
@ -92,7 +83,12 @@ port = 8989
#
# Options: "ERROR", "DEBUG", "INFO", "WARN"
#
logLevel = "DEBUG"`)
logLevel = "DEBUG"
# Session secret
#
sessionSecret = "secret-session-key"`)
if err != nil {
log.Printf("error writing contents to file: %v %q", configPath, err)
return err
@ -105,7 +101,7 @@ logLevel = "DEBUG"`)
return nil
}
func Read(configPath string) Cfg {
func Read(configPath string) domain.Config {
config := Defaults()
// or use viper.SetDefault(val, def)

View file

@ -3,11 +3,18 @@ package database
import (
"database/sql"
"fmt"
"github.com/rs/zerolog/log"
)
const schema = `
CREATE TABLE users
(
id INTEGER PRIMARY KEY,
username TEXT NOT NULL,
password TEXT NOT NULL,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE indexer
(
id INTEGER PRIMARY KEY,
@ -135,8 +142,6 @@ var migrations = []string{
}
func Migrate(db *sql.DB) error {
log.Info().Msg("Migrating database...")
var version int
if err := db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
return fmt.Errorf("failed to query schema version: %v", err)

47
internal/database/user.go Normal file
View file

@ -0,0 +1,47 @@
package database
import (
"database/sql"
"github.com/rs/zerolog/log"
"github.com/autobrr/autobrr/internal/domain"
)
type UserRepo struct {
db *sql.DB
}
func NewUserRepo(db *sql.DB) domain.UserRepo {
return &UserRepo{db: db}
}
func (r *UserRepo) FindByUsername(username string) (*domain.User, error) {
query := `SELECT username, password FROM users WHERE username = ?`
row := r.db.QueryRow(query, username)
if err := row.Err(); err != nil {
return nil, err
}
var user domain.User
if err := row.Scan(&user.Username, &user.Password); err != nil {
log.Error().Err(err).Msg("could not scan user to struct")
return nil, err
}
return &user, nil
}
func (r *UserRepo) Store(user domain.User) error {
query := `INSERT INTO users (username, password) VALUES (?, ?)`
_, err := r.db.Exec(query, user.Username, user.Password)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return err
}
return nil
}

View file

@ -1,11 +1,10 @@
package domain
type Settings struct {
Host string `toml:"host"`
Debug bool
type Config struct {
Host string `toml:"host"`
Port int `toml:"port"`
LogLevel string `toml:"logLevel"`
LogPath string `toml:"logPath"`
BaseURL string `toml:"baseUrl"`
SessionSecret string `toml:"sessionSecret"`
}
//type AppConfig struct {
// Settings `toml:"settings"`
// Trackers []Tracker `mapstructure:"tracker"`
//}

11
internal/domain/user.go Normal file
View file

@ -0,0 +1,11 @@
package domain
type UserRepo interface {
FindByUsername(username string) (*User, error)
Store(user User) error
}
type User struct {
Username string `json:"username"`
Password string `json:"password"`
}

86
internal/http/auth.go Normal file
View file

@ -0,0 +1,86 @@
package http
import (
"encoding/json"
"net/http"
"github.com/go-chi/chi"
"github.com/gorilla/sessions"
"github.com/autobrr/autobrr/internal/config"
"github.com/autobrr/autobrr/internal/domain"
)
type authService interface {
Login(username, password string) (*domain.User, error)
}
type authHandler struct {
encoder encoder
authService authService
}
var (
// key will only be valid as long as it's running.
key = []byte(config.Config.SessionSecret)
store = sessions.NewCookieStore(key)
)
func (h authHandler) Routes(r chi.Router) {
r.Post("/login", h.login)
r.Post("/logout", h.logout)
r.Get("/test", h.test)
}
func (h authHandler) login(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.User
)
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
return
}
session, _ := store.Get(r, "user_session")
_, err := h.authService.Login(data.Username, data.Password)
if err != nil {
h.encoder.StatusResponse(ctx, w, nil, http.StatusUnauthorized)
return
}
// Set user as authenticated
session.Values["authenticated"] = true
session.Save(r, w)
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
}
func (h authHandler) logout(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
session, _ := store.Get(r, "user_session")
// Revoke users authentication
session.Values["authenticated"] = false
session.Save(r, w)
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
}
func (h authHandler) test(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
session, _ := store.Get(r, "user_session")
// Check if user is authenticated
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
http.Error(w, "Forbidden", http.StatusForbidden)
return
}
// send empty response as ok
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
}

View file

@ -0,0 +1,17 @@
package http
import "net/http"
func IsAuthenticated(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// check session
session, _ := store.Get(r, "user_session")
// Check if user is authenticated
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
next.ServeHTTP(w, r)
})
}

View file

@ -15,17 +15,19 @@ type Server struct {
address string
baseUrl string
actionService actionService
authService authService
downloadClientService downloadClientService
filterService filterService
indexerService indexerService
ircService ircService
}
func NewServer(address string, baseUrl string, actionService actionService, downloadClientSvc downloadClientService, filterSvc filterService, indexerSvc indexerService, ircSvc ircService) Server {
func NewServer(address string, baseUrl string, actionService actionService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, indexerSvc indexerService, ircSvc ircService) Server {
return Server{
address: address,
baseUrl: baseUrl,
actionService: actionService,
authService: authService,
downloadClientService: downloadClientSvc,
filterService: filterSvc,
indexerService: indexerSvc,
@ -62,7 +64,15 @@ func (s Server) Handler() http.Handler {
fileSystem.ServeHTTP(w, r)
})
authHandler := authHandler{
encoder: encoder,
authService: s.authService,
}
r.Route("/api/auth", authHandler.Routes)
r.Group(func(r chi.Router) {
r.Use(IsAuthenticated)
actionHandler := actionHandler{
encoder: encoder,

View file

@ -5,14 +5,14 @@ import (
"os"
"time"
"github.com/autobrr/autobrr/internal/config"
"github.com/autobrr/autobrr/internal/domain"
"github.com/rs/zerolog"
"github.com/rs/zerolog/log"
"gopkg.in/natefinch/lumberjack.v2"
)
func Setup(cfg config.Cfg) {
func Setup(cfg domain.Config) {
zerolog.TimeFieldFormat = time.RFC3339
switch cfg.LogLevel {

26
internal/user/service.go Normal file
View file

@ -0,0 +1,26 @@
package user
import "github.com/autobrr/autobrr/internal/domain"
type Service interface {
FindByUsername(username string) (*domain.User, error)
}
type service struct {
repo domain.UserRepo
}
func NewService(repo domain.UserRepo) Service {
return &service{
repo: repo,
}
}
func (s *service) FindByUsername(username string) (*domain.User, error) {
user, err := s.repo.FindByUsername(username)
if err != nil {
return nil, err
}
return user, nil
}