// Copyright (c) 2021 - 2023, Ludvig Lundgren and the autobrr contributors. // SPDX-License-Identifier: GPL-2.0-or-later package http import ( "net/http" "runtime/debug" "strings" "time" "github.com/go-chi/chi/v5/middleware" "github.com/rs/zerolog" ) func (s Server) IsAuthenticated(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if token := r.Header.Get("X-API-Token"); token != "" { // check header if !s.apiService.ValidateAPIKey(r.Context(), token) { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } } else if key := r.URL.Query().Get("apikey"); key != "" { // check query param lke ?apikey=TOKEN if !s.apiService.ValidateAPIKey(r.Context(), key) { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } } else { // check session session, _ := s.cookieStore.Get(r, "user_session") // Check if user is authenticated if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } } next.ServeHTTP(w, r) }) } func LoggerMiddleware(logger *zerolog.Logger) func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { log := logger.With().Logger() ww := middleware.NewWrapResponseWriter(w, r.ProtoMajor) t1 := time.Now() defer func() { t2 := time.Now() // Recover and record stack traces in case of a panic if rec := recover(); rec != nil { log.Error(). Str("type", "error"). Timestamp(). Interface("recover_info", rec). Bytes("debug_stack", debug.Stack()). Msg("log system error") http.Error(ww, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } if !strings.Contains("/api/healthz/liveness|/api/healthz/readiness", r.URL.Path) { // log end request log.Trace(). Str("type", "access"). Timestamp(). Fields(map[string]interface{}{ "remote_ip": r.RemoteAddr, "url": r.URL.Path, "proto": r.Proto, "method": r.Method, "user_agent": r.Header.Get("User-Agent"), "status": ww.Status(), "latency_ms": float64(t2.Sub(t1).Nanoseconds()) / 1000000.0, "bytes_in": r.Header.Get("Content-Length"), "bytes_out": ww.BytesWritten(), }). Msg("incoming_request") } }() next.ServeHTTP(ww, r) } return http.HandlerFunc(fn) } }