diff --git a/cmd/autobrr/main.go b/cmd/autobrr/main.go index 64db8fe..a575d0a 100644 --- a/cmd/autobrr/main.go +++ b/cmd/autobrr/main.go @@ -117,7 +117,7 @@ func main() { downloadClientService = download_client.NewService(log, downloadClientRepo) actionService = action.NewService(log, actionRepo, downloadClientService, bus) indexerService = indexer.NewService(log, cfg.Config, indexerRepo, releaseRepo, indexerAPIService, schedulingService) - filterService = filter.NewService(log, filterRepo, actionRepo, releaseRepo, indexerAPIService, indexerService) + filterService = filter.NewService(log, filterRepo, actionService, releaseRepo, indexerAPIService, indexerService) releaseService = release.NewService(log, releaseRepo, actionService, filterService, indexerService) ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService) feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, schedulingService) diff --git a/internal/action/deluge.go b/internal/action/deluge.go index ba6f301..4bf74c3 100644 --- a/internal/action/deluge.go +++ b/internal/action/deluge.go @@ -7,7 +7,6 @@ import ( "context" "encoding/base64" "os" - "time" "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/pkg/errors" @@ -20,20 +19,18 @@ func (s *service) deluge(ctx context.Context, action *domain.Action, release dom var err error - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID) - return nil, err + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - if client == nil { - return nil, errors.New("could not find client by id: %d", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } var rejections []string - switch client.Type { + switch action.Client.Type { case "DELUGE_V1": rejections, err = s.delugeV1(ctx, client, action, release) @@ -90,27 +87,18 @@ func (s *service) delugeCheckRulesCanDownload(ctx context.Context, del deluge.De } func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) { - settings := deluge.Settings{ - Hostname: client.Host, - Port: uint(client.Port), - Login: client.Username, - Password: client.Password, - DebugServerResponses: true, - ReadWriteTimeout: time.Second * 30, - } - - del := deluge.NewV1(settings) + downloadClient := client.Client.(*deluge.Client) // perform connection to Deluge server - err := del.Connect(ctx) + err := downloadClient.Connect(ctx) if err != nil { return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host) } - defer del.Close() + defer downloadClient.Close() // perform connection to Deluge server - rejections, err := s.delugeCheckRulesCanDownload(ctx, del, client, action) + rejections, err := s.delugeCheckRulesCanDownload(ctx, downloadClient, client, action) if err != nil { s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name) return nil, err @@ -127,13 +115,13 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a s.log.Trace().Msgf("action Deluge options: %+v", options) - torrentHash, err := del.AddTorrentMagnet(ctx, release.MagnetURI, &options) + torrentHash, err := downloadClient.AddTorrentMagnet(ctx, release.MagnetURI, &options) if err != nil { return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name) } if action.Label != "" { - labelPluginActive, err := del.LabelPlugin(ctx) + labelPluginActive, err := downloadClient.LabelPlugin(ctx) if err != nil { return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) } @@ -176,13 +164,13 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a s.log.Trace().Msgf("action Deluge options: %+v", options) - torrentHash, err := del.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options) + torrentHash, err := downloadClient.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options) if err != nil { return nil, errors.Wrap(err, "could not add torrent %v to client: %v", release.TorrentTmpFile, client.Name) } if action.Label != "" { - labelPluginActive, err := del.LabelPlugin(ctx) + labelPluginActive, err := downloadClient.LabelPlugin(ctx) if err != nil { return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) } @@ -203,27 +191,18 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a } func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) { - settings := deluge.Settings{ - Hostname: client.Host, - Port: uint(client.Port), - Login: client.Username, - Password: client.Password, - DebugServerResponses: true, - ReadWriteTimeout: time.Second * 30, - } - - del := deluge.NewV2(settings) + downloadClient := client.Client.(*deluge.ClientV2) // perform connection to Deluge server - err := del.Connect(ctx) + err := downloadClient.Connect(ctx) if err != nil { return nil, errors.Wrap(err, "could not connect to client %s at %s", client.Name, client.Host) } - defer del.Close() + defer downloadClient.Close() // perform connection to Deluge server - rejections, err := s.delugeCheckRulesCanDownload(ctx, del, client, action) + rejections, err := s.delugeCheckRulesCanDownload(ctx, downloadClient, client, action) if err != nil { s.log.Error().Err(err).Msgf("error checking client rules: %s", action.Name) return nil, err @@ -240,13 +219,13 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a s.log.Trace().Msgf("action Deluge options: %+v", options) - torrentHash, err := del.AddTorrentMagnet(ctx, release.MagnetURI, &options) + torrentHash, err := downloadClient.AddTorrentMagnet(ctx, release.MagnetURI, &options) if err != nil { return nil, errors.Wrap(err, "could not add torrent magnet %s to client: %s", release.MagnetURI, client.Name) } if action.Label != "" { - labelPluginActive, err := del.LabelPlugin(ctx) + labelPluginActive, err := downloadClient.LabelPlugin(ctx) if err != nil { return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) } @@ -290,13 +269,13 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a s.log.Trace().Msgf("action Deluge options: %+v", options) - torrentHash, err := del.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options) + torrentHash, err := downloadClient.AddTorrentFile(ctx, release.TorrentTmpFile, encodedFile, &options) if err != nil { return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, client.Name) } if action.Label != "" { - labelPluginActive, err := del.LabelPlugin(ctx) + labelPluginActive, err := downloadClient.LabelPlugin(ctx) if err != nil { return nil, errors.Wrap(err, "could not load label plugin for client: %s", client.Name) } diff --git a/internal/action/lidarr.go b/internal/action/lidarr.go index c796fa0..791d6aa 100644 --- a/internal/action/lidarr.go +++ b/internal/action/lidarr.go @@ -17,41 +17,16 @@ func (s *service) lidarr(ctx context.Context, action *domain.Action, release dom // TODO validate data - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - s.log.Error().Err(err).Msgf("lidarr: error finding client: %v", action.ClientID) - return nil, err + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - // return early if no client found - if client == nil { - return nil, errors.New("could not find client by id: %v", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - // initial config - cfg := lidarr.Config{ - Hostname: client.Host, - APIKey: client.Settings.APIKey, - Log: s.subLogger, - } - - // only set basic auth if enabled - if client.Settings.Basic.Auth { - cfg.BasicAuth = client.Settings.Basic.Auth - cfg.Username = client.Settings.Basic.Username - cfg.Password = client.Settings.Basic.Password - } - - externalClientId := client.Settings.ExternalDownloadClientId - if action.ExternalDownloadClientID > 0 { - externalClientId = int(action.ExternalDownloadClientID) - } - - externalClient := client.Settings.ExternalDownloadClient - if action.ExternalDownloadClient != "" { - externalClient = action.ExternalDownloadClient - } + arr := client.Client.(lidarr.Client) r := lidarr.Release{ Title: release.TorrentName, @@ -60,14 +35,20 @@ func (s *service) lidarr(ctx context.Context, action *domain.Action, release dom MagnetUrl: release.MagnetURI, Size: int64(release.Size), Indexer: release.Indexer.GetExternalIdentifier(), - DownloadClientId: externalClientId, - DownloadClient: externalClient, + DownloadClientId: client.Settings.ExternalDownloadClientId, + DownloadClient: client.Settings.ExternalDownloadClient, DownloadProtocol: release.Protocol.String(), Protocol: release.Protocol.String(), PublishDate: time.Now().Format(time.RFC3339), } - arr := lidarr.New(cfg) + if action.ExternalDownloadClientID > 0 { + r.DownloadClientId = int(action.ExternalDownloadClientID) + } + + if action.ExternalDownloadClient != "" { + r.DownloadClient = action.ExternalDownloadClient + } rejections, err := arr.Push(ctx, r) if err != nil { diff --git a/internal/action/porla.go b/internal/action/porla.go index ba4f3da..950d39e 100644 --- a/internal/action/porla.go +++ b/internal/action/porla.go @@ -13,34 +13,21 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/porla" - - "github.com/dcarbone/zadapters/zstdlog" - "github.com/rs/zerolog" ) func (s *service) porla(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action Porla: %s", action.Name) - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - return nil, errors.Wrap(err, "error finding client: %d", action.ClientID) + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - if client == nil { - return nil, errors.New("could not find client by id: %d", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - porlaSettings := porla.Config{ - Hostname: client.Host, - AuthToken: client.Settings.APIKey, - TLSSkipVerify: client.TLSSkipVerify, - BasicUser: client.Settings.Basic.Username, - BasicPass: client.Settings.Basic.Password, - } - - porlaSettings.Log = zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Porla").Str("client", client.Name).Logger(), zerolog.TraceLevel) - - prl := porla.NewClient(porlaSettings) + prl := client.Client.(*porla.Client) rejections, err := s.porlaCheckRulesCanDownload(ctx, action, client, prl) if err != nil { diff --git a/internal/action/qbittorrent.go b/internal/action/qbittorrent.go index 9e26806..8f5c2f0 100644 --- a/internal/action/qbittorrent.go +++ b/internal/action/qbittorrent.go @@ -17,11 +17,20 @@ import ( func (s *service) qbittorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action qBittorrent: %s", action.Name) - c := s.clientSvc.GetCachedClient(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) + if err != nil { + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) + } - if c.Dc.Settings.Rules.Enabled && !action.IgnoreRules { + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) + } + + qbtClient := client.Client.(*qbittorrent.Client) + + if client.Settings.Rules.Enabled && !action.IgnoreRules { // check for active downloads and other rules - rejections, err := s.qbittorrentCheckRulesCanDownload(ctx, action, c.Dc.Settings.Rules, c.Qbt) + rejections, err := s.qbittorrentCheckRulesCanDownload(ctx, action, client.Settings.Rules, qbtClient) if err != nil { return nil, errors.Wrap(err, "error checking client rules: %s", action.Name) } @@ -39,11 +48,11 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas s.log.Trace().Msgf("action qBittorrent options: %+v", options) - if err = c.Qbt.AddTorrentFromUrlCtx(ctx, release.MagnetURI, options); err != nil { - return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.MagnetURI, c.Dc.Name) + if err = qbtClient.AddTorrentFromUrlCtx(ctx, release.MagnetURI, options); err != nil { + return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.MagnetURI, client.Name) } - s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", c.Dc.Name) + s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", client.Name) return nil, nil } @@ -61,37 +70,37 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas s.log.Trace().Msgf("action qBittorrent options: %+v", options) - if err = c.Qbt.AddTorrentFromFileCtx(ctx, release.TorrentTmpFile, options); err != nil { - return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, c.Dc.Name) + if err = qbtClient.AddTorrentFromFileCtx(ctx, release.TorrentTmpFile, options); err != nil { + return nil, errors.Wrap(err, "could not add torrent %s to client: %s", release.TorrentTmpFile, client.Name) } if release.TorrentHash != "" { // check if torrent queueing is enabled if priority is set switch action.PriorityLayout { case domain.PriorityLayoutMax, domain.PriorityLayoutMin: - prefs, err := c.Qbt.GetAppPreferencesCtx(ctx) + prefs, err := qbtClient.GetAppPreferencesCtx(ctx) if err != nil { - return nil, errors.Wrap(err, "could not get application preferences from client: '%s'", c.Dc.Name) + return nil, errors.Wrap(err, "could not get application preferences from client: '%s'", client.Name) } // enable queueing if it's disabled if !prefs.QueueingEnabled { - if err := c.Qbt.SetPreferencesQueueingEnabled(true); err != nil { + if err := qbtClient.SetPreferencesQueueingEnabled(true); err != nil { return nil, errors.Wrap(err, "could not enable torrent queueing") } - s.log.Trace().Msgf("torrent queueing was disabled, now enabled in client: '%s'", c.Dc.Name) + s.log.Trace().Msgf("torrent queueing was disabled, now enabled in client: '%s'", client.Name) } // set priority if queueing is enabled if action.PriorityLayout == domain.PriorityLayoutMax { - if err := c.Qbt.SetMaxPriorityCtx(ctx, []string{release.TorrentHash}); err != nil { + if err := qbtClient.SetMaxPriorityCtx(ctx, []string{release.TorrentHash}); err != nil { return nil, errors.Wrap(err, "could not set torrent %s to max priority", release.TorrentHash) } - s.log.Debug().Msgf("torrent with hash %s set to max priority in client: '%s'", release.TorrentHash, c.Dc.Name) + s.log.Debug().Msgf("torrent with hash %s set to max priority in client: '%s'", release.TorrentHash, client.Name) } else { // domain.PriorityLayoutMin - if err := c.Qbt.SetMinPriorityCtx(ctx, []string{release.TorrentHash}); err != nil { + if err := qbtClient.SetMinPriorityCtx(ctx, []string{release.TorrentHash}); err != nil { return nil, errors.Wrap(err, "could not set torrent %s to min priority", release.TorrentHash) } - s.log.Debug().Msgf("torrent with hash %s set to min priority in client: '%s'", release.TorrentHash, c.Dc.Name) + s.log.Debug().Msgf("torrent with hash %s set to min priority in client: '%s'", release.TorrentHash, client.Name) } case domain.PriorityLayoutDefault: @@ -111,7 +120,7 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas DeleteOnFailure: action.ReAnnounceDelete, } - if err := c.Qbt.ReannounceTorrentWithRetry(ctx, release.TorrentHash, &opts); err != nil { + if err := qbtClient.ReannounceTorrentWithRetry(ctx, release.TorrentHash, &opts); err != nil { if errors.Is(err, qbittorrent.ErrReannounceTookTooLong) { return []string{fmt.Sprintf("re-announce took too long for hash: %s", release.TorrentHash)}, nil } @@ -120,7 +129,7 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas } } - s.log.Info().Msgf("torrent with hash %s successfully added to client: '%s'", release.TorrentHash, c.Dc.Name) + s.log.Info().Msgf("torrent with hash %s successfully added to client: '%s'", release.TorrentHash, client.Name) return nil, nil } diff --git a/internal/action/radarr.go b/internal/action/radarr.go index 3461761..fb0e287 100644 --- a/internal/action/radarr.go +++ b/internal/action/radarr.go @@ -17,40 +17,16 @@ func (s *service) radarr(ctx context.Context, action *domain.Action, release dom // TODO validate data - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - return nil, errors.Wrap(err, "error finding client: %v", action.ClientID) + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - // return early if no client found - if client == nil { - return nil, errors.New("could not find client by id: %v", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - // initial config - cfg := radarr.Config{ - Hostname: client.Host, - APIKey: client.Settings.APIKey, - Log: s.subLogger, - } - - // only set basic auth if enabled - if client.Settings.Basic.Auth { - cfg.BasicAuth = client.Settings.Basic.Auth - cfg.Username = client.Settings.Basic.Username - cfg.Password = client.Settings.Basic.Password - } - - externalClientId := client.Settings.ExternalDownloadClientId - if action.ExternalDownloadClientID > 0 { - externalClientId = int(action.ExternalDownloadClientID) - } - - externalClient := client.Settings.ExternalDownloadClient - if action.ExternalDownloadClient != "" { - externalClient = action.ExternalDownloadClient - } + arr := client.Client.(radarr.Client) r := radarr.Release{ Title: release.TorrentName, @@ -59,14 +35,20 @@ func (s *service) radarr(ctx context.Context, action *domain.Action, release dom MagnetUrl: release.MagnetURI, Size: int64(release.Size), Indexer: release.Indexer.GetExternalIdentifier(), - DownloadClientId: externalClientId, - DownloadClient: externalClient, + DownloadClientId: client.Settings.ExternalDownloadClientId, + DownloadClient: client.Settings.ExternalDownloadClient, DownloadProtocol: release.Protocol.String(), Protocol: release.Protocol.String(), PublishDate: time.Now().Format(time.RFC3339), } - arr := radarr.New(cfg) + if action.ExternalDownloadClientID > 0 { + r.DownloadClientId = int(action.ExternalDownloadClientID) + } + + if action.ExternalDownloadClient != "" { + r.DownloadClient = action.ExternalDownloadClient + } rejections, err := arr.Push(ctx, r) if err != nil { diff --git a/internal/action/readarr.go b/internal/action/readarr.go index 0ff6f0f..8a5c639 100644 --- a/internal/action/readarr.go +++ b/internal/action/readarr.go @@ -17,40 +17,16 @@ func (s *service) readarr(ctx context.Context, action *domain.Action, release do // TODO validate data - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - return nil, errors.Wrap(err, "readarr could not find client: %v", action.ClientID) + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - // return early if no client found - if client == nil { - return nil, errors.New("no client found") + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - // initial config - cfg := readarr.Config{ - Hostname: client.Host, - APIKey: client.Settings.APIKey, - Log: s.subLogger, - } - - // only set basic auth if enabled - if client.Settings.Basic.Auth { - cfg.BasicAuth = client.Settings.Basic.Auth - cfg.Username = client.Settings.Basic.Username - cfg.Password = client.Settings.Basic.Password - } - - externalClientId := client.Settings.ExternalDownloadClientId - if action.ExternalDownloadClientID > 0 { - externalClientId = int(action.ExternalDownloadClientID) - } - - externalClient := client.Settings.ExternalDownloadClient - if action.ExternalDownloadClient != "" { - externalClient = action.ExternalDownloadClient - } + arr := client.Client.(readarr.Client) r := readarr.Release{ Title: release.TorrentName, @@ -59,14 +35,20 @@ func (s *service) readarr(ctx context.Context, action *domain.Action, release do MagnetUrl: release.MagnetURI, Size: int64(release.Size), Indexer: release.Indexer.GetExternalIdentifier(), - DownloadClientId: externalClientId, - DownloadClient: externalClient, + DownloadClientId: client.Settings.ExternalDownloadClientId, + DownloadClient: client.Settings.ExternalDownloadClient, DownloadProtocol: release.Protocol.String(), Protocol: release.Protocol.String(), PublishDate: time.Now().Format(time.RFC3339), } - arr := readarr.New(cfg) + if action.ExternalDownloadClientID > 0 { + r.DownloadClientId = int(action.ExternalDownloadClientID) + } + + if action.ExternalDownloadClient != "" { + r.DownloadClient = action.ExternalDownloadClient + } rejections, err := arr.Push(ctx, r) if err != nil { diff --git a/internal/action/rtorrent.go b/internal/action/rtorrent.go index b46ea2c..99dc8c5 100644 --- a/internal/action/rtorrent.go +++ b/internal/action/rtorrent.go @@ -16,32 +16,19 @@ import ( func (s *service) rtorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action rTorrent: %s", action.Name) - var err error - - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID) - return nil, err + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - if client == nil { - return nil, errors.New("could not find client by id: %d", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } + rt := client.Client.(*rtorrent.Client) + var rejections []string - // create config - cfg := rtorrent.Config{ - Addr: client.Host, - TLSSkipVerify: client.TLSSkipVerify, - BasicUser: client.Settings.Basic.Username, - BasicPass: client.Settings.Basic.Password, - } - - // create client - rt := rtorrent.NewClient(cfg) - if release.HasMagnetUri() { var args []*rtorrent.FieldValue @@ -79,55 +66,54 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d s.log.Info().Msgf("torrent from magnet successfully added to client: '%s'", client.Name) return nil, nil + } - } else { - if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) - return nil, err - } + if release.TorrentTmpFile == "" { + if err := release.DownloadTorrentFileCtx(ctx); err != nil { + s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) + return nil, err } + } - tmpFile, err := os.ReadFile(release.TorrentTmpFile) - if err != nil { - return nil, errors.Wrap(err, "could not read torrent file: %s", release.TorrentTmpFile) - } + tmpFile, err := os.ReadFile(release.TorrentTmpFile) + if err != nil { + return nil, errors.Wrap(err, "could not read torrent file: %s", release.TorrentTmpFile) + } - var args []*rtorrent.FieldValue + var args []*rtorrent.FieldValue - if action.Label != "" { + if action.Label != "" { + args = append(args, &rtorrent.FieldValue{ + Field: rtorrent.DLabel, + Value: action.Label, + }) + } + if action.SavePath != "" { + if action.ContentLayout == domain.ActionContentLayoutSubfolderNone { args = append(args, &rtorrent.FieldValue{ - Field: rtorrent.DLabel, - Value: action.Label, + Field: "d.directory_base", + Value: action.SavePath, + }) + } else { + args = append(args, &rtorrent.FieldValue{ + Field: rtorrent.DDirectory, + Value: action.SavePath, }) } - if action.SavePath != "" { - if action.ContentLayout == domain.ActionContentLayoutSubfolderNone { - args = append(args, &rtorrent.FieldValue{ - Field: "d.directory_base", - Value: action.SavePath, - }) - } else { - args = append(args, &rtorrent.FieldValue{ - Field: rtorrent.DDirectory, - Value: action.SavePath, - }) - } - } - - var addTorrentFile func(context.Context, []byte, ...*rtorrent.FieldValue) error - if action.Paused { - addTorrentFile = rt.AddTorrentStopped - } else { - addTorrentFile = rt.AddTorrent - } - - if err := addTorrentFile(ctx, tmpFile, args...); err != nil { - return nil, errors.Wrap(err, "could not add torrent file: %s", release.TorrentTmpFile) - } - - s.log.Info().Msgf("torrent successfully added to client: '%s'", client.Name) } + var addTorrentFile func(context.Context, []byte, ...*rtorrent.FieldValue) error + if action.Paused { + addTorrentFile = rt.AddTorrentStopped + } else { + addTorrentFile = rt.AddTorrent + } + + if err := addTorrentFile(ctx, tmpFile, args...); err != nil { + return nil, errors.Wrap(err, "could not add torrent file: %s", release.TorrentTmpFile) + } + + s.log.Info().Msgf("torrent successfully added to client: '%s'", client.Name) + return rejections, nil } diff --git a/internal/action/run.go b/internal/action/run.go index d7566fe..e0e5443 100644 --- a/internal/action/run.go +++ b/internal/action/run.go @@ -19,7 +19,6 @@ import ( ) func (s *service) RunAction(ctx context.Context, action *domain.Action, release *domain.Release) ([]string, error) { - var ( err error rejections []string @@ -33,6 +32,10 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release } }() + if action.ClientID > 0 && action.Client != nil && !action.Client.Enabled { + return nil, errors.New("action %s client %s %s not enabled, skipping", action.Name, action.Client.Type, action.Client.Name) + } + // if set, try to resolve MagnetURI before parsing macros // to allow webhook and exec to get the magnet_uri if err := release.ResolveMagnetUri(ctx); err != nil { diff --git a/internal/action/sabnzbd.go b/internal/action/sabnzbd.go index 76476dd..580722e 100644 --- a/internal/action/sabnzbd.go +++ b/internal/action/sabnzbd.go @@ -18,29 +18,16 @@ func (s *service) sabnzbd(ctx context.Context, action *domain.Action, release do return nil, errors.New("action type: %s invalid protocol: %s", action.Type, release.Protocol) } - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - return nil, errors.Wrap(err, "sonarr could not find client: %d", action.ClientID) + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - // return early if no client found - if client == nil { - return nil, errors.New("no sabnzbd client found by id: %d", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - opts := sabnzbd.Options{ - Addr: client.Host, - ApiKey: client.Settings.APIKey, - Log: nil, - } - - if client.Settings.Basic.Auth { - opts.BasicUser = client.Settings.Basic.Username - opts.BasicPass = client.Settings.Basic.Password - } - - sab := sabnzbd.New(opts) + sab := client.Client.(*sabnzbd.Client) ids, err := sab.AddFromUrl(ctx, sabnzbd.AddNzbRequest{Url: release.DownloadURL, Category: action.Category}) if err != nil { diff --git a/internal/action/service.go b/internal/action/service.go index a83d042..6d1bc18 100644 --- a/internal/action/service.go +++ b/internal/action/service.go @@ -21,9 +21,10 @@ import ( type Service interface { Store(ctx context.Context, action domain.Action) (*domain.Action, error) + StoreFilterActions(ctx context.Context, filterID int64, actions []*domain.Action) ([]*domain.Action, error) List(ctx context.Context) ([]domain.Action, error) Get(ctx context.Context, req *domain.GetActionRequest) (*domain.Action, error) - FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) + FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error) Delete(ctx context.Context, req *domain.DeleteActionRequest) error DeleteByFilterID(ctx context.Context, filterID int) error ToggleEnabled(actionID int) error @@ -63,6 +64,10 @@ func (s *service) Store(ctx context.Context, action domain.Action) (*domain.Acti return s.repo.Store(ctx, action) } +func (s *service) StoreFilterActions(ctx context.Context, filterID int64, actions []*domain.Action) ([]*domain.Action, error) { + return s.repo.StoreFilterActions(ctx, filterID, actions) +} + func (s *service) List(ctx context.Context) ([]domain.Action, error) { return s.repo.List(ctx) } @@ -86,8 +91,8 @@ func (s *service) Get(ctx context.Context, req *domain.GetActionRequest) (*domai return a, nil } -func (s *service) FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) { - return s.repo.FindByFilterID(ctx, filterID, active) +func (s *service) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error) { + return s.repo.FindByFilterID(ctx, filterID, active, withClient) } func (s *service) Delete(ctx context.Context, req *domain.DeleteActionRequest) error { diff --git a/internal/action/sonarr.go b/internal/action/sonarr.go index 6012034..3fa77ba 100644 --- a/internal/action/sonarr.go +++ b/internal/action/sonarr.go @@ -17,40 +17,16 @@ func (s *service) sonarr(ctx context.Context, action *domain.Action, release dom // TODO validate data - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - // return early if no client found - if client == nil { - return nil, errors.New("no client found") + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - // initial config - cfg := sonarr.Config{ - Hostname: client.Host, - APIKey: client.Settings.APIKey, - Log: s.subLogger, - } - - // only set basic auth if enabled - if client.Settings.Basic.Auth { - cfg.BasicAuth = client.Settings.Basic.Auth - cfg.Username = client.Settings.Basic.Username - cfg.Password = client.Settings.Basic.Password - } - - externalClientId := client.Settings.ExternalDownloadClientId - if action.ExternalDownloadClientID > 0 { - externalClientId = int(action.ExternalDownloadClientID) - } - - externalClient := client.Settings.ExternalDownloadClient - if action.ExternalDownloadClient != "" { - externalClient = action.ExternalDownloadClient - } + arr := client.Client.(sonarr.Client) r := sonarr.Release{ Title: release.TorrentName, @@ -59,14 +35,20 @@ func (s *service) sonarr(ctx context.Context, action *domain.Action, release dom MagnetUrl: release.MagnetURI, Size: int64(release.Size), Indexer: release.Indexer.GetExternalIdentifier(), - DownloadClientId: externalClientId, - DownloadClient: externalClient, + DownloadClientId: client.Settings.ExternalDownloadClientId, + DownloadClient: client.Settings.ExternalDownloadClient, DownloadProtocol: release.Protocol.String(), Protocol: release.Protocol.String(), PublishDate: time.Now().Format(time.RFC3339), } - arr := sonarr.New(cfg) + if action.ExternalDownloadClientID > 0 { + r.DownloadClientId = int(action.ExternalDownloadClientID) + } + + if action.ExternalDownloadClient != "" { + r.DownloadClient = action.ExternalDownloadClient + } rejections, err := arr.Push(ctx, r) if err != nil { diff --git a/internal/action/transmission.go b/internal/action/transmission.go index 989fb53..a03a368 100644 --- a/internal/action/transmission.go +++ b/internal/action/transmission.go @@ -6,14 +6,11 @@ package action import ( "context" "fmt" - "net/url" "strings" "time" "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/pkg/errors" - "github.com/autobrr/autobrr/pkg/transmission" - "github.com/hekmon/transmissionrpc/v3" ) @@ -28,38 +25,16 @@ var TrTrue = true func (s *service) transmission(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action Transmission: %s", action.Name) - var err error - - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - s.log.Error().Stack().Err(err).Msgf("error finding client: %d", action.ClientID) - return nil, err + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - if client == nil { - return nil, errors.New("could not find client by id: %d", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - scheme := "http" - if client.TLS { - scheme = "https" - } - - u, err := url.Parse(fmt.Sprintf("%s://%s:%d/transmission/rpc", scheme, client.Host, client.Port)) - if err != nil { - return nil, err - } - - tbt, err := transmission.New(u, &transmission.Config{ - UserAgent: "autobrr", - Username: client.Username, - Password: client.Password, - TLSSkipVerify: client.TLSSkipVerify, - }) - if err != nil { - return nil, errors.Wrap(err, "error logging into client: %s", client.Host) - } + tbt := client.Client.(*transmissionrpc.Client) rejections, err := s.transmissionCheckRulesCanDownload(ctx, action, client, tbt) if err != nil { diff --git a/internal/action/whisparr.go b/internal/action/whisparr.go index a131ac5..5110c08 100644 --- a/internal/action/whisparr.go +++ b/internal/action/whisparr.go @@ -17,40 +17,16 @@ func (s *service) whisparr(ctx context.Context, action *domain.Action, release d // TODO validate data - // get client for action - client, err := s.clientSvc.FindByID(ctx, action.ClientID) + client, err := s.clientSvc.GetClient(ctx, action.ClientID) if err != nil { - return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) + return nil, errors.Wrap(err, "could not get client with id %d", action.ClientID) } - // return early if no client found - if client == nil { - return nil, errors.New("could not find client by id: %v", action.ClientID) + if !client.Enabled { + return nil, errors.New("client %s %s not enabled", client.Type, client.Name) } - // initial config - cfg := whisparr.Config{ - Hostname: client.Host, - APIKey: client.Settings.APIKey, - Log: s.subLogger, - } - - // only set basic auth if enabled - if client.Settings.Basic.Auth { - cfg.BasicAuth = client.Settings.Basic.Auth - cfg.Username = client.Settings.Basic.Username - cfg.Password = client.Settings.Basic.Password - } - - externalClientId := client.Settings.ExternalDownloadClientId - if action.ExternalDownloadClientID > 0 { - externalClientId = int(action.ExternalDownloadClientID) - } - - externalClient := client.Settings.ExternalDownloadClient - if action.ExternalDownloadClient != "" { - externalClient = action.ExternalDownloadClient - } + arr := client.Client.(whisparr.Client) r := whisparr.Release{ Title: release.TorrentName, @@ -59,14 +35,20 @@ func (s *service) whisparr(ctx context.Context, action *domain.Action, release d MagnetUrl: release.MagnetURI, Size: int64(release.Size), Indexer: release.Indexer.GetExternalIdentifier(), - DownloadClientId: externalClientId, - DownloadClient: externalClient, + DownloadClientId: client.Settings.ExternalDownloadClientId, + DownloadClient: client.Settings.ExternalDownloadClient, DownloadProtocol: release.Protocol.String(), Protocol: release.Protocol.String(), PublishDate: time.Now().Format(time.RFC3339), } - arr := whisparr.New(cfg) + if action.ExternalDownloadClientID > 0 { + r.DownloadClientId = int(action.ExternalDownloadClientID) + } + + if action.ExternalDownloadClient != "" { + r.DownloadClient = action.ExternalDownloadClient + } rejections, err := arr.Push(ctx, r) if err != nil { diff --git a/internal/database/action.go b/internal/database/action.go index cac27c3..a27c7a3 100644 --- a/internal/database/action.go +++ b/internal/database/action.go @@ -30,7 +30,266 @@ func NewActionRepo(log logger.Logger, db *DB, clientRepo domain.DownloadClientRe } } -func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) { +func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*domain.Action, error) { + if withClient { + return r.findByFilterIDWithClient(ctx, filterID, active) + } + + return r.findByFilterID(ctx, filterID, active) +} + +func (r *ActionRepo) findByFilterID(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) { + queryBuilder := r.db.squirrel. + Select( + "a.id", + "a.name", + "a.type", + "a.enabled", + "a.exec_cmd", + "a.exec_args", + "a.watch_folder", + "a.category", + "a.tags", + "a.label", + "a.save_path", + "a.paused", + "a.ignore_rules", + "a.first_last_piece_prio", + "a.skip_hash_check", + "a.content_layout", + "a.priority", + "a.limit_download_speed", + "a.limit_upload_speed", + "a.limit_ratio", + "a.limit_seed_time", + "a.reannounce_skip", + "a.reannounce_delete", + "a.reannounce_interval", + "a.reannounce_max_attempts", + "a.webhook_host", + "a.webhook_type", + "a.webhook_method", + "a.webhook_data", + "a.external_client_id", + "a.external_client", + "a.client_id", + ). + From("action a"). + Where(sq.Eq{"a.filter_id": filterID}) + + if active != nil { + queryBuilder = queryBuilder.Where(sq.Eq{"enabled": *active}) + } + + query, args, err := queryBuilder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "error building query") + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(err, "error executing query") + } + + defer rows.Close() + + actions := make([]*domain.Action, 0) + for rows.Next() { + var a domain.Action + + var execCmd, execArgs, watchFolder, category, tags, label, savePath, contentLayout, priorityLayout, webhookHost, webhookType, webhookMethod, webhookData, externalClient sql.NullString + var limitUl, limitDl, limitSeedTime sql.NullInt64 + var limitRatio sql.NullFloat64 + + var externalClientID, clientID sql.NullInt32 + var paused, ignoreRules sql.NullBool + + if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &a.FirstLastPiecePrio, &a.SkipHashCheck, &contentLayout, &priorityLayout, &limitDl, &limitUl, &limitRatio, &limitSeedTime, &a.ReAnnounceSkip, &a.ReAnnounceDelete, &a.ReAnnounceInterval, &a.ReAnnounceMaxAttempts, &webhookHost, &webhookType, &webhookMethod, &webhookData, &externalClientID, &externalClient, &clientID); err != nil { + return nil, errors.Wrap(err, "error scanning row") + } + + a.ExecCmd = execCmd.String + a.ExecArgs = execArgs.String + a.WatchFolder = watchFolder.String + a.Category = category.String + a.Tags = tags.String + a.Label = label.String + a.SavePath = savePath.String + a.Paused = paused.Bool + a.IgnoreRules = ignoreRules.Bool + a.ContentLayout = domain.ActionContentLayout(contentLayout.String) + a.PriorityLayout = domain.PriorityLayout(priorityLayout.String) + + a.LimitDownloadSpeed = limitDl.Int64 + a.LimitUploadSpeed = limitUl.Int64 + a.LimitRatio = limitRatio.Float64 + a.LimitSeedTime = limitSeedTime.Int64 + + a.WebhookHost = webhookHost.String + a.WebhookType = webhookType.String + a.WebhookMethod = webhookMethod.String + a.WebhookData = webhookData.String + + a.ExternalDownloadClientID = externalClientID.Int32 + a.ExternalDownloadClient = externalClient.String + a.ClientID = clientID.Int32 + + actions = append(actions, &a) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(err, "row error") + } + + return actions, nil +} + +func (r *ActionRepo) findByFilterIDWithClient(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) { + queryBuilder := r.db.squirrel. + Select( + "a.id", + "a.name", + "a.type", + "a.enabled", + "a.exec_cmd", + "a.exec_args", + "a.watch_folder", + "a.category", + "a.tags", + "a.label", + "a.save_path", + "a.paused", + "a.ignore_rules", + "a.first_last_piece_prio", + "a.skip_hash_check", + "a.content_layout", + "a.priority", + "a.limit_download_speed", + "a.limit_upload_speed", + "a.limit_ratio", + "a.limit_seed_time", + "a.reannounce_skip", + "a.reannounce_delete", + "a.reannounce_interval", + "a.reannounce_max_attempts", + "a.webhook_host", + "a.webhook_type", + "a.webhook_method", + "a.webhook_data", + "a.external_client_id", + "a.external_client", + "a.client_id", + "c.id", + "c.name", + "c.type", + "c.enabled", + "c.host", + "c.port", + "c.tls", + "c.tls_skip_verify", + "c.username", + "c.password", + "c.settings", + ). + From("action a"). + Join("client c ON a.client_id = c.id"). + Where(sq.Eq{"a.filter_id": filterID}) + + if active != nil { + queryBuilder = queryBuilder.Where(sq.Eq{"enabled": *active}) + } + + query, args, err := queryBuilder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "error building query") + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(err, "error executing query") + } + + defer rows.Close() + + actions := make([]*domain.Action, 0) + for rows.Next() { + var a domain.Action + var c domain.DownloadClient + + var execCmd, execArgs, watchFolder, category, tags, label, savePath, contentLayout, priorityLayout, webhookHost, webhookType, webhookMethod, webhookData, externalClient sql.NullString + var limitUl, limitDl, limitSeedTime sql.NullInt64 + var limitRatio sql.NullFloat64 + + var externalClientID, clientID sql.NullInt32 + var paused, ignoreRules sql.NullBool + + var clientClientId, clientPort sql.Null[int32] + var clientName, clientType, clientHost, clientUsername, clientPassword, clientSettings sql.Null[string] + var clientEnabled, clientTLS, clientTLSSkip sql.Null[bool] + + if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &a.FirstLastPiecePrio, &a.SkipHashCheck, &contentLayout, &priorityLayout, &limitDl, &limitUl, &limitRatio, &limitSeedTime, &a.ReAnnounceSkip, &a.ReAnnounceDelete, &a.ReAnnounceInterval, &a.ReAnnounceMaxAttempts, &webhookHost, &webhookType, &webhookMethod, &webhookData, &externalClientID, &externalClient, &clientID, &clientClientId, &clientName, &clientType, &clientEnabled, &clientHost, &clientPort, &clientTLS, &clientTLSSkip, &clientUsername, &clientPassword, &clientSettings); err != nil { + return nil, errors.Wrap(err, "error scanning row") + } + + a.ExecCmd = execCmd.String + a.ExecArgs = execArgs.String + a.WatchFolder = watchFolder.String + a.Category = category.String + a.Tags = tags.String + a.Label = label.String + a.SavePath = savePath.String + a.Paused = paused.Bool + a.IgnoreRules = ignoreRules.Bool + a.ContentLayout = domain.ActionContentLayout(contentLayout.String) + a.PriorityLayout = domain.PriorityLayout(priorityLayout.String) + + a.LimitDownloadSpeed = limitDl.Int64 + a.LimitUploadSpeed = limitUl.Int64 + a.LimitRatio = limitRatio.Float64 + a.LimitSeedTime = limitSeedTime.Int64 + + a.WebhookHost = webhookHost.String + a.WebhookType = webhookType.String + a.WebhookMethod = webhookMethod.String + a.WebhookData = webhookData.String + + a.ExternalDownloadClientID = externalClientID.Int32 + a.ExternalDownloadClient = externalClient.String + a.ClientID = clientID.Int32 + + c.ID = clientClientId.V + c.Name = clientName.V + c.Type = domain.DownloadClientType(clientType.V) + c.Enabled = clientEnabled.V + c.Host = clientHost.V + c.Port = int(clientPort.V) + c.TLS = clientTLS.V + c.TLSSkipVerify = clientTLSSkip.V + c.Username = clientUsername.V + c.Password = clientPassword.V + //c.Settings = clientSettings.String + + if a.ClientID > 0 { + if clientSettings.Valid { + if err := json.Unmarshal([]byte(clientSettings.V), &c.Settings); err != nil { + return nil, errors.Wrap(err, "could not unmarshal download client settings: %v", clientSettings.V) + } + } + + a.Client = &c + } + + actions = append(actions, &a) + } + + if err := rows.Err(); err != nil { + return nil, errors.Wrap(err, "row error") + } + + return actions, nil +} + +func (r *ActionRepo) FindByFilterIDTx(ctx context.Context, filterID int, active *bool) ([]*domain.Action, error) { tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted}) if err != nil { return nil, err @@ -38,7 +297,7 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *b defer tx.Rollback() - actions, err := r.findByFilterID(ctx, tx, filterID, active) + actions, err := r.findByFilterIDTx(ctx, tx, filterID, active) if err != nil { return nil, err } @@ -59,7 +318,7 @@ func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int, active *b return actions, nil } -func (r *ActionRepo) findByFilterID(ctx context.Context, tx *Tx, filterID int, active *bool) ([]*domain.Action, error) { +func (r *ActionRepo) findByFilterIDTx(ctx context.Context, tx *Tx, filterID int, active *bool) ([]*domain.Action, error) { queryBuilder := r.db.squirrel. Select( "id", diff --git a/internal/database/action_test.go b/internal/database/action_test.go index 76c2806..471f809 100644 --- a/internal/database/action_test.go +++ b/internal/database/action_test.go @@ -62,9 +62,10 @@ func TestActionRepo_Store(t *testing.T) { t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -73,7 +74,7 @@ func TestActionRepo_Store(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID // Actual test for Store @@ -84,7 +85,7 @@ func TestActionRepo_Store(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("Store_Succeeds_With_Missing_or_empty_fields [%s]", dbType), func(t *testing.T) { @@ -125,9 +126,10 @@ func TestActionRepo_StoreFilterActions(t *testing.T) { t.Run(fmt.Sprintf("StoreFilterActions_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -136,7 +138,7 @@ func TestActionRepo_StoreFilterActions(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID // Actual test for StoreFilterActions @@ -148,7 +150,7 @@ func TestActionRepo_StoreFilterActions(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("StoreFilterActions_Fails_Invalid_FilterID [%s]", dbType), func(t *testing.T) { @@ -203,9 +205,10 @@ func TestActionRepo_FindByFilterID(t *testing.T) { t.Run(fmt.Sprintf("FindByFilterID_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -214,13 +217,13 @@ func TestActionRepo_FindByFilterID(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) assert.NoError(t, err) // Actual test for FindByFilterID - actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil) + actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil, false) assert.NoError(t, err) assert.NotNil(t, actions) assert.Equal(t, 1, len(actions)) @@ -228,7 +231,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("FindByFilterID_Fails_No_Actions [%s]", dbType), func(t *testing.T) { @@ -241,7 +244,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) { assert.NotNil(t, createdFilters) // Actual test for FindByFilterID - actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil) + actions, err := repo.FindByFilterID(context.Background(), createdFilters[0].ID, nil, false) assert.NoError(t, err) assert.Equal(t, 0, len(actions)) @@ -250,7 +253,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) { }) t.Run(fmt.Sprintf("FindByFilterID_Succeeds_With_Invalid_FilterID [%s]", dbType), func(t *testing.T) { - actions, err := repo.FindByFilterID(context.Background(), 9999, nil) // 9999 is an invalid filter ID + actions, err := repo.FindByFilterID(context.Background(), 9999, nil, false) // 9999 is an invalid filter ID assert.NoError(t, err) assert.NotNil(t, actions) assert.Equal(t, 0, len(actions)) @@ -260,7 +263,7 @@ func TestActionRepo_FindByFilterID(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - actions, err := repo.FindByFilterID(ctx, 1, nil) + actions, err := repo.FindByFilterID(ctx, 1, nil, false) assert.Error(t, err) assert.Nil(t, actions) }) @@ -277,9 +280,10 @@ func TestActionRepo_List(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -288,7 +292,7 @@ func TestActionRepo_List(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) assert.NoError(t, err) @@ -302,7 +306,7 @@ func TestActionRepo_List(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("List_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) { @@ -326,9 +330,10 @@ func TestActionRepo_Get(t *testing.T) { t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -337,7 +342,7 @@ func TestActionRepo_Get(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) assert.NoError(t, err) @@ -351,7 +356,7 @@ func TestActionRepo_Get(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("Get_Fails_No_Record [%s]", dbType), func(t *testing.T) { @@ -382,9 +387,10 @@ func TestActionRepo_Delete(t *testing.T) { t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -393,7 +399,7 @@ func TestActionRepo_Delete(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) assert.NoError(t, err) @@ -411,7 +417,7 @@ func TestActionRepo_Delete(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("Delete_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) { @@ -435,9 +441,10 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) { t.Run(fmt.Sprintf("DeleteByFilterID_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -446,7 +453,7 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) assert.NoError(t, err) @@ -463,7 +470,7 @@ func TestActionRepo_DeleteByFilterID(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("DeleteByFilterID_Fails_Context_Timeout [%s]", dbType), func(t *testing.T) { @@ -486,9 +493,10 @@ func TestActionRepo_ToggleEnabled(t *testing.T) { t.Run(fmt.Sprintf("ToggleEnabled_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -497,7 +505,7 @@ func TestActionRepo_ToggleEnabled(t *testing.T) { assert.NoError(t, err) assert.NotNil(t, createdFilters) - mockData.ClientID = int32(createdClient.ID) + mockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID mockData.Enabled = false createdActions, err := repo.StoreFilterActions(context.Background(), int64(createdFilters[0].ID), []*domain.Action{&mockData}) @@ -515,7 +523,7 @@ func TestActionRepo_ToggleEnabled(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdActions[0].ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("ToggleEnabled_Fails_No_Record [%s]", dbType), func(t *testing.T) { diff --git a/internal/database/download_client.go b/internal/database/download_client.go index 39aa1c3..5609689 100644 --- a/internal/database/download_client.go +++ b/internal/database/download_client.go @@ -7,7 +7,6 @@ import ( "context" "database/sql" "encoding/json" - "sync" "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" @@ -18,55 +17,18 @@ import ( ) type DownloadClientRepo struct { - log zerolog.Logger - db *DB - cache *clientCache -} - -type clientCache struct { - mu sync.RWMutex - clients map[int]*domain.DownloadClient -} - -func NewClientCache() *clientCache { - return &clientCache{ - clients: make(map[int]*domain.DownloadClient, 0), - } -} - -func (c *clientCache) Set(id int, client *domain.DownloadClient) { - c.mu.Lock() - c.clients[id] = client - c.mu.Unlock() -} - -func (c *clientCache) Get(id int) *domain.DownloadClient { - c.mu.RLock() - defer c.mu.RUnlock() - v, ok := c.clients[id] - if ok { - return v - } - return nil -} - -func (c *clientCache) Pop(id int) { - c.mu.Lock() - delete(c.clients, id) - c.mu.Unlock() + log zerolog.Logger + db *DB } func NewDownloadClientRepo(log logger.Logger, db *DB) domain.DownloadClientRepo { return &DownloadClientRepo{ - log: log.With().Str("repo", "action").Logger(), - db: db, - cache: NewClientCache(), + log: log.With().Str("repo", "action").Logger(), + db: db, } } func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) { - clients := make([]domain.DownloadClient, 0) - queryBuilder := r.db.squirrel. Select( "id", @@ -100,6 +62,8 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, } }(rows) + clients := make([]domain.DownloadClient, 0) + for rows.Next() { var f domain.DownloadClient var settingsJsonStr string @@ -124,12 +88,6 @@ func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, } func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) { - // get client from cache - c := r.cache.Get(int(id)) - if c != nil { - return c, nil - } - queryBuilder := r.db.squirrel. Select( "id", @@ -153,7 +111,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do } row := r.db.handler.QueryRowContext(ctx, query, args...) - if err != nil { + if err := row.Err(); err != nil { return nil, errors.Wrap(err, "error executing query") } @@ -177,9 +135,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do return &client, nil } -func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { - var err error - +func (r *DownloadClientRepo) Store(ctx context.Context, client *domain.DownloadClient) error { settings := domain.DownloadClientSettings{ APIKey: client.Settings.APIKey, Basic: client.Settings.Basic, @@ -190,7 +146,7 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl settingsJson, err := json.Marshal(&settings) if err != nil { - return nil, errors.Wrap(err, "error marshal download client settings %+v", settings) + return errors.Wrap(err, "error marshal download client settings %+v", settings) } queryBuilder := r.db.squirrel. @@ -204,22 +160,17 @@ func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadCl err = queryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { - return nil, errors.Wrap(err, "error executing query") + return errors.Wrap(err, "error executing query") } - client.ID = retID + client.ID = int32(retID) r.log.Debug().Msgf("download_client.store: %d", client.ID) - // save to cache - r.cache.Set(client.ID, &client) - - return &client, nil + return nil } -func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { - var err error - +func (r *DownloadClientRepo) Update(ctx context.Context, client *domain.DownloadClient) error { settings := domain.DownloadClientSettings{ APIKey: client.Settings.APIKey, Basic: client.Settings.Basic, @@ -230,7 +181,7 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC settingsJson, err := json.Marshal(&settings) if err != nil { - return nil, errors.Wrap(err, "error marshal download client settings %+v", settings) + return errors.Wrap(err, "error marshal download client settings %+v", settings) } queryBuilder := r.db.squirrel. @@ -249,32 +200,29 @@ func (r *DownloadClientRepo) Update(ctx context.Context, client domain.DownloadC query, args, err := queryBuilder.ToSql() if err != nil { - return nil, errors.Wrap(err, "error building query") + return errors.Wrap(err, "error building query") } result, err := r.db.handler.ExecContext(ctx, query, args...) if err != nil { - return nil, errors.Wrap(err, "error executing query") + return errors.Wrap(err, "error executing query") } rowsAffected, err := result.RowsAffected() if err != nil { - return nil, errors.Wrap(err, "error getting rows affected") + return errors.Wrap(err, "error getting rows affected") } if rowsAffected == 0 { - return nil, errors.New("no rows updated") + return errors.New("no rows updated") } r.log.Debug().Msgf("download_client.update: %d", client.ID) - // save to cache - r.cache.Set(client.ID, &client) - - return &client, nil + return nil } -func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error { +func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int32) error { tx, err := r.db.BeginTx(ctx, &sql.TxOptions{}) if err != nil { return err @@ -311,10 +259,11 @@ func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error { } r.log.Debug().Msgf("delete download client: %d", clientID) + return nil } -func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) error { +func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int32) error { queryBuilder := r.db.squirrel. Delete("client"). Where(sq.Eq{"id": clientID}) @@ -329,9 +278,6 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e return errors.Wrap(err, "error executing query") } - // remove from cache - r.cache.Pop(clientID) - rows, _ := res.RowsAffected() if rows == 0 { return errors.New("no rows affected") @@ -342,9 +288,7 @@ func (r *DownloadClientRepo) delete(ctx context.Context, tx *Tx, clientID int) e return nil } -func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int) error { - var err error - +func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, clientID int32) error { queryBuilder := r.db.squirrel. Update("action"). Set("enabled", false). @@ -355,12 +299,14 @@ func (r *DownloadClientRepo) deleteClientFromAction(ctx context.Context, tx *Tx, // return values var filterID int - if err = queryBuilder.QueryRowContext(ctx).Scan(&filterID); err != nil { + err := queryBuilder.QueryRowContext(ctx).Scan(&filterID) + if err != nil { // this will throw when the client is not connected to any actions // it is not an error in this case if errors.Is(err, sql.ErrNoRows) { return nil } + return errors.Wrap(err, "error executing query") } diff --git a/internal/database/download_client_test.go b/internal/database/download_client_test.go index cd4a919..309f50c 100644 --- a/internal/database/download_client_test.go +++ b/internal/database/download_client_test.go @@ -8,10 +8,12 @@ package database import ( "context" "fmt" - "github.com/autobrr/autobrr/internal/domain" - "github.com/stretchr/testify/assert" "testing" "time" + + "github.com/autobrr/autobrr/internal/domain" + + "github.com/stretchr/testify/assert" ) func getMockDownloadClient() domain.DownloadClient { @@ -54,13 +56,14 @@ func TestDownloadClientRepo_List(t *testing.T) { t.Run(fmt.Sprintf("List_Succeeds_With_No_Filters [%s]", dbType), func(t *testing.T) { // Insert mock data - createdClient, err := repo.Store(context.Background(), mockData) + mock := &mockData + err := repo.Store(context.Background(), mock) clients, err := repo.List(context.Background()) assert.NoError(t, err) assert.NotEmpty(t, clients) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("List_Succeeds_With_Empty_Database [%s]", dbType), func(t *testing.T) { @@ -77,32 +80,34 @@ func TestDownloadClientRepo_List(t *testing.T) { }) t.Run(fmt.Sprintf("List_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) { - createdClient, err := repo.Store(context.Background(), mockData) + mock := &mockData + err := repo.Store(context.Background(), mock) clients, err := repo.List(context.Background()) assert.NoError(t, err) assert.Equal(t, 1, len(clients)) - assert.Equal(t, createdClient.Name, clients[0].Name) + assert.Equal(t, mock.Name, clients[0].Name) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("List_Succeeds_With_Boundary_Value_For_Port [%s]", dbType), func(t *testing.T) { - mockData.Port = 65535 - createdClient, err := repo.Store(context.Background(), mockData) + mock := &mockData + mock.Port = 65535 + err := repo.Store(context.Background(), mock) clients, err := repo.List(context.Background()) assert.NoError(t, err) assert.Equal(t, 65535, clients[0].Port) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("List_Succeeds_With_Boolean_Flags_Set_To_False [%s]", dbType), func(t *testing.T) { mockData.Enabled = false mockData.TLS = false mockData.TLSSkipVerify = false - createdClient, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), &mockData) clients, err := repo.List(context.Background()) assert.NoError(t, err) assert.Equal(t, false, clients[0].Enabled) @@ -110,18 +115,18 @@ func TestDownloadClientRepo_List(t *testing.T) { assert.Equal(t, false, clients[0].TLSSkipVerify) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mockData.ID) }) t.Run(fmt.Sprintf("List_Succeeds_With_Special_Characters_In_Name [%s]", dbType), func(t *testing.T) { mockData.Name = "Special$Name" - createdClient, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), &mockData) clients, err := repo.List(context.Background()) assert.NoError(t, err) assert.Equal(t, "Special$Name", clients[0].Name) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mockData.ID) }) } } @@ -133,13 +138,14 @@ func TestDownloadClientRepo_FindByID(t *testing.T) { mockData := getMockDownloadClient() t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) { - createdClient, _ := repo.Store(context.Background(), mockData) - foundClient, err := repo.FindByID(context.Background(), int32(createdClient.ID)) + mock := &mockData + _ = repo.Store(context.Background(), mock) + foundClient, err := repo.FindByID(context.Background(), mock.ID) assert.NoError(t, err) assert.NotNil(t, foundClient) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("FindByID_Fails_With_Nonexistent_ID [%s]", dbType), func(t *testing.T) { @@ -156,40 +162,44 @@ func TestDownloadClientRepo_FindByID(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) { ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() + _, err := repo.FindByID(ctx, 1) assert.Error(t, err) }) t.Run(fmt.Sprintf("FindByID_Fails_After_Client_Deleted [%s]", dbType), func(t *testing.T) { - createdClient, _ := repo.Store(context.Background(), mockData) - _ = repo.Delete(context.Background(), createdClient.ID) - _, err := repo.FindByID(context.Background(), int32(createdClient.ID)) + mock := &mockData + _ = repo.Store(context.Background(), mock) + _ = repo.Delete(context.Background(), mock.ID) + _, err := repo.FindByID(context.Background(), mock.ID) assert.Error(t, err) assert.Equal(t, "no client configured", err.Error()) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("FindByID_Succeeds_With_Data_Integrity [%s]", dbType), func(t *testing.T) { - createdClient, _ := repo.Store(context.Background(), mockData) - foundClient, err := repo.FindByID(context.Background(), int32(createdClient.ID)) + mock := &mockData + _ = repo.Store(context.Background(), mock) + foundClient, err := repo.FindByID(context.Background(), mock.ID) assert.NoError(t, err) - assert.Equal(t, createdClient.Name, foundClient.Name) + assert.Equal(t, mock.Name, foundClient.Name) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mock.ID) }) t.Run(fmt.Sprintf("FindByID_Succeeds_From_Cache [%s]", dbType), func(t *testing.T) { - createdClient, _ := repo.Store(context.Background(), mockData) - foundClient1, _ := repo.FindByID(context.Background(), int32(createdClient.ID)) - foundClient2, err := repo.FindByID(context.Background(), int32(createdClient.ID)) + mock := &mockData + _ = repo.Store(context.Background(), mock) + foundClient1, _ := repo.FindByID(context.Background(), mock.ID) + foundClient2, err := repo.FindByID(context.Background(), mock.ID) assert.NoError(t, err) assert.Equal(t, foundClient1, foundClient2) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mock.ID) }) } } @@ -201,17 +211,17 @@ func TestDownloadClientRepo_Store(t *testing.T) { t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) { mockData := getMockDownloadClient() - createdClient, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), &mockData) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mockData) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mockData.ID) }) //TODO: Is this okay? Should we be able to store a client with no name (empty string)? t.Run(fmt.Sprintf("Store_Succeeds?_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) { - badMockData := domain.DownloadClient{ + badMockData := &domain.DownloadClient{ Type: "", Enabled: false, Host: "", @@ -222,30 +232,30 @@ func TestDownloadClientRepo_Store(t *testing.T) { Password: "", Settings: domain.DownloadClientSettings{}, } - createdClient, err := repo.Store(context.Background(), badMockData) + err := repo.Store(context.Background(), badMockData) assert.NoError(t, err) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), badMockData.ID) }) t.Run(fmt.Sprintf("Store_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) { mockData := getMockDownloadClient() ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - _, err := repo.Store(ctx, mockData) + err := repo.Store(ctx, &mockData) assert.Error(t, err) }) t.Run(fmt.Sprintf("Store_Succeeds_And_Caches [%s]", dbType), func(t *testing.T) { mockData := getMockDownloadClient() - createdClient, _ := repo.Store(context.Background(), mockData) + _ = repo.Store(context.Background(), &mockData) - cachedClient, _ := repo.FindByID(context.Background(), int32(createdClient.ID)) - assert.Equal(t, createdClient, cachedClient) + cachedClient, _ := repo.FindByID(context.Background(), mockData.ID) + assert.Equal(t, &mockData, cachedClient) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mockData.ID) }) } } @@ -258,22 +268,22 @@ func TestDownloadClientRepo_Update(t *testing.T) { t.Run(fmt.Sprintf("Update_Successfully_Updates_Record [%s]", dbType), func(t *testing.T) { mockClient := getMockDownloadClient() - createdClient, _ := repo.Store(context.Background(), mockClient) - createdClient.Name = "updatedName" - updatedClient, err := repo.Update(context.Background(), *createdClient) + _ = repo.Store(context.Background(), &mockClient) + mockClient.Name = "updatedName" + err := repo.Update(context.Background(), &mockClient) assert.NoError(t, err) - assert.Equal(t, "updatedName", updatedClient.Name) + assert.Equal(t, "updatedName", mockClient.Name) // Cleanup - _ = repo.Delete(context.Background(), updatedClient.ID) + _ = repo.Delete(context.Background(), mockClient.ID) }) t.Run(fmt.Sprintf("Update_Fails_With_Missing_ID [%s]", dbType), func(t *testing.T) { badMockData := getMockDownloadClient() badMockData.ID = 0 - _, err := repo.Update(context.Background(), badMockData) + err := repo.Update(context.Background(), &badMockData) assert.Error(t, err) @@ -283,7 +293,7 @@ func TestDownloadClientRepo_Update(t *testing.T) { badMockData := getMockDownloadClient() badMockData.ID = 9999 - _, err := repo.Update(context.Background(), badMockData) + err := repo.Update(context.Background(), &badMockData) assert.Error(t, err) }) @@ -291,7 +301,7 @@ func TestDownloadClientRepo_Update(t *testing.T) { t.Run(fmt.Sprintf("Update_Fails_With_Missing_Required_Fields [%s]", dbType), func(t *testing.T) { badMockData := domain.DownloadClient{} - _, err := repo.Update(context.Background(), badMockData) + err := repo.Update(context.Background(), &badMockData) assert.Error(t, err) }) @@ -305,13 +315,13 @@ func TestDownloadClientRepo_Delete(t *testing.T) { t.Run(fmt.Sprintf("Delete_Successfully_Deletes_Client [%s]", dbType), func(t *testing.T) { mockClient := getMockDownloadClient() - createdClient, _ := repo.Store(context.Background(), mockClient) + _ = repo.Store(context.Background(), &mockClient) - err := repo.Delete(context.Background(), createdClient.ID) + err := repo.Delete(context.Background(), mockClient.ID) assert.NoError(t, err) // Verify client was deleted - _, err = repo.FindByID(context.Background(), int32(createdClient.ID)) + _, err = repo.FindByID(context.Background(), mockClient.ID) assert.Error(t, err) }) @@ -322,16 +332,16 @@ func TestDownloadClientRepo_Delete(t *testing.T) { t.Run(fmt.Sprintf("Delete_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) { mockClient := getMockDownloadClient() - createdClient, _ := repo.Store(context.Background(), mockClient) + _ = repo.Store(context.Background(), &mockClient) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) defer cancel() - err := repo.Delete(ctx, createdClient.ID) + err := repo.Delete(ctx, mockClient.ID) assert.Error(t, err) // Cleanup - _ = repo.Delete(context.Background(), createdClient.ID) + _ = repo.Delete(context.Background(), mockClient.ID) }) } } diff --git a/internal/database/filter.go b/internal/database/filter.go index c0ca5dc..e54aa1b 100644 --- a/internal/database/filter.go +++ b/internal/database/filter.go @@ -255,13 +255,12 @@ func (r *FilterRepo) FindByID(ctx context.Context, filterID int) (*domain.Filter } row := r.db.handler.QueryRowContext(ctx, query, args...) - - if row.Err() != nil { - if errors.Is(row.Err(), sql.ErrNoRows) { + if err := row.Err(); err != nil { + if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrRecordNotFound } - return nil, errors.Wrap(row.Err(), "error row") + return nil, errors.Wrap(err, "error row") } var f domain.Filter diff --git a/internal/database/filter_test.go b/internal/database/filter_test.go index 6dade3c..204f2fc 100644 --- a/internal/database/filter_test.go +++ b/internal/database/filter_test.go @@ -791,12 +791,14 @@ func TestFilterRepo_GetDownloadsByFilterId(t *testing.T) { err := repo.Store(context.Background(), mockData) assert.NoError(t, err) - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mockClient := getMockDownloadClient() + + err = downloadClientRepo.Store(context.Background(), &mockClient) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mockClient) mockAction.FilterID = mockData.ID - mockAction.ClientID = int32(createdClient.ID) + mockAction.ClientID = mockClient.ID action, err := actionRepo.Store(context.Background(), mockAction) @@ -827,7 +829,7 @@ func TestFilterRepo_GetDownloadsByFilterId(t *testing.T) { // Cleanup _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: action.ID}) _ = repo.Delete(context.Background(), mockData.ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mockClient.ID) _ = releaseRepo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) }) diff --git a/internal/database/release_test.go b/internal/database/release_test.go index 83668ae..b31fe25 100644 --- a/internal/database/release_test.go +++ b/internal/database/release_test.go @@ -89,9 +89,10 @@ func TestReleaseRepo_Store(t *testing.T) { t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -101,7 +102,7 @@ func TestReleaseRepo_Store(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID // Execute @@ -124,7 +125,7 @@ func TestReleaseRepo_Store(t *testing.T) { _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -144,9 +145,10 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) { t.Run(fmt.Sprintf("StoreReleaseActionStatus_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -156,7 +158,7 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID // Execute @@ -179,7 +181,7 @@ func TestReleaseRepo_StoreReleaseActionStatus(t *testing.T) { _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -199,9 +201,10 @@ func TestReleaseRepo_Find(t *testing.T) { t.Run(fmt.Sprintf("FindReleases_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -211,7 +214,7 @@ func TestReleaseRepo_Find(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID // Execute @@ -238,7 +241,7 @@ func TestReleaseRepo_Find(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -258,9 +261,10 @@ func TestReleaseRepo_FindRecent(t *testing.T) { t.Run(fmt.Sprintf("FindRecent_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -270,7 +274,7 @@ func TestReleaseRepo_FindRecent(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID // Execute @@ -286,7 +290,7 @@ func TestReleaseRepo_FindRecent(t *testing.T) { // Cleanup _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -306,9 +310,10 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) { t.Run(fmt.Sprintf("GetIndexerOptions_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -318,7 +323,7 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID err = repo.Store(context.Background(), mockData) @@ -344,7 +349,7 @@ func TestReleaseRepo_GetIndexerOptions(t *testing.T) { _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -364,9 +369,10 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) { t.Run(fmt.Sprintf("GetActionStatusByReleaseID_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -376,7 +382,7 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID err = repo.Store(context.Background(), mockData) @@ -403,7 +409,7 @@ func TestReleaseRepo_GetActionStatusByReleaseID(t *testing.T) { _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -423,9 +429,10 @@ func TestReleaseRepo_Get(t *testing.T) { t.Run(fmt.Sprintf("Get_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -435,7 +442,7 @@ func TestReleaseRepo_Get(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID err = repo.Store(context.Background(), mockData) @@ -462,7 +469,7 @@ func TestReleaseRepo_Get(t *testing.T) { _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -482,9 +489,10 @@ func TestReleaseRepo_Stats(t *testing.T) { t.Run(fmt.Sprintf("Stats_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -494,7 +502,7 @@ func TestReleaseRepo_Stats(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID err = repo.Store(context.Background(), mockData) @@ -521,7 +529,7 @@ func TestReleaseRepo_Stats(t *testing.T) { _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -541,9 +549,10 @@ func TestReleaseRepo_Delete(t *testing.T) { t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -553,7 +562,7 @@ func TestReleaseRepo_Delete(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID err = repo.Store(context.Background(), mockData) @@ -577,7 +586,7 @@ func TestReleaseRepo_Delete(t *testing.T) { // Cleanup _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } @@ -597,9 +606,10 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) { t.Run(fmt.Sprintf("Check_Smart_Episode_Can_Download [%s]", dbType), func(t *testing.T) { // Setup - createdClient, err := downloadClientRepo.Store(context.Background(), getMockDownloadClient()) + mock := getMockDownloadClient() + err := downloadClientRepo.Store(context.Background(), &mock) assert.NoError(t, err) - assert.NotNil(t, createdClient) + assert.NotNil(t, mock) err = filterRepo.Store(context.Background(), getMockFilter()) assert.NoError(t, err) @@ -609,7 +619,7 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) { assert.NotNil(t, createdFilters) actionMockData.FilterID = createdFilters[0].ID - actionMockData.ClientID = int32(createdClient.ID) + actionMockData.ClientID = mock.ID mockData.FilterID = createdFilters[0].ID err = repo.Store(context.Background(), mockData) @@ -644,7 +654,7 @@ func TestReleaseRepo_CheckSmartEpisodeCanDownloadShow(t *testing.T) { _ = repo.Delete(context.Background(), &domain.DeleteReleaseRequest{OlderThan: 0}) _ = actionRepo.Delete(context.Background(), &domain.DeleteActionRequest{ActionId: createdAction.ID}) _ = filterRepo.Delete(context.Background(), createdFilters[0].ID) - _ = downloadClientRepo.Delete(context.Background(), createdClient.ID) + _ = downloadClientRepo.Delete(context.Background(), mock.ID) }) } } diff --git a/internal/domain/action.go b/internal/domain/action.go index 173ef50..c171292 100644 --- a/internal/domain/action.go +++ b/internal/domain/action.go @@ -14,7 +14,7 @@ import ( type ActionRepo interface { Store(ctx context.Context, action Action) (*Action, error) StoreFilterActions(ctx context.Context, filterID int64, actions []*Action) ([]*Action, error) - FindByFilterID(ctx context.Context, filterID int, active *bool) ([]*Action, error) + FindByFilterID(ctx context.Context, filterID int, active *bool, withClient bool) ([]*Action, error) List(ctx context.Context) ([]Action, error) Get(ctx context.Context, req *GetActionRequest) (*Action, error) Delete(ctx context.Context, req *DeleteActionRequest) error diff --git a/internal/domain/client.go b/internal/domain/client.go index f9c05bf..f9d667e 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -9,20 +9,18 @@ import ( "net/url" "github.com/autobrr/autobrr/pkg/errors" - - "github.com/autobrr/go-qbittorrent" ) type DownloadClientRepo interface { List(ctx context.Context) ([]DownloadClient, error) FindByID(ctx context.Context, id int32) (*DownloadClient, error) - Store(ctx context.Context, client DownloadClient) (*DownloadClient, error) - Update(ctx context.Context, client DownloadClient) (*DownloadClient, error) - Delete(ctx context.Context, clientID int) error + Store(ctx context.Context, client *DownloadClient) error + Update(ctx context.Context, client *DownloadClient) error + Delete(ctx context.Context, clientID int32) error } type DownloadClient struct { - ID int `json:"id"` + ID int32 `json:"id"` Name string `json:"name"` Type DownloadClientType `json:"type"` Enabled bool `json:"enabled"` @@ -33,11 +31,9 @@ type DownloadClient struct { Username string `json:"username"` Password string `json:"password"` Settings DownloadClientSettings `json:"settings,omitempty"` -} -type DownloadClientCached struct { - Dc *DownloadClient - Qbt *qbittorrent.Client + // cached http client + Client any } type DownloadClientSettings struct { diff --git a/internal/domain/client_test.go b/internal/domain/client_test.go index 551e1c6..25e362d 100644 --- a/internal/domain/client_test.go +++ b/internal/domain/client_test.go @@ -11,7 +11,7 @@ import ( func TestDownloadClient_qbitBuildLegacyHost(t *testing.T) { type fields struct { - ID int + ID int32 Name string Type DownloadClientType Enabled bool diff --git a/internal/download_client/cache.go b/internal/download_client/cache.go new file mode 100644 index 0000000..dd2784e --- /dev/null +++ b/internal/download_client/cache.go @@ -0,0 +1,48 @@ +package download_client + +import ( + "sync" + + "github.com/autobrr/autobrr/internal/domain" +) + +type ClientCacheStore interface { + Set(id int32, client *domain.DownloadClient) + Get(id int32) *domain.DownloadClient + Pop(id int32) +} + +type ClientCache struct { + mu sync.RWMutex + clients map[int32]*domain.DownloadClient +} + +func NewClientCache() *ClientCache { + return &ClientCache{ + clients: make(map[int32]*domain.DownloadClient), + } +} + +func (c *ClientCache) Set(id int32, client *domain.DownloadClient) { + if client != nil { + c.mu.Lock() + c.clients[id] = client + c.mu.Unlock() + } +} + +func (c *ClientCache) Get(id int32) *domain.DownloadClient { + c.mu.RLock() + defer c.mu.RUnlock() + v, ok := c.clients[id] + if ok { + return v + } + return nil +} + +func (c *ClientCache) Pop(id int32) { + c.mu.Lock() + delete(c.clients, id) + c.mu.Unlock() +} diff --git a/internal/download_client/service.go b/internal/download_client/service.go index 7d58754..9c58865 100644 --- a/internal/download_client/service.go +++ b/internal/download_client/service.go @@ -5,13 +5,27 @@ package download_client import ( "context" + "fmt" "log" + "net/url" "sync" + "time" "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/pkg/errors" + "github.com/autobrr/autobrr/pkg/lidarr" + "github.com/autobrr/autobrr/pkg/porla" + "github.com/autobrr/autobrr/pkg/radarr" + "github.com/autobrr/autobrr/pkg/readarr" + "github.com/autobrr/autobrr/pkg/sabnzbd" + "github.com/autobrr/autobrr/pkg/sonarr" + "github.com/autobrr/autobrr/pkg/transmission" + "github.com/autobrr/autobrr/pkg/whisparr" + "github.com/autobrr/go-deluge" "github.com/autobrr/go-qbittorrent" + "github.com/autobrr/go-rtorrent" "github.com/dcarbone/zadapters/zstdlog" "github.com/rs/zerolog" ) @@ -19,12 +33,12 @@ import ( type Service interface { List(ctx context.Context) ([]domain.DownloadClient, error) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) - Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) - Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) - Delete(ctx context.Context, clientID int) error + Store(ctx context.Context, client *domain.DownloadClient) error + Update(ctx context.Context, client *domain.DownloadClient) error + Delete(ctx context.Context, clientID int32) error Test(ctx context.Context, client domain.DownloadClient) error - GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached + GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error) } type service struct { @@ -32,8 +46,8 @@ type service struct { repo domain.DownloadClientRepo subLogger *log.Logger - qbitClients map[int32]*domain.DownloadClientCached - m sync.RWMutex + cache *ClientCache + m sync.RWMutex } func NewService(log logger.Logger, repo domain.DownloadClientRepo) Service { @@ -41,8 +55,8 @@ func NewService(log logger.Logger, repo domain.DownloadClientRepo) Service { log: log.With().Str("module", "download_client").Logger(), repo: repo, - qbitClients: map[int32]*domain.DownloadClientCached{}, - m: sync.RWMutex{}, + cache: NewClientCache(), + m: sync.RWMutex{}, } s.subLogger = zstdlog.NewStdLoggerWithLevel(s.log.With().Logger(), zerolog.TraceLevel) @@ -61,6 +75,13 @@ func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) { } func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) { + client := s.cache.Get(id) + if client != nil { + return client, nil + } + + s.log.Trace().Msgf("cache miss for client id %d, continue to repo lookup", id) + client, err := s.repo.FindByID(ctx, id) if err != nil { s.log.Error().Err(err).Msgf("could not find download client by id: %v", id) @@ -70,53 +91,49 @@ func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClien return client, nil } -func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { +func (s *service) Store(ctx context.Context, client *domain.DownloadClient) error { // basic validation of client if err := client.Validate(); err != nil { - return nil, err + return err } // store - c, err := s.repo.Store(ctx, client) + err := s.repo.Store(ctx, client) if err != nil { s.log.Error().Err(err).Msgf("could not store download client: %+v", client) - return nil, err + return err } - return c, err + s.cache.Set(client.ID, client) + + return err } -func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { +func (s *service) Update(ctx context.Context, client *domain.DownloadClient) error { // basic validation of client if err := client.Validate(); err != nil { - return nil, err + return err } // update - c, err := s.repo.Update(ctx, client) + err := s.repo.Update(ctx, client) if err != nil { s.log.Error().Err(err).Msgf("could not update download client: %+v", client) - return nil, err + return err } - if client.Type == domain.DownloadClientTypeQbittorrent { - s.m.Lock() - delete(s.qbitClients, int32(client.ID)) - s.m.Unlock() - } + s.cache.Set(client.ID, client) - return c, err + return err } -func (s *service) Delete(ctx context.Context, clientID int) error { +func (s *service) Delete(ctx context.Context, clientID int32) error { if err := s.repo.Delete(ctx, clientID); err != nil { s.log.Error().Err(err).Msgf("could not delete download client: %v", clientID) return err } - s.m.Lock() - delete(s.qbitClients, int32(clientID)) - s.m.Unlock() + s.cache.Pop(clientID) return nil } @@ -136,53 +153,165 @@ func (s *service) Test(ctx context.Context, client domain.DownloadClient) error return nil } -func (s *service) GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached { - - // check if client exists in cache - s.m.RLock() - cached, ok := s.qbitClients[clientId] - s.m.RUnlock() - - if ok { - return cached - } - - // get client for action - client, err := s.FindByID(ctx, clientId) - if err != nil { - return nil - } +// GetClient get client from cache or repo and attach downloadClient implementation +func (s *service) GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error) { + l := s.log.With().Str("cache", "download-client").Logger() + client := s.cache.Get(clientId) if client == nil { - return nil + l.Trace().Msgf("cache miss for client id %d, continue to repo lookup", clientId) + + var err error + client, err = s.repo.FindByID(ctx, clientId) + if err != nil { + return nil, errors.Wrap(err, "could not find client repo.FindByID") + } } - qbtSettings := qbittorrent.Config{ - Host: client.BuildLegacyHost(), - Username: client.Username, - Password: client.Password, - TLSSkipVerify: client.TLSSkipVerify, + // if we have the client return it + if client.Client != nil { + l.Trace().Msgf("cache hit for client id %d %s", clientId, client.Name) + return client, nil } - // setup sub logger adapter which is compatible with *log.Logger - qbtSettings.Log = zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel) + l.Trace().Msgf("init cache client id %d %s", clientId, client.Name) - // only set basic auth if enabled - if client.Settings.Basic.Auth { - qbtSettings.BasicUser = client.Settings.Basic.Username - qbtSettings.BasicPass = client.Settings.Basic.Password + switch client.Type { + case domain.DownloadClientTypeQbittorrent: + client.Client = qbittorrent.NewClient(qbittorrent.Config{ + Host: client.BuildLegacyHost(), + Username: client.Username, + Password: client.Password, + TLSSkipVerify: client.TLSSkipVerify, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel), + BasicUser: client.Settings.Basic.Username, + BasicPass: client.Settings.Basic.Password, + }) + + case domain.DownloadClientTypePorla: + client.Client = porla.NewClient(porla.Config{ + Hostname: client.Host, + AuthToken: client.Settings.APIKey, + TLSSkipVerify: client.TLSSkipVerify, + BasicUser: client.Settings.Basic.Username, + BasicPass: client.Settings.Basic.Password, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Porla").Str("client", client.Name).Logger(), zerolog.TraceLevel), + }) + + case domain.DownloadClientTypeDelugeV1: + client.Client = deluge.NewV1(deluge.Settings{ + Hostname: client.Host, + Port: uint(client.Port), + Login: client.Username, + Password: client.Password, + DebugServerResponses: true, + ReadWriteTimeout: time.Second * 60, + }) + + case domain.DownloadClientTypeDelugeV2: + client.Client = deluge.NewV2(deluge.Settings{ + Hostname: client.Host, + Port: uint(client.Port), + Login: client.Username, + Password: client.Password, + DebugServerResponses: true, + ReadWriteTimeout: time.Second * 60, + }) + + case domain.DownloadClientTypeTransmission: + scheme := "http" + if client.TLS { + scheme = "https" + } + + transmissionURL, err := url.Parse(fmt.Sprintf("%s://%s:%d/transmission/rpc", scheme, client.Host, client.Port)) + if err != nil { + return nil, errors.Wrap(err, "could not parse transmission url") + } + + tbt, err := transmission.New(transmissionURL, &transmission.Config{ + UserAgent: "autobrr", + Username: client.Username, + Password: client.Password, + TLSSkipVerify: client.TLSSkipVerify, + }) + if err != nil { + return nil, errors.Wrap(err, "error logging into transmission client: %s", client.Host) + } + client.Client = tbt + + case domain.DownloadClientTypeRTorrent: + client.Client = rtorrent.NewClient(rtorrent.Config{ + Addr: client.Host, + TLSSkipVerify: client.TLSSkipVerify, + BasicUser: client.Settings.Basic.Username, + BasicPass: client.Settings.Basic.Password, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "rTorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel), + }) + + case domain.DownloadClientTypeLidarr: + client.Client = lidarr.New(lidarr.Config{ + Hostname: client.Host, + APIKey: client.Settings.APIKey, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Lidarr").Str("client", client.Name).Logger(), zerolog.TraceLevel), + BasicAuth: client.Settings.Basic.Auth, + Username: client.Settings.Basic.Username, + Password: client.Settings.Basic.Password, + }) + + case domain.DownloadClientTypeRadarr: + client.Client = radarr.New(radarr.Config{ + Hostname: client.Host, + APIKey: client.Settings.APIKey, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Radarr").Str("client", client.Name).Logger(), zerolog.TraceLevel), + BasicAuth: client.Settings.Basic.Auth, + Username: client.Settings.Basic.Username, + Password: client.Settings.Basic.Password, + }) + + case domain.DownloadClientTypeReadarr: + client.Client = readarr.New(readarr.Config{ + Hostname: client.Host, + APIKey: client.Settings.APIKey, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Readarr").Str("client", client.Name).Logger(), zerolog.TraceLevel), + BasicAuth: client.Settings.Basic.Auth, + Username: client.Settings.Basic.Username, + Password: client.Settings.Basic.Password, + }) + + case domain.DownloadClientTypeSonarr: + client.Client = sonarr.New(sonarr.Config{ + Hostname: client.Host, + APIKey: client.Settings.APIKey, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Sonarr").Str("client", client.Name).Logger(), zerolog.TraceLevel), + BasicAuth: client.Settings.Basic.Auth, + Username: client.Settings.Basic.Username, + Password: client.Settings.Basic.Password, + }) + + case domain.DownloadClientTypeWhisparr: + client.Client = whisparr.New(whisparr.Config{ + Hostname: client.Host, + APIKey: client.Settings.APIKey, + Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Whisparr").Str("client", client.Name).Logger(), zerolog.TraceLevel), + BasicAuth: client.Settings.Basic.Auth, + Username: client.Settings.Basic.Username, + Password: client.Settings.Basic.Password, + }) + + case domain.DownloadClientTypeSabnzbd: + client.Client = sabnzbd.New(sabnzbd.Options{ + Addr: client.Host, + ApiKey: client.Settings.APIKey, + Log: nil, + BasicUser: client.Settings.Basic.Username, + BasicPass: client.Settings.Basic.Password, + }) } - qc := &domain.DownloadClientCached{ - Dc: client, - Qbt: qbittorrent.NewClient(qbtSettings), - } + l.Trace().Msgf("set cache client id %d %s", clientId, client.Name) - cached = qc + s.cache.Set(clientId, client) - s.m.Lock() - s.qbitClients[clientId] = cached - s.m.Unlock() - - return cached + return client, nil } diff --git a/internal/filter/service.go b/internal/filter/service.go index ff3e818..05189df 100644 --- a/internal/filter/service.go +++ b/internal/filter/service.go @@ -7,6 +7,7 @@ import ( "bytes" "context" "fmt" + "github.com/autobrr/autobrr/internal/action" "io" "net/http" "os" @@ -46,24 +47,24 @@ type Service interface { } type service struct { - log zerolog.Logger - repo domain.FilterRepo - actionRepo domain.ActionRepo - releaseRepo domain.ReleaseRepo - indexerSvc indexer.Service - apiService indexer.APIService + log zerolog.Logger + repo domain.FilterRepo + actionService action.Service + releaseRepo domain.ReleaseRepo + indexerSvc indexer.Service + apiService indexer.APIService httpClient *http.Client } -func NewService(log logger.Logger, repo domain.FilterRepo, actionRepo domain.ActionRepo, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service { +func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service { return &service{ - log: log.With().Str("module", "filter").Logger(), - repo: repo, - actionRepo: actionRepo, - releaseRepo: releaseRepo, - apiService: apiService, - indexerSvc: indexerSvc, + log: log.With().Str("module", "filter").Logger(), + repo: repo, + releaseRepo: releaseRepo, + actionService: actionSvc, + apiService: apiService, + indexerSvc: indexerSvc, httpClient: &http.Client{ Timeout: time.Second * 120, Transport: sharedhttp.TransportTLSInsecure, @@ -130,7 +131,7 @@ func (s *service) FindByID(ctx context.Context, filterID int) (*domain.Filter, e } filter.External = externalFilters - actions, err := s.actionRepo.FindByFilterID(ctx, filter.ID, nil) + actions, err := s.actionService.FindByFilterID(ctx, filter.ID, nil, false) if err != nil { s.log.Error().Err(err).Msgf("could not find filter actions for filter id: %v", filter.ID) } @@ -222,7 +223,7 @@ func (s *service) Update(ctx context.Context, filter *domain.Filter) error { } // take care of filter actions - actions, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions) + actions, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions) if err != nil { s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name) return err @@ -267,7 +268,7 @@ func (s *service) UpdatePartial(ctx context.Context, filter domain.FilterUpdate) if filter.Actions != nil { // take care of filter actions - if _, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil { + if _, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil { s.log.Error().Err(err).Msgf("could not store filter actions: %v", filter.ID) return err } @@ -308,7 +309,7 @@ func (s *service) Duplicate(ctx context.Context, filterID int) (*domain.Filter, } // take care of filter actions - if _, err := s.actionRepo.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil { + if _, err := s.actionService.StoreFilterActions(ctx, int64(filter.ID), filter.Actions); err != nil { s.log.Error().Err(err).Msgf("could not store filter actions: %s", filter.Name) return nil, err } @@ -340,7 +341,7 @@ func (s *service) Delete(ctx context.Context, filterID int) error { } // take care of filter actions - if err := s.actionRepo.DeleteByFilterID(ctx, filterID); err != nil { + if err := s.actionService.DeleteByFilterID(ctx, filterID); err != nil { s.log.Error().Err(err).Msg("could not delete filter actions") return err } diff --git a/internal/http/download_client.go b/internal/http/download_client.go index 5673ba1..7e05609 100644 --- a/internal/http/download_client.go +++ b/internal/http/download_client.go @@ -17,9 +17,9 @@ import ( type downloadClientService interface { List(ctx context.Context) ([]domain.DownloadClient, error) - Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) - Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) - Delete(ctx context.Context, clientID int) error + Store(ctx context.Context, client *domain.DownloadClient) error + Update(ctx context.Context, client *domain.DownloadClient) error + Delete(ctx context.Context, clientID int32) error Test(ctx context.Context, client domain.DownloadClient) error } @@ -56,20 +56,20 @@ func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *htt } func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) { - var data domain.DownloadClient + var data *domain.DownloadClient if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.Error(w, err) return } - client, err := h.service.Store(r.Context(), data) + err := h.service.Store(r.Context(), data) if err != nil { h.encoder.Error(w, err) return } - h.encoder.StatusResponse(w, http.StatusCreated, client) + h.encoder.StatusResponse(w, http.StatusCreated, data) } func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) { @@ -89,20 +89,20 @@ func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) { } func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) { - var data domain.DownloadClient + var data *domain.DownloadClient if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.Error(w, err) return } - client, err := h.service.Update(r.Context(), data) + err := h.service.Update(r.Context(), data) if err != nil { h.encoder.Error(w, err) return } - h.encoder.StatusResponse(w, http.StatusCreated, client) + h.encoder.StatusResponse(w, http.StatusCreated, data) } func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) { @@ -113,13 +113,13 @@ func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) { return } - id, err := strconv.Atoi(clientID) + id, err := strconv.ParseInt(clientID, 10, 32) if err != nil { h.encoder.Error(w, err) return } - if err = h.service.Delete(r.Context(), id); err != nil { + if err = h.service.Delete(r.Context(), int32(id)); err != nil { h.encoder.Error(w, err) return } diff --git a/internal/release/service.go b/internal/release/service.go index 7054264..0f9eaff 100644 --- a/internal/release/service.go +++ b/internal/release/service.go @@ -221,7 +221,7 @@ func (s *service) processFilters(ctx context.Context, filters []*domain.Filter, // found matching filter, lets find the filter actions and attach active := true - actions, err := s.actionSvc.FindByFilterID(ctx, f.ID, &active) + actions, err := s.actionSvc.FindByFilterID(ctx, f.ID, &active, false) if err != nil { s.log.Error().Err(err).Msgf("release.Process: error finding actions for filter: %s", f.Name) return err diff --git a/pkg/transmission/transmission.go b/pkg/transmission/transmission.go index af69c8c..62297df 100644 --- a/pkg/transmission/transmission.go +++ b/pkg/transmission/transmission.go @@ -17,7 +17,7 @@ type Config struct { Username string Password string TLSSkipVerify bool - Timeout int + Timeout time.Duration } func New(endpoint *url.URL, cfg *Config) (*transmissionrpc.Client, error) {