diff --git a/internal/auth/service.go b/internal/auth/service.go index 0411325..4f7e761 100644 --- a/internal/auth/service.go +++ b/internal/auth/service.go @@ -10,7 +10,9 @@ import ( ) type Service interface { + GetUserCount(ctx context.Context) (int, error) Login(ctx context.Context, username, password string) (*domain.User, error) + CreateUser(ctx context.Context, username, password string) error } type service struct { @@ -23,9 +25,13 @@ func NewService(userSvc user.Service) Service { } } +func (s *service) GetUserCount(ctx context.Context) (int, error) { + return s.userSvc.GetUserCount(ctx) +} + func (s *service) Login(ctx context.Context, username, password string) (*domain.User, error) { if username == "" || password == "" { - return nil, errors.New("bad credentials") + return nil, errors.New("empty credentials supplied") } // find user @@ -50,3 +56,33 @@ func (s *service) Login(ctx context.Context, username, password string) (*domain return u, nil } + +func (s *service) CreateUser(ctx context.Context, username, password string) error { + if username == "" || password == "" { + return errors.New("empty credentials supplied") + } + + userCount, err := s.userSvc.GetUserCount(ctx) + if err != nil { + return err + } + + if userCount > 0 { + return errors.New("only 1 user account is supported at the moment") + } + + hashed, err := argon2id.CreateHash(password, argon2id.DefaultParams) + if err != nil { + return errors.New("failed to hash password") + } + + newUser := domain.User{ + Username: username, + Password: hashed, + } + if err := s.userSvc.CreateUser(context.Background(), newUser); err != nil { + return errors.New("failed to create new user") + } + + return nil +} diff --git a/internal/database/user.go b/internal/database/user.go index 9d7190c..148b290 100644 --- a/internal/database/user.go +++ b/internal/database/user.go @@ -15,6 +15,29 @@ func NewUserRepo(db *DB) domain.UserRepo { return &UserRepo{db: db} } +func (r *UserRepo) GetUserCount(ctx context.Context) (int, error) { + queryBuilder := r.db.squirrel.Select("count(*)").From("users") + + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("user.store: error building query") + return 0, err + } + + row := r.db.handler.QueryRowContext(ctx, query, args...) + if err := row.Err(); err != nil { + return 0, err + } + + result := 0 + if err := row.Scan(&result); err != nil { + log.Error().Err(err).Msg("could not query number of users") + return 0, err + } + + return result, nil +} + func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain.User, error) { queryBuilder := r.db.squirrel. @@ -66,6 +89,7 @@ func (r *UserRepo) Store(ctx context.Context, user domain.User) error { return err } + func (r *UserRepo) Update(ctx context.Context, user domain.User) error { var err error diff --git a/internal/domain/user.go b/internal/domain/user.go index 2804e6d..366da73 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -3,6 +3,7 @@ package domain import "context" type UserRepo interface { + GetUserCount(ctx context.Context) (int, error) FindByUsername(ctx context.Context, username string) (*User, error) Store(ctx context.Context, user User) error Update(ctx context.Context, user User) error diff --git a/internal/http/auth.go b/internal/http/auth.go index 8178ee3..d30851e 100644 --- a/internal/http/auth.go +++ b/internal/http/auth.go @@ -12,7 +12,9 @@ import ( ) type authService interface { + GetUserCount(ctx context.Context) (int, error) Login(ctx context.Context, username, password string) (*domain.User, error) + CreateUser(ctx context.Context, username, password string) error } type authHandler struct { @@ -35,6 +37,8 @@ func newAuthHandler(encoder encoder, config domain.Config, cookieStore *sessions func (h authHandler) Routes(r chi.Router) { r.Post("/login", h.login) r.Post("/logout", h.logout) + r.Post("/onboard", h.onboard) + r.Get("/onboard", h.canOnboard) r.Get("/validate", h.validate) } @@ -91,6 +95,53 @@ func (h authHandler) logout(w http.ResponseWriter, r *http.Request) { h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) } +func (h authHandler) onboard(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + session, _ := h.cookieStore.Get(r, "user_session") + + // Don't proceed if user is authenticated + if _, ok := session.Values["authenticated"].(bool); ok { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + var data domain.User + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + // encode error + h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest) + return + } + + err := h.service.CreateUser(ctx, data.Username, data.Password) + if err != nil { + http.Error(w, "Forbidden", http.StatusForbidden) + return + } + + // send empty response as ok + h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) +} + +func (h authHandler) canOnboard(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + userCount, err := h.service.GetUserCount(ctx) + if err != nil { + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + if userCount > 0 { + // send 503 service onboarding unavailable + http.Error(w, "Onboarding unavailable", http.StatusServiceUnavailable) + return + } + + // send empty response as ok + // (client can proceed with redirection to onboarding page) + h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) +} + func (h authHandler) validate(w http.ResponseWriter, r *http.Request) { ctx := r.Context() session, _ := h.cookieStore.Get(r, "user_session") diff --git a/internal/user/service.go b/internal/user/service.go index ba9ed85..5bf38e5 100644 --- a/internal/user/service.go +++ b/internal/user/service.go @@ -2,11 +2,14 @@ package user import ( "context" + "errors" "github.com/autobrr/autobrr/internal/domain" ) type Service interface { + GetUserCount(ctx context.Context) (int, error) FindByUsername(ctx context.Context, username string) (*domain.User, error) + CreateUser(ctx context.Context, user domain.User) error } type service struct { @@ -19,6 +22,10 @@ func NewService(repo domain.UserRepo) Service { } } +func (s *service) GetUserCount(ctx context.Context) (int, error) { + return s.repo.GetUserCount(ctx) +} + func (s *service) FindByUsername(ctx context.Context, username string) (*domain.User, error) { user, err := s.repo.FindByUsername(ctx, username) if err != nil { @@ -27,3 +34,16 @@ func (s *service) FindByUsername(ctx context.Context, username string) (*domain. return user, nil } + +func (s *service) CreateUser(ctx context.Context, newUser domain.User) error { + userCount, err := s.repo.GetUserCount(ctx) + if err != nil { + return err + } + + if userCount > 0 { + return errors.New("only 1 user account is supported at the moment") + } + + return s.repo.Store(ctx, newUser) +} diff --git a/web/src/App.tsx b/web/src/App.tsx index d508d50..1782170 100644 --- a/web/src/App.tsx +++ b/web/src/App.tsx @@ -1,44 +1,49 @@ import { Fragment } from "react"; -import { BrowserRouter as Router, Route } from "react-router-dom"; +import { BrowserRouter as Router, Route, Switch } from "react-router-dom"; import { QueryClient, QueryClientProvider } from "react-query"; import { ReactQueryDevtools } from "react-query/devtools"; import { Toaster } from "react-hot-toast"; import Base from "./screens/Base"; -import Login from "./screens/auth/login"; -import Logout from "./screens/auth/logout"; +import { Login } from "./screens/auth/login"; +import { Logout } from "./screens/auth/logout"; +import { Onboarding } from "./screens/auth/onboarding"; import { baseUrl } from "./utils"; import { AuthContext, SettingsContext } from "./utils/Context"; function Protected() { - return ( - - - - - ) + return ( + + + + + ) } export const queryClient = new QueryClient(); export function App() { - const authContext = AuthContext.useValue(); - const settings = SettingsContext.useValue(); + const authContext = AuthContext.useValue(); + const settings = SettingsContext.useValue(); - return ( - - - {authContext.isLoggedIn ? ( - - ) : - - } - - - {settings.debug ? ( - - ) : null} - - ); + return ( + + + + + {authContext.isLoggedIn ? ( + + ) : ( + + + + + )} + + {settings.debug ? ( + + ) : null} + + ); } \ No newline at end of file diff --git a/web/src/api/APIClient.ts b/web/src/api/APIClient.ts index 6c9cbe6..daebb86 100644 --- a/web/src/api/APIClient.ts +++ b/web/src/api/APIClient.ts @@ -62,6 +62,8 @@ export const APIClient = { login: (username: string, password: string) => appClient.Post("api/auth/login", { username: username, password: password }), logout: () => appClient.Post("api/auth/logout", null), validate: () => appClient.Get("api/auth/validate"), + onboard: (username: string, password: string) => appClient.Post("api/auth/onboard", { username: username, password: password }), + canOnboard: () => appClient.Get("api/auth/onboard"), }, actions: { create: (action: Action) => appClient.Post("api/actions", action), diff --git a/web/src/components/inputs/input.tsx b/web/src/components/inputs/input.tsx index b43eb0e..1fe2618 100644 --- a/web/src/components/inputs/input.tsx +++ b/web/src/components/inputs/input.tsx @@ -47,12 +47,12 @@ export const TextField = ({ type="text" defaultValue={defaultValue} autoComplete={autoComplete} - className="mt-2 block w-full dark:bg-gray-800 border border-gray-300 dark:border-gray-700 rounded-md py-2 px-3 focus:outline-none focus:ring-blue-500 focus:border-blue-500 dark:text-gray-100" + className={classNames(meta.touched && meta.error ? "focus:ring-red-500 focus:border-red-500 border-red-500" : "focus:ring-indigo-500 dark:focus:ring-blue-500 focus:border-indigo-500 dark:focus:border-blue-500 border-gray-300 dark:border-gray-700", "mt-2 block w-full dark:bg-gray-800 dark:text-gray-100 rounded-md")} placeholder={placeholder} /> {meta.touched && meta.error && ( -
{meta.error}
+

* {meta.error}

)} )} @@ -118,7 +118,7 @@ export const PasswordField = ({ )} {meta.touched && meta.error && ( -
{meta.error}
+

* {meta.error}

)} )} diff --git a/web/src/screens/auth/login.tsx b/web/src/screens/auth/login.tsx index fc40d4b..d97eb43 100644 --- a/web/src/screens/auth/login.tsx +++ b/web/src/screens/auth/login.tsx @@ -7,68 +7,72 @@ import { TextField, PasswordField } from "../../components/inputs"; import logo from "../../logo.png"; import { AuthContext } from "../../utils/Context"; +import { useEffect } from "react"; interface LoginData { - username: string; - password: string; + username: string; + password: string; } -function Login() { - const history = useHistory(); - const [, setAuthContext] = AuthContext.use(); +export const Login = () => { + const history = useHistory(); + const [, setAuthContext] = AuthContext.use(); - const mutation = useMutation( - (data: LoginData) => APIClient.auth.login(data.username, data.password), - { - onSuccess: (_, variables: LoginData) => { - setAuthContext({ - username: variables.username, - isLoggedIn: true - }); - history.push("/"); - }, - } - ); + useEffect(() => { + // Check if onboarding is available for this instance + // and redirect if needed + APIClient.auth.canOnboard() + .then(() => history.push("/onboard")); + }, [history]); - const handleSubmit = (data: any) => mutation.mutate(data); + const mutation = useMutation( + (data: LoginData) => APIClient.auth.login(data.username, data.password), + { + onSuccess: (_, variables: LoginData) => { + setAuthContext({ + username: variables.username, + isLoggedIn: true + }); + history.push("/"); + }, + } + ); - return ( -
-
- logo -
-
-
+ const handleSubmit = (data: any) => mutation.mutate(data); - - {() => ( -
-
- - -
-
- -
-
- )} -
-
-
+ return ( +
+
+ logo +
+
+
+ + +
+
+ + +
+
+ +
+
+
- ) +
+
+ ); } - -export default Login; diff --git a/web/src/screens/auth/logout.tsx b/web/src/screens/auth/logout.tsx index e0e2cdd..48e5d1f 100644 --- a/web/src/screens/auth/logout.tsx +++ b/web/src/screens/auth/logout.tsx @@ -1,32 +1,32 @@ -import {useEffect} from "react"; -import {useCookies} from "react-cookie"; -import {useHistory} from "react-router-dom"; +import { useEffect } from "react"; +import { useCookies } from "react-cookie"; +import { useHistory } from "react-router-dom"; import { APIClient } from "../../api/APIClient"; import { AuthContext } from "../../utils/Context"; -function Logout() { - const history = useHistory(); +export const Logout = () => { + const history = useHistory(); - const [, setAuthContext] = AuthContext.use(); - const [,, removeCookie] = useCookies(['user_session']); + const [, setAuthContext] = AuthContext.use(); + const [,, removeCookie] = useCookies(["user_session"]); - useEffect( - () => { - APIClient.auth.logout().then(() => { - setAuthContext({ username: "", isLoggedIn: false }); - removeCookie("user_session"); - history.push('/login'); - }) - }, - [history, removeCookie, setAuthContext] - ); + useEffect( + () => { + APIClient.auth.logout() + .then(() => { + setAuthContext({ username: "", isLoggedIn: false }); + removeCookie("user_session"); - return ( -
-

Logged out

-
- ) + history.push("/login"); + }); + }, + [history, removeCookie, setAuthContext] + ); + + return ( +
+

Logged out

+
+ ); } - -export default Logout; \ No newline at end of file diff --git a/web/src/screens/auth/onboarding.tsx b/web/src/screens/auth/onboarding.tsx new file mode 100644 index 0000000..60a45fe --- /dev/null +++ b/web/src/screens/auth/onboarding.tsx @@ -0,0 +1,85 @@ +import { Form, Formik } from "formik"; +import { useMutation } from "react-query"; +import { useHistory } from "react-router-dom"; +import { APIClient } from "../../api/APIClient"; + +import { TextField, PasswordField } from "../../components/inputs"; + +interface InputValues { + username: string; + password1: string; + password2: string; +} + +export const Onboarding = () => { + const validate = (values: InputValues) => { + const obj: Record = {}; + + if (!values.username) + obj.username = "Required"; + + if (!values.password1) + obj.password1 = "Required"; + + if (!values.password2) + obj.password2 = "Required"; + + if (values.password1 !== values.password2) + obj.password2 = "Passwords don't match!"; + + return obj; + }; + + const history = useHistory(); + + const mutation = useMutation( + (data: InputValues) => APIClient.auth.onboard(data.username, data.password1), + { + onSuccess: () => { + history.push("/login"); + }, + } + ); + + return ( +
+
+

+ Create a new user +

+
+
+
+ mutation.mutate(data)} + validate={validate} + > +
+
+ + + +
+
+ +
+
+
+
+
+
+ ); +} +