From 7f05dd1efdefda4e24dc3e5bc9139414aa47ce1f Mon Sep 17 00:00:00 2001 From: ze0s <43699394+zze0s@users.noreply.github.com> Date: Mon, 17 Apr 2023 20:56:17 +0200 Subject: [PATCH] fix(onboarding): could not create user (#848) fix: onboarding not working --- cmd/autobrrctl/main.go | 2 +- internal/auth/service.go | 22 ++++---- internal/database/user.go | 6 +- internal/domain/user.go | 7 ++- internal/http/action.go | 18 +++--- internal/http/apikey.go | 2 +- internal/http/auth.go | 97 ++++++++++++++++++++------------ internal/http/download_client.go | 6 +- internal/http/encoder.go | 36 +++++++++++- internal/http/feed.go | 16 +++--- internal/http/filter.go | 18 +++--- internal/http/indexer.go | 18 +++--- internal/http/irc.go | 4 +- internal/http/notification.go | 10 ++-- internal/http/release.go | 44 +++++++-------- internal/user/service.go | 6 +- 16 files changed, 182 insertions(+), 130 deletions(-) diff --git a/cmd/autobrrctl/main.go b/cmd/autobrrctl/main.go index 4c8ccb3..dabce76 100644 --- a/cmd/autobrrctl/main.go +++ b/cmd/autobrrctl/main.go @@ -121,7 +121,7 @@ func main() { log.Fatalf("failed to hash password: %v", err) } - user := domain.User{ + user := domain.CreateUserRequest{ Username: username, Password: hashed, } diff --git a/internal/auth/service.go b/internal/auth/service.go index 37022c9..0e58818 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -15,7 +15,7 @@ import ( type Service interface { GetUserCount(ctx context.Context) (int, error) Login(ctx context.Context, username, password string) (*domain.User, error) - CreateUser(ctx context.Context, username, password string) error + CreateUser(ctx context.Context, req domain.CreateUserRequest) error } type service struct { @@ -64,9 +64,11 @@ func (s *service) Login(ctx context.Context, username, password string) (*domain return u, nil } -func (s *service) CreateUser(ctx context.Context, username, password string) error { - if username == "" || password == "" { - return errors.New("empty credentials supplied") +func (s *service) CreateUser(ctx context.Context, req domain.CreateUserRequest) error { + if req.Username == "" { + return errors.New("validation error: empty username supplied") + } else if req.Password == "" { + return errors.New("validation error: empty password supplied") } userCount, err := s.userSvc.GetUserCount(ctx) @@ -78,17 +80,15 @@ func (s *service) CreateUser(ctx context.Context, username, password string) err return errors.New("only 1 user account is supported at the moment") } - hashed, err := argon2id.CreateHash(password, argon2id.DefaultParams) + hashed, err := argon2id.CreateHash(req.Password, argon2id.DefaultParams) if err != nil { return errors.New("failed to hash password") } - newUser := domain.User{ - Username: username, - Password: hashed, - } - if err := s.userSvc.CreateUser(context.Background(), newUser); err != nil { - s.log.Error().Err(err).Msgf("could not create user: %v", username) + req.Password = hashed + + if err := s.userSvc.CreateUser(ctx, req); err != nil { + s.log.Error().Err(err).Msgf("could not create user: %s", req.Username) return errors.New("failed to create new user") } diff --git a/internal/database/user.go b/internal/database/user.go index b3c544d..fefa7a0 100644 --- a/internal/database/user.go +++ b/internal/database/user.go @@ -7,8 +7,8 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/pkg/errors" - sq "github.com/Masterminds/squirrel" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog" ) @@ -75,14 +75,14 @@ func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain return &user, nil } -func (r *UserRepo) Store(ctx context.Context, user domain.User) error { +func (r *UserRepo) Store(ctx context.Context, req domain.CreateUserRequest) error { var err error queryBuilder := r.db.squirrel. Insert("users"). Columns("username", "password"). - Values(user.Username, user.Password) + Values(req.Username, req.Password) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/domain/user.go b/internal/domain/user.go index 366da73..9973f5d 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -5,7 +5,7 @@ import "context" type UserRepo interface { GetUserCount(ctx context.Context) (int, error) FindByUsername(ctx context.Context, username string) (*User, error) - Store(ctx context.Context, user User) error + Store(ctx context.Context, req CreateUserRequest) error Update(ctx context.Context, user User) error } @@ -14,3 +14,8 @@ type User struct { Username string `json:"username"` Password string `json:"password"` } + +type CreateUserRequest struct { + Username string `json:"username"` + Password string `json:"password"` +} diff --git a/internal/http/action.go b/internal/http/action.go index 7396c36..c4b4156 100644 --- a/internal/http/action.go +++ b/internal/http/action.go @@ -44,7 +44,7 @@ func (h actionHandler) getActions(w http.ResponseWriter, r *http.Request) { // encode error } - h.encoder.StatusResponse(r.Context(), w, actions, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, actions) } func (h actionHandler) storeAction(w http.ResponseWriter, r *http.Request) { @@ -63,7 +63,7 @@ func (h actionHandler) storeAction(w http.ResponseWriter, r *http.Request) { // encode error } - h.encoder.StatusResponse(ctx, w, action, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, action) } func (h actionHandler) updateAction(w http.ResponseWriter, r *http.Request) { @@ -82,37 +82,33 @@ func (h actionHandler) updateAction(w http.ResponseWriter, r *http.Request) { // encode error } - h.encoder.StatusResponse(ctx, w, action, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, action) } func (h actionHandler) deleteAction(w http.ResponseWriter, r *http.Request) { - var ctx = r.Context() - actionID, err := parseInt(chi.URLParam(r, "id")) if err != nil { - h.encoder.StatusResponse(ctx, w, errors.New("bad param id"), http.StatusBadRequest) + h.encoder.StatusResponse(w, http.StatusBadRequest, errors.New("bad param id")) } if err := h.service.Delete(actionID); err != nil { // encode error } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h actionHandler) toggleActionEnabled(w http.ResponseWriter, r *http.Request) { - var ctx = r.Context() - actionID, err := parseInt(chi.URLParam(r, "id")) if err != nil { - h.encoder.StatusResponse(ctx, w, errors.New("bad param id"), http.StatusBadRequest) + h.encoder.StatusResponse(w, http.StatusBadRequest, errors.New("bad param id")) } if err := h.service.ToggleEnabled(actionID); err != nil { // encode error } - h.encoder.StatusResponse(ctx, w, nil, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, nil) } func parseInt(s string) (int, error) { diff --git a/internal/http/apikey.go b/internal/http/apikey.go index 45f5026..50c9e20 100644 --- a/internal/http/apikey.go +++ b/internal/http/apikey.go @@ -66,7 +66,7 @@ func (h apikeyHandler) store(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, data, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, data) } func (h apikeyHandler) delete(w http.ResponseWriter, r *http.Request) { diff --git a/internal/http/auth.go b/internal/http/auth.go index ff35da2..c0d903a 100644 --- a/internal/http/auth.go +++ b/internal/http/auth.go @@ -6,6 +6,7 @@ import ( "net/http" "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/pkg/errors" "github.com/go-chi/chi/v5" "github.com/gorilla/sessions" @@ -15,7 +16,7 @@ import ( type authService interface { GetUserCount(ctx context.Context) (int, error) Login(ctx context.Context, username, password string) (*domain.User, error) - CreateUser(ctx context.Context, username, password string) error + CreateUser(ctx context.Context, req domain.CreateUserRequest) error } type authHandler struct { @@ -52,8 +53,7 @@ func (h authHandler) login(w http.ResponseWriter, r *http.Request) { ) if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - // encode error - h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest) + h.encoder.StatusError(w, http.StatusBadRequest, errors.Wrap(err, "could not decode json")) return } @@ -71,59 +71,82 @@ func (h authHandler) login(w http.ResponseWriter, r *http.Request) { h.cookieStore.Options.SameSite = http.SameSiteStrictMode } - session, _ := h.cookieStore.Get(r, "user_session") - - _, err := h.service.Login(ctx, data.Username, data.Password) + session, err := h.cookieStore.Get(r, "user_session") if err != nil { + h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get session")) + return + } + + if _, err = h.service.Login(ctx, data.Username, data.Password); err != nil { h.log.Error().Err(err).Msgf("Auth: Failed login attempt username: [%s] ip: %s", data.Username, ReadUserIP(r)) - h.encoder.StatusResponse(ctx, w, nil, http.StatusUnauthorized) + h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("could not login: bad credentials")) return } // Set user as authenticated session.Values["authenticated"] = true - session.Save(r, w) + if err := session.Save(r, w); err != nil { + h.encoder.StatusError(w, http.StatusInternalServerError, errors.Wrap(err, "could not save session")) + return + } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h authHandler) logout(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() + session, err := h.cookieStore.Get(r, "user_session") + if err != nil { + h.log.Error().Err(err).Msg("could not get session") + h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get session")) + return + } - session, _ := h.cookieStore.Get(r, "user_session") + if session.IsNew { + h.encoder.StatusResponse(w, http.StatusNoContent, nil) + return + } // Revoke users authentication session.Values["authenticated"] = false - session.Save(r, w) + if err := session.Save(r, w); err != nil { + h.encoder.StatusError(w, http.StatusInternalServerError, errors.Wrap(err, "could not save session")) + return + } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - session, _ := h.cookieStore.Get(r, "user_session") + + session, err := h.cookieStore.Get(r, "user_session") + if err != nil { + h.log.Error().Err(err).Msg("could not get session") + h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get session")) + return + } // Don't proceed if user is authenticated - if _, ok := session.Values["authenticated"].(bool); ok { - http.Error(w, "Forbidden", http.StatusForbidden) + if authenticated, ok := session.Values["authenticated"].(bool); ok { + if ok && authenticated { + h.encoder.StatusError(w, http.StatusForbidden, errors.New("active session found")) + return + } + } + + var req domain.CreateUserRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + h.encoder.StatusError(w, http.StatusBadRequest, errors.Wrap(err, "could not decode json")) return } - var data domain.User - if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - // encode error - h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest) + if err := h.service.CreateUser(ctx, req); err != nil { + h.encoder.StatusError(w, http.StatusForbidden, err) return } - err := h.service.CreateUser(ctx, data.Username, data.Password) - if err != nil { - http.Error(w, "Forbidden", http.StatusForbidden) - return - } - - // send empty response as ok - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + // send response as ok + h.encoder.StatusResponseMessage(w, http.StatusOK, "user successfully created") } func (h authHandler) canOnboard(w http.ResponseWriter, r *http.Request) { @@ -131,33 +154,35 @@ func (h authHandler) canOnboard(w http.ResponseWriter, r *http.Request) { userCount, err := h.service.GetUserCount(ctx) if err != nil { - http.Error(w, "Internal server error", http.StatusInternalServerError) + h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get user count")) return } if userCount > 0 { - // send 503 service onboarding unavailable - http.Error(w, "Onboarding unavailable", http.StatusForbidden) + h.encoder.StatusError(w, http.StatusForbidden, errors.New("onboarding unavailable")) return } // send empty response as ok // (client can proceed with redirection to onboarding page) - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.NoContent(w) } func (h authHandler) validate(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - session, _ := h.cookieStore.Get(r, "user_session") + session, err := h.cookieStore.Get(r, "user_session") + if err != nil { + h.encoder.StatusError(w, http.StatusInternalServerError, errors.New("could not get session")) + return + } // Check if user is authenticated if auth, ok := session.Values["authenticated"].(bool); !ok || !auth { - http.Error(w, "Forbidden", http.StatusUnauthorized) + h.encoder.StatusError(w, http.StatusUnauthorized, errors.New("forbidden: invalid session")) return } // send empty response as ok - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.NoContent(w) } func ReadUserIP(r *http.Request) string { diff --git a/internal/http/download_client.go b/internal/http/download_client.go index a7250e9..b42af04 100644 --- a/internal/http/download_client.go +++ b/internal/http/download_client.go @@ -49,7 +49,7 @@ func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *htt return } - h.encoder.StatusResponse(ctx, w, clients, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, clients) } func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) { @@ -66,7 +66,7 @@ func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, client) } func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) { @@ -99,7 +99,7 @@ func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, client) } func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) { diff --git a/internal/http/encoder.go b/internal/http/encoder.go index d42336d..49a6af9 100644 --- a/internal/http/encoder.go +++ b/internal/http/encoder.go @@ -1,7 +1,6 @@ package http import ( - "context" "encoding/json" "net/http" ) @@ -13,7 +12,12 @@ type errorResponse struct { Status int `json:"status,omitempty"` } -func (e encoder) StatusResponse(ctx context.Context, w http.ResponseWriter, response interface{}, status int) { +type statusResponse struct { + Message string `json:"message"` + Status int `json:"status,omitempty"` +} + +func (e encoder) StatusResponse(w http.ResponseWriter, status int, response interface{}) { if response != nil { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) @@ -26,6 +30,19 @@ func (e encoder) StatusResponse(ctx context.Context, w http.ResponseWriter, resp } } +func (e encoder) StatusResponseMessage(w http.ResponseWriter, status int, message string) { + if message != "" { + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(statusResponse{Message: message}); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + } else { + w.WriteHeader(status) + } +} + func (e encoder) StatusCreated(w http.ResponseWriter) { w.WriteHeader(http.StatusCreated) } @@ -43,7 +60,7 @@ func (e encoder) NoContent(w http.ResponseWriter) { w.WriteHeader(http.StatusNoContent) } -func (e encoder) StatusNotFound(ctx context.Context, w http.ResponseWriter) { +func (e encoder) StatusNotFound(w http.ResponseWriter) { w.WriteHeader(http.StatusNotFound) } @@ -60,3 +77,16 @@ func (e encoder) Error(w http.ResponseWriter, err error) { w.WriteHeader(http.StatusInternalServerError) json.NewEncoder(w).Encode(res) } + +func (e encoder) StatusError(w http.ResponseWriter, status int, err error) { + res := errorResponse{ + Message: err.Error(), + } + + w.Header().Set("Content-Type", "application/json; charset=utf-8") + w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(res); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +} diff --git a/internal/http/feed.go b/internal/http/feed.go index b4a96c0..94d819d 100644 --- a/internal/http/feed.go +++ b/internal/http/feed.go @@ -48,11 +48,11 @@ func (h feedHandler) find(w http.ResponseWriter, r *http.Request) { feeds, err := h.service.Find(ctx) if err != nil { - h.encoder.StatusNotFound(ctx, w) + h.encoder.StatusNotFound(w) return } - h.encoder.StatusResponse(ctx, w, feeds, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, feeds) } func (h feedHandler) store(w http.ResponseWriter, r *http.Request) { @@ -63,7 +63,7 @@ func (h feedHandler) store(w http.ResponseWriter, r *http.Request) { if err := json.NewDecoder(r.Body).Decode(&data); err != nil { // encode error - h.encoder.StatusNotFound(ctx, w) + h.encoder.StatusNotFound(w) return } @@ -74,7 +74,7 @@ func (h feedHandler) store(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, data, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, data) } func (h feedHandler) test(w http.ResponseWriter, r *http.Request) { @@ -117,7 +117,7 @@ func (h feedHandler) update(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, data, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, data) } func (h feedHandler) toggleEnabled(w http.ResponseWriter, r *http.Request) { @@ -144,7 +144,7 @@ func (h feedHandler) toggleEnabled(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h feedHandler) delete(w http.ResponseWriter, r *http.Request) { @@ -160,7 +160,7 @@ func (h feedHandler) delete(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h feedHandler) latestRun(w http.ResponseWriter, r *http.Request) { @@ -182,7 +182,7 @@ func (h feedHandler) latestRun(w http.ResponseWriter, r *http.Request) { } if feed == "" { - h.encoder.StatusNotFound(ctx, w) + h.encoder.StatusNotFound(w) w.Write([]byte("No data found")) return } diff --git a/internal/http/filter.go b/internal/http/filter.go index 1c4ca6c..16660e5 100644 --- a/internal/http/filter.go +++ b/internal/http/filter.go @@ -78,10 +78,10 @@ func (h filterHandler) getFilters(w http.ResponseWriter, r *http.Request) { u, err := url.Parse(r.URL.String()) if err != nil { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusBadRequest, map[string]interface{}{ "code": "BAD_REQUEST_PARAMS", "message": "indexer parameter is invalid", - }, http.StatusBadRequest) + }) return } vals := u.Query() @@ -94,7 +94,7 @@ func (h filterHandler) getFilters(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, trackers, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, trackers) } func (h filterHandler) getByID(w http.ResponseWriter, r *http.Request) { @@ -111,11 +111,11 @@ func (h filterHandler) getByID(w http.ResponseWriter, r *http.Request) { filter, err := h.service.FindByID(ctx, id) if err != nil { - h.encoder.StatusNotFound(ctx, w) + h.encoder.StatusNotFound(w) return } - h.encoder.StatusResponse(ctx, w, filter, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, filter) } func (h filterHandler) duplicate(w http.ResponseWriter, r *http.Request) { @@ -132,11 +132,11 @@ func (h filterHandler) duplicate(w http.ResponseWriter, r *http.Request) { filter, err := h.service.Duplicate(ctx, id) if err != nil { - h.encoder.StatusNotFound(ctx, w) + h.encoder.StatusNotFound(w) return } - h.encoder.StatusResponse(ctx, w, filter, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, filter) } func (h filterHandler) store(w http.ResponseWriter, r *http.Request) { @@ -180,7 +180,7 @@ func (h filterHandler) update(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, filter, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, filter) } func (h filterHandler) updatePartial(w http.ResponseWriter, r *http.Request) { @@ -260,5 +260,5 @@ func (h filterHandler) delete(w http.ResponseWriter, r *http.Request) { h.encoder.Error(w, err) } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } diff --git a/internal/http/indexer.go b/internal/http/indexer.go index 88b578b..37a3447 100644 --- a/internal/http/indexer.go +++ b/internal/http/indexer.go @@ -46,15 +46,13 @@ func (h indexerHandler) Routes(r chi.Router) { } func (h indexerHandler) getSchema(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - indexers, err := h.service.GetTemplates() if err != nil { h.encoder.Error(w, err) return } - h.encoder.StatusResponse(ctx, w, indexers, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, indexers) } func (h indexerHandler) store(w http.ResponseWriter, r *http.Request) { @@ -74,7 +72,7 @@ func (h indexerHandler) store(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, indexer, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, indexer) } func (h indexerHandler) update(w http.ResponseWriter, r *http.Request) { @@ -94,7 +92,7 @@ func (h indexerHandler) update(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, indexer, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, indexer) } func (h indexerHandler) delete(w http.ResponseWriter, r *http.Request) { @@ -110,19 +108,17 @@ func (h indexerHandler) delete(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h indexerHandler) getAll(w http.ResponseWriter, r *http.Request) { - ctx := r.Context() - indexers, err := h.service.GetAll() if err != nil { h.encoder.Error(w, err) return } - h.encoder.StatusResponse(ctx, w, indexers, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, indexers) } func (h indexerHandler) list(w http.ResponseWriter, r *http.Request) { @@ -134,7 +130,7 @@ func (h indexerHandler) list(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, indexers, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, indexers) } func (h indexerHandler) testApi(w http.ResponseWriter, r *http.Request) { @@ -170,5 +166,5 @@ func (h indexerHandler) testApi(w http.ResponseWriter, r *http.Request) { Message: "Indexer api test OK", } - h.encoder.StatusResponse(ctx, w, res, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, res) } diff --git a/internal/http/irc.go b/internal/http/irc.go index 00df374..8f386ae 100644 --- a/internal/http/irc.go +++ b/internal/http/irc.go @@ -52,7 +52,7 @@ func (h ircHandler) listNetworks(w http.ResponseWriter, r *http.Request) { h.encoder.Error(w, err) } - h.encoder.StatusResponse(ctx, w, networks, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, networks) } func (h ircHandler) getNetworkByID(w http.ResponseWriter, r *http.Request) { @@ -68,7 +68,7 @@ func (h ircHandler) getNetworkByID(w http.ResponseWriter, r *http.Request) { h.encoder.Error(w, err) } - h.encoder.StatusResponse(ctx, w, network, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, network) } func (h ircHandler) restartNetwork(w http.ResponseWriter, r *http.Request) { diff --git a/internal/http/notification.go b/internal/http/notification.go index f661999..1baec10 100644 --- a/internal/http/notification.go +++ b/internal/http/notification.go @@ -45,11 +45,11 @@ func (h notificationHandler) list(w http.ResponseWriter, r *http.Request) { list, _, err := h.service.Find(ctx, domain.NotificationQueryParams{}) if err != nil { - h.encoder.StatusNotFound(ctx, w) + h.encoder.StatusNotFound(w) return } - h.encoder.StatusResponse(ctx, w, list, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, list) } func (h notificationHandler) store(w http.ResponseWriter, r *http.Request) { @@ -69,7 +69,7 @@ func (h notificationHandler) store(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, filter, http.StatusCreated) + h.encoder.StatusResponse(w, http.StatusCreated, filter) } func (h notificationHandler) update(w http.ResponseWriter, r *http.Request) { @@ -89,7 +89,7 @@ func (h notificationHandler) update(w http.ResponseWriter, r *http.Request) { return } - h.encoder.StatusResponse(ctx, w, filter, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, filter) } func (h notificationHandler) delete(w http.ResponseWriter, r *http.Request) { @@ -104,7 +104,7 @@ func (h notificationHandler) delete(w http.ResponseWriter, r *http.Request) { // return err } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.StatusResponse(w, http.StatusNoContent, nil) } func (h notificationHandler) test(w http.ResponseWriter, r *http.Request) { diff --git a/internal/http/release.go b/internal/http/release.go index 7d395e8..d3bc286 100644 --- a/internal/http/release.go +++ b/internal/http/release.go @@ -43,10 +43,10 @@ func (h releaseHandler) findReleases(w http.ResponseWriter, r *http.Request) { limitP := r.URL.Query().Get("limit") limit, err := strconv.Atoi(limitP) if err != nil && limitP != "" { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusBadRequest, map[string]interface{}{ "code": "BAD_REQUEST_PARAMS", "message": "limit parameter is invalid", - }, http.StatusBadRequest) + }) return } if limit == 0 { @@ -56,10 +56,10 @@ func (h releaseHandler) findReleases(w http.ResponseWriter, r *http.Request) { offsetP := r.URL.Query().Get("offset") offset, err := strconv.Atoi(offsetP) if err != nil && offsetP != "" { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusBadRequest, map[string]interface{}{ "code": "BAD_REQUEST_PARAMS", "message": "offset parameter is invalid", - }, http.StatusBadRequest) + }) return } @@ -68,20 +68,20 @@ func (h releaseHandler) findReleases(w http.ResponseWriter, r *http.Request) { if cursorP != "" { cursor, err = strconv.Atoi(cursorP) if err != nil && cursorP != "" { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusBadRequest, map[string]interface{}{ "code": "BAD_REQUEST_PARAMS", "message": "cursor parameter is invalid", - }, http.StatusBadRequest) + }) } return } u, err := url.Parse(r.URL.String()) if err != nil { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusBadRequest, map[string]interface{}{ "code": "BAD_REQUEST_PARAMS", "message": "indexer parameter is invalid", - }, http.StatusBadRequest) + }) return } vals := u.Query() @@ -104,10 +104,10 @@ func (h releaseHandler) findReleases(w http.ResponseWriter, r *http.Request) { releases, nextCursor, count, err := h.service.Find(r.Context(), query) if err != nil { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusInternalServerError, map[string]interface{}{ "code": "INTERNAL_SERVER_ERROR", "message": err.Error(), - }, http.StatusInternalServerError) + }) return } @@ -121,17 +121,17 @@ func (h releaseHandler) findReleases(w http.ResponseWriter, r *http.Request) { Count: count, } - h.encoder.StatusResponse(r.Context(), w, ret, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, ret) } func (h releaseHandler) findRecentReleases(w http.ResponseWriter, r *http.Request) { releases, err := h.service.FindRecent(r.Context()) if err != nil { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusInternalServerError, map[string]interface{}{ "code": "INTERNAL_SERVER_ERROR", "message": err.Error(), - }, http.StatusInternalServerError) + }) return } @@ -141,43 +141,43 @@ func (h releaseHandler) findRecentReleases(w http.ResponseWriter, r *http.Reques Data: releases, } - h.encoder.StatusResponse(r.Context(), w, ret, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, ret) } func (h releaseHandler) getIndexerOptions(w http.ResponseWriter, r *http.Request) { stats, err := h.service.GetIndexerOptions(r.Context()) if err != nil { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusInternalServerError, map[string]interface{}{ "code": "INTERNAL_SERVER_ERROR", "message": err.Error(), - }, http.StatusInternalServerError) + }) return } - h.encoder.StatusResponse(r.Context(), w, stats, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, stats) } func (h releaseHandler) getStats(w http.ResponseWriter, r *http.Request) { stats, err := h.service.Stats(r.Context()) if err != nil { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusInternalServerError, map[string]interface{}{ "code": "INTERNAL_SERVER_ERROR", "message": err.Error(), - }, http.StatusInternalServerError) + }) return } - h.encoder.StatusResponse(r.Context(), w, stats, http.StatusOK) + h.encoder.StatusResponse(w, http.StatusOK, stats) } func (h releaseHandler) deleteReleases(w http.ResponseWriter, r *http.Request) { err := h.service.Delete(r.Context()) if err != nil { - h.encoder.StatusResponse(r.Context(), w, map[string]interface{}{ + h.encoder.StatusResponse(w, http.StatusInternalServerError, map[string]interface{}{ "code": "INTERNAL_SERVER_ERROR", "message": err.Error(), - }, http.StatusInternalServerError) + }) return } diff --git a/internal/user/service.go b/internal/user/service.go index 1dbad52..2c0a1a0 100644 --- a/internal/user/service.go +++ b/internal/user/service.go @@ -10,7 +10,7 @@ import ( type Service interface { GetUserCount(ctx context.Context) (int, error) FindByUsername(ctx context.Context, username string) (*domain.User, error) - CreateUser(ctx context.Context, user domain.User) error + CreateUser(ctx context.Context, req domain.CreateUserRequest) error } type service struct { @@ -36,7 +36,7 @@ func (s *service) FindByUsername(ctx context.Context, username string) (*domain. return user, nil } -func (s *service) CreateUser(ctx context.Context, newUser domain.User) error { +func (s *service) CreateUser(ctx context.Context, req domain.CreateUserRequest) error { userCount, err := s.repo.GetUserCount(ctx) if err != nil { return err @@ -46,5 +46,5 @@ func (s *service) CreateUser(ctx context.Context, newUser domain.User) error { return errors.New("only 1 user account is supported at the moment") } - return s.repo.Store(ctx, newUser) + return s.repo.Store(ctx, req) }