From 6a94ecacca862822d268aab067f15cbfe4b6aee7 Mon Sep 17 00:00:00 2001 From: Kyle Sanderson Date: Wed, 27 Dec 2023 17:04:25 -0800 Subject: [PATCH] refactor(http): auth handlers (#1311) * fix(auth): implement invalid cookie handling * that escalated quickly * refactor(http): auth handlers * add tests for auth handler * refactor methods * chore(tests): add header and build tag * add build tag integration * chore(tests): run in ci --------- Co-authored-by: ze0s --- .github/workflows/release.yml | 2 +- internal/database/action_test.go | 5 + internal/database/api_test.go | 5 + internal/database/database_test.go | 5 + internal/database/download_client_test.go | 5 + internal/database/feed_cache_test.go | 5 + internal/database/feed_test.go | 5 + internal/database/filter_test.go | 5 + internal/database/indexer_test.go | 5 + internal/database/irc_test.go | 5 + internal/database/notification_test.go | 5 + internal/database/release_test.go | 5 + internal/database/user_test.go | 5 + internal/http/auth.go | 133 ++++--- internal/http/auth_test.go | 402 ++++++++++++++++++++++ internal/http/logs_sanitize_test.go | 3 +- internal/http/middleware.go | 15 +- internal/http/server.go | 2 +- 18 files changed, 537 insertions(+), 80 deletions(-) create mode 100644 internal/http/auth_test.go diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index b30c60d..4ef7d50 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -83,7 +83,7 @@ jobs: cache: true - name: Test - run: go test -v ./... + run: go test -v ./... -tags=integration goreleaserbuild: name: Build distribution binaries diff --git a/internal/database/action_test.go b/internal/database/action_test.go index 53eb8fd..1f4de40 100644 --- a/internal/database/action_test.go +++ b/internal/database/action_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/api_test.go b/internal/database/api_test.go index 61bf7f3..3534e4f 100644 --- a/internal/database/api_test.go +++ b/internal/database/api_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/database_test.go b/internal/database/database_test.go index 41c3c50..fe2540d 100644 --- a/internal/database/database_test.go +++ b/internal/database/database_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/download_client_test.go b/internal/database/download_client_test.go index 13f2ddd..2531682 100644 --- a/internal/database/download_client_test.go +++ b/internal/database/download_client_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/feed_cache_test.go b/internal/database/feed_cache_test.go index a5905f0..2f6bdb4 100644 --- a/internal/database/feed_cache_test.go +++ b/internal/database/feed_cache_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/feed_test.go b/internal/database/feed_test.go index ca99d18..7a25ed9 100644 --- a/internal/database/feed_test.go +++ b/internal/database/feed_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/filter_test.go b/internal/database/filter_test.go index ac5b04f..0b1e846 100644 --- a/internal/database/filter_test.go +++ b/internal/database/filter_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/indexer_test.go b/internal/database/indexer_test.go index 004b8b6..0e67d1d 100644 --- a/internal/database/indexer_test.go +++ b/internal/database/indexer_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/irc_test.go b/internal/database/irc_test.go index 6c1883a..0025ecb 100644 --- a/internal/database/irc_test.go +++ b/internal/database/irc_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/notification_test.go b/internal/database/notification_test.go index c2d513b..91fa04c 100644 --- a/internal/database/notification_test.go +++ b/internal/database/notification_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/release_test.go b/internal/database/release_test.go index 830af56..df3e306 100644 --- a/internal/database/release_test.go +++ b/internal/database/release_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/database/user_test.go b/internal/database/user_test.go index 5800121..285b6d0 100644 --- a/internal/database/user_test.go +++ b/internal/database/user_test.go @@ -1,3 +1,8 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + package database import ( diff --git a/internal/http/auth.go b/internal/http/auth.go index c739ac1..739bf42 100644 --- a/internal/http/auth.go +++ b/internal/http/auth.go @@ -33,7 +33,7 @@ type authHandler struct { cookieStore *sessions.CookieStore } -func newAuthHandler(encoder encoder, log zerolog.Logger, config *domain.Config, cookieStore *sessions.CookieStore, service authService, server Server) *authHandler { +func newAuthHandler(encoder encoder, log zerolog.Logger, server Server, config *domain.Config, cookieStore *sessions.CookieStore, service authService) *authHandler { return &authHandler{ log: log, encoder: encoder, @@ -46,26 +46,21 @@ func newAuthHandler(encoder encoder, log zerolog.Logger, config *domain.Config, func (h authHandler) Routes(r chi.Router) { r.Post("/login", h.login) - r.Post("/logout", h.logout) r.Post("/onboard", h.onboard) r.Get("/onboard", h.canOnboard) - r.Get("/validate", h.validate) // Group for authenticated routes r.Group(func(r chi.Router) { r.Use(h.server.IsAuthenticated) - // Authenticated routes + r.Post("/logout", h.logout) + r.Get("/validate", h.validate) r.Patch("/user/{username}", h.updateUser) }) } func (h authHandler) login(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - data domain.User - ) - + var data domain.User if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.StatusError(w, http.StatusBadRequest, errors.Wrap(err, "could not decode json")) return @@ -79,63 +74,64 @@ func (h authHandler) login(w http.ResponseWriter, r *http.Request) { // if forwarded protocol is https then set cookie secure // SameSite Strict can only be set with a secure cookie. So we overwrite it here if possible. // https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Set-Cookie/SameSite - fwdProto := r.Header.Get("X-Forwarded-Proto") - if fwdProto == "https" { + if r.Header.Get("X-Forwarded-Proto") == "https" { h.cookieStore.Options.Secure = true h.cookieStore.Options.SameSite = http.SameSiteStrictMode } - if _, err := h.service.Login(ctx, data.Username, data.Password); err != nil { - h.log.Error().Err(err).Msgf("Auth: Failed login attempt username: [%s] ip: %s", data.Username, ReadUserIP(r)) + if _, err := h.service.Login(r.Context(), data.Username, data.Password); err != nil { + h.log.Error().Err(err).Msgf("Auth: Failed login attempt username: [%s] ip: %s", data.Username, r.RemoteAddr) h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("could not login: bad credentials")) return } // create new session - session, _ := h.cookieStore.Get(r, "user_session") + session, err := h.cookieStore.New(r, "user_session") + if err != nil { + h.log.Error().Err(err).Msgf("Auth: Failed to parse cookies with attempt username: [%s] ip: %s", data.Username, r.RemoteAddr) + h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("could not parse cookies")) + return + } // Set user as authenticated session.Values["authenticated"] = true + if err := session.Save(r, w); err != nil { h.encoder.StatusError(w, http.StatusInternalServerError, errors.Wrap(err, "could not save session")) return } - h.encoder.StatusResponse(w, http.StatusNoContent, nil) + h.encoder.NoContent(w) } func (h authHandler) logout(w http.ResponseWriter, r *http.Request) { - session, _ := h.cookieStore.Get(r, "user_session") - - // cookieStore.Get will create a new session if it does not exist - // so if it created a new then lets just return without saving it - if session.IsNew { - h.encoder.StatusResponse(w, http.StatusNoContent, nil) + // get session from context + session, ok := r.Context().Value("session").(*sessions.Session) + if !ok { + h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get session from context")) return } - // Revoke users authentication - session.Values["authenticated"] = false - session.Options.MaxAge = -1 - if err := session.Save(r, w); err != nil { - h.encoder.StatusError(w, http.StatusInternalServerError, errors.Wrap(err, "could not save session")) - return + if session != nil { + session.Values["authenticated"] = false + + // MaxAge<0 means delete cookie immediately + session.Options.MaxAge = -1 + + if err := session.Save(r, w); err != nil { + h.log.Error().Err(err).Msgf("could not store session: %s", r.RemoteAddr) + h.encoder.StatusError(w, http.StatusInternalServerError, err) + return + } } h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - session, _ := h.cookieStore.Get(r, "user_session") - - // Don't proceed if user is authenticated - if authenticated, ok := session.Values["authenticated"].(bool); ok { - if ok && authenticated { - h.encoder.StatusError(w, http.StatusForbidden, errors.New("active session found")) - return - } + if status, err := h.onboardEligible(r.Context()); err != nil { + h.encoder.StatusError(w, status, err) + return } var req domain.CreateUserRequest @@ -144,7 +140,7 @@ func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) { return } - if err := h.service.CreateUser(ctx, req); err != nil { + if err := h.service.CreateUser(r.Context(), req); err != nil { h.encoder.StatusError(w, http.StatusForbidden, err) return } @@ -154,16 +150,8 @@ func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) { } func (h authHandler) canOnboard(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - - userCount, err := h.service.GetUserCount(ctx) - if err != nil { - h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get user count")) - return - } - - if userCount > 0 { - h.encoder.StatusError(w, http.StatusForbidden, errors.New("onboarding unavailable")) + if status, err := h.onboardEligible(r.Context()); err != nil { + h.encoder.StatusError(w, status, err) return } @@ -172,28 +160,34 @@ func (h authHandler) canOnboard(w http.ResponseWriter, r *http.Request) { h.encoder.NoContent(w) } -func (h authHandler) validate(w http.ResponseWriter, r *http.Request) { - session, _ := h.cookieStore.Get(r, "user_session") - - // Check if user is authenticated - if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { - session.Values["authenticated"] = false - session.Options.MaxAge = -1 - session.Save(r, w) - h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("forbidden: invalid session")) - return +// onboardEligible checks if the onboarding process is eligible. +func (h authHandler) onboardEligible(ctx context.Context) (int, error) { + userCount, err := h.service.GetUserCount(ctx) + if err != nil { + return http.StatusInternalServerError, errors.New("could not get user count") } + if userCount > 0 { + return http.StatusForbidden, errors.New("onboarding unavailable") + } + + return http.StatusOK, nil +} + +// validate sits behind the IsAuthenticated middleware which takes care of checking for a valid session +// If there is a valid session return OK, otherwise the middleware returns early with a 401 +func (h authHandler) validate(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value("session").(*sessions.Session) + if session != nil { + h.log.Debug().Msgf("found user session: %+v", session) + + } // send empty response as ok h.encoder.NoContent(w) } func (h authHandler) updateUser(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - data domain.UpdateUserRequest - ) - + var data domain.UpdateUserRequest if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.StatusError(w, http.StatusBadRequest, errors.Wrap(err, "could not decode json")) return @@ -201,7 +195,7 @@ func (h authHandler) updateUser(w http.ResponseWriter, r *http.Request) { data.UsernameCurrent = chi.URLParam(r, "username") - if err := h.service.UpdateUser(ctx, data); err != nil { + if err := h.service.UpdateUser(r.Context(), data); err != nil { h.encoder.StatusError(w, http.StatusForbidden, err) return } @@ -209,14 +203,3 @@ func (h authHandler) updateUser(w http.ResponseWriter, r *http.Request) { // send response as ok h.encoder.StatusResponseMessage(w, http.StatusOK, "user successfully updated") } - -func ReadUserIP(r *http.Request) string { - IPAddress := r.Header.Get("X-Real-Ip") - if IPAddress == "" { - IPAddress = r.Header.Get("X-Forwarded-For") - } - if IPAddress == "" { - IPAddress = r.RemoteAddr - } - return IPAddress -} diff --git a/internal/http/auth_test.go b/internal/http/auth_test.go new file mode 100644 index 0000000..fab13fc --- /dev/null +++ b/internal/http/auth_test.go @@ -0,0 +1,402 @@ +// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + +package http + +import ( + "bytes" + "context" + "encoding/json" + "log" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "testing" + + "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/pkg/errors" + + "github.com/go-chi/chi/v5" + "github.com/gorilla/sessions" + "github.com/rs/zerolog" +) + +type authServiceMock struct { + users map[string]*domain.User +} + +func (a authServiceMock) GetUserCount(ctx context.Context) (int, error) { + return len(a.users), nil +} + +func (a authServiceMock) Login(ctx context.Context, username, password string) (*domain.User, error) { + u, ok := a.users[username] + if !ok { + return nil, errors.New("invalid login") + } + + if u.Password != password { + return nil, errors.New("bad credentials") + } + + return u, nil +} + +func (a authServiceMock) CreateUser(ctx context.Context, req domain.CreateUserRequest) error { + if req.Username != "" { + a.users[req.Username] = &domain.User{ + ID: len(a.users) + 1, + Username: req.Username, + Password: req.Password, + } + } + + return nil +} + +func (a authServiceMock) UpdateUser(ctx context.Context, req domain.UpdateUserRequest) error { + u, ok := a.users[req.UsernameCurrent] + if !ok { + return errors.New("user not found") + } + + if req.UsernameNew != "" { + u.Username = req.UsernameNew + } + + if req.PasswordNew != "" { + u.Password = req.PasswordNew + } + + return nil +} + +func setupServer() chi.Router { + r := chi.NewRouter() + //r.Use(middleware.Logger) + r.Get("/", func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte("OK")) + }) + return r +} + +func runTestServer(s chi.Router) *httptest.Server { + return httptest.NewServer(s) +} + +func setupAuthHandler() { + +} + +func TestAuthHandlerLogin(t *testing.T) { + logger := zerolog.Nop() + encoder := encoder{} + cookieStore := sessions.NewCookieStore([]byte("test")) + + service := authServiceMock{ + users: map[string]*domain.User{ + "test": { + ID: 0, + Username: "test", + Password: "pass", + }, + }, + } + + server := Server{ + log: logger, + cookieStore: cookieStore, + } + + handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service) + s := setupServer() + s.Route("/auth", handler.Routes) + + testServer := runTestServer(s) + defer testServer.Close() + + // generate request, here we'll use login as example + reqBody, err := json.Marshal(map[string]string{ + "username": "test", + "password": "pass", + }) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + jarOptions := &cookiejar.Options{PublicSuffixList: nil} + jar, err := cookiejar.New(jarOptions) + if err != nil { + log.Fatalf("error creating cookiejar: %v", err) + } + + client := http.DefaultClient + client.Jar = jar + + // make request + resp, err := client.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody)) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + // check for response, here we'll just check for 204 NoContent + if status := resp.StatusCode; status != http.StatusNoContent { + t.Errorf("login: handler returned wrong status code: got %v want %v", status, http.StatusNoContent) + } + + if v := resp.Header.Get("Set-Cookie"); v == "" { + t.Errorf("handler returned no cookie") + } +} + +func TestAuthHandlerValidateOK(t *testing.T) { + logger := zerolog.Nop() + encoder := encoder{} + cookieStore := sessions.NewCookieStore([]byte("test")) + + service := authServiceMock{ + users: map[string]*domain.User{ + "test": { + ID: 0, + Username: "test", + Password: "pass", + }, + }, + } + + server := Server{ + log: logger, + cookieStore: cookieStore, + } + + handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service) + s := setupServer() + s.Route("/auth", handler.Routes) + + testServer := runTestServer(s) + defer testServer.Close() + + // generate request, here we'll use login as example + reqBody, err := json.Marshal(map[string]string{ + "username": "test", + "password": "pass", + }) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + jarOptions := &cookiejar.Options{PublicSuffixList: nil} + jar, err := cookiejar.New(jarOptions) + if err != nil { + log.Fatalf("error creating cookiejar: %v", err) + } + + client := http.DefaultClient + client.Jar = jar + + // make request + resp, err := client.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody)) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + // check for response, here we'll just check for 204 NoContent + if status := resp.StatusCode; status != http.StatusNoContent { + t.Errorf("login: handler returned wrong status code: got %v want %v", status, http.StatusNoContent) + } + + if v := resp.Header.Get("Set-Cookie"); v == "" { + t.Errorf("handler returned no cookie") + } + + // validate token + resp, err = client.Get(testServer.URL + "/auth/validate") + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + if status := resp.StatusCode; status != http.StatusNoContent { + t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusNoContent) + } +} + +func TestAuthHandlerValidateBad(t *testing.T) { + logger := zerolog.Nop() + encoder := encoder{} + cookieStore := sessions.NewCookieStore([]byte("test")) + + service := authServiceMock{ + users: map[string]*domain.User{ + "test": { + ID: 0, + Username: "test", + Password: "pass", + }, + }, + } + + server := Server{ + log: logger, + cookieStore: cookieStore, + } + + handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service) + s := setupServer() + s.Route("/auth", handler.Routes) + + testServer := runTestServer(s) + defer testServer.Close() + + jarOptions := &cookiejar.Options{PublicSuffixList: nil} + jar, err := cookiejar.New(jarOptions) + if err != nil { + log.Fatalf("error creating cookiejar: %v", err) + } + + client := http.DefaultClient + client.Jar = jar + + // validate token + resp, err := client.Get(testServer.URL + "/auth/validate") + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + if status := resp.StatusCode; status != http.StatusUnauthorized { + t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusUnauthorized) + } +} + +func TestAuthHandlerLoginBad(t *testing.T) { + logger := zerolog.Nop() + encoder := encoder{} + cookieStore := sessions.NewCookieStore([]byte("test")) + + service := authServiceMock{ + users: map[string]*domain.User{ + "test": { + ID: 0, + Username: "test", + Password: "pass", + }, + }, + } + + server := Server{ + log: logger, + } + + handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service) + s := setupServer() + s.Route("/auth", handler.Routes) + + testServer := runTestServer(s) + defer testServer.Close() + + // generate request, here we'll use login as example + reqBody, err := json.Marshal(map[string]string{ + "username": "test", + "password": "notmypass", + }) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + // make request + resp, err := http.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody)) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + // check for response, here we'll just check for 204 NoContent + if status := resp.StatusCode; status != http.StatusUnauthorized { + t.Errorf("handler returned wrong status code: got %v want %v", status, http.StatusUnauthorized) + } +} + +func TestAuthHandlerLogout(t *testing.T) { + logger := zerolog.Nop() + encoder := encoder{} + cookieStore := sessions.NewCookieStore([]byte("test")) + + service := authServiceMock{ + users: map[string]*domain.User{ + "test": { + ID: 0, + Username: "test", + Password: "pass", + }, + }, + } + + server := Server{ + log: logger, + cookieStore: cookieStore, + } + + handler := newAuthHandler(encoder, logger, server, &domain.Config{}, cookieStore, service) + s := setupServer() + s.Route("/auth", handler.Routes) + + testServer := runTestServer(s) + defer testServer.Close() + + // generate request, here we'll use login as example + reqBody, err := json.Marshal(map[string]string{ + "username": "test", + "password": "pass", + }) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + jarOptions := &cookiejar.Options{PublicSuffixList: nil} + jar, err := cookiejar.New(jarOptions) + if err != nil { + log.Fatalf("error creating cookiejar: %v", err) + } + + client := http.DefaultClient + client.Jar = jar + + // make request + resp, err := client.Post(testServer.URL+"/auth/login", "application/json", bytes.NewBuffer(reqBody)) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + // check for response, here we'll just check for 204 NoContent + if status := resp.StatusCode; status != http.StatusNoContent { + t.Errorf("login: handler returned wrong status code: got %v want %v", status, http.StatusNoContent) + } + + if v := resp.Header.Get("Set-Cookie"); v == "" { + t.Errorf("handler returned no cookie") + } + + // validate token + resp, err = client.Get(testServer.URL + "/auth/validate") + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + if status := resp.StatusCode; status != http.StatusNoContent { + t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusNoContent) + } + + // logout + resp, err = client.Post(testServer.URL+"/auth/logout", "application/json", nil) + if err != nil { + log.Fatalf("Error occurred: %v", err) + } + + if status := resp.StatusCode; status != http.StatusNoContent { + t.Errorf("validate: handler returned wrong status code: got %v want %v", status, http.StatusNoContent) + } + + //if v := resp.Header.Get("Set-Cookie"); v != "" { + // t.Errorf("logout handler returned cookie") + //} +} diff --git a/internal/http/logs_sanitize_test.go b/internal/http/logs_sanitize_test.go index b030248..c2ee91f 100644 --- a/internal/http/logs_sanitize_test.go +++ b/internal/http/logs_sanitize_test.go @@ -5,7 +5,6 @@ package http import ( "bytes" - "io/ioutil" "os" "strings" "testing" @@ -162,7 +161,7 @@ func TestSanitizeLogFile(t *testing.T) { for _, testCase := range testCases { t.Run(testCase.name, func(t *testing.T) { // Create a temporary file with sample log data - tmpFile, err := ioutil.TempFile("", "test-log-*.log") + tmpFile, err := os.CreateTemp("", "test-log-*.log") if err != nil { t.Fatal(err) } diff --git a/internal/http/middleware.go b/internal/http/middleware.go index be57ee4..48f955f 100644 --- a/internal/http/middleware.go +++ b/internal/http/middleware.go @@ -4,6 +4,7 @@ package http import ( + "context" "net/http" "runtime/debug" "strings" @@ -30,13 +31,25 @@ func (s Server) IsAuthenticated(next http.Handler) http.Handler { } } else { // check session - session, _ := s.cookieStore.Get(r, "user_session") + session, err := s.cookieStore.Get(r, "user_session") + if err != nil { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } + + if session.IsNew { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + return + } // Check if user is authenticated if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } + + ctx := context.WithValue(r.Context(), "session", session) + r = r.WithContext(ctx) } next.ServeHTTP(w, r) diff --git a/internal/http/server.go b/internal/http/server.go index a63423b..fb42a62 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -126,7 +126,7 @@ func (s Server) Handler() http.Handler { encoder := encoder{} r.Route("/api", func(r chi.Router) { - r.Route("/auth", newAuthHandler(encoder, s.log, s.config.Config, s.cookieStore, s.authService, s).Routes) + r.Route("/auth", newAuthHandler(encoder, s.log, s, s.config.Config, s.cookieStore, s.authService).Routes) r.Route("/healthz", newHealthHandler(encoder, s.db).Routes) r.Group(func(r chi.Router) {