From 3185832708b01aa46c68fc9814c6da42e6dd82ac Mon Sep 17 00:00:00 2001 From: Ludvig Lundgren Date: Sat, 2 Apr 2022 19:24:23 +0200 Subject: [PATCH] feat: add postgres support (#215) * feat: add postgres support and refactor * feat: improve releases find * fix: autobrrctl create user --- cmd/autobrr/main.go | 11 +- cmd/autobrrctl/main.go | 4 +- internal/action/service.go | 38 +- internal/database/action.go | 369 +++++++++++++---- internal/database/database.go | 8 +- internal/database/download_client.go | 219 +++++----- internal/database/filter.go | 580 +++++++++++++++------------ internal/database/indexer.go | 91 +++-- internal/database/irc.go | 463 +++++++++++++-------- internal/database/migrate.go | 374 ++++++++++++++++- internal/database/postgres.go | 12 +- internal/database/release.go | 155 ++++--- internal/database/sqlite.go | 16 +- internal/database/user.go | 71 +++- internal/domain/action.go | 2 +- internal/domain/client.go | 1 + internal/domain/indexer.go | 6 +- internal/domain/irc.go | 3 +- internal/domain/release.go | 2 +- internal/domain/user.go | 1 + internal/download_client/service.go | 13 + internal/filter/service.go | 4 +- internal/http/action.go | 4 +- internal/http/download_client.go | 3 +- internal/http/indexer.go | 21 +- internal/http/irc.go | 4 +- internal/http/release.go | 8 +- internal/indexer/service.go | 32 +- internal/irc/service.go | 13 +- internal/release/service.go | 11 +- 30 files changed, 1708 insertions(+), 831 deletions(-) diff --git a/cmd/autobrr/main.go b/cmd/autobrr/main.go index e998503..88ec978 100644 --- a/cmd/autobrr/main.go +++ b/cmd/autobrr/main.go @@ -54,16 +54,19 @@ func main() { // setup logger logger.Setup(cfg, serverEvents) - log.Info().Msg("Starting autobrr") - log.Info().Msgf("Version: %v", version) - log.Info().Msgf("Log-level: %v", cfg.LogLevel) - // open database connection db, _ := database.NewDB(cfg) if err := db.Open(); err != nil { log.Fatal().Err(err).Msg("could not open db connection") } + log.Info().Msgf("Starting autobrr") + log.Info().Msgf("Version: %v", version) + log.Info().Msgf("Commit: %v", commit) + log.Info().Msgf("Build date: %v", date) + log.Info().Msgf("Log-level: %v", cfg.LogLevel) + log.Info().Msgf("Using database: %v", db.Driver) + // setup repos var ( actionRepo = database.NewActionRepo(db) diff --git a/cmd/autobrrctl/main.go b/cmd/autobrrctl/main.go index 41b629d..4dc9484 100644 --- a/cmd/autobrrctl/main.go +++ b/cmd/autobrrctl/main.go @@ -39,7 +39,7 @@ func main() { } // open database connection - db, _ := database.NewDB(domain.Config{ConfigPath: configPath}) + db, _ := database.NewDB(domain.Config{ConfigPath: configPath, DatabaseType: "sqlite"}) if err := db.Open(); err != nil { log.Fatal("could not open db connection") } @@ -96,7 +96,7 @@ func main() { } user.Password = hashed - if err := userRepo.Store(context.Background(), *user); err != nil { + if err := userRepo.Update(context.Background(), *user); err != nil { log.Fatalf("failed to create user: %v", err) } default: diff --git a/internal/action/service.go b/internal/action/service.go index af07e09..3331675 100644 --- a/internal/action/service.go +++ b/internal/action/service.go @@ -11,7 +11,7 @@ import ( type Service interface { Store(ctx context.Context, action domain.Action) (*domain.Action, error) - Fetch() ([]domain.Action, error) + List(ctx context.Context) ([]domain.Action, error) Delete(actionID int) error DeleteByFilterID(ctx context.Context, filterID int) error ToggleEnabled(actionID int) error @@ -31,45 +31,21 @@ func NewService(repo domain.ActionRepo, clientSvc download_client.Service, bus E } func (s *service) Store(ctx context.Context, action domain.Action) (*domain.Action, error) { - // validate data - - a, err := s.repo.Store(ctx, action) - if err != nil { - return nil, err - } - - return a, nil + return s.repo.Store(ctx, action) } func (s *service) Delete(actionID int) error { - if err := s.repo.Delete(actionID); err != nil { - return err - } - - return nil + return s.repo.Delete(actionID) } func (s *service) DeleteByFilterID(ctx context.Context, filterID int) error { - if err := s.repo.DeleteByFilterID(ctx, filterID); err != nil { - return err - } - - return nil + return s.repo.DeleteByFilterID(ctx, filterID) } -func (s *service) Fetch() ([]domain.Action, error) { - actions, err := s.repo.List() - if err != nil { - return nil, err - } - - return actions, nil +func (s *service) List(ctx context.Context) ([]domain.Action, error) { + return s.repo.List(ctx) } func (s *service) ToggleEnabled(actionID int) error { - if err := s.repo.ToggleEnabled(actionID); err != nil { - return err - } - - return nil + return s.repo.ToggleEnabled(actionID) } diff --git a/internal/database/action.go b/internal/database/action.go index ba4df7c..518375e 100644 --- a/internal/database/action.go +++ b/internal/database/action.go @@ -6,6 +6,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog/log" ) @@ -18,17 +19,47 @@ func NewActionRepo(db *DB) domain.ActionRepo { } func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]domain.Action, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select( + "id", + "name", + "type", + "enabled", + "exec_cmd", + "exec_args", + "watch_folder", + "category", + "tags", + "label", + "save_path", + "paused", + "ignore_rules", + "limit_upload_speed", + "limit_download_speed", + "webhook_host", + "webhook_type", + "webhook_method", + "webhook_data", + "client_id", + ). + From("action"). + Where("filter_id = ?", filterID) - rows, err := r.db.handler.QueryContext(ctx, "SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, webhook_host, webhook_data, webhook_type, webhook_method, client_id FROM action WHERE action.filter_id = ?", filterID) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Fatal().Err(err) + log.Error().Stack().Err(err).Msg("action.findByFilterID: error building query") + return nil, err + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("action.findByFilterID: query error") + return nil, err } defer rows.Close() - var actions []domain.Action + actions := make([]domain.Action, 0) for rows.Next() { var a domain.Action @@ -39,9 +70,7 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]domain var paused, ignoreRules sql.NullBool if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &limitDl, &limitUl, &host, &data, &webhookType, &webhookMethod, &clientID); err != nil { - log.Fatal().Err(err) - } - if err != nil { + log.Error().Stack().Err(err).Msg("action.findByFilterID: error scanning row") return nil, err } @@ -65,24 +94,54 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]domain actions = append(actions, a) } if err := rows.Err(); err != nil { + log.Error().Stack().Err(err).Msg("action.findByFilterID: row error") return nil, err } return actions, nil } -func (r *ActionRepo) List() ([]domain.Action, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() +func (r *ActionRepo) List(ctx context.Context) ([]domain.Action, error) { + queryBuilder := r.db.squirrel. + Select( + "id", + "name", + "type", + "enabled", + "exec_cmd", + "exec_args", + "watch_folder", + "category", + "tags", + "label", + "save_path", + "paused", + "ignore_rules", + "limit_upload_speed", + "limit_download_speed", + "webhook_host", + "webhook_type", + "webhook_method", + "webhook_data", + "client_id", + ). + From("action") - rows, err := r.db.handler.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action") + query, args, err := queryBuilder.ToSql() if err != nil { - log.Fatal().Err(err) + log.Error().Stack().Err(err).Msg("action.list: error building query") + return nil, err + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("action.list: error executing query") + return nil, err } defer rows.Close() - var actions []domain.Action + actions := make([]domain.Action, 0) for rows.Next() { var a domain.Action @@ -92,9 +151,7 @@ func (r *ActionRepo) List() ([]domain.Action, error) { var paused, ignoreRules sql.NullBool if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &limitDl, &limitUl, &clientID); err != nil { - log.Fatal().Err(err) - } - if err != nil { + log.Error().Stack().Err(err).Msg("action.list: error scanning row") return nil, err } @@ -111,6 +168,7 @@ func (r *ActionRepo) List() ([]domain.Action, error) { actions = append(actions, a) } if err := rows.Err(); err != nil { + log.Error().Stack().Err(err).Msg("action.list: row error") return nil, err } @@ -118,40 +176,50 @@ func (r *ActionRepo) List() ([]domain.Action, error) { } func (r *ActionRepo) Delete(actionID int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Delete("action"). + Where("id = ?", actionID) - res, err := r.db.handler.Exec(`DELETE FROM action WHERE action.id = ?`, actionID) + query, args, err := queryBuilder.ToSql() if err != nil { + log.Error().Stack().Err(err).Msg("action.delete: error building query") return err } - rows, _ := res.RowsAffected() + _, err = r.db.handler.Exec(query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("action.delete: error executing query") + return err + } - log.Info().Msgf("rows affected %v", rows) + log.Debug().Msgf("action.delete: %v", actionID) return nil } func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Delete("action"). + Where("filter_id = ?", filterID) - _, err := r.db.handler.ExecContext(ctx, `DELETE FROM action WHERE filter_id = ?`, filterID) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("actions: error deleting by filterid") + log.Error().Stack().Err(err).Msg("action.deleteByFilterID: error building query") return err } - log.Debug().Msgf("actions: delete by filterid %v", filterID) + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("action.deleteByFilterID: error executing query") + return err + } + + log.Debug().Msgf("action.deleteByFilterID: %v", filterID) return nil } func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.Action, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - execCmd := toNullString(action.ExecCmd) execArgs := toNullString(action.ExecArgs) watchFolder := toNullString(action.WatchFolder) @@ -159,8 +227,89 @@ func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.A tags := toNullString(action.Tags) label := toNullString(action.Label) savePath := toNullString(action.SavePath) - host := toNullString(action.WebhookHost) - data := toNullString(action.WebhookData) + webhookHost := toNullString(action.WebhookHost) + webhookData := toNullString(action.WebhookData) + webhookType := toNullString(action.WebhookType) + webhookMethod := toNullString(action.WebhookMethod) + + limitDL := toNullInt64(action.LimitDownloadSpeed) + limitUL := toNullInt64(action.LimitUploadSpeed) + clientID := toNullInt32(action.ClientID) + filterID := toNullInt32(int32(action.FilterID)) + + queryBuilder := r.db.squirrel. + Insert("action"). + Columns( + "name", + "type", + "enabled", + "exec_cmd", + "exec_args", + "watch_folder", + "category", + "tags", + "label", + "save_path", + "paused", + "ignore_rules", + "limit_upload_speed", + "limit_download_speed", + "webhook_host", + "webhook_type", + "webhook_method", + "webhook_data", + "client_id", + "filter_id", + ). + Values( + action.Name, + action.Type, + action.Enabled, + execCmd, + execArgs, + watchFolder, + category, + tags, + label, + savePath, + action.Paused, + action.IgnoreRules, + limitUL, + limitDL, + webhookHost, + webhookType, + webhookMethod, + webhookData, + clientID, + filterID, + ). + Suffix("RETURNING id").RunWith(r.db.handler) + + // return values + var retID int64 + + err := queryBuilder.QueryRowContext(ctx).Scan(&retID) + if err != nil { + log.Error().Stack().Err(err).Msg("action.store: error executing query") + return nil, err + } + + log.Debug().Msgf("action.store: added new %v", retID) + action.ID = int(retID) + + return &action, nil +} + +func (r *ActionRepo) Update(ctx context.Context, action domain.Action) (*domain.Action, error) { + execCmd := toNullString(action.ExecCmd) + execArgs := toNullString(action.ExecArgs) + watchFolder := toNullString(action.WatchFolder) + category := toNullString(action.Category) + tags := toNullString(action.Tags) + label := toNullString(action.Label) + savePath := toNullString(action.SavePath) + webhookHost := toNullString(action.WebhookHost) + webhookData := toNullString(action.WebhookData) webhookType := toNullString(action.WebhookType) webhookMethod := toNullString(action.WebhookMethod) @@ -170,32 +319,49 @@ func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.A filterID := toNullInt32(int32(action.FilterID)) var err error - if action.ID != 0 { - log.Debug().Msg("actions: update existing record") - _, err = r.db.handler.ExecContext(ctx, `UPDATE action SET name = ?, type = ?, enabled = ?, exec_cmd = ?, exec_args = ?, watch_folder = ? , category =? , tags = ?, label = ?, save_path = ?, paused = ?, ignore_rules = ?, limit_upload_speed = ?, limit_download_speed = ?, webhook_host = ?, webhook_data = ?, webhook_type = ?, webhook_method = ?, client_id = ? - WHERE id = ?`, action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, host, data, webhookType, webhookMethod, clientID, action.ID) - } else { - var res sql.Result - res, err = r.db.handler.ExecContext(ctx, `INSERT INTO action(name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_upload_speed, limit_download_speed, webhook_host, webhook_data, webhook_type, webhook_method, client_id, filter_id) - VALUES (?, ?, ?, ?, ?,? ,?, ?,?,?,?,?,?,?,?,?,?,?,?,?) ON CONFLICT DO NOTHING`, action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, host, data, webhookType, webhookMethod, clientID, filterID) - if err != nil { - log.Error().Err(err) - return nil, err - } + queryBuilder := r.db.squirrel. + Update("action"). + Set("name", action.Name). + Set("type", action.Type). + Set("enabled", action.Enabled). + Set("exec_cmd", execCmd). + Set("exec_args", execArgs). + Set("watch_folder", watchFolder). + Set("category", category). + Set("tags", tags). + Set("label", label). + Set("save_path", savePath). + Set("paused", action.Paused). + Set("ignore_rules", action.IgnoreRules). + Set("limit_upload_speed", limitUL). + Set("limit_download_speed", limitDL). + Set("webhook_host", webhookHost). + Set("webhook_type", webhookType). + Set("webhook_method", webhookMethod). + Set("webhook_data", webhookData). + Set("client_id", clientID). + Set("filter_id", filterID). + Where("id = ?", action.ID) - resId, _ := res.LastInsertId() - log.Debug().Msgf("actions: added new %v", resId) - action.ID = int(resId) + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("action.update: error building query") + return nil, err } + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("action.update: error executing query") + return nil, err + } + + log.Debug().Msgf("action.update: %v", action.ID) + return &action, nil } func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Action, filterID int64) ([]domain.Action, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - tx, err := r.db.handler.BeginTx(ctx, nil) if err != nil { return nil, err @@ -203,9 +369,18 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Ac defer tx.Rollback() - _, err = tx.ExecContext(ctx, `DELETE FROM action WHERE filter_id = ?`, filterID) + deleteQueryBuilder := r.db.squirrel. + Delete("action"). + Where("filter_id = ?", filterID) + + deleteQuery, deleteArgs, err := deleteQueryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting actions for filter: %v", filterID) + log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error building query") + return nil, err + } + _, err = tx.ExecContext(ctx, deleteQuery, deleteArgs...) + if err != nil { + log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error executing query") return nil, err } @@ -217,8 +392,8 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Ac tags := toNullString(action.Tags) label := toNullString(action.Label) savePath := toNullString(action.SavePath) - host := toNullString(action.WebhookHost) - data := toNullString(action.WebhookData) + webhookHost := toNullString(action.WebhookHost) + webhookData := toNullString(action.WebhookData) webhookType := toNullString(action.WebhookType) webhookMethod := toNullString(action.WebhookMethod) @@ -226,25 +401,71 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Ac limitUL := toNullInt64(action.LimitUploadSpeed) clientID := toNullInt32(action.ClientID) - var err error - var res sql.Result + queryBuilder := r.db.squirrel. + Insert("action"). + Columns( + "name", + "type", + "enabled", + "exec_cmd", + "exec_args", + "watch_folder", + "category", + "tags", + "label", + "save_path", + "paused", + "ignore_rules", + "limit_upload_speed", + "limit_download_speed", + "webhook_host", + "webhook_type", + "webhook_method", + "webhook_data", + "client_id", + "filter_id", + ). + Values( + action.Name, + action.Type, + action.Enabled, + execCmd, + execArgs, + watchFolder, + category, + tags, + label, + savePath, + action.Paused, + action.IgnoreRules, + limitUL, + limitDL, + webhookHost, + webhookType, + webhookMethod, + webhookData, + clientID, + filterID, + ). + Suffix("RETURNING id").RunWith(tx) - res, err = tx.ExecContext(ctx, `INSERT INTO action(name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_upload_speed, limit_download_speed, webhook_host, webhook_data, webhook_type, webhook_method, client_id, filter_id) - VALUES (?, ?, ?, ?, ?,? ,?, ?,?,?,?,?,?,?,?,?,?,?,?,?) ON CONFLICT DO NOTHING`, action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, host, data, webhookType, webhookMethod, clientID, filterID) + // return values + var retID int + + err = queryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { - log.Error().Stack().Err(err).Msg("actions: error executing query") + log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error executing query") return nil, err } - resId, _ := res.LastInsertId() - action.ID = int(resId) + action.ID = retID - log.Debug().Msgf("actions: store '%v' type: '%v' on filter: %v", action.Name, action.Type, filterID) + log.Debug().Msgf("action.StoreFilterActions: store '%v' type: '%v' on filter: %v", action.Name, action.Type, filterID) } err = tx.Commit() if err != nil { - log.Error().Stack().Err(err).Msg("error updating actions") + log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error updating actions") return nil, err } @@ -253,20 +474,26 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Ac } func (r *ActionRepo) ToggleEnabled(actionID int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - var err error - var res sql.Result - res, err = r.db.handler.Exec(`UPDATE action SET enabled = NOT enabled WHERE id = ?`, actionID) + queryBuilder := r.db.squirrel. + Update("action"). + Set("enabled", sq.Expr("NOT enabled")). + Where("id = ?", actionID) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Err(err) + log.Error().Stack().Err(err).Msg("action.toggleEnabled: error building query") return err } - resId, _ := res.LastInsertId() - log.Info().Msgf("LAST INSERT ID %v", resId) + _, err = r.db.handler.Exec(query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("action.toggleEnabled: error executing query") + return err + } + + log.Debug().Msgf("action.toggleEnabled: %v", actionID) return nil } diff --git a/internal/database/database.go b/internal/database/database.go index 08a6619..027a274 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -6,6 +6,7 @@ import ( "fmt" "sync" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog/log" "github.com/autobrr/autobrr/internal/domain" @@ -19,10 +20,15 @@ type DB struct { Driver string DSN string + + squirrel sq.StatementBuilderType } func NewDB(cfg domain.Config) (*DB, error) { - db := &DB{} + db := &DB{ + // set default placeholder for squirrel to support both sqlite and postgres + squirrel: sq.StatementBuilder.PlaceholderFormat(sq.Dollar), + } db.ctx, db.cancel = context.WithCancel(context.Background()) switch cfg.DatabaseType { diff --git a/internal/database/download_client.go b/internal/database/download_client.go index 041dc20..232cbdd 100644 --- a/internal/database/download_client.go +++ b/internal/database/download_client.go @@ -2,7 +2,6 @@ package database import ( "context" - "database/sql" "encoding/json" "sync" @@ -57,14 +56,34 @@ func NewDownloadClientRepo(db *DB) domain.DownloadClientRepo { } func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() clients := make([]domain.DownloadClient, 0) - rows, err := r.db.handler.QueryContext(ctx, "SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client") + queryBuilder := r.db.squirrel. + Select( + "id", + "name", + "type", + "enabled", + "host", + "port", + "tls", + "tls_skip_verify", + "username", + "password", + "settings", + ). + From("client") + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("could not query download client rows") - return clients, err + log.Error().Stack().Err(err).Msg("download_client.list: error building query") + return nil, err + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("download_client.list: error executing query") + return nil, err } defer rows.Close() @@ -74,7 +93,7 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, var settingsJsonStr string if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.TLS, &f.TLSSkipVerify, &f.Username, &f.Password, &settingsJsonStr); err != nil { - log.Error().Stack().Err(err).Msg("could not scan download client to struct") + log.Error().Stack().Err(err).Msg("download_client.list: error scanning row") return clients, err } @@ -88,7 +107,7 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, clients = append(clients, f) } if err := rows.Err(); err != nil { - log.Error().Stack().Err(err).Msg("could not query download client rows") + log.Error().Stack().Err(err).Msg("download_client.list: row error") return clients, err } @@ -96,20 +115,38 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, } func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - // get client from cache c := r.cache.Get(int(id)) if c != nil { return c, nil } - query := `SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client WHERE id = ?` + queryBuilder := r.db.squirrel. + Select( + "id", + "name", + "type", + "enabled", + "host", + "port", + "tls", + "tls_skip_verify", + "username", + "password", + "settings", + ). + From("client"). + Where("id = ?", id) - row := r.db.handler.QueryRowContext(ctx, query, id) - if err := row.Err(); err != nil { - log.Error().Stack().Err(err).Msg("could not query download client rows") + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("download_client.findByID: error building query") + return nil, err + } + + row := r.db.handler.QueryRowContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("download_client.findByID: error executing query") return nil, err } @@ -117,7 +154,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do var settingsJsonStr string if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.TLS, &client.TLSSkipVerify, &client.Username, &client.Password, &settingsJsonStr); err != nil { - log.Error().Stack().Err(err).Msg("could not scan download client to struct") + log.Error().Stack().Err(err).Msg("download_client.findByID: error scanning row") return nil, err } @@ -132,9 +169,6 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do } func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - var err error settings := domain.DownloadClientSettings{ @@ -149,79 +183,73 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl return nil, err } - if client.ID != 0 { - _, err = r.db.handler.ExecContext(ctx, ` - UPDATE - client - SET - name = ?, - type = ?, - enabled = ?, - host = ?, - port = ?, - tls = ?, - tls_skip_verify = ?, - username = ?, - password = ?, - settings = (?) - WHERE - id = ?`, - client.Name, - client.Type, - client.Enabled, - client.Host, - client.Port, - client.TLS, - client.TLSSkipVerify, - client.Username, - client.Password, - string(settingsJson), - client.ID, - ) - if err != nil { - log.Error().Stack().Err(err).Msgf("could not update download client: %v", client) - return nil, err - } - } else { - var res sql.Result + queryBuilder := r.db.squirrel. + Insert("client"). + Columns("name", "type", "enabled", "host", "port", "tls", "tls_skip_verify", "username", "password", "settings"). + Values(client.Name, client.Type, client.Enabled, client.Host, client.Port, client.TLS, client.TLSSkipVerify, client.Username, client.Password, settingsJson). + Suffix("RETURNING id").RunWith(r.db.handler) - res, err = r.db.handler.ExecContext(ctx, `INSERT INTO - client( - name, - type, - enabled, - host, - port, - tls, - tls_skip_verify, - username, - password, - settings) - VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?, ?) ON CONFLICT DO NOTHING`, - client.Name, - client.Type, - client.Enabled, - client.Host, - client.Port, - client.TLS, - client.TLSSkipVerify, - client.Username, - client.Password, - string(settingsJson), - ) - if err != nil { - log.Error().Stack().Err(err).Msgf("could not store new download client: %v", client) - return nil, err - } + // return values + var retID int - resId, _ := res.LastInsertId() - client.ID = int(resId) - - log.Trace().Msgf("download_client: store new record %d", client.ID) + err = queryBuilder.QueryRowContext(ctx).Scan(&retID) + if err != nil { + log.Error().Stack().Err(err).Msg("download_client.store: error executing query") + return nil, err } - log.Info().Msgf("store download client: %v", client.Name) - log.Trace().Msgf("store download client: %+v", client) + client.ID = retID + + log.Debug().Msgf("download_client.store: %d", client.ID) + + // save to cache + r.cache.Set(client.ID, &client) + + return &client, nil +} + +func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { + var err error + + settings := domain.DownloadClientSettings{ + APIKey: client.Settings.APIKey, + Basic: client.Settings.Basic, + Rules: client.Settings.Rules, + } + + settingsJson, err := json.Marshal(&settings) + if err != nil { + log.Error().Stack().Err(err).Msgf("could not marshal download client settings %v", settings) + return nil, err + } + + queryBuilder := r.db.squirrel. + Update("client"). + Set("name", client.Name). + Set("type", client.Type). + Set("enabled", client.Enabled). + Set("host", client.Host). + Set("port", client.Port). + Set("tls", client.TLS). + Set("tls_skip_verify", client.TLSSkipVerify). + Set("username", client.Username). + Set("password", client.Password). + Set("settings", string(settingsJson)). + Where("id = ?", client.ID) + + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("download_client.update: error building query") + return nil, err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("download_client.update: error querying data") + return nil, err + } + + log.Debug().Msgf("download_client.update: %d", client.ID) // save to cache r.cache.Set(client.ID, &client) @@ -230,12 +258,19 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl } func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Delete("client"). + Where("id = ?", clientID) - res, err := r.db.handler.ExecContext(ctx, `DELETE FROM client WHERE client.id = ?`, clientID) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID) + log.Error().Stack().Err(err).Msg("download_client.delete: error building query") + return err + } + + res, err := r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("download_client.delete: error query data") return err } diff --git a/internal/database/filter.go b/internal/database/filter.go index 7cbbcac..02ab7fd 100644 --- a/internal/database/filter.go +++ b/internal/database/filter.go @@ -3,7 +3,9 @@ package database import ( "context" "database/sql" + sq "github.com/Masterminds/squirrel" "strings" + "time" "github.com/lib/pq" "github.com/rs/zerolog/log" @@ -20,12 +22,28 @@ func NewFilterRepo(db *DB) domain.FilterRepo { } func (r *FilterRepo) ListFilters(ctx context.Context) ([]domain.Filter, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select( + "id", + "enabled", + "name", + "match_releases", + "except_releases", + "created_at", + "updated_at", + ). + From("filter"). + OrderBy("name ASC") - rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter ORDER BY name ASC") + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("filters_list: error query data") + log.Error().Stack().Err(err).Msg("filter.list: error building query") + return nil, err + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.list: error executing query") return nil, err } @@ -38,7 +56,7 @@ func (r *FilterRepo) ListFilters(ctx context.Context) ([]domain.Filter, error) { var matchReleases, exceptReleases sql.NullString if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &matchReleases, &exceptReleases, &f.CreatedAt, &f.UpdatedAt); err != nil { - log.Error().Stack().Err(err).Msg("filters_list: error scanning data to struct") + log.Error().Stack().Err(err).Msg("filter.list: error scanning row") return nil, err } @@ -48,6 +66,7 @@ func (r *FilterRepo) ListFilters(ctx context.Context) ([]domain.Filter, error) { filters = append(filters, f) } if err := rows.Err(); err != nil { + log.Error().Stack().Err(err).Msg("filter.list: row error") return nil, err } @@ -55,11 +74,64 @@ func (r *FilterRepo) ListFilters(ctx context.Context) ([]domain.Filter, error) { } func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select( + "id", + "enabled", + "name", + "min_size", + "max_size", + "delay", + "priority", + "match_releases", + "except_releases", + "use_regex", + "match_release_groups", + "except_release_groups", + "scene", + "freeleech", + "freeleech_percent", + "shows", + "seasons", + "episodes", + "resolutions", + "codecs", + "sources", + "containers", + "match_hdr", + "except_hdr", + "years", + "artists", + "albums", + "release_types_match", + "formats", + "quality", + "media", + "log_score", + "has_log", + "has_cue", + "perfect_flac", + "match_categories", + "except_categories", + "match_uploaders", + "except_uploaders", + "tags", + "except_tags", + "created_at", + "updated_at", + ). + From("filter"). + Where("id = ?", filterID) - row := r.db.handler.QueryRowContext(ctx, "SELECT id, enabled, name, min_size, max_size, delay, priority, match_releases, except_releases, use_regex, match_release_groups, except_release_groups, scene, freeleech, freeleech_percent, shows, seasons, episodes, resolutions, codecs, sources, containers, match_hdr, except_hdr, years, artists, albums, release_types_match, formats, quality, media, log_score, has_log, has_cue, perfect_flac, match_categories, except_categories, match_uploaders, except_uploaders, tags, except_tags, created_at, updated_at FROM filter WHERE id = ?", filterID) + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("filter.findByID: error building query") + return nil, err + } + + row := r.db.handler.QueryRowContext(ctx, query, args...) if err := row.Err(); err != nil { + log.Error().Stack().Err(err).Msg("filter.findByID: error query row") return nil, err } @@ -69,7 +141,7 @@ func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter var delay, logScore sql.NullInt32 if err := row.Scan(&f.ID, &f.Enabled, &f.Name, &minSize, &maxSize, &delay, &f.Priority, &matchReleases, &exceptReleases, &useRegex, &matchReleaseGroups, &exceptReleaseGroups, &scene, &freeleech, &freeleechPercent, &shows, &seasons, &episodes, pq.Array(&f.Resolutions), pq.Array(&f.Codecs), pq.Array(&f.Sources), pq.Array(&f.Containers), pq.Array(&f.MatchHDR), pq.Array(&f.ExceptHDR), &years, &artists, &albums, pq.Array(&f.MatchReleaseTypes), pq.Array(&f.Formats), pq.Array(&f.Quality), pq.Array(&f.Media), &logScore, &hasLog, &hasCue, &perfectFlac, &matchCategories, &exceptCategories, &matchUploaders, &exceptUploaders, &tags, &exceptTags, &f.CreatedAt, &f.UpdatedAt); err != nil { - log.Error().Stack().Err(err).Msgf("filter: %v : error scanning data to struct", filterID) + log.Error().Stack().Err(err).Msgf("filter.findByID: %v : error scanning row", filterID) return nil, err } @@ -106,63 +178,69 @@ func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter // FindByIndexerIdentifier find active filters with active indexer only func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select( + "f.id", + "f.enabled", + "f.name", + "f.min_size", + "f.max_size", + "f.delay", + "f.priority", + "f.match_releases", + "f.except_releases", + "f.use_regex", + "f.match_release_groups", + "f.except_release_groups", + "f.scene", + "f.freeleech", + "f.freeleech_percent", + "f.shows", + "f.seasons", + "f.episodes", + "f.resolutions", + "f.codecs", + "f.sources", + "f.containers", + "f.match_hdr", + "f.except_hdr", + "f.years", + "f.artists", + "f.albums", + "f.release_types_match", + "f.formats", + "f.quality", + "f.media", + "f.log_score", + "f.has_log", + "f.has_cue", + "f.perfect_flac", + "f.match_categories", + "f.except_categories", + "f.match_uploaders", + "f.except_uploaders", + "f.tags", + "f.except_tags", + "f.created_at", + "f.updated_at", + ). + From("filter f"). + Join("filter_indexer fi ON f.id = fi.filter_id"). + Join("indexer i ON i.id = fi.indexer_id"). + Where("i.identifier = ?", indexer). + Where("i.enabled = ?", true). + Where("f.enabled = ?", true). + OrderBy("f.priority DESC") - rows, err := r.db.handler.Query(` - SELECT - f.id, - f.enabled, - f.name, - f.min_size, - f.max_size, - f.delay, - f.priority, - f.match_releases, - f.except_releases, - f.use_regex, - f.match_release_groups, - f.except_release_groups, - f.scene, - f.freeleech, - f.freeleech_percent, - f.shows, - f.seasons, - f.episodes, - f.resolutions, - f.codecs, - f.sources, - f.containers, - f.match_hdr, - f.except_hdr, - f.years, - f.artists, - f.albums, - f.release_types_match, - f.formats, - f.quality, - f.media, - f.log_score, - f.has_log, - f.has_cue, - f.perfect_flac, - f.match_categories, - f.except_categories, - f.match_uploaders, - f.except_uploaders, - f.tags, - f.except_tags, - f.created_at, - f.updated_at - FROM filter f - JOIN filter_indexer fi on f.id = fi.filter_id - JOIN indexer i on i.id = fi.indexer_id - WHERE i.identifier = ? - AND f.enabled = true - AND i.enabled = true - ORDER BY f.priority DESC`, indexer) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error querying filter row") + log.Error().Stack().Err(err).Msg("filter.findByIndexerIdentifier: error building query") + return nil, err + } + + rows, err := r.db.handler.Query(query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.findByIndexerIdentifier: error executing query") return nil, err } @@ -177,7 +255,7 @@ func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, e var delay, logScore sql.NullInt32 if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &minSize, &maxSize, &delay, &f.Priority, &matchReleases, &exceptReleases, &useRegex, &matchReleaseGroups, &exceptReleaseGroups, &scene, &freeleech, &freeleechPercent, &shows, &seasons, &episodes, pq.Array(&f.Resolutions), pq.Array(&f.Codecs), pq.Array(&f.Sources), pq.Array(&f.Containers), pq.Array(&f.MatchHDR), pq.Array(&f.ExceptHDR), &years, &artists, &albums, pq.Array(&f.MatchReleaseTypes), pq.Array(&f.Formats), pq.Array(&f.Quality), pq.Array(&f.Media), &logScore, &hasLog, &hasCue, &perfectFlac, &matchCategories, &exceptCategories, &matchUploaders, &exceptUploaders, &tags, &exceptTags, &f.CreatedAt, &f.UpdatedAt); err != nil { - log.Error().Stack().Err(err).Msg("error scanning data to struct") + log.Error().Stack().Err(err).Msg("filter.findByIndexerIdentifier: error scanning row") return nil, err } @@ -211,66 +289,56 @@ func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, e filters = append(filters, f) } - if err := rows.Err(); err != nil { - return nil, err - } return filters, nil } func (r *FilterRepo) Store(ctx context.Context, filter domain.Filter) (*domain.Filter, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - - var err error - if filter.ID != 0 { - log.Debug().Msg("update existing record") - } else { - var res sql.Result - - res, err = r.db.handler.ExecContext(ctx, `INSERT INTO filter ( - name, - enabled, - min_size, - max_size, - delay, - priority, - match_releases, - except_releases, - use_regex, - match_release_groups, - except_release_groups, - scene, - freeleech, - freeleech_percent, - shows, - seasons, - episodes, - resolutions, - codecs, - sources, - containers, - match_hdr, - except_hdr, - years, - match_categories, - except_categories, - match_uploaders, - except_uploaders, - tags, - except_tags, - artists, - albums, - release_types_match, - formats, - quality, - media, - log_score, - has_log, - has_cue, - perfect_flac - ) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23, $24, $25, $26, $27, $28, $29, $30, $31, $32, $33, $34, $35, $36, $37, $38, $39, $40) ON CONFLICT DO NOTHING`, + queryBuilder := r.db.squirrel. + Insert("filter"). + Columns( + "name", + "enabled", + "min_size", + "max_size", + "delay", + "priority", + "match_releases", + "except_releases", + "use_regex", + "match_release_groups", + "except_release_groups", + "scene", + "freeleech", + "freeleech_percent", + "shows", + "seasons", + "episodes", + "resolutions", + "codecs", + "sources", + "containers", + "match_hdr", + "except_hdr", + "years", + "match_categories", + "except_categories", + "match_uploaders", + "except_uploaders", + "tags", + "except_tags", + "artists", + "albums", + "release_types_match", + "formats", + "quality", + "media", + "log_score", + "has_log", + "has_cue", + "perfect_flac", + ). + Values( filter.Name, filter.Enabled, filter.MinSize, @@ -311,114 +379,80 @@ func (r *FilterRepo) Store(ctx context.Context, filter domain.Filter) (*domain.F filter.Log, filter.Cue, filter.PerfectFlac, - ) - if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") - return nil, err - } + ). + Suffix("RETURNING id").RunWith(r.db.handler) - resId, _ := res.LastInsertId() - filter.ID = int(resId) + // return values + var retID int + + err := queryBuilder.QueryRowContext(ctx).Scan(&retID) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.store: error executing query") + return nil, err } + filter.ID = retID + return &filter, nil } func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain.Filter, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - - //var res sql.Result - var err error - _, err = r.db.handler.ExecContext(ctx, ` - UPDATE filter SET - name = ?, - enabled = ?, - min_size = ?, - max_size = ?, - delay = ?, - priority = ?, - match_releases = ?, - except_releases = ?, - use_regex = ?, - match_release_groups = ?, - except_release_groups = ?, - scene = ?, - freeleech = ?, - freeleech_percent = ?, - shows = ?, - seasons = ?, - episodes = ?, - resolutions = ?, - codecs = ?, - sources = ?, - containers = ?, - match_hdr = ?, - except_hdr = ?, - years = ?, - match_categories = ?, - except_categories = ?, - match_uploaders = ?, - except_uploaders = ?, - tags = ?, - except_tags = ?, - artists = ?, - albums = ?, - release_types_match = ?, - formats = ?, - quality = ?, - media = ?, - log_score = ?, - has_log = ?, - has_cue = ?, - perfect_flac = ?, - updated_at = CURRENT_TIMESTAMP - WHERE id = ?`, - filter.Name, - filter.Enabled, - filter.MinSize, - filter.MaxSize, - filter.Delay, - filter.Priority, - filter.MatchReleases, - filter.ExceptReleases, - filter.UseRegex, - filter.MatchReleaseGroups, - filter.ExceptReleaseGroups, - filter.Scene, - filter.Freeleech, - filter.FreeleechPercent, - filter.Shows, - filter.Seasons, - filter.Episodes, - pq.Array(filter.Resolutions), - pq.Array(filter.Codecs), - pq.Array(filter.Sources), - pq.Array(filter.Containers), - pq.Array(filter.MatchHDR), - pq.Array(filter.ExceptHDR), - filter.Years, - filter.MatchCategories, - filter.ExceptCategories, - filter.MatchUploaders, - filter.ExceptUploaders, - filter.Tags, - filter.ExceptTags, - filter.Artists, - filter.Albums, - pq.Array(filter.MatchReleaseTypes), - pq.Array(filter.Formats), - pq.Array(filter.Quality), - pq.Array(filter.Media), - filter.LogScore, - filter.Log, - filter.Cue, - filter.PerfectFlac, - filter.ID, - ) + + queryBuilder := r.db.squirrel. + Update("filter"). + Set("name", filter.Name). + Set("enabled", filter.Enabled). + Set("min_size", filter.MinSize). + Set("max_size", filter.MaxSize). + Set("delay", filter.Delay). + Set("priority", filter.Priority). + Set("use_regex", filter.UseRegex). + Set("match_releases", filter.MatchReleases). + Set("except_releases", filter.ExceptReleases). + Set("match_release_groups", filter.MatchReleaseGroups). + Set("except_release_groups", filter.ExceptReleaseGroups). + Set("scene", filter.Scene). + Set("freeleech", filter.Freeleech). + Set("freeleech_percent", filter.FreeleechPercent). + Set("shows", filter.Shows). + Set("seasons", filter.Seasons). + Set("episodes", filter.Episodes). + Set("resolutions", pq.Array(filter.Resolutions)). + Set("codecs", pq.Array(filter.Codecs)). + Set("sources", pq.Array(filter.Sources)). + Set("containers", pq.Array(filter.Containers)). + Set("match_hdr", pq.Array(filter.MatchHDR)). + Set("except_hdr", pq.Array(filter.ExceptHDR)). + Set("years", filter.Years). + Set("match_categories", filter.MatchCategories). + Set("except_categories", filter.ExceptCategories). + Set("match_uploaders", filter.MatchUploaders). + Set("except_uploaders", filter.ExceptUploaders). + Set("tags", filter.Tags). + Set("except_tags", filter.ExceptTags). + Set("artists", filter.Artists). + Set("albums", filter.Albums). + Set("release_types_match", pq.Array(filter.MatchReleaseTypes)). + Set("formats", pq.Array(filter.Formats)). + Set("quality", pq.Array(filter.Quality)). + Set("media", pq.Array(filter.Media)). + Set("log_score", filter.LogScore). + Set("has_log", filter.Log). + Set("has_cue", filter.Cue). + Set("perfect_flac", filter.PerfectFlac). + Set("updated_at", time.Now().Format(time.RFC3339)). + Where("id = ?", filter.ID) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("filter.update: error building query") + return nil, err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.update: error executing query") return nil, err } @@ -426,20 +460,22 @@ func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain. } func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bool) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - var err error - _, err = r.db.handler.ExecContext(ctx, ` - UPDATE filter SET - enabled = ?, - updated_at = CURRENT_TIMESTAMP - WHERE id = ?`, - enabled, - filterID, - ) + + queryBuilder := r.db.squirrel. + Update("filter"). + Set("enabled", enabled). + Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")). + Where("id = ?", filterID) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("filter.toggleEnabled: error building query") + return err + } + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.toggleEnabled: error executing query") return err } @@ -447,9 +483,6 @@ func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bo } func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int, indexers []domain.Indexer) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - tx, err := r.db.handler.BeginTx(ctx, nil) if err != nil { return err @@ -457,27 +490,43 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int, defer tx.Rollback() - deleteQuery := `DELETE FROM filter_indexer WHERE filter_id = ?` - _, err = tx.ExecContext(ctx, deleteQuery, filterID) + deleteQueryBuilder := r.db.squirrel. + Delete("filter_indexer"). + Where("filter_id = ?", filterID) + + deleteQuery, deleteArgs, err := deleteQueryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting indexers for filter: %v", filterID) + log.Error().Stack().Err(err).Msg("filter.StoreIndexerConnections: error building query") + return err + } + _, err = tx.ExecContext(ctx, deleteQuery, deleteArgs...) + if err != nil { + log.Error().Stack().Err(err).Msgf("filter.StoreIndexerConnections: error deleting indexers for filter: %v", filterID) return err } for _, indexer := range indexers { - query := `INSERT INTO filter_indexer (filter_id, indexer_id) VALUES ($1, $2)` - _, err := tx.ExecContext(ctx, query, filterID, indexer.ID) + queryBuilder := r.db.squirrel. + Insert("filter_indexer").Columns("filter_id", "indexer_id"). + Values(filterID, indexer.ID) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("filter.StoreIndexerConnections: error building query") + return err + } + _, err = tx.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.StoreIndexerConnections: error executing query") return err } - log.Debug().Msgf("filter.indexers: store '%v' on filter: %v", indexer.Name, filterID) + log.Debug().Msgf("filter.StoreIndexerConnections: store '%v' on filter: %v", indexer.Name, filterID) } err = tx.Commit() if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting indexers for filter: %v", filterID) + log.Error().Stack().Err(err).Msgf("filter.StoreIndexerConnections: error storing indexers for filter: %v", filterID) return err } @@ -485,13 +534,19 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int, } func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, indexerID int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Insert("filter_indexer").Columns("filter_id", "indexer_id"). + Values(filterID, indexerID) - query := `INSERT INTO filter_indexer (filter_id, indexer_id) VALUES ($1, $2)` - _, err := r.db.handler.ExecContext(ctx, query, filterID, indexerID) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("filter.storeIndexerConnection: error building query") + return err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.storeIndexerConnection: error executing query") return err } @@ -499,13 +554,19 @@ func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, i } func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Delete("filter_indexer"). + Where("filter_id = ?", filterID) - query := `DELETE FROM filter_indexer WHERE filter_id = ?` - _, err := r.db.handler.ExecContext(ctx, query, filterID) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("filter.deleteIndexerConnections: error building query") + return err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.deleteIndexerConnections: error executing query") return err } @@ -513,12 +574,19 @@ func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int) } func (r *FilterRepo) Delete(ctx context.Context, filterID int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Delete("filter"). + Where("id = ?", filterID) - _, err := r.db.handler.ExecContext(ctx, `DELETE FROM filter WHERE id = ?`, filterID) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("filter.delete: error building query") + return err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("filter.delete: error executing query") return err } diff --git a/internal/database/indexer.go b/internal/database/indexer.go index fd947ad..e92f985 100644 --- a/internal/database/indexer.go +++ b/internal/database/indexer.go @@ -3,8 +3,11 @@ package database import ( "context" "encoding/json" - "github.com/autobrr/autobrr/internal/domain" + "time" + "github.com/rs/zerolog/log" + + "github.com/autobrr/autobrr/internal/domain" ) type IndexerRepo struct { @@ -17,52 +20,64 @@ func NewIndexerRepo(db *DB) domain.IndexerRepo { } } -func (r *IndexerRepo) Store(indexer domain.Indexer) (*domain.Indexer, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - +func (r *IndexerRepo) Store(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) { settings, err := json.Marshal(indexer.Settings) if err != nil { log.Error().Stack().Err(err).Msg("error marshaling json data") return nil, err } - res, err := r.db.handler.Exec(`INSERT INTO indexer (enabled, name, identifier, settings) VALUES (?, ?, ?, ?)`, indexer.Enabled, indexer.Name, indexer.Identifier, settings) + queryBuilder := r.db.squirrel. + Insert("indexer").Columns("enabled", "name", "identifier", "settings"). + Values(indexer.Enabled, indexer.Name, indexer.Identifier, settings). + Suffix("RETURNING id").RunWith(r.db.handler) + + // return values + var retID int64 + + err = queryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("indexer.store: error executing query") return nil, err } - id, _ := res.LastInsertId() - indexer.ID = id + indexer.ID = retID return &indexer, nil } -func (r *IndexerRepo) Update(indexer domain.Indexer) (*domain.Indexer, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - - sett, err := json.Marshal(indexer.Settings) +func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) { + settings, err := json.Marshal(indexer.Settings) if err != nil { log.Error().Stack().Err(err).Msg("error marshaling json data") return nil, err } - _, err = r.db.handler.Exec(`UPDATE indexer SET enabled = ?, name = ?, settings = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, indexer.Enabled, indexer.Name, sett, indexer.ID) + queryBuilder := r.db.squirrel. + Update("indexer"). + Set("enabled", indexer.Enabled). + Set("name", indexer.Name). + Set("settings", settings). + Set("updated_at", time.Now().Format(time.RFC3339)). + Where("id = ?", indexer.ID) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("indexer.update: error building query") + return nil, err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("indexer.update: error executing query") return nil, err } return &indexer, nil } -func (r *IndexerRepo) List() ([]domain.Indexer, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - - rows, err := r.db.handler.Query("SELECT id, enabled, name, identifier, settings FROM indexer ORDER BY name ASC") +func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) { + rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, identifier, settings FROM indexer ORDER BY name ASC") if err != nil { log.Error().Stack().Err(err).Msg("indexer.list: error query indexer") return nil, err @@ -100,14 +115,19 @@ func (r *IndexerRepo) List() ([]domain.Indexer, error) { } func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Indexer, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select("id", "enabled", "name", "identifier"). + From("indexer"). + Join("filter_indexer ON indexer.id = filter_indexer.indexer_id"). + Where("filter_indexer.filter_id = ?", id) - rows, err := r.db.handler.QueryContext(ctx, ` - SELECT i.id, i.enabled, i.name, i.identifier - FROM indexer i - JOIN filter_indexer fi on i.id = fi.indexer_id - WHERE fi.filter_id = ?`, id) + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("irc.check_existing_network: error fetching data") + return nil, err + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) if err != nil { log.Error().Stack().Err(err).Msg("indexer.find_by_filter_id: error query indexer") return nil, err @@ -115,7 +135,7 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde defer rows.Close() - var indexers []domain.Indexer + indexers := make([]domain.Indexer, 0) for rows.Next() { var f domain.Indexer @@ -146,12 +166,17 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde } func (r *IndexerRepo) Delete(ctx context.Context, id int) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Delete("indexer"). + Where("id = ?", id) - query := `DELETE FROM indexer WHERE id = ?` + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("indexer.delete: error building query") + return err + } - _, err := r.db.handler.ExecContext(ctx, query, id) + _, err = r.db.handler.ExecContext(ctx, query, args...) if err != nil { log.Error().Stack().Err(err).Msgf("indexer.delete: error executing query: '%v'", query) return err diff --git a/internal/database/irc.go b/internal/database/irc.go index 1a9417a..e12478a 100644 --- a/internal/database/irc.go +++ b/internal/database/irc.go @@ -3,8 +3,8 @@ package database import ( "context" "database/sql" - - sq "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + "time" "github.com/autobrr/autobrr/internal/domain" @@ -19,15 +19,18 @@ func NewIrcRepo(db *DB) domain.IrcRepo { return &IrcRepo{db: db} } -func (r *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() +func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) { + queryBuilder := r.db.squirrel. + Select("id", "enabled", "name", "server", "port", "tls", "pass", "invite_command", "nickserv_account", "nickserv_password"). + From("irc_network"). + Where("id = ?", id) - row := r.db.handler.QueryRow("SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE id = ?", id) - if err := row.Err(); err != nil { - log.Fatal().Err(err) + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("irc.getNetworkByID: error building query") return nil, err } + log.Trace().Str("database", "irc.check_existing_network").Msgf("query: '%v', args: '%v'", query, args) var n domain.IrcNetwork @@ -35,8 +38,10 @@ func (r *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) { var nsAccount, nsPassword sql.NullString var tls sql.NullBool + row := r.db.handler.QueryRowContext(ctx, query, args...) if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &inviteCmd, &nsAccount, &nsPassword); err != nil { - log.Fatal().Err(err) + log.Error().Stack().Err(err).Msg("irc.getNetworkByID: error executing query") + return nil, err } n.TLS = tls.Bool @@ -56,21 +61,41 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error { defer tx.Rollback() - _, err = tx.ExecContext(ctx, `DELETE FROM irc_channel WHERE network_id = ?`, id) + queryBuilder := r.db.squirrel. + Delete("irc_channel"). + Where("network_id = ?", id) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting channels for network: %v", id) + log.Error().Stack().Err(err).Msg("irc.deleteNetwork: error building query") return err } - _, err = tx.ExecContext(ctx, `DELETE FROM irc_network WHERE id = ?`, id) + _, err = tx.ExecContext(ctx, query, args...) if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting network: %v", id) + log.Error().Stack().Err(err).Msg("irc.deleteNetwork: error executing query") + return err + } + + netQueryBuilder := r.db.squirrel. + Delete("irc_network"). + Where("id = ?", id) + + netQuery, netArgs, err := netQueryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("irc.deleteNetwork: error building query") + return err + } + + _, err = tx.ExecContext(ctx, netQuery, netArgs...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.deleteNetwork: error executing query") return err } err = tx.Commit() if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting network: %v", id) + log.Error().Stack().Err(err).Msgf("irc.deleteNetwork: error deleting network %v", id) return err } @@ -79,12 +104,21 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error { } func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select("id", "enabled", "name", "server", "port", "tls", "pass", "invite_command", "nickserv_account", "nickserv_password"). + From("irc_network"). + Where("enabled = ?", true) - rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE enabled = true") + query, args, err := queryBuilder.ToSql() if err != nil { - log.Fatal().Err(err) + log.Error().Stack().Err(err).Msg("irc.findActiveNetworks: error building query") + return nil, err + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.findActiveNetworks: error executing query") + return nil, err } defer rows.Close() @@ -94,19 +128,25 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, var net domain.IrcNetwork var pass, inviteCmd sql.NullString + var nsAccount, nsPassword sql.NullString var tls sql.NullBool - if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &inviteCmd, &net.NickServ.Account, &net.NickServ.Password); err != nil { - log.Fatal().Err(err) + if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &inviteCmd, &nsAccount, &nsPassword); err != nil { + log.Error().Stack().Err(err).Msg("irc.findActiveNetworks: error scanning row") + return nil, err } net.TLS = tls.Bool net.Pass = pass.String net.InviteCommand = inviteCmd.String + net.NickServ.Account = nsAccount.String + net.NickServ.Password = nsPassword.String + networks = append(networks, net) } if err := rows.Err(); err != nil { + log.Error().Stack().Err(err).Msg("irc.findActiveNetworks: row error") return nil, err } @@ -114,12 +154,21 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, } func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select("id", "enabled", "name", "server", "port", "tls", "pass", "invite_command", "nickserv_account", "nickserv_password"). + From("irc_network"). + OrderBy("name ASC") - rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network ORDER BY name ASC") + query, args, err := queryBuilder.ToSql() if err != nil { - log.Fatal().Err(err) + log.Error().Stack().Err(err).Msg("irc.listNetworks: error building query") + return nil, err + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.listNetworks: error executing query") + return nil, err } defer rows.Close() @@ -129,19 +178,25 @@ func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) var net domain.IrcNetwork var pass, inviteCmd sql.NullString + var nsAccount, nsPassword sql.NullString var tls sql.NullBool - if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &inviteCmd, &net.NickServ.Account, &net.NickServ.Password); err != nil { - log.Fatal().Err(err) + if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &inviteCmd, &nsAccount, &nsPassword); err != nil { + log.Error().Stack().Err(err).Msg("irc.listNetworks: error scanning row") + return nil, err } net.TLS = tls.Bool net.Pass = pass.String net.InviteCommand = inviteCmd.String + net.NickServ.Account = nsAccount.String + net.NickServ.Password = nsPassword.String + networks = append(networks, net) } if err := rows.Err(); err != nil { + log.Error().Stack().Err(err).Msg("irc.listNetworks: row error") return nil, err } @@ -149,12 +204,20 @@ func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) } func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() + queryBuilder := r.db.squirrel. + Select("id", "name", "enabled", "password"). + From("irc_channel"). + Where("network_id = ?", networkID) - rows, err := r.db.handler.Query("SELECT id, name, enabled, password FROM irc_channel WHERE network_id = ?", networkID) + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msgf("error querying channels for network: %v", networkID) + log.Error().Stack().Err(err).Msg("irc.listChannels: error building query") + return nil, err + } + + rows, err := r.db.handler.Query(query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.listChannels: error executing query") return nil, err } defer rows.Close() @@ -165,7 +228,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) { var pass sql.NullString if err := rows.Scan(&ch.ID, &ch.Name, &ch.Enabled, &pass); err != nil { - log.Error().Stack().Err(err).Msgf("error querying channels for network: %v", networkID) + log.Error().Stack().Err(err).Msg("irc.listChannels: error scanning row") return nil, err } @@ -174,6 +237,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) { channels = append(channels, ch) } if err := rows.Err(); err != nil { + log.Error().Stack().Err(err).Msg("irc.listChannels: error row") return nil, err } @@ -181,10 +245,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) { } func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - - queryBuilder := sq. + queryBuilder := r.db.squirrel. Select("id", "enabled", "name", "server", "port", "tls", "pass", "invite_command", "nickserv_account", "nickserv_password"). From("irc_network"). Where("server = ?", network.Server). @@ -192,10 +253,10 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("irc.check_existing_network: error fetching data") + log.Error().Stack().Err(err).Msg("irc.checkExistingNetwork: error building query") return nil, err } - log.Trace().Str("database", "irc.check_existing_network").Msgf("query: '%v', args: '%v'", query, args) + log.Trace().Str("database", "irc.checkExistingNetwork").Msgf("query: '%v', args: '%v'", query, args) row := r.db.handler.QueryRowContext(ctx, query, args...) @@ -209,7 +270,7 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN // no result is not an error in our case return nil, nil } else if err != nil { - log.Error().Stack().Err(err).Msg("irc.check_existing_network: error scanning data to struct") + log.Error().Stack().Err(err).Msg("irc.checkExistingNetwork: error scanning data to struct") return nil, err } @@ -222,9 +283,6 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN } func (r *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - netName := toNullString(network.Name) pass := toNullString(network.Pass) inviteCmd := toNullString(network.InviteCommand) @@ -233,20 +291,22 @@ func (r *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error { nsPassword := toNullString(network.NickServ.Password) var err error - if network.ID != 0 { - // update record - _, err = r.db.handler.Exec(`UPDATE irc_network - SET enabled = ?, - name = ?, - server = ?, - port = ?, - tls = ?, - pass = ?, - invite_command = ?, - nickserv_account = ?, - nickserv_password = ?, - updated_at = CURRENT_TIMESTAMP - WHERE id = ?`, + var retID int64 + + queryBuilder := r.db.squirrel. + Insert("irc_network"). + Columns( + "enabled", + "name", + "server", + "port", + "tls", + "pass", + "invite_command", + "nickserv_account", + "nickserv_password", + ). + Values( network.Enabled, netName, network.Server, @@ -256,51 +316,22 @@ func (r *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error { inviteCmd, nsAccount, nsPassword, - network.ID, - ) - if err != nil { - log.Error().Stack().Err(err).Msg("irc.store_network: error executing query") - return err - } - } else { - var res sql.Result + ). + Suffix("RETURNING id"). + RunWith(r.db.handler) - res, err = r.db.handler.Exec(`INSERT INTO irc_network ( - enabled, - name, - server, - port, - tls, - pass, - invite_command, - nickserv_account, - nickserv_password - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT DO NOTHING`, - network.Enabled, - netName, - network.Server, - network.Port, - network.TLS, - pass, - inviteCmd, - nsAccount, - nsPassword, - ) - if err != nil { - log.Error().Stack().Err(err).Msg("irc.store_network: error executing query") - return err - } - - network.ID, err = res.LastInsertId() + err = queryBuilder.QueryRow().Scan(&retID) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.storeNetwork: error executing query") + return errors.Wrap(err, "error executing query") } + network.ID = retID + return err } func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - netName := toNullString(network.Name) pass := toNullString(network.Pass) inviteCmd := toNullString(network.InviteCommand) @@ -309,32 +340,31 @@ func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) nsPassword := toNullString(network.NickServ.Password) var err error - // update record - _, err = r.db.handler.ExecContext(ctx, `UPDATE irc_network - SET enabled = ?, - name = ?, - server = ?, - port = ?, - tls = ?, - pass = ?, - invite_command = ?, - nickserv_account = ?, - nickserv_password = ?, - updated_at = CURRENT_TIMESTAMP - WHERE id = ?`, - network.Enabled, - netName, - network.Server, - network.Port, - network.TLS, - pass, - inviteCmd, - nsAccount, - nsPassword, - network.ID, - ) + + queryBuilder := r.db.squirrel. + Update("irc_network"). + Set("enabled", network.Enabled). + Set("name", netName). + Set("server", network.Server). + Set("port", network.Port). + Set("tls", network.TLS). + Set("pass", pass). + Set("invite_command", inviteCmd). + Set("nickserv_account", nsAccount). + Set("nickserv_password", nsPassword). + Set("updated_at", time.Now().Format(time.RFC3339)). + Where("id = ?", network.ID) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("irc.store_network: error executing query") + log.Error().Stack().Err(err).Msg("irc.updateNetwork: error building query") + return err + } + + // update record + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.updateNetwork: error executing query") return err } @@ -344,9 +374,6 @@ func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) // TODO create new channel handler to only add, not delete func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, channels []domain.IrcChannel) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - tx, err := r.db.handler.BeginTx(ctx, nil) if err != nil { return err @@ -354,40 +381,74 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha defer tx.Rollback() - _, err = tx.ExecContext(ctx, `DELETE FROM irc_channel WHERE network_id = ?`, networkID) + queryBuilder := r.db.squirrel. + Delete("irc_channel"). + Where("network_id = ?", networkID) + + query, args, err := queryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting channels for network: %v", networkID) + log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error building query") + return err + } + + _, err = tx.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error executing query") return err } for _, channel := range channels { - var res sql.Result + // values pass := toNullString(channel.Password) - res, err = tx.ExecContext(ctx, `INSERT INTO irc_channel ( - enabled, - detached, - name, - password, - network_id - ) VALUES (?, ?, ?, ?, ?)`, - channel.Enabled, - true, - channel.Name, - pass, - networkID, - ) + channelQueryBuilder := r.db.squirrel. + Insert("irc_channel"). + Columns( + "enabled", + "detached", + "name", + "password", + "network_id", + ). + Values( + channel.Enabled, + true, + channel.Name, + pass, + networkID, + ). + Suffix("RETURNING id"). + RunWith(tx) + + // returning + var retID int64 + + err = channelQueryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") - return err + log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error executing query") + return errors.Wrap(err, "error executing query") } - channel.ID, err = res.LastInsertId() + channel.ID = retID + + //channelQuery, channelArgs, err := channelQueryBuilder.ToSql() + //if err != nil { + // log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error building query") + // return err + //} + // + //res, err = r.db.handler.ExecContext(ctx, channelQuery, channelArgs...) + //if err != nil { + // log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error executing query") + // return err + //} + // + //channel.ID, err = res.LastInsertId() } err = tx.Commit() if err != nil { - log.Error().Stack().Err(err).Msgf("error deleting network: %v", networkID) + log.Error().Stack().Err(err).Msgf("irc.storeNetworkChannels: error deleting network: %v", networkID) return err } @@ -395,50 +456,102 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha } func (r *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - pass := toNullString(channel.Password) var err error if channel.ID != 0 { // update record - _, err = r.db.handler.Exec(`UPDATE irc_channel - SET - enabled = ?, - detached = ?, - name = ?, - password = ? - WHERE - id = ?`, - channel.Enabled, - channel.Detached, - channel.Name, - pass, - channel.ID, - ) - } else { - var res sql.Result + channelQueryBuilder := r.db.squirrel. + Update("irc_channel"). + Set("enabled", channel.Enabled). + Set("detached", channel.Detached). + Set("name", channel.Name). + Set("pass", pass). + Where("id = ?", channel.ID) - res, err = r.db.handler.Exec(`INSERT INTO irc_channel ( - enabled, - detached, - name, - password, - network_id - ) VALUES (?, ?, ?, ?, ?) ON CONFLICT DO NOTHING`, - channel.Enabled, - true, - channel.Name, - pass, - networkID, - ) + query, args, err := channelQueryBuilder.ToSql() if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") + log.Error().Stack().Err(err).Msg("irc.storeChannel: error building query") return err } - channel.ID, err = res.LastInsertId() + _, err = r.db.handler.Exec(query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.storeChannel: error executing query") + return err + } + } else { + queryBuilder := r.db.squirrel. + Insert("irc_channel"). + Columns( + "enabled", + "detached", + "name", + "password", + "network_id", + ). + Values( + channel.Enabled, + true, + channel.Name, + pass, + networkID, + ). + Suffix("RETURNING id"). + RunWith(r.db.handler) + + // returning + var retID int64 + + err = queryBuilder.QueryRow().Scan(&retID) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.storeChannels: error executing query") + return errors.Wrap(err, "error executing query") + } + + channel.ID = retID + + //channelQuery, channelArgs, err := channelQueryBuilder.ToSql() + //if err != nil { + // log.Error().Stack().Err(err).Msg("irc.storeChannel: error building query") + // return err + //} + // + //res, err := r.db.handler.Exec(channelQuery, channelArgs...) + //if err != nil { + // log.Error().Stack().Err(err).Msg("irc.storeChannel: error executing query") + // return errors.Wrap(err, "error executing query") + // //return err + //} + // + //channel.ID, err = res.LastInsertId() + } + + return err +} + +func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error { + pass := toNullString(channel.Password) + + // update record + channelQueryBuilder := r.db.squirrel. + Update("irc_channel"). + Set("enabled", channel.Enabled). + Set("detached", channel.Detached). + Set("name", channel.Name). + Set("pass", pass). + Where("id = ?", channel.ID) + + query, args, err := channelQueryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("irc.updateChannel: error building query") + return err + } + + _, err = r.db.handler.Exec(query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("irc.updateChannel: error executing query") + return err } return err diff --git a/internal/database/migrate.go b/internal/database/migrate.go index 9bd1e48..ca7197b 100644 --- a/internal/database/migrate.go +++ b/internal/database/migrate.go @@ -1,6 +1,6 @@ package database -const schema = ` +const sqliteSchema = ` CREATE TABLE users ( id INTEGER PRIMARY KEY, @@ -107,7 +107,7 @@ CREATE TABLE filter_indexer filter_id INTEGER, indexer_id INTEGER, FOREIGN KEY (filter_id) REFERENCES filter(id), - FOREIGN KEY (indexer_id) REFERENCES indexer(id), + FOREIGN KEY (indexer_id) REFERENCES indexer(id) ON DELETE CASCADE, PRIMARY KEY (filter_id, indexer_id) ); @@ -150,8 +150,8 @@ CREATE TABLE action webhook_headers TEXT [] DEFAULT '{}', client_id INTEGER, filter_id INTEGER, - FOREIGN KEY (client_id) REFERENCES client(id), - FOREIGN KEY (filter_id) REFERENCES filter(id) + FOREIGN KEY (filter_id) REFERENCES filter(id), + FOREIGN KEY (client_id) REFERENCES client(id) ON DELETE SET NULL ); CREATE TABLE "release" @@ -207,20 +207,20 @@ CREATE TABLE "release" CREATE TABLE release_action_status ( - id INTEGER PRIMARY KEY, - status TEXT, - action TEXT NOT NULL, - type TEXT NOT NULL, - rejections TEXT [] DEFAULT '{}' NOT NULL, - timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, - raw TEXT, - log TEXT, - release_id INTEGER NOT NULL, - FOREIGN KEY (release_id) REFERENCES "release"(id) + id INTEGER PRIMARY KEY, + status TEXT, + action TEXT NOT NULL, + type TEXT NOT NULL, + rejections TEXT [] DEFAULT '{}' NOT NULL, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + raw TEXT, + log TEXT, + release_id INTEGER NOT NULL, + FOREIGN KEY (release_id) REFERENCES "release"(id) ON DELETE CASCADE ); ` -var migrations = []string{ +var sqliteMigrations = []string{ "", ` CREATE TABLE "release" @@ -368,4 +368,348 @@ var migrations = []string{ ALTER TABLE "action" ADD COLUMN webhook_headers TEXT [] DEFAULT '{}'; `, + ` +CREATE TABLE action_dg_tmp +( + id INTEGER PRIMARY KEY, + name TEXT, + type TEXT, + enabled BOOLEAN, + exec_cmd TEXT, + exec_args TEXT, + watch_folder TEXT, + category TEXT, + tags TEXT, + label TEXT, + save_path TEXT, + paused BOOLEAN, + ignore_rules BOOLEAN, + limit_upload_speed INT, + limit_download_speed INT, + client_id INTEGER + CONSTRAINT action_client_id_fkey + REFERENCES client + ON DELETE SET NULL, + filter_id INTEGER + CONSTRAINT action_filter_id_fkey + REFERENCES filter, + webhook_host TEXT, + webhook_data TEXT, + webhook_method TEXT, + webhook_type TEXT, + webhook_headers TEXT [] default '{}' +); + +INSERT INTO action_dg_tmp(id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, + paused, ignore_rules, limit_upload_speed, limit_download_speed, client_id, filter_id, + webhook_host, webhook_data, webhook_method, webhook_type, webhook_headers) +SELECT id, + name, + type, + enabled, + exec_cmd, + exec_args, + watch_folder, + category, + tags, + label, + save_path, + paused, + ignore_rules, + limit_upload_speed, + limit_download_speed, + client_id, + filter_id, + webhook_host, + webhook_data, + webhook_method, + webhook_type, + webhook_headers +FROM action; + +DROP TABLE action; + +ALTER TABLE action_dg_tmp + RENAME TO action; + `, + ` +CREATE TABLE filter_indexer_dg_tmp +( + filter_id INTEGER + CONSTRAINT filter_indexer_filter_id_fkey + REFERENCES filter, + indexer_id INTEGER + CONSTRAINT filter_indexer_indexer_id_fkey + REFERENCES indexer + ON DELETE CASCADE, + PRIMARY KEY (filter_id, indexer_id) +); + +INSERT INTO filter_indexer_dg_tmp(filter_id, indexer_id) +SELECT filter_id, indexer_id +FROM filter_indexer; + +DROP TABLE filter_indexer; + +ALTER TABLE filter_indexer_dg_tmp + RENAME TO filter_indexer; + `, + ` +CREATE TABLE release_action_status_dg_tmp +( + id INTEGER PRIMARY KEY, + status TEXT, + action TEXT not null, + type TEXT not null, + rejections TEXT [] default '{}' not null, + timestamp TIMESTAMP default CURRENT_TIMESTAMP, + raw TEXT, + log TEXT, + release_id INTEGER not null + CONSTRAINT release_action_status_release_id_fkey + REFERENCES "release" + ON DELETE CASCADE +); + +INSERT INTO release_action_status_dg_tmp(id, status, action, type, rejections, timestamp, raw, log, release_id) +SELECT id, + status, + action, + type, + rejections, + timestamp, + raw, + log, + release_id +FROM release_action_status; + +DROP TABLE release_action_status; + +ALTER TABLE release_action_status_dg_tmp + RENAME TO release_action_status; + `, +} + +const postgresSchema = ` +CREATE TABLE users +( + id SERIAL PRIMARY KEY, + username TEXT NOT NULL, + password TEXT NOT NULL, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE (username) +); + +CREATE TABLE indexer +( + id SERIAL PRIMARY KEY, + identifier TEXT, + enabled BOOLEAN, + name TEXT NOT NULL, + settings TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE (identifier) +); + +CREATE TABLE irc_network +( + id SERIAL PRIMARY KEY, + enabled BOOLEAN, + name TEXT NOT NULL, + server TEXT NOT NULL, + port INTEGER NOT NULL, + tls BOOLEAN, + pass TEXT, + invite_command TEXT, + nickserv_account TEXT, + nickserv_password TEXT, + connected BOOLEAN, + connected_since TIMESTAMP, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + UNIQUE (server, port, nickserv_account) +); + +CREATE TABLE irc_channel +( + id SERIAL PRIMARY KEY, + enabled BOOLEAN, + name TEXT NOT NULL, + password TEXT, + detached BOOLEAN, + network_id INTEGER NOT NULL, + FOREIGN KEY (network_id) REFERENCES irc_network(id), + UNIQUE (network_id, name) +); + +CREATE TABLE filter +( + id SERIAL PRIMARY KEY, + enabled BOOLEAN, + name TEXT NOT NULL, + min_size TEXT, + max_size TEXT, + delay INTEGER, + priority INTEGER DEFAULT 0 NOT NULL, + match_releases TEXT, + except_releases TEXT, + use_regex BOOLEAN, + match_release_groups TEXT, + except_release_groups TEXT, + scene BOOLEAN, + freeleech BOOLEAN, + freeleech_percent TEXT, + shows TEXT, + seasons TEXT, + episodes TEXT, + resolutions TEXT [] DEFAULT '{}' NOT NULL, + codecs TEXT [] DEFAULT '{}' NOT NULL, + sources TEXT [] DEFAULT '{}' NOT NULL, + containers TEXT [] DEFAULT '{}' NOT NULL, + match_hdr TEXT [] DEFAULT '{}', + except_hdr TEXT [] DEFAULT '{}', + years TEXT, + artists TEXT, + albums TEXT, + release_types_match TEXT [] DEFAULT '{}', + release_types_ignore TEXT [] DEFAULT '{}', + formats TEXT [] DEFAULT '{}', + quality TEXT [] DEFAULT '{}', + media TEXT [] DEFAULT '{}', + log_score INTEGER, + has_log BOOLEAN, + has_cue BOOLEAN, + perfect_flac BOOLEAN, + match_categories TEXT, + except_categories TEXT, + match_uploaders TEXT, + except_uploaders TEXT, + tags TEXT, + except_tags TEXT, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE filter_indexer +( + filter_id INTEGER, + indexer_id INTEGER, + FOREIGN KEY (filter_id) REFERENCES filter(id), + FOREIGN KEY (indexer_id) REFERENCES indexer(id) ON DELETE CASCADE, + PRIMARY KEY (filter_id, indexer_id) +); + +CREATE TABLE client +( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + enabled BOOLEAN, + type TEXT, + host TEXT NOT NULL, + port INTEGER, + tls BOOLEAN, + tls_skip_verify BOOLEAN, + username TEXT, + password TEXT, + settings JSON +); + +CREATE TABLE action +( + id SERIAL PRIMARY KEY, + name TEXT, + type TEXT, + enabled BOOLEAN, + exec_cmd TEXT, + exec_args TEXT, + watch_folder TEXT, + category TEXT, + tags TEXT, + label TEXT, + save_path TEXT, + paused BOOLEAN, + ignore_rules BOOLEAN, + limit_upload_speed INT, + limit_download_speed INT, + webhook_host TEXT, + webhook_method TEXT, + webhook_type TEXT, + webhook_data TEXT, + webhook_headers TEXT [] DEFAULT '{}', + client_id INTEGER, + filter_id INTEGER, + FOREIGN KEY (filter_id) REFERENCES filter(id), + FOREIGN KEY (client_id) REFERENCES client(id) ON DELETE SET NULL +); + +CREATE TABLE "release" +( + id SERIAL PRIMARY KEY, + filter_status TEXT, + rejections TEXT [] DEFAULT '{}' NOT NULL, + indexer TEXT, + filter TEXT, + protocol TEXT, + implementation TEXT, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + group_id TEXT, + torrent_id TEXT, + torrent_name TEXT, + size INTEGER, + raw TEXT, + title TEXT, + category TEXT, + season INTEGER, + episode INTEGER, + year INTEGER, + resolution TEXT, + source TEXT, + codec TEXT, + container TEXT, + hdr TEXT, + audio TEXT, + release_group TEXT, + region TEXT, + language TEXT, + edition TEXT, + unrated BOOLEAN, + hybrid BOOLEAN, + proper BOOLEAN, + repack BOOLEAN, + website TEXT, + artists TEXT [] DEFAULT '{}' NOT NULL, + type TEXT, + format TEXT, + quality TEXT, + log_score INTEGER, + has_log BOOLEAN, + has_cue BOOLEAN, + is_scene BOOLEAN, + origin TEXT, + tags TEXT [] DEFAULT '{}' NOT NULL, + freeleech BOOLEAN, + freeleech_percent INTEGER, + uploader TEXT, + pre_time TEXT +); + +CREATE TABLE release_action_status +( + id SERIAL PRIMARY KEY, + status TEXT, + action TEXT NOT NULL, + type TEXT NOT NULL, + rejections TEXT [] DEFAULT '{}' NOT NULL, + timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + raw TEXT, + log TEXT, + release_id INTEGER NOT NULL, + FOREIGN KEY (release_id) REFERENCES "release"(id) ON DELETE CASCADE +); +` + +var postgresMigrations = []string{ + "", } diff --git a/internal/database/postgres.go b/internal/database/postgres.go index 749e2df..72d8938 100644 --- a/internal/database/postgres.go +++ b/internal/database/postgres.go @@ -55,26 +55,26 @@ func (db *DB) migratePostgres() error { return err } - if version == len(migrations) { + if version == len(postgresMigrations) { return nil } - if version > len(migrations) { + if version > len(postgresMigrations) { return fmt.Errorf("old") } if version == 0 { - if _, err := tx.Exec(schema); err != nil { + if _, err := tx.Exec(postgresSchema); err != nil { return fmt.Errorf("failed to initialize schema: %v", err) } } else { - for i := version; i < len(migrations); i++ { - if _, err := tx.Exec(migrations[i]); err != nil { + for i := version; i < len(postgresMigrations); i++ { + if _, err := tx.Exec(postgresMigrations[i]); err != nil { return fmt.Errorf("failed to execute migration #%v: %v", i, err) } } } - _, err = tx.Exec(`INSERT INTO schema_migrations (id, version) VALUES (1, $1) ON CONFLICT (id) DO UPDATE SET version = $1`, len(migrations)) + _, err = tx.Exec(`INSERT INTO schema_migrations (id, version) VALUES (1, $1) ON CONFLICT (id) DO UPDATE SET version = $1`, len(postgresMigrations)) if err != nil { return fmt.Errorf("failed to bump schema version: %v", err) } diff --git a/internal/database/release.go b/internal/database/release.go index 7fd4286..9e1b673 100644 --- a/internal/database/release.go +++ b/internal/database/release.go @@ -4,10 +4,9 @@ import ( "context" "database/sql" sq "github.com/Masterminds/squirrel" + "github.com/autobrr/autobrr/internal/domain" "github.com/lib/pq" "github.com/rs/zerolog/log" - - "github.com/autobrr/autobrr/internal/domain" ) type ReleaseRepo struct { @@ -19,42 +18,43 @@ func NewReleaseRepo(db *DB) domain.ReleaseRepo { } func (repo *ReleaseRepo) Store(ctx context.Context, r *domain.Release) (*domain.Release, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - - query, args, err := sq. + queryBuilder := repo.db.squirrel. Insert("release"). Columns("filter_status", "rejections", "indexer", "filter", "protocol", "implementation", "timestamp", "group_id", "torrent_id", "torrent_name", "size", "raw", "title", "category", "season", "episode", "year", "resolution", "source", "codec", "container", "hdr", "audio", "release_group", "region", "language", "edition", "unrated", "hybrid", "proper", "repack", "website", "artists", "type", "format", "quality", "log_score", "has_log", "has_cue", "is_scene", "origin", "tags", "freeleech", "freeleech_percent", "uploader", "pre_time"). Values(r.FilterStatus, pq.Array(r.Rejections), r.Indexer, r.FilterName, r.Protocol, r.Implementation, r.Timestamp, r.GroupID, r.TorrentID, r.TorrentName, r.Size, r.Raw, r.Title, r.Category, r.Season, r.Episode, r.Year, r.Resolution, r.Source, r.Codec, r.Container, r.HDR, r.Audio, r.Group, r.Region, r.Language, r.Edition, r.Unrated, r.Hybrid, r.Proper, r.Repack, r.Website, pq.Array(r.Artists), r.Type, r.Format, r.Quality, r.LogScore, r.HasLog, r.HasCue, r.IsScene, r.Origin, pq.Array(r.Tags), r.Freeleech, r.FreeleechPercent, r.Uploader, r.PreTime). - ToSql() + Suffix("RETURNING id").RunWith(repo.db.handler) - res, err := repo.db.handler.ExecContext(ctx, query, args...) + // return values + var retID int64 + + err := queryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { - log.Error().Stack().Err(err).Msg("error inserting release") + log.Error().Stack().Err(err).Msg("release.store: error executing query") return nil, err } - resId, _ := res.LastInsertId() - r.ID = resId + r.ID = retID - log.Trace().Msgf("release.store: %+v", r) + log.Debug().Msgf("release.store: %+v", r) return r, nil } func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain.ReleaseActionStatus) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - if a.ID != 0 { - query, args, err := sq. + queryBuilder := repo.db.squirrel. Update("release_action_status"). Set("status", a.Status). Set("rejections", pq.Array(a.Rejections)). Set("timestamp", a.Timestamp). Where("id = ?", a.ID). - Where("release_id = ?", a.ReleaseID). - ToSql() + Where("release_id = ?", a.ReleaseID) + + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("release.store: error building query") + return err + } _, err = repo.db.handler.ExecContext(ctx, query, args...) if err != nil { @@ -63,20 +63,22 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain } } else { - query, args, err := sq. + queryBuilder := repo.db.squirrel. Insert("release_action_status"). Columns("status", "action", "type", "rejections", "timestamp", "release_id"). Values(a.Status, a.Action, a.Type, pq.Array(a.Rejections), a.Timestamp, a.ReleaseID). - ToSql() + Suffix("RETURNING id").RunWith(repo.db.handler) - res, err := repo.db.handler.ExecContext(ctx, query, args...) + // return values + var retID int64 + + err := queryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { - log.Error().Stack().Err(err).Msg("error inserting status of release") + log.Error().Stack().Err(err).Msg("release.storeReleaseActionStatus: error executing query") return err } - resId, _ := res.LastInsertId() - a.ID = resId + a.ID = retID } log.Trace().Msgf("release.store_release_action_status: %+v", a) @@ -84,12 +86,32 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain return nil } -func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryParams) ([]domain.Release, int64, int64, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() +func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryParams) ([]*domain.Release, int64, int64, error) { + tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{}) + if err != nil { + return nil, 0, 0, err + } + defer tx.Rollback() - queryBuilder := sq. - Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp", "COUNT() OVER() AS total_count"). + releases, nextCursor, total, err := repo.findReleases(ctx, tx, params) + if err != nil { + return nil, nextCursor, total, err + } + + for _, release := range releases { + statuses, err := repo.attachActionStatus(ctx, tx, release.ID) + if err != nil { + return releases, nextCursor, total, err + } + release.ActionStatus = statuses + } + + return releases, nextCursor, total, nil +} + +func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain.ReleaseQueryParams) ([]*domain.Release, int64, int64, error) { + queryBuilder := repo.db.squirrel. + Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp", "COUNT(*) OVER() AS total_count"). From("release r"). OrderBy("r.timestamp DESC") @@ -123,9 +145,9 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryPar query, args, err := queryBuilder.ToSql() log.Trace().Str("database", "release.find").Msgf("query: '%v', args: '%v'", query, args) - res := make([]domain.Release, 0) + res := make([]*domain.Release, 0) - rows, err := repo.db.handler.QueryContext(ctx, query, args...) + rows, err := tx.QueryContext(ctx, query, args...) if err != nil { log.Error().Stack().Err(err).Msg("error fetching releases") return res, 0, 0, nil @@ -153,36 +175,21 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryPar rls.Indexer = indexer.String rls.FilterName = filter.String - // get action status - actionStatus, err := repo.GetActionStatusByReleaseID(ctx, rls.ID) - if err != nil { - log.Error().Stack().Err(err).Msg("release.find: error getting action status") - return res, 0, 0, err - } - - rls.ActionStatus = actionStatus - - res = append(res, rls) + res = append(res, &rls) } nextCursor := int64(0) if len(res) > 0 { lastID := res[len(res)-1].ID nextCursor = lastID - //nextCursor, _ = strconv.ParseInt(lastID, 10, 64) } return res, nextCursor, countItems, nil } func (repo *ReleaseRepo) GetIndexerOptions(ctx context.Context) ([]string, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - query := ` - SELECT DISTINCT indexer FROM "release" - UNION - SELECT DISTINCT identifier indexer FROM indexer;` + query := `SELECT DISTINCT indexer FROM "release" UNION SELECT DISTINCT identifier indexer FROM indexer;` log.Trace().Str("database", "release.get_indexers").Msgf("query: '%v'", query) @@ -216,10 +223,8 @@ func (repo *ReleaseRepo) GetIndexerOptions(ctx context.Context) ([]string, error } func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, releaseID int64) ([]domain.ReleaseActionStatus, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - queryBuilder := sq. + queryBuilder := repo.db.squirrel. Select("id", "status", "action", "type", "rejections", "timestamp"). From("release_action_status"). Where("release_id = ?", releaseID) @@ -255,16 +260,52 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release return res, nil } +func (repo *ReleaseRepo) attachActionStatus(ctx context.Context, tx *Tx, releaseID int64) ([]domain.ReleaseActionStatus, error) { + + queryBuilder := repo.db.squirrel. + Select("id", "status", "action", "type", "rejections", "timestamp"). + From("release_action_status"). + Where("release_id = ?", releaseID) + + query, args, err := queryBuilder.ToSql() + + res := make([]domain.ReleaseActionStatus, 0) + + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("error fetching releases") + return res, nil + } + + defer rows.Close() + + if err := rows.Err(); err != nil { + log.Error().Stack().Err(err) + return res, err + } + + for rows.Next() { + var rls domain.ReleaseActionStatus + + if err := rows.Scan(&rls.ID, &rls.Status, &rls.Action, &rls.Type, pq.Array(&rls.Rejections), &rls.Timestamp); err != nil { + log.Error().Stack().Err(err).Msg("release.find: error scanning data to struct") + return res, err + } + + res = append(res, rls) + } + + return res, nil +} + func (repo *ReleaseRepo) Stats(ctx context.Context) (*domain.ReleaseStats, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() query := `SELECT COUNT(*) total, - IFNULL(SUM(CASE WHEN filter_status = 'FILTER_APPROVED' THEN 1 ELSE 0 END), 0) filtered_count, - IFNULL(SUM(CASE WHEN filter_status = 'FILTER_REJECTED' THEN 1 ELSE 0 END), 0) filter_rejected_count, - (SELECT IFNULL(SUM(CASE WHEN status = 'PUSH_APPROVED' THEN 1 ELSE 0 END), 0) + COALESCE(SUM(CASE WHEN filter_status = 'FILTER_APPROVED' THEN 1 ELSE 0 END), 0) AS filtered_count, + COALESCE(SUM(CASE WHEN filter_status = 'FILTER_REJECTED' THEN 1 ELSE 0 END), 0) AS filter_rejected_count, + (SELECT COALESCE(SUM(CASE WHEN status = 'PUSH_APPROVED' THEN 1 ELSE 0 END), 0) FROM "release_action_status") AS push_approved_count, - (SELECT IFNULL(SUM(CASE WHEN status = 'PUSH_REJECTED' THEN 1 ELSE 0 END), 0) + (SELECT COALESCE(SUM(CASE WHEN status = 'PUSH_REJECTED' THEN 1 ELSE 0 END), 0) FROM "release_action_status") AS push_rejected_count FROM "release";` diff --git a/internal/database/sqlite.go b/internal/database/sqlite.go index e4e546e..fc62eb8 100644 --- a/internal/database/sqlite.go +++ b/internal/database/sqlite.go @@ -58,10 +58,10 @@ func (db *DB) migrateSQLite() error { return fmt.Errorf("failed to query schema version: %v", err) } - if version == len(migrations) { + if version == len(sqliteMigrations) { return nil - } else if version > len(migrations) { - return fmt.Errorf("autobrr (version %d) older than schema (version: %d)", len(migrations), version) + } else if version > len(sqliteMigrations) { + return fmt.Errorf("autobrr (version %d) older than schema (version: %d)", len(sqliteMigrations), version) } tx, err := db.handler.Begin() @@ -71,12 +71,12 @@ func (db *DB) migrateSQLite() error { defer tx.Rollback() if version == 0 { - if _, err := tx.Exec(schema); err != nil { + if _, err := tx.Exec(sqliteSchema); err != nil { return fmt.Errorf("failed to initialize schema: %v", err) } } else { - for i := version; i < len(migrations); i++ { - if _, err := tx.Exec(migrations[i]); err != nil { + for i := version; i < len(sqliteMigrations); i++ { + if _, err := tx.Exec(sqliteMigrations[i]); err != nil { return fmt.Errorf("failed to execute migration #%v: %v", i, err) } } @@ -86,13 +86,13 @@ func (db *DB) migrateSQLite() error { // get data from filter.sources, check if specific types, move to new table and clear // if migration 6 // TODO 2022-01-30 remove this in future version - if version == 5 && len(migrations) == 6 { + if version == 5 && len(sqliteMigrations) == 6 { if err := customMigrateCopySourcesToMedia(tx); err != nil { return fmt.Errorf("could not run custom data migration: %v", err) } } - _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations))) + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(sqliteMigrations))) if err != nil { return fmt.Errorf("failed to bump schema version: %v", err) } diff --git a/internal/database/user.go b/internal/database/user.go index 71e12b4..3bc0af8 100644 --- a/internal/database/user.go +++ b/internal/database/user.go @@ -16,12 +16,19 @@ func NewUserRepo(db *DB) domain.UserRepo { } func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain.User, error) { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() - query := `SELECT id, username, password FROM users WHERE username = ?` + queryBuilder := r.db.squirrel. + Select("id", "username", "password"). + From("users"). + Where("username = ?", username) - row := r.db.handler.QueryRowContext(ctx, query, username) + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("user.store: error building query") + return nil, err + } + + row := r.db.handler.QueryRowContext(ctx, query, args...) if err := row.Err(); err != nil { return nil, err } @@ -37,25 +44,49 @@ func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain } func (r *UserRepo) Store(ctx context.Context, user domain.User) error { - //r.db.lock.RLock() - //defer r.db.lock.RUnlock() var err error - if user.ID != 0 { - update := `UPDATE users SET password = ? WHERE username = ?` - _, err = r.db.handler.ExecContext(ctx, update, user.Password, user.Username) - if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") - return err - } - } else { - query := `INSERT INTO users (username, password) VALUES (?, ?)` - _, err = r.db.handler.ExecContext(ctx, query, user.Username, user.Password) - if err != nil { - log.Error().Stack().Err(err).Msg("error executing query") - return err - } + queryBuilder := r.db.squirrel. + Update("users"). + Set("username", user.Username). + Set("password", user.Password). + Where("username = ?", user.Username) + + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("user.store: error building query") + return err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("user.store: error executing query") + return err + } + + return err +} +func (r *UserRepo) Update(ctx context.Context, user domain.User) error { + + var err error + + queryBuilder := r.db.squirrel. + Update("users"). + Set("username", user.Username). + Set("password", user.Password). + Where("username = ?", user.Username) + + query, args, err := queryBuilder.ToSql() + if err != nil { + log.Error().Stack().Err(err).Msg("user.store: error building query") + return err + } + + _, err = r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + log.Error().Stack().Err(err).Msg("user.store: error executing query") + return err } return err diff --git a/internal/domain/action.go b/internal/domain/action.go index eafd18f..8282a46 100644 --- a/internal/domain/action.go +++ b/internal/domain/action.go @@ -7,7 +7,7 @@ type ActionRepo interface { StoreFilterActions(ctx context.Context, actions []Action, filterID int64) ([]Action, error) DeleteByFilterID(ctx context.Context, filterID int) error FindByFilterID(ctx context.Context, filterID int) ([]Action, error) - List() ([]Action, error) + List(ctx context.Context) ([]Action, error) Delete(actionID int) error ToggleEnabled(actionID int) error } diff --git a/internal/domain/client.go b/internal/domain/client.go index 04aafcb..ad224f4 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -7,6 +7,7 @@ type DownloadClientRepo interface { List(ctx context.Context) ([]DownloadClient, error) FindByID(ctx context.Context, id int32) (*DownloadClient, error) Store(ctx context.Context, client DownloadClient) (*DownloadClient, error) + Update(ctx context.Context, client DownloadClient) (*DownloadClient, error) Delete(ctx context.Context, clientID int) error } diff --git a/internal/domain/indexer.go b/internal/domain/indexer.go index 81f30b0..5748b1f 100644 --- a/internal/domain/indexer.go +++ b/internal/domain/indexer.go @@ -7,9 +7,9 @@ import ( ) type IndexerRepo interface { - Store(indexer Indexer) (*Indexer, error) - Update(indexer Indexer) (*Indexer, error) - List() ([]Indexer, error) + Store(ctx context.Context, indexer Indexer) (*Indexer, error) + Update(ctx context.Context, indexer Indexer) (*Indexer, error) + List(ctx context.Context) ([]Indexer, error) Delete(ctx context.Context, id int) error FindByFilterID(ctx context.Context, id int) ([]Indexer, error) } diff --git a/internal/domain/irc.go b/internal/domain/irc.go index 3ee164e..8a15bd1 100644 --- a/internal/domain/irc.go +++ b/internal/domain/irc.go @@ -71,11 +71,12 @@ type IrcRepo interface { StoreNetwork(network *IrcNetwork) error UpdateNetwork(ctx context.Context, network *IrcNetwork) error StoreChannel(networkID int64, channel *IrcChannel) error + UpdateChannel(channel *IrcChannel) error StoreNetworkChannels(ctx context.Context, networkID int64, channels []IrcChannel) error CheckExistingNetwork(ctx context.Context, network *IrcNetwork) (*IrcNetwork, error) FindActiveNetworks(ctx context.Context) ([]IrcNetwork, error) ListNetworks(ctx context.Context) ([]IrcNetwork, error) ListChannels(networkID int64) ([]IrcChannel, error) - GetNetworkByID(id int64) (*IrcNetwork, error) + GetNetworkByID(ctx context.Context, id int64) (*IrcNetwork, error) DeleteNetwork(ctx context.Context, id int64) error } diff --git a/internal/domain/release.go b/internal/domain/release.go index 99ecfb8..a4552cd 100644 --- a/internal/domain/release.go +++ b/internal/domain/release.go @@ -28,7 +28,7 @@ import ( type ReleaseRepo interface { Store(ctx context.Context, release *Release) (*Release, error) - Find(ctx context.Context, params ReleaseQueryParams) (res []Release, nextCursor int64, count int64, err error) + Find(ctx context.Context, params ReleaseQueryParams) (res []*Release, nextCursor int64, count int64, err error) GetIndexerOptions(ctx context.Context) ([]string, error) GetActionStatusByReleaseID(ctx context.Context, releaseID int64) ([]ReleaseActionStatus, error) Stats(ctx context.Context) (*ReleaseStats, error) diff --git a/internal/domain/user.go b/internal/domain/user.go index 016b376..2804e6d 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -5,6 +5,7 @@ import "context" type UserRepo interface { FindByUsername(ctx context.Context, username string) (*User, error) Store(ctx context.Context, user User) error + Update(ctx context.Context, user User) error } type User struct { diff --git a/internal/download_client/service.go b/internal/download_client/service.go index a5d2b02..52e9073 100644 --- a/internal/download_client/service.go +++ b/internal/download_client/service.go @@ -11,6 +11,7 @@ type Service interface { List(ctx context.Context) ([]domain.DownloadClient, error) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) + Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) Delete(ctx context.Context, clientID int) error Test(client domain.DownloadClient) error } @@ -43,6 +44,18 @@ func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*dom return s.repo.Store(ctx, client) } +func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { + // validate data + if client.Host == "" { + return nil, errors.New("validation error: no host") + } else if client.Type == "" { + return nil, errors.New("validation error: no type") + } + + // store + return s.repo.Update(ctx, client) +} + func (s *service) Delete(ctx context.Context, clientID int) error { return s.repo.Delete(ctx, clientID) } diff --git a/internal/filter/service.go b/internal/filter/service.go index b4bc04b..e3b7ed5 100644 --- a/internal/filter/service.go +++ b/internal/filter/service.go @@ -47,12 +47,12 @@ func (s *service) ListFilters(ctx context.Context) ([]domain.Filter, error) { return nil, err } - var ret []domain.Filter + ret := make([]domain.Filter, 0) for _, filter := range filters { indexers, err := s.indexerSvc.FindByFilterID(ctx, filter.ID) if err != nil { - return nil, err + return ret, err } filter.Indexers = indexers diff --git a/internal/http/action.go b/internal/http/action.go index 1609c1a..b5e7048 100644 --- a/internal/http/action.go +++ b/internal/http/action.go @@ -12,7 +12,7 @@ import ( ) type actionService interface { - Fetch() ([]domain.Action, error) + List(ctx context.Context) ([]domain.Action, error) Store(ctx context.Context, action domain.Action) (*domain.Action, error) Delete(actionID int) error ToggleEnabled(actionID int) error @@ -39,7 +39,7 @@ func (h actionHandler) Routes(r chi.Router) { } func (h actionHandler) getActions(w http.ResponseWriter, r *http.Request) { - actions, err := h.service.Fetch() + actions, err := h.service.List(r.Context()) if err != nil { // encode error } diff --git a/internal/http/download_client.go b/internal/http/download_client.go index 1d90d9e..5f2d7f8 100644 --- a/internal/http/download_client.go +++ b/internal/http/download_client.go @@ -15,6 +15,7 @@ import ( type downloadClientService interface { List(ctx context.Context) ([]domain.DownloadClient, error) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) + Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) Delete(ctx context.Context, clientID int) error Test(client domain.DownloadClient) error } @@ -93,7 +94,7 @@ func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) { return } - client, err := h.service.Store(r.Context(), data) + client, err := h.service.Update(r.Context(), data) if err != nil { h.encoder.Error(w, err) return diff --git a/internal/http/indexer.go b/internal/http/indexer.go index 7e736ee..7ac5be9 100644 --- a/internal/http/indexer.go +++ b/internal/http/indexer.go @@ -12,9 +12,9 @@ import ( ) type indexerService interface { - Store(indexer domain.Indexer) (*domain.Indexer, error) - Update(indexer domain.Indexer) (*domain.Indexer, error) - List() ([]domain.Indexer, error) + Store(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) + Update(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) + List(ctx context.Context) ([]domain.Indexer, error) GetAll() ([]*domain.IndexerDefinition, error) GetTemplates() ([]domain.IndexerDefinition, error) Delete(ctx context.Context, id int) error @@ -55,20 +55,23 @@ func (h indexerHandler) getSchema(w http.ResponseWriter, r *http.Request) { } func (h indexerHandler) store(w http.ResponseWriter, r *http.Request) { - var data domain.Indexer + var ( + ctx = r.Context() + data domain.Indexer + ) if err := json.NewDecoder(r.Body).Decode(&data); err != nil { return } - indexer, err := h.service.Store(data) + indexer, err := h.service.Store(ctx, data) if err != nil { // - h.encoder.StatusResponse(r.Context(), w, nil, http.StatusBadRequest) + h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest) return } - h.encoder.StatusResponse(r.Context(), w, indexer, http.StatusCreated) + h.encoder.StatusResponse(ctx, w, indexer, http.StatusCreated) } func (h indexerHandler) update(w http.ResponseWriter, r *http.Request) { @@ -81,7 +84,7 @@ func (h indexerHandler) update(w http.ResponseWriter, r *http.Request) { return } - indexer, err := h.service.Update(data) + indexer, err := h.service.Update(ctx, data) if err != nil { // } @@ -118,7 +121,7 @@ func (h indexerHandler) getAll(w http.ResponseWriter, r *http.Request) { func (h indexerHandler) list(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - indexers, err := h.service.List() + indexers, err := h.service.List(ctx) if err != nil { // } diff --git a/internal/http/irc.go b/internal/http/irc.go index f657d58..de8d1cf 100644 --- a/internal/http/irc.go +++ b/internal/http/irc.go @@ -15,7 +15,7 @@ type ircService interface { ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) GetNetworksWithHealth(ctx context.Context) ([]domain.IrcNetworkWithHealth, error) DeleteNetwork(ctx context.Context, id int64) error - GetNetworkByID(id int64) (*domain.IrcNetwork, error) + GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) StoreNetwork(ctx context.Context, network *domain.IrcNetwork) error UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error StoreChannel(networkID int64, channel *domain.IrcChannel) error @@ -61,7 +61,7 @@ func (h ircHandler) getNetworkByID(w http.ResponseWriter, r *http.Request) { id, _ := strconv.Atoi(networkID) - network, err := h.service.GetNetworkByID(int64(id)) + network, err := h.service.GetNetworkByID(ctx, int64(id)) if err != nil { h.encoder.Error(w, err) } diff --git a/internal/http/release.go b/internal/http/release.go index 0ea8033..9e3b513 100644 --- a/internal/http/release.go +++ b/internal/http/release.go @@ -11,7 +11,7 @@ import ( ) type releaseService interface { - Find(ctx context.Context, query domain.ReleaseQueryParams) (res []domain.Release, nextCursor int64, count int64, err error) + Find(ctx context.Context, query domain.ReleaseQueryParams) (res []*domain.Release, nextCursor int64, count int64, err error) GetIndexerOptions(ctx context.Context) ([]string, error) Stats(ctx context.Context) (*domain.ReleaseStats, error) Delete(ctx context.Context) error @@ -105,9 +105,9 @@ func (h releaseHandler) findReleases(w http.ResponseWriter, r *http.Request) { } ret := struct { - Data []domain.Release `json:"data"` - NextCursor int64 `json:"next_cursor"` - Count int64 `json:"count"` + Data []*domain.Release `json:"data"` + NextCursor int64 `json:"next_cursor"` + Count int64 `json:"count"` }{ Data: releases, NextCursor: nextCursor, diff --git a/internal/indexer/service.go b/internal/indexer/service.go index f69bbc6..e0da0b3 100644 --- a/internal/indexer/service.go +++ b/internal/indexer/service.go @@ -15,11 +15,11 @@ import ( ) type Service interface { - Store(indexer domain.Indexer) (*domain.Indexer, error) - Update(indexer domain.Indexer) (*domain.Indexer, error) + Store(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) + Update(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) Delete(ctx context.Context, id int) error FindByFilterID(ctx context.Context, id int) ([]domain.Indexer, error) - List() ([]domain.Indexer, error) + List(ctx context.Context) ([]domain.Indexer, error) GetAll() ([]*domain.IndexerDefinition, error) GetTemplates() ([]domain.IndexerDefinition, error) LoadIndexerDefinitions() error @@ -52,8 +52,8 @@ func NewService(config domain.Config, repo domain.IndexerRepo, apiService APISer } } -func (s *service) Store(indexer domain.Indexer) (*domain.Indexer, error) { - i, err := s.repo.Store(indexer) +func (s *service) Store(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) { + i, err := s.repo.Store(ctx, indexer) if err != nil { log.Error().Stack().Err(err).Msgf("failed to store indexer: %v", indexer.Name) return nil, err @@ -69,8 +69,8 @@ func (s *service) Store(indexer domain.Indexer) (*domain.Indexer, error) { return i, nil } -func (s *service) Update(indexer domain.Indexer) (*domain.Indexer, error) { - i, err := s.repo.Update(indexer) +func (s *service) Update(ctx context.Context, indexer domain.Indexer) (*domain.Indexer, error) { + i, err := s.repo.Update(ctx, indexer) if err != nil { return nil, err } @@ -97,25 +97,15 @@ func (s *service) Delete(ctx context.Context, id int) error { } func (s *service) FindByFilterID(ctx context.Context, id int) ([]domain.Indexer, error) { - filters, err := s.repo.FindByFilterID(ctx, id) - if err != nil { - return nil, err - } - - return filters, nil + return s.repo.FindByFilterID(ctx, id) } -func (s *service) List() ([]domain.Indexer, error) { - i, err := s.repo.List() - if err != nil { - return nil, err - } - - return i, nil +func (s *service) List(ctx context.Context) ([]domain.Indexer, error) { + return s.repo.List(ctx) } func (s *service) GetAll() ([]*domain.IndexerDefinition, error) { - indexers, err := s.repo.List() + indexers, err := s.repo.List(context.Background()) if err != nil { return nil, err } diff --git a/internal/irc/service.go b/internal/irc/service.go index 5b1ea05..14823e7 100644 --- a/internal/irc/service.go +++ b/internal/irc/service.go @@ -3,6 +3,7 @@ package irc import ( "context" "fmt" + "github.com/pkg/errors" "strings" "sync" @@ -20,7 +21,7 @@ type Service interface { StopNetwork(key handlerKey) error ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) GetNetworksWithHealth(ctx context.Context) ([]domain.IrcNetworkWithHealth, error) - GetNetworkByID(id int64) (*domain.IrcNetwork, error) + GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) DeleteNetwork(ctx context.Context, id int64) error StoreNetwork(ctx context.Context, network *domain.IrcNetwork) error UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error @@ -335,8 +336,8 @@ func (s *service) StopNetworkIfRunning(key handlerKey) error { return nil } -func (s *service) GetNetworkByID(id int64) (*domain.IrcNetwork, error) { - network, err := s.repo.GetNetworkByID(id) +func (s *service) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) { + network, err := s.repo.GetNetworkByID(ctx, id) if err != nil { log.Error().Err(err).Msgf("failed to get network: %v", id) return nil, err @@ -454,7 +455,7 @@ func (s *service) GetNetworksWithHealth(ctx context.Context) ([]domain.IrcNetwor } func (s *service) DeleteNetwork(ctx context.Context, id int64) error { - network, err := s.GetNetworkByID(id) + network, err := s.GetNetworkByID(ctx, id) if err != nil { return err } @@ -527,7 +528,9 @@ func (s *service) StoreNetwork(ctx context.Context, network *domain.IrcNetwork) if network.Channels != nil { for _, channel := range network.Channels { if err := s.repo.StoreChannel(network.ID, &channel); err != nil { - return err + log.Error().Stack().Err(err).Msg("irc.storeChannel: error executing query") + return errors.Wrap(err, "error storing channel on network") + //return err } } } diff --git a/internal/release/service.go b/internal/release/service.go index 4aa516f..9cd8818 100644 --- a/internal/release/service.go +++ b/internal/release/service.go @@ -10,7 +10,7 @@ import ( ) type Service interface { - Find(ctx context.Context, query domain.ReleaseQueryParams) (res []domain.Release, nextCursor int64, count int64, err error) + Find(ctx context.Context, query domain.ReleaseQueryParams) (res []*domain.Release, nextCursor int64, count int64, err error) GetIndexerOptions(ctx context.Context) ([]string, error) Stats(ctx context.Context) (*domain.ReleaseStats, error) Store(ctx context.Context, release *domain.Release) error @@ -31,13 +31,8 @@ func NewService(repo domain.ReleaseRepo, actionService action.Service) Service { } } -func (s *service) Find(ctx context.Context, query domain.ReleaseQueryParams) (res []domain.Release, nextCursor int64, count int64, err error) { - res, nextCursor, count, err = s.repo.Find(ctx, query) - if err != nil { - return - } - - return +func (s *service) Find(ctx context.Context, query domain.ReleaseQueryParams) (res []*domain.Release, nextCursor int64, count int64, err error) { + return s.repo.Find(ctx, query) } func (s *service) GetIndexerOptions(ctx context.Context) ([]string, error) {