refactor(http): auth handlers (#1311)

* fix(auth): implement invalid cookie handling

* that escalated quickly

* refactor(http): auth handlers

* add tests for auth handler
* refactor methods

* chore(tests): add header and build tag

* add build tag integration

* chore(tests): run in ci

---------

Co-authored-by: ze0s <ze0s@riseup.net>
This commit is contained in:
Kyle Sanderson 2023-12-27 17:04:25 -08:00 committed by GitHub
parent df2612602b
commit 6a94ecacca
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
18 changed files with 537 additions and 80 deletions

View file

@ -83,7 +83,7 @@ jobs:
cache: true
- name: Test
run: go test -v ./...
run: go test -v ./... -tags=integration
goreleaserbuild:
name: Build distribution binaries

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -1,3 +1,8 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (

View file

@ -33,7 +33,7 @@ type authHandler struct {
cookieStore *sessions.CookieStore
}
func newAuthHandler(encoder encoder, log zerolog.Logger, config *domain.Config, cookieStore *sessions.CookieStore, service authService, server Server) *authHandler {
func newAuthHandler(encoder encoder, log zerolog.Logger, server Server, config *domain.Config, cookieStore *sessions.CookieStore, service authService) *authHandler {
return &authHandler{
log: log,
encoder: encoder,
@ -46,26 +46,21 @@ func newAuthHandler(encoder encoder, log zerolog.Logger, config *domain.Config,
func (h authHandler) Routes(r chi.Router) {
r.Post("/login", h.login)
r.Post("/logout", h.logout)
r.Post("/onboard", h.onboard)
r.Get("/onboard", h.canOnboard)
r.Get("/validate", h.validate)
// Group for authenticated routes
r.Group(func(r chi.Router) {
r.Use(h.server.IsAuthenticated)
// Authenticated routes
r.Post("/logout", h.logout)
r.Get("/validate", h.validate)
r.Patch("/user/{username}", h.updateUser)
})
}
func (h authHandler) login(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.User
)
var data domain.User
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.StatusError(w, http.StatusBadRequest, errors.Wrap(err, "could not decode json"))
return
@ -79,63 +74,64 @@ func (h authHandler) login(w http.ResponseWriter, r *http.Request) {
// if forwarded protocol is https then set cookie secure
// SameSite Strict can only be set with a secure cookie. So we overwrite it here if possible.
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie/SameSite
fwdProto := r.Header.Get("X-Forwarded-Proto")
if fwdProto == "https" {
if r.Header.Get("X-Forwarded-Proto") == "https" {
h.cookieStore.Options.Secure = true
h.cookieStore.Options.SameSite = http.SameSiteStrictMode
}
if _, err := h.service.Login(ctx, data.Username, data.Password); err != nil {
h.log.Error().Err(err).Msgf("Auth: Failed login attempt username: [%s] ip: %s", data.Username, ReadUserIP(r))
if _, err := h.service.Login(r.Context(), data.Username, data.Password); err != nil {
h.log.Error().Err(err).Msgf("Auth: Failed login attempt username: [%s] ip: %s", data.Username, r.RemoteAddr)
h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("could not login: bad credentials"))
return
}
// create new session
session, _ := h.cookieStore.Get(r, "user_session")
session, err := h.cookieStore.New(r, "user_session")
if err != nil {
h.log.Error().Err(err).Msgf("Auth: Failed to parse cookies with attempt username: [%s] ip: %s", data.Username, r.RemoteAddr)
h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("could not parse cookies"))
return
}
// Set user as authenticated
session.Values["authenticated"] = true
if err := session.Save(r, w); err != nil {
h.encoder.StatusError(w, http.StatusInternalServerError, errors.Wrap(err, "could not save session"))
return
}
h.encoder.StatusResponse(w, http.StatusNoContent, nil)
h.encoder.NoContent(w)
}
func (h authHandler) logout(w http.ResponseWriter, r *http.Request) {
session, _ := h.cookieStore.Get(r, "user_session")
// cookieStore.Get will create a new session if it does not exist
// so if it created a new then lets just return without saving it
if session.IsNew {
h.encoder.StatusResponse(w, http.StatusNoContent, nil)
// get session from context
session, ok := r.Context().Value("session").(*sessions.Session)
if !ok {
h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get session from context"))
return
}
// Revoke users authentication
session.Values["authenticated"] = false
session.Options.MaxAge = -1
if err := session.Save(r, w); err != nil {
h.encoder.StatusError(w, http.StatusInternalServerError, errors.Wrap(err, "could not save session"))
return
if session != nil {
session.Values["authenticated"] = false
// MaxAge<0 means delete cookie immediately
session.Options.MaxAge = -1
if err := session.Save(r, w); err != nil {
h.log.Error().Err(err).Msgf("could not store session: %s", r.RemoteAddr)
h.encoder.StatusError(w, http.StatusInternalServerError, err)
return
}
}
h.encoder.StatusResponse(w, http.StatusNoContent, nil)
}
func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
session, _ := h.cookieStore.Get(r, "user_session")
// Don't proceed if user is authenticated
if authenticated, ok := session.Values["authenticated"].(bool); ok {
if ok && authenticated {
h.encoder.StatusError(w, http.StatusForbidden, errors.New("active session found"))
return
}
if status, err := h.onboardEligible(r.Context()); err != nil {
h.encoder.StatusError(w, status, err)
return
}
var req domain.CreateUserRequest
@ -144,7 +140,7 @@ func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) {
return
}
if err := h.service.CreateUser(ctx, req); err != nil {
if err := h.service.CreateUser(r.Context(), req); err != nil {
h.encoder.StatusError(w, http.StatusForbidden, err)
return
}
@ -154,16 +150,8 @@ func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) {
}
func (h authHandler) canOnboard(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userCount, err := h.service.GetUserCount(ctx)
if err != nil {
h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get user count"))
return
}
if userCount > 0 {
h.encoder.StatusError(w, http.StatusForbidden, errors.New("onboarding unavailable"))
if status, err := h.onboardEligible(r.Context()); err != nil {
h.encoder.StatusError(w, status, err)
return
}
@ -172,28 +160,34 @@ func (h authHandler) canOnboard(w http.ResponseWriter, r *http.Request) {
h.encoder.NoContent(w)
}
func (h authHandler) validate(w http.ResponseWriter, r *http.Request) {
session, _ := h.cookieStore.Get(r, "user_session")
// Check if user is authenticated
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
session.Values["authenticated"] = false
session.Options.MaxAge = -1
session.Save(r, w)
h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("forbidden: invalid session"))
return
// onboardEligible checks if the onboarding process is eligible.
func (h authHandler) onboardEligible(ctx context.Context) (int, error) {
userCount, err := h.service.GetUserCount(ctx)
if err != nil {
return http.StatusInternalServerError, errors.New("could not get user count")
}
if userCount > 0 {
return http.StatusForbidden, errors.New("onboarding unavailable")
}
return http.StatusOK, nil
}
// validate sits behind the IsAuthenticated middleware which takes care of checking for a valid session
// If there is a valid session return OK, otherwise the middleware returns early with a 401
func (h authHandler) validate(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value("session").(*sessions.Session)
if session != nil {
h.log.Debug().Msgf("found user session: %+v", session)
}
// send empty response as ok
h.encoder.NoContent(w)
}
func (h authHandler) updateUser(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.UpdateUserRequest
)
var data domain.UpdateUserRequest
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.StatusError(w, http.StatusBadRequest, errors.Wrap(err, "could not decode json"))
return
@ -201,7 +195,7 @@ func (h authHandler) updateUser(w http.ResponseWriter, r *http.Request) {
data.UsernameCurrent = chi.URLParam(r, "username")
if err := h.service.UpdateUser(ctx, data); err != nil {
if err := h.service.UpdateUser(r.Context(), data); err != nil {
h.encoder.StatusError(w, http.StatusForbidden, err)
return
}
@ -209,14 +203,3 @@ func (h authHandler) updateUser(w http.ResponseWriter, r *http.Request) {
// send response as ok
h.encoder.StatusResponseMessage(w, http.StatusOK, "user successfully updated")
}
func ReadUserIP(r *http.Request) string {
IPAddress := r.Header.Get("X-Real-Ip")
if IPAddress == "" {
IPAddress = r.Header.Get("X-Forwarded-For")
}
if IPAddress == "" {
IPAddress = r.RemoteAddr
}
return IPAddress
}

402
internal/http/auth_test.go Normal file
View file

@ -0,0 +1,402 @@
// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package http
import (
"bytes"
"context"
"encoding/json"
"log"
"net/http"
"net/http/cookiejar"
"net/http/httptest"
"testing"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5"
"github.com/gorilla/sessions"
"github.com/rs/zerolog"
)
type authServiceMock struct {
users map[string]*domain.User
}
func (a authServiceMock) GetUserCount(ctx context.Context) (int, error) {
return len(a.users), nil
}
func (a authServiceMock) Login(ctx context.Context, username, password string) (*domain.User, error) {
u, ok := a.users[username]
if !ok {
return nil, errors.New("invalid login")
}
if u.Password != password {
return nil, errors.New("bad credentials")
}
return u, nil
}
func (a authServiceMock) CreateUser(ctx context.Context, req domain.CreateUserRequest) error {
if req.Username != "" {
a.users[req.Username] = &domain.User{
ID: len(a.users) + 1,
Username: req.Username,
Password: req.Password,
}
}
return nil
}
func (a authServiceMock) UpdateUser(ctx context.Context, req domain.UpdateUserRequest) error {
u, ok := a.users[req.UsernameCurrent]
if !ok {
return errors.New("user not found")
}
if req.UsernameNew != "" {
u.Username = req.UsernameNew
}
if req.PasswordNew != "" {
u.Password = req.PasswordNew
}
return nil
}
func setupServer() chi.Router {
r := chi.NewRouter()
//r.Use(middleware.Logger)
r.Get("/", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte("OK"))
})
return r
}
func runTestServer(s chi.Router) *httptest.Server {
return httptest.NewServer(s)
}
func setupAuthHandler() {
}
func TestAuthHandlerLogin(t *testing.T) {
logger := zerolog.Nop()
encoder := encoder{}
cookieStore := sessions.NewCookieStore([]byte("test"))
service := authServiceMock{
users: map[string]*domain.User{
"test": {
ID: 0,
Username: "test",
Password: "pass",
},
},
}
server := Server{
log: logger,
cookieStore: cookieStore,
}
handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service)
s := setupServer()
s.Route("/auth", handler.Routes)
testServer := runTestServer(s)
defer testServer.Close()
// generate request, here we'll use login as example
reqBody, err := json.Marshal(map[string]string{
"username": "test",
"password": "pass",
})
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
jarOptions := &cookiejar.Options{PublicSuffixList: nil}
jar, err := cookiejar.New(jarOptions)
if err != nil {
log.Fatalf("error creating cookiejar: %v", err)
}
client := http.DefaultClient
client.Jar = jar
// make request
resp, err := client.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody))
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
// check for response, here we'll just check for 204 NoContent
if status := resp.StatusCode; status != http.StatusNoContent {
t.Errorf("login: handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
}
if v := resp.Header.Get("Set-Cookie"); v == "" {
t.Errorf("handler returned no cookie")
}
}
func TestAuthHandlerValidateOK(t *testing.T) {
logger := zerolog.Nop()
encoder := encoder{}
cookieStore := sessions.NewCookieStore([]byte("test"))
service := authServiceMock{
users: map[string]*domain.User{
"test": {
ID: 0,
Username: "test",
Password: "pass",
},
},
}
server := Server{
log: logger,
cookieStore: cookieStore,
}
handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service)
s := setupServer()
s.Route("/auth", handler.Routes)
testServer := runTestServer(s)
defer testServer.Close()
// generate request, here we'll use login as example
reqBody, err := json.Marshal(map[string]string{
"username": "test",
"password": "pass",
})
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
jarOptions := &cookiejar.Options{PublicSuffixList: nil}
jar, err := cookiejar.New(jarOptions)
if err != nil {
log.Fatalf("error creating cookiejar: %v", err)
}
client := http.DefaultClient
client.Jar = jar
// make request
resp, err := client.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody))
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
// check for response, here we'll just check for 204 NoContent
if status := resp.StatusCode; status != http.StatusNoContent {
t.Errorf("login: handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
}
if v := resp.Header.Get("Set-Cookie"); v == "" {
t.Errorf("handler returned no cookie")
}
// validate token
resp, err = client.Get(testServer.URL + "/auth/validate")
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
if status := resp.StatusCode; status != http.StatusNoContent {
t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
}
}
func TestAuthHandlerValidateBad(t *testing.T) {
logger := zerolog.Nop()
encoder := encoder{}
cookieStore := sessions.NewCookieStore([]byte("test"))
service := authServiceMock{
users: map[string]*domain.User{
"test": {
ID: 0,
Username: "test",
Password: "pass",
},
},
}
server := Server{
log: logger,
cookieStore: cookieStore,
}
handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service)
s := setupServer()
s.Route("/auth", handler.Routes)
testServer := runTestServer(s)
defer testServer.Close()
jarOptions := &cookiejar.Options{PublicSuffixList: nil}
jar, err := cookiejar.New(jarOptions)
if err != nil {
log.Fatalf("error creating cookiejar: %v", err)
}
client := http.DefaultClient
client.Jar = jar
// validate token
resp, err := client.Get(testServer.URL + "/auth/validate")
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
if status := resp.StatusCode; status != http.StatusUnauthorized {
t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusUnauthorized)
}
}
func TestAuthHandlerLoginBad(t *testing.T) {
logger := zerolog.Nop()
encoder := encoder{}
cookieStore := sessions.NewCookieStore([]byte("test"))
service := authServiceMock{
users: map[string]*domain.User{
"test": {
ID: 0,
Username: "test",
Password: "pass",
},
},
}
server := Server{
log: logger,
}
handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service)
s := setupServer()
s.Route("/auth", handler.Routes)
testServer := runTestServer(s)
defer testServer.Close()
// generate request, here we'll use login as example
reqBody, err := json.Marshal(map[string]string{
"username": "test",
"password": "notmypass",
})
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
// make request
resp, err := http.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody))
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
// check for response, here we'll just check for 204 NoContent
if status := resp.StatusCode; status != http.StatusUnauthorized {
t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusUnauthorized)
}
}
func TestAuthHandlerLogout(t *testing.T) {
logger := zerolog.Nop()
encoder := encoder{}
cookieStore := sessions.NewCookieStore([]byte("test"))
service := authServiceMock{
users: map[string]*domain.User{
"test": {
ID: 0,
Username: "test",
Password: "pass",
},
},
}
server := Server{
log: logger,
cookieStore: cookieStore,
}
handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service)
s := setupServer()
s.Route("/auth", handler.Routes)
testServer := runTestServer(s)
defer testServer.Close()
// generate request, here we'll use login as example
reqBody, err := json.Marshal(map[string]string{
"username": "test",
"password": "pass",
})
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
jarOptions := &cookiejar.Options{PublicSuffixList: nil}
jar, err := cookiejar.New(jarOptions)
if err != nil {
log.Fatalf("error creating cookiejar: %v", err)
}
client := http.DefaultClient
client.Jar = jar
// make request
resp, err := client.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody))
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
// check for response, here we'll just check for 204 NoContent
if status := resp.StatusCode; status != http.StatusNoContent {
t.Errorf("login: handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
}
if v := resp.Header.Get("Set-Cookie"); v == "" {
t.Errorf("handler returned no cookie")
}
// validate token
resp, err = client.Get(testServer.URL + "/auth/validate")
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
if status := resp.StatusCode; status != http.StatusNoContent {
t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
}
// logout
resp, err = client.Post(testServer.URL+"/auth/logout", "application/json", nil)
if err != nil {
log.Fatalf("Error occurred: %v", err)
}
if status := resp.StatusCode; status != http.StatusNoContent {
t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusNoContent)
}
//if v := resp.Header.Get("Set-Cookie"); v != "" {
// t.Errorf("logout handler returned cookie")
//}
}

View file

@ -5,7 +5,6 @@ package http
import (
"bytes"
"io/ioutil"
"os"
"strings"
"testing"
@ -162,7 +161,7 @@ func TestSanitizeLogFile(t *testing.T) {
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
// Create a temporary file with sample log data
tmpFile, err := ioutil.TempFile("", "test-log-*.log")
tmpFile, err := os.CreateTemp("", "test-log-*.log")
if err != nil {
t.Fatal(err)
}

View file

@ -4,6 +4,7 @@
package http
import (
"context"
"net/http"
"runtime/debug"
"strings"
@ -30,13 +31,25 @@ func (s Server) IsAuthenticated(next http.Handler) http.Handler {
}
} else {
// check session
session, _ := s.cookieStore.Get(r, "user_session")
session, err := s.cookieStore.Get(r, "user_session")
if err != nil {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
if session.IsNew {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
// Check if user is authenticated
if auth, ok := session.Values["authenticated"].(bool); !ok || !auth {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
ctx := context.WithValue(r.Context(), "session", session)
r = r.WithContext(ctx)
}
next.ServeHTTP(w, r)

View file

@ -126,7 +126,7 @@ func (s Server) Handler() http.Handler {
encoder := encoder{}
r.Route("/api", func(r chi.Router) {
r.Route("/auth", newAuthHandler(encoder, s.log, s.config.Config, s.cookieStore, s.authService, s).Routes)
r.Route("/auth", newAuthHandler(encoder, s.log, s, s.config.Config, s.cookieStore, s.authService).Routes)
r.Route("/healthz", newHealthHandler(encoder, s.db).Routes)
r.Group(func(r chi.Router) {