refactor(database): clean-up queries (#569)

* fix(database): build WHERE using squirrel

* flip LIKEs

* change sql.LevelReadCommitted

* feat(db): add missing commits

* implement single query for releases

* cleanup

* feat(releases): properly handle limit for Find

* feat(releases): make dynamic ILike helper

* feat(releases): check for empty ReleaseActionStatus

* add values as sql.NullX
* check if ID is non 0

* feat(releases): improve find
This commit is contained in:
Kyle Sanderson 2022-12-30 14:53:45 -08:00 committed by GitHub
parent e6c48a5228
commit 19b3899a5c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
12 changed files with 206 additions and 168 deletions

View file

@ -29,7 +29,7 @@ func NewActionRepo(log logger.Logger, db *DB, clientRepo domain.DownloadClientRe
func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]*domain.Action, error) {
tx, err := r.db.BeginTx(ctx, nil)
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return nil, err
}
@ -51,6 +51,10 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]*domai
}
}
if err = tx.Commit(); err != nil {
return nil, errors.Wrap(err, "error finding filter by id")
}
return actions, nil
}
@ -87,7 +91,7 @@ func (r *ActionRepo) findByFilterID(ctx context.Context, tx *Tx, filterID int) (
"client_id",
).
From("action").
Where("filter_id = ?", filterID)
Where(sq.Eq{"filter_id": filterID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -165,7 +169,7 @@ func (r *ActionRepo) attachDownloadClient(ctx context.Context, tx *Tx, clientID
"settings",
).
From("client").
Where("id = ?", clientID)
Where(sq.Eq{"id": clientID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -282,7 +286,7 @@ func (r *ActionRepo) List(ctx context.Context) ([]domain.Action, error) {
func (r *ActionRepo) Delete(actionID int) error {
queryBuilder := r.db.squirrel.
Delete("action").
Where("id = ?", actionID)
Where(sq.Eq{"id": actionID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -302,7 +306,7 @@ func (r *ActionRepo) Delete(actionID int) error {
func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error {
queryBuilder := r.db.squirrel.
Delete("action").
Where("filter_id = ?", filterID)
Where(sq.Eq{"filter_id": filterID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -472,7 +476,7 @@ func (r *ActionRepo) Update(ctx context.Context, action domain.Action) (*domain.
Set("webhook_data", webhookData).
Set("client_id", clientID).
Set("filter_id", filterID).
Where("id = ?", action.ID)
Where(sq.Eq{"id": action.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -499,7 +503,7 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []*domain.A
deleteQueryBuilder := r.db.squirrel.
Delete("action").
Where("filter_id = ?", filterID)
Where(sq.Eq{"filter_id": filterID})
deleteQuery, deleteArgs, err := deleteQueryBuilder.ToSql()
if err != nil {
@ -622,7 +626,7 @@ func (r *ActionRepo) ToggleEnabled(actionID int) error {
queryBuilder := r.db.squirrel.
Update("action").
Set("enabled", sq.Expr("NOT enabled")).
Where("id = ?", actionID)
Where(sq.Eq{"id": actionID})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -9,6 +9,7 @@ import (
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/lib/pq"
"github.com/rs/zerolog"
)
@ -55,7 +56,7 @@ func (r *APIRepo) Store(ctx context.Context, key *domain.APIKey) error {
func (r *APIRepo) Delete(ctx context.Context, key string) error {
queryBuilder := r.db.squirrel.
Delete("api_key").
Where("key = ?", key)
Where(sq.Eq{"key": key})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -14,6 +14,8 @@ import (
"github.com/rs/zerolog"
)
var databaseDriver = "sqlite"
type DB struct {
log zerolog.Logger
handler *sql.DB
@ -37,6 +39,7 @@ func NewDB(cfg *domain.Config, log logger.Logger) (*DB, error) {
switch cfg.DatabaseType {
case "sqlite":
databaseDriver = "sqlite"
db.Driver = "sqlite"
db.DSN = dataSourceName(cfg.ConfigPath, "autobrr.db")
case "postgres":
@ -45,6 +48,7 @@ func NewDB(cfg *domain.Config, log logger.Logger) (*DB, error) {
}
db.DSN = fmt.Sprintf("postgres://%v:%v@%v:%d/%v?sslmode=disable", cfg.PostgresUser, cfg.PostgresPass, cfg.PostgresHost, cfg.PostgresPort, cfg.PostgresDatabase)
db.Driver = "postgres"
databaseDriver = "postgres"
default:
return nil, errors.New("unsupported database: %v", cfg.DatabaseType)
}
@ -106,3 +110,17 @@ type Tx struct {
*sql.Tx
handler *DB
}
type ILikeDynamic interface {
ToSql() (sql string, args []interface{}, err error)
}
// ILike is a wrapper for sq.Like and sq.ILike
// SQLite does not support ILike but postgres does so this checks what database is being used
func ILike(col string, val string) ILikeDynamic {
if databaseDriver == "sqlite" {
return sq.Like{col: val}
}
return sq.ILike{col: val}
}

View file

@ -10,6 +10,7 @@ import (
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/rs/zerolog"
)
@ -136,7 +137,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
"settings",
).
From("client").
Where("id = ?", id)
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -232,7 +233,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
Set("username", client.Username).
Set("password", client.Password).
Set("settings", string(settingsJson)).
Where("id = ?", client.ID)
Where(sq.Eq{"id": client.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -255,7 +256,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
queryBuilder := r.db.squirrel.
Delete("client").
Where("id = ?", clientID)
Where(sq.Eq{"id": clientID})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -43,7 +43,7 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
).
From("feed f").
Join("indexer i ON f.indexer_id = i.id").
Where("f.id = ?", id)
Where(sq.Eq{"f.id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -89,7 +89,7 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
).
From("feed f").
Join("indexer i ON f.indexer_id = i.id").
Where("i.name = ?", indexer)
Where(sq.Eq{"i.name": indexer})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -221,7 +221,7 @@ func (r *FeedRepo) Update(ctx context.Context, feed *domain.Feed) error {
Set("api_key", feed.ApiKey).
Set("cookie", feed.Cookie).
Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")).
Where("id = ?", feed.ID)
Where(sq.Eq{"id": feed.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -240,7 +240,7 @@ func (r *FeedRepo) UpdateLastRun(ctx context.Context, feedID int) error {
queryBuilder := r.db.squirrel.
Update("feed").
Set("last_run", sq.Expr("CURRENT_TIMESTAMP")).
Where("id = ?", feedID)
Where(sq.Eq{"id": feedID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -260,7 +260,7 @@ func (r *FeedRepo) UpdateLastRunWithData(ctx context.Context, feedID int, data s
Update("feed").
Set("last_run", sq.Expr("CURRENT_TIMESTAMP")).
Set("last_run_data", data).
Where("id = ?", feedID)
Where(sq.Eq{"id": feedID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -282,7 +282,7 @@ func (r *FeedRepo) ToggleEnabled(ctx context.Context, id int, enabled bool) erro
Update("feed").
Set("enabled", enabled).
Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")).
Where("id = ?", id)
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -299,7 +299,7 @@ func (r *FeedRepo) ToggleEnabled(ctx context.Context, id int, enabled bool) erro
func (r *FeedRepo) Delete(ctx context.Context, id int) error {
queryBuilder := r.db.squirrel.
Delete("feed").
Where("id = ?", id)
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -8,6 +8,7 @@ import (
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/rs/zerolog"
)
@ -31,9 +32,9 @@ func (r *FeedCacheRepo) Get(bucket string, key string) ([]byte, error) {
"ttl",
).
From("feed_cache").
Where("bucket = ?", bucket).
Where("key = ?", key).
Where("ttl > ?", time.Now())
Where(sq.Eq{"bucket": bucket}).
Where(sq.Eq{"key": key}).
Where(sq.Gt{"ttl": time.Now()})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -64,7 +65,7 @@ func (r *FeedCacheRepo) GetByBucket(ctx context.Context, bucket string) ([]domai
"ttl",
).
From("feed_cache").
Where("bucket = ?", bucket)
Where(sq.Eq{"bucket": bucket})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -102,7 +103,7 @@ func (r *FeedCacheRepo) GetCountByBucket(ctx context.Context, bucket string) (in
queryBuilder := r.db.squirrel.
Select("COUNT(*)").
From("feed_cache").
Where("bucket = ?", bucket)
Where(sq.Eq{"bucket": bucket})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -128,8 +129,8 @@ func (r *FeedCacheRepo) Exists(bucket string, key string) (bool, error) {
Select("1").
Prefix("SELECT EXISTS (").
From("feed_cache").
Where("bucket = ?", bucket).
Where("key = ?", key).
Where(sq.Eq{"bucket": bucket}).
Where(sq.Eq{"key": key}).
Suffix(")")
query, args, err := queryBuilder.ToSql()
@ -167,8 +168,8 @@ func (r *FeedCacheRepo) Put(bucket string, key string, val []byte, ttl time.Time
func (r *FeedCacheRepo) Delete(ctx context.Context, bucket string, key string) error {
queryBuilder := r.db.squirrel.
Delete("feed_cache").
Where("bucket = ?", bucket).
Where("key = ?", key)
Where(sq.Eq{"bucket": bucket}).
Where(sq.Eq{"key": key})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -186,7 +187,7 @@ func (r *FeedCacheRepo) Delete(ctx context.Context, bucket string, key string) e
func (r *FeedCacheRepo) DeleteBucket(ctx context.Context, bucket string) error {
queryBuilder := r.db.squirrel.
Delete("feed_cache").
Where("bucket = ?", bucket)
Where(sq.Eq{"bucket": bucket})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -29,7 +29,7 @@ func NewFilterRepo(log logger.Logger, db *DB) domain.FilterRepo {
}
func (r *FilterRepo) Find(ctx context.Context, params domain.FilterQueryParams) ([]domain.Filter, error) {
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{})
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return nil, errors.Wrap(err, "error begin transaction")
}
@ -40,7 +40,7 @@ func (r *FilterRepo) Find(ctx context.Context, params domain.FilterQueryParams)
return nil, err
}
if err = tx.Commit(); err != nil {
if err := tx.Commit(); err != nil {
return nil, errors.Wrap(err, "error commit transaction find releases")
}
@ -52,7 +52,7 @@ func (r *FilterRepo) find(ctx context.Context, tx *Tx, params domain.FilterQuery
actionCountQuery := r.db.squirrel.
Select("COUNT(*)").
From("action a").
Where("a.filter_id = f.id")
Where(sq.Eq{"a.filter_id": "f.id"})
queryBuilder := r.db.squirrel.
Select(
@ -70,7 +70,7 @@ func (r *FilterRepo) find(ctx context.Context, tx *Tx, params domain.FilterQuery
From("filter f")
if params.Search != "" {
queryBuilder = queryBuilder.Where("f.name LIKE ?", fmt.Sprint("%", params.Search, "%"))
queryBuilder = queryBuilder.Where(sq.Like{"f.name": params.Search + "%"})
}
if len(params.Sort) > 0 {
@ -123,7 +123,7 @@ func (r *FilterRepo) ListFilters(ctx context.Context) ([]domain.Filter, error) {
actionCountQuery := r.db.squirrel.
Select("COUNT(*)").
From("action a").
Where("a.filter_id = f.id")
Where(sq.Eq{"a.filter_id": "f.id"})
queryBuilder := r.db.squirrel.
Select(
@ -233,7 +233,7 @@ func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter
"updated_at",
).
From("filter").
Where("id = ?", filterID)
Where(sq.Eq{"id": filterID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -321,6 +321,10 @@ func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, e
filters[i].Downloads = downloads
}
if err := tx.Commit(); err != nil {
return nil, errors.Wrap(err, "error finding filter by identifier")
}
return filters, nil
}
@ -392,9 +396,9 @@ func (r *FilterRepo) findByIndexerIdentifier(ctx context.Context, tx *Tx, indexe
From("filter f").
Join("filter_indexer fi ON f.id = fi.filter_id").
Join("indexer i ON i.id = fi.indexer_id").
Where("i.identifier = ?", indexer).
Where("i.enabled = ?", true).
Where("f.enabled = ?", true).
Where(sq.Eq{"i.identifier": indexer}).
Where(sq.Eq{"i.enabled": true}).
Where(sq.Eq{"f.enabled": true}).
OrderBy("f.priority DESC")
query, args, err := queryBuilder.ToSql()
@ -671,7 +675,7 @@ func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain.
Set("external_webhook_data", filter.ExternalWebhookData).
Set("external_webhook_expect_status", filter.ExternalWebhookExpectStatus).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where("id = ?", filter.ID)
Where(sq.Eq{"id": filter.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -866,7 +870,7 @@ func (r *FilterRepo) UpdatePartial(ctx context.Context, filter domain.FilterUpda
q = q.Set("external_webhook_expect_status", filter.ExternalWebhookExpectStatus)
}
q = q.Where("id = ?", filter.ID)
q = q.Where(sq.Eq{"id": filter.ID})
query, args, err := q.ToSql()
if err != nil {
@ -897,7 +901,7 @@ func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bo
Update("filter").
Set("enabled", enabled).
Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")).
Where("id = ?", filterID)
Where(sq.Eq{"id": filterID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -921,7 +925,7 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int,
deleteQueryBuilder := r.db.squirrel.
Delete("filter_indexer").
Where("filter_id = ?", filterID)
Where(sq.Eq{"filter_id": filterID})
deleteQuery, deleteArgs, err := deleteQueryBuilder.ToSql()
if err != nil {
@ -949,8 +953,7 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int,
r.log.Debug().Msgf("filter.StoreIndexerConnections: store '%v' on filter: %v", indexer.Name, filterID)
}
err = tx.Commit()
if err != nil {
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "error store indexers for filter: %v", filterID)
}
@ -978,7 +981,7 @@ func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, i
func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int) error {
queryBuilder := r.db.squirrel.
Delete("filter_indexer").
Where("filter_id = ?", filterID)
Where(sq.Eq{"filter_id": filterID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -996,7 +999,7 @@ func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int)
func (r *FilterRepo) Delete(ctx context.Context, filterID int) error {
queryBuilder := r.db.squirrel.
Delete("filter").
Where("id = ?", filterID)
Where(sq.Eq{"id": filterID})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -9,6 +9,7 @@ import (
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/rs/zerolog"
)
@ -61,7 +62,7 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma
Set("base_url", indexer.BaseURL).
Set("settings", settings).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where("id = ?", indexer.ID)
Where(sq.Eq{"id": indexer.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -117,7 +118,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "implementation", "base_url", "settings").
From("indexer").
Where("id = ?", id)
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -156,7 +157,7 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde
Select("id", "enabled", "name", "identifier", "base_url", "settings").
From("indexer").
Join("filter_indexer ON indexer.id = filter_indexer.indexer_id").
Where("filter_indexer.filter_id = ?", id)
Where(sq.Eq{"filter_indexer.filter_id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -202,7 +203,7 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde
func (r *IndexerRepo) Delete(ctx context.Context, id int) error {
queryBuilder := r.db.squirrel.
Delete("indexer").
Where("id = ?", id)
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -8,6 +8,7 @@ import (
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/rs/zerolog"
)
@ -28,7 +29,7 @@ func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetw
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command").
From("irc_network").
Where("id = ?", id)
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -67,7 +68,7 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
queryBuilder := r.db.squirrel.
Delete("irc_channel").
Where("network_id = ?", id)
Where(sq.Eq{"network_id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -81,7 +82,7 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
netQueryBuilder := r.db.squirrel.
Delete("irc_network").
Where("id = ?", id)
Where(sq.Eq{"id": id})
netQuery, netArgs, err := netQueryBuilder.ToSql()
if err != nil {
@ -93,10 +94,8 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
return errors.Wrap(err, "error executing query")
}
err = tx.Commit()
if err != nil {
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "error commit deleting network")
}
return nil
@ -106,7 +105,7 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork,
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command").
From("irc_network").
Where("enabled = ?", true)
Where(sq.Eq{"enabled": true})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -200,7 +199,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
queryBuilder := r.db.squirrel.
Select("id", "name", "enabled", "password").
From("irc_channel").
Where("network_id = ?", networkID)
Where(sq.Eq{"network_id": networkID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -237,8 +236,8 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command").
From("irc_network").
Where("server = ?", network.Server).
Where("auth_account = ?", network.Auth.Account)
Where(sq.Eq{"server": network.Server}).
Where(sq.Eq{"auth_account": network.Auth.Account})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -350,7 +349,7 @@ func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork)
Set("auth_password", password).
Set("invite_command", inviteCmd).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where("id = ?", network.ID)
Where(sq.Eq{"id": network.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -378,7 +377,7 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
queryBuilder := r.db.squirrel.
Delete("irc_channel").
Where("network_id = ?", networkID)
Where(sq.Eq{"network_id": networkID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -438,8 +437,7 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
//channel.ID, err = res.LastInsertId()
}
err = tx.Commit()
if err != nil {
if err := tx.Commit(); err != nil {
return errors.Wrap(err, "error commit transaction store network")
}
@ -458,7 +456,7 @@ func (r *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) erro
Set("detached", channel.Detached).
Set("name", channel.Name).
Set("pass", pass).
Where("id = ?", channel.ID)
Where(sq.Eq{"id": channel.ID})
query, args, err := channelQueryBuilder.ToSql()
if err != nil {
@ -528,7 +526,7 @@ func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error {
Set("detached", channel.Detached).
Set("name", channel.Name).
Set("pass", pass).
Where("id = ?", channel.ID)
Where(sq.Eq{"id": channel.ID})
query, args, err := channelQueryBuilder.ToSql()
if err != nil {
@ -549,7 +547,7 @@ func (r *IrcRepo) UpdateInviteCommand(networkID int64, invite string) error {
channelQueryBuilder := r.db.squirrel.
Update("irc_network").
Set("invite_command", invite).
Where("id = ?", networkID)
Where(sq.Eq{"id": networkID})
query, args, err := channelQueryBuilder.ToSql()
if err != nil {

View file

@ -144,7 +144,7 @@ func (r *NotificationRepo) FindByID(ctx context.Context, id int) (*domain.Notifi
"updated_at",
).
From("notification").
Where("id = ?", id)
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -240,7 +240,7 @@ func (r *NotificationRepo) Update(ctx context.Context, notification domain.Notif
Set("api_key", apiKey).
Set("channel", channel).
Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")).
Where("id = ?", notification.ID)
Where(sq.Eq{"id": notification.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -260,7 +260,7 @@ func (r *NotificationRepo) Update(ctx context.Context, notification domain.Notif
func (r *NotificationRepo) Delete(ctx context.Context, notificationID int) error {
queryBuilder := r.db.squirrel.
Delete("notification").
Where("id = ?", notificationID)
Where(sq.Eq{"id": notificationID})
query, args, err := queryBuilder.ToSql()
if err != nil {

View file

@ -61,8 +61,8 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
Set("status", a.Status).
Set("rejections", pq.Array(a.Rejections)).
Set("timestamp", a.Timestamp).
Where("id = ?", a.ID).
Where("release_id = ?", a.ReleaseID)
Where(sq.Eq{"id": a.ID}).
Where(sq.Eq{"release_id": a.ReleaseID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -98,7 +98,7 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
}
func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryParams) ([]*domain.Release, int64, int64, error) {
tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{})
tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return nil, 0, 0, errors.Wrap(err, "error begin transaction")
}
@ -109,14 +109,6 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryPar
return nil, nextCursor, total, err
}
for _, release := range releases {
statuses, err := repo.attachActionStatus(ctx, tx, release.ID)
if err != nil {
return releases, nextCursor, total, err
}
release.ActionStatus = statuses
}
if err = tx.Commit(); err != nil {
return nil, 0, 0, errors.Wrap(err, "error commit transaction find releases")
}
@ -125,23 +117,9 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.ReleaseQueryPar
}
func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain.ReleaseQueryParams) ([]*domain.Release, int64, int64, error) {
queryBuilder := repo.db.squirrel.
Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp", "(SELECT COUNT(*) FROM release) AS total_count").
From("release r").
OrderBy("r.id DESC")
if params.Limit > 0 {
queryBuilder = queryBuilder.Limit(params.Limit)
} else {
queryBuilder = queryBuilder.Limit(20)
}
if params.Offset > 0 {
queryBuilder = queryBuilder.Offset(params.Offset)
}
whereQueryBuilder := sq.And{}
if params.Cursor > 0 {
queryBuilder = queryBuilder.Where(sq.Lt{"r.id": params.Cursor})
whereQueryBuilder = append(whereQueryBuilder, sq.Lt{"r.id": params.Cursor})
}
if params.Search != "" {
@ -165,16 +143,20 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain
if reskey := r.FindAllStringSubmatch(search, -1); len(reskey) != 0 {
filter := sq.Or{}
for _, found := range reskey {
filter = append(filter, sq.Like{v: strings.ReplaceAll(strings.Trim(strings.Trim(found[1], `"`), `'`), ".", "_") + "%"})
filter = append(filter, ILike(v, strings.ReplaceAll(strings.Trim(strings.Trim(found[1], `"`), `'`), ".", "_")+"%"))
}
queryBuilder = queryBuilder.Where(filter)
if len(filter) == 0 {
continue
}
whereQueryBuilder = append(whereQueryBuilder, filter)
search = strings.TrimSpace(r.ReplaceAllLiteralString(search, ""))
}
}
if len(search) != 0 {
queryBuilder = queryBuilder.Where(sq.Like{"r.torrent_name": search + "%"})
whereQueryBuilder = append(whereQueryBuilder, sq.Like{"r.torrent_name": search + "%"})
}
}
@ -184,13 +166,56 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain
filter = append(filter, sq.Eq{"r.indexer": v})
}
queryBuilder = queryBuilder.Where(filter)
if len(filter) > 0 {
whereQueryBuilder = append(whereQueryBuilder, filter)
}
}
if params.Filters.PushStatus != "" {
queryBuilder = queryBuilder.InnerJoin("release_action_status ras ON r.id = ras.release_id").Where(sq.Eq{"ras.status": params.Filters.PushStatus})
whereQuery, _, err := whereQueryBuilder.ToSql()
if err != nil {
return nil, 0, 0, errors.Wrap(err, "error building wherequery")
}
subQueryBuilder := repo.db.squirrel.
Select("r.id").
Distinct().
From("release r")
if params.Limit > 0 {
subQueryBuilder = subQueryBuilder.Limit(params.Limit)
} else {
subQueryBuilder = subQueryBuilder.Limit(20)
}
if params.Offset > 0 {
subQueryBuilder = subQueryBuilder.Offset(params.Offset)
}
if len(whereQueryBuilder) != 0 {
subQueryBuilder = subQueryBuilder.Where(whereQueryBuilder)
}
countQuery := repo.db.squirrel.Select("COUNT(*)").From("release r").Where(whereQuery)
if params.Filters.PushStatus != "" {
subQueryBuilder = subQueryBuilder.InnerJoin("release_action_status ras ON r.id = ras.release_id").Where(sq.Eq{"ras.status": params.Filters.PushStatus})
countQuery = countQuery.InnerJoin("release_action_status ras ON r.id = ras.release_id").Where("ras.status = '" + params.Filters.PushStatus + `'`)
}
subQuery, subArgs, err := subQueryBuilder.ToSql()
if err != nil {
return nil, 0, 0, errors.Wrap(err, "error building subquery")
}
queryBuilder := repo.db.squirrel.
Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp",
"ras.id", "ras.status", "ras.action", "ras.type", "ras.client", "ras.filter", "ras.rejections", "ras.timestamp").
Column(sq.Alias(countQuery, "page_total")).
From("release r").
OrderBy("r.id DESC").
Where("r.id IN ("+subQuery+")", subArgs...).
LeftJoin("release_action_status ras ON r.id = ras.release_id")
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, 0, 0, errors.Wrap(err, "error building query")
@ -215,15 +240,54 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain
for rows.Next() {
var rls domain.Release
var ras domain.ReleaseActionStatus
var indexer, filter sql.NullString
var rlsindexer, rlsfilter sql.NullString
if err := rows.Scan(&rls.ID, &rls.FilterStatus, pq.Array(&rls.Rejections), &indexer, &filter, &rls.Protocol, &rls.Title, &rls.TorrentName, &rls.Size, &rls.Timestamp, &countItems); err != nil {
var rasId sql.NullInt64
var rasStatus, rasAction, rasType, rasClient, rasFilter sql.NullString
var rasRejections []sql.NullString
var rasTimestamp sql.NullTime
if err := rows.Scan(&rls.ID, &rls.FilterStatus, pq.Array(&rls.Rejections), &rlsindexer, &rlsfilter, &rls.Protocol, &rls.Title, &rls.TorrentName, &rls.Size, &rls.Timestamp, &rasId, &rasStatus, &rasAction, &rasType, &rasClient, &rasFilter, pq.Array(&rasRejections), &rasTimestamp, &countItems); err != nil {
return res, 0, 0, errors.Wrap(err, "error scanning row")
}
rls.Indexer = indexer.String
rls.FilterName = filter.String
ras.ID = rasId.Int64
ras.Status = domain.ReleasePushStatus(rasStatus.String)
ras.Action = rasAction.String
ras.Type = domain.ActionType(rasType.String)
ras.Client = rasClient.String
ras.Filter = rasFilter.String
ras.Timestamp = rasTimestamp.Time
ras.Rejections = []string{}
for _, rejection := range rasRejections {
ras.Rejections = append(ras.Rejections, rejection.String)
}
idx := 0
for ; idx < len(res); idx++ {
if res[idx].ID != rls.ID {
continue
}
res[idx].ActionStatus = append(res[idx].ActionStatus, ras)
break
}
if idx != len(res) {
continue
}
rls.Indexer = rlsindexer.String
rls.FilterName = rlsfilter.String
rls.ActionStatus = make([]domain.ReleaseActionStatus, 0)
// only add ActionStatus if it's not empty
if ras.ID > 0 {
rls.ActionStatus = append(rls.ActionStatus, ras)
}
res = append(res, &rls)
}
@ -238,25 +302,17 @@ func (repo *ReleaseRepo) findReleases(ctx context.Context, tx *Tx, params domain
}
func (repo *ReleaseRepo) FindRecent(ctx context.Context) ([]*domain.Release, error) {
tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{})
tx, err := repo.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return nil, errors.Wrap(err, "error begin transaction")
}
defer tx.Rollback()
releases, err := repo.findRecentReleases(ctx, tx)
releases, _, _, err := repo.findReleases(ctx, tx, domain.ReleaseQueryParams{Limit: 10})
if err != nil {
return nil, err
}
for _, release := range releases {
statuses, err := repo.attachActionStatus(ctx, tx, release.ID)
if err != nil {
return releases, err
}
release.ActionStatus = statuses
}
if err = tx.Commit(); err != nil {
return nil, errors.Wrap(err, "error transaction commit")
}
@ -264,51 +320,6 @@ func (repo *ReleaseRepo) FindRecent(ctx context.Context) ([]*domain.Release, err
return releases, nil
}
func (repo *ReleaseRepo) findRecentReleases(ctx context.Context, tx *Tx) ([]*domain.Release, error) {
queryBuilder := repo.db.squirrel.
Select("r.id", "r.filter_status", "r.rejections", "r.indexer", "r.filter", "r.protocol", "r.title", "r.torrent_name", "r.size", "r.timestamp").
From("release r").
OrderBy("r.id DESC").
Limit(10)
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
repo.log.Trace().Str("database", "release.find").Msgf("query: '%v', args: '%v'", query, args)
res := make([]*domain.Release, 0)
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return res, errors.Wrap(err, "error executing query")
}
defer rows.Close()
if err := rows.Err(); err != nil {
return res, errors.Wrap(err, "rows error")
}
for rows.Next() {
var rls domain.Release
var indexer, filter sql.NullString
if err := rows.Scan(&rls.ID, &rls.FilterStatus, pq.Array(&rls.Rejections), &indexer, &filter, &rls.Protocol, &rls.Title, &rls.TorrentName, &rls.Size, &rls.Timestamp); err != nil {
return res, errors.Wrap(err, "error scanning row")
}
rls.Indexer = indexer.String
rls.FilterName = filter.String
res = append(res, &rls)
}
return res, nil
}
func (repo *ReleaseRepo) GetIndexerOptions(ctx context.Context) ([]string, error) {
query := `SELECT DISTINCT indexer FROM "release" UNION SELECT DISTINCT identifier indexer FROM indexer;`
@ -346,7 +357,7 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release
queryBuilder := repo.db.squirrel.
Select("id", "status", "action", "type", "client", "filter", "rejections", "timestamp").
From("release_action_status").
Where("release_id = ?", releaseID)
Where(sq.Eq{"release_id": releaseID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -386,11 +397,10 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release
}
func (repo *ReleaseRepo) attachActionStatus(ctx context.Context, tx *Tx, releaseID int64) ([]domain.ReleaseActionStatus, error) {
queryBuilder := repo.db.squirrel.
Select("id", "status", "action", "type", "client", "filter", "rejections", "timestamp").
From("release_action_status").
Where("release_id = ?", releaseID)
Where(sq.Eq{"release_id": releaseID})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -511,7 +521,7 @@ func (repo *ReleaseRepo) CanDownloadShow(ctx context.Context, title string, seas
queryBuilder := repo.db.squirrel.
Select("COUNT(*)").
From("release").
Where("title LIKE ?", fmt.Sprint("%", title, "%"))
Where(ILike("title", title+"%"))
if season > 0 && episode > 0 {
queryBuilder = queryBuilder.Where(sq.Or{

View file

@ -7,6 +7,7 @@ import (
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/rs/zerolog"
)
@ -49,7 +50,7 @@ func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain
queryBuilder := r.db.squirrel.
Select("id", "username", "password").
From("users").
Where("username = ?", username)
Where(sq.Eq{"username": username})
query, args, err := queryBuilder.ToSql()
if err != nil {
@ -104,7 +105,7 @@ func (r *UserRepo) Update(ctx context.Context, user domain.User) error {
Update("users").
Set("username", user.Username).
Set("password", user.Password).
Where("username = ?", user.Username)
Where(sq.Eq{"username": user.Username})
query, args, err := queryBuilder.ToSql()
if err != nil {