mirror of
https://github.com/idanoo/autobrr
synced 2025-07-22 16:29:12 +00:00
fix(actions): reject if client is disabled (#1626)
* fix(actions): error on disabled client * fix(actions): sql scan args * refactor: download client cache for actions * fix: tests client store * fix: tests client store and int conversion * fix: tests revert findbyid ctx timeout * fix: tests row.err * feat: add logging to download client cache
This commit is contained in:
parent
77e1c2c305
commit
861f30c144
30 changed files with 928 additions and 680 deletions
48
internal/download_client/cache.go
Normal file
48
internal/download_client/cache.go
Normal file
|
@ -0,0 +1,48 @@
|
|||
package download_client
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
)
|
||||
|
||||
type ClientCacheStore interface {
|
||||
Set(id int32, client *domain.DownloadClient)
|
||||
Get(id int32) *domain.DownloadClient
|
||||
Pop(id int32)
|
||||
}
|
||||
|
||||
type ClientCache struct {
|
||||
mu sync.RWMutex
|
||||
clients map[int32]*domain.DownloadClient
|
||||
}
|
||||
|
||||
func NewClientCache() *ClientCache {
|
||||
return &ClientCache{
|
||||
clients: make(map[int32]*domain.DownloadClient),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCache) Set(id int32, client *domain.DownloadClient) {
|
||||
if client != nil {
|
||||
c.mu.Lock()
|
||||
c.clients[id] = client
|
||||
c.mu.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ClientCache) Get(id int32) *domain.DownloadClient {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
v, ok := c.clients[id]
|
||||
if ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *ClientCache) Pop(id int32) {
|
||||
c.mu.Lock()
|
||||
delete(c.clients, id)
|
||||
c.mu.Unlock()
|
||||
}
|
|
@ -5,13 +5,27 @@ package download_client
|
|||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/url"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
"github.com/autobrr/autobrr/internal/logger"
|
||||
"github.com/autobrr/autobrr/pkg/errors"
|
||||
"github.com/autobrr/autobrr/pkg/lidarr"
|
||||
"github.com/autobrr/autobrr/pkg/porla"
|
||||
"github.com/autobrr/autobrr/pkg/radarr"
|
||||
"github.com/autobrr/autobrr/pkg/readarr"
|
||||
"github.com/autobrr/autobrr/pkg/sabnzbd"
|
||||
"github.com/autobrr/autobrr/pkg/sonarr"
|
||||
"github.com/autobrr/autobrr/pkg/transmission"
|
||||
"github.com/autobrr/autobrr/pkg/whisparr"
|
||||
|
||||
"github.com/autobrr/go-deluge"
|
||||
"github.com/autobrr/go-qbittorrent"
|
||||
"github.com/autobrr/go-rtorrent"
|
||||
"github.com/dcarbone/zadapters/zstdlog"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
@ -19,12 +33,12 @@ import (
|
|||
type Service interface {
|
||||
List(ctx context.Context) ([]domain.DownloadClient, error)
|
||||
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
|
||||
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
|
||||
Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
|
||||
Delete(ctx context.Context, clientID int) error
|
||||
Store(ctx context.Context, client *domain.DownloadClient) error
|
||||
Update(ctx context.Context, client *domain.DownloadClient) error
|
||||
Delete(ctx context.Context, clientID int32) error
|
||||
Test(ctx context.Context, client domain.DownloadClient) error
|
||||
|
||||
GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached
|
||||
GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error)
|
||||
}
|
||||
|
||||
type service struct {
|
||||
|
@ -32,8 +46,8 @@ type service struct {
|
|||
repo domain.DownloadClientRepo
|
||||
subLogger *log.Logger
|
||||
|
||||
qbitClients map[int32]*domain.DownloadClientCached
|
||||
m sync.RWMutex
|
||||
cache *ClientCache
|
||||
m sync.RWMutex
|
||||
}
|
||||
|
||||
func NewService(log logger.Logger, repo domain.DownloadClientRepo) Service {
|
||||
|
@ -41,8 +55,8 @@ func NewService(log logger.Logger, repo domain.DownloadClientRepo) Service {
|
|||
log: log.With().Str("module", "download_client").Logger(),
|
||||
repo: repo,
|
||||
|
||||
qbitClients: map[int32]*domain.DownloadClientCached{},
|
||||
m: sync.RWMutex{},
|
||||
cache: NewClientCache(),
|
||||
m: sync.RWMutex{},
|
||||
}
|
||||
|
||||
s.subLogger = zstdlog.NewStdLoggerWithLevel(s.log.With().Logger(), zerolog.TraceLevel)
|
||||
|
@ -61,6 +75,13 @@ func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) {
|
|||
}
|
||||
|
||||
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
|
||||
client := s.cache.Get(id)
|
||||
if client != nil {
|
||||
return client, nil
|
||||
}
|
||||
|
||||
s.log.Trace().Msgf("cache miss for client id %d, continue to repo lookup", id)
|
||||
|
||||
client, err := s.repo.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msgf("could not find download client by id: %v", id)
|
||||
|
@ -70,53 +91,49 @@ func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClien
|
|||
return client, nil
|
||||
}
|
||||
|
||||
func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
|
||||
func (s *service) Store(ctx context.Context, client *domain.DownloadClient) error {
|
||||
// basic validation of client
|
||||
if err := client.Validate(); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// store
|
||||
c, err := s.repo.Store(ctx, client)
|
||||
err := s.repo.Store(ctx, client)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msgf("could not store download client: %+v", client)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
return c, err
|
||||
s.cache.Set(client.ID, client)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
|
||||
func (s *service) Update(ctx context.Context, client *domain.DownloadClient) error {
|
||||
// basic validation of client
|
||||
if err := client.Validate(); err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
// update
|
||||
c, err := s.repo.Update(ctx, client)
|
||||
err := s.repo.Update(ctx, client)
|
||||
if err != nil {
|
||||
s.log.Error().Err(err).Msgf("could not update download client: %+v", client)
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
if client.Type == domain.DownloadClientTypeQbittorrent {
|
||||
s.m.Lock()
|
||||
delete(s.qbitClients, int32(client.ID))
|
||||
s.m.Unlock()
|
||||
}
|
||||
s.cache.Set(client.ID, client)
|
||||
|
||||
return c, err
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *service) Delete(ctx context.Context, clientID int) error {
|
||||
func (s *service) Delete(ctx context.Context, clientID int32) error {
|
||||
if err := s.repo.Delete(ctx, clientID); err != nil {
|
||||
s.log.Error().Err(err).Msgf("could not delete download client: %v", clientID)
|
||||
return err
|
||||
}
|
||||
|
||||
s.m.Lock()
|
||||
delete(s.qbitClients, int32(clientID))
|
||||
s.m.Unlock()
|
||||
s.cache.Pop(clientID)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -136,53 +153,165 @@ func (s *service) Test(ctx context.Context, client domain.DownloadClient) error
|
|||
return nil
|
||||
}
|
||||
|
||||
func (s *service) GetCachedClient(ctx context.Context, clientId int32) *domain.DownloadClientCached {
|
||||
|
||||
// check if client exists in cache
|
||||
s.m.RLock()
|
||||
cached, ok := s.qbitClients[clientId]
|
||||
s.m.RUnlock()
|
||||
|
||||
if ok {
|
||||
return cached
|
||||
}
|
||||
|
||||
// get client for action
|
||||
client, err := s.FindByID(ctx, clientId)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
// GetClient get client from cache or repo and attach downloadClient implementation
|
||||
func (s *service) GetClient(ctx context.Context, clientId int32) (*domain.DownloadClient, error) {
|
||||
l := s.log.With().Str("cache", "download-client").Logger()
|
||||
|
||||
client := s.cache.Get(clientId)
|
||||
if client == nil {
|
||||
return nil
|
||||
l.Trace().Msgf("cache miss for client id %d, continue to repo lookup", clientId)
|
||||
|
||||
var err error
|
||||
client, err = s.repo.FindByID(ctx, clientId)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not find client repo.FindByID")
|
||||
}
|
||||
}
|
||||
|
||||
qbtSettings := qbittorrent.Config{
|
||||
Host: client.BuildLegacyHost(),
|
||||
Username: client.Username,
|
||||
Password: client.Password,
|
||||
TLSSkipVerify: client.TLSSkipVerify,
|
||||
// if we have the client return it
|
||||
if client.Client != nil {
|
||||
l.Trace().Msgf("cache hit for client id %d %s", clientId, client.Name)
|
||||
return client, nil
|
||||
}
|
||||
|
||||
// setup sub logger adapter which is compatible with *log.Logger
|
||||
qbtSettings.Log = zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel)
|
||||
l.Trace().Msgf("init cache client id %d %s", clientId, client.Name)
|
||||
|
||||
// only set basic auth if enabled
|
||||
if client.Settings.Basic.Auth {
|
||||
qbtSettings.BasicUser = client.Settings.Basic.Username
|
||||
qbtSettings.BasicPass = client.Settings.Basic.Password
|
||||
switch client.Type {
|
||||
case domain.DownloadClientTypeQbittorrent:
|
||||
client.Client = qbittorrent.NewClient(qbittorrent.Config{
|
||||
Host: client.BuildLegacyHost(),
|
||||
Username: client.Username,
|
||||
Password: client.Password,
|
||||
TLSSkipVerify: client.TLSSkipVerify,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "qBittorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
BasicUser: client.Settings.Basic.Username,
|
||||
BasicPass: client.Settings.Basic.Password,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypePorla:
|
||||
client.Client = porla.NewClient(porla.Config{
|
||||
Hostname: client.Host,
|
||||
AuthToken: client.Settings.APIKey,
|
||||
TLSSkipVerify: client.TLSSkipVerify,
|
||||
BasicUser: client.Settings.Basic.Username,
|
||||
BasicPass: client.Settings.Basic.Password,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Porla").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeDelugeV1:
|
||||
client.Client = deluge.NewV1(deluge.Settings{
|
||||
Hostname: client.Host,
|
||||
Port: uint(client.Port),
|
||||
Login: client.Username,
|
||||
Password: client.Password,
|
||||
DebugServerResponses: true,
|
||||
ReadWriteTimeout: time.Second * 60,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeDelugeV2:
|
||||
client.Client = deluge.NewV2(deluge.Settings{
|
||||
Hostname: client.Host,
|
||||
Port: uint(client.Port),
|
||||
Login: client.Username,
|
||||
Password: client.Password,
|
||||
DebugServerResponses: true,
|
||||
ReadWriteTimeout: time.Second * 60,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeTransmission:
|
||||
scheme := "http"
|
||||
if client.TLS {
|
||||
scheme = "https"
|
||||
}
|
||||
|
||||
transmissionURL, err := url.Parse(fmt.Sprintf("%s://%s:%d/transmission/rpc", scheme, client.Host, client.Port))
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not parse transmission url")
|
||||
}
|
||||
|
||||
tbt, err := transmission.New(transmissionURL, &transmission.Config{
|
||||
UserAgent: "autobrr",
|
||||
Username: client.Username,
|
||||
Password: client.Password,
|
||||
TLSSkipVerify: client.TLSSkipVerify,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "error logging into transmission client: %s", client.Host)
|
||||
}
|
||||
client.Client = tbt
|
||||
|
||||
case domain.DownloadClientTypeRTorrent:
|
||||
client.Client = rtorrent.NewClient(rtorrent.Config{
|
||||
Addr: client.Host,
|
||||
TLSSkipVerify: client.TLSSkipVerify,
|
||||
BasicUser: client.Settings.Basic.Username,
|
||||
BasicPass: client.Settings.Basic.Password,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "rTorrent").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeLidarr:
|
||||
client.Client = lidarr.New(lidarr.Config{
|
||||
Hostname: client.Host,
|
||||
APIKey: client.Settings.APIKey,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Lidarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
BasicAuth: client.Settings.Basic.Auth,
|
||||
Username: client.Settings.Basic.Username,
|
||||
Password: client.Settings.Basic.Password,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeRadarr:
|
||||
client.Client = radarr.New(radarr.Config{
|
||||
Hostname: client.Host,
|
||||
APIKey: client.Settings.APIKey,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Radarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
BasicAuth: client.Settings.Basic.Auth,
|
||||
Username: client.Settings.Basic.Username,
|
||||
Password: client.Settings.Basic.Password,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeReadarr:
|
||||
client.Client = readarr.New(readarr.Config{
|
||||
Hostname: client.Host,
|
||||
APIKey: client.Settings.APIKey,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Readarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
BasicAuth: client.Settings.Basic.Auth,
|
||||
Username: client.Settings.Basic.Username,
|
||||
Password: client.Settings.Basic.Password,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeSonarr:
|
||||
client.Client = sonarr.New(sonarr.Config{
|
||||
Hostname: client.Host,
|
||||
APIKey: client.Settings.APIKey,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Sonarr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
BasicAuth: client.Settings.Basic.Auth,
|
||||
Username: client.Settings.Basic.Username,
|
||||
Password: client.Settings.Basic.Password,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeWhisparr:
|
||||
client.Client = whisparr.New(whisparr.Config{
|
||||
Hostname: client.Host,
|
||||
APIKey: client.Settings.APIKey,
|
||||
Log: zstdlog.NewStdLoggerWithLevel(s.log.With().Str("type", "Whisparr").Str("client", client.Name).Logger(), zerolog.TraceLevel),
|
||||
BasicAuth: client.Settings.Basic.Auth,
|
||||
Username: client.Settings.Basic.Username,
|
||||
Password: client.Settings.Basic.Password,
|
||||
})
|
||||
|
||||
case domain.DownloadClientTypeSabnzbd:
|
||||
client.Client = sabnzbd.New(sabnzbd.Options{
|
||||
Addr: client.Host,
|
||||
ApiKey: client.Settings.APIKey,
|
||||
Log: nil,
|
||||
BasicUser: client.Settings.Basic.Username,
|
||||
BasicPass: client.Settings.Basic.Password,
|
||||
})
|
||||
}
|
||||
|
||||
qc := &domain.DownloadClientCached{
|
||||
Dc: client,
|
||||
Qbt: qbittorrent.NewClient(qbtSettings),
|
||||
}
|
||||
l.Trace().Msgf("set cache client id %d %s", clientId, client.Name)
|
||||
|
||||
cached = qc
|
||||
s.cache.Set(clientId, client)
|
||||
|
||||
s.m.Lock()
|
||||
s.qbitClients[clientId] = cached
|
||||
s.m.Unlock()
|
||||
|
||||
return cached
|
||||
return client, nil
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue