diff --git a/cmd/autobrrctl/main.go b/cmd/autobrrctl/main.go index 8f285bf..6edaadc 100644 --- a/cmd/autobrrctl/main.go +++ b/cmd/autobrrctl/main.go @@ -14,11 +14,12 @@ import ( "os" "time" + "github.com/autobrr/autobrr/internal/auth" "github.com/autobrr/autobrr/internal/config" "github.com/autobrr/autobrr/internal/database" "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" - "github.com/autobrr/autobrr/pkg/argon2id" + "github.com/autobrr/autobrr/internal/user" "github.com/autobrr/autobrr/pkg/errors" "golang.org/x/term" @@ -95,6 +96,12 @@ func main() { log.Fatal("--config required") } + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + // read config cfg := config.New(configPath, version) @@ -109,34 +116,42 @@ func main() { userRepo := database.NewUserRepo(l, db) - username := flag.Arg(1) - if username == "" { - flag.Usage() - os.Exit(1) - } + userSvc := user.NewService(userRepo) + authSvc := auth.NewService(l, userSvc) + + ctx := context.Background() password, err := readPassword() if err != nil { log.Fatalf("failed to read password: %v", err) } - hashed, err := argon2id.CreateHash(string(password), argon2id.DefaultParams) + + hashed, err := authSvc.CreateHash(string(password)) if err != nil { log.Fatalf("failed to hash password: %v", err) } - user := domain.CreateUserRequest{ + req := domain.CreateUserRequest{ Username: username, Password: hashed, } - if err := userRepo.Store(context.Background(), user); err != nil { + + if err := userRepo.Store(ctx, req); err != nil { log.Fatalf("failed to create user: %v", err) } + case "change-password": if configPath == "" { log.Fatal("--config required") } + username := flag.Arg(1) + if username == "" { + flag.Usage() + os.Exit(1) + } + // read config cfg := config.New(configPath, version) @@ -151,18 +166,17 @@ func main() { userRepo := database.NewUserRepo(l, db) - username := flag.Arg(1) - if username == "" { - flag.Usage() - os.Exit(1) - } + userSvc := user.NewService(userRepo) + authSvc := auth.NewService(l, userSvc) - user, err := userRepo.FindByUsername(context.Background(), username) + ctx := context.Background() + + usr, err := userSvc.FindByUsername(ctx, username) if err != nil { log.Fatalf("failed to get user: %v", err) } - if user == nil { + if usr == nil { log.Fatalf("failed to get user: %v", err) } @@ -170,15 +184,26 @@ func main() { if err != nil { log.Fatalf("failed to read password: %v", err) } - hashed, err := argon2id.CreateHash(string(password), argon2id.DefaultParams) + + hashed, err := authSvc.CreateHash(string(password)) if err != nil { log.Fatalf("failed to hash password: %v", err) } - user.Password = hashed - if err := userRepo.Update(context.Background(), *user); err != nil { + usr.Password = hashed + + req := domain.UpdateUserRequest{ + UsernameCurrent: username, + PasswordNew: string(password), + PasswordNewHash: hashed, + } + + if err := userSvc.Update(ctx, req); err != nil { log.Fatalf("failed to create user: %v", err) } + + log.Printf("successfully updated password for user %q", username) + default: flag.Usage() if cmd != "help" { diff --git a/internal/auth/service.go b/internal/auth/service.go index c58f4f2..7060155 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -19,6 +19,9 @@ type Service interface { GetUserCount(ctx context.Context) (int, error) Login(ctx context.Context, username, password string) (*domain.User, error) CreateUser(ctx context.Context, req domain.CreateUserRequest) error + UpdateUser(ctx context.Context, req domain.UpdateUserRequest) error + CreateHash(password string) (hash string, err error) + ComparePasswordAndHash(password string, hash string) (match bool, err error) } type service struct { @@ -54,7 +57,7 @@ func (s *service) Login(ctx context.Context, username, password string) (*domain } // compare password from request and the saved password - match, err := argon2id.ComparePasswordAndHash(password, u.Password) + match, err := s.ComparePasswordAndHash(password, u.Password) if err != nil { return nil, errors.New("error checking credentials") } @@ -83,7 +86,7 @@ func (s *service) CreateUser(ctx context.Context, req domain.CreateUserRequest) return errors.New("only 1 user account is supported at the moment") } - hashed, err := argon2id.CreateHash(req.Password, argon2id.DefaultParams) + hashed, err := s.CreateHash(req.Password) if err != nil { return errors.New("failed to hash password") } @@ -97,3 +100,59 @@ func (s *service) CreateUser(ctx context.Context, req domain.CreateUserRequest) return nil } + +func (s *service) UpdateUser(ctx context.Context, req domain.UpdateUserRequest) error { + if req.PasswordCurrent == "" { + return errors.New("validation error: empty current password supplied") + } + + if req.PasswordNew != "" && req.PasswordCurrent != "" { + if req.PasswordNew == req.PasswordCurrent { + return errors.New("validation error: new password must be different") + } + } + + // find user + u, err := s.userSvc.FindByUsername(ctx, req.UsernameCurrent) + if err != nil { + s.log.Trace().Err(err).Msgf("invalid login %v", req.UsernameCurrent) + return errors.Wrapf(err, "invalid login: %s", req.UsernameCurrent) + } + + if u == nil { + return errors.Errorf("invalid login: %s", req.UsernameCurrent) + } + + // compare password from request and the saved password + match, err := s.ComparePasswordAndHash(req.PasswordCurrent, u.Password) + if err != nil { + return errors.New("error checking credentials") + } + + if !match { + s.log.Debug().Msgf("bad credentials: %q | %q", req.UsernameCurrent, req.PasswordCurrent) + return errors.Errorf("invalid login: %s", req.UsernameCurrent) + } + + hashed, err := s.CreateHash(req.PasswordNew) + if err != nil { + return errors.New("failed to hash password") + } + + req.PasswordNewHash = hashed + + if err := s.userSvc.Update(ctx, req); err != nil { + s.log.Error().Err(err).Msgf("could not change password for user: %s", req.UsernameCurrent) + return errors.New("failed to change password") + } + + return nil +} + +func (s *service) ComparePasswordAndHash(password string, hash string) (match bool, err error) { + return argon2id.ComparePasswordAndHash(password, hash) +} + +func (s *service) CreateHash(password string) (hash string, err error) { + return argon2id.CreateHash(password, argon2id.DefaultParams) +} diff --git a/internal/database/user.go b/internal/database/user.go index b035721..1139393 100644 --- a/internal/database/user.go +++ b/internal/database/user.go @@ -49,7 +49,6 @@ func (r *UserRepo) GetUserCount(ctx context.Context) (int, error) { } func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain.User, error) { - queryBuilder := r.db.squirrel. Select("id", "username", "password"). From("users"). @@ -79,9 +78,6 @@ func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain } func (r *UserRepo) Store(ctx context.Context, req domain.CreateUserRequest) error { - - var err error - queryBuilder := r.db.squirrel. Insert("users"). Columns("username", "password"). @@ -100,15 +96,18 @@ func (r *UserRepo) Store(ctx context.Context, req domain.CreateUserRequest) erro return err } -func (r *UserRepo) Update(ctx context.Context, user domain.User) error { +func (r *UserRepo) Update(ctx context.Context, user domain.UpdateUserRequest) error { + queryBuilder := r.db.squirrel.Update("users") - var err error + if user.UsernameNew != "" { + queryBuilder = queryBuilder.Set("username", user.UsernameNew) + } - queryBuilder := r.db.squirrel. - Update("users"). - Set("username", user.Username). - Set("password", user.Password). - Where(sq.Eq{"username": user.Username}) + if user.PasswordNewHash != "" { + queryBuilder = queryBuilder.Set("password", user.PasswordNewHash) + } + + queryBuilder = queryBuilder.Where(sq.Eq{"username": user.UsernameCurrent}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -120,11 +119,10 @@ func (r *UserRepo) Update(ctx context.Context, user domain.User) error { return errors.Wrap(err, "error executing query") } - return err + return nil } func (r *UserRepo) Delete(ctx context.Context, username string) error { - queryBuilder := r.db.squirrel. Delete("users"). Where(sq.Eq{"username": username}) diff --git a/internal/database/user_test.go b/internal/database/user_test.go index f34296d..5800121 100644 --- a/internal/database/user_test.go +++ b/internal/database/user_test.go @@ -55,11 +55,19 @@ func TestUserRepo_Update(t *testing.T) { }) assert.NoError(t, err) + storedUser, err := repo.FindByUsername(context.Background(), user.Username) + assert.NoError(t, err) + user.ID = storedUser.ID + t.Run(fmt.Sprintf("UpdateUser_Succeeds [%s]", dbType), func(t *testing.T) { // Update the user newPassword := "newPassword123" user.Password = newPassword - err := repo.Update(context.Background(), user) + req := domain.UpdateUserRequest{ + UsernameCurrent: user.Username, + PasswordNewHash: newPassword, + } + err := repo.Update(context.Background(), req) assert.NoError(t, err) // Verify @@ -68,7 +76,7 @@ func TestUserRepo_Update(t *testing.T) { assert.Equal(t, newPassword, updatedUser.Password) // Cleanup - _ = repo.Delete(context.Background(), user.Username) + _ = repo.Delete(context.Background(), updatedUser.Username) }) } } diff --git a/internal/domain/user.go b/internal/domain/user.go index c3b84cd..87faec3 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -9,7 +9,7 @@ type UserRepo interface { GetUserCount(ctx context.Context) (int, error) FindByUsername(ctx context.Context, username string) (*User, error) Store(ctx context.Context, req CreateUserRequest) error - Update(ctx context.Context, user User) error + Update(ctx context.Context, req UpdateUserRequest) error Delete(ctx context.Context, username string) error } @@ -19,6 +19,14 @@ type User struct { Password string `json:"password"` } +type UpdateUserRequest struct { + UsernameCurrent string `json:"username_username"` + UsernameNew string `json:"username_new"` + PasswordCurrent string `json:"password_current"` + PasswordNew string `json:"password_new"` + PasswordNewHash string `json:"-"` +} + type CreateUserRequest struct { Username string `json:"username"` Password string `json:"password"` diff --git a/internal/http/auth.go b/internal/http/auth.go index ca7ff77..c739ac1 100644 --- a/internal/http/auth.go +++ b/internal/http/auth.go @@ -20,6 +20,7 @@ type authService interface { GetUserCount(ctx context.Context) (int, error) Login(ctx context.Context, username, password string) (*domain.User, error) CreateUser(ctx context.Context, req domain.CreateUserRequest) error + UpdateUser(ctx context.Context, req domain.UpdateUserRequest) error } type authHandler struct { @@ -27,17 +28,19 @@ type authHandler struct { encoder encoder config *domain.Config service authService + server Server cookieStore *sessions.CookieStore } -func newAuthHandler(encoder encoder, log zerolog.Logger, config *domain.Config, cookieStore *sessions.CookieStore, service authService) *authHandler { +func newAuthHandler(encoder encoder, log zerolog.Logger, config *domain.Config, cookieStore *sessions.CookieStore, service authService, server Server) *authHandler { return &authHandler{ log: log, encoder: encoder, config: config, service: service, cookieStore: cookieStore, + server: server, } } @@ -47,6 +50,14 @@ func (h authHandler) Routes(r chi.Router) { r.Post("/onboard", h.onboard) r.Get("/onboard", h.canOnboard) r.Get("/validate", h.validate) + + // Group for authenticated routes + r.Group(func(r chi.Router) { + r.Use(h.server.IsAuthenticated) + + // Authenticated routes + r.Patch("/user/{username}", h.updateUser) + }) } func (h authHandler) login(w http.ResponseWriter, r *http.Request) { @@ -177,6 +188,28 @@ func (h authHandler) validate(w http.ResponseWriter, r *http.Request) { h.encoder.NoContent(w) } +func (h authHandler) updateUser(w http.ResponseWriter, r *http.Request) { + var ( + ctx = r.Context() + data domain.UpdateUserRequest + ) + + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + h.encoder.StatusError(w, http.StatusBadRequest, errors.Wrap(err, "could not decode json")) + return + } + + data.UsernameCurrent = chi.URLParam(r, "username") + + if err := h.service.UpdateUser(ctx, data); err != nil { + h.encoder.StatusError(w, http.StatusForbidden, err) + return + } + + // send response as ok + h.encoder.StatusResponseMessage(w, http.StatusOK, "user successfully updated") +} + func ReadUserIP(r *http.Request) string { IPAddress := r.Header.Get("X-Real-Ip") if IPAddress == "" { diff --git a/internal/http/server.go b/internal/http/server.go index 15a8fa4..a63423b 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -126,7 +126,7 @@ func (s Server) Handler() http.Handler { encoder := encoder{} r.Route("/api", func(r chi.Router) { - r.Route("/auth", newAuthHandler(encoder, s.log, s.config.Config, s.cookieStore, s.authService).Routes) + r.Route("/auth", newAuthHandler(encoder, s.log, s.config.Config, s.cookieStore, s.authService, s).Routes) r.Route("/healthz", newHealthHandler(encoder, s.db).Routes) r.Group(func(r chi.Router) { diff --git a/internal/user/service.go b/internal/user/service.go index bb0fcbb..6584f22 100644 --- a/internal/user/service.go +++ b/internal/user/service.go @@ -14,6 +14,7 @@ type Service interface { GetUserCount(ctx context.Context) (int, error) FindByUsername(ctx context.Context, username string) (*domain.User, error) CreateUser(ctx context.Context, req domain.CreateUserRequest) error + Update(ctx context.Context, req domain.UpdateUserRequest) error } type service struct { @@ -51,3 +52,7 @@ func (s *service) CreateUser(ctx context.Context, req domain.CreateUserRequest) return s.repo.Store(ctx, req) } + +func (s *service) Update(ctx context.Context, req domain.UpdateUserRequest) error { + return s.repo.Update(ctx, req) +} diff --git a/web/src/api/APIClient.ts b/web/src/api/APIClient.ts index 9e63723..bfb3c0a 100644 --- a/web/src/api/APIClient.ts +++ b/web/src/api/APIClient.ts @@ -158,7 +158,9 @@ export const APIClient = { onboard: (username: string, password: string) => appClient.Post("api/auth/onboard", { body: { username, password } }), - canOnboard: () => appClient.Get("api/auth/onboard") + canOnboard: () => appClient.Get("api/auth/onboard"), + updateUser: (req: UserUpdate) => appClient.Patch(`api/auth/user/${req.username_current}`, + { body: req }) }, actions: { create: (action: Action) => appClient.Post("api/actions", { diff --git a/web/src/components/header/RightNav.tsx b/web/src/components/header/RightNav.tsx index d501109..4605909 100644 --- a/web/src/components/header/RightNav.tsx +++ b/web/src/components/header/RightNav.tsx @@ -55,6 +55,25 @@ export const RightNav = (props: RightNavProps) => { static className="origin-top-right absolute right-0 mt-2 w-48 z-10 divide-y divide-gray-100 dark:divide-gray-750 rounded-md shadow-lg bg-white dark:bg-gray-800 border border-gray-250 dark:border-gray-775 focus:outline-none" > +