From 19b3899a5cd9e5fb0df54a4fbc98436037cccae2 Mon Sep 17 00:00:00 2001 From: Kyle Sanderson Date: Fri, 30 Dec 2022 14:53:45 -0800 Subject: [PATCH] refactor(database): clean-up queries (#569) * fix(database): build WHERE using squirrel * flip LIKEs * change sql.LevelReadCommitted * feat(db): add missing commits * implement single query for releases * cleanup * feat(releases): properly handle limit for Find * feat(releases): make dynamic ILike helper * feat(releases): check for empty ReleaseActionStatus * add values as sql.NullX * check if ID is non 0 * feat(releases): improve find --- internal/database/action.go | 20 +-- internal/database/api.go | 3 +- internal/database/database.go | 18 +++ internal/database/download_client.go | 7 +- internal/database/feed.go | 14 +- internal/database/feed_cache.go | 21 +-- internal/database/filter.go | 37 ++--- internal/database/indexer.go | 9 +- internal/database/irc.go | 32 ++--- internal/database/notification.go | 6 +- internal/database/release.go | 202 ++++++++++++++------------- internal/database/user.go | 5 +- 12 files changed, 206 insertions(+), 168 deletions(-) diff --git a/internal/database/action.go b/internal/database/action.go index b1821a9..1814e58 100644 --- a/internal/database/action.go +++ b/internal/database/action.go @@ -29,7 +29,7 @@ func NewActionRepo(log logger.Logger, db *DB, clientRepo domain.DownloadClientRe func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]*domain.Action, error) { - tx, err := r.db.BeginTx(ctx, nil) + tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) if err != nil { return nil, err } @@ -51,6 +51,10 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]*domai } } + if err = tx.Commit(); err != nil { + return nil, errors.Wrap(err, "error finding filter by id") + } + return actions, nil } @@ -87,7 +91,7 @@ func (r *ActionRepo) findByFilterID(ctx context.Context, tx *Tx, filterID int) ( "client_id", ). From("action"). - Where("filter_id = ?", filterID) + Where(sq.Eq{"filter_id": filterID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -165,7 +169,7 @@ func (r *ActionRepo) attachDownloadClient(ctx context.Context, tx *Tx, clientID "settings", ). From("client"). - Where("id = ?", clientID) + Where(sq.Eq{"id": clientID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -282,7 +286,7 @@ func (r *ActionRepo) List(ctx context.Context) ([]domain.Action, error) { func (r *ActionRepo) Delete(actionID int) error { queryBuilder := r.db.squirrel. Delete("action"). - Where("id = ?", actionID) + Where(sq.Eq{"id": actionID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -302,7 +306,7 @@ func (r *ActionRepo) Delete(actionID int) error { func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error { queryBuilder := r.db.squirrel. Delete("action"). - Where("filter_id = ?", filterID) + Where(sq.Eq{"filter_id": filterID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -472,7 +476,7 @@ func (r *ActionRepo) Update(ctx context.Context, action domain.Action) (*domain. Set("webhook_data", webhookData). Set("client_id", clientID). Set("filter_id", filterID). - Where("id = ?", action.ID) + Where(sq.Eq{"id": action.ID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -499,7 +503,7 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []*domain.A deleteQueryBuilder := r.db.squirrel. Delete("action"). - Where("filter_id = ?", filterID) + Where(sq.Eq{"filter_id": filterID}) deleteQuery, deleteArgs, err := deleteQueryBuilder.ToSql() if err != nil { @@ -622,7 +626,7 @@ func (r *ActionRepo) ToggleEnabled(actionID int) error { queryBuilder := r.db.squirrel. Update("action"). Set("enabled", sq.Expr("NOT enabled")). - Where("id = ?", actionID) + Where(sq.Eq{"id": actionID}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/api.go b/internal/database/api.go index 4eadc19..d0e73d1 100644 --- a/internal/database/api.go +++ b/internal/database/api.go @@ -9,6 +9,7 @@ import ( "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/pkg/errors" + sq "github.com/Masterminds/squirrel" "github.com/lib/pq" "github.com/rs/zerolog" ) @@ -55,7 +56,7 @@ func (r *APIRepo) Store(ctx context.Context, key *domain.APIKey) error { func (r *APIRepo) Delete(ctx context.Context, key string) error { queryBuilder := r.db.squirrel. Delete("api_key"). - Where("key = ?", key) + Where(sq.Eq{"key": key}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/database.go b/internal/database/database.go index 9a3230f..c9f2fbb 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -14,6 +14,8 @@ import ( "github.com/rs/zerolog" ) +var databaseDriver = "sqlite" + type DB struct { log zerolog.Logger handler *sql.DB @@ -37,6 +39,7 @@ func NewDB(cfg *domain.Config, log logger.Logger) (*DB, error) { switch cfg.DatabaseType { case "sqlite": + databaseDriver = "sqlite" db.Driver = "sqlite" db.DSN = dataSourceName(cfg.ConfigPath, "autobrr.db") case "postgres": @@ -45,6 +48,7 @@ func NewDB(cfg *domain.Config, log logger.Logger) (*DB, error) { } db.DSN = fmt.Sprintf("postgres://%v:%v@%v:%d/%v?sslmode=disable", cfg.PostgresUser, cfg.PostgresPass, cfg.PostgresHost, cfg.PostgresPort, cfg.PostgresDatabase) db.Driver = "postgres" + databaseDriver = "postgres" default: return nil, errors.New("unsupported database: %v", cfg.DatabaseType) } @@ -106,3 +110,17 @@ type Tx struct { *sql.Tx handler *DB } + +type ILikeDynamic interface { + ToSql() (sql string, args []interface{}, err error) +} + +// ILike is a wrapper for sq.Like and sq.ILike +// SQLite does not support ILike but postgres does so this checks what database is being used +func ILike(col string, val string) ILikeDynamic { + if databaseDriver == "sqlite" { + return sq.Like{col: val} + } + + return sq.ILike{col: val} +} diff --git a/internal/database/download_client.go b/internal/database/download_client.go index 5946591..fbf24ab 100644 --- a/internal/database/download_client.go +++ b/internal/database/download_client.go @@ -10,6 +10,7 @@ import ( "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/pkg/errors" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog" ) @@ -136,7 +137,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do "settings", ). From("client"). - Where("id = ?", id) + Where(sq.Eq{"id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -232,7 +233,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC Set("username", client.Username). Set("password", client.Password). Set("settings", string(settingsJson)). - Where("id = ?", client.ID) + Where(sq.Eq{"id": client.ID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -255,7 +256,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error { queryBuilder := r.db.squirrel. Delete("client"). - Where("id = ?", clientID) + Where(sq.Eq{"id": clientID}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/feed.go b/internal/database/feed.go index b0b6789..3246d41 100644 --- a/internal/database/feed.go +++ b/internal/database/feed.go @@ -43,7 +43,7 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) { ). From("feed f"). Join("indexer i ON f.indexer_id = i.id"). - Where("f.id = ?", id) + Where(sq.Eq{"f.id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -89,7 +89,7 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string) ). From("feed f"). Join("indexer i ON f.indexer_id = i.id"). - Where("i.name = ?", indexer) + Where(sq.Eq{"i.name": indexer}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -221,7 +221,7 @@ func (r *FeedRepo) Update(ctx context.Context, feed *domain.Feed) error { Set("api_key", feed.ApiKey). Set("cookie", feed.Cookie). Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")). - Where("id = ?", feed.ID) + Where(sq.Eq{"id": feed.ID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -240,7 +240,7 @@ func (r *FeedRepo) UpdateLastRun(ctx context.Context, feedID int) error { queryBuilder := r.db.squirrel. Update("feed"). Set("last_run", sq.Expr("CURRENT_TIMESTAMP")). - Where("id = ?", feedID) + Where(sq.Eq{"id": feedID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -260,7 +260,7 @@ func (r *FeedRepo) UpdateLastRunWithData(ctx context.Context, feedID int, data s Update("feed"). Set("last_run", sq.Expr("CURRENT_TIMESTAMP")). Set("last_run_data", data). - Where("id = ?", feedID) + Where(sq.Eq{"id": feedID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -282,7 +282,7 @@ func (r *FeedRepo) ToggleEnabled(ctx context.Context, id int, enabled bool) erro Update("feed"). Set("enabled", enabled). Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")). - Where("id = ?", id) + Where(sq.Eq{"id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -299,7 +299,7 @@ func (r *FeedRepo) ToggleEnabled(ctx context.Context, id int, enabled bool) erro func (r *FeedRepo) Delete(ctx context.Context, id int) error { queryBuilder := r.db.squirrel. Delete("feed"). - Where("id = ?", id) + Where(sq.Eq{"id": id}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/feed_cache.go b/internal/database/feed_cache.go index dae3105..a208133 100644 --- a/internal/database/feed_cache.go +++ b/internal/database/feed_cache.go @@ -8,6 +8,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/pkg/errors" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog" ) @@ -31,9 +32,9 @@ func (r *FeedCacheRepo) Get(bucket string, key string) ([]byte, error) { "ttl", ). From("feed_cache"). - Where("bucket = ?", bucket). - Where("key = ?", key). - Where("ttl > ?", time.Now()) + Where(sq.Eq{"bucket": bucket}). + Where(sq.Eq{"key": key}). + Where(sq.Gt{"ttl": time.Now()}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -64,7 +65,7 @@ func (r *FeedCacheRepo) GetByBucket(ctx context.Context, bucket string) ([]domai "ttl", ). From("feed_cache"). - Where("bucket = ?", bucket) + Where(sq.Eq{"bucket": bucket}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -102,7 +103,7 @@ func (r *FeedCacheRepo) GetCountByBucket(ctx context.Context, bucket string) (in queryBuilder := r.db.squirrel. Select("COUNT(*)"). From("feed_cache"). - Where("bucket = ?", bucket) + Where(sq.Eq{"bucket": bucket}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -128,8 +129,8 @@ func (r *FeedCacheRepo) Exists(bucket string, key string) (bool, error) { Select("1"). Prefix("SELECT EXISTS ("). From("feed_cache"). - Where("bucket = ?", bucket). - Where("key = ?", key). + Where(sq.Eq{"bucket": bucket}). + Where(sq.Eq{"key": key}). Suffix(")") query, args, err := queryBuilder.ToSql() @@ -167,8 +168,8 @@ func (r *FeedCacheRepo) Put(bucket string, key string, val []byte, ttl time.Time func (r *FeedCacheRepo) Delete(ctx context.Context, bucket string, key string) error { queryBuilder := r.db.squirrel. Delete("feed_cache"). - Where("bucket = ?", bucket). - Where("key = ?", key) + Where(sq.Eq{"bucket": bucket}). + Where(sq.Eq{"key": key}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -186,7 +187,7 @@ func (r *FeedCacheRepo) Delete(ctx context.Context, bucket string, key string) e func (r *FeedCacheRepo) DeleteBucket(ctx context.Context, bucket string) error { queryBuilder := r.db.squirrel. Delete("feed_cache"). - Where("bucket = ?", bucket) + Where(sq.Eq{"bucket": bucket}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/filter.go b/internal/database/filter.go index d4346b4..7188ded 100644 --- a/internal/database/filter.go +++ b/internal/database/filter.go @@ -29,7 +29,7 @@ func NewFilterRepo(log logger.Logger, db *DB) domain.FilterRepo { } func (r *FilterRepo) Find(ctx context.Context, params domain.FilterQueryParams) ([]domain.Filter, error) { - tx, err := r.db.BeginTx(ctx, &sql.TxOptions{}) + tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) if err != nil { return nil, errors.Wrap(err, "error begin transaction") } @@ -40,7 +40,7 @@ func (r *FilterRepo) Find(ctx context.Context, params domain.FilterQueryParams) return nil, err } - if err = tx.Commit(); err != nil { + if err := tx.Commit(); err != nil { return nil, errors.Wrap(err, "error commit transaction find releases") } @@ -52,7 +52,7 @@ func (r *FilterRepo) find(ctx context.Context, tx *Tx, params domain.FilterQuery actionCountQuery := r.db.squirrel. Select("COUNT(*)"). From("action a"). - Where("a.filter_id = f.id") + Where(sq.Eq{"a.filter_id": "f.id"}) queryBuilder := r.db.squirrel. Select( @@ -70,7 +70,7 @@ func (r *FilterRepo) find(ctx context.Context, tx *Tx, params domain.FilterQuery From("filter f") if params.Search != "" { - queryBuilder = queryBuilder.Where("f.name LIKE ?", fmt.Sprint("%", params.Search, "%")) + queryBuilder = queryBuilder.Where(sq.Like{"f.name": params.Search + "%"}) } if len(params.Sort) > 0 { @@ -123,7 +123,7 @@ func (r *FilterRepo) ListFilters(ctx context.Context) ([]domain.Filter, error) { actionCountQuery := r.db.squirrel. Select("COUNT(*)"). From("action a"). - Where("a.filter_id = f.id") + Where(sq.Eq{"a.filter_id": "f.id"}) queryBuilder := r.db.squirrel. Select( @@ -233,7 +233,7 @@ func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter "updated_at", ). From("filter"). - Where("id = ?", filterID) + Where(sq.Eq{"id": filterID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -321,6 +321,10 @@ func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, e filters[i].Downloads = downloads } + if err := tx.Commit(); err != nil { + return nil, errors.Wrap(err, "error finding filter by identifier") + } + return filters, nil } @@ -392,9 +396,9 @@ func (r *FilterRepo) findByIndexerIdentifier(ctx context.Context, tx *Tx, indexe From("filter f"). Join("filter_indexer fi ON f.id = fi.filter_id"). Join("indexer i ON i.id = fi.indexer_id"). - Where("i.identifier = ?", indexer). - Where("i.enabled = ?", true). - Where("f.enabled = ?", true). + Where(sq.Eq{"i.identifier": indexer}). + Where(sq.Eq{"i.enabled": true}). + Where(sq.Eq{"f.enabled": true}). OrderBy("f.priority DESC") query, args, err := queryBuilder.ToSql() @@ -671,7 +675,7 @@ func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain. Set("external_webhook_data", filter.ExternalWebhookData). Set("external_webhook_expect_status", filter.ExternalWebhookExpectStatus). Set("updated_at", time.Now().Format(time.RFC3339)). - Where("id = ?", filter.ID) + Where(sq.Eq{"id": filter.ID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -866,7 +870,7 @@ func (r *FilterRepo) UpdatePartial(ctx context.Context, filter domain.FilterUpda q = q.Set("external_webhook_expect_status", filter.ExternalWebhookExpectStatus) } - q = q.Where("id = ?", filter.ID) + q = q.Where(sq.Eq{"id": filter.ID}) query, args, err := q.ToSql() if err != nil { @@ -897,7 +901,7 @@ func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bo Update("filter"). Set("enabled", enabled). Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")). - Where("id = ?", filterID) + Where(sq.Eq{"id": filterID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -921,7 +925,7 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int, deleteQueryBuilder := r.db.squirrel. Delete("filter_indexer"). - Where("filter_id = ?", filterID) + Where(sq.Eq{"filter_id": filterID}) deleteQuery, deleteArgs, err := deleteQueryBuilder.ToSql() if err != nil { @@ -949,8 +953,7 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int, r.log.Debug().Msgf("filter.StoreIndexerConnections: store '%v' on filter: %v", indexer.Name, filterID) } - err = tx.Commit() - if err != nil { + if err := tx.Commit(); err != nil { return errors.Wrap(err, "error store indexers for filter: %v", filterID) } @@ -978,7 +981,7 @@ func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, i func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int) error { queryBuilder := r.db.squirrel. Delete("filter_indexer"). - Where("filter_id = ?", filterID) + Where(sq.Eq{"filter_id": filterID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -996,7 +999,7 @@ func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int) func (r *FilterRepo) Delete(ctx context.Context, filterID int) error { queryBuilder := r.db.squirrel. Delete("filter"). - Where("id = ?", filterID) + Where(sq.Eq{"id": filterID}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/indexer.go b/internal/database/indexer.go index b87fbc5..c67d041 100644 --- a/internal/database/indexer.go +++ b/internal/database/indexer.go @@ -9,6 +9,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/pkg/errors" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog" ) @@ -61,7 +62,7 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma Set("base_url", indexer.BaseURL). Set("settings", settings). Set("updated_at", time.Now().Format(time.RFC3339)). - Where("id = ?", indexer.ID) + Where(sq.Eq{"id": indexer.ID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -117,7 +118,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er queryBuilder := r.db.squirrel. Select("id", "enabled", "name", "identifier", "implementation", "base_url", "settings"). From("indexer"). - Where("id = ?", id) + Where(sq.Eq{"id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -156,7 +157,7 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde Select("id", "enabled", "name", "identifier", "base_url", "settings"). From("indexer"). Join("filter_indexer ON indexer.id = filter_indexer.indexer_id"). - Where("filter_indexer.filter_id = ?", id) + Where(sq.Eq{"filter_indexer.filter_id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -202,7 +203,7 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde func (r *IndexerRepo) Delete(ctx context.Context, id int) error { queryBuilder := r.db.squirrel. Delete("indexer"). - Where("id = ?", id) + Where(sq.Eq{"id": id}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/irc.go b/internal/database/irc.go index fdc2f96..edb47f4 100644 --- a/internal/database/irc.go +++ b/internal/database/irc.go @@ -8,6 +8,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/pkg/errors" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog" ) @@ -28,7 +29,7 @@ func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetw queryBuilder := r.db.squirrel. Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command"). From("irc_network"). - Where("id = ?", id) + Where(sq.Eq{"id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -67,7 +68,7 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error { queryBuilder := r.db.squirrel. Delete("irc_channel"). - Where("network_id = ?", id) + Where(sq.Eq{"network_id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -81,7 +82,7 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error { netQueryBuilder := r.db.squirrel. Delete("irc_network"). - Where("id = ?", id) + Where(sq.Eq{"id": id}) netQuery, netArgs, err := netQueryBuilder.ToSql() if err != nil { @@ -93,10 +94,8 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error { return errors.Wrap(err, "error executing query") } - err = tx.Commit() - if err != nil { + if err := tx.Commit(); err != nil { return errors.Wrap(err, "error commit deleting network") - } return nil @@ -106,7 +105,7 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, queryBuilder := r.db.squirrel. Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command"). From("irc_network"). - Where("enabled = ?", true) + Where(sq.Eq{"enabled": true}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -200,7 +199,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) { queryBuilder := r.db.squirrel. Select("id", "name", "enabled", "password"). From("irc_channel"). - Where("network_id = ?", networkID) + Where(sq.Eq{"network_id": networkID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -237,8 +236,8 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN queryBuilder := r.db.squirrel. Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command"). From("irc_network"). - Where("server = ?", network.Server). - Where("auth_account = ?", network.Auth.Account) + Where(sq.Eq{"server": network.Server}). + Where(sq.Eq{"auth_account": network.Auth.Account}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -350,7 +349,7 @@ func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) Set("auth_password", password). Set("invite_command", inviteCmd). Set("updated_at", time.Now().Format(time.RFC3339)). - Where("id = ?", network.ID) + Where(sq.Eq{"id": network.ID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -378,7 +377,7 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha queryBuilder := r.db.squirrel. Delete("irc_channel"). - Where("network_id = ?", networkID) + Where(sq.Eq{"network_id": networkID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -438,8 +437,7 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha //channel.ID, err = res.LastInsertId() } - err = tx.Commit() - if err != nil { + if err := tx.Commit(); err != nil { return errors.Wrap(err, "error commit transaction store network") } @@ -458,7 +456,7 @@ func (r *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) erro Set("detached", channel.Detached). Set("name", channel.Name). Set("pass", pass). - Where("id = ?", channel.ID) + Where(sq.Eq{"id": channel.ID}) query, args, err := channelQueryBuilder.ToSql() if err != nil { @@ -528,7 +526,7 @@ func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error { Set("detached", channel.Detached). Set("name", channel.Name). Set("pass", pass). - Where("id = ?", channel.ID) + Where(sq.Eq{"id": channel.ID}) query, args, err := channelQueryBuilder.ToSql() if err != nil { @@ -549,7 +547,7 @@ func (r *IrcRepo) UpdateInviteCommand(networkID int64, invite string) error { channelQueryBuilder := r.db.squirrel. Update("irc_network"). Set("invite_command", invite). - Where("id = ?", networkID) + Where(sq.Eq{"id": networkID}) query, args, err := channelQueryBuilder.ToSql() if err != nil { diff --git a/internal/database/notification.go b/internal/database/notification.go index 69f1074..80c9ebb 100644 --- a/internal/database/notification.go +++ b/internal/database/notification.go @@ -144,7 +144,7 @@ func (r *NotificationRepo) FindByID(ctx context.Context, id int) (*domain.Notifi "updated_at", ). From("notification"). - Where("id = ?", id) + Where(sq.Eq{"id": id}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -240,7 +240,7 @@ func (r *NotificationRepo) Update(ctx context.Context, notification domain.Notif Set("api_key", apiKey). Set("channel", channel). Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")). - Where("id = ?", notification.ID) + Where(sq.Eq{"id": notification.ID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -260,7 +260,7 @@ func (r *NotificationRepo) Update(ctx context.Context, notification domain.Notif func (r *NotificationRepo) Delete(ctx context.Context, notificationID int) error { queryBuilder := r.db.squirrel. Delete("notification"). - Where("id = ?", notificationID) + Where(sq.Eq{"id": notificationID}) query, args, err := queryBuilder.ToSql() if err != nil { diff --git a/internal/database/release.go b/internal/database/release.go index 44dc096..a8719e4 100644 --- a/internal/database/release.go +++ b/internal/database/release.go @@ -61,8 +61,8 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain Set("status", a.Status). Set("rejections", pq.Array(a.Rejections)). Set("timestamp", a.Timestamp). - Where("id = ?", a.ID). - Where("release_id = ?", a.ReleaseID) + Where(sq.Eq{"id": a.ID}). + Where(sq.Eq{"release_id": a.ReleaseID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -98,7 +98,7 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain } func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryParams) ([]*domain.Release, int64, int64, error) { - tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{}) + tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) if err != nil { return nil, 0, 0, errors.Wrap(err, "error begin transaction") } @@ -109,14 +109,6 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryPar return nil, nextCursor, total, err } - for _, release := range releases { - statuses, err := repo.attachActionStatus(ctx, tx, release.ID) - if err != nil { - return releases, nextCursor, total, err - } - release.ActionStatus = statuses - } - if err = tx.Commit(); err != nil { return nil, 0, 0, errors.Wrap(err, "error commit transaction find releases") } @@ -125,23 +117,9 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryPar } func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain.ReleaseQueryParams) ([]*domain.Release, int64, int64, error) { - queryBuilder := repo.db.squirrel. - Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp", "(SELECT COUNT(*) FROM release) AS total_count"). - From("release r"). - OrderBy("r.id DESC") - - if params.Limit > 0 { - queryBuilder = queryBuilder.Limit(params.Limit) - } else { - queryBuilder = queryBuilder.Limit(20) - } - - if params.Offset > 0 { - queryBuilder = queryBuilder.Offset(params.Offset) - } - + whereQueryBuilder := sq.And{} if params.Cursor > 0 { - queryBuilder = queryBuilder.Where(sq.Lt{"r.id": params.Cursor}) + whereQueryBuilder = append(whereQueryBuilder, sq.Lt{"r.id": params.Cursor}) } if params.Search != "" { @@ -165,16 +143,20 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain if reskey := r.FindAllStringSubmatch(search, -1); len(reskey) != 0 { filter := sq.Or{} for _, found := range reskey { - filter = append(filter, sq.Like{v: strings.ReplaceAll(strings.Trim(strings.Trim(found[1], `"`), `'`), ".", "_") + "%"}) + filter = append(filter, ILike(v, strings.ReplaceAll(strings.Trim(strings.Trim(found[1], `"`), `'`), ".", "_")+"%")) } - queryBuilder = queryBuilder.Where(filter) + if len(filter) == 0 { + continue + } + + whereQueryBuilder = append(whereQueryBuilder, filter) search = strings.TrimSpace(r.ReplaceAllLiteralString(search, "")) } } if len(search) != 0 { - queryBuilder = queryBuilder.Where(sq.Like{"r.torrent_name": search + "%"}) + whereQueryBuilder = append(whereQueryBuilder, sq.Like{"r.torrent_name": search + "%"}) } } @@ -184,13 +166,56 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain filter = append(filter, sq.Eq{"r.indexer": v}) } - queryBuilder = queryBuilder.Where(filter) + if len(filter) > 0 { + whereQueryBuilder = append(whereQueryBuilder, filter) + } } - if params.Filters.PushStatus != "" { - queryBuilder = queryBuilder.InnerJoin("release_action_status ras ON r.id = ras.release_id").Where(sq.Eq{"ras.status": params.Filters.PushStatus}) + whereQuery, _, err := whereQueryBuilder.ToSql() + if err != nil { + return nil, 0, 0, errors.Wrap(err, "error building wherequery") } + subQueryBuilder := repo.db.squirrel. + Select("r.id"). + Distinct(). + From("release r") + + if params.Limit > 0 { + subQueryBuilder = subQueryBuilder.Limit(params.Limit) + } else { + subQueryBuilder = subQueryBuilder.Limit(20) + } + + if params.Offset > 0 { + subQueryBuilder = subQueryBuilder.Offset(params.Offset) + } + + if len(whereQueryBuilder) != 0 { + subQueryBuilder = subQueryBuilder.Where(whereQueryBuilder) + } + + countQuery := repo.db.squirrel.Select("COUNT(*)").From("release r").Where(whereQuery) + + if params.Filters.PushStatus != "" { + subQueryBuilder = subQueryBuilder.InnerJoin("release_action_status ras ON r.id = ras.release_id").Where(sq.Eq{"ras.status": params.Filters.PushStatus}) + countQuery = countQuery.InnerJoin("release_action_status ras ON r.id = ras.release_id").Where("ras.status = '" + params.Filters.PushStatus + `'`) + } + + subQuery, subArgs, err := subQueryBuilder.ToSql() + if err != nil { + return nil, 0, 0, errors.Wrap(err, "error building subquery") + } + + queryBuilder := repo.db.squirrel. + Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp", + "ras.id", "ras.status", "ras.action", "ras.type", "ras.client", "ras.filter", "ras.rejections", "ras.timestamp"). + Column(sq.Alias(countQuery, "page_total")). + From("release r"). + OrderBy("r.id DESC"). + Where("r.id IN ("+subQuery+")", subArgs...). + LeftJoin("release_action_status ras ON r.id = ras.release_id") + query, args, err := queryBuilder.ToSql() if err != nil { return nil, 0, 0, errors.Wrap(err, "error building query") @@ -215,15 +240,54 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain for rows.Next() { var rls domain.Release + var ras domain.ReleaseActionStatus - var indexer, filter sql.NullString + var rlsindexer, rlsfilter sql.NullString - if err := rows.Scan(&rls.ID, &rls.FilterStatus, pq.Array(&rls.Rejections), &indexer, &filter, &rls.Protocol, &rls.Title, &rls.TorrentName, &rls.Size, &rls.Timestamp, &countItems); err != nil { + var rasId sql.NullInt64 + var rasStatus, rasAction, rasType, rasClient, rasFilter sql.NullString + var rasRejections []sql.NullString + var rasTimestamp sql.NullTime + + if err := rows.Scan(&rls.ID, &rls.FilterStatus, pq.Array(&rls.Rejections), &rlsindexer, &rlsfilter, &rls.Protocol, &rls.Title, &rls.TorrentName, &rls.Size, &rls.Timestamp, &rasId, &rasStatus, &rasAction, &rasType, &rasClient, &rasFilter, pq.Array(&rasRejections), &rasTimestamp, &countItems); err != nil { return res, 0, 0, errors.Wrap(err, "error scanning row") } - rls.Indexer = indexer.String - rls.FilterName = filter.String + ras.ID = rasId.Int64 + ras.Status = domain.ReleasePushStatus(rasStatus.String) + ras.Action = rasAction.String + ras.Type = domain.ActionType(rasType.String) + ras.Client = rasClient.String + ras.Filter = rasFilter.String + ras.Timestamp = rasTimestamp.Time + ras.Rejections = []string{} + + for _, rejection := range rasRejections { + ras.Rejections = append(ras.Rejections, rejection.String) + } + + idx := 0 + for ; idx < len(res); idx++ { + if res[idx].ID != rls.ID { + continue + } + + res[idx].ActionStatus = append(res[idx].ActionStatus, ras) + break + } + + if idx != len(res) { + continue + } + + rls.Indexer = rlsindexer.String + rls.FilterName = rlsfilter.String + rls.ActionStatus = make([]domain.ReleaseActionStatus, 0) + + // only add ActionStatus if it's not empty + if ras.ID > 0 { + rls.ActionStatus = append(rls.ActionStatus, ras) + } res = append(res, &rls) } @@ -238,25 +302,17 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain } func (repo *ReleaseRepo) FindRecent(ctx context.Context) ([]*domain.Release, error) { - tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{}) + tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) if err != nil { return nil, errors.Wrap(err, "error begin transaction") } defer tx.Rollback() - releases, err := repo.findRecentReleases(ctx, tx) + releases, _, _, err := repo.findReleases(ctx, tx, domain.ReleaseQueryParams{Limit: 10}) if err != nil { return nil, err } - for _, release := range releases { - statuses, err := repo.attachActionStatus(ctx, tx, release.ID) - if err != nil { - return releases, err - } - release.ActionStatus = statuses - } - if err = tx.Commit(); err != nil { return nil, errors.Wrap(err, "error transaction commit") } @@ -264,51 +320,6 @@ func (repo *ReleaseRepo) FindRecent(ctx context.Context) ([]*domain.Release, err return releases, nil } -func (repo *ReleaseRepo) findRecentReleases(ctx context.Context, tx *Tx) ([]*domain.Release, error) { - queryBuilder := repo.db.squirrel. - Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp"). - From("release r"). - OrderBy("r.id DESC"). - Limit(10) - - query, args, err := queryBuilder.ToSql() - if err != nil { - return nil, errors.Wrap(err, "error building query") - } - - repo.log.Trace().Str("database", "release.find").Msgf("query: '%v', args: '%v'", query, args) - - res := make([]*domain.Release, 0) - - rows, err := tx.QueryContext(ctx, query, args...) - if err != nil { - return res, errors.Wrap(err, "error executing query") - } - - defer rows.Close() - - if err := rows.Err(); err != nil { - return res, errors.Wrap(err, "rows error") - } - - for rows.Next() { - var rls domain.Release - - var indexer, filter sql.NullString - - if err := rows.Scan(&rls.ID, &rls.FilterStatus, pq.Array(&rls.Rejections), &indexer, &filter, &rls.Protocol, &rls.Title, &rls.TorrentName, &rls.Size, &rls.Timestamp); err != nil { - return res, errors.Wrap(err, "error scanning row") - } - - rls.Indexer = indexer.String - rls.FilterName = filter.String - - res = append(res, &rls) - } - - return res, nil -} - func (repo *ReleaseRepo) GetIndexerOptions(ctx context.Context) ([]string, error) { query := `SELECT DISTINCT indexer FROM "release" UNION SELECT DISTINCT identifier indexer FROM indexer;` @@ -346,7 +357,7 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release queryBuilder := repo.db.squirrel. Select("id", "status", "action", "type", "client", "filter", "rejections", "timestamp"). From("release_action_status"). - Where("release_id = ?", releaseID) + Where(sq.Eq{"release_id": releaseID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -386,11 +397,10 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release } func (repo *ReleaseRepo) attachActionStatus(ctx context.Context, tx *Tx, releaseID int64) ([]domain.ReleaseActionStatus, error) { - queryBuilder := repo.db.squirrel. Select("id", "status", "action", "type", "client", "filter", "rejections", "timestamp"). From("release_action_status"). - Where("release_id = ?", releaseID) + Where(sq.Eq{"release_id": releaseID}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -511,7 +521,7 @@ func (repo *ReleaseRepo) CanDownloadShow(ctx context.Context, title string, seas queryBuilder := repo.db.squirrel. Select("COUNT(*)"). From("release"). - Where("title LIKE ?", fmt.Sprint("%", title, "%")) + Where(ILike("title", title+"%")) if season > 0 && episode > 0 { queryBuilder = queryBuilder.Where(sq.Or{ diff --git a/internal/database/user.go b/internal/database/user.go index d66ab97..b3c544d 100644 --- a/internal/database/user.go +++ b/internal/database/user.go @@ -7,6 +7,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/pkg/errors" + sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog" ) @@ -49,7 +50,7 @@ func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain queryBuilder := r.db.squirrel. Select("id", "username", "password"). From("users"). - Where("username = ?", username) + Where(sq.Eq{"username": username}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -104,7 +105,7 @@ func (r *UserRepo) Update(ctx context.Context, user domain.User) error { Update("users"). Set("username", user.Username). Set("password", user.Password). - Where("username = ?", user.Username) + Where(sq.Eq{"username": user.Username}) query, args, err := queryBuilder.ToSql() if err != nil {