mirror of
https://github.com/idanoo/autobrr
synced 2025-07-23 08:49:13 +00:00
Feature: Auth (#4)
* feat(api): add auth * feat(web): add auth and refactor * refactor(web): baseurl * feat: add autobrrctl cli for user creation * build: move static assets * refactor(web): auth guard and routing * refactor: rename var * fix: remove subrouter * build: update default config
This commit is contained in:
parent
2e8d0950c1
commit
40b855bf39
56 changed files with 1208 additions and 257 deletions
51
internal/auth/service.go
Normal file
51
internal/auth/service.go
Normal file
|
@ -0,0 +1,51 @@
|
|||
package auth
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
"github.com/autobrr/autobrr/internal/user"
|
||||
"github.com/autobrr/autobrr/pkg/argon2id"
|
||||
)
|
||||
|
||||
type Service interface {
|
||||
Login(username, password string) (*domain.User, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
userSvc user.Service
|
||||
}
|
||||
|
||||
func NewService(userSvc user.Service) Service {
|
||||
return &service{
|
||||
userSvc: userSvc,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) Login(username, password string) (*domain.User, error) {
|
||||
if username == "" || password == "" {
|
||||
return nil, errors.New("bad credentials")
|
||||
}
|
||||
|
||||
// find user
|
||||
u, err := s.userSvc.FindByUsername(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if u == nil {
|
||||
return nil, errors.New("bad credentials")
|
||||
}
|
||||
|
||||
// compare password from request and the saved password
|
||||
match, err := argon2id.ComparePasswordAndHash(password, u.Password)
|
||||
if err != nil {
|
||||
return nil, errors.New("error checking credentials")
|
||||
}
|
||||
|
||||
if !match {
|
||||
return nil, errors.New("bad credentials")
|
||||
}
|
||||
|
||||
return u, nil
|
||||
}
|
|
@ -7,30 +7,21 @@ import (
|
|||
"path"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
type Cfg struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
LogLevel string `toml:"logLevel"`
|
||||
LogPath string `toml:"logPath"`
|
||||
BaseURL string `toml:"baseUrl"`
|
||||
}
|
||||
var Config domain.Config
|
||||
|
||||
var Config Cfg
|
||||
|
||||
func Defaults() Cfg {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "localhost"
|
||||
}
|
||||
return Cfg{
|
||||
Host: hostname,
|
||||
Port: 8989,
|
||||
LogLevel: "DEBUG",
|
||||
LogPath: "",
|
||||
BaseURL: "/",
|
||||
func Defaults() domain.Config {
|
||||
return domain.Config{
|
||||
Host: "localhost",
|
||||
Port: 8989,
|
||||
LogLevel: "DEBUG",
|
||||
LogPath: "",
|
||||
BaseURL: "/",
|
||||
SessionSecret: "secret-session-key",
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -92,7 +83,12 @@ port = 8989
|
|||
#
|
||||
# Options: "ERROR", "DEBUG", "INFO", "WARN"
|
||||
#
|
||||
logLevel = "DEBUG"`)
|
||||
logLevel = "DEBUG"
|
||||
|
||||
# Session secret
|
||||
#
|
||||
sessionSecret = "secret-session-key"`)
|
||||
|
||||
if err != nil {
|
||||
log.Printf("error writing contents to file: %v %q", configPath, err)
|
||||
return err
|
||||
|
@ -105,7 +101,7 @@ logLevel = "DEBUG"`)
|
|||
return nil
|
||||
}
|
||||
|
||||
func Read(configPath string) Cfg {
|
||||
func Read(configPath string) domain.Config {
|
||||
config := Defaults()
|
||||
|
||||
// or use viper.SetDefault(val, def)
|
||||
|
|
|
@ -3,11 +3,18 @@ package database
|
|||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const schema = `
|
||||
CREATE TABLE users
|
||||
(
|
||||
id INTEGER PRIMARY KEY,
|
||||
username TEXT NOT NULL,
|
||||
password TEXT NOT NULL,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE indexer
|
||||
(
|
||||
id INTEGER PRIMARY KEY,
|
||||
|
@ -135,8 +142,6 @@ var migrations = []string{
|
|||
}
|
||||
|
||||
func Migrate(db *sql.DB) error {
|
||||
log.Info().Msg("Migrating database...")
|
||||
|
||||
var version int
|
||||
if err := db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
|
||||
return fmt.Errorf("failed to query schema version: %v", err)
|
||||
|
|
47
internal/database/user.go
Normal file
47
internal/database/user.go
Normal file
|
@ -0,0 +1,47 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
)
|
||||
|
||||
type UserRepo struct {
|
||||
db *sql.DB
|
||||
}
|
||||
|
||||
func NewUserRepo(db *sql.DB) domain.UserRepo {
|
||||
return &UserRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *UserRepo) FindByUsername(username string) (*domain.User, error) {
|
||||
query := `SELECT username, password FROM users WHERE username = ?`
|
||||
|
||||
row := r.db.QueryRow(query, username)
|
||||
if err := row.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var user domain.User
|
||||
|
||||
if err := row.Scan(&user.Username, &user.Password); err != nil {
|
||||
log.Error().Err(err).Msg("could not scan user to struct")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (r *UserRepo) Store(user domain.User) error {
|
||||
query := `INSERT INTO users (username, password) VALUES (?, ?)`
|
||||
|
||||
_, err := r.db.Exec(query, user.Username, user.Password)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -1,11 +1,10 @@
|
|||
package domain
|
||||
|
||||
type Settings struct {
|
||||
Host string `toml:"host"`
|
||||
Debug bool
|
||||
type Config struct {
|
||||
Host string `toml:"host"`
|
||||
Port int `toml:"port"`
|
||||
LogLevel string `toml:"logLevel"`
|
||||
LogPath string `toml:"logPath"`
|
||||
BaseURL string `toml:"baseUrl"`
|
||||
SessionSecret string `toml:"sessionSecret"`
|
||||
}
|
||||
|
||||
//type AppConfig struct {
|
||||
// Settings `toml:"settings"`
|
||||
// Trackers []Tracker `mapstructure:"tracker"`
|
||||
//}
|
||||
|
|
11
internal/domain/user.go
Normal file
11
internal/domain/user.go
Normal file
|
@ -0,0 +1,11 @@
|
|||
package domain
|
||||
|
||||
type UserRepo interface {
|
||||
FindByUsername(username string) (*User, error)
|
||||
Store(user User) error
|
||||
}
|
||||
|
||||
type User struct {
|
||||
Username string `json:"username"`
|
||||
Password string `json:"password"`
|
||||
}
|
86
internal/http/auth.go
Normal file
86
internal/http/auth.go
Normal file
|
@ -0,0 +1,86 @@
|
|||
package http
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/gorilla/sessions"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/config"
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
)
|
||||
|
||||
type authService interface {
|
||||
Login(username, password string) (*domain.User, error)
|
||||
}
|
||||
|
||||
type authHandler struct {
|
||||
encoder encoder
|
||||
authService authService
|
||||
}
|
||||
|
||||
var (
|
||||
// key will only be valid as long as it's running.
|
||||
key = []byte(config.Config.SessionSecret)
|
||||
store = sessions.NewCookieStore(key)
|
||||
)
|
||||
|
||||
func (h authHandler) Routes(r chi.Router) {
|
||||
r.Post("/login", h.login)
|
||||
r.Post("/logout", h.logout)
|
||||
r.Get("/test", h.test)
|
||||
}
|
||||
|
||||
func (h authHandler) login(w http.ResponseWriter, r *http.Request) {
|
||||
var (
|
||||
ctx = r.Context()
|
||||
data domain.User
|
||||
)
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
|
||||
// encode error
|
||||
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
session, _ := store.Get(r, "user_session")
|
||||
|
||||
_, err := h.authService.Login(data.Username, data.Password)
|
||||
if err != nil {
|
||||
h.encoder.StatusResponse(ctx, w, nil, http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Set user as authenticated
|
||||
session.Values["authenticated"] = true
|
||||
session.Save(r, w)
|
||||
|
||||
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h authHandler) logout(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
session, _ := store.Get(r, "user_session")
|
||||
|
||||
// Revoke users authentication
|
||||
session.Values["authenticated"] = false
|
||||
session.Save(r, w)
|
||||
|
||||
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
|
||||
}
|
||||
|
||||
func (h authHandler) test(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
session, _ := store.Get(r, "user_session")
|
||||
|
||||
// Check if user is authenticated
|
||||
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
|
||||
http.Error(w, "Forbidden", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// send empty response as ok
|
||||
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
|
||||
}
|
17
internal/http/middleware.go
Normal file
17
internal/http/middleware.go
Normal file
|
@ -0,0 +1,17 @@
|
|||
package http
|
||||
|
||||
import "net/http"
|
||||
|
||||
func IsAuthenticated(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
// check session
|
||||
session, _ := store.Get(r, "user_session")
|
||||
|
||||
// Check if user is authenticated
|
||||
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
|
||||
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
|
@ -15,17 +15,19 @@ type Server struct {
|
|||
address string
|
||||
baseUrl string
|
||||
actionService actionService
|
||||
authService authService
|
||||
downloadClientService downloadClientService
|
||||
filterService filterService
|
||||
indexerService indexerService
|
||||
ircService ircService
|
||||
}
|
||||
|
||||
func NewServer(address string, baseUrl string, actionService actionService, downloadClientSvc downloadClientService, filterSvc filterService, indexerSvc indexerService, ircSvc ircService) Server {
|
||||
func NewServer(address string, baseUrl string, actionService actionService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, indexerSvc indexerService, ircSvc ircService) Server {
|
||||
return Server{
|
||||
address: address,
|
||||
baseUrl: baseUrl,
|
||||
actionService: actionService,
|
||||
authService: authService,
|
||||
downloadClientService: downloadClientSvc,
|
||||
filterService: filterSvc,
|
||||
indexerService: indexerSvc,
|
||||
|
@ -62,7 +64,15 @@ func (s Server) Handler() http.Handler {
|
|||
fileSystem.ServeHTTP(w, r)
|
||||
})
|
||||
|
||||
authHandler := authHandler{
|
||||
encoder: encoder,
|
||||
authService: s.authService,
|
||||
}
|
||||
|
||||
r.Route("/api/auth", authHandler.Routes)
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(IsAuthenticated)
|
||||
|
||||
actionHandler := actionHandler{
|
||||
encoder: encoder,
|
||||
|
|
|
@ -5,14 +5,14 @@ import (
|
|||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/config"
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/natefinch/lumberjack.v2"
|
||||
)
|
||||
|
||||
func Setup(cfg config.Cfg) {
|
||||
func Setup(cfg domain.Config) {
|
||||
zerolog.TimeFieldFormat = time.RFC3339
|
||||
|
||||
switch cfg.LogLevel {
|
||||
|
|
26
internal/user/service.go
Normal file
26
internal/user/service.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package user
|
||||
|
||||
import "github.com/autobrr/autobrr/internal/domain"
|
||||
|
||||
type Service interface {
|
||||
FindByUsername(username string) (*domain.User, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
repo domain.UserRepo
|
||||
}
|
||||
|
||||
func NewService(repo domain.UserRepo) Service {
|
||||
return &service{
|
||||
repo: repo,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *service) FindByUsername(username string) (*domain.User, error) {
|
||||
user, err := s.repo.FindByUsername(username)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue