mirror of
https://github.com/idanoo/autobrr
synced 2025-07-23 00:39:13 +00:00
fix(actions): reject if client is disabled (#1626)
* fix(actions): error on disabled client * fix(actions): sql scan args * refactor: download client cache for actions * fix: tests client store * fix: tests client store and int conversion * fix: tests revert findbyid ctx timeout * fix: tests row.err * feat: add logging to download client cache
This commit is contained in:
parent
77e1c2c305
commit
861f30c144
30 changed files with 928 additions and 680 deletions
|
@ -7,7 +7,6 @@ import (
|
|||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"sync"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
"github.com/autobrr/autobrr/internal/logger"
|
||||
|
@ -18,55 +17,18 @@ import (
|
|||
)
|
||||
|
||||
type DownloadClientRepo struct {
|
||||
log zerolog.Logger
|
||||
db *DB
|
||||
cache *clientCache
|
||||
}
|
||||
|
||||
type clientCache struct {
|
||||
mu sync.RWMutex
|
||||
clients map[int]*domain.DownloadClient
|
||||
}
|
||||
|
||||
func NewClientCache() *clientCache {
|
||||
return &clientCache{
|
||||
clients: make(map[int]*domain.DownloadClient, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientCache) Set(id int, client *domain.DownloadClient) {
|
||||
c.mu.Lock()
|
||||
c.clients[id] = client
|
||||
c.mu.Unlock()
|
||||
}
|
||||
|
||||
func (c *clientCache) Get(id int) *domain.DownloadClient {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
v, ok := c.clients[id]
|
||||
if ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientCache) Pop(id int) {
|
||||
c.mu.Lock()
|
||||
delete(c.clients, id)
|
||||
c.mu.Unlock()
|
||||
log zerolog.Logger
|
||||
db *DB
|
||||
}
|
||||
|
||||
func NewDownloadClientRepo(log logger.Logger, db *DB) domain.DownloadClientRepo {
|
||||
return &DownloadClientRepo{
|
||||
log: log.With().Str("repo", "action").Logger(),
|
||||
db: db,
|
||||
cache: NewClientCache(),
|
||||
log: log.With().Str("repo", "action").Logger(),
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) {
|
||||
clients := make([]domain.DownloadClient, 0)
|
||||
|
||||
queryBuilder := r.db.squirrel.
|
||||
Select(
|
||||
"id",
|
||||
|
@ -100,6 +62,8 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient,
|
|||
}
|
||||
}(rows)
|
||||
|
||||
clients := make([]domain.DownloadClient, 0)
|
||||
|
||||
for rows.Next() {
|
||||
var f domain.DownloadClient
|
||||
var settingsJsonStr string
|
||||
|
@ -124,12 +88,6 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient,
|
|||
}
|
||||
|
||||
func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
|
||||
// get client from cache
|
||||
c := r.cache.Get(int(id))
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
queryBuilder := r.db.squirrel.
|
||||
Select(
|
||||
"id",
|
||||
|
@ -153,7 +111,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
|
|||
}
|
||||
|
||||
row := r.db.handler.QueryRowContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
if err := row.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "error executing query")
|
||||
}
|
||||
|
||||
|
@ -177,9 +135,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
|
|||
return &client, nil
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
|
||||
var err error
|
||||
|
||||
func (r *DownloadClientRepo) Store(ctx context.Context, client *domain.DownloadClient) error {
|
||||
settings := domain.DownloadClientSettings{
|
||||
APIKey: client.Settings.APIKey,
|
||||
Basic: client.Settings.Basic,
|
||||
|
@ -190,7 +146,7 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl
|
|||
|
||||
settingsJson, err := json.Marshal(&settings)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshal download client settings %+v", settings)
|
||||
return errors.Wrap(err, "error marshal download client settings %+v", settings)
|
||||
}
|
||||
|
||||
queryBuilder := r.db.squirrel.
|
||||
|
@ -204,22 +160,17 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl
|
|||
|
||||
err = queryBuilder.QueryRowContext(ctx).Scan(&retID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error executing query")
|
||||
return errors.Wrap(err, "error executing query")
|
||||
}
|
||||
|
||||
client.ID = retID
|
||||
client.ID = int32(retID)
|
||||
|
||||
r.log.Debug().Msgf("download_client.store: %d", client.ID)
|
||||
|
||||
// save to cache
|
||||
r.cache.Set(client.ID, &client)
|
||||
|
||||
return &client, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
|
||||
var err error
|
||||
|
||||
func (r *DownloadClientRepo) Update(ctx context.Context, client *domain.DownloadClient) error {
|
||||
settings := domain.DownloadClientSettings{
|
||||
APIKey: client.Settings.APIKey,
|
||||
Basic: client.Settings.Basic,
|
||||
|
@ -230,7 +181,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
|
|||
|
||||
settingsJson, err := json.Marshal(&settings)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error marshal download client settings %+v", settings)
|
||||
return errors.Wrap(err, "error marshal download client settings %+v", settings)
|
||||
}
|
||||
|
||||
queryBuilder := r.db.squirrel.
|
||||
|
@ -249,32 +200,29 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
|
|||
|
||||
query, args, err := queryBuilder.ToSql()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error building query")
|
||||
return errors.Wrap(err, "error building query")
|
||||
}
|
||||
|
||||
result, err := r.db.handler.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error executing query")
|
||||
return errors.Wrap(err, "error executing query")
|
||||
}
|
||||
|
||||
rowsAffected, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error getting rows affected")
|
||||
return errors.Wrap(err, "error getting rows affected")
|
||||
}
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return nil, errors.New("no rows updated")
|
||||
return errors.New("no rows updated")
|
||||
}
|
||||
|
||||
r.log.Debug().Msgf("download_client.update: %d", client.ID)
|
||||
|
||||
// save to cache
|
||||
r.cache.Set(client.ID, &client)
|
||||
|
||||
return &client, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
|
||||
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int32) error {
|
||||
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{})
|
||||
if err != nil {
|
||||
return err
|
||||
|
@ -311,10 +259,11 @@ func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
|
|||
}
|
||||
|
||||
r.log.Debug().Msgf("delete download client: %d", clientID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) error {
|
||||
func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int32) error {
|
||||
queryBuilder := r.db.squirrel.
|
||||
Delete("client").
|
||||
Where(sq.Eq{"id": clientID})
|
||||
|
@ -329,9 +278,6 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e
|
|||
return errors.Wrap(err, "error executing query")
|
||||
}
|
||||
|
||||
// remove from cache
|
||||
r.cache.Pop(clientID)
|
||||
|
||||
rows, _ := res.RowsAffected()
|
||||
if rows == 0 {
|
||||
return errors.New("no rows affected")
|
||||
|
@ -342,9 +288,7 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e
|
|||
return nil
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int) error {
|
||||
var err error
|
||||
|
||||
func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int32) error {
|
||||
queryBuilder := r.db.squirrel.
|
||||
Update("action").
|
||||
Set("enabled", false).
|
||||
|
@ -355,12 +299,14 @@ func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx,
|
|||
// return values
|
||||
var filterID int
|
||||
|
||||
if err = queryBuilder.QueryRowContext(ctx).Scan(&filterID); err != nil {
|
||||
err := queryBuilder.QueryRowContext(ctx).Scan(&filterID)
|
||||
if err != nil {
|
||||
// this will throw when the client is not connected to any actions
|
||||
// it is not an error in this case
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.Wrap(err, "error executing query")
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue