fix(actions): reject if client is disabled (#1626)

* fix(actions): error on disabled client

* fix(actions): sql scan args

* refactor: download client cache for actions

* fix: tests client store

* fix: tests client store and int conversion

* fix: tests revert findbyid ctx timeout

* fix: tests row.err

* feat: add logging to download client cache
This commit is contained in:
ze0s 2024-08-27 19:45:06 +02:00 committed by GitHub
parent 77e1c2c305
commit 861f30c144
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 928 additions and 680 deletions

View file

@ -117,7 +117,7 @@ func main() {
downloadClientService = download_client.NewService(log, downloadClientRepo) downloadClientService = download_client.NewService(log, downloadClientRepo)
actionService = action.NewService(log, actionRepo, downloadClientService, bus) actionService = action.NewService(log, actionRepo, downloadClientService, bus)
indexerService = indexer.NewService(log, cfg.Config, indexerRepo, releaseRepo, indexerAPIService, schedulingService) indexerService = indexer.NewService(log, cfg.Config, indexerRepo, releaseRepo, indexerAPIService, schedulingService)
filterService = filter.NewService(log, filterRepo, actionRepo, releaseRepo, indexerAPIService, indexerService) filterService = filter.NewService(log, filterRepo, actionService, releaseRepo, indexerAPIService, indexerService)
releaseService = release.NewService(log, releaseRepo, actionService, filterService, indexerService) releaseService = release.NewService(log, releaseRepo, actionService, filterService, indexerService)
ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService) ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService)
feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, schedulingService) feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, schedulingService)

View file

@ -7,7 +7,6 @@ import (
"context" "context"
"encoding/base64" "encoding/base64"
"os" "os"
"time"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
@ -20,20 +19,18 @@ func (s *service) deluge(ctx context.Context, action *domain.Action, release dom
var err error var err error
// get client for action client, err := s.clientSvc.GetClient(ctx, action.ClientID)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
return nil, err
} }
if client == nil { if !client.Enabled {
return nil, errors.New("could not find client by id: %d", action.ClientID) return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
} }
var rejections []string var rejections []string
switch client.Type { switch action.Client.Type {
case "DELUGE_V1": case "DELUGE_V1":
rejections, err = s.delugeV1(ctx, client, action, release) rejections, err = s.delugeV1(ctx, client, action, release)
@ -90,27 +87,18 @@ func (s *service) delugeCheckRulesCanDownload(ctx context.Context, del deluge.De
} }
func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) { func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) {
settings := deluge.Settings{ downloadClient := client.Client.(*deluge.Client)
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 30,
}
del := deluge.NewV1(settings)
// perform connection to Deluge server // perform connection to Deluge server
err := del.Connect(ctx) err := downloadClient.Connect(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host) return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host)
} }
defer del.Close() defer downloadClient.Close()
// perform connection to Deluge server // perform connection to Deluge server
rejections, err := s.delugeCheckRulesCanDownload(ctx, del, client, action) rejections, err := s.delugeCheckRulesCanDownload(ctx, downloadClient, client, action)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name) s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name)
return nil, err return nil, err
@ -127,13 +115,13 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options) s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentMagnet(ctx, release.MagnetURI, &options) torrentHash, err := downloadClient.AddTorrentMagnet(ctx, release.MagnetURI, &options)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name) return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name)
} }
if action.Label != "" { if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx) labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
} }
@ -176,13 +164,13 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options) s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options) torrentHash, err := downloadClient.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not add torrent %v to client: %v", release.TorrentTmpFile, client.Name) return nil, errors.Wrap(err, "could not add torrent %v to client: %v", release.TorrentTmpFile, client.Name)
} }
if action.Label != "" { if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx) labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
} }
@ -203,27 +191,18 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
} }
func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) { func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) {
settings := deluge.Settings{ downloadClient := client.Client.(*deluge.ClientV2)
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 30,
}
del := deluge.NewV2(settings)
// perform connection to Deluge server // perform connection to Deluge server
err := del.Connect(ctx) err := downloadClient.Connect(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host) return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host)
} }
defer del.Close() defer downloadClient.Close()
// perform connection to Deluge server // perform connection to Deluge server
rejections, err := s.delugeCheckRulesCanDownload(ctx, del, client, action) rejections, err := s.delugeCheckRulesCanDownload(ctx, downloadClient, client, action)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name) s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name)
return nil, err return nil, err
@ -240,13 +219,13 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options) s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentMagnet(ctx, release.MagnetURI, &options) torrentHash, err := downloadClient.AddTorrentMagnet(ctx, release.MagnetURI, &options)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name) return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name)
} }
if action.Label != "" { if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx) labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
} }
@ -290,13 +269,13 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options) s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options) torrentHash, err := downloadClient.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, client.Name) return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, client.Name)
} }
if action.Label != "" { if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx) labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
} }

View file

@ -17,41 +17,16 @@ func (s *service) lidarr(ctx context.Context, action *domain.Action, release dom
// TODO validate data // TODO validate data
// get client for action client, err := s.clientSvc.GetClient(ctx, action.ClientID)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("lidarr: error finding client: %v", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
return nil, err
} }
// return early if no client found if !client.Enabled {
if client == nil { return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
return nil, errors.New("could not find client by id: %v", action.ClientID)
} }
// initial config arr := client.Client.(lidarr.Client)
cfg := lidarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
r := lidarr.Release{ r := lidarr.Release{
Title: release.TorrentName, Title: release.TorrentName,
@ -60,14 +35,20 @@ func (s *service) lidarr(ctx context.Context, action *domain.Action, release dom
MagnetUrl: release.MagnetURI, MagnetUrl: release.MagnetURI,
Size: int64(release.Size), Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(), Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId, DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: externalClient, DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(), DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(), Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
arr := lidarr.New(cfg) if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {

View file

@ -13,34 +13,21 @@ import (
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/porla" "github.com/autobrr/autobrr/pkg/porla"
"github.com/dcarbone/zadapters/zstdlog"
"github.com/rs/zerolog"
) )
func (s *service) porla(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { func (s *service) porla(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action Porla: %s", action.Name) s.log.Debug().Msgf("action Porla: %s", action.Name)
client, err := s.clientSvc.FindByID(ctx, action.ClientID) client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error finding client: %d", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
} }
if client == nil { if !client.Enabled {
return nil, errors.New("could not find client by id: %d", action.ClientID) return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
} }
porlaSettings := porla.Config{ prl := client.Client.(*porla.Client)
Hostname: client.Host,
AuthToken: client.Settings.APIKey,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
}
porlaSettings.Log = zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Porla").Str("client", client.Name).Logger(), zerolog.TraceLevel)
prl := porla.NewClient(porlaSettings)
rejections, err := s.porlaCheckRulesCanDownload(ctx, action, client, prl) rejections, err := s.porlaCheckRulesCanDownload(ctx, action, client, prl)
if err != nil { if err != nil {

View file

@ -17,11 +17,20 @@ import (
func (s *service) qbittorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { func (s *service) qbittorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action qBittorrent: %s", action.Name) s.log.Debug().Msgf("action qBittorrent: %s", action.Name)
c := s.clientSvc.GetCachedClient(ctx, action.ClientID) client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
if c.Dc.Settings.Rules.Enabled && !action.IgnoreRules { if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
qbtClient := client.Client.(*qbittorrent.Client)
if client.Settings.Rules.Enabled && !action.IgnoreRules {
// check for active downloads and other rules // check for active downloads and other rules
rejections, err := s.qbittorrentCheckRulesCanDownload(ctx, action, c.Dc.Settings.Rules, c.Qbt) rejections, err := s.qbittorrentCheckRulesCanDownload(ctx, action, client.Settings.Rules, qbtClient)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error checking client rules: %s", action.Name) return nil, errors.Wrap(err, "error checking client rules: %s", action.Name)
} }
@ -39,11 +48,11 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
s.log.Trace().Msgf("action qBittorrent options: %+v", options) s.log.Trace().Msgf("action qBittorrent options: %+v", options)
if err = c.Qbt.AddTorrentFromUrlCtx(ctx, release.MagnetURI, options); err != nil { if err = qbtClient.AddTorrentFromUrlCtx(ctx, release.MagnetURI, options); err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.MagnetURI, c.Dc.Name) return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.MagnetURI, client.Name)
} }
s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", c.Dc.Name) s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", client.Name)
return nil, nil return nil, nil
} }
@ -61,37 +70,37 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
s.log.Trace().Msgf("action qBittorrent options: %+v", options) s.log.Trace().Msgf("action qBittorrent options: %+v", options)
if err = c.Qbt.AddTorrentFromFileCtx(ctx, release.TorrentTmpFile, options); err != nil { if err = qbtClient.AddTorrentFromFileCtx(ctx, release.TorrentTmpFile, options); err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, c.Dc.Name) return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, client.Name)
} }
if release.TorrentHash != "" { if release.TorrentHash != "" {
// check if torrent queueing is enabled if priority is set // check if torrent queueing is enabled if priority is set
switch action.PriorityLayout { switch action.PriorityLayout {
case domain.PriorityLayoutMax, domain.PriorityLayoutMin: case domain.PriorityLayoutMax, domain.PriorityLayoutMin:
prefs, err := c.Qbt.GetAppPreferencesCtx(ctx) prefs, err := qbtClient.GetAppPreferencesCtx(ctx)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get application preferences from client: '%s'", c.Dc.Name) return nil, errors.Wrap(err, "could not get application preferences from client: '%s'", client.Name)
} }
// enable queueing if it's disabled // enable queueing if it's disabled
if !prefs.QueueingEnabled { if !prefs.QueueingEnabled {
if err := c.Qbt.SetPreferencesQueueingEnabled(true); err != nil { if err := qbtClient.SetPreferencesQueueingEnabled(true); err != nil {
return nil, errors.Wrap(err, "could not enable torrent queueing") return nil, errors.Wrap(err, "could not enable torrent queueing")
} }
s.log.Trace().Msgf("torrent queueing was disabled, now enabled in client: '%s'", c.Dc.Name) s.log.Trace().Msgf("torrent queueing was disabled, now enabled in client: '%s'", client.Name)
} }
// set priority if queueing is enabled // set priority if queueing is enabled
if action.PriorityLayout == domain.PriorityLayoutMax { if action.PriorityLayout == domain.PriorityLayoutMax {
if err := c.Qbt.SetMaxPriorityCtx(ctx, []string{release.TorrentHash}); err != nil { if err := qbtClient.SetMaxPriorityCtx(ctx, []string{release.TorrentHash}); err != nil {
return nil, errors.Wrap(err, "could not set torrent %s to max priority", release.TorrentHash) return nil, errors.Wrap(err, "could not set torrent %s to max priority", release.TorrentHash)
} }
s.log.Debug().Msgf("torrent with hash %s set to max priority in client: '%s'", release.TorrentHash, c.Dc.Name) s.log.Debug().Msgf("torrent with hash %s set to max priority in client: '%s'", release.TorrentHash, client.Name)
} else { // domain.PriorityLayoutMin } else { // domain.PriorityLayoutMin
if err := c.Qbt.SetMinPriorityCtx(ctx, []string{release.TorrentHash}); err != nil { if err := qbtClient.SetMinPriorityCtx(ctx, []string{release.TorrentHash}); err != nil {
return nil, errors.Wrap(err, "could not set torrent %s to min priority", release.TorrentHash) return nil, errors.Wrap(err, "could not set torrent %s to min priority", release.TorrentHash)
} }
s.log.Debug().Msgf("torrent with hash %s set to min priority in client: '%s'", release.TorrentHash, c.Dc.Name) s.log.Debug().Msgf("torrent with hash %s set to min priority in client: '%s'", release.TorrentHash, client.Name)
} }
case domain.PriorityLayoutDefault: case domain.PriorityLayoutDefault:
@ -111,7 +120,7 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
DeleteOnFailure: action.ReAnnounceDelete, DeleteOnFailure: action.ReAnnounceDelete,
} }
if err := c.Qbt.ReannounceTorrentWithRetry(ctx, release.TorrentHash, &opts); err != nil { if err := qbtClient.ReannounceTorrentWithRetry(ctx, release.TorrentHash, &opts); err != nil {
if errors.Is(err, qbittorrent.ErrReannounceTookTooLong) { if errors.Is(err, qbittorrent.ErrReannounceTookTooLong) {
return []string{fmt.Sprintf("re-announce took too long for hash: %s", release.TorrentHash)}, nil return []string{fmt.Sprintf("re-announce took too long for hash: %s", release.TorrentHash)}, nil
} }
@ -120,7 +129,7 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
} }
} }
s.log.Info().Msgf("torrent with hash %s successfully added to client: '%s'", release.TorrentHash, c.Dc.Name) s.log.Info().Msgf("torrent with hash %s successfully added to client: '%s'", release.TorrentHash, client.Name)
return nil, nil return nil, nil
} }

View file

@ -17,40 +17,16 @@ func (s *service) radarr(ctx context.Context, action *domain.Action, release dom
// TODO validate data // TODO validate data
// get client for action client, err := s.clientSvc.GetClient(ctx, action.ClientID)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error finding client: %v", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
} }
// return early if no client found if !client.Enabled {
if client == nil { return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
return nil, errors.New("could not find client by id: %v", action.ClientID)
} }
// initial config arr := client.Client.(radarr.Client)
cfg := radarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
r := radarr.Release{ r := radarr.Release{
Title: release.TorrentName, Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) radarr(ctx context.Context, action *domain.Action, release dom
MagnetUrl: release.MagnetURI, MagnetUrl: release.MagnetURI,
Size: int64(release.Size), Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(), Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId, DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: externalClient, DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(), DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(), Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
arr := radarr.New(cfg) if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {

View file

@ -17,40 +17,16 @@ func (s *service) readarr(ctx context.Context, action *domain.Action, release do
// TODO validate data // TODO validate data
// get client for action client, err := s.clientSvc.GetClient(ctx, action.ClientID)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "readarr could not find client: %v", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
} }
// return early if no client found if !client.Enabled {
if client == nil { return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
return nil, errors.New("no client found")
} }
// initial config arr := client.Client.(readarr.Client)
cfg := readarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
r := readarr.Release{ r := readarr.Release{
Title: release.TorrentName, Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) readarr(ctx context.Context, action *domain.Action, release do
MagnetUrl: release.MagnetURI, MagnetUrl: release.MagnetURI,
Size: int64(release.Size), Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(), Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId, DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: externalClient, DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(), DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(), Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
arr := readarr.New(cfg) if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {

View file

@ -16,32 +16,19 @@ import (
func (s *service) rtorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { func (s *service) rtorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action rTorrent: %s", action.Name) s.log.Debug().Msgf("action rTorrent: %s", action.Name)
var err error client, err := s.clientSvc.GetClient(ctx, action.ClientID)
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
return nil, err
} }
if client == nil { if !client.Enabled {
return nil, errors.New("could not find client by id: %d", action.ClientID) return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
} }
rt := client.Client.(*rtorrent.Client)
var rejections []string var rejections []string
// create config
cfg := rtorrent.Config{
Addr: client.Host,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
}
// create client
rt := rtorrent.NewClient(cfg)
if release.HasMagnetUri() { if release.HasMagnetUri() {
var args []*rtorrent.FieldValue var args []*rtorrent.FieldValue
@ -79,8 +66,8 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d
s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", client.Name) s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", client.Name)
return nil, nil return nil, nil
}
} else {
if release.TorrentTmpFile == "" { if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
@ -127,7 +114,6 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d
} }
s.log.Info().Msgf("torrent successfully added to client: '%s'", client.Name) s.log.Info().Msgf("torrent successfully added to client: '%s'", client.Name)
}
return rejections, nil return rejections, nil
} }

View file

@ -19,7 +19,6 @@ import (
) )
func (s *service) RunAction(ctx context.Context, action *domain.Action, release *domain.Release) ([]string, error) { func (s *service) RunAction(ctx context.Context, action *domain.Action, release *domain.Release) ([]string, error) {
var ( var (
err error err error
rejections []string rejections []string
@ -33,6 +32,10 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release
} }
}() }()
if action.ClientID > 0 && action.Client != nil && !action.Client.Enabled {
return nil, errors.New("action %s client %s %s not enabled, skipping", action.Name, action.Client.Type, action.Client.Name)
}
// if set, try to resolve MagnetURI before parsing macros // if set, try to resolve MagnetURI before parsing macros
// to allow webhook and exec to get the magnet_uri // to allow webhook and exec to get the magnet_uri
if err := release.ResolveMagnetUri(ctx); err != nil { if err := release.ResolveMagnetUri(ctx); err != nil {

View file

@ -18,29 +18,16 @@ func (s *service) sabnzbd(ctx context.Context, action *domain.Action, release do
return nil, errors.New("action type: %s invalid protocol: %s", action.Type, release.Protocol) return nil, errors.New("action type: %s invalid protocol: %s", action.Type, release.Protocol)
} }
// get client for action client, err := s.clientSvc.GetClient(ctx, action.ClientID)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %d", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
} }
// return early if no client found if !client.Enabled {
if client == nil { return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
return nil, errors.New("no sabnzbd client found by id: %d", action.ClientID)
} }
opts := sabnzbd.Options{ sab := client.Client.(*sabnzbd.Client)
Addr: client.Host,
ApiKey: client.Settings.APIKey,
Log: nil,
}
if client.Settings.Basic.Auth {
opts.BasicUser = client.Settings.Basic.Username
opts.BasicPass = client.Settings.Basic.Password
}
sab := sabnzbd.New(opts)
ids, err := sab.AddFromUrl(ctx, sabnzbd.AddNzbRequest{Url: release.DownloadURL, Category: action.Category}) ids, err := sab.AddFromUrl(ctx, sabnzbd.AddNzbRequest{Url: release.DownloadURL, Category: action.Category})
if err != nil { if err != nil {

View file

@ -21,9 +21,10 @@ import (
type Service interface { type Service interface {
Store(ctx context.Context, action domain.Action) (*domain.Action, error) Store(ctx context.Context, action domain.Action) (*domain.Action, error)
StoreFilterActions(ctx context.Context, filterID int64, actions []*domain.Action) ([]*domain.Action, error)
List(ctx context.Context) ([]domain.Action, error) List(ctx context.Context) ([]domain.Action, error)
Get(ctx context.Context, req *domain.GetActionRequest) (*domain.Action, error) Get(ctx context.Context, req *domain.GetActionRequest) (*domain.Action, error)
FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error)
Delete(ctx context.Context, req *domain.DeleteActionRequest) error Delete(ctx context.Context, req *domain.DeleteActionRequest) error
DeleteByFilterID(ctx context.Context, filterID int) error DeleteByFilterID(ctx context.Context, filterID int) error
ToggleEnabled(actionID int) error ToggleEnabled(actionID int) error
@ -63,6 +64,10 @@ func (s *service) Store(ctx context.Context, action domain.Action) (*domain.Acti
return s.repo.Store(ctx, action) return s.repo.Store(ctx, action)
} }
func (s *service) StoreFilterActions(ctx context.Context, filterID int64, actions []*domain.Action) ([]*domain.Action, error) {
return s.repo.StoreFilterActions(ctx, filterID, actions)
}
func (s *service) List(ctx context.Context) ([]domain.Action, error) { func (s *service) List(ctx context.Context) ([]domain.Action, error) {
return s.repo.List(ctx) return s.repo.List(ctx)
} }
@ -86,8 +91,8 @@ func (s *service) Get(ctx context.Context, req *domain.GetActionRequest) (*domai
return a, nil return a, nil
} }
func (s *service) FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) { func (s *service) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error) {
return s.repo.FindByFilterID(ctx, filterID, active) return s.repo.FindByFilterID(ctx, filterID, active, withClient)
} }
func (s *service) Delete(ctx context.Context, req *domain.DeleteActionRequest) error { func (s *service) Delete(ctx context.Context, req *domain.DeleteActionRequest) error {

View file

@ -17,40 +17,16 @@ func (s *service) sonarr(ctx context.Context, action *domain.Action, release dom
// TODO validate data // TODO validate data
// get client for action client, err := s.clientSvc.GetClient(ctx, action.ClientID)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
} }
// return early if no client found if !client.Enabled {
if client == nil { return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
return nil, errors.New("no client found")
} }
// initial config arr := client.Client.(sonarr.Client)
cfg := sonarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
r := sonarr.Release{ r := sonarr.Release{
Title: release.TorrentName, Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) sonarr(ctx context.Context, action *domain.Action, release dom
MagnetUrl: release.MagnetURI, MagnetUrl: release.MagnetURI,
Size: int64(release.Size), Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(), Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId, DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: externalClient, DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(), DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(), Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
arr := sonarr.New(cfg) if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {

View file

@ -6,14 +6,11 @@ package action
import ( import (
"context" "context"
"fmt" "fmt"
"net/url"
"strings" "strings"
"time" "time"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/transmission"
"github.com/hekmon/transmissionrpc/v3" "github.com/hekmon/transmissionrpc/v3"
) )
@ -28,38 +25,16 @@ var TrTrue = true
func (s *service) transmission(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { func (s *service) transmission(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action Transmission: %s", action.Name) s.log.Debug().Msgf("action Transmission: %s", action.Name)
var err error client, err := s.clientSvc.GetClient(ctx, action.ClientID)
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
return nil, err
} }
if client == nil { if !client.Enabled {
return nil, errors.New("could not find client by id: %d", action.ClientID) return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
} }
scheme := "http" tbt := client.Client.(*transmissionrpc.Client)
if client.TLS {
scheme = "https"
}
u, err := url.Parse(fmt.Sprintf("%s://%s:%d/transmission/rpc", scheme, client.Host, client.Port))
if err != nil {
return nil, err
}
tbt, err := transmission.New(u, &transmission.Config{
UserAgent: "autobrr",
Username: client.Username,
Password: client.Password,
TLSSkipVerify: client.TLSSkipVerify,
})
if err != nil {
return nil, errors.Wrap(err, "error logging into client: %s", client.Host)
}
rejections, err := s.transmissionCheckRulesCanDownload(ctx, action, client, tbt) rejections, err := s.transmissionCheckRulesCanDownload(ctx, action, client, tbt)
if err != nil { if err != nil {

View file

@ -17,40 +17,16 @@ func (s *service) whisparr(ctx context.Context, action *domain.Action, release d
// TODO validate data // TODO validate data
// get client for action client, err := s.clientSvc.GetClient(ctx, action.ClientID)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
} }
// return early if no client found if !client.Enabled {
if client == nil { return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
return nil, errors.New("could not find client by id: %v", action.ClientID)
} }
// initial config arr := client.Client.(whisparr.Client)
cfg := whisparr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
r := whisparr.Release{ r := whisparr.Release{
Title: release.TorrentName, Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) whisparr(ctx context.Context, action *domain.Action, release d
MagnetUrl: release.MagnetURI, MagnetUrl: release.MagnetURI,
Size: int64(release.Size), Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(), Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId, DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: externalClient, DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(), DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(), Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
arr := whisparr.New(cfg) if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {

View file

@ -30,7 +30,266 @@ func NewActionRepo(log logger.Logger, db *DB, clientRepo domain.DownloadClientRe
} }
} }
func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) { func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error) {
if withClient {
return r.findByFilterIDWithClient(ctx, filterID, active)
}
return r.findByFilterID(ctx, filterID, active)
}
func (r *ActionRepo) findByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
queryBuilder := r.db.squirrel.
Select(
"a.id",
"a.name",
"a.type",
"a.enabled",
"a.exec_cmd",
"a.exec_args",
"a.watch_folder",
"a.category",
"a.tags",
"a.label",
"a.save_path",
"a.paused",
"a.ignore_rules",
"a.first_last_piece_prio",
"a.skip_hash_check",
"a.content_layout",
"a.priority",
"a.limit_download_speed",
"a.limit_upload_speed",
"a.limit_ratio",
"a.limit_seed_time",
"a.reannounce_skip",
"a.reannounce_delete",
"a.reannounce_interval",
"a.reannounce_max_attempts",
"a.webhook_host",
"a.webhook_type",
"a.webhook_method",
"a.webhook_data",
"a.external_client_id",
"a.external_client",
"a.client_id",
).
From("action a").
Where(sq.Eq{"a.filter_id": filterID})
if active != nil {
queryBuilder = queryBuilder.Where(sq.Eq{"enabled": *active})
}
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
rows, err := r.db.handler.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
defer rows.Close()
actions := make([]*domain.Action, 0)
for rows.Next() {
var a domain.Action
var execCmd, execArgs, watchFolder, category, tags, label, savePath, contentLayout, priorityLayout, webhookHost, webhookType, webhookMethod, webhookData, externalClient sql.NullString
var limitUl, limitDl, limitSeedTime sql.NullInt64
var limitRatio sql.NullFloat64
var externalClientID, clientID sql.NullInt32
var paused, ignoreRules sql.NullBool
if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &a.FirstLastPiecePrio, &a.SkipHashCheck, &contentLayout, &priorityLayout, &limitDl, &limitUl, &limitRatio, &limitSeedTime, &a.ReAnnounceSkip, &a.ReAnnounceDelete, &a.ReAnnounceInterval, &a.ReAnnounceMaxAttempts, &webhookHost, &webhookType, &webhookMethod, &webhookData, &externalClientID, &externalClient, &clientID); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
a.ExecCmd = execCmd.String
a.ExecArgs = execArgs.String
a.WatchFolder = watchFolder.String
a.Category = category.String
a.Tags = tags.String
a.Label = label.String
a.SavePath = savePath.String
a.Paused = paused.Bool
a.IgnoreRules = ignoreRules.Bool
a.ContentLayout = domain.ActionContentLayout(contentLayout.String)
a.PriorityLayout = domain.PriorityLayout(priorityLayout.String)
a.LimitDownloadSpeed = limitDl.Int64
a.LimitUploadSpeed = limitUl.Int64
a.LimitRatio = limitRatio.Float64
a.LimitSeedTime = limitSeedTime.Int64
a.WebhookHost = webhookHost.String
a.WebhookType = webhookType.String
a.WebhookMethod = webhookMethod.String
a.WebhookData = webhookData.String
a.ExternalDownloadClientID = externalClientID.Int32
a.ExternalDownloadClient = externalClient.String
a.ClientID = clientID.Int32
actions = append(actions, &a)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "row error")
}
return actions, nil
}
func (r *ActionRepo) findByFilterIDWithClient(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
queryBuilder := r.db.squirrel.
Select(
"a.id",
"a.name",
"a.type",
"a.enabled",
"a.exec_cmd",
"a.exec_args",
"a.watch_folder",
"a.category",
"a.tags",
"a.label",
"a.save_path",
"a.paused",
"a.ignore_rules",
"a.first_last_piece_prio",
"a.skip_hash_check",
"a.content_layout",
"a.priority",
"a.limit_download_speed",
"a.limit_upload_speed",
"a.limit_ratio",
"a.limit_seed_time",
"a.reannounce_skip",
"a.reannounce_delete",
"a.reannounce_interval",
"a.reannounce_max_attempts",
"a.webhook_host",
"a.webhook_type",
"a.webhook_method",
"a.webhook_data",
"a.external_client_id",
"a.external_client",
"a.client_id",
"c.id",
"c.name",
"c.type",
"c.enabled",
"c.host",
"c.port",
"c.tls",
"c.tls_skip_verify",
"c.username",
"c.password",
"c.settings",
).
From("action a").
Join("client c ON a.client_id = c.id").
Where(sq.Eq{"a.filter_id": filterID})
if active != nil {
queryBuilder = queryBuilder.Where(sq.Eq{"enabled": *active})
}
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
rows, err := r.db.handler.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
defer rows.Close()
actions := make([]*domain.Action, 0)
for rows.Next() {
var a domain.Action
var c domain.DownloadClient
var execCmd, execArgs, watchFolder, category, tags, label, savePath, contentLayout, priorityLayout, webhookHost, webhookType, webhookMethod, webhookData, externalClient sql.NullString
var limitUl, limitDl, limitSeedTime sql.NullInt64
var limitRatio sql.NullFloat64
var externalClientID, clientID sql.NullInt32
var paused, ignoreRules sql.NullBool
var clientClientId, clientPort sql.Null[int32]
var clientName, clientType, clientHost, clientUsername, clientPassword, clientSettings sql.Null[string]
var clientEnabled, clientTLS, clientTLSSkip sql.Null[bool]
if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &a.FirstLastPiecePrio, &a.SkipHashCheck, &contentLayout, &priorityLayout, &limitDl, &limitUl, &limitRatio, &limitSeedTime, &a.ReAnnounceSkip, &a.ReAnnounceDelete, &a.ReAnnounceInterval, &a.ReAnnounceMaxAttempts, &webhookHost, &webhookType, &webhookMethod, &webhookData, &externalClientID, &externalClient, &clientID, &clientClientId, &clientName, &clientType, &clientEnabled, &clientHost, &clientPort, &clientTLS, &clientTLSSkip, &clientUsername, &clientPassword, &clientSettings); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
a.ExecCmd = execCmd.String
a.ExecArgs = execArgs.String
a.WatchFolder = watchFolder.String
a.Category = category.String
a.Tags = tags.String
a.Label = label.String
a.SavePath = savePath.String
a.Paused = paused.Bool
a.IgnoreRules = ignoreRules.Bool
a.ContentLayout = domain.ActionContentLayout(contentLayout.String)
a.PriorityLayout = domain.PriorityLayout(priorityLayout.String)
a.LimitDownloadSpeed = limitDl.Int64
a.LimitUploadSpeed = limitUl.Int64
a.LimitRatio = limitRatio.Float64
a.LimitSeedTime = limitSeedTime.Int64
a.WebhookHost = webhookHost.String
a.WebhookType = webhookType.String
a.WebhookMethod = webhookMethod.String
a.WebhookData = webhookData.String
a.ExternalDownloadClientID = externalClientID.Int32
a.ExternalDownloadClient = externalClient.String
a.ClientID = clientID.Int32
c.ID = clientClientId.V
c.Name = clientName.V
c.Type = domain.DownloadClientType(clientType.V)
c.Enabled = clientEnabled.V
c.Host = clientHost.V
c.Port = int(clientPort.V)
c.TLS = clientTLS.V
c.TLSSkipVerify = clientTLSSkip.V
c.Username = clientUsername.V
c.Password = clientPassword.V
//c.Settings = clientSettings.String
if a.ClientID > 0 {
if clientSettings.Valid {
if err := json.Unmarshal([]byte(clientSettings.V), &c.Settings); err != nil {
return nil, errors.Wrap(err, "could not unmarshal download client settings: %v", clientSettings.V)
}
}
a.Client = &c
}
actions = append(actions, &a)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "row error")
}
return actions, nil
}
func (r *ActionRepo) FindByFilterIDTx(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil { if err != nil {
return nil, err return nil, err
@ -38,7 +297,7 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *b
defer tx.Rollback() defer tx.Rollback()
actions, err := r.findByFilterID(ctx, tx, filterID, active) actions, err := r.findByFilterIDTx(ctx, tx, filterID, active)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -59,7 +318,7 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *b
return actions, nil return actions, nil
} }
func (r *ActionRepo) findByFilterID(ctx context.Context, tx *Tx, filterID int, active *bool) ([]*domain.Action, error) { func (r *ActionRepo) findByFilterIDTx(ctx context.Context, tx *Tx, filterID int, active *bool) ([]*domain.Action, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select( Select(
"id", "id",

View file

@ -62,9 +62,10 @@ func TestActionRepo_Store(t *testing.T) {
t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -73,7 +74,7 @@ func TestActionRepo_Store(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
// Actual test for Store // Actual test for Store
@ -84,7 +85,7 @@ func TestActionRepo_Store(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("Store_Succeeds_With_Missing_or_empty_fields [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Store_Succeeds_With_Missing_or_empty_fields [%s]", dbType), func(t *testing.T) {
@ -125,9 +126,10 @@ func TestActionRepo_StoreFilterActions(t *testing.T) {
t.Run(fmt.Sprintf("StoreFilterActions_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("StoreFilterActions_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -136,7 +138,7 @@ func TestActionRepo_StoreFilterActions(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
// Actual test for StoreFilterActions // Actual test for StoreFilterActions
@ -148,7 +150,7 @@ func TestActionRepo_StoreFilterActions(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("StoreFilterActions_Fails_Invalid_FilterID [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("StoreFilterActions_Fails_Invalid_FilterID [%s]", dbType), func(t *testing.T) {
@ -203,9 +205,10 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
t.Run(fmt.Sprintf("FindByFilterID_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByFilterID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -214,13 +217,13 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err) assert.NoError(t, err)
// Actual test for FindByFilterID // Actual test for FindByFilterID
actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil) actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil, false)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, actions) assert.NotNil(t, actions)
assert.Equal(t, 1, len(actions)) assert.Equal(t, 1, len(actions))
@ -228,7 +231,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("FindByFilterID_Fails_No_Actions [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByFilterID_Fails_No_Actions [%s]", dbType), func(t *testing.T) {
@ -241,7 +244,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
// Actual test for FindByFilterID // Actual test for FindByFilterID
actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil) actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil, false)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 0, len(actions)) assert.Equal(t, 0, len(actions))
@ -250,7 +253,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
}) })
t.Run(fmt.Sprintf("FindByFilterID_Succeeds_With_Invalid_FilterID [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByFilterID_Succeeds_With_Invalid_FilterID [%s]", dbType), func(t *testing.T) {
actions, err := repo.FindByFilterID(context.Background(), 9999, nil) // 9999 is an invalid filter ID actions, err := repo.FindByFilterID(context.Background(), 9999, nil, false) // 9999 is an invalid filter ID
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, actions) assert.NotNil(t, actions)
assert.Equal(t, 0, len(actions)) assert.Equal(t, 0, len(actions))
@ -260,7 +263,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel() defer cancel()
actions, err := repo.FindByFilterID(ctx, 1, nil) actions, err := repo.FindByFilterID(ctx, 1, nil, false)
assert.Error(t, err) assert.Error(t, err)
assert.Nil(t, actions) assert.Nil(t, actions)
}) })
@ -277,9 +280,10 @@ func TestActionRepo_List(t *testing.T) {
t.Run(fmt.Sprintf("List_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -288,7 +292,7 @@ func TestActionRepo_List(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err) assert.NoError(t, err)
@ -302,7 +306,7 @@ func TestActionRepo_List(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("List_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) {
@ -326,9 +330,10 @@ func TestActionRepo_Get(t *testing.T) {
t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -337,7 +342,7 @@ func TestActionRepo_Get(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err) assert.NoError(t, err)
@ -351,7 +356,7 @@ func TestActionRepo_Get(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("Get_Fails_No_Record [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Get_Fails_No_Record [%s]", dbType), func(t *testing.T) {
@ -382,9 +387,10 @@ func TestActionRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -393,7 +399,7 @@ func TestActionRepo_Delete(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err) assert.NoError(t, err)
@ -411,7 +417,7 @@ func TestActionRepo_Delete(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("Delete_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Delete_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) {
@ -435,9 +441,10 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) {
t.Run(fmt.Sprintf("DeleteByFilterID_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("DeleteByFilterID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -446,7 +453,7 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err) assert.NoError(t, err)
@ -463,7 +470,7 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("DeleteByFilterID_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("DeleteByFilterID_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) {
@ -486,9 +493,10 @@ func TestActionRepo_ToggleEnabled(t *testing.T) {
t.Run(fmt.Sprintf("ToggleEnabled_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("ToggleEnabled_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -497,7 +505,7 @@ func TestActionRepo_ToggleEnabled(t *testing.T) {
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID) mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
mockData.Enabled = false mockData.Enabled = false
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
@ -515,7 +523,7 @@ func TestActionRepo_ToggleEnabled(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("ToggleEnabled_Fails_No_Record [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("ToggleEnabled_Fails_No_Record [%s]", dbType), func(t *testing.T) {

View file

@ -7,7 +7,6 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"sync"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/logger"
@ -20,53 +19,16 @@ import (
type DownloadClientRepo struct { type DownloadClientRepo struct {
log zerolog.Logger log zerolog.Logger
db *DB db *DB
cache *clientCache
}
type clientCache struct {
mu sync.RWMutex
clients map[int]*domain.DownloadClient
}
func NewClientCache() *clientCache {
return &clientCache{
clients: make(map[int]*domain.DownloadClient, 0),
}
}
func (c *clientCache) Set(id int, client *domain.DownloadClient) {
c.mu.Lock()
c.clients[id] = client
c.mu.Unlock()
}
func (c *clientCache) Get(id int) *domain.DownloadClient {
c.mu.RLock()
defer c.mu.RUnlock()
v, ok := c.clients[id]
if ok {
return v
}
return nil
}
func (c *clientCache) Pop(id int) {
c.mu.Lock()
delete(c.clients, id)
c.mu.Unlock()
} }
func NewDownloadClientRepo(log logger.Logger, db *DB) domain.DownloadClientRepo { func NewDownloadClientRepo(log logger.Logger, db *DB) domain.DownloadClientRepo {
return &DownloadClientRepo{ return &DownloadClientRepo{
log: log.With().Str("repo", "action").Logger(), log: log.With().Str("repo", "action").Logger(),
db: db, db: db,
cache: NewClientCache(),
} }
} }
func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) { func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) {
clients := make([]domain.DownloadClient, 0)
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select( Select(
"id", "id",
@ -100,6 +62,8 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient,
} }
}(rows) }(rows)
clients := make([]domain.DownloadClient, 0)
for rows.Next() { for rows.Next() {
var f domain.DownloadClient var f domain.DownloadClient
var settingsJsonStr string var settingsJsonStr string
@ -124,12 +88,6 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient,
} }
func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) { func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
// get client from cache
c := r.cache.Get(int(id))
if c != nil {
return c, nil
}
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select( Select(
"id", "id",
@ -153,7 +111,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
} }
row := r.db.handler.QueryRowContext(ctx, query, args...) row := r.db.handler.QueryRowContext(ctx, query, args...)
if err != nil { if err := row.Err(); err != nil {
return nil, errors.Wrap(err, "error executing query") return nil, errors.Wrap(err, "error executing query")
} }
@ -177,9 +135,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return &client, nil return &client, nil
} }
func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { func (r *DownloadClientRepo) Store(ctx context.Context, client *domain.DownloadClient) error {
var err error
settings := domain.DownloadClientSettings{ settings := domain.DownloadClientSettings{
APIKey: client.Settings.APIKey, APIKey: client.Settings.APIKey,
Basic: client.Settings.Basic, Basic: client.Settings.Basic,
@ -190,7 +146,7 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl
settingsJson, err := json.Marshal(&settings) settingsJson, err := json.Marshal(&settings)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshal download client settings %+v", settings) return errors.Wrap(err, "error marshal download client settings %+v", settings)
} }
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
@ -204,22 +160,17 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl
err = queryBuilder.QueryRowContext(ctx).Scan(&retID) err = queryBuilder.QueryRowContext(ctx).Scan(&retID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }
client.ID = retID client.ID = int32(retID)
r.log.Debug().Msgf("download_client.store: %d", client.ID) r.log.Debug().Msgf("download_client.store: %d", client.ID)
// save to cache return nil
r.cache.Set(client.ID, &client)
return &client, nil
} }
func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { func (r *DownloadClientRepo) Update(ctx context.Context, client *domain.DownloadClient) error {
var err error
settings := domain.DownloadClientSettings{ settings := domain.DownloadClientSettings{
APIKey: client.Settings.APIKey, APIKey: client.Settings.APIKey,
Basic: client.Settings.Basic, Basic: client.Settings.Basic,
@ -230,7 +181,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
settingsJson, err := json.Marshal(&settings) settingsJson, err := json.Marshal(&settings)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error marshal download client settings %+v", settings) return errors.Wrap(err, "error marshal download client settings %+v", settings)
} }
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
@ -249,32 +200,29 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
query, args, err := queryBuilder.ToSql() query, args, err := queryBuilder.ToSql()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error building query") return errors.Wrap(err, "error building query")
} }
result, err := r.db.handler.ExecContext(ctx, query, args...) result, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }
rowsAffected, err := result.RowsAffected() rowsAffected, err := result.RowsAffected()
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error getting rows affected") return errors.Wrap(err, "error getting rows affected")
} }
if rowsAffected == 0 { if rowsAffected == 0 {
return nil, errors.New("no rows updated") return errors.New("no rows updated")
} }
r.log.Debug().Msgf("download_client.update: %d", client.ID) r.log.Debug().Msgf("download_client.update: %d", client.ID)
// save to cache return nil
r.cache.Set(client.ID, &client)
return &client, nil
} }
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error { func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int32) error {
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{}) tx, err := r.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil { if err != nil {
return err return err
@ -311,10 +259,11 @@ func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
} }
r.log.Debug().Msgf("delete download client: %d", clientID) r.log.Debug().Msgf("delete download client: %d", clientID)
return nil return nil
} }
func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) error { func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int32) error {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Delete("client"). Delete("client").
Where(sq.Eq{"id": clientID}) Where(sq.Eq{"id": clientID})
@ -329,9 +278,6 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e
return errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }
// remove from cache
r.cache.Pop(clientID)
rows, _ := res.RowsAffected() rows, _ := res.RowsAffected()
if rows == 0 { if rows == 0 {
return errors.New("no rows affected") return errors.New("no rows affected")
@ -342,9 +288,7 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e
return nil return nil
} }
func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int) error { func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int32) error {
var err error
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Update("action"). Update("action").
Set("enabled", false). Set("enabled", false).
@ -355,12 +299,14 @@ func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx,
// return values // return values
var filterID int var filterID int
if err = queryBuilder.QueryRowContext(ctx).Scan(&filterID); err != nil { err := queryBuilder.QueryRowContext(ctx).Scan(&filterID)
if err != nil {
// this will throw when the client is not connected to any actions // this will throw when the client is not connected to any actions
// it is not an error in this case // it is not an error in this case
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil return nil
} }
return errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }

View file

@ -8,10 +8,12 @@ package database
import ( import (
"context" "context"
"fmt" "fmt"
"github.com/autobrr/autobrr/internal/domain"
"github.com/stretchr/testify/assert"
"testing" "testing"
"time" "time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/stretchr/testify/assert"
) )
func getMockDownloadClient() domain.DownloadClient { func getMockDownloadClient() domain.DownloadClient {
@ -54,13 +56,14 @@ func TestDownloadClientRepo_List(t *testing.T) {
t.Run(fmt.Sprintf("List_Succeeds_With_No_Filters [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds_With_No_Filters [%s]", dbType), func(t *testing.T) {
// Insert mock data // Insert mock data
createdClient, err := repo.Store(context.Background(), mockData) mock := &mockData
err := repo.Store(context.Background(), mock)
clients, err := repo.List(context.Background()) clients, err := repo.List(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.NotEmpty(t, clients) assert.NotEmpty(t, clients)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("List_Succeeds_With_Empty_Database [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds_With_Empty_Database [%s]", dbType), func(t *testing.T) {
@ -77,32 +80,34 @@ func TestDownloadClientRepo_List(t *testing.T) {
}) })
t.Run(fmt.Sprintf("List_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) {
createdClient, err := repo.Store(context.Background(), mockData) mock := &mockData
err := repo.Store(context.Background(), mock)
clients, err := repo.List(context.Background()) clients, err := repo.List(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 1, len(clients)) assert.Equal(t, 1, len(clients))
assert.Equal(t, createdClient.Name, clients[0].Name) assert.Equal(t, mock.Name, clients[0].Name)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("List_Succeeds_With_Boundary_Value_For_Port [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds_With_Boundary_Value_For_Port [%s]", dbType), func(t *testing.T) {
mockData.Port = 65535 mock := &mockData
createdClient, err := repo.Store(context.Background(), mockData) mock.Port = 65535
err := repo.Store(context.Background(), mock)
clients, err := repo.List(context.Background()) clients, err := repo.List(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, 65535, clients[0].Port) assert.Equal(t, 65535, clients[0].Port)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("List_Succeeds_With_Boolean_Flags_Set_To_False [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds_With_Boolean_Flags_Set_To_False [%s]", dbType), func(t *testing.T) {
mockData.Enabled = false mockData.Enabled = false
mockData.TLS = false mockData.TLS = false
mockData.TLSSkipVerify = false mockData.TLSSkipVerify = false
createdClient, err := repo.Store(context.Background(), mockData) err := repo.Store(context.Background(), &mockData)
clients, err := repo.List(context.Background()) clients, err := repo.List(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, false, clients[0].Enabled) assert.Equal(t, false, clients[0].Enabled)
@ -110,18 +115,18 @@ func TestDownloadClientRepo_List(t *testing.T) {
assert.Equal(t, false, clients[0].TLSSkipVerify) assert.Equal(t, false, clients[0].TLSSkipVerify)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mockData.ID)
}) })
t.Run(fmt.Sprintf("List_Succeeds_With_Special_Characters_In_Name [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds_With_Special_Characters_In_Name [%s]", dbType), func(t *testing.T) {
mockData.Name = "Special$Name" mockData.Name = "Special$Name"
createdClient, err := repo.Store(context.Background(), mockData) err := repo.Store(context.Background(), &mockData)
clients, err := repo.List(context.Background()) clients, err := repo.List(context.Background())
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "Special$Name", clients[0].Name) assert.Equal(t, "Special$Name", clients[0].Name)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mockData.ID)
}) })
} }
} }
@ -133,13 +138,14 @@ func TestDownloadClientRepo_FindByID(t *testing.T) {
mockData := getMockDownloadClient() mockData := getMockDownloadClient()
t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData) mock := &mockData
foundClient, err := repo.FindByID(context.Background(), int32(createdClient.ID)) _ = repo.Store(context.Background(), mock)
foundClient, err := repo.FindByID(context.Background(), mock.ID)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, foundClient) assert.NotNil(t, foundClient)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("FindByID_Fails_With_Nonexistent_ID [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Fails_With_Nonexistent_ID [%s]", dbType), func(t *testing.T) {
@ -156,40 +162,44 @@ func TestDownloadClientRepo_FindByID(t *testing.T) {
t.Run(fmt.Sprintf("FindByID_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel() defer cancel()
_, err := repo.FindByID(ctx, 1) _, err := repo.FindByID(ctx, 1)
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run(fmt.Sprintf("FindByID_Fails_After_Client_Deleted [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Fails_After_Client_Deleted [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData) mock := &mockData
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Store(context.Background(), mock)
_, err := repo.FindByID(context.Background(), int32(createdClient.ID)) _ = repo.Delete(context.Background(), mock.ID)
_, err := repo.FindByID(context.Background(), mock.ID)
assert.Error(t, err) assert.Error(t, err)
assert.Equal(t, "no client configured", err.Error()) assert.Equal(t, "no client configured", err.Error())
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("FindByID_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData) mock := &mockData
foundClient, err := repo.FindByID(context.Background(), int32(createdClient.ID)) _ = repo.Store(context.Background(), mock)
foundClient, err := repo.FindByID(context.Background(), mock.ID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, createdClient.Name, foundClient.Name) assert.Equal(t, mock.Name, foundClient.Name)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mock.ID)
}) })
t.Run(fmt.Sprintf("FindByID_Succeeds_From_Cache [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Succeeds_From_Cache [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData) mock := &mockData
foundClient1, _ := repo.FindByID(context.Background(), int32(createdClient.ID)) _ = repo.Store(context.Background(), mock)
foundClient2, err := repo.FindByID(context.Background(), int32(createdClient.ID)) foundClient1, _ := repo.FindByID(context.Background(), mock.ID)
foundClient2, err := repo.FindByID(context.Background(), mock.ID)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, foundClient1, foundClient2) assert.Equal(t, foundClient1, foundClient2)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -201,17 +211,17 @@ func TestDownloadClientRepo_Store(t *testing.T) {
t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) {
mockData := getMockDownloadClient() mockData := getMockDownloadClient()
createdClient, err := repo.Store(context.Background(), mockData) err := repo.Store(context.Background(), &mockData)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mockData)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mockData.ID)
}) })
//TODO: Is this okay? Should we be able to store a client with no name (empty string)? //TODO: Is this okay? Should we be able to store a client with no name (empty string)?
t.Run(fmt.Sprintf("Store_Succeeds?_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Store_Succeeds?_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) {
badMockData := domain.DownloadClient{ badMockData := &domain.DownloadClient{
Type: "", Type: "",
Enabled: false, Enabled: false,
Host: "", Host: "",
@ -222,30 +232,30 @@ func TestDownloadClientRepo_Store(t *testing.T) {
Password: "", Password: "",
Settings: domain.DownloadClientSettings{}, Settings: domain.DownloadClientSettings{},
} }
createdClient, err := repo.Store(context.Background(), badMockData) err := repo.Store(context.Background(), badMockData)
assert.NoError(t, err) assert.NoError(t, err)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), badMockData.ID)
}) })
t.Run(fmt.Sprintf("Store_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Store_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
mockData := getMockDownloadClient() mockData := getMockDownloadClient()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel() defer cancel()
_, err := repo.Store(ctx, mockData) err := repo.Store(ctx, &mockData)
assert.Error(t, err) assert.Error(t, err)
}) })
t.Run(fmt.Sprintf("Store_Succeeds_And_Caches [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Store_Succeeds_And_Caches [%s]", dbType), func(t *testing.T) {
mockData := getMockDownloadClient() mockData := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockData) _ = repo.Store(context.Background(), &mockData)
cachedClient, _ := repo.FindByID(context.Background(), int32(createdClient.ID)) cachedClient, _ := repo.FindByID(context.Background(), mockData.ID)
assert.Equal(t, createdClient, cachedClient) assert.Equal(t, &mockData, cachedClient)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mockData.ID)
}) })
} }
} }
@ -258,22 +268,22 @@ func TestDownloadClientRepo_Update(t *testing.T) {
t.Run(fmt.Sprintf("Update_Successfully_Updates_Record [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Update_Successfully_Updates_Record [%s]", dbType), func(t *testing.T) {
mockClient := getMockDownloadClient() mockClient := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockClient) _ = repo.Store(context.Background(), &mockClient)
createdClient.Name = "updatedName" mockClient.Name = "updatedName"
updatedClient, err := repo.Update(context.Background(), *createdClient) err := repo.Update(context.Background(), &mockClient)
assert.NoError(t, err) assert.NoError(t, err)
assert.Equal(t, "updatedName", updatedClient.Name) assert.Equal(t, "updatedName", mockClient.Name)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), updatedClient.ID) _ = repo.Delete(context.Background(), mockClient.ID)
}) })
t.Run(fmt.Sprintf("Update_Fails_With_Missing_ID [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Update_Fails_With_Missing_ID [%s]", dbType), func(t *testing.T) {
badMockData := getMockDownloadClient() badMockData := getMockDownloadClient()
badMockData.ID = 0 badMockData.ID = 0
_, err := repo.Update(context.Background(), badMockData) err := repo.Update(context.Background(), &badMockData)
assert.Error(t, err) assert.Error(t, err)
@ -283,7 +293,7 @@ func TestDownloadClientRepo_Update(t *testing.T) {
badMockData := getMockDownloadClient() badMockData := getMockDownloadClient()
badMockData.ID = 9999 badMockData.ID = 9999
_, err := repo.Update(context.Background(), badMockData) err := repo.Update(context.Background(), &badMockData)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -291,7 +301,7 @@ func TestDownloadClientRepo_Update(t *testing.T) {
t.Run(fmt.Sprintf("Update_Fails_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Update_Fails_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) {
badMockData := domain.DownloadClient{} badMockData := domain.DownloadClient{}
_, err := repo.Update(context.Background(), badMockData) err := repo.Update(context.Background(), &badMockData)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -305,13 +315,13 @@ func TestDownloadClientRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Successfully_Deletes_Client [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Delete_Successfully_Deletes_Client [%s]", dbType), func(t *testing.T) {
mockClient := getMockDownloadClient() mockClient := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockClient) _ = repo.Store(context.Background(), &mockClient)
err := repo.Delete(context.Background(), createdClient.ID) err := repo.Delete(context.Background(), mockClient.ID)
assert.NoError(t, err) assert.NoError(t, err)
// Verify client was deleted // Verify client was deleted
_, err = repo.FindByID(context.Background(), int32(createdClient.ID)) _, err = repo.FindByID(context.Background(), mockClient.ID)
assert.Error(t, err) assert.Error(t, err)
}) })
@ -322,16 +332,16 @@ func TestDownloadClientRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Delete_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
mockClient := getMockDownloadClient() mockClient := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockClient) _ = repo.Store(context.Background(), &mockClient)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel() defer cancel()
err := repo.Delete(ctx, createdClient.ID) err := repo.Delete(ctx, mockClient.ID)
assert.Error(t, err) assert.Error(t, err)
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), createdClient.ID) _ = repo.Delete(context.Background(), mockClient.ID)
}) })
} }
} }

View file

@ -255,13 +255,12 @@ func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter
} }
row := r.db.handler.QueryRowContext(ctx, query, args...) row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Err(); err != nil {
if row.Err() != nil { if errors.Is(err, sql.ErrNoRows) {
if errors.Is(row.Err(), sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound return nil, domain.ErrRecordNotFound
} }
return nil, errors.Wrap(row.Err(), "error row") return nil, errors.Wrap(err, "error row")
} }
var f domain.Filter var f domain.Filter

View file

@ -791,12 +791,14 @@ func TestFilterRepo_GetDownloadsByFilterId(t *testing.T) {
err := repo.Store(context.Background(), mockData) err := repo.Store(context.Background(), mockData)
assert.NoError(t, err) assert.NoError(t, err)
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mockClient := getMockDownloadClient()
err = downloadClientRepo.Store(context.Background(), &mockClient)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mockClient)
mockAction.FilterID = mockData.ID mockAction.FilterID = mockData.ID
mockAction.ClientID = int32(createdClient.ID) mockAction.ClientID = mockClient.ID
action, err := actionRepo.Store(context.Background(), mockAction) action, err := actionRepo.Store(context.Background(), mockAction)
@ -827,7 +829,7 @@ func TestFilterRepo_GetDownloadsByFilterId(t *testing.T) {
// Cleanup // Cleanup
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: action.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: action.ID})
_ = repo.Delete(context.Background(), mockData.ID) _ = repo.Delete(context.Background(), mockData.ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mockClient.ID)
_ = releaseRepo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = releaseRepo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
}) })

View file

@ -89,9 +89,10 @@ func TestReleaseRepo_Store(t *testing.T) {
t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -101,7 +102,7 @@ func TestReleaseRepo_Store(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
// Execute // Execute
@ -124,7 +125,7 @@ func TestReleaseRepo_Store(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -144,9 +145,10 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) {
t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -156,7 +158,7 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
// Execute // Execute
@ -179,7 +181,7 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -199,9 +201,10 @@ func TestReleaseRepo_Find(t *testing.T) {
t.Run(fmt.Sprintf("FindReleases_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindReleases_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -211,7 +214,7 @@ func TestReleaseRepo_Find(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
// Execute // Execute
@ -238,7 +241,7 @@ func TestReleaseRepo_Find(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -258,9 +261,10 @@ func TestReleaseRepo_FindRecent(t *testing.T) {
t.Run(fmt.Sprintf("FindRecent_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("FindRecent_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -270,7 +274,7 @@ func TestReleaseRepo_FindRecent(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
// Execute // Execute
@ -286,7 +290,7 @@ func TestReleaseRepo_FindRecent(t *testing.T) {
// Cleanup // Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -306,9 +310,10 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) {
t.Run(fmt.Sprintf("GetIndexerOptions_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("GetIndexerOptions_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -318,7 +323,7 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData) err = repo.Store(context.Background(), mockData)
@ -344,7 +349,7 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -364,9 +369,10 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) {
t.Run(fmt.Sprintf("GetActionStatusByReleaseID_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("GetActionStatusByReleaseID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -376,7 +382,7 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData) err = repo.Store(context.Background(), mockData)
@ -403,7 +409,7 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -423,9 +429,10 @@ func TestReleaseRepo_Get(t *testing.T) {
t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -435,7 +442,7 @@ func TestReleaseRepo_Get(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData) err = repo.Store(context.Background(), mockData)
@ -462,7 +469,7 @@ func TestReleaseRepo_Get(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -482,9 +489,10 @@ func TestReleaseRepo_Stats(t *testing.T) {
t.Run(fmt.Sprintf("Stats_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Stats_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -494,7 +502,7 @@ func TestReleaseRepo_Stats(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData) err = repo.Store(context.Background(), mockData)
@ -521,7 +529,7 @@ func TestReleaseRepo_Stats(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -541,9 +549,10 @@ func TestReleaseRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -553,7 +562,7 @@ func TestReleaseRepo_Delete(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData) err = repo.Store(context.Background(), mockData)
@ -577,7 +586,7 @@ func TestReleaseRepo_Delete(t *testing.T) {
// Cleanup // Cleanup
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }
@ -597,9 +606,10 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) {
t.Run(fmt.Sprintf("Check_Smart_Episode_Can_Download [%s]", dbType), func(t *testing.T) { t.Run(fmt.Sprintf("Check_Smart_Episode_Can_Download [%s]", dbType), func(t *testing.T) {
// Setup // Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err) assert.NoError(t, err)
assert.NotNil(t, createdClient) assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter()) err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err) assert.NoError(t, err)
@ -609,7 +619,7 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) {
assert.NotNil(t, createdFilters) assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID) actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData) err = repo.Store(context.Background(), mockData)
@ -644,7 +654,7 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID) _ = downloadClientRepo.Delete(context.Background(), mock.ID)
}) })
} }
} }

View file

@ -14,7 +14,7 @@ import (
type ActionRepo interface { type ActionRepo interface {
Store(ctx context.Context, action Action) (*Action, error) Store(ctx context.Context, action Action) (*Action, error)
StoreFilterActions(ctx context.Context, filterID int64, actions []*Action) ([]*Action, error) StoreFilterActions(ctx context.Context, filterID int64, actions []*Action) ([]*Action, error)
FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*Action, error) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*Action, error)
List(ctx context.Context) ([]Action, error) List(ctx context.Context) ([]Action, error)
Get(ctx context.Context, req *GetActionRequest) (*Action, error) Get(ctx context.Context, req *GetActionRequest) (*Action, error)
Delete(ctx context.Context, req *DeleteActionRequest) error Delete(ctx context.Context, req *DeleteActionRequest) error

View file

@ -9,20 +9,18 @@ import (
"net/url" "net/url"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/go-qbittorrent"
) )
type DownloadClientRepo interface { type DownloadClientRepo interface {
List(ctx context.Context) ([]DownloadClient, error) List(ctx context.Context) ([]DownloadClient, error)
FindByID(ctx context.Context, id int32) (*DownloadClient, error) FindByID(ctx context.Context, id int32) (*DownloadClient, error)
Store(ctx context.Context, client DownloadClient) (*DownloadClient, error) Store(ctx context.Context, client *DownloadClient) error
Update(ctx context.Context, client DownloadClient) (*DownloadClient, error) Update(ctx context.Context, client *DownloadClient) error
Delete(ctx context.Context, clientID int) error Delete(ctx context.Context, clientID int32) error
} }
type DownloadClient struct { type DownloadClient struct {
ID int `json:"id"` ID int32 `json:"id"`
Name string `json:"name"` Name string `json:"name"`
Type DownloadClientType `json:"type"` Type DownloadClientType `json:"type"`
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
@ -33,11 +31,9 @@ type DownloadClient struct {
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Settings DownloadClientSettings `json:"settings,omitempty"` Settings DownloadClientSettings `json:"settings,omitempty"`
}
type DownloadClientCached struct { // cached http client
Dc *DownloadClient Client any
Qbt *qbittorrent.Client
} }
type DownloadClientSettings struct { type DownloadClientSettings struct {

View file

@ -11,7 +11,7 @@ import (
func TestDownloadClient_qbitBuildLegacyHost(t *testing.T) { func TestDownloadClient_qbitBuildLegacyHost(t *testing.T) {
type fields struct { type fields struct {
ID int ID int32
Name string Name string
Type DownloadClientType Type DownloadClientType
Enabled bool Enabled bool

View file

@ -0,0 +1,48 @@
package download_client
import (
"sync"
"github.com/autobrr/autobrr/internal/domain"
)
type ClientCacheStore interface {
Set(id int32, client *domain.DownloadClient)
Get(id int32) *domain.DownloadClient
Pop(id int32)
}
type ClientCache struct {
mu sync.RWMutex
clients map[int32]*domain.DownloadClient
}
func NewClientCache() *ClientCache {
return &ClientCache{
clients: make(map[int32]*domain.DownloadClient),
}
}
func (c *ClientCache) Set(id int32, client *domain.DownloadClient) {
if client != nil {
c.mu.Lock()
c.clients[id] = client
c.mu.Unlock()
}
}
func (c *ClientCache) Get(id int32) *domain.DownloadClient {
c.mu.RLock()
defer c.mu.RUnlock()
v, ok := c.clients[id]
if ok {
return v
}
return nil
}
func (c *ClientCache) Pop(id int32) {
c.mu.Lock()
delete(c.clients, id)
c.mu.Unlock()
}

View file

@ -5,13 +5,27 @@ package download_client
import ( import (
"context" "context"
"fmt"
"log" "log"
"net/url"
"sync" "sync"
"time"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/lidarr"
"github.com/autobrr/autobrr/pkg/porla"
"github.com/autobrr/autobrr/pkg/radarr"
"github.com/autobrr/autobrr/pkg/readarr"
"github.com/autobrr/autobrr/pkg/sabnzbd"
"github.com/autobrr/autobrr/pkg/sonarr"
"github.com/autobrr/autobrr/pkg/transmission"
"github.com/autobrr/autobrr/pkg/whisparr"
"github.com/autobrr/go-deluge"
"github.com/autobrr/go-qbittorrent" "github.com/autobrr/go-qbittorrent"
"github.com/autobrr/go-rtorrent"
"github.com/dcarbone/zadapters/zstdlog" "github.com/dcarbone/zadapters/zstdlog"
"github.com/rs/zerolog" "github.com/rs/zerolog"
) )
@ -19,12 +33,12 @@ import (
type Service interface { type Service interface {
List(ctx context.Context) ([]domain.DownloadClient, error) List(ctx context.Context) ([]domain.DownloadClient, error)
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) Store(ctx context.Context, client *domain.DownloadClient) error
Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) Update(ctx context.Context, client *domain.DownloadClient) error
Delete(ctx context.Context, clientID int) error Delete(ctx context.Context, clientID int32) error
Test(ctx context.Context, client domain.DownloadClient) error Test(ctx context.Context, client domain.DownloadClient) error
GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error)
} }
type service struct { type service struct {
@ -32,7 +46,7 @@ type service struct {
repo domain.DownloadClientRepo repo domain.DownloadClientRepo
subLogger *log.Logger subLogger *log.Logger
qbitClients map[int32]*domain.DownloadClientCached cache *ClientCache
m sync.RWMutex m sync.RWMutex
} }
@ -41,7 +55,7 @@ func NewService(log logger.Logger, repo domain.DownloadClientRepo) Service {
log: log.With().Str("module", "download_client").Logger(), log: log.With().Str("module", "download_client").Logger(),
repo: repo, repo: repo,
qbitClients: map[int32]*domain.DownloadClientCached{}, cache: NewClientCache(),
m: sync.RWMutex{}, m: sync.RWMutex{},
} }
@ -61,6 +75,13 @@ func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) {
} }
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) { func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
client := s.cache.Get(id)
if client != nil {
return client, nil
}
s.log.Trace().Msgf("cache miss for client id %d, continue to repo lookup", id)
client, err := s.repo.FindByID(ctx, id) client, err := s.repo.FindByID(ctx, id)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("could not find download client by id: %v", id) s.log.Error().Err(err).Msgf("could not find download client by id: %v", id)
@ -70,53 +91,49 @@ func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClien
return client, nil return client, nil
} }
func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { func (s *service) Store(ctx context.Context, client *domain.DownloadClient) error {
// basic validation of client // basic validation of client
if err := client.Validate(); err != nil { if err := client.Validate(); err != nil {
return nil, err return err
} }
// store // store
c, err := s.repo.Store(ctx, client) err := s.repo.Store(ctx, client)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("could not store download client: %+v", client) s.log.Error().Err(err).Msgf("could not store download client: %+v", client)
return nil, err return err
} }
return c, err s.cache.Set(client.ID, client)
return err
} }
func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { func (s *service) Update(ctx context.Context, client *domain.DownloadClient) error {
// basic validation of client // basic validation of client
if err := client.Validate(); err != nil { if err := client.Validate(); err != nil {
return nil, err return err
} }
// update // update
c, err := s.repo.Update(ctx, client) err := s.repo.Update(ctx, client)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("could not update download client: %+v", client) s.log.Error().Err(err).Msgf("could not update download client: %+v", client)
return nil, err return err
} }
if client.Type == domain.DownloadClientTypeQbittorrent { s.cache.Set(client.ID, client)
s.m.Lock()
delete(s.qbitClients, int32(client.ID)) return err
s.m.Unlock()
} }
return c, err func (s *service) Delete(ctx context.Context, clientID int32) error {
}
func (s *service) Delete(ctx context.Context, clientID int) error {
if err := s.repo.Delete(ctx, clientID); err != nil { if err := s.repo.Delete(ctx, clientID); err != nil {
s.log.Error().Err(err).Msgf("could not delete download client: %v", clientID) s.log.Error().Err(err).Msgf("could not delete download client: %v", clientID)
return err return err
} }
s.m.Lock() s.cache.Pop(clientID)
delete(s.qbitClients, int32(clientID))
s.m.Unlock()
return nil return nil
} }
@ -136,53 +153,165 @@ func (s *service) Test(ctx context.Context, client domain.DownloadClient) error
return nil return nil
} }
func (s *service) GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached { // GetClient get client from cache or repo and attach downloadClient implementation
func (s *service) GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error) {
// check if client exists in cache l := s.log.With().Str("cache", "download-client").Logger()
s.m.RLock()
cached, ok := s.qbitClients[clientId]
s.m.RUnlock()
if ok {
return cached
}
// get client for action
client, err := s.FindByID(ctx, clientId)
if err != nil {
return nil
}
client := s.cache.Get(clientId)
if client == nil { if client == nil {
return nil l.Trace().Msgf("cache miss for client id %d, continue to repo lookup", clientId)
var err error
client, err = s.repo.FindByID(ctx, clientId)
if err != nil {
return nil, errors.Wrap(err, "could not find client repo.FindByID")
}
} }
qbtSettings := qbittorrent.Config{ // if we have the client return it
if client.Client != nil {
l.Trace().Msgf("cache hit for client id %d %s", clientId, client.Name)
return client, nil
}
l.Trace().Msgf("init cache client id %d %s", clientId, client.Name)
switch client.Type {
case domain.DownloadClientTypeQbittorrent:
client.Client = qbittorrent.NewClient(qbittorrent.Config{
Host: client.BuildLegacyHost(), Host: client.BuildLegacyHost(),
Username: client.Username, Username: client.Username,
Password: client.Password, Password: client.Password,
TLSSkipVerify: client.TLSSkipVerify, TLSSkipVerify: client.TLSSkipVerify,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
})
case domain.DownloadClientTypePorla:
client.Client = porla.NewClient(porla.Config{
Hostname: client.Host,
AuthToken: client.Settings.APIKey,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Porla").Str("client", client.Name).Logger(), zerolog.TraceLevel),
})
case domain.DownloadClientTypeDelugeV1:
client.Client = deluge.NewV1(deluge.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 60,
})
case domain.DownloadClientTypeDelugeV2:
client.Client = deluge.NewV2(deluge.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 60,
})
case domain.DownloadClientTypeTransmission:
scheme := "http"
if client.TLS {
scheme = "https"
} }
// setup sub logger adapter which is compatible with *log.Logger transmissionURL, err := url.Parse(fmt.Sprintf("%s://%s:%d/transmission/rpc", scheme, client.Host, client.Port))
qbtSettings.Log = zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel) if err != nil {
return nil, errors.Wrap(err, "could not parse transmission url")
// only set basic auth if enabled
if client.Settings.Basic.Auth {
qbtSettings.BasicUser = client.Settings.Basic.Username
qbtSettings.BasicPass = client.Settings.Basic.Password
} }
qc := &domain.DownloadClientCached{ tbt, err := transmission.New(transmissionURL, &transmission.Config{
Dc: client, UserAgent: "autobrr",
Qbt: qbittorrent.NewClient(qbtSettings), Username: client.Username,
Password: client.Password,
TLSSkipVerify: client.TLSSkipVerify,
})
if err != nil {
return nil, errors.Wrap(err, "error logging into transmission client: %s", client.Host)
}
client.Client = tbt
case domain.DownloadClientTypeRTorrent:
client.Client = rtorrent.NewClient(rtorrent.Config{
Addr: client.Host,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "rTorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel),
})
case domain.DownloadClientTypeLidarr:
client.Client = lidarr.New(lidarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Lidarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeRadarr:
client.Client = radarr.New(radarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Radarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeReadarr:
client.Client = readarr.New(readarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Readarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeSonarr:
client.Client = sonarr.New(sonarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Sonarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeWhisparr:
client.Client = whisparr.New(whisparr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Whisparr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeSabnzbd:
client.Client = sabnzbd.New(sabnzbd.Options{
Addr: client.Host,
ApiKey: client.Settings.APIKey,
Log: nil,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
})
} }
cached = qc l.Trace().Msgf("set cache client id %d %s", clientId, client.Name)
s.m.Lock() s.cache.Set(clientId, client)
s.qbitClients[clientId] = cached
s.m.Unlock()
return cached return client, nil
} }

View file

@ -7,6 +7,7 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/autobrr/autobrr/internal/action"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -48,7 +49,7 @@ type Service interface {
type service struct { type service struct {
log zerolog.Logger log zerolog.Logger
repo domain.FilterRepo repo domain.FilterRepo
actionRepo domain.ActionRepo actionService action.Service
releaseRepo domain.ReleaseRepo releaseRepo domain.ReleaseRepo
indexerSvc indexer.Service indexerSvc indexer.Service
apiService indexer.APIService apiService indexer.APIService
@ -56,12 +57,12 @@ type service struct {
httpClient *http.Client httpClient *http.Client
} }
func NewService(log logger.Logger, repo domain.FilterRepo, actionRepo domain.ActionRepo, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service { func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service {
return &service{ return &service{
log: log.With().Str("module", "filter").Logger(), log: log.With().Str("module", "filter").Logger(),
repo: repo, repo: repo,
actionRepo: actionRepo,
releaseRepo: releaseRepo, releaseRepo: releaseRepo,
actionService: actionSvc,
apiService: apiService, apiService: apiService,
indexerSvc: indexerSvc, indexerSvc: indexerSvc,
httpClient: &http.Client{ httpClient: &http.Client{
@ -130,7 +131,7 @@ func (s *service) FindByID(ctx context.Context, filterID int) (*domain.Filter, e
} }
filter.External = externalFilters filter.External = externalFilters
actions, err := s.actionRepo.FindByFilterID(ctx, filter.ID, nil) actions, err := s.actionService.FindByFilterID(ctx, filter.ID, nil, false)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("could not find filter actions for filter id: %v", filter.ID) s.log.Error().Err(err).Msgf("could not find filter actions for filter id: %v", filter.ID)
} }
@ -222,7 +223,7 @@ func (s *service) Update(ctx context.Context, filter *domain.Filter) error {
} }
// take care of filter actions // take care of filter actions
actions, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions) actions, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name) s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name)
return err return err
@ -267,7 +268,7 @@ func (s *service) UpdatePartial(ctx context.Context, filter domain.FilterUpdate)
if filter.Actions != nil { if filter.Actions != nil {
// take care of filter actions // take care of filter actions
if _, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil { if _, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil {
s.log.Error().Err(err).Msgf("could not store filter actions: %v", filter.ID) s.log.Error().Err(err).Msgf("could not store filter actions: %v", filter.ID)
return err return err
} }
@ -308,7 +309,7 @@ func (s *service) Duplicate(ctx context.Context, filterID int) (*domain.Filter,
} }
// take care of filter actions // take care of filter actions
if _, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil { if _, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil {
s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name) s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name)
return nil, err return nil, err
} }
@ -340,7 +341,7 @@ func (s *service) Delete(ctx context.Context, filterID int) error {
} }
// take care of filter actions // take care of filter actions
if err := s.actionRepo.DeleteByFilterID(ctx, filterID); err != nil { if err := s.actionService.DeleteByFilterID(ctx, filterID); err != nil {
s.log.Error().Err(err).Msg("could not delete filter actions") s.log.Error().Err(err).Msg("could not delete filter actions")
return err return err
} }

View file

@ -17,9 +17,9 @@ import (
type downloadClientService interface { type downloadClientService interface {
List(ctx context.Context) ([]domain.DownloadClient, error) List(ctx context.Context) ([]domain.DownloadClient, error)
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) Store(ctx context.Context, client *domain.DownloadClient) error
Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) Update(ctx context.Context, client *domain.DownloadClient) error
Delete(ctx context.Context, clientID int) error Delete(ctx context.Context, clientID int32) error
Test(ctx context.Context, client domain.DownloadClient) error Test(ctx context.Context, client domain.DownloadClient) error
} }
@ -56,20 +56,20 @@ func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *htt
} }
func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) {
var data domain.DownloadClient var data *domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil { if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err) h.encoder.Error(w, err)
return return
} }
client, err := h.service.Store(r.Context(), data) err := h.service.Store(r.Context(), data)
if err != nil { if err != nil {
h.encoder.Error(w, err) h.encoder.Error(w, err)
return return
} }
h.encoder.StatusResponse(w, http.StatusCreated, client) h.encoder.StatusResponse(w, http.StatusCreated, data)
} }
func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) {
@ -89,20 +89,20 @@ func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) {
} }
func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) {
var data domain.DownloadClient var data *domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil { if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err) h.encoder.Error(w, err)
return return
} }
client, err := h.service.Update(r.Context(), data) err := h.service.Update(r.Context(), data)
if err != nil { if err != nil {
h.encoder.Error(w, err) h.encoder.Error(w, err)
return return
} }
h.encoder.StatusResponse(w, http.StatusCreated, client) h.encoder.StatusResponse(w, http.StatusCreated, data)
} }
func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) {
@ -113,13 +113,13 @@ func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) {
return return
} }
id, err := strconv.Atoi(clientID) id, err := strconv.ParseInt(clientID, 10, 32)
if err != nil { if err != nil {
h.encoder.Error(w, err) h.encoder.Error(w, err)
return return
} }
if err = h.service.Delete(r.Context(), id); err != nil { if err = h.service.Delete(r.Context(), int32(id)); err != nil {
h.encoder.Error(w, err) h.encoder.Error(w, err)
return return
} }

View file

@ -221,7 +221,7 @@ func (s *service) processFilters(ctx context.Context, filters []*domain.Filter,
// found matching filter, lets find the filter actions and attach // found matching filter, lets find the filter actions and attach
active := true active := true
actions, err := s.actionSvc.FindByFilterID(ctx, f.ID, &active) actions, err := s.actionSvc.FindByFilterID(ctx, f.ID, &active, false)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("release.Process: error finding actions for filter: %s", f.Name) s.log.Error().Err(err).Msgf("release.Process: error finding actions for filter: %s", f.Name)
return err return err

View file

@ -17,7 +17,7 @@ type Config struct {
Username string Username string
Password string Password string
TLSSkipVerify bool TLSSkipVerify bool
Timeout int Timeout time.Duration
} }
func New(endpoint *url.URL, cfg *Config) (*transmissionrpc.Client, error) { func New(endpoint *url.URL, cfg *Config) (*transmissionrpc.Client, error) {