fix(actions): reject if client is disabled (#1626)

* fix(actions): error on disabled client

* fix(actions): sql scan args

* refactor: download client cache for actions

* fix: tests client store

* fix: tests client store and int conversion

* fix: tests revert findbyid ctx timeout

* fix: tests row.err

* feat: add logging to download client cache
This commit is contained in:
ze0s 2024-08-27 19:45:06 +02:00 committed by GitHub
parent 77e1c2c305
commit 861f30c144
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
30 changed files with 928 additions and 680 deletions

View file

@ -117,7 +117,7 @@ func main() {
downloadClientService = download_client.NewService(log, downloadClientRepo)
actionService = action.NewService(log, actionRepo, downloadClientService, bus)
indexerService = indexer.NewService(log, cfg.Config, indexerRepo, releaseRepo, indexerAPIService, schedulingService)
filterService = filter.NewService(log, filterRepo, actionRepo, releaseRepo, indexerAPIService, indexerService)
filterService = filter.NewService(log, filterRepo, actionService, releaseRepo, indexerAPIService, indexerService)
releaseService = release.NewService(log, releaseRepo, actionService, filterService, indexerService)
ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService)
feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, schedulingService)

View file

@ -7,7 +7,6 @@ import (
"context"
"encoding/base64"
"os"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
@ -20,20 +19,18 @@ func (s *service) deluge(ctx context.Context, action *domain.Action, release dom
var err error
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID)
return nil, err
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
if client == nil {
return nil, errors.New("could not find client by id: %d", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
var rejections []string
switch client.Type {
switch action.Client.Type {
case "DELUGE_V1":
rejections, err = s.delugeV1(ctx, client, action, release)
@ -90,27 +87,18 @@ func (s *service) delugeCheckRulesCanDownload(ctx context.Context, del deluge.De
}
func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) {
settings := deluge.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 30,
}
del := deluge.NewV1(settings)
downloadClient := client.Client.(*deluge.Client)
// perform connection to Deluge server
err := del.Connect(ctx)
err := downloadClient.Connect(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host)
}
defer del.Close()
defer downloadClient.Close()
// perform connection to Deluge server
rejections, err := s.delugeCheckRulesCanDownload(ctx, del, client, action)
rejections, err := s.delugeCheckRulesCanDownload(ctx, downloadClient, client, action)
if err != nil {
s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name)
return nil, err
@ -127,13 +115,13 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentMagnet(ctx, release.MagnetURI, &options)
torrentHash, err := downloadClient.AddTorrentMagnet(ctx, release.MagnetURI, &options)
if err != nil {
return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name)
}
if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx)
labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
}
@ -176,13 +164,13 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options)
torrentHash, err := downloadClient.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options)
if err != nil {
return nil, errors.Wrap(err, "could not add torrent %v to client: %v", release.TorrentTmpFile, client.Name)
}
if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx)
labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
}
@ -203,27 +191,18 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
}
func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) {
settings := deluge.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 30,
}
del := deluge.NewV2(settings)
downloadClient := client.Client.(*deluge.ClientV2)
// perform connection to Deluge server
err := del.Connect(ctx)
err := downloadClient.Connect(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host)
}
defer del.Close()
defer downloadClient.Close()
// perform connection to Deluge server
rejections, err := s.delugeCheckRulesCanDownload(ctx, del, client, action)
rejections, err := s.delugeCheckRulesCanDownload(ctx, downloadClient, client, action)
if err != nil {
s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name)
return nil, err
@ -240,13 +219,13 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentMagnet(ctx, release.MagnetURI, &options)
torrentHash, err := downloadClient.AddTorrentMagnet(ctx, release.MagnetURI, &options)
if err != nil {
return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name)
}
if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx)
labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
}
@ -290,13 +269,13 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a
s.log.Trace().Msgf("action Deluge options: %+v", options)
torrentHash, err := del.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options)
torrentHash, err := downloadClient.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options)
if err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, client.Name)
}
if action.Label != "" {
labelPluginActive, err := del.LabelPlugin(ctx)
labelPluginActive, err := downloadClient.LabelPlugin(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name)
}

View file

@ -17,41 +17,16 @@ func (s *service) lidarr(ctx context.Context, action *domain.Action, release dom
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
s.log.Error().Err(err).Msgf("lidarr: error finding client: %v", action.ClientID)
return nil, err
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
// return early if no client found
if client == nil {
return nil, errors.New("could not find client by id: %v", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
// initial config
cfg := lidarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
arr := client.Client.(lidarr.Client)
r := lidarr.Release{
Title: release.TorrentName,
@ -60,14 +35,20 @@ func (s *service) lidarr(ctx context.Context, action *domain.Action, release dom
MagnetUrl: release.MagnetURI,
Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId,
DownloadClient: externalClient,
DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339),
}
arr := lidarr.New(cfg)
if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r)
if err != nil {

View file

@ -13,34 +13,21 @@ import (
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/porla"
"github.com/dcarbone/zadapters/zstdlog"
"github.com/rs/zerolog"
)
func (s *service) porla(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action Porla: %s", action.Name)
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "error finding client: %d", action.ClientID)
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
if client == nil {
return nil, errors.New("could not find client by id: %d", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
porlaSettings := porla.Config{
Hostname: client.Host,
AuthToken: client.Settings.APIKey,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
}
porlaSettings.Log = zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Porla").Str("client", client.Name).Logger(), zerolog.TraceLevel)
prl := porla.NewClient(porlaSettings)
prl := client.Client.(*porla.Client)
rejections, err := s.porlaCheckRulesCanDownload(ctx, action, client, prl)
if err != nil {

View file

@ -17,11 +17,20 @@ import (
func (s *service) qbittorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action qBittorrent: %s", action.Name)
c := s.clientSvc.GetCachedClient(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
if c.Dc.Settings.Rules.Enabled && !action.IgnoreRules {
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
qbtClient := client.Client.(*qbittorrent.Client)
if client.Settings.Rules.Enabled && !action.IgnoreRules {
// check for active downloads and other rules
rejections, err := s.qbittorrentCheckRulesCanDownload(ctx, action, c.Dc.Settings.Rules, c.Qbt)
rejections, err := s.qbittorrentCheckRulesCanDownload(ctx, action, client.Settings.Rules, qbtClient)
if err != nil {
return nil, errors.Wrap(err, "error checking client rules: %s", action.Name)
}
@ -39,11 +48,11 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
s.log.Trace().Msgf("action qBittorrent options: %+v", options)
if err = c.Qbt.AddTorrentFromUrlCtx(ctx, release.MagnetURI, options); err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.MagnetURI, c.Dc.Name)
if err = qbtClient.AddTorrentFromUrlCtx(ctx, release.MagnetURI, options); err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.MagnetURI, client.Name)
}
s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", c.Dc.Name)
s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", client.Name)
return nil, nil
}
@ -61,37 +70,37 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
s.log.Trace().Msgf("action qBittorrent options: %+v", options)
if err = c.Qbt.AddTorrentFromFileCtx(ctx, release.TorrentTmpFile, options); err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, c.Dc.Name)
if err = qbtClient.AddTorrentFromFileCtx(ctx, release.TorrentTmpFile, options); err != nil {
return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, client.Name)
}
if release.TorrentHash != "" {
// check if torrent queueing is enabled if priority is set
switch action.PriorityLayout {
case domain.PriorityLayoutMax, domain.PriorityLayoutMin:
prefs, err := c.Qbt.GetAppPreferencesCtx(ctx)
prefs, err := qbtClient.GetAppPreferencesCtx(ctx)
if err != nil {
return nil, errors.Wrap(err, "could not get application preferences from client: '%s'", c.Dc.Name)
return nil, errors.Wrap(err, "could not get application preferences from client: '%s'", client.Name)
}
// enable queueing if it's disabled
if !prefs.QueueingEnabled {
if err := c.Qbt.SetPreferencesQueueingEnabled(true); err != nil {
if err := qbtClient.SetPreferencesQueueingEnabled(true); err != nil {
return nil, errors.Wrap(err, "could not enable torrent queueing")
}
s.log.Trace().Msgf("torrent queueing was disabled, now enabled in client: '%s'", c.Dc.Name)
s.log.Trace().Msgf("torrent queueing was disabled, now enabled in client: '%s'", client.Name)
}
// set priority if queueing is enabled
if action.PriorityLayout == domain.PriorityLayoutMax {
if err := c.Qbt.SetMaxPriorityCtx(ctx, []string{release.TorrentHash}); err != nil {
if err := qbtClient.SetMaxPriorityCtx(ctx, []string{release.TorrentHash}); err != nil {
return nil, errors.Wrap(err, "could not set torrent %s to max priority", release.TorrentHash)
}
s.log.Debug().Msgf("torrent with hash %s set to max priority in client: '%s'", release.TorrentHash, c.Dc.Name)
s.log.Debug().Msgf("torrent with hash %s set to max priority in client: '%s'", release.TorrentHash, client.Name)
} else { // domain.PriorityLayoutMin
if err := c.Qbt.SetMinPriorityCtx(ctx, []string{release.TorrentHash}); err != nil {
if err := qbtClient.SetMinPriorityCtx(ctx, []string{release.TorrentHash}); err != nil {
return nil, errors.Wrap(err, "could not set torrent %s to min priority", release.TorrentHash)
}
s.log.Debug().Msgf("torrent with hash %s set to min priority in client: '%s'", release.TorrentHash, c.Dc.Name)
s.log.Debug().Msgf("torrent with hash %s set to min priority in client: '%s'", release.TorrentHash, client.Name)
}
case domain.PriorityLayoutDefault:
@ -111,7 +120,7 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
DeleteOnFailure: action.ReAnnounceDelete,
}
if err := c.Qbt.ReannounceTorrentWithRetry(ctx, release.TorrentHash, &opts); err != nil {
if err := qbtClient.ReannounceTorrentWithRetry(ctx, release.TorrentHash, &opts); err != nil {
if errors.Is(err, qbittorrent.ErrReannounceTookTooLong) {
return []string{fmt.Sprintf("re-announce took too long for hash: %s", release.TorrentHash)}, nil
}
@ -120,7 +129,7 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
}
}
s.log.Info().Msgf("torrent with hash %s successfully added to client: '%s'", release.TorrentHash, c.Dc.Name)
s.log.Info().Msgf("torrent with hash %s successfully added to client: '%s'", release.TorrentHash, client.Name)
return nil, nil
}

View file

@ -17,40 +17,16 @@ func (s *service) radarr(ctx context.Context, action *domain.Action, release dom
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "error finding client: %v", action.ClientID)
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
// return early if no client found
if client == nil {
return nil, errors.New("could not find client by id: %v", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
// initial config
cfg := radarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
arr := client.Client.(radarr.Client)
r := radarr.Release{
Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) radarr(ctx context.Context, action *domain.Action, release dom
MagnetUrl: release.MagnetURI,
Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId,
DownloadClient: externalClient,
DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339),
}
arr := radarr.New(cfg)
if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r)
if err != nil {

View file

@ -17,40 +17,16 @@ func (s *service) readarr(ctx context.Context, action *domain.Action, release do
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "readarr could not find client: %v", action.ClientID)
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
// return early if no client found
if client == nil {
return nil, errors.New("no client found")
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
// initial config
cfg := readarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
arr := client.Client.(readarr.Client)
r := readarr.Release{
Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) readarr(ctx context.Context, action *domain.Action, release do
MagnetUrl: release.MagnetURI,
Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId,
DownloadClient: externalClient,
DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339),
}
arr := readarr.New(cfg)
if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r)
if err != nil {

View file

@ -16,32 +16,19 @@ import (
func (s *service) rtorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action rTorrent: %s", action.Name)
var err error
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID)
return nil, err
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
if client == nil {
return nil, errors.New("could not find client by id: %d", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
rt := client.Client.(*rtorrent.Client)
var rejections []string
// create config
cfg := rtorrent.Config{
Addr: client.Host,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
}
// create client
rt := rtorrent.NewClient(cfg)
if release.HasMagnetUri() {
var args []*rtorrent.FieldValue
@ -79,8 +66,8 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d
s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", client.Name)
return nil, nil
}
} else {
if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
@ -127,7 +114,6 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d
}
s.log.Info().Msgf("torrent successfully added to client: '%s'", client.Name)
}
return rejections, nil
}

View file

@ -19,7 +19,6 @@ import (
)
func (s *service) RunAction(ctx context.Context, action *domain.Action, release *domain.Release) ([]string, error) {
var (
err error
rejections []string
@ -33,6 +32,10 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release
}
}()
if action.ClientID > 0 && action.Client != nil && !action.Client.Enabled {
return nil, errors.New("action %s client %s %s not enabled, skipping", action.Name, action.Client.Type, action.Client.Name)
}
// if set, try to resolve MagnetURI before parsing macros
// to allow webhook and exec to get the magnet_uri
if err := release.ResolveMagnetUri(ctx); err != nil {

View file

@ -18,29 +18,16 @@ func (s *service) sabnzbd(ctx context.Context, action *domain.Action, release do
return nil, errors.New("action type: %s invalid protocol: %s", action.Type, release.Protocol)
}
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %d", action.ClientID)
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
// return early if no client found
if client == nil {
return nil, errors.New("no sabnzbd client found by id: %d", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
opts := sabnzbd.Options{
Addr: client.Host,
ApiKey: client.Settings.APIKey,
Log: nil,
}
if client.Settings.Basic.Auth {
opts.BasicUser = client.Settings.Basic.Username
opts.BasicPass = client.Settings.Basic.Password
}
sab := sabnzbd.New(opts)
sab := client.Client.(*sabnzbd.Client)
ids, err := sab.AddFromUrl(ctx, sabnzbd.AddNzbRequest{Url: release.DownloadURL, Category: action.Category})
if err != nil {

View file

@ -21,9 +21,10 @@ import (
type Service interface {
Store(ctx context.Context, action domain.Action) (*domain.Action, error)
StoreFilterActions(ctx context.Context, filterID int64, actions []*domain.Action) ([]*domain.Action, error)
List(ctx context.Context) ([]domain.Action, error)
Get(ctx context.Context, req *domain.GetActionRequest) (*domain.Action, error)
FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error)
FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error)
Delete(ctx context.Context, req *domain.DeleteActionRequest) error
DeleteByFilterID(ctx context.Context, filterID int) error
ToggleEnabled(actionID int) error
@ -63,6 +64,10 @@ func (s *service) Store(ctx context.Context, action domain.Action) (*domain.Acti
return s.repo.Store(ctx, action)
}
func (s *service) StoreFilterActions(ctx context.Context, filterID int64, actions []*domain.Action) ([]*domain.Action, error) {
return s.repo.StoreFilterActions(ctx, filterID, actions)
}
func (s *service) List(ctx context.Context) ([]domain.Action, error) {
return s.repo.List(ctx)
}
@ -86,8 +91,8 @@ func (s *service) Get(ctx context.Context, req *domain.GetActionRequest) (*domai
return a, nil
}
func (s *service) FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
return s.repo.FindByFilterID(ctx, filterID, active)
func (s *service) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error) {
return s.repo.FindByFilterID(ctx, filterID, active, withClient)
}
func (s *service) Delete(ctx context.Context, req *domain.DeleteActionRequest) error {

View file

@ -17,40 +17,16 @@ func (s *service) sonarr(ctx context.Context, action *domain.Action, release dom
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID)
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
// return early if no client found
if client == nil {
return nil, errors.New("no client found")
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
// initial config
cfg := sonarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
arr := client.Client.(sonarr.Client)
r := sonarr.Release{
Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) sonarr(ctx context.Context, action *domain.Action, release dom
MagnetUrl: release.MagnetURI,
Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId,
DownloadClient: externalClient,
DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339),
}
arr := sonarr.New(cfg)
if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r)
if err != nil {

View file

@ -6,14 +6,11 @@ package action
import (
"context"
"fmt"
"net/url"
"strings"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/transmission"
"github.com/hekmon/transmissionrpc/v3"
)
@ -28,38 +25,16 @@ var TrTrue = true
func (s *service) transmission(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action Transmission: %s", action.Name)
var err error
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID)
return nil, err
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
if client == nil {
return nil, errors.New("could not find client by id: %d", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
scheme := "http"
if client.TLS {
scheme = "https"
}
u, err := url.Parse(fmt.Sprintf("%s://%s:%d/transmission/rpc", scheme, client.Host, client.Port))
if err != nil {
return nil, err
}
tbt, err := transmission.New(u, &transmission.Config{
UserAgent: "autobrr",
Username: client.Username,
Password: client.Password,
TLSSkipVerify: client.TLSSkipVerify,
})
if err != nil {
return nil, errors.Wrap(err, "error logging into client: %s", client.Host)
}
tbt := client.Client.(*transmissionrpc.Client)
rejections, err := s.transmissionCheckRulesCanDownload(ctx, action, client, tbt)
if err != nil {

View file

@ -17,40 +17,16 @@ func (s *service) whisparr(ctx context.Context, action *domain.Action, release d
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(ctx, action.ClientID)
client, err := s.clientSvc.GetClient(ctx, action.ClientID)
if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID)
return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID)
}
// return early if no client found
if client == nil {
return nil, errors.New("could not find client by id: %v", action.ClientID)
if !client.Enabled {
return nil, errors.New("client %s %s not enabled", client.Type, client.Name)
}
// initial config
cfg := whisparr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: s.subLogger,
}
// only set basic auth if enabled
if client.Settings.Basic.Auth {
cfg.BasicAuth = client.Settings.Basic.Auth
cfg.Username = client.Settings.Basic.Username
cfg.Password = client.Settings.Basic.Password
}
externalClientId := client.Settings.ExternalDownloadClientId
if action.ExternalDownloadClientID > 0 {
externalClientId = int(action.ExternalDownloadClientID)
}
externalClient := client.Settings.ExternalDownloadClient
if action.ExternalDownloadClient != "" {
externalClient = action.ExternalDownloadClient
}
arr := client.Client.(whisparr.Client)
r := whisparr.Release{
Title: release.TorrentName,
@ -59,14 +35,20 @@ func (s *service) whisparr(ctx context.Context, action *domain.Action, release d
MagnetUrl: release.MagnetURI,
Size: int64(release.Size),
Indexer: release.Indexer.GetExternalIdentifier(),
DownloadClientId: externalClientId,
DownloadClient: externalClient,
DownloadClientId: client.Settings.ExternalDownloadClientId,
DownloadClient: client.Settings.ExternalDownloadClient,
DownloadProtocol: release.Protocol.String(),
Protocol: release.Protocol.String(),
PublishDate: time.Now().Format(time.RFC3339),
}
arr := whisparr.New(cfg)
if action.ExternalDownloadClientID > 0 {
r.DownloadClientId = int(action.ExternalDownloadClientID)
}
if action.ExternalDownloadClient != "" {
r.DownloadClient = action.ExternalDownloadClient
}
rejections, err := arr.Push(ctx, r)
if err != nil {

View file

@ -30,7 +30,266 @@ func NewActionRepo(log logger.Logger, db *DB, clientRepo domain.DownloadClientRe
}
}
func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error) {
if withClient {
return r.findByFilterIDWithClient(ctx, filterID, active)
}
return r.findByFilterID(ctx, filterID, active)
}
func (r *ActionRepo) findByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
queryBuilder := r.db.squirrel.
Select(
"a.id",
"a.name",
"a.type",
"a.enabled",
"a.exec_cmd",
"a.exec_args",
"a.watch_folder",
"a.category",
"a.tags",
"a.label",
"a.save_path",
"a.paused",
"a.ignore_rules",
"a.first_last_piece_prio",
"a.skip_hash_check",
"a.content_layout",
"a.priority",
"a.limit_download_speed",
"a.limit_upload_speed",
"a.limit_ratio",
"a.limit_seed_time",
"a.reannounce_skip",
"a.reannounce_delete",
"a.reannounce_interval",
"a.reannounce_max_attempts",
"a.webhook_host",
"a.webhook_type",
"a.webhook_method",
"a.webhook_data",
"a.external_client_id",
"a.external_client",
"a.client_id",
).
From("action a").
Where(sq.Eq{"a.filter_id": filterID})
if active != nil {
queryBuilder = queryBuilder.Where(sq.Eq{"enabled": *active})
}
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
rows, err := r.db.handler.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
defer rows.Close()
actions := make([]*domain.Action, 0)
for rows.Next() {
var a domain.Action
var execCmd, execArgs, watchFolder, category, tags, label, savePath, contentLayout, priorityLayout, webhookHost, webhookType, webhookMethod, webhookData, externalClient sql.NullString
var limitUl, limitDl, limitSeedTime sql.NullInt64
var limitRatio sql.NullFloat64
var externalClientID, clientID sql.NullInt32
var paused, ignoreRules sql.NullBool
if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &a.FirstLastPiecePrio, &a.SkipHashCheck, &contentLayout, &priorityLayout, &limitDl, &limitUl, &limitRatio, &limitSeedTime, &a.ReAnnounceSkip, &a.ReAnnounceDelete, &a.ReAnnounceInterval, &a.ReAnnounceMaxAttempts, &webhookHost, &webhookType, &webhookMethod, &webhookData, &externalClientID, &externalClient, &clientID); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
a.ExecCmd = execCmd.String
a.ExecArgs = execArgs.String
a.WatchFolder = watchFolder.String
a.Category = category.String
a.Tags = tags.String
a.Label = label.String
a.SavePath = savePath.String
a.Paused = paused.Bool
a.IgnoreRules = ignoreRules.Bool
a.ContentLayout = domain.ActionContentLayout(contentLayout.String)
a.PriorityLayout = domain.PriorityLayout(priorityLayout.String)
a.LimitDownloadSpeed = limitDl.Int64
a.LimitUploadSpeed = limitUl.Int64
a.LimitRatio = limitRatio.Float64
a.LimitSeedTime = limitSeedTime.Int64
a.WebhookHost = webhookHost.String
a.WebhookType = webhookType.String
a.WebhookMethod = webhookMethod.String
a.WebhookData = webhookData.String
a.ExternalDownloadClientID = externalClientID.Int32
a.ExternalDownloadClient = externalClient.String
a.ClientID = clientID.Int32
actions = append(actions, &a)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "row error")
}
return actions, nil
}
func (r *ActionRepo) findByFilterIDWithClient(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
queryBuilder := r.db.squirrel.
Select(
"a.id",
"a.name",
"a.type",
"a.enabled",
"a.exec_cmd",
"a.exec_args",
"a.watch_folder",
"a.category",
"a.tags",
"a.label",
"a.save_path",
"a.paused",
"a.ignore_rules",
"a.first_last_piece_prio",
"a.skip_hash_check",
"a.content_layout",
"a.priority",
"a.limit_download_speed",
"a.limit_upload_speed",
"a.limit_ratio",
"a.limit_seed_time",
"a.reannounce_skip",
"a.reannounce_delete",
"a.reannounce_interval",
"a.reannounce_max_attempts",
"a.webhook_host",
"a.webhook_type",
"a.webhook_method",
"a.webhook_data",
"a.external_client_id",
"a.external_client",
"a.client_id",
"c.id",
"c.name",
"c.type",
"c.enabled",
"c.host",
"c.port",
"c.tls",
"c.tls_skip_verify",
"c.username",
"c.password",
"c.settings",
).
From("action a").
Join("client c ON a.client_id = c.id").
Where(sq.Eq{"a.filter_id": filterID})
if active != nil {
queryBuilder = queryBuilder.Where(sq.Eq{"enabled": *active})
}
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
rows, err := r.db.handler.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
defer rows.Close()
actions := make([]*domain.Action, 0)
for rows.Next() {
var a domain.Action
var c domain.DownloadClient
var execCmd, execArgs, watchFolder, category, tags, label, savePath, contentLayout, priorityLayout, webhookHost, webhookType, webhookMethod, webhookData, externalClient sql.NullString
var limitUl, limitDl, limitSeedTime sql.NullInt64
var limitRatio sql.NullFloat64
var externalClientID, clientID sql.NullInt32
var paused, ignoreRules sql.NullBool
var clientClientId, clientPort sql.Null[int32]
var clientName, clientType, clientHost, clientUsername, clientPassword, clientSettings sql.Null[string]
var clientEnabled, clientTLS, clientTLSSkip sql.Null[bool]
if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &a.FirstLastPiecePrio, &a.SkipHashCheck, &contentLayout, &priorityLayout, &limitDl, &limitUl, &limitRatio, &limitSeedTime, &a.ReAnnounceSkip, &a.ReAnnounceDelete, &a.ReAnnounceInterval, &a.ReAnnounceMaxAttempts, &webhookHost, &webhookType, &webhookMethod, &webhookData, &externalClientID, &externalClient, &clientID, &clientClientId, &clientName, &clientType, &clientEnabled, &clientHost, &clientPort, &clientTLS, &clientTLSSkip, &clientUsername, &clientPassword, &clientSettings); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
a.ExecCmd = execCmd.String
a.ExecArgs = execArgs.String
a.WatchFolder = watchFolder.String
a.Category = category.String
a.Tags = tags.String
a.Label = label.String
a.SavePath = savePath.String
a.Paused = paused.Bool
a.IgnoreRules = ignoreRules.Bool
a.ContentLayout = domain.ActionContentLayout(contentLayout.String)
a.PriorityLayout = domain.PriorityLayout(priorityLayout.String)
a.LimitDownloadSpeed = limitDl.Int64
a.LimitUploadSpeed = limitUl.Int64
a.LimitRatio = limitRatio.Float64
a.LimitSeedTime = limitSeedTime.Int64
a.WebhookHost = webhookHost.String
a.WebhookType = webhookType.String
a.WebhookMethod = webhookMethod.String
a.WebhookData = webhookData.String
a.ExternalDownloadClientID = externalClientID.Int32
a.ExternalDownloadClient = externalClient.String
a.ClientID = clientID.Int32
c.ID = clientClientId.V
c.Name = clientName.V
c.Type = domain.DownloadClientType(clientType.V)
c.Enabled = clientEnabled.V
c.Host = clientHost.V
c.Port = int(clientPort.V)
c.TLS = clientTLS.V
c.TLSSkipVerify = clientTLSSkip.V
c.Username = clientUsername.V
c.Password = clientPassword.V
//c.Settings = clientSettings.String
if a.ClientID > 0 {
if clientSettings.Valid {
if err := json.Unmarshal([]byte(clientSettings.V), &c.Settings); err != nil {
return nil, errors.Wrap(err, "could not unmarshal download client settings: %v", clientSettings.V)
}
}
a.Client = &c
}
actions = append(actions, &a)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "row error")
}
return actions, nil
}
func (r *ActionRepo) FindByFilterIDTx(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) {
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return nil, err
@ -38,7 +297,7 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *b
defer tx.Rollback()
actions, err := r.findByFilterID(ctx, tx, filterID, active)
actions, err := r.findByFilterIDTx(ctx, tx, filterID, active)
if err != nil {
return nil, err
}
@ -59,7 +318,7 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *b
return actions, nil
}
func (r *ActionRepo) findByFilterID(ctx context.Context, tx *Tx, filterID int, active *bool) ([]*domain.Action, error) {
func (r *ActionRepo) findByFilterIDTx(ctx context.Context, tx *Tx, filterID int, active *bool) ([]*domain.Action, error) {
queryBuilder := r.db.squirrel.
Select(
"id",

View file

@ -62,9 +62,10 @@ func TestActionRepo_Store(t *testing.T) {
t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -73,7 +74,7 @@ func TestActionRepo_Store(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
// Actual test for Store
@ -84,7 +85,7 @@ func TestActionRepo_Store(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("Store_Succeeds_With_Missing_or_empty_fields [%s]", dbType), func(t *testing.T) {
@ -125,9 +126,10 @@ func TestActionRepo_StoreFilterActions(t *testing.T) {
t.Run(fmt.Sprintf("StoreFilterActions_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -136,7 +138,7 @@ func TestActionRepo_StoreFilterActions(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
// Actual test for StoreFilterActions
@ -148,7 +150,7 @@ func TestActionRepo_StoreFilterActions(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("StoreFilterActions_Fails_Invalid_FilterID [%s]", dbType), func(t *testing.T) {
@ -203,9 +205,10 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
t.Run(fmt.Sprintf("FindByFilterID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -214,13 +217,13 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err)
// Actual test for FindByFilterID
actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil)
actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil, false)
assert.NoError(t, err)
assert.NotNil(t, actions)
assert.Equal(t, 1, len(actions))
@ -228,7 +231,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("FindByFilterID_Fails_No_Actions [%s]", dbType), func(t *testing.T) {
@ -241,7 +244,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
assert.NotNil(t, createdFilters)
// Actual test for FindByFilterID
actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil)
actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil, false)
assert.NoError(t, err)
assert.Equal(t, 0, len(actions))
@ -250,7 +253,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
})
t.Run(fmt.Sprintf("FindByFilterID_Succeeds_With_Invalid_FilterID [%s]", dbType), func(t *testing.T) {
actions, err := repo.FindByFilterID(context.Background(), 9999, nil) // 9999 is an invalid filter ID
actions, err := repo.FindByFilterID(context.Background(), 9999, nil, false) // 9999 is an invalid filter ID
assert.NoError(t, err)
assert.NotNil(t, actions)
assert.Equal(t, 0, len(actions))
@ -260,7 +263,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
actions, err := repo.FindByFilterID(ctx, 1, nil)
actions, err := repo.FindByFilterID(ctx, 1, nil, false)
assert.Error(t, err)
assert.Nil(t, actions)
})
@ -277,9 +280,10 @@ func TestActionRepo_List(t *testing.T) {
t.Run(fmt.Sprintf("List_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -288,7 +292,7 @@ func TestActionRepo_List(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err)
@ -302,7 +306,7 @@ func TestActionRepo_List(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("List_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) {
@ -326,9 +330,10 @@ func TestActionRepo_Get(t *testing.T) {
t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -337,7 +342,7 @@ func TestActionRepo_Get(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err)
@ -351,7 +356,7 @@ func TestActionRepo_Get(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("Get_Fails_No_Record [%s]", dbType), func(t *testing.T) {
@ -382,9 +387,10 @@ func TestActionRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -393,7 +399,7 @@ func TestActionRepo_Delete(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err)
@ -411,7 +417,7 @@ func TestActionRepo_Delete(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("Delete_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) {
@ -435,9 +441,10 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) {
t.Run(fmt.Sprintf("DeleteByFilterID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -446,7 +453,7 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
assert.NoError(t, err)
@ -463,7 +470,7 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("DeleteByFilterID_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) {
@ -486,9 +493,10 @@ func TestActionRepo_ToggleEnabled(t *testing.T) {
t.Run(fmt.Sprintf("ToggleEnabled_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -497,7 +505,7 @@ func TestActionRepo_ToggleEnabled(t *testing.T) {
assert.NoError(t, err)
assert.NotNil(t, createdFilters)
mockData.ClientID = int32(createdClient.ID)
mockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
mockData.Enabled = false
createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData})
@ -515,7 +523,7 @@ func TestActionRepo_ToggleEnabled(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("ToggleEnabled_Fails_No_Record [%s]", dbType), func(t *testing.T) {

View file

@ -7,7 +7,6 @@ import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
@ -20,53 +19,16 @@ import (
type DownloadClientRepo struct {
log zerolog.Logger
db *DB
cache *clientCache
}
type clientCache struct {
mu sync.RWMutex
clients map[int]*domain.DownloadClient
}
func NewClientCache() *clientCache {
return &clientCache{
clients: make(map[int]*domain.DownloadClient, 0),
}
}
func (c *clientCache) Set(id int, client *domain.DownloadClient) {
c.mu.Lock()
c.clients[id] = client
c.mu.Unlock()
}
func (c *clientCache) Get(id int) *domain.DownloadClient {
c.mu.RLock()
defer c.mu.RUnlock()
v, ok := c.clients[id]
if ok {
return v
}
return nil
}
func (c *clientCache) Pop(id int) {
c.mu.Lock()
delete(c.clients, id)
c.mu.Unlock()
}
func NewDownloadClientRepo(log logger.Logger, db *DB) domain.DownloadClientRepo {
return &DownloadClientRepo{
log: log.With().Str("repo", "action").Logger(),
db: db,
cache: NewClientCache(),
}
}
func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) {
clients := make([]domain.DownloadClient, 0)
queryBuilder := r.db.squirrel.
Select(
"id",
@ -100,6 +62,8 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient,
}
}(rows)
clients := make([]domain.DownloadClient, 0)
for rows.Next() {
var f domain.DownloadClient
var settingsJsonStr string
@ -124,12 +88,6 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient,
}
func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
// get client from cache
c := r.cache.Get(int(id))
if c != nil {
return c, nil
}
queryBuilder := r.db.squirrel.
Select(
"id",
@ -153,7 +111,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
}
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err != nil {
if err := row.Err(); err != nil {
return nil, errors.Wrap(err, "error executing query")
}
@ -177,9 +135,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return &client, nil
}
func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
var err error
func (r *DownloadClientRepo) Store(ctx context.Context, client *domain.DownloadClient) error {
settings := domain.DownloadClientSettings{
APIKey: client.Settings.APIKey,
Basic: client.Settings.Basic,
@ -190,7 +146,7 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl
settingsJson, err := json.Marshal(&settings)
if err != nil {
return nil, errors.Wrap(err, "error marshal download client settings %+v", settings)
return errors.Wrap(err, "error marshal download client settings %+v", settings)
}
queryBuilder := r.db.squirrel.
@ -204,22 +160,17 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl
err = queryBuilder.QueryRowContext(ctx).Scan(&retID)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
return errors.Wrap(err, "error executing query")
}
client.ID = retID
client.ID = int32(retID)
r.log.Debug().Msgf("download_client.store: %d", client.ID)
// save to cache
r.cache.Set(client.ID, &client)
return &client, nil
return nil
}
func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
var err error
func (r *DownloadClientRepo) Update(ctx context.Context, client *domain.DownloadClient) error {
settings := domain.DownloadClientSettings{
APIKey: client.Settings.APIKey,
Basic: client.Settings.Basic,
@ -230,7 +181,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
settingsJson, err := json.Marshal(&settings)
if err != nil {
return nil, errors.Wrap(err, "error marshal download client settings %+v", settings)
return errors.Wrap(err, "error marshal download client settings %+v", settings)
}
queryBuilder := r.db.squirrel.
@ -249,32 +200,29 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
return errors.Wrap(err, "error building query")
}
result, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, errors.Wrap(err, "error getting rows affected")
return errors.Wrap(err, "error getting rows affected")
}
if rowsAffected == 0 {
return nil, errors.New("no rows updated")
return errors.New("no rows updated")
}
r.log.Debug().Msgf("download_client.update: %d", client.ID)
// save to cache
r.cache.Set(client.ID, &client)
return &client, nil
return nil
}
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int32) error {
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{})
if err != nil {
return err
@ -311,10 +259,11 @@ func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
}
r.log.Debug().Msgf("delete download client: %d", clientID)
return nil
}
func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) error {
func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int32) error {
queryBuilder := r.db.squirrel.
Delete("client").
Where(sq.Eq{"id": clientID})
@ -329,9 +278,6 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e
return errors.Wrap(err, "error executing query")
}
// remove from cache
r.cache.Pop(clientID)
rows, _ := res.RowsAffected()
if rows == 0 {
return errors.New("no rows affected")
@ -342,9 +288,7 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e
return nil
}
func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int) error {
var err error
func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int32) error {
queryBuilder := r.db.squirrel.
Update("action").
Set("enabled", false).
@ -355,12 +299,14 @@ func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx,
// return values
var filterID int
if err = queryBuilder.QueryRowContext(ctx).Scan(&filterID); err != nil {
err := queryBuilder.QueryRowContext(ctx).Scan(&filterID)
if err != nil {
// this will throw when the client is not connected to any actions
// it is not an error in this case
if errors.Is(err, sql.ErrNoRows) {
return nil
}
return errors.Wrap(err, "error executing query")
}

View file

@ -8,10 +8,12 @@ package database
import (
"context"
"fmt"
"github.com/autobrr/autobrr/internal/domain"
"github.com/stretchr/testify/assert"
"testing"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/stretchr/testify/assert"
)
func getMockDownloadClient() domain.DownloadClient {
@ -54,13 +56,14 @@ func TestDownloadClientRepo_List(t *testing.T) {
t.Run(fmt.Sprintf("List_Succeeds_With_No_Filters [%s]", dbType), func(t *testing.T) {
// Insert mock data
createdClient, err := repo.Store(context.Background(), mockData)
mock := &mockData
err := repo.Store(context.Background(), mock)
clients, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotEmpty(t, clients)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("List_Succeeds_With_Empty_Database [%s]", dbType), func(t *testing.T) {
@ -77,32 +80,34 @@ func TestDownloadClientRepo_List(t *testing.T) {
})
t.Run(fmt.Sprintf("List_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) {
createdClient, err := repo.Store(context.Background(), mockData)
mock := &mockData
err := repo.Store(context.Background(), mock)
clients, err := repo.List(context.Background())
assert.NoError(t, err)
assert.Equal(t, 1, len(clients))
assert.Equal(t, createdClient.Name, clients[0].Name)
assert.Equal(t, mock.Name, clients[0].Name)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("List_Succeeds_With_Boundary_Value_For_Port [%s]", dbType), func(t *testing.T) {
mockData.Port = 65535
createdClient, err := repo.Store(context.Background(), mockData)
mock := &mockData
mock.Port = 65535
err := repo.Store(context.Background(), mock)
clients, err := repo.List(context.Background())
assert.NoError(t, err)
assert.Equal(t, 65535, clients[0].Port)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("List_Succeeds_With_Boolean_Flags_Set_To_False [%s]", dbType), func(t *testing.T) {
mockData.Enabled = false
mockData.TLS = false
mockData.TLSSkipVerify = false
createdClient, err := repo.Store(context.Background(), mockData)
err := repo.Store(context.Background(), &mockData)
clients, err := repo.List(context.Background())
assert.NoError(t, err)
assert.Equal(t, false, clients[0].Enabled)
@ -110,18 +115,18 @@ func TestDownloadClientRepo_List(t *testing.T) {
assert.Equal(t, false, clients[0].TLSSkipVerify)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mockData.ID)
})
t.Run(fmt.Sprintf("List_Succeeds_With_Special_Characters_In_Name [%s]", dbType), func(t *testing.T) {
mockData.Name = "Special$Name"
createdClient, err := repo.Store(context.Background(), mockData)
err := repo.Store(context.Background(), &mockData)
clients, err := repo.List(context.Background())
assert.NoError(t, err)
assert.Equal(t, "Special$Name", clients[0].Name)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mockData.ID)
})
}
}
@ -133,13 +138,14 @@ func TestDownloadClientRepo_FindByID(t *testing.T) {
mockData := getMockDownloadClient()
t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData)
foundClient, err := repo.FindByID(context.Background(), int32(createdClient.ID))
mock := &mockData
_ = repo.Store(context.Background(), mock)
foundClient, err := repo.FindByID(context.Background(), mock.ID)
assert.NoError(t, err)
assert.NotNil(t, foundClient)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("FindByID_Fails_With_Nonexistent_ID [%s]", dbType), func(t *testing.T) {
@ -156,40 +162,44 @@ func TestDownloadClientRepo_FindByID(t *testing.T) {
t.Run(fmt.Sprintf("FindByID_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
_, err := repo.FindByID(ctx, 1)
assert.Error(t, err)
})
t.Run(fmt.Sprintf("FindByID_Fails_After_Client_Deleted [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData)
_ = repo.Delete(context.Background(), createdClient.ID)
_, err := repo.FindByID(context.Background(), int32(createdClient.ID))
mock := &mockData
_ = repo.Store(context.Background(), mock)
_ = repo.Delete(context.Background(), mock.ID)
_, err := repo.FindByID(context.Background(), mock.ID)
assert.Error(t, err)
assert.Equal(t, "no client configured", err.Error())
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("FindByID_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData)
foundClient, err := repo.FindByID(context.Background(), int32(createdClient.ID))
mock := &mockData
_ = repo.Store(context.Background(), mock)
foundClient, err := repo.FindByID(context.Background(), mock.ID)
assert.NoError(t, err)
assert.Equal(t, createdClient.Name, foundClient.Name)
assert.Equal(t, mock.Name, foundClient.Name)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mock.ID)
})
t.Run(fmt.Sprintf("FindByID_Succeeds_From_Cache [%s]", dbType), func(t *testing.T) {
createdClient, _ := repo.Store(context.Background(), mockData)
foundClient1, _ := repo.FindByID(context.Background(), int32(createdClient.ID))
foundClient2, err := repo.FindByID(context.Background(), int32(createdClient.ID))
mock := &mockData
_ = repo.Store(context.Background(), mock)
foundClient1, _ := repo.FindByID(context.Background(), mock.ID)
foundClient2, err := repo.FindByID(context.Background(), mock.ID)
assert.NoError(t, err)
assert.Equal(t, foundClient1, foundClient2)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mock.ID)
})
}
}
@ -201,17 +211,17 @@ func TestDownloadClientRepo_Store(t *testing.T) {
t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) {
mockData := getMockDownloadClient()
createdClient, err := repo.Store(context.Background(), mockData)
err := repo.Store(context.Background(), &mockData)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mockData)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mockData.ID)
})
//TODO: Is this okay? Should we be able to store a client with no name (empty string)?
t.Run(fmt.Sprintf("Store_Succeeds?_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) {
badMockData := domain.DownloadClient{
badMockData := &domain.DownloadClient{
Type: "",
Enabled: false,
Host: "",
@ -222,30 +232,30 @@ func TestDownloadClientRepo_Store(t *testing.T) {
Password: "",
Settings: domain.DownloadClientSettings{},
}
createdClient, err := repo.Store(context.Background(), badMockData)
err := repo.Store(context.Background(), badMockData)
assert.NoError(t, err)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), badMockData.ID)
})
t.Run(fmt.Sprintf("Store_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
mockData := getMockDownloadClient()
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
_, err := repo.Store(ctx, mockData)
err := repo.Store(ctx, &mockData)
assert.Error(t, err)
})
t.Run(fmt.Sprintf("Store_Succeeds_And_Caches [%s]", dbType), func(t *testing.T) {
mockData := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockData)
_ = repo.Store(context.Background(), &mockData)
cachedClient, _ := repo.FindByID(context.Background(), int32(createdClient.ID))
assert.Equal(t, createdClient, cachedClient)
cachedClient, _ := repo.FindByID(context.Background(), mockData.ID)
assert.Equal(t, &mockData, cachedClient)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mockData.ID)
})
}
}
@ -258,22 +268,22 @@ func TestDownloadClientRepo_Update(t *testing.T) {
t.Run(fmt.Sprintf("Update_Successfully_Updates_Record [%s]", dbType), func(t *testing.T) {
mockClient := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockClient)
createdClient.Name = "updatedName"
updatedClient, err := repo.Update(context.Background(), *createdClient)
_ = repo.Store(context.Background(), &mockClient)
mockClient.Name = "updatedName"
err := repo.Update(context.Background(), &mockClient)
assert.NoError(t, err)
assert.Equal(t, "updatedName", updatedClient.Name)
assert.Equal(t, "updatedName", mockClient.Name)
// Cleanup
_ = repo.Delete(context.Background(), updatedClient.ID)
_ = repo.Delete(context.Background(), mockClient.ID)
})
t.Run(fmt.Sprintf("Update_Fails_With_Missing_ID [%s]", dbType), func(t *testing.T) {
badMockData := getMockDownloadClient()
badMockData.ID = 0
_, err := repo.Update(context.Background(), badMockData)
err := repo.Update(context.Background(), &badMockData)
assert.Error(t, err)
@ -283,7 +293,7 @@ func TestDownloadClientRepo_Update(t *testing.T) {
badMockData := getMockDownloadClient()
badMockData.ID = 9999
_, err := repo.Update(context.Background(), badMockData)
err := repo.Update(context.Background(), &badMockData)
assert.Error(t, err)
})
@ -291,7 +301,7 @@ func TestDownloadClientRepo_Update(t *testing.T) {
t.Run(fmt.Sprintf("Update_Fails_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) {
badMockData := domain.DownloadClient{}
_, err := repo.Update(context.Background(), badMockData)
err := repo.Update(context.Background(), &badMockData)
assert.Error(t, err)
})
@ -305,13 +315,13 @@ func TestDownloadClientRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Successfully_Deletes_Client [%s]", dbType), func(t *testing.T) {
mockClient := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockClient)
_ = repo.Store(context.Background(), &mockClient)
err := repo.Delete(context.Background(), createdClient.ID)
err := repo.Delete(context.Background(), mockClient.ID)
assert.NoError(t, err)
// Verify client was deleted
_, err = repo.FindByID(context.Background(), int32(createdClient.ID))
_, err = repo.FindByID(context.Background(), mockClient.ID)
assert.Error(t, err)
})
@ -322,16 +332,16 @@ func TestDownloadClientRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
mockClient := getMockDownloadClient()
createdClient, _ := repo.Store(context.Background(), mockClient)
_ = repo.Store(context.Background(), &mockClient)
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
err := repo.Delete(ctx, createdClient.ID)
err := repo.Delete(ctx, mockClient.ID)
assert.Error(t, err)
// Cleanup
_ = repo.Delete(context.Background(), createdClient.ID)
_ = repo.Delete(context.Background(), mockClient.ID)
})
}
}

View file

@ -255,13 +255,12 @@ func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter
}
row := r.db.handler.QueryRowContext(ctx, query, args...)
if row.Err() != nil {
if errors.Is(row.Err(), sql.ErrNoRows) {
if err := row.Err(); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(row.Err(), "error row")
return nil, errors.Wrap(err, "error row")
}
var f domain.Filter

View file

@ -791,12 +791,14 @@ func TestFilterRepo_GetDownloadsByFilterId(t *testing.T) {
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mockClient := getMockDownloadClient()
err = downloadClientRepo.Store(context.Background(), &mockClient)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mockClient)
mockAction.FilterID = mockData.ID
mockAction.ClientID = int32(createdClient.ID)
mockAction.ClientID = mockClient.ID
action, err := actionRepo.Store(context.Background(), mockAction)
@ -827,7 +829,7 @@ func TestFilterRepo_GetDownloadsByFilterId(t *testing.T) {
// Cleanup
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: action.ID})
_ = repo.Delete(context.Background(), mockData.ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mockClient.ID)
_ = releaseRepo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
})

View file

@ -89,9 +89,10 @@ func TestReleaseRepo_Store(t *testing.T) {
t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -101,7 +102,7 @@ func TestReleaseRepo_Store(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
// Execute
@ -124,7 +125,7 @@ func TestReleaseRepo_Store(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -144,9 +145,10 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) {
t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -156,7 +158,7 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
// Execute
@ -179,7 +181,7 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -199,9 +201,10 @@ func TestReleaseRepo_Find(t *testing.T) {
t.Run(fmt.Sprintf("FindReleases_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -211,7 +214,7 @@ func TestReleaseRepo_Find(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
// Execute
@ -238,7 +241,7 @@ func TestReleaseRepo_Find(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -258,9 +261,10 @@ func TestReleaseRepo_FindRecent(t *testing.T) {
t.Run(fmt.Sprintf("FindRecent_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -270,7 +274,7 @@ func TestReleaseRepo_FindRecent(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
// Execute
@ -286,7 +290,7 @@ func TestReleaseRepo_FindRecent(t *testing.T) {
// Cleanup
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -306,9 +310,10 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) {
t.Run(fmt.Sprintf("GetIndexerOptions_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -318,7 +323,7 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData)
@ -344,7 +349,7 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -364,9 +369,10 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) {
t.Run(fmt.Sprintf("GetActionStatusByReleaseID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -376,7 +382,7 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData)
@ -403,7 +409,7 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -423,9 +429,10 @@ func TestReleaseRepo_Get(t *testing.T) {
t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -435,7 +442,7 @@ func TestReleaseRepo_Get(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData)
@ -462,7 +469,7 @@ func TestReleaseRepo_Get(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -482,9 +489,10 @@ func TestReleaseRepo_Stats(t *testing.T) {
t.Run(fmt.Sprintf("Stats_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -494,7 +502,7 @@ func TestReleaseRepo_Stats(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData)
@ -521,7 +529,7 @@ func TestReleaseRepo_Stats(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -541,9 +549,10 @@ func TestReleaseRepo_Delete(t *testing.T) {
t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -553,7 +562,7 @@ func TestReleaseRepo_Delete(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData)
@ -577,7 +586,7 @@ func TestReleaseRepo_Delete(t *testing.T) {
// Cleanup
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}
@ -597,9 +606,10 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) {
t.Run(fmt.Sprintf("Check_Smart_Episode_Can_Download [%s]", dbType), func(t *testing.T) {
// Setup
createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient())
mock := getMockDownloadClient()
err := downloadClientRepo.Store(context.Background(), &mock)
assert.NoError(t, err)
assert.NotNil(t, createdClient)
assert.NotNil(t, mock)
err = filterRepo.Store(context.Background(), getMockFilter())
assert.NoError(t, err)
@ -609,7 +619,7 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) {
assert.NotNil(t, createdFilters)
actionMockData.FilterID = createdFilters[0].ID
actionMockData.ClientID = int32(createdClient.ID)
actionMockData.ClientID = mock.ID
mockData.FilterID = createdFilters[0].ID
err = repo.Store(context.Background(), mockData)
@ -644,7 +654,7 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) {
_ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0})
_ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID})
_ = filterRepo.Delete(context.Background(), createdFilters[0].ID)
_ = downloadClientRepo.Delete(context.Background(), createdClient.ID)
_ = downloadClientRepo.Delete(context.Background(), mock.ID)
})
}
}

View file

@ -14,7 +14,7 @@ import (
type ActionRepo interface {
Store(ctx context.Context, action Action) (*Action, error)
StoreFilterActions(ctx context.Context, filterID int64, actions []*Action) ([]*Action, error)
FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*Action, error)
FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*Action, error)
List(ctx context.Context) ([]Action, error)
Get(ctx context.Context, req *GetActionRequest) (*Action, error)
Delete(ctx context.Context, req *DeleteActionRequest) error

View file

@ -9,20 +9,18 @@ import (
"net/url"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/go-qbittorrent"
)
type DownloadClientRepo interface {
List(ctx context.Context) ([]DownloadClient, error)
FindByID(ctx context.Context, id int32) (*DownloadClient, error)
Store(ctx context.Context, client DownloadClient) (*DownloadClient, error)
Update(ctx context.Context, client DownloadClient) (*DownloadClient, error)
Delete(ctx context.Context, clientID int) error
Store(ctx context.Context, client *DownloadClient) error
Update(ctx context.Context, client *DownloadClient) error
Delete(ctx context.Context, clientID int32) error
}
type DownloadClient struct {
ID int `json:"id"`
ID int32 `json:"id"`
Name string `json:"name"`
Type DownloadClientType `json:"type"`
Enabled bool `json:"enabled"`
@ -33,11 +31,9 @@ type DownloadClient struct {
Username string `json:"username"`
Password string `json:"password"`
Settings DownloadClientSettings `json:"settings,omitempty"`
}
type DownloadClientCached struct {
Dc *DownloadClient
Qbt *qbittorrent.Client
// cached http client
Client any
}
type DownloadClientSettings struct {

View file

@ -11,7 +11,7 @@ import (
func TestDownloadClient_qbitBuildLegacyHost(t *testing.T) {
type fields struct {
ID int
ID int32
Name string
Type DownloadClientType
Enabled bool

View file

@ -0,0 +1,48 @@
package download_client
import (
"sync"
"github.com/autobrr/autobrr/internal/domain"
)
type ClientCacheStore interface {
Set(id int32, client *domain.DownloadClient)
Get(id int32) *domain.DownloadClient
Pop(id int32)
}
type ClientCache struct {
mu sync.RWMutex
clients map[int32]*domain.DownloadClient
}
func NewClientCache() *ClientCache {
return &ClientCache{
clients: make(map[int32]*domain.DownloadClient),
}
}
func (c *ClientCache) Set(id int32, client *domain.DownloadClient) {
if client != nil {
c.mu.Lock()
c.clients[id] = client
c.mu.Unlock()
}
}
func (c *ClientCache) Get(id int32) *domain.DownloadClient {
c.mu.RLock()
defer c.mu.RUnlock()
v, ok := c.clients[id]
if ok {
return v
}
return nil
}
func (c *ClientCache) Pop(id int32) {
c.mu.Lock()
delete(c.clients, id)
c.mu.Unlock()
}

View file

@ -5,13 +5,27 @@ package download_client
import (
"context"
"fmt"
"log"
"net/url"
"sync"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/lidarr"
"github.com/autobrr/autobrr/pkg/porla"
"github.com/autobrr/autobrr/pkg/radarr"
"github.com/autobrr/autobrr/pkg/readarr"
"github.com/autobrr/autobrr/pkg/sabnzbd"
"github.com/autobrr/autobrr/pkg/sonarr"
"github.com/autobrr/autobrr/pkg/transmission"
"github.com/autobrr/autobrr/pkg/whisparr"
"github.com/autobrr/go-deluge"
"github.com/autobrr/go-qbittorrent"
"github.com/autobrr/go-rtorrent"
"github.com/dcarbone/zadapters/zstdlog"
"github.com/rs/zerolog"
)
@ -19,12 +33,12 @@ import (
type Service interface {
List(ctx context.Context) ([]domain.DownloadClient, error)
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(ctx context.Context, clientID int) error
Store(ctx context.Context, client *domain.DownloadClient) error
Update(ctx context.Context, client *domain.DownloadClient) error
Delete(ctx context.Context, clientID int32) error
Test(ctx context.Context, client domain.DownloadClient) error
GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached
GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error)
}
type service struct {
@ -32,7 +46,7 @@ type service struct {
repo domain.DownloadClientRepo
subLogger *log.Logger
qbitClients map[int32]*domain.DownloadClientCached
cache *ClientCache
m sync.RWMutex
}
@ -41,7 +55,7 @@ func NewService(log logger.Logger, repo domain.DownloadClientRepo) Service {
log: log.With().Str("module", "download_client").Logger(),
repo: repo,
qbitClients: map[int32]*domain.DownloadClientCached{},
cache: NewClientCache(),
m: sync.RWMutex{},
}
@ -61,6 +75,13 @@ func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) {
}
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
client := s.cache.Get(id)
if client != nil {
return client, nil
}
s.log.Trace().Msgf("cache miss for client id %d, continue to repo lookup", id)
client, err := s.repo.FindByID(ctx, id)
if err != nil {
s.log.Error().Err(err).Msgf("could not find download client by id: %v", id)
@ -70,53 +91,49 @@ func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClien
return client, nil
}
func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
func (s *service) Store(ctx context.Context, client *domain.DownloadClient) error {
// basic validation of client
if err := client.Validate(); err != nil {
return nil, err
return err
}
// store
c, err := s.repo.Store(ctx, client)
err := s.repo.Store(ctx, client)
if err != nil {
s.log.Error().Err(err).Msgf("could not store download client: %+v", client)
return nil, err
return err
}
return c, err
s.cache.Set(client.ID, client)
return err
}
func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
func (s *service) Update(ctx context.Context, client *domain.DownloadClient) error {
// basic validation of client
if err := client.Validate(); err != nil {
return nil, err
return err
}
// update
c, err := s.repo.Update(ctx, client)
err := s.repo.Update(ctx, client)
if err != nil {
s.log.Error().Err(err).Msgf("could not update download client: %+v", client)
return nil, err
return err
}
if client.Type == domain.DownloadClientTypeQbittorrent {
s.m.Lock()
delete(s.qbitClients, int32(client.ID))
s.m.Unlock()
}
s.cache.Set(client.ID, client)
return c, err
return err
}
func (s *service) Delete(ctx context.Context, clientID int) error {
func (s *service) Delete(ctx context.Context, clientID int32) error {
if err := s.repo.Delete(ctx, clientID); err != nil {
s.log.Error().Err(err).Msgf("could not delete download client: %v", clientID)
return err
}
s.m.Lock()
delete(s.qbitClients, int32(clientID))
s.m.Unlock()
s.cache.Pop(clientID)
return nil
}
@ -136,53 +153,165 @@ func (s *service) Test(ctx context.Context, client domain.DownloadClient) error
return nil
}
func (s *service) GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached {
// check if client exists in cache
s.m.RLock()
cached, ok := s.qbitClients[clientId]
s.m.RUnlock()
if ok {
return cached
}
// get client for action
client, err := s.FindByID(ctx, clientId)
if err != nil {
return nil
}
// GetClient get client from cache or repo and attach downloadClient implementation
func (s *service) GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error) {
l := s.log.With().Str("cache", "download-client").Logger()
client := s.cache.Get(clientId)
if client == nil {
return nil
l.Trace().Msgf("cache miss for client id %d, continue to repo lookup", clientId)
var err error
client, err = s.repo.FindByID(ctx, clientId)
if err != nil {
return nil, errors.Wrap(err, "could not find client repo.FindByID")
}
}
qbtSettings := qbittorrent.Config{
// if we have the client return it
if client.Client != nil {
l.Trace().Msgf("cache hit for client id %d %s", clientId, client.Name)
return client, nil
}
l.Trace().Msgf("init cache client id %d %s", clientId, client.Name)
switch client.Type {
case domain.DownloadClientTypeQbittorrent:
client.Client = qbittorrent.NewClient(qbittorrent.Config{
Host: client.BuildLegacyHost(),
Username: client.Username,
Password: client.Password,
TLSSkipVerify: client.TLSSkipVerify,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
})
case domain.DownloadClientTypePorla:
client.Client = porla.NewClient(porla.Config{
Hostname: client.Host,
AuthToken: client.Settings.APIKey,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Porla").Str("client", client.Name).Logger(), zerolog.TraceLevel),
})
case domain.DownloadClientTypeDelugeV1:
client.Client = deluge.NewV1(deluge.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 60,
})
case domain.DownloadClientTypeDelugeV2:
client.Client = deluge.NewV2(deluge.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Login: client.Username,
Password: client.Password,
DebugServerResponses: true,
ReadWriteTimeout: time.Second * 60,
})
case domain.DownloadClientTypeTransmission:
scheme := "http"
if client.TLS {
scheme = "https"
}
// setup sub logger adapter which is compatible with *log.Logger
qbtSettings.Log = zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel)
// only set basic auth if enabled
if client.Settings.Basic.Auth {
qbtSettings.BasicUser = client.Settings.Basic.Username
qbtSettings.BasicPass = client.Settings.Basic.Password
transmissionURL, err := url.Parse(fmt.Sprintf("%s://%s:%d/transmission/rpc", scheme, client.Host, client.Port))
if err != nil {
return nil, errors.Wrap(err, "could not parse transmission url")
}
qc := &domain.DownloadClientCached{
Dc: client,
Qbt: qbittorrent.NewClient(qbtSettings),
tbt, err := transmission.New(transmissionURL, &transmission.Config{
UserAgent: "autobrr",
Username: client.Username,
Password: client.Password,
TLSSkipVerify: client.TLSSkipVerify,
})
if err != nil {
return nil, errors.Wrap(err, "error logging into transmission client: %s", client.Host)
}
client.Client = tbt
case domain.DownloadClientTypeRTorrent:
client.Client = rtorrent.NewClient(rtorrent.Config{
Addr: client.Host,
TLSSkipVerify: client.TLSSkipVerify,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "rTorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel),
})
case domain.DownloadClientTypeLidarr:
client.Client = lidarr.New(lidarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Lidarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeRadarr:
client.Client = radarr.New(radarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Radarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeReadarr:
client.Client = readarr.New(readarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Readarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeSonarr:
client.Client = sonarr.New(sonarr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Sonarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeWhisparr:
client.Client = whisparr.New(whisparr.Config{
Hostname: client.Host,
APIKey: client.Settings.APIKey,
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Whisparr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
BasicAuth: client.Settings.Basic.Auth,
Username: client.Settings.Basic.Username,
Password: client.Settings.Basic.Password,
})
case domain.DownloadClientTypeSabnzbd:
client.Client = sabnzbd.New(sabnzbd.Options{
Addr: client.Host,
ApiKey: client.Settings.APIKey,
Log: nil,
BasicUser: client.Settings.Basic.Username,
BasicPass: client.Settings.Basic.Password,
})
}
cached = qc
l.Trace().Msgf("set cache client id %d %s", clientId, client.Name)
s.m.Lock()
s.qbitClients[clientId] = cached
s.m.Unlock()
s.cache.Set(clientId, client)
return cached
return client, nil
}

View file

@ -7,6 +7,7 @@ import (
"bytes"
"context"
"fmt"
"github.com/autobrr/autobrr/internal/action"
"io"
"net/http"
"os"
@ -48,7 +49,7 @@ type Service interface {
type service struct {
log zerolog.Logger
repo domain.FilterRepo
actionRepo domain.ActionRepo
actionService action.Service
releaseRepo domain.ReleaseRepo
indexerSvc indexer.Service
apiService indexer.APIService
@ -56,12 +57,12 @@ type service struct {
httpClient *http.Client
}
func NewService(log logger.Logger, repo domain.FilterRepo, actionRepo domain.ActionRepo, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service {
func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service {
return &service{
log: log.With().Str("module", "filter").Logger(),
repo: repo,
actionRepo: actionRepo,
releaseRepo: releaseRepo,
actionService: actionSvc,
apiService: apiService,
indexerSvc: indexerSvc,
httpClient: &http.Client{
@ -130,7 +131,7 @@ func (s *service) FindByID(ctx context.Context, filterID int) (*domain.Filter, e
}
filter.External = externalFilters
actions, err := s.actionRepo.FindByFilterID(ctx, filter.ID, nil)
actions, err := s.actionService.FindByFilterID(ctx, filter.ID, nil, false)
if err != nil {
s.log.Error().Err(err).Msgf("could not find filter actions for filter id: %v", filter.ID)
}
@ -222,7 +223,7 @@ func (s *service) Update(ctx context.Context, filter *domain.Filter) error {
}
// take care of filter actions
actions, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions)
actions, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions)
if err != nil {
s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name)
return err
@ -267,7 +268,7 @@ func (s *service) UpdatePartial(ctx context.Context, filter domain.FilterUpdate)
if filter.Actions != nil {
// take care of filter actions
if _, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil {
if _, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil {
s.log.Error().Err(err).Msgf("could not store filter actions: %v", filter.ID)
return err
}
@ -308,7 +309,7 @@ func (s *service) Duplicate(ctx context.Context, filterID int) (*domain.Filter,
}
// take care of filter actions
if _, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil {
if _, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil {
s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name)
return nil, err
}
@ -340,7 +341,7 @@ func (s *service) Delete(ctx context.Context, filterID int) error {
}
// take care of filter actions
if err := s.actionRepo.DeleteByFilterID(ctx, filterID); err != nil {
if err := s.actionService.DeleteByFilterID(ctx, filterID); err != nil {
s.log.Error().Err(err).Msg("could not delete filter actions")
return err
}

View file

@ -17,9 +17,9 @@ import (
type downloadClientService interface {
List(ctx context.Context) ([]domain.DownloadClient, error)
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(ctx context.Context, clientID int) error
Store(ctx context.Context, client *domain.DownloadClient) error
Update(ctx context.Context, client *domain.DownloadClient) error
Delete(ctx context.Context, clientID int32) error
Test(ctx context.Context, client domain.DownloadClient) error
}
@ -56,20 +56,20 @@ func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *htt
}
func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) {
var data domain.DownloadClient
var data *domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
client, err := h.service.Store(r.Context(), data)
err := h.service.Store(r.Context(), data)
if err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(w, http.StatusCreated, client)
h.encoder.StatusResponse(w, http.StatusCreated, data)
}
func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) {
@ -89,20 +89,20 @@ func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) {
}
func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) {
var data domain.DownloadClient
var data *domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
client, err := h.service.Update(r.Context(), data)
err := h.service.Update(r.Context(), data)
if err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(w, http.StatusCreated, client)
h.encoder.StatusResponse(w, http.StatusCreated, data)
}
func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) {
@ -113,13 +113,13 @@ func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) {
return
}
id, err := strconv.Atoi(clientID)
id, err := strconv.ParseInt(clientID, 10, 32)
if err != nil {
h.encoder.Error(w, err)
return
}
if err = h.service.Delete(r.Context(), id); err != nil {
if err = h.service.Delete(r.Context(), int32(id)); err != nil {
h.encoder.Error(w, err)
return
}

View file

@ -221,7 +221,7 @@ func (s *service) processFilters(ctx context.Context, filters []*domain.Filter,
// found matching filter, lets find the filter actions and attach
active := true
actions, err := s.actionSvc.FindByFilterID(ctx, f.ID, &active)
actions, err := s.actionSvc.FindByFilterID(ctx, f.ID, &active, false)
if err != nil {
s.log.Error().Err(err).Msgf("release.Process: error finding actions for filter: %s", f.Name)
return err

View file

@ -17,7 +17,7 @@ type Config struct {
Username string
Password string
TLSSkipVerify bool
Timeout int
Timeout time.Duration
}
func New(endpoint *url.URL, cfg *Config) (*transmissionrpc.Client, error) {