diff --git a/cmd/autobrr/main.go b/cmd/autobrr/main.go index a575d0a..d7f55c5 100644 --- a/cmd/autobrr/main.go +++ b/cmd/autobrr/main.go @@ -24,7 +24,9 @@ import ( "github.com/autobrr/autobrr/internal/irc" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/notification" + "github.com/autobrr/autobrr/internal/proxy" "github.com/autobrr/autobrr/internal/release" + "github.com/autobrr/autobrr/internal/releasedownload" "github.com/autobrr/autobrr/internal/scheduler" "github.com/autobrr/autobrr/internal/server" "github.com/autobrr/autobrr/internal/update" @@ -103,24 +105,27 @@ func main() { notificationRepo = database.NewNotificationRepo(log, db) releaseRepo = database.NewReleaseRepo(log, db) userRepo = database.NewUserRepo(log, db) + proxyRepo = database.NewProxyRepo(log, db) ) // setup services var ( apiService = api.NewService(log, apikeyRepo) - notificationService = notification.NewService(log, notificationRepo) updateService = update.NewUpdate(log, cfg.Config) + notificationService = notification.NewService(log, notificationRepo) schedulingService = scheduler.NewService(log, cfg.Config, notificationService, updateService) indexerAPIService = indexer.NewAPIService(log) userService = user.NewService(userRepo) authService = auth.NewService(log, userService) + proxyService = proxy.NewService(log, proxyRepo) + downloadService = releasedownload.NewDownloadService(log, releaseRepo, indexerRepo, proxyService) downloadClientService = download_client.NewService(log, downloadClientRepo) - actionService = action.NewService(log, actionRepo, downloadClientService, bus) + actionService = action.NewService(log, actionRepo, downloadClientService, downloadService, bus) indexerService = indexer.NewService(log, cfg.Config, indexerRepo, releaseRepo, indexerAPIService, schedulingService) - filterService = filter.NewService(log, filterRepo, actionService, releaseRepo, indexerAPIService, indexerService) + filterService = filter.NewService(log, filterRepo, actionService, releaseRepo, indexerAPIService, indexerService, downloadService) releaseService = release.NewService(log, releaseRepo, actionService, filterService, indexerService) - ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService) - feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, schedulingService) + ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService, proxyService) + feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, proxyService, schedulingService) ) // register event subscribers @@ -146,6 +151,7 @@ func main() { indexerService, ircService, notificationService, + proxyService, releaseService, updateService, ) diff --git a/internal/action/deluge.go b/internal/action/deluge.go index 4bf74c3..77f8961 100644 --- a/internal/action/deluge.go +++ b/internal/action/deluge.go @@ -140,9 +140,8 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a return nil, nil } else { if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) - return nil, err + if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil { + return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } } @@ -243,11 +242,8 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a return nil, nil } else { - if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) - return nil, err - } + if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil { + return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } t, err := os.ReadFile(release.TorrentTmpFile) diff --git a/internal/action/porla.go b/internal/action/porla.go index 7c8d472..6c05a5d 100644 --- a/internal/action/porla.go +++ b/internal/action/porla.go @@ -74,10 +74,8 @@ func (s *service) porla(ctx context.Context, action *domain.Action, release doma return nil, nil } else { - if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - return nil, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName) - } + if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil { + return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } file, err := os.Open(release.TorrentTmpFile) diff --git a/internal/action/qbittorrent.go b/internal/action/qbittorrent.go index 8f5c2f0..0a233b2 100644 --- a/internal/action/qbittorrent.go +++ b/internal/action/qbittorrent.go @@ -57,10 +57,8 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas return nil, nil } - if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - return nil, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName) - } + if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil { + return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } options, err := s.prepareQbitOptions(action) diff --git a/internal/action/rtorrent.go b/internal/action/rtorrent.go index 99dc8c5..bbb9efc 100644 --- a/internal/action/rtorrent.go +++ b/internal/action/rtorrent.go @@ -68,11 +68,8 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d return nil, nil } - if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) - return nil, err - } + if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil { + return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } tmpFile, err := os.ReadFile(release.TorrentTmpFile) diff --git a/internal/action/run.go b/internal/action/run.go index e0e5443..c61027d 100644 --- a/internal/action/run.go +++ b/internal/action/run.go @@ -36,9 +36,8 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release return nil, errors.New("action %s client %s %s not enabled, skipping", action.Name, action.Client.Type, action.Client.Name) } - // if set, try to resolve MagnetURI before parsing macros - // to allow webhook and exec to get the magnet_uri - if err := release.ResolveMagnetUri(ctx); err != nil { + // Check preconditions: download torrent file if needed + if err := s.CheckActionPreconditions(ctx, action, release); err != nil { return nil, err } @@ -137,6 +136,30 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release return rejections, err } +func (s *service) CheckActionPreconditions(ctx context.Context, action *domain.Action, release *domain.Release) error { + if err := s.downloadSvc.ResolveMagnetURI(ctx, release); err != nil { + return errors.Wrap(err, "could not resolve magnet uri: %s", release.MagnetURI) + } + + // parse all macros in one go + if action.CheckMacrosNeedTorrentTmpFile(release) { + if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil { + return errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) + } + } + + if action.CheckMacrosNeedRawDataBytes(release) { + tmpFile, err := os.ReadFile(release.TorrentTmpFile) + if err != nil { + return errors.Wrap(err, "could not read torrent file: %v", release.TorrentTmpFile) + } + + release.TorrentDataRawBytes = tmpFile + } + + return nil +} + func (s *service) test(name string) { s.log.Info().Msgf("action TEST: %v", name) } diff --git a/internal/action/service.go b/internal/action/service.go index 6d1bc18..351869f 100644 --- a/internal/action/service.go +++ b/internal/action/service.go @@ -12,6 +12,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/download_client" "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/internal/releasedownload" "github.com/autobrr/autobrr/pkg/sharedhttp" "github.com/asaskevich/EventBus" @@ -33,21 +34,23 @@ type Service interface { } type service struct { - log zerolog.Logger - subLogger *log.Logger - repo domain.ActionRepo - clientSvc download_client.Service - bus EventBus.Bus + log zerolog.Logger + subLogger *log.Logger + repo domain.ActionRepo + clientSvc download_client.Service + downloadSvc *releasedownload.DownloadService + bus EventBus.Bus httpClient *http.Client } -func NewService(log logger.Logger, repo domain.ActionRepo, clientSvc download_client.Service, bus EventBus.Bus) Service { +func NewService(log logger.Logger, repo domain.ActionRepo, clientSvc download_client.Service, downloadSvc *releasedownload.DownloadService, bus EventBus.Bus) Service { s := &service{ - log: log.With().Str("module", "action").Logger(), - repo: repo, - clientSvc: clientSvc, - bus: bus, + log: log.With().Str("module", "action").Logger(), + repo: repo, + clientSvc: clientSvc, + downloadSvc: downloadSvc, + bus: bus, httpClient: &http.Client{ Timeout: time.Second * 120, diff --git a/internal/action/transmission.go b/internal/action/transmission.go index a03a368..4172232 100644 --- a/internal/action/transmission.go +++ b/internal/action/transmission.go @@ -107,11 +107,8 @@ func (s *service) transmission(ctx context.Context, action *domain.Action, relea return nil, nil } - if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) - return nil, err - } + if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil { + return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } b64, err := transmissionrpc.File2Base64(release.TorrentTmpFile) diff --git a/internal/database/feed.go b/internal/database/feed.go index 8d70a1b..6ffd9e0 100644 --- a/internal/database/feed.go +++ b/internal/database/feed.go @@ -36,6 +36,8 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) { "i.identifier", "i.identifier_external", "i.name", + "i.use_proxy", + "i.proxy_id", "f.name", "f.type", "f.enabled", @@ -66,8 +68,9 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) { var f domain.Feed var apiKey, cookie, settings sql.NullString + var proxyID sql.NullInt64 - if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { + if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrRecordNotFound } @@ -75,6 +78,7 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) { return nil, errors.Wrap(err, "error scanning row") } + f.ProxyID = proxyID.Int64 f.ApiKey = apiKey.String f.Cookie = cookie.String @@ -98,6 +102,8 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string) "i.identifier", "i.identifier_external", "i.name", + "i.use_proxy", + "i.proxy_id", "f.name", "f.type", "f.enabled", @@ -128,8 +134,9 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string) var f domain.Feed var apiKey, cookie, settings sql.NullString + var proxyID sql.NullInt64 - if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { + if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrRecordNotFound } @@ -137,6 +144,7 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string) return nil, errors.Wrap(err, "error scanning row") } + f.ProxyID = proxyID.Int64 f.ApiKey = apiKey.String f.Cookie = cookie.String @@ -158,6 +166,8 @@ func (r *FeedRepo) Find(ctx context.Context) ([]domain.Feed, error) { "i.identifier", "i.identifier_external", "i.name", + "i.use_proxy", + "i.proxy_id", "f.name", "f.type", "f.enabled", @@ -196,10 +206,13 @@ func (r *FeedRepo) Find(ctx context.Context) ([]domain.Feed, error) { var apiKey, cookie, lastRunData, settings sql.NullString var lastRun sql.NullTime - if err := rows.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &lastRun, &lastRunData, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { + var proxyID sql.NullInt64 + + if err := rows.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &lastRun, &lastRunData, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { return nil, errors.Wrap(err, "error scanning row") } + f.ProxyID = proxyID.Int64 f.LastRun = lastRun.Time f.LastRunData = lastRunData.String f.ApiKey = apiKey.String diff --git a/internal/database/indexer.go b/internal/database/indexer.go index 80b8e1e..fc46cc7 100644 --- a/internal/database/indexer.go +++ b/internal/database/indexer.go @@ -36,8 +36,8 @@ func (r *IndexerRepo) Store(ctx context.Context, indexer domain.Indexer) (*domai } queryBuilder := r.db.squirrel. - Insert("indexer").Columns("enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). - Values(indexer.Enabled, indexer.Name, indexer.Identifier, indexer.IdentifierExternal, indexer.Implementation, indexer.BaseURL, settings). + Insert("indexer").Columns("enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings"). + Values(indexer.Enabled, indexer.Name, indexer.Identifier, indexer.IdentifierExternal, indexer.Implementation, indexer.BaseURL, indexer.UseProxy, toNullInt64(indexer.ProxyID), settings). Suffix("RETURNING id").RunWith(r.db.handler) // return values @@ -61,6 +61,8 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma Set("name", indexer.Name). Set("identifier_external", indexer.IdentifierExternal). Set("base_url", indexer.BaseURL). + Set("use_proxy", indexer.UseProxy). + Set("proxy_id", toNullInt64(indexer.ProxyID)). Set("settings", settings). Set("updated_at", time.Now().Format(time.RFC3339)). Where(sq.Eq{"id": indexer.ID}) @@ -70,16 +72,26 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma return nil, errors.Wrap(err, "error building query") } - if _, err = r.db.handler.ExecContext(ctx, query, args...); err != nil { + result, err := r.db.handler.ExecContext(ctx, query, args...) + if err != nil { return nil, errors.Wrap(err, "error executing query") } + rowsAffected, err := result.RowsAffected() + if err != nil { + return nil, errors.Wrap(err, "error rows affected") + } + + if rowsAffected == 0 { + return nil, domain.ErrUpdateFailed + } + return &indexer, nil } func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). + Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings"). From("indexer"). OrderBy("name ASC") @@ -98,27 +110,29 @@ func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) { indexers := make([]domain.Indexer, 0) for rows.Next() { - var f domain.Indexer + var i domain.Indexer var identifierExternal, implementation, baseURL sql.Null[string] + var proxyID sql.Null[int64] var settings string var settingsMap map[string]string - if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &f.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil { + if err := rows.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil { return nil, errors.Wrap(err, "error scanning row") } - f.IdentifierExternal = identifierExternal.V - f.Implementation = implementation.V - f.BaseURL = baseURL.V + i.IdentifierExternal = identifierExternal.V + i.Implementation = implementation.V + i.BaseURL = baseURL.V + i.ProxyID = proxyID.V if err = json.Unmarshal([]byte(settings), &settingsMap); err != nil { return nil, errors.Wrap(err, "error unmarshal settings") } - f.Settings = settingsMap + i.Settings = settingsMap - indexers = append(indexers, f) + indexers = append(indexers, i) } if err := rows.Err(); err != nil { return nil, errors.Wrap(err, "error rows") @@ -129,7 +143,7 @@ func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) { func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). + Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings"). From("indexer"). Where(sq.Eq{"id": id}) @@ -146,8 +160,9 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er var i domain.Indexer var identifierExternal, implementation, baseURL, settings sql.Null[string] + var proxyID sql.Null[int64] - if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil { + if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrRecordNotFound } @@ -158,6 +173,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er i.IdentifierExternal = identifierExternal.V i.Implementation = implementation.V i.BaseURL = baseURL.V + i.ProxyID = proxyID.V var settingsMap map[string]string if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil { @@ -171,7 +187,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (*domain.Indexer, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). + Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings"). From("indexer") if req.ID > 0 { @@ -195,8 +211,9 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) ( var i domain.Indexer var identifierExternal, implementation, baseURL, settings sql.Null[string] + var proxyID sql.Null[int64] - if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil { + if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrRecordNotFound } @@ -207,6 +224,7 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) ( i.IdentifierExternal = identifierExternal.V i.Implementation = implementation.V i.BaseURL = baseURL.V + i.ProxyID = proxyID.V var settingsMap map[string]string if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil { @@ -220,7 +238,7 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) ( func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Indexer, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "identifier", "identifier_external", "base_url", "settings"). + Select("id", "enabled", "name", "identifier", "identifier_external", "base_url", "use_proxy", "proxy_id", "settings"). From("indexer"). Join("filter_indexer ON indexer.id = filter_indexer.indexer_id"). Where(sq.Eq{"filter_indexer.filter_id": id}) @@ -239,13 +257,14 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde indexers := make([]domain.Indexer, 0) for rows.Next() { - var f domain.Indexer + var i domain.Indexer var settings string var settingsMap map[string]string var identifierExternal, baseURL sql.Null[string] + var proxyID sql.Null[int64] - if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &f.Identifier, &identifierExternal, &baseURL, &settings); err != nil { + if err := rows.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil { return nil, errors.Wrap(err, "error scanning row") } @@ -253,11 +272,12 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde return nil, errors.Wrap(err, "error unmarshal settings") } - f.IdentifierExternal = identifierExternal.V - f.BaseURL = baseURL.V - f.Settings = settingsMap + i.IdentifierExternal = identifierExternal.V + i.BaseURL = baseURL.V + i.ProxyID = proxyID.V + i.Settings = settingsMap - indexers = append(indexers, f) + indexers = append(indexers, i) } if err := rows.Err(); err != nil { return nil, errors.Wrap(err, "error rows") @@ -282,13 +302,13 @@ func (r *IndexerRepo) Delete(ctx context.Context, id int) error { return errors.Wrap(err, "error executing query") } - rows, err := result.RowsAffected() + rowsAffected, err := result.RowsAffected() if err != nil { return errors.Wrap(err, "error rows affected") } - if rows != 1 { - return errors.New("error deleting row") + if rowsAffected == 0 { + return domain.ErrRecordNotFound } r.log.Debug().Str("method", "delete").Msgf("successfully deleted indexer with id %v", id) @@ -297,8 +317,6 @@ func (r *IndexerRepo) Delete(ctx context.Context, id int) error { } func (r *IndexerRepo) ToggleEnabled(ctx context.Context, indexerID int, enabled bool) error { - var err error - queryBuilder := r.db.squirrel. Update("indexer"). Set("enabled", enabled). @@ -310,10 +328,19 @@ func (r *IndexerRepo) ToggleEnabled(ctx context.Context, indexerID int, enabled return errors.Wrap(err, "error building query") } - _, err = r.db.handler.ExecContext(ctx, query, args...) + result, err := r.db.handler.ExecContext(ctx, query, args...) if err != nil { return errors.Wrap(err, "error executing query") } + rowsAffected, err := result.RowsAffected() + if err != nil { + return errors.Wrap(err, "error rows affected") + } + + if rowsAffected == 0 { + return domain.ErrUpdateFailed + } + return nil } diff --git a/internal/database/irc.go b/internal/database/irc.go index 4e49b8b..b30f01f 100644 --- a/internal/database/irc.go +++ b/internal/database/irc.go @@ -30,7 +30,7 @@ func NewIrcRepo(log logger.Logger, db *DB) domain.IrcRepo { func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). + Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id"). From("irc_network"). Where(sq.Eq{"id": id}) @@ -42,26 +42,27 @@ func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetw var n domain.IrcNetwork - var pass, nick, inviteCmd, bouncerAddr sql.NullString - var account, password sql.NullString - var tls sql.NullBool + var pass, nick, inviteCmd, bouncerAddr sql.Null[string] + var account, password sql.Null[string] + var tls sql.Null[bool] + var proxyId sql.Null[int64] row := r.db.handler.QueryRowContext(ctx, query, args...) - if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &nick, &n.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &n.UseBouncer, &n.BotMode); err != nil { + if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &nick, &n.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &n.UseBouncer, &n.BotMode, &n.UseProxy, &proxyId); err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, domain.ErrRecordNotFound } - return nil, errors.Wrap(err, "error scanning row") } - n.TLS = tls.Bool - n.Pass = pass.String - n.Nick = nick.String - n.InviteCommand = inviteCmd.String - n.Auth.Account = account.String - n.Auth.Password = password.String - n.BouncerAddr = bouncerAddr.String + n.TLS = tls.V + n.Pass = pass.V + n.Nick = nick.V + n.InviteCommand = inviteCmd.V + n.BouncerAddr = bouncerAddr.V + n.Auth.Account = account.V + n.Auth.Password = password.V + n.ProxyId = proxyId.V return &n, nil } @@ -111,7 +112,7 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error { func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). + Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id"). From("irc_network"). Where(sq.Eq{"enabled": true}) @@ -131,22 +132,24 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, for rows.Next() { var net domain.IrcNetwork - var pass, nick, inviteCmd, bouncerAddr sql.NullString - var account, password sql.NullString - var tls sql.NullBool + var pass, nick, inviteCmd, bouncerAddr sql.Null[string] + var account, password sql.Null[string] + var tls sql.Null[bool] + var proxyId sql.Null[int64] - if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil { + if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil { return nil, errors.Wrap(err, "error scanning row") } - net.TLS = tls.Bool - net.Pass = pass.String - net.Nick = nick.String - net.InviteCommand = inviteCmd.String - net.BouncerAddr = bouncerAddr.String + net.TLS = tls.V + net.Pass = pass.V + net.Nick = nick.V + net.InviteCommand = inviteCmd.V + net.BouncerAddr = bouncerAddr.V + net.Auth.Account = account.V + net.Auth.Password = password.V - net.Auth.Account = account.String - net.Auth.Password = password.String + net.ProxyId = proxyId.V networks = append(networks, net) } @@ -159,7 +162,7 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). + Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id"). From("irc_network"). OrderBy("name ASC") @@ -179,22 +182,24 @@ func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) for rows.Next() { var net domain.IrcNetwork - var pass, nick, inviteCmd, bouncerAddr sql.NullString - var account, password sql.NullString - var tls sql.NullBool + var pass, nick, inviteCmd, bouncerAddr sql.Null[string] + var account, password sql.Null[string] + var tls sql.Null[bool] + var proxyId sql.Null[int64] - if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil { + if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil { return nil, errors.Wrap(err, "error scanning row") } - net.TLS = tls.Bool - net.Pass = pass.String - net.Nick = nick.String - net.InviteCommand = inviteCmd.String - net.BouncerAddr = bouncerAddr.String + net.TLS = tls.V + net.Pass = pass.V + net.Nick = nick.V + net.InviteCommand = inviteCmd.V + net.BouncerAddr = bouncerAddr.V + net.Auth.Account = account.V + net.Auth.Password = password.V - net.Auth.Account = account.String - net.Auth.Password = password.String + net.ProxyId = proxyId.V networks = append(networks, net) } @@ -225,13 +230,13 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) { var channels []domain.IrcChannel for rows.Next() { var ch domain.IrcChannel - var pass sql.NullString + var pass sql.Null[string] if err := rows.Scan(&ch.ID, &ch.Name, &ch.Enabled, &pass); err != nil { return nil, errors.Wrap(err, "error scanning row") } - ch.Password = pass.String + ch.Password = pass.V channels = append(channels, ch) } @@ -244,7 +249,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) { func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) { queryBuilder := r.db.squirrel. - Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). + Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id"). From("irc_network"). Where(sq.Eq{"server": network.Server}). Where(sq.Eq{"port": network.Port}). @@ -263,11 +268,12 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN var net domain.IrcNetwork - var pass, nick, inviteCmd, bouncerAddr sql.NullString - var account, password sql.NullString - var tls sql.NullBool + var pass, nick, inviteCmd, bouncerAddr sql.Null[string] + var account, password sql.Null[string] + var tls sql.Null[bool] + var proxyId sql.Null[int64] - if err = row.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil { + if err = row.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil { if errors.Is(err, sql.ErrNoRows) { // no result is not an error in our case return nil, nil @@ -276,29 +282,20 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN return nil, errors.Wrap(err, "error scanning row") } - net.TLS = tls.Bool - net.Pass = pass.String - net.Nick = nick.String - net.InviteCommand = inviteCmd.String - net.BouncerAddr = bouncerAddr.String - net.Auth.Account = account.String - net.Auth.Password = password.String + net.TLS = tls.V + net.Pass = pass.V + net.Nick = nick.V + net.InviteCommand = inviteCmd.V + net.BouncerAddr = bouncerAddr.V + net.Auth.Account = account.V + net.Auth.Password = password.V + + net.ProxyId = proxyId.V return &net, nil } func (r *IrcRepo) StoreNetwork(ctx context.Context, network *domain.IrcNetwork) error { - netName := toNullString(network.Name) - pass := toNullString(network.Pass) - nick := toNullString(network.Nick) - inviteCmd := toNullString(network.InviteCommand) - bouncerAddr := toNullString(network.BouncerAddr) - - account := toNullString(network.Auth.Account) - password := toNullString(network.Auth.Password) - - var retID int64 - queryBuilder := r.db.squirrel. Insert("irc_network"). Columns( @@ -319,60 +316,49 @@ func (r *IrcRepo) StoreNetwork(ctx context.Context, network *domain.IrcNetwork) ). Values( network.Enabled, - netName, + toNullString(network.Name), network.Server, network.Port, network.TLS, - pass, - nick, + toNullString(network.Pass), + toNullString(network.Nick), network.Auth.Mechanism, - account, - password, - inviteCmd, - bouncerAddr, + toNullString(network.Auth.Account), + toNullString(network.Auth.Password), + toNullString(network.InviteCommand), + toNullString(network.BouncerAddr), network.UseBouncer, network.BotMode, ). Suffix("RETURNING id"). RunWith(r.db.handler) - if err := queryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil { + if err := queryBuilder.QueryRowContext(ctx).Scan(&network.ID); err != nil { return errors.Wrap(err, "error executing query") } - network.ID = retID - return nil } func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error { - netName := toNullString(network.Name) - pass := toNullString(network.Pass) - nick := toNullString(network.Nick) - inviteCmd := toNullString(network.InviteCommand) - bouncerAddr := toNullString(network.BouncerAddr) - - account := toNullString(network.Auth.Account) - password := toNullString(network.Auth.Password) - - var err error - queryBuilder := r.db.squirrel. Update("irc_network"). Set("enabled", network.Enabled). - Set("name", netName). + Set("name", toNullString(network.Name)). Set("server", network.Server). Set("port", network.Port). Set("tls", network.TLS). - Set("pass", pass). - Set("nick", nick). + Set("pass", toNullString(network.Pass)). + Set("nick", toNullString(network.Nick)). Set("auth_mechanism", network.Auth.Mechanism). - Set("auth_account", account). - Set("auth_password", password). - Set("invite_command", inviteCmd). - Set("bouncer_addr", bouncerAddr). + Set("auth_account", toNullString(network.Auth.Account)). + Set("auth_password", toNullString(network.Auth.Password)). + Set("invite_command", toNullString(network.InviteCommand)). + Set("bouncer_addr", toNullString(network.BouncerAddr)). Set("use_bouncer", network.UseBouncer). Set("bot_mode", network.BotMode). + Set("use_proxy", network.UseProxy). + Set("proxy_id", toNullInt64(network.ProxyId)). Set("updated_at", time.Now().Format(time.RFC3339)). Where(sq.Eq{"id": network.ID}) @@ -414,7 +400,6 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha for _, channel := range channels { // values - pass := toNullString(channel.Password) channelQueryBuilder := r.db.squirrel. Insert("irc_channel"). @@ -429,21 +414,17 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha channel.Enabled, true, channel.Name, - pass, + toNullString(channel.Password), networkID, ). Suffix("RETURNING id"). RunWith(tx) // returning - var retID int64 - - if err = channelQueryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil { + if err = channelQueryBuilder.QueryRowContext(ctx).Scan(&channel.ID); err != nil { return errors.Wrap(err, "error executing query storeNetworkChannels") } - channel.ID = retID - //channelQuery, channelArgs, err := channelQueryBuilder.ToSql() //if err != nil { // r.log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error building query") @@ -467,8 +448,6 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha } func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *domain.IrcChannel) error { - pass := toNullString(channel.Password) - if channel.ID != 0 { // update record channelQueryBuilder := r.db.squirrel. @@ -476,7 +455,7 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do Set("enabled", channel.Enabled). Set("detached", channel.Detached). Set("name", channel.Name). - Set("password", pass). + Set("password", toNullString(channel.Password)). Where(sq.Eq{"id": channel.ID}) query, args, err := channelQueryBuilder.ToSql() @@ -501,21 +480,17 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do channel.Enabled, true, channel.Name, - pass, + toNullString(channel.Password), networkID, ). Suffix("RETURNING id"). RunWith(r.db.handler) // returning - var retID int64 - - if err := queryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil { + if err := queryBuilder.QueryRowContext(ctx).Scan(&channel.ID); err != nil { return errors.Wrap(err, "error executing query") } - channel.ID = retID - //channelQuery, channelArgs, err := channelQueryBuilder.ToSql() //if err != nil { // r.log.Error().Stack().Err(err).Msg("irc.storeChannel: error building query") @@ -536,15 +511,13 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do } func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error { - pass := toNullString(channel.Password) - // update record channelQueryBuilder := r.db.squirrel. Update("irc_channel"). Set("enabled", channel.Enabled). Set("detached", channel.Detached). Set("name", channel.Name). - Set("password", pass). + Set("password", toNullString(channel.Password)). Where(sq.Eq{"id": channel.ID}) query, args, err := channelQueryBuilder.ToSql() @@ -561,7 +534,6 @@ func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error { } func (r *IrcRepo) UpdateInviteCommand(networkID int64, invite string) error { - // update record channelQueryBuilder := r.db.squirrel. Update("irc_network"). diff --git a/internal/database/postgres_migrate.go b/internal/database/postgres_migrate.go index cf2d51c..c2e246a 100644 --- a/internal/database/postgres_migrate.go +++ b/internal/database/postgres_migrate.go @@ -14,6 +14,20 @@ CREATE TABLE users UNIQUE (username) ); +CREATE TABLE proxy +( + id SERIAL PRIMARY KEY, + enabled BOOLEAN, + name TEXT NOT NULL, + type TEXT NOT NULL, + addr TEXT NOT NULL, + auth_user TEXT, + auth_pass TEXT, + timeout INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + CREATE TABLE indexer ( id SERIAL PRIMARY KEY, @@ -24,8 +38,11 @@ CREATE TABLE indexer enabled BOOLEAN, name TEXT NOT NULL, settings TEXT, + use_proxy BOOLEAN DEFAULT FALSE, + proxy_id INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL, UNIQUE (identifier) ); @@ -51,8 +68,11 @@ CREATE TABLE irc_network bot_mode BOOLEAN DEFAULT FALSE, connected BOOLEAN, connected_since TIMESTAMP, + use_proxy BOOLEAN DEFAULT FALSE, + proxy_id INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL, UNIQUE (server, port, nick) ); @@ -900,5 +920,39 @@ ADD COLUMN months TEXT; ALTER TABLE filter ADD COLUMN days TEXT; +`, + `CREATE TABLE proxy +( + id SERIAL PRIMARY KEY, + enabled BOOLEAN, + name TEXT NOT NULL, + type TEXT NOT NULL, + addr TEXT NOT NULL, + auth_user TEXT, + auth_pass TEXT, + timeout INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +ALTER TABLE indexer + ADD COLUMN proxy_id INTEGER; + +ALTER TABLE indexer + ADD COLUMN use_proxy BOOLEAN DEFAULT FALSE; + +ALTER TABLE indexer + ADD FOREIGN KEY (proxy_id) REFERENCES proxy + ON DELETE SET NULL; + +ALTER TABLE irc_network + ADD COLUMN proxy_id INTEGER; + +ALTER TABLE irc_network + ADD COLUMN use_proxy BOOLEAN DEFAULT FALSE; + +ALTER TABLE irc_network + ADD FOREIGN KEY (proxy_id) REFERENCES proxy + ON DELETE SET NULL; `, } diff --git a/internal/database/proxy.go b/internal/database/proxy.go new file mode 100644 index 0000000..456e733 --- /dev/null +++ b/internal/database/proxy.go @@ -0,0 +1,265 @@ +// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +package database + +import ( + "context" + "database/sql" + "time" + + "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/pkg/errors" + + sq "github.com/Masterminds/squirrel" + "github.com/rs/zerolog" +) + +type ProxyRepo struct { + log zerolog.Logger + db *DB +} + +func NewProxyRepo(log logger.Logger, db *DB) domain.ProxyRepo { + return &ProxyRepo{ + log: log.With().Str("repo", "proxy").Logger(), + db: db, + } +} + +func (r *ProxyRepo) Store(ctx context.Context, p *domain.Proxy) error { + queryBuilder := r.db.squirrel. + Insert("proxy"). + Columns( + "enabled", + "name", + "type", + "addr", + "auth_user", + "auth_pass", + "timeout", + ). + Values( + p.Enabled, + p.Name, + p.Type, + toNullString(p.Addr), + toNullString(p.User), + toNullString(p.Pass), + p.Timeout, + ). + Suffix("RETURNING id"). + RunWith(r.db.handler) + + var retID int64 + err := queryBuilder.QueryRowContext(ctx).Scan(&retID) + if err != nil { + return errors.Wrap(err, "error executing query") + } + + p.ID = retID + + return nil +} + +func (r *ProxyRepo) Update(ctx context.Context, p *domain.Proxy) error { + queryBuilder := r.db.squirrel. + Update("proxy"). + Set("enabled", p.Enabled). + Set("name", p.Name). + Set("type", p.Type). + Set("addr", p.Addr). + Set("auth_user", toNullString(p.User)). + Set("auth_pass", toNullString(p.Pass)). + Set("timeout", p.Timeout). + Set("updated_at", time.Now().Format(time.RFC3339)). + Where(sq.Eq{"id": p.ID}) + + query, args, err := queryBuilder.ToSql() + if err != nil { + return errors.Wrap(err, "error building query") + } + + // update record + res, err := r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + return errors.Wrap(err, "error executing query") + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return errors.Wrap(err, "error getting affected rows") + } + + if rowsAffected == 0 { + return domain.ErrUpdateFailed + } + + return err +} + +func (r *ProxyRepo) List(ctx context.Context) ([]domain.Proxy, error) { + queryBuilder := r.db.squirrel. + Select( + "id", + "enabled", + "name", + "type", + "addr", + "auth_user", + "auth_pass", + "timeout", + ). + From("proxy"). + OrderBy("name ASC") + + query, args, err := queryBuilder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "error building query") + } + + rows, err := r.db.handler.QueryContext(ctx, query, args...) + if err != nil { + return nil, errors.Wrap(err, "error executing query") + } + + defer rows.Close() + + proxies := make([]domain.Proxy, 0) + for rows.Next() { + var proxy domain.Proxy + + var user, pass sql.NullString + + if err := rows.Scan(&proxy.ID, &proxy.Enabled, &proxy.Name, &proxy.Type, &proxy.Addr, &user, &pass, &proxy.Timeout); err != nil { + return nil, errors.Wrap(err, "error scanning row") + } + + proxy.User = user.String + proxy.Pass = pass.String + + proxies = append(proxies, proxy) + } + + err = rows.Err() + if err != nil { + return nil, errors.Wrap(err, "error row") + } + + return proxies, nil +} + +func (r *ProxyRepo) Delete(ctx context.Context, id int64) error { + tx, err := r.db.BeginTx(ctx, nil) + if err != nil { + return errors.Wrap(err, "error begin transaction") + } + + defer tx.Rollback() + + queryBuilder := r.db.squirrel. + Delete("proxy"). + Where(sq.Eq{"id": id}) + + query, args, err := queryBuilder.ToSql() + if err != nil { + return errors.Wrap(err, "error building query") + } + + res, err := tx.ExecContext(ctx, query, args...) + if err != nil { + return errors.Wrap(err, "error executing query") + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return errors.Wrap(err, "error getting affected rows") + } + + if rowsAffected == 0 { + return domain.ErrDeleteFailed + } + + err = tx.Commit() + if err != nil { + return errors.Wrap(err, "error commit deleting proxy") + } + + return nil +} + +func (r *ProxyRepo) FindByID(ctx context.Context, id int64) (*domain.Proxy, error) { + queryBuilder := r.db.squirrel. + Select( + "id", + "enabled", + "name", + "type", + "addr", + "auth_user", + "auth_pass", + "timeout", + ). + From("proxy"). + OrderBy("name ASC"). + Where(sq.Eq{"id": id}) + + query, args, err := queryBuilder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "error building query") + } + + row := r.db.handler.QueryRowContext(ctx, query, args...) + if err := row.Err(); err != nil { + return nil, errors.Wrap(err, "error executing query") + } + + var proxy domain.Proxy + + var user, pass sql.NullString + + err = row.Scan(&proxy.ID, &proxy.Enabled, &proxy.Name, &proxy.Type, &proxy.Addr, &user, &pass, &proxy.Timeout) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, domain.ErrRecordNotFound + } + + return nil, errors.Wrap(err, "error scanning row") + } + + proxy.User = user.String + proxy.Pass = pass.String + + return &proxy, nil +} + +func (r *ProxyRepo) ToggleEnabled(ctx context.Context, id int64, enabled bool) error { + queryBuilder := r.db.squirrel. + Update("proxy"). + Set("enabled", enabled). + Set("updated_at", time.Now().Format(time.RFC3339)). + Where(sq.Eq{"id": id}) + + query, args, err := queryBuilder.ToSql() + if err != nil { + return errors.Wrap(err, "error building query") + } + + // update record + res, err := r.db.handler.ExecContext(ctx, query, args...) + if err != nil { + return errors.Wrap(err, "error executing query") + } + + rowsAffected, err := res.RowsAffected() + if err != nil { + return errors.Wrap(err, "error getting affected rows") + } + + if rowsAffected == 0 { + return domain.ErrUpdateFailed + } + + return nil +} diff --git a/internal/database/proxy_test.go b/internal/database/proxy_test.go new file mode 100644 index 0000000..d9b2250 --- /dev/null +++ b/internal/database/proxy_test.go @@ -0,0 +1,220 @@ +// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +//go:build integration + +package database + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/autobrr/autobrr/internal/domain" + + "github.com/stretchr/testify/assert" +) + +func getMockProxy() *domain.Proxy { + return &domain.Proxy{ + //ID: 0, + Name: "Proxy", + Enabled: true, + Type: domain.ProxyTypeSocks5, + Addr: "socks5://127.0.0.1:1080", + User: "", + Pass: "", + Timeout: 0, + } +} + +func TestProxyRepo_Store(t *testing.T) { + for dbType, db := range testDBs { + log := setupLoggerForTest() + repo := NewProxyRepo(log, db) + mockData := getMockProxy() + + t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) { + // Setup + err := repo.Store(context.Background(), mockData) + assert.NoError(t, err) + + proxies, err := repo.List(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, proxies) + assert.Equal(t, mockData.Name, proxies[0].Name) + + // Cleanup + _ = repo.Delete(context.Background(), mockData.ID) + }) + + t.Run(fmt.Sprintf("Store_Fails_With_Missing_or_empty_fields [%s]", dbType), func(t *testing.T) { + mockData := domain.Proxy{} + err := repo.Store(context.Background(), &mockData) + assert.Error(t, err) + + proxies, err := repo.List(context.Background()) + assert.NoError(t, err) + assert.Empty(t, proxies) + //assert.Nil(t, proxies) + + // Cleanup + // No cleanup needed + }) + + t.Run(fmt.Sprintf("Store_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) { + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond) + defer cancel() + + err := repo.Store(ctx, mockData) + assert.Error(t, err) + }) + } +} + +func TestProxyRepo_Update(t *testing.T) { + for dbType, db := range testDBs { + log := setupLoggerForTest() + repo := NewProxyRepo(log, db) + mockData := getMockProxy() + + t.Run(fmt.Sprintf("Update_Succeeds [%s]", dbType), func(t *testing.T) { + // Setup + err := repo.Store(context.Background(), mockData) + assert.NoError(t, err) + + // Update mockData + updatedProxy := mockData + updatedProxy.Name = "Updated Proxy" + updatedProxy.Enabled = false + + // Execute + err = repo.Update(context.Background(), updatedProxy) + assert.NoError(t, err) + + proxies, err := repo.List(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, proxies) + assert.Equal(t, "Updated Proxy", proxies[0].Name) + assert.Equal(t, false, proxies[0].Enabled) + + // Cleanup + _ = repo.Delete(context.Background(), proxies[0].ID) + }) + + t.Run(fmt.Sprintf("Update_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) { + mockData.ID = -1 + err := repo.Update(context.Background(), mockData) + assert.Error(t, err) + assert.ErrorIs(t, err, domain.ErrUpdateFailed) + }) + } +} + +func TestProxyRepo_Delete(t *testing.T) { + for dbType, db := range testDBs { + log := setupLoggerForTest() + repo := NewProxyRepo(log, db) + mockData := getMockProxy() + + t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) { + // Setup + err := repo.Store(context.Background(), mockData) + assert.NoError(t, err) + + proxies, err := repo.List(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, proxies) + assert.Equal(t, mockData.Name, proxies[0].Name) + + // Execute + err = repo.Delete(context.Background(), proxies[0].ID) + assert.NoError(t, err) + + // Verify that the proxy is deleted and return error ErrRecordNotFound + proxy, err := repo.FindByID(context.Background(), proxies[0].ID) + assert.ErrorIs(t, err, domain.ErrRecordNotFound) + assert.Nil(t, proxy) + }) + + t.Run(fmt.Sprintf("Delete_Fails_No_Record [%s]", dbType), func(t *testing.T) { + err := repo.Delete(context.Background(), 9999) + assert.Error(t, err) + assert.ErrorIs(t, err, domain.ErrDeleteFailed) + }) + } +} + +func TestProxyRepo_ToggleEnabled(t *testing.T) { + for dbType, db := range testDBs { + log := setupLoggerForTest() + repo := NewProxyRepo(log, db) + mockData := getMockProxy() + + t.Run(fmt.Sprintf("ToggleEnabled_Succeeds [%s]", dbType), func(t *testing.T) { + // Setup + err := repo.Store(context.Background(), mockData) + assert.NoError(t, err) + + proxies, err := repo.List(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, proxies) + assert.Equal(t, true, proxies[0].Enabled) + + // Execute + err = repo.ToggleEnabled(context.Background(), mockData.ID, false) + assert.NoError(t, err) + + // Verify that the proxy is updated + proxy, err := repo.FindByID(context.Background(), proxies[0].ID) + assert.NoError(t, err) + assert.NotNil(t, proxy) + assert.Equal(t, false, proxy.Enabled) + + // Cleanup + _ = repo.Delete(context.Background(), proxies[0].ID) + }) + + t.Run(fmt.Sprintf("ToggleEnabled_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) { + err := repo.ToggleEnabled(context.Background(), -1, false) + assert.Error(t, err) + assert.ErrorIs(t, err, domain.ErrUpdateFailed) + }) + } +} + +func TestProxyRepo_FindByID(t *testing.T) { + for dbType, db := range testDBs { + log := setupLoggerForTest() + repo := NewProxyRepo(log, db) + mockData := getMockProxy() + + t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) { + // Setup + err := repo.Store(context.Background(), mockData) + assert.NoError(t, err) + + proxies, err := repo.List(context.Background()) + assert.NoError(t, err) + assert.NotNil(t, proxies) + + // Execute + proxy, err := repo.FindByID(context.Background(), proxies[0].ID) + assert.NoError(t, err) + assert.NotNil(t, proxy) + assert.Equal(t, proxies[0].ID, proxy.ID) + + // Cleanup + _ = repo.Delete(context.Background(), proxies[0].ID) + }) + + t.Run(fmt.Sprintf("FindByID_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) { + // Test using an invalid ID + proxy, err := repo.FindByID(context.Background(), -1) + assert.ErrorIs(t, err, domain.ErrRecordNotFound) // should return an error + assert.Nil(t, proxy) // should be nil + }) + + } +} diff --git a/internal/database/sqlite_migrate.go b/internal/database/sqlite_migrate.go index 79cad4f..26084ff 100644 --- a/internal/database/sqlite_migrate.go +++ b/internal/database/sqlite_migrate.go @@ -14,6 +14,20 @@ CREATE TABLE users UNIQUE (username) ); +CREATE TABLE proxy +( + id INTEGER PRIMARY KEY, + enabled BOOLEAN, + name TEXT NOT NULL, + type TEXT NOT NULL, + addr TEXT NOT NULL, + auth_user TEXT, + auth_pass TEXT, + timeout INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + CREATE TABLE indexer ( id INTEGER PRIMARY KEY, @@ -24,8 +38,11 @@ CREATE TABLE indexer enabled BOOLEAN, name TEXT NOT NULL, settings TEXT, + use_proxy BOOLEAN DEFAULT FALSE, + proxy_id INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL, UNIQUE (identifier) ); @@ -51,8 +68,11 @@ CREATE TABLE irc_network bot_mode BOOLEAN DEFAULT FALSE, connected BOOLEAN, connected_since TIMESTAMP, + use_proxy BOOLEAN DEFAULT FALSE, + proxy_id INTEGER, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL, UNIQUE (server, port, nick) ); @@ -1538,5 +1558,37 @@ ADD COLUMN months TEXT; ALTER TABLE filter ADD COLUMN days TEXT; +`, + `CREATE TABLE proxy +( + id INTEGER PRIMARY KEY, + enabled BOOLEAN, + name TEXT NOT NULL, + type TEXT NOT NULL, + addr TEXT NOT NULL, + auth_user TEXT, + auth_pass TEXT, + timeout INTEGER, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP +); + +ALTER TABLE indexer + ADD proxy_id INTEGER + CONSTRAINT indexer_proxy_id_fk + REFERENCES proxy(id) + ON DELETE SET NULL; + +ALTER TABLE indexer + ADD use_proxy BOOLEAN DEFAULT FALSE; + +ALTER TABLE irc_network + ADD use_proxy BOOLEAN DEFAULT FALSE; + +ALTER TABLE irc_network + ADD proxy_id INTEGER + CONSTRAINT irc_network_proxy_id_fk + REFERENCES proxy(id) + ON DELETE SET NULL; `, } diff --git a/internal/database/utils.go b/internal/database/utils.go index 2b34f63..c5d6b2e 100644 --- a/internal/database/utils.go +++ b/internal/database/utils.go @@ -16,10 +16,10 @@ func dataSourceName(configPath string, name string) string { return name } -func toNullString(s string) sql.NullString { - return sql.NullString{ - String: s, - Valid: s != "", +func toNullString(s string) sql.Null[string] { + return sql.Null[string]{ + V: s, + Valid: s != "", } } diff --git a/internal/domain/action.go b/internal/domain/action.go index c171292..a4ba728 100644 --- a/internal/domain/action.go +++ b/internal/domain/action.go @@ -5,7 +5,6 @@ package domain import ( "context" - "os" "strings" "github.com/autobrr/autobrr/pkg/errors" @@ -60,43 +59,69 @@ type Action struct { Client *DownloadClient `json:"client,omitempty"` } -// ParseMacros parse all macros on action -func (a *Action) ParseMacros(release *Release) error { - var err error - +// CheckMacrosNeedTorrentTmpFile check if macros needs torrent downloaded +func (a *Action) CheckMacrosNeedTorrentTmpFile(release *Release) bool { if release.TorrentTmpFile == "" && - (strings.Contains(a.ExecArgs, "TorrentPathName") || strings.Contains(a.ExecArgs, "TorrentDataRawBytes") || - strings.Contains(a.WebhookData, "TorrentPathName") || strings.Contains(a.WebhookData, "TorrentDataRawBytes") || - strings.Contains(a.SavePath, "TorrentPathName") || a.Type == ActionTypeWatchFolder) { - if err := release.DownloadTorrentFile(); err != nil { - return errors.Wrap(err, "webhook: could not download torrent file for release: %v", release.TorrentName) - } + (strings.Contains(a.ExecArgs, "TorrentPathName") || + strings.Contains(a.ExecArgs, "TorrentDataRawBytes") || + strings.Contains(a.WebhookData, "TorrentPathName") || + strings.Contains(a.WebhookData, "TorrentDataRawBytes") || + strings.Contains(a.SavePath, "TorrentPathName") || + a.Type == ActionTypeWatchFolder) { + return true } + return false +} + +func (a *Action) CheckMacrosNeedRawDataBytes(release *Release) bool { // if webhook data contains TorrentDataRawBytes, lets read the file into bytes we can then use in the macro if len(release.TorrentDataRawBytes) == 0 && (strings.Contains(a.ExecArgs, "TorrentDataRawBytes") || strings.Contains(a.WebhookData, "TorrentDataRawBytes") || a.Type == ActionTypeWatchFolder) { - t, err := os.ReadFile(release.TorrentTmpFile) - if err != nil { - return errors.Wrap(err, "could not read torrent file: %v", release.TorrentTmpFile) - } - - release.TorrentDataRawBytes = t + return true } + return false +} + +// ParseMacros parse all macros on action +func (a *Action) ParseMacros(release *Release) error { + var err error + m := NewMacro(*release) a.ExecArgs, err = m.Parse(a.ExecArgs) - a.WatchFolder, err = m.Parse(a.WatchFolder) - a.Category, err = m.Parse(a.Category) - a.Tags, err = m.Parse(a.Tags) - a.Label, err = m.Parse(a.Label) - a.SavePath, err = m.Parse(a.SavePath) - a.WebhookData, err = m.Parse(a.WebhookData) - if err != nil { - return errors.Wrap(err, "could not parse macros for action: %v", a.Name) + return errors.Wrap(err, "could not parse exec args") + } + + a.WatchFolder, err = m.Parse(a.WatchFolder) + if err != nil { + return errors.Wrap(err, "could not parse watch folder") + } + + a.Category, err = m.Parse(a.Category) + if err != nil { + return errors.Wrap(err, "could not parse category") + } + + a.Tags, err = m.Parse(a.Tags) + if err != nil { + return errors.Wrap(err, "could not parse tags") + } + + a.Label, err = m.Parse(a.Label) + if err != nil { + return errors.Wrap(err, "could not parse label") + } + a.SavePath, err = m.Parse(a.SavePath) + if err != nil { + return errors.Wrap(err, "could not parse save_path") + } + a.WebhookData, err = m.Parse(a.WebhookData) + if err != nil { + return errors.Wrap(err, "could not parse webhook_data") } return nil diff --git a/internal/domain/error.go b/internal/domain/error.go index 4ba682b..cab573c 100644 --- a/internal/domain/error.go +++ b/internal/domain/error.go @@ -3,8 +3,14 @@ package domain -import "database/sql" +import ( + "database/sql" + + "github.com/autobrr/autobrr/pkg/errors" +) var ( ErrRecordNotFound = sql.ErrNoRows + ErrUpdateFailed = errors.New("update failed") + ErrDeleteFailed = errors.New("delete failed") ) diff --git a/internal/domain/feed.go b/internal/domain/feed.go index 1a4eebd..29a4b23 100644 --- a/internal/domain/feed.go +++ b/internal/domain/feed.go @@ -53,6 +53,11 @@ type Feed struct { LastRun time.Time `json:"last_run"` LastRunData string `json:"last_run_data"` NextRun time.Time `json:"next_run"` + + // belongs to Indexer + ProxyID int64 + UseProxy bool + Proxy *Proxy } type FeedSettingsJSON struct { diff --git a/internal/domain/indexer.go b/internal/domain/indexer.go index 7a516c8..63d410d 100644 --- a/internal/domain/indexer.go +++ b/internal/domain/indexer.go @@ -34,6 +34,9 @@ type Indexer struct { Enabled bool `json:"enabled"` Implementation string `json:"implementation"` BaseURL string `json:"base_url,omitempty"` + UseProxy bool `json:"use_proxy"` + Proxy *Proxy `json:"proxy"` + ProxyID int64 `json:"proxy_id"` Settings map[string]string `json:"settings,omitempty"` } @@ -66,6 +69,8 @@ type IndexerDefinition struct { Protocol string `json:"protocol"` URLS []string `json:"urls"` Supports []string `json:"supports"` + UseProxy bool `json:"use_proxy"` + ProxyID int64 `json:"proxy_id"` Settings []IndexerSetting `json:"settings,omitempty"` SettingsMap map[string]string `json:"-"` IRC *IndexerIRC `json:"irc,omitempty"` diff --git a/internal/domain/irc.go b/internal/domain/irc.go index 7a67c92..2775230 100644 --- a/internal/domain/irc.go +++ b/internal/domain/irc.go @@ -48,6 +48,9 @@ type IrcNetwork struct { InviteCommand string `json:"invite_command"` UseBouncer bool `json:"use_bouncer"` BouncerAddr string `json:"bouncer_addr"` + UseProxy bool `json:"use_proxy"` + ProxyId int64 `json:"proxy_id"` + Proxy *Proxy `json:"proxy"` BotMode bool `json:"bot_mode"` Channels []IrcChannel `json:"channels"` Connected bool `json:"connected"` @@ -70,6 +73,9 @@ type IrcNetworkWithHealth struct { BotMode bool `json:"bot_mode"` CurrentNick string `json:"current_nick"` PreferredNick string `json:"preferred_nick"` + UseProxy bool `json:"use_proxy"` + ProxyId int64 `json:"proxy_id"` + Proxy *Proxy `json:"proxy"` Channels []ChannelWithHealth `json:"channels"` Connected bool `json:"connected"` ConnectedSince time.Time `json:"connected_since"` diff --git a/internal/domain/macros.go b/internal/domain/macros.go index 8adb0de..e1a7225 100644 --- a/internal/domain/macros.go +++ b/internal/domain/macros.go @@ -163,6 +163,8 @@ func (m Macro) Parse(text string) (string, error) { return "", nil } + // TODO implement template cache + // setup template tmpl, err := template.New("macro").Funcs(sprig.TxtFuncMap()).Parse(text) if err != nil { diff --git a/internal/domain/proxy.go b/internal/domain/proxy.go new file mode 100644 index 0000000..7a5ad30 --- /dev/null +++ b/internal/domain/proxy.go @@ -0,0 +1,78 @@ +// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +package domain + +import ( + "context" + "net/url" + + "github.com/autobrr/autobrr/pkg/errors" +) + +type ProxyRepo interface { + Store(ctx context.Context, p *Proxy) error + Update(ctx context.Context, p *Proxy) error + List(ctx context.Context) ([]Proxy, error) + Delete(ctx context.Context, id int64) error + FindByID(ctx context.Context, id int64) (*Proxy, error) + ToggleEnabled(ctx context.Context, id int64, enabled bool) error +} + +type Proxy struct { + ID int64 `json:"id"` + Name string `json:"name"` + Enabled bool `json:"enabled"` + Type ProxyType `json:"type"` + Addr string `json:"addr"` + User string `json:"user"` + Pass string `json:"pass"` + Timeout int `json:"timeout"` +} + +type ProxyType string + +const ( + ProxyTypeSocks5 = "SOCKS5" +) + +func (p Proxy) ValidProxyType() bool { + if p.Type == ProxyTypeSocks5 { + return true + } + + return false +} + +func (p Proxy) Validate() error { + if !p.ValidProxyType() { + return errors.New("invalid proxy type: %s", p.Type) + } + + if err := ValidateProxyAddr(p.Addr); err != nil { + return err + } + + if p.Name == "" { + return errors.New("name is required") + } + + return nil +} + +func ValidateProxyAddr(addr string) error { + if addr == "" { + return errors.New("addr is required") + } + + proxyUrl, err := url.Parse(addr) + if err != nil { + return errors.Wrap(err, "could not parse proxy url: %s", addr) + } + + if proxyUrl.Scheme != "socks5" && proxyUrl.Scheme != "socks5h" { + return errors.New("proxy url scheme must be socks5 or socks5h") + } + + return nil +} diff --git a/internal/domain/release.go b/internal/domain/release.go index aaa607b..9cdba2e 100644 --- a/internal/domain/release.go +++ b/internal/domain/release.go @@ -356,8 +356,6 @@ func (r *Release) ParseString(title string) { r.ParseReleaseTagsString(r.ReleaseTags) } -var ErrUnrecoverableError = errors.New("unrecoverable error") - func (r *Release) ParseReleaseTagsString(tags string) { cleanTags := CleanReleaseTags(tags) t := ParseReleaseTagString(cleanTags) @@ -432,10 +430,6 @@ func (r *Release) DownloadTorrentFileCtx(ctx context.Context) error { return r.downloadTorrentFile(ctx) } -func (r *Release) DownloadTorrentFile() error { - return r.downloadTorrentFile(context.Background()) -} - func (r *Release) downloadTorrentFile(ctx context.Context) error { if r.HasMagnetUri() { return errors.New("downloading magnet links is not supported: %s", r.MagnetURI) @@ -592,7 +586,7 @@ func (r *Release) downloadTorrentFile(ctx context.Context) error { } func (r *Release) CleanupTemporaryFiles() { - if len(r.TorrentTmpFile) == 0 { + if r.TorrentTmpFile == "" { return } @@ -600,54 +594,15 @@ func (r *Release) CleanupTemporaryFiles() { r.TorrentTmpFile = "" } -// HasMagnetUri check uf MagnetURI is set or empty +// HasMagnetUri check uf MagnetURI is set and valid or empty func (r *Release) HasMagnetUri() bool { - return r.MagnetURI != "" + if r.MagnetURI != "" && strings.HasPrefix(r.MagnetURI, MagnetURIPrefix) { + return true + } + return false } -func (r *Release) ResolveMagnetUri(ctx context.Context) error { - if r.MagnetURI == "" { - return nil - } else if strings.HasPrefix(r.MagnetURI, "magnet:?") { - return nil - } - - req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.MagnetURI, nil) - if err != nil { - return errors.Wrap(err, "could not build request to resolve magnet uri") - } - - req.Header.Set("Content-Type", "application/json") - req.Header.Set("User-Agent", "autobrr") - - client := &http.Client{ - Timeout: time.Second * 45, - Transport: sharedhttp.MagnetTransport, - } - - res, err := client.Do(req) - if err != nil { - return errors.Wrap(err, "could not make request to resolve magnet uri") - } - - defer res.Body.Close() - - if res.StatusCode != http.StatusOK { - return errors.New("unexpected status code: %d", res.StatusCode) - } - - body, err := io.ReadAll(res.Body) - if err != nil { - return errors.Wrap(err, "could not read response body") - } - - magnet := string(body) - if magnet != "" { - r.MagnetURI = magnet - } - - return nil -} +const MagnetURIPrefix = "magnet:?" func (r *Release) addRejection(reason string) { r.Rejections = append(r.Rejections, reason) diff --git a/internal/domain/release_download_test.go b/internal/domain/release_download_test.go index aa9634b..234593a 100644 --- a/internal/domain/release_download_test.go +++ b/internal/domain/release_download_test.go @@ -6,6 +6,7 @@ package domain import ( + "context" "fmt" "net/http" "net/http/httptest" @@ -290,7 +291,7 @@ func TestRelease_DownloadTorrentFile(t *testing.T) { Filter: tt.fields.Filter, ActionStatus: tt.fields.ActionStatus, } - err := r.DownloadTorrentFile() + err := r.DownloadTorrentFileCtx(context.Background()) if err == nil && tt.wantErr { fmt.Println("error") } diff --git a/internal/feed/client.go b/internal/feed/client.go index ba578d4..86eac71 100644 --- a/internal/feed/client.go +++ b/internal/feed/client.go @@ -42,10 +42,23 @@ func NewFeedParser(timeout time.Duration, cookie string) *RSSParser { } c.http.Timeout = timeout + c.parser.Client = httpClient return c } +func (c *RSSParser) WithHTTPClient(client *http.Client) { + httpClient := client + if client.Jar == nil { + jarOptions := &cookiejar.Options{PublicSuffixList: publicsuffix.List} + jar, _ := cookiejar.New(jarOptions) + httpClient.Jar = jar + } + + c.http = httpClient + c.parser.Client = httpClient +} + func (c *RSSParser) ParseURLWithContext(ctx context.Context, feedURL string) (feed *gofeed.Feed, err error) { req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil) if err != nil { diff --git a/internal/feed/newznab.go b/internal/feed/newznab.go index 67b390f..c2a3f49 100644 --- a/internal/feed/newznab.go +++ b/internal/feed/newznab.go @@ -10,6 +10,7 @@ import ( "time" "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/internal/proxy" "github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/internal/scheduler" "github.com/autobrr/autobrr/pkg/errors" @@ -129,6 +130,18 @@ func (j *NewznabJob) process(ctx context.Context) error { } func (j *NewznabJob) getFeed(ctx context.Context) ([]newznab.FeedItem, error) { + // add proxy if enabled and exists + if j.Feed.UseProxy && j.Feed.Proxy != nil { + proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy) + if err != nil { + return nil, errors.Wrap(err, "could not get proxy client") + } + + j.Client.WithHTTPClient(proxyClient) + + j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name) + } + // get feed feed, err := j.Client.GetFeed(ctx) if err != nil { @@ -156,36 +169,34 @@ func (j *NewznabJob) getFeed(ctx context.Context) ([]newznab.FeedItem, error) { // set ttl to 1 month ttl := time.Now().AddDate(0, 1, 0) - for _, i := range feed.Channel.Items { - i := i - - if i.GUID == "" { + for _, item := range feed.Channel.Items { + if item.GUID == "" { j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name) continue } - exists, err := j.CacheRepo.Exists(j.Feed.ID, i.GUID) + exists, err := j.CacheRepo.Exists(j.Feed.ID, item.GUID) if err != nil { j.Log.Error().Err(err).Msg("could not check if item exists") continue } if exists { - j.Log.Trace().Msgf("cache item exists, skipping release: %s", i.Title) + j.Log.Trace().Msgf("cache item exists, skipping release: %s", item.Title) continue } - j.Log.Debug().Msgf("found new release: %s", i.Title) + j.Log.Debug().Msgf("found new release: %s", item.Title) toCache = append(toCache, domain.FeedCacheItem{ FeedId: strconv.Itoa(j.Feed.ID), - Key: i.GUID, - Value: []byte(i.Title), + Key: item.GUID, + Value: []byte(item.Title), TTL: ttl, }) // only append if we successfully added to cache - items = append(items, *i) + items = append(items, *item) } if len(toCache) > 0 { diff --git a/internal/feed/rss.go b/internal/feed/rss.go index 000f588..ace6739 100644 --- a/internal/feed/rss.go +++ b/internal/feed/rss.go @@ -13,6 +13,7 @@ import ( "time" "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/internal/proxy" "github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/pkg/errors" @@ -93,7 +94,6 @@ func (j *RSSJob) process(ctx context.Context) error { releases := make([]*domain.Release, 0) for _, item := range items { - item := item j.Log.Debug().Msgf("item: %v", item.Title) rls := j.processItem(item) @@ -139,7 +139,7 @@ func (j *RSSJob) processItem(item *gofeed.Item) *domain.Release { } if j.Feed.Settings != nil && j.Feed.Settings.DownloadType == domain.FeedDownloadTypeMagnet { - if !strings.HasPrefix(rls.MagnetURI, "magnet:?") && strings.HasPrefix(e.URL, "magnet:?") { + if !strings.HasPrefix(rls.MagnetURI, domain.MagnetURIPrefix) && strings.HasPrefix(e.URL, domain.MagnetURIPrefix) { rls.MagnetURI = e.URL rls.DownloadURL = "" } @@ -232,7 +232,20 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error) ctx, cancel := context.WithTimeout(ctx, j.Timeout) defer cancel() - feed, err := NewFeedParser(j.Timeout, j.Feed.Cookie).ParseURLWithContext(ctx, j.URL) + feedParser := NewFeedParser(j.Timeout, j.Feed.Cookie) + + if j.Feed.UseProxy && j.Feed.Proxy != nil { + proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy) + if err != nil { + return nil, errors.Wrap(err, "could not get proxy client") + } + + feedParser.WithHTTPClient(proxyClient) + + j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name) + } + + feed, err := feedParser.ParseURLWithContext(ctx, j.URL) if err != nil { return nil, errors.Wrap(err, "error fetching rss feed items") } @@ -257,9 +270,7 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error) // set ttl to 1 month ttl := time.Now().AddDate(0, 1, 0) - for _, i := range feed.Items { - item := i - + for _, item := range feed.Items { key := item.GUID if len(key) == 0 { key = item.Link @@ -278,12 +289,12 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error) continue } - j.Log.Debug().Msgf("found new release: %s", i.Title) + j.Log.Debug().Msgf("found new release: %s", item.Title) toCache = append(toCache, domain.FeedCacheItem{ FeedId: strconv.Itoa(j.Feed.ID), Key: key, - Value: []byte(i.Title), + Value: []byte(item.Title), TTL: ttl, }) diff --git a/internal/feed/service.go b/internal/feed/service.go index f2afba6..fb8b34d 100644 --- a/internal/feed/service.go +++ b/internal/feed/service.go @@ -11,6 +11,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/internal/proxy" "github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/internal/scheduler" "github.com/autobrr/autobrr/pkg/errors" @@ -68,16 +69,18 @@ type service struct { repo domain.FeedRepo cacheRepo domain.FeedCacheRepo releaseSvc release.Service + proxySvc proxy.Service scheduler scheduler.Service } -func NewService(log logger.Logger, repo domain.FeedRepo, cacheRepo domain.FeedCacheRepo, releaseSvc release.Service, scheduler scheduler.Service) Service { +func NewService(log logger.Logger, repo domain.FeedRepo, cacheRepo domain.FeedCacheRepo, releaseSvc release.Service, proxySvc proxy.Service, scheduler scheduler.Service) Service { return &service{ log: log.With().Str("module", "feed").Logger(), jobs: map[string]int{}, repo: repo, cacheRepo: cacheRepo, releaseSvc: releaseSvc, + proxySvc: proxySvc, scheduler: scheduler, } } @@ -150,6 +153,13 @@ func (s *service) update(ctx context.Context, feed *domain.Feed) error { return err } + // get Feed again for ProxyID and UseProxy to be correctly populated + feed, err := s.repo.FindByID(ctx, feed.ID) + if err != nil { + s.log.Error().Err(err).Msg("error finding feed") + return err + } + if err := s.restartJob(feed); err != nil { s.log.Error().Err(err).Msg("error restarting feed") return err @@ -227,6 +237,18 @@ func (s *service) test(ctx context.Context, feed *domain.Feed) error { // create sub logger subLogger := zstdlog.NewStdLoggerWithLevel(s.log.With().Logger(), zerolog.DebugLevel) + // add proxy conf + if feed.UseProxy { + proxyConf, err := s.proxySvc.FindByID(ctx, feed.ProxyID) + if err != nil { + return errors.Wrap(err, "could not find proxy for indexer feed") + } + + if proxyConf.Enabled { + feed.Proxy = proxyConf + } + } + // test feeds switch feed.Type { case string(domain.FeedTypeTorznab): @@ -254,13 +276,27 @@ func (s *service) test(ctx context.Context, feed *domain.Feed) error { } func (s *service) testRSS(ctx context.Context, feed *domain.Feed) error { - f, err := NewFeedParser(time.Duration(feed.Timeout)*time.Second, feed.Cookie).ParseURLWithContext(ctx, feed.URL) + feedParser := NewFeedParser(time.Duration(feed.Timeout)*time.Second, feed.Cookie) + + // add proxy if enabled and exists + if feed.UseProxy && feed.Proxy != nil { + proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy) + if err != nil { + return errors.Wrap(err, "could not get proxy client") + } + + feedParser.WithHTTPClient(proxyClient) + + s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name) + } + + feedResponse, err := feedParser.ParseURLWithContext(ctx, feed.URL) if err != nil { s.log.Error().Err(err).Msgf("error fetching rss feed items") return errors.Wrap(err, "error fetching rss feed items") } - s.log.Info().Msgf("refreshing rss feed: %s, found (%d) items", feed.Name, len(f.Items)) + s.log.Info().Msgf("refreshing rss feed: %s, found (%d) items", feed.Name, len(feedResponse.Items)) return nil } @@ -269,6 +305,18 @@ func (s *service) testTorznab(ctx context.Context, feed *domain.Feed, subLogger // setup torznab Client c := torznab.NewClient(torznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger}) + // add proxy if enabled and exists + if feed.UseProxy && feed.Proxy != nil { + proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy) + if err != nil { + return errors.Wrap(err, "could not get proxy client") + } + + c.WithHTTPClient(proxyClient) + + s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name) + } + items, err := c.FetchFeed(ctx) if err != nil { s.log.Error().Err(err).Msg("error getting torznab feed") @@ -284,6 +332,18 @@ func (s *service) testNewznab(ctx context.Context, feed *domain.Feed, subLogger // setup newznab Client c := newznab.NewClient(newznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger}) + // add proxy if enabled and exists + if feed.UseProxy && feed.Proxy != nil { + proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy) + if err != nil { + return errors.Wrap(err, "could not get proxy client") + } + + c.WithHTTPClient(proxyClient) + + s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name) + } + items, err := c.GetFeed(ctx) if err != nil { s.log.Error().Err(err).Msg("error getting newznab feed") @@ -316,8 +376,6 @@ func (s *service) start() error { s.log.Debug().Msgf("preparing staggered start of %d feeds", len(feeds)) for _, feed := range feeds { - feed := feed - if !feed.Enabled { s.log.Trace().Msgf("feed disabled, skipping... %s", feed.Name) continue @@ -408,6 +466,18 @@ func (s *service) startJob(f *domain.Feed) error { return errors.New("no URL provided for feed: %s", f.Name) } + // add proxy conf + if f.UseProxy { + proxyConf, err := s.proxySvc.FindByID(context.Background(), f.ProxyID) + if err != nil { + return errors.Wrap(err, "could not find proxy for indexer feed") + } + + if proxyConf.Enabled { + f.Proxy = proxyConf + } + } + fi := newFeedInstance(f) job, err := s.initializeFeedJob(fi) diff --git a/internal/feed/torznab.go b/internal/feed/torznab.go index ce717e5..6a4567e 100644 --- a/internal/feed/torznab.go +++ b/internal/feed/torznab.go @@ -5,6 +5,7 @@ package feed import ( "context" + "github.com/autobrr/autobrr/internal/proxy" "math" "sort" "strconv" @@ -224,6 +225,18 @@ func mapFreeleechToBonus(percentage int) string { } func (j *TorznabJob) getFeed(ctx context.Context) ([]torznab.FeedItem, error) { + // add proxy if enabled and exists + if j.Feed.UseProxy && j.Feed.Proxy != nil { + proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy) + if err != nil { + return nil, errors.Wrap(err, "could not get proxy client") + } + + j.Client.WithHTTPClient(proxyClient) + + j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name) + } + // get feed feed, err := j.Client.FetchFeed(ctx) if err != nil { @@ -251,35 +264,33 @@ func (j *TorznabJob) getFeed(ctx context.Context) ([]torznab.FeedItem, error) { // set ttl to 1 month ttl := time.Now().AddDate(0, 1, 0) - for _, i := range feed.Channel.Items { - i := i - - if i.GUID == "" { + for _, item := range feed.Channel.Items { + if item.GUID == "" { j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name) continue } - exists, err := j.CacheRepo.Exists(j.Feed.ID, i.GUID) + exists, err := j.CacheRepo.Exists(j.Feed.ID, item.GUID) if err != nil { j.Log.Error().Err(err).Msg("could not check if item exists") continue } if exists { - j.Log.Trace().Msgf("cache item exists, skipping release: %s", i.Title) + j.Log.Trace().Msgf("cache item exists, skipping release: %s", item.Title) continue } - j.Log.Debug().Msgf("found new release: %s", i.Title) + j.Log.Debug().Msgf("found new release: %s", item.Title) toCache = append(toCache, domain.FeedCacheItem{ FeedId: strconv.Itoa(j.Feed.ID), - Key: i.GUID, - Value: []byte(i.Title), + Key: item.GUID, + Value: []byte(item.Title), TTL: ttl, }) // only append if we successfully added to cache - items = append(items, *i) + items = append(items, *item) } if len(toCache) > 0 { diff --git a/internal/filter/service.go b/internal/filter/service.go index f4448eb..856efcf 100644 --- a/internal/filter/service.go +++ b/internal/filter/service.go @@ -7,7 +7,6 @@ import ( "bytes" "context" "fmt" - "github.com/autobrr/autobrr/internal/action" "io" "net/http" "os" @@ -17,9 +16,11 @@ import ( "strings" "time" + "github.com/autobrr/autobrr/internal/action" "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/indexer" "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/internal/releasedownload" "github.com/autobrr/autobrr/internal/utils" "github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/sharedhttp" @@ -53,11 +54,12 @@ type service struct { releaseRepo domain.ReleaseRepo indexerSvc indexer.Service apiService indexer.APIService + downloadSvc *releasedownload.DownloadService httpClient *http.Client } -func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service { +func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service, downloadSvc *releasedownload.DownloadService) Service { return &service{ log: log.With().Str("module", "filter").Logger(), repo: repo, @@ -65,6 +67,7 @@ func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Serv actionService: actionSvc, apiService: apiService, indexerSvc: indexerSvc, + downloadSvc: downloadSvc, httpClient: &http.Client{ Timeout: time.Second * 120, Transport: sharedhttp.TransportTLSInsecure, @@ -504,9 +507,9 @@ func (s *service) AdditionalSizeCheck(ctx context.Context, f *domain.Filter, rel l.Trace().Msgf("(%s) preparing to download torrent metafile", f.Name) // if indexer doesn't have api, download torrent and add to tmpPath - if err := release.DownloadTorrentFileCtx(ctx); err != nil { + if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil { l.Error().Err(err).Msgf("(%s) could not download torrent file with id: '%s' from: %s", f.Name, release.TorrentID, release.Indexer.Identifier) - return false, err + return false, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } } @@ -584,8 +587,8 @@ func (s *service) execCmd(ctx context.Context, external domain.FilterExternal, r s.log.Trace().Msgf("filter exec release: %s", release.TorrentName) if release.TorrentTmpFile == "" && strings.Contains(external.ExecArgs, "TorrentPathName") { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { - return 0, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName) + if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil { + return 0, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName) } } @@ -686,7 +689,7 @@ func (s *service) webhook(ctx context.Context, external domain.FilterExternal, r // if webhook data contains TorrentPathName or TorrentDataRawBytes, lets download the torrent file if release.TorrentTmpFile == "" && (strings.Contains(external.WebhookData, "TorrentPathName") || strings.Contains(external.WebhookData, "TorrentDataRawBytes")) { - if err := release.DownloadTorrentFileCtx(ctx); err != nil { + if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil { return 0, errors.Wrap(err, "webhook: could not download torrent file for release: %s", release.TorrentName) } } diff --git a/internal/http/feed.go b/internal/http/feed.go index 7c93a3d..c221632 100644 --- a/internal/http/feed.go +++ b/internal/http/feed.go @@ -6,11 +6,11 @@ package http import ( "context" "encoding/json" - "github.com/autobrr/autobrr/pkg/errors" "net/http" "strconv" "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/pkg/errors" "github.com/go-chi/chi/v5" ) diff --git a/internal/http/irc.go b/internal/http/irc.go index 68ba490..03c73d5 100644 --- a/internal/http/irc.go +++ b/internal/http/irc.go @@ -7,12 +7,12 @@ import ( "context" "encoding/json" "fmt" - "github.com/autobrr/autobrr/pkg/errors" "net/http" "strconv" "strings" "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/pkg/errors" "github.com/go-chi/chi/v5" "github.com/r3labs/sse/v2" diff --git a/internal/http/notification.go b/internal/http/notification.go index 78dc819..8d6b0bc 100644 --- a/internal/http/notification.go +++ b/internal/http/notification.go @@ -6,11 +6,11 @@ package http import ( "context" "encoding/json" - "github.com/autobrr/autobrr/pkg/errors" "net/http" "strconv" "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/pkg/errors" "github.com/go-chi/chi/v5" ) diff --git a/internal/http/proxy.go b/internal/http/proxy.go new file mode 100644 index 0000000..9a2d478 --- /dev/null +++ b/internal/http/proxy.go @@ -0,0 +1,151 @@ +// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +package http + +import ( + "context" + "encoding/json" + "net/http" + "strconv" + + "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/pkg/errors" + + "github.com/go-chi/chi/v5" +) + +type proxyService interface { + Store(ctx context.Context, p *domain.Proxy) error + Update(ctx context.Context, p *domain.Proxy) error + Delete(ctx context.Context, id int64) error + List(ctx context.Context) ([]domain.Proxy, error) + FindByID(ctx context.Context, id int64) (*domain.Proxy, error) + Test(ctx context.Context, p *domain.Proxy) error +} + +type proxyHandler struct { + encoder encoder + service proxyService +} + +func newProxyHandler(encoder encoder, service proxyService) *proxyHandler { + return &proxyHandler{ + encoder: encoder, + service: service, + } +} + +func (h proxyHandler) Routes(r chi.Router) { + r.Get("/", h.list) + r.Post("/", h.store) + r.Post("/test", h.test) + + r.Route("/{proxyID}", func(r chi.Router) { + r.Get("/", h.findByID) + r.Put("/", h.update) + r.Delete("/", h.delete) + }) +} + +func (h proxyHandler) store(w http.ResponseWriter, r *http.Request) { + var data domain.Proxy + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + h.encoder.Error(w, err) + return + } + + if err := h.service.Store(r.Context(), &data); err != nil { + h.encoder.Error(w, err) + return + } + + h.encoder.NoContent(w) +} + +func (h proxyHandler) update(w http.ResponseWriter, r *http.Request) { + var data domain.Proxy + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + h.encoder.Error(w, err) + return + } + + if err := h.service.Update(r.Context(), &data); err != nil { + if errors.Is(err, domain.ErrUpdateFailed) { + h.encoder.StatusError(w, http.StatusBadRequest, err) + return + } + + h.encoder.Error(w, err) + return + } + + h.encoder.NoContent(w) +} + +func (h proxyHandler) list(w http.ResponseWriter, r *http.Request) { + proxies, err := h.service.List(r.Context()) + if err != nil { + h.encoder.Error(w, err) + return + } + + h.encoder.StatusResponse(w, http.StatusOK, proxies) +} + +func (h proxyHandler) findByID(w http.ResponseWriter, r *http.Request) { + proxyID, err := strconv.Atoi(chi.URLParam(r, "proxyID")) + if err != nil { + h.encoder.Error(w, err) + return + } + + proxies, err := h.service.FindByID(r.Context(), int64(proxyID)) + if err != nil { + if errors.Is(err, domain.ErrRecordNotFound) { + h.encoder.NotFoundErr(w, errors.New("could not find proxy with id %d", proxyID)) + return + } + + h.encoder.Error(w, err) + return + } + + h.encoder.StatusResponse(w, http.StatusOK, proxies) +} + +func (h proxyHandler) delete(w http.ResponseWriter, r *http.Request) { + proxyID, err := strconv.Atoi(chi.URLParam(r, "proxyID")) + if err != nil { + h.encoder.Error(w, err) + return + } + + err = h.service.Delete(r.Context(), int64(proxyID)) + if err != nil { + if errors.Is(err, domain.ErrDeleteFailed) { + h.encoder.StatusError(w, http.StatusBadRequest, err) + return + } + + h.encoder.Error(w, err) + return + } + + h.encoder.NoContent(w) +} + +func (h proxyHandler) test(w http.ResponseWriter, r *http.Request) { + var data domain.Proxy + if err := json.NewDecoder(r.Body).Decode(&data); err != nil { + h.encoder.Error(w, err) + return + } + + if err := h.service.Test(r.Context(), &data); err != nil { + h.encoder.Error(w, err) + return + } + + h.encoder.NoContent(w) +} diff --git a/internal/http/server.go b/internal/http/server.go index 0b8f300..ab449e1 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -43,11 +43,12 @@ type Server struct { indexerService indexerService ircService ircService notificationService notificationService + proxyService proxyService releaseService releaseService updateService updateService } -func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db *database.DB, version string, commit string, date string, actionService actionService, apiService apikeyService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, feedSvc feedService, indexerSvc indexerService, ircSvc ircService, notificationSvc notificationService, releaseSvc releaseService, updateSvc updateService) Server { +func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db *database.DB, version string, commit string, date string, actionService actionService, apiService apikeyService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, feedSvc feedService, indexerSvc indexerService, ircSvc ircService, notificationSvc notificationService, proxySvc proxyService, releaseSvc releaseService, updateSvc updateService) Server { return Server{ log: log.With().Str("module", "http").Logger(), config: config, @@ -68,6 +69,7 @@ func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db indexerService: indexerSvc, ircService: ircSvc, notificationService: notificationSvc, + proxyService: proxySvc, releaseService: releaseSvc, updateService: updateSvc, } @@ -142,6 +144,7 @@ func (s Server) Handler() http.Handler { r.Route("/keys", newAPIKeyHandler(encoder, s.apiService).Routes) r.Route("/logs", newLogsHandler(s.config).Routes) r.Route("/notification", newNotificationHandler(encoder, s.notificationService).Routes) + r.Route("/proxy", newProxyHandler(encoder, s.proxyService).Routes) r.Route("/release", newReleaseHandler(encoder, s.releaseService).Routes) r.Route("/updates", newUpdateHandler(encoder, s.updateService).Routes) diff --git a/internal/indexer/service.go b/internal/indexer/service.go index 47c9852..40085b6 100644 --- a/internal/indexer/service.go +++ b/internal/indexer/service.go @@ -295,6 +295,9 @@ func (s *service) mapIndexer(indexer domain.Indexer) (*domain.IndexerDefinition, d.BaseURL = indexer.BaseURL d.Enabled = indexer.Enabled + d.UseProxy = indexer.UseProxy + d.ProxyID = indexer.ProxyID + if d.SettingsMap == nil { d.SettingsMap = make(map[string]string) } @@ -332,6 +335,9 @@ func (s *service) updateMapIndexer(indexer domain.Indexer) (*domain.IndexerDefin d.BaseURL = indexer.BaseURL d.Enabled = indexer.Enabled + d.UseProxy = indexer.UseProxy + d.ProxyID = indexer.ProxyID + if d.SettingsMap == nil { d.SettingsMap = make(map[string]string) } diff --git a/internal/irc/handler.go b/internal/irc/handler.go index 8d2582d..f554c2b 100644 --- a/internal/irc/handler.go +++ b/internal/irc/handler.go @@ -6,6 +6,8 @@ package irc import ( "crypto/tls" "fmt" + "golang.org/x/net/proxy" + "net/url" "slices" "strings" "time" @@ -220,6 +222,37 @@ func (h *Handler) Run() (err error) { Log: subLogger, } + if h.network.UseProxy && h.network.Proxy != nil { + if !h.network.Proxy.Enabled { + h.log.Debug().Msgf("proxy disabled, skip") + } else { + if h.network.Proxy.Addr == "" { + return errors.New("proxy addr missing") + } + + proxyUrl, err := url.Parse(h.network.Proxy.Addr) + if err != nil { + return errors.Wrap(err, "could not parse proxy url: %s", h.network.Proxy.Addr) + } + + // set user and pass if not empty + if h.network.Proxy.User != "" && h.network.Proxy.Pass != "" { + proxyUrl.User = url.UserPassword(h.network.Proxy.User, h.network.Proxy.Pass) + } + + proxyDialer, err := proxy.FromURL(proxyUrl, proxy.Direct) + if err != nil { + return errors.Wrap(err, "could not create proxy dialer from url: %s", h.network.Proxy.Addr) + } + proxyContextDialer, ok := proxyDialer.(proxy.ContextDialer) + if !ok { + return errors.Wrap(err, "proxy dialer does not expose DialContext(): %v", proxyDialer) + } + + client.DialContext = proxyContextDialer.DialContext + } + } + if h.network.Auth.Mechanism == domain.IRCAuthMechanismSASLPlain { if h.network.Auth.Account != "" && h.network.Auth.Password != "" { client.SASLLogin = h.network.Auth.Account diff --git a/internal/irc/service.go b/internal/irc/service.go index e80b4ca..e719bca 100644 --- a/internal/irc/service.go +++ b/internal/irc/service.go @@ -14,6 +14,7 @@ import ( "github.com/autobrr/autobrr/internal/indexer" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/notification" + "github.com/autobrr/autobrr/internal/proxy" "github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/pkg/errors" @@ -47,8 +48,10 @@ type service struct { releaseService release.Service indexerService indexer.Service notificationService notification.Service - indexerMap map[string]string - handlers map[int64]*Handler + proxyService proxy.Service + + indexerMap map[string]string + handlers map[int64]*Handler stopWG sync.WaitGroup lock sync.RWMutex @@ -56,7 +59,7 @@ type service struct { const sseMaxEntries = 1000 -func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, releaseSvc release.Service, indexerSvc indexer.Service, notificationSvc notification.Service) Service { +func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, releaseSvc release.Service, indexerSvc indexer.Service, notificationSvc notification.Service, proxySvc proxy.Service) Service { return &service{ log: log.With().Str("module", "irc").Logger(), sse: sse, @@ -64,6 +67,7 @@ func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, release releaseService: releaseSvc, indexerService: indexerSvc, notificationService: notificationSvc, + proxyService: proxySvc, handlers: make(map[int64]*Handler), } } @@ -79,6 +83,15 @@ func (s *service) StartHandlers() { continue } + if network.ProxyId != 0 { + networkProxy, err := s.proxyService.FindByID(context.Background(), network.ProxyId) + if err != nil { + s.log.Error().Err(err).Msgf("failed to get proxy for network: %s", network.Server) + return + } + network.Proxy = networkProxy + } + channels, err := s.repo.ListChannels(network.ID) if err != nil { s.log.Error().Err(err).Msgf("failed to list channels for network: %s", network.Server) @@ -215,6 +228,14 @@ func (s *service) checkIfNetworkRestartNeeded(network *domain.IrcNetwork) error restartNeeded = true fieldsChanged = append(fieldsChanged, "bot mode") } + if handler.UseProxy != network.UseProxy { + restartNeeded = true + fieldsChanged = append(fieldsChanged, "use proxy") + } + if handler.ProxyId != network.ProxyId { + restartNeeded = true + fieldsChanged = append(fieldsChanged, "proxy id") + } if handler.Auth.Mechanism != network.Auth.Mechanism { restartNeeded = true fieldsChanged = append(fieldsChanged, "auth mechanism") @@ -476,12 +497,16 @@ func (s *service) GetNetworksWithHealth(ctx context.Context) ([]domain.IrcNetwor BouncerAddr: n.BouncerAddr, UseBouncer: n.UseBouncer, BotMode: n.BotMode, + UseProxy: n.UseProxy, + ProxyId: n.ProxyId, Connected: false, Channels: []domain.ChannelWithHealth{}, ConnectionErrors: []string{}, } + s.lock.RLock() handler, ok := s.handlers[n.ID] + s.lock.RUnlock() if ok { handler.ReportStatus(&netw) } @@ -566,6 +591,18 @@ func (s *service) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) } s.log.Debug().Msgf("irc.service: update network: %s", network.Name) + network.Proxy = nil + + // attach proxy + if network.UseProxy && network.ProxyId != 0 { + networkProxy, err := s.proxyService.FindByID(context.Background(), network.ProxyId) + if err != nil { + s.log.Error().Err(err).Msgf("failed to get proxy for network: %s", network.Server) + return errors.Wrap(err, "could not get proxy for network: %s", network.Server) + } + network.Proxy = networkProxy + } + // stop or start network // TODO get current state to see if enabled or not? if network.Enabled { diff --git a/internal/proxy/service.go b/internal/proxy/service.go new file mode 100644 index 0000000..af3ad48 --- /dev/null +++ b/internal/proxy/service.go @@ -0,0 +1,194 @@ +// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors. +// SPDX-License-Identifier: GPL-2.0-or-later + +package proxy + +import ( + "context" + "net/http" + "net/url" + "time" + + "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/pkg/errors" + "github.com/autobrr/autobrr/pkg/sharedhttp" + + "github.com/rs/zerolog" + netProxy "golang.org/x/net/proxy" +) + +type Service interface { + List(ctx context.Context) ([]domain.Proxy, error) + FindByID(ctx context.Context, id int64) (*domain.Proxy, error) + Store(ctx context.Context, p *domain.Proxy) error + Update(ctx context.Context, p *domain.Proxy) error + Delete(ctx context.Context, id int64) error + Test(ctx context.Context, p *domain.Proxy) error +} + +type service struct { + log zerolog.Logger + + repo domain.ProxyRepo + cache map[int64]*domain.Proxy +} + +func NewService(log logger.Logger, repo domain.ProxyRepo) Service { + return &service{ + log: log.With().Str("module", "proxy").Logger(), + repo: repo, + cache: make(map[int64]*domain.Proxy), + } +} + +func (s *service) Store(ctx context.Context, proxy *domain.Proxy) error { + if err := proxy.Validate(); err != nil { + return errors.Wrap(err, "validation error") + } + + err := s.repo.Store(ctx, proxy) + if err != nil { + return err + } + + s.cache[proxy.ID] = proxy + + return nil +} + +func (s *service) Update(ctx context.Context, proxy *domain.Proxy) error { + if err := proxy.Validate(); err != nil { + return errors.Wrap(err, "validation error") + } + + err := s.repo.Update(ctx, proxy) + if err != nil { + return err + } + + s.cache[proxy.ID] = proxy + + // TODO update IRC handlers + + return nil +} + +func (s *service) FindByID(ctx context.Context, id int64) (*domain.Proxy, error) { + if proxy, ok := s.cache[id]; ok { + return proxy, nil + } + + return s.repo.FindByID(ctx, id) +} + +func (s *service) List(ctx context.Context) ([]domain.Proxy, error) { + return s.repo.List(ctx) +} + +func (s *service) ToggleEnabled(ctx context.Context, id int64, enabled bool) error { + err := s.repo.ToggleEnabled(ctx, id, enabled) + if err != nil { + return err + } + + v, ok := s.cache[id] + if !ok { + v.Enabled = !enabled + s.cache[id] = v + } + + // TODO update IRC handlers + + return nil +} + +func (s *service) Delete(ctx context.Context, id int64) error { + err := s.repo.Delete(ctx, id) + if err != nil { + return err + } + + delete(s.cache, id) + + // TODO update IRC handlers + + return nil +} + +func (s *service) Test(ctx context.Context, proxy *domain.Proxy) error { + if !proxy.ValidProxyType() { + return errors.New("invalid proxy type %s", proxy.Type) + } + + if proxy.Addr == "" { + return errors.New("proxy addr missing") + } + + httpClient, err := GetProxiedHTTPClient(proxy) + if err != nil { + return errors.Wrap(err, "could not get http client") + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://autobrr.com", nil) + if err != nil { + return errors.Wrap(err, "could not create proxy request") + } + + resp, err := httpClient.Do(req) + if err != nil { + return errors.Wrap(err, "could not connect to proxy server: %s", proxy.Addr) + } + + if resp.StatusCode != http.StatusOK { + return errors.New(resp.Status) + } + + s.log.Debug().Msgf("proxy %s test OK!", proxy.Addr) + + return nil +} + +func GetProxiedHTTPClient(p *domain.Proxy) (*http.Client, error) { + proxyUrl, err := url.Parse(p.Addr) + if err != nil { + return nil, errors.Wrap(err, "could not parse proxy url: %s", p.Addr) + } + + // set user and pass if not empty + if p.User != "" && p.Pass != "" { + proxyUrl.User = url.UserPassword(p.User, p.Pass) + } + + transport := sharedhttp.TransportTLSInsecure + + // set user and pass if not empty + if p.User != "" && p.Pass != "" { + proxyUrl.User = url.UserPassword(p.User, p.Pass) + } + + switch p.Type { + case domain.ProxyTypeSocks5: + proxyDialer, err := netProxy.FromURL(proxyUrl, netProxy.Direct) + if err != nil { + return nil, errors.Wrap(err, "could not create proxy dialer from url: %s", p.Addr) + } + + proxyContextDialer, ok := proxyDialer.(netProxy.ContextDialer) + if !ok { + return nil, errors.Wrap(err, "proxy dialer does not expose DialContext(): %v", proxyDialer) + } + + transport.DialContext = proxyContextDialer.DialContext + + default: + return nil, errors.New("invalid proxy type: %s", p.Type) + } + + client := &http.Client{ + Timeout: 30 * time.Second, + Transport: transport, + } + + return client, nil +} diff --git a/internal/releasedownload/download_service.go b/internal/releasedownload/download_service.go new file mode 100644 index 0000000..d83c696 --- /dev/null +++ b/internal/releasedownload/download_service.go @@ -0,0 +1,316 @@ +package releasedownload + +import ( + "bufio" + "bytes" + "context" + "io" + "net" + "net/http" + "net/http/cookiejar" + "os" + "strings" + "time" + + "github.com/autobrr/autobrr/internal/domain" + "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/internal/proxy" + "github.com/autobrr/autobrr/pkg/errors" + "github.com/autobrr/autobrr/pkg/sharedhttp" + + "github.com/anacrolix/torrent/bencode" + "github.com/anacrolix/torrent/metainfo" + "github.com/avast/retry-go/v4" + "github.com/rs/zerolog" + "golang.org/x/net/publicsuffix" +) + +type DownloadService struct { + log zerolog.Logger + repo domain.ReleaseRepo + + indexerRepo domain.IndexerRepo + + proxySvc proxy.Service +} + +func NewDownloadService(log logger.Logger, repo domain.ReleaseRepo, indexerRepo domain.IndexerRepo, proxySvc proxy.Service) *DownloadService { + return &DownloadService{ + log: log.With().Str("module", "release-download").Logger(), + repo: repo, + indexerRepo: indexerRepo, + proxySvc: proxySvc, + } +} + +func (s *DownloadService) DownloadRelease(ctx context.Context, rls *domain.Release) error { + if rls.HasMagnetUri() { + return errors.New("downloading magnet links is not supported: %s", rls.MagnetURI) + } else if rls.Protocol != domain.ReleaseProtocolTorrent { + return errors.New("could not download file: protocol %s is not supported", rls.Protocol) + } + + if rls.DownloadURL == "" { + return errors.New("download_file: url can't be empty") + } else if rls.TorrentTmpFile != "" { + // already downloaded + return nil + } + + // get indexer + indexer, err := s.indexerRepo.FindByID(ctx, rls.Indexer.ID) + if err != nil { + return err + } + + // get proxy + if indexer.UseProxy { + proxyConf, err := s.proxySvc.FindByID(ctx, indexer.ProxyID) + if err != nil { + return err + } + + if proxyConf.Enabled { + s.log.Debug().Msgf("using proxy: %s", proxyConf.Name) + + indexer.Proxy = proxyConf + } else { + s.log.Debug().Msgf("proxy disabled, skip: %s", proxyConf.Name) + } + } + + // download release + err = s.downloadTorrentFile(ctx, indexer, rls) + if err != nil { + return err + } + + return nil +} + +func (s *DownloadService) downloadTorrentFile(ctx context.Context, indexer *domain.Indexer, r *domain.Release) error { + req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.DownloadURL, nil) + if err != nil { + return errors.Wrap(err, "error downloading file") + } + + req.Header.Set("User-Agent", "autobrr") + + httpClient := &http.Client{ + Timeout: 30 * time.Second, + Transport: sharedhttp.TransportTLSInsecure, + } + + // handle proxy + if indexer.Proxy != nil { + s.log.Debug().Msgf("using proxy: %s", indexer.Proxy.Name) + + proxiedClient, err := proxy.GetProxiedHTTPClient(indexer.Proxy) + if err != nil { + return errors.Wrap(err, "could not get proxied http client") + } + + httpClient = proxiedClient + } + + if r.RawCookie != "" { + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + return errors.Wrap(err, "could not create cookiejar") + } + httpClient.Jar = jar + + // set the cookie on the header instead of req.AddCookie + // since we have a raw cookie like "uid=10; pass=000" + req.Header.Set("Cookie", r.RawCookie) + } + + tmpFilePattern := "autobrr-" + tmpDir := os.TempDir() + + // Create tmp file + // TODO check if tmp file is wanted + tmpFile, err := os.CreateTemp(tmpDir, tmpFilePattern) + if err != nil { + if os.IsNotExist(err) { + if mkdirErr := os.MkdirAll(tmpDir, os.ModePerm); mkdirErr != nil { + return errors.Wrap(mkdirErr, "could not create TMP dir: %s", tmpDir) + } + + tmpFile, err = os.CreateTemp(tmpDir, tmpFilePattern) + if err != nil { + return errors.Wrap(err, "error creating tmp file in: %s", tmpDir) + } + } else { + return errors.Wrap(err, "error creating tmp file") + } + } + defer tmpFile.Close() + + errFunc := retry.Do(retryableRequest(httpClient, req, r, tmpFile), retry.Delay(time.Second*3), retry.Attempts(3), retry.MaxJitter(time.Second*1)) + + return errFunc +} + +func retryableRequest(httpClient *http.Client, req *http.Request, r *domain.Release, tmpFile *os.File) func() error { + return func() error { + // Get the data + resp, err := httpClient.Do(req) + if err != nil { + if errors.As(err, net.OpError{}) { + return retry.Unrecoverable(errors.Wrap(err, "issue from proxy")) + } + return errors.Wrap(err, "error downloading file") + } + defer resp.Body.Close() + + // Check server response + switch resp.StatusCode { + case http.StatusOK: + // Continue processing the response + break + + //case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect: + // // Handle redirect + // return retry.Unrecoverable(errors.New("redirect encountered for torrent (%s) file (%s) - status code: %d - check indexer keys for %s", r.TorrentName, r.DownloadURL, resp.StatusCode, r.Indexer.Name)) + + case http.StatusUnauthorized, http.StatusForbidden: + return retry.Unrecoverable(errors.New("unrecoverable error downloading torrent (%s) file (%s) - status code: %d - check indexer keys for %s", r.TorrentName, r.DownloadURL, resp.StatusCode, r.Indexer.Name)) + + case http.StatusMethodNotAllowed: + return retry.Unrecoverable(errors.New("unrecoverable error downloading torrent (%s) file (%s) from '%s' - status code: %d. Check if the request method is correct", r.TorrentName, r.DownloadURL, r.Indexer.Name, resp.StatusCode)) + case http.StatusNotFound: + return errors.New("torrent %s not found on %s (%d) - retrying", r.TorrentName, r.Indexer.Name, resp.StatusCode) + + case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout: + return errors.New("server error (%d) encountered while downloading torrent (%s) file (%s) from '%s' - retrying", resp.StatusCode, r.TorrentName, r.DownloadURL, r.Indexer.Name) + + case http.StatusInternalServerError: + return errors.New("server error (%d) encountered while downloading torrent (%s) file (%s) - check indexer keys for %s", resp.StatusCode, r.TorrentName, r.DownloadURL, r.Indexer.Name) + + default: + return retry.Unrecoverable(errors.New("unexpected status code %d: check indexer keys for %s", resp.StatusCode, r.Indexer.Name)) + } + + resetTmpFile := func() { + tmpFile.Seek(0, io.SeekStart) + tmpFile.Truncate(0) + } + + // Read the body into bytes + bodyBytes, err := io.ReadAll(bufio.NewReader(resp.Body)) + if err != nil { + return errors.Wrap(err, "error reading response body") + } + + // Create a new reader for bodyBytes + bodyReader := bytes.NewReader(bodyBytes) + + // Try to decode as torrent file + meta, err := metainfo.Load(bodyReader) + if err != nil { + resetTmpFile() + + // explicitly check for unexpected content type that match html + var bse *bencode.SyntaxError + if errors.As(err, &bse) { + // regular error so we can retry if we receive html first run + return errors.Wrap(err, "metainfo unexpected content type, got HTML expected a bencoded torrent. check indexer keys for %s - %s", r.Indexer.Name, r.TorrentName) + } + + return retry.Unrecoverable(errors.Wrap(err, "metainfo unexpected content type. check indexer keys for %s - %s", r.Indexer.Name, r.TorrentName)) + } + + torrentMetaInfo, err := meta.UnmarshalInfo() + if err != nil { + resetTmpFile() + return retry.Unrecoverable(errors.Wrap(err, "metainfo could not unmarshal info from torrent: %s", tmpFile.Name())) + } + + hashInfoBytes := meta.HashInfoBytes().Bytes() + if len(hashInfoBytes) < 1 { + resetTmpFile() + return retry.Unrecoverable(errors.New("could not read infohash")) + } + + // Write the body to file + // TODO move to io.Reader and pass around in the future + if _, err := tmpFile.Write(bodyBytes); err != nil { + resetTmpFile() + return errors.Wrap(err, "error writing downloaded file: %s", tmpFile.Name()) + } + + r.TorrentTmpFile = tmpFile.Name() + r.TorrentHash = meta.HashInfoBytes().String() + r.Size = uint64(torrentMetaInfo.TotalLength()) + + return nil + } +} + +func (s *DownloadService) ResolveMagnetURI(ctx context.Context, r *domain.Release) error { + if r.MagnetURI == "" { + return nil + } else if strings.HasPrefix(r.MagnetURI, domain.MagnetURIPrefix) { + return nil + } + + // get indexer + indexer, err := s.indexerRepo.FindByID(ctx, r.Indexer.ID) + if err != nil { + return err + } + + httpClient := &http.Client{ + Timeout: time.Second * 45, + Transport: sharedhttp.MagnetTransport, + } + + // get proxy + if indexer.UseProxy { + proxyConf, err := s.proxySvc.FindByID(ctx, indexer.ProxyID) + if err != nil { + return err + } + + s.log.Debug().Msgf("using proxy: %s", proxyConf.Name) + + proxiedClient, err := proxy.GetProxiedHTTPClient(proxyConf) + if err != nil { + return errors.Wrap(err, "could not get proxied http client") + } + + httpClient = proxiedClient + } + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.MagnetURI, nil) + if err != nil { + return errors.Wrap(err, "could not build request to resolve magnet uri") + } + + //req.Header.Set("Content-Type", "application/json") + req.Header.Set("User-Agent", "autobrr") + + res, err := httpClient.Do(req) + if err != nil { + return errors.Wrap(err, "could not make request to resolve magnet uri") + } + + defer res.Body.Close() + + if res.StatusCode != http.StatusOK { + return errors.New("unexpected status code: %d", res.StatusCode) + } + + body, err := io.ReadAll(res.Body) + if err != nil { + return errors.Wrap(err, "could not read response body") + } + + magnet := string(body) + if magnet != "" { + r.MagnetURI = magnet + } + + return nil +} diff --git a/pkg/newznab/newznab.go b/pkg/newznab/newznab.go index c195df0..65713e7 100644 --- a/pkg/newznab/newznab.go +++ b/pkg/newznab/newznab.go @@ -25,6 +25,7 @@ type Client interface { GetFeed(ctx context.Context) (*Feed, error) GetCaps(ctx context.Context) (*Caps, error) Caps() *Caps + WithHTTPClient(client *http.Client) } type client struct { @@ -41,6 +42,10 @@ type client struct { Log *log.Logger } +func (c *client) WithHTTPClient(client *http.Client) { + c.http = client +} + type BasicAuth struct { Username string Password string @@ -99,6 +104,9 @@ func (c *client) get(ctx context.Context, endpoint string, queryParams map[strin } u, err := url.Parse(c.Host) + if err != nil { + return 0, nil, err + } u.Path = strings.TrimSuffix(u.Path, "/") u.RawQuery = params.Encode() reqUrl := u.String() @@ -273,6 +281,9 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s } u, err := url.Parse(c.Host) + if err != nil { + return 0, nil, err + } u.Path = strings.TrimSuffix(u.Path, "/") u.RawQuery = params.Encode() reqUrl := u.String() @@ -325,7 +336,6 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s } func (c *client) GetCaps(ctx context.Context) (*Caps, error) { - status, res, err := c.getCaps(ctx, "?t=caps", nil) if err != nil { return nil, errors.Wrap(err, "could not get caps for feed") diff --git a/pkg/torznab/torznab.go b/pkg/torznab/torznab.go index 3109984..7e6ae1b 100644 --- a/pkg/torznab/torznab.go +++ b/pkg/torznab/torznab.go @@ -23,6 +23,7 @@ type Client interface { FetchFeed(ctx context.Context) (*Feed, error) FetchCaps(ctx context.Context) (*Caps, error) GetCaps() *Caps + WithHTTPClient(client *http.Client) } type client struct { @@ -39,6 +40,10 @@ type client struct { Log *log.Logger } +func (c *client) WithHTTPClient(client *http.Client) { + c.http = client +} + type BasicAuth struct { Username string Password string @@ -90,6 +95,10 @@ func (c *client) get(ctx context.Context, endpoint string, opts map[string]strin } u, err := url.Parse(c.Host) + if err != nil { + return 0, nil, err + } + u.Path = strings.TrimSuffix(u.Path, "/") u.RawQuery = params.Encode() reqUrl := u.String() @@ -177,6 +186,9 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s } u, err := url.Parse(c.Host) + if err != nil { + return 0, nil, err + } u.Path = strings.TrimSuffix(u.Path, "/") u.RawQuery = params.Encode() reqUrl := u.String() @@ -229,7 +241,6 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s } func (c *client) FetchCaps(ctx context.Context) (*Caps, error) { - status, res, err := c.getCaps(ctx, "?t=caps", nil) if err != nil { return nil, errors.Wrap(err, "could not get caps for feed") diff --git a/web/src/api/APIClient.ts b/web/src/api/APIClient.ts index 28a1df3..bca32d9 100644 --- a/web/src/api/APIClient.ts +++ b/web/src/api/APIClient.ts @@ -389,6 +389,20 @@ export const APIClient = { body: notification }) }, + proxy: { + list: () => appClient.Get("api/proxy"), + getByID: (id: number) => appClient.Get(`api/proxy/${id}`), + store: (proxy: ProxyCreate) => appClient.Post("api/proxy", { + body: proxy + }), + update: (proxy: Proxy) => appClient.Put(`api/proxy/${proxy.id}`, { + body: proxy + }), + delete: (id: number) => appClient.Delete(`api/proxy/${id}`), + test: (proxy: Proxy) => appClient.Post("api/proxy/test", { + body: proxy + }) + }, release: { find: (query?: string) => appClient.Get(`api/release${query}`), findRecent: () => appClient.Get("api/release/recent"), diff --git a/web/src/api/queries.ts b/web/src/api/queries.ts index b4f863d..4d1981b 100644 --- a/web/src/api/queries.ts +++ b/web/src/api/queries.ts @@ -11,7 +11,7 @@ import { FeedKeys, FilterKeys, IndexerKeys, - IrcKeys, NotificationKeys, + IrcKeys, NotificationKeys, ProxyKeys, ReleaseKeys, SettingsKeys } from "@api/query_keys"; @@ -137,3 +137,17 @@ export const ReleasesIndexersQueryOptions = () => placeholderData: keepPreviousData, staleTime: Infinity }); + +export const ProxiesQueryOptions = () => + queryOptions({ + queryKey: ProxyKeys.lists(), + queryFn: () => APIClient.proxy.list(), + refetchOnWindowFocus: false + }); + +export const ProxyByIdQueryOptions = (proxyId: number) => + queryOptions({ + queryKey: ProxyKeys.detail(proxyId), + queryFn: async ({queryKey}) => await APIClient.proxy.getByID(queryKey[2]), + retry: false, + }); diff --git a/web/src/api/query_keys.ts b/web/src/api/query_keys.ts index d8260aa..d091106 100644 --- a/web/src/api/query_keys.ts +++ b/web/src/api/query_keys.ts @@ -79,4 +79,11 @@ export const NotificationKeys = { lists: () => [...NotificationKeys.all, "list"] as const, details: () => [...NotificationKeys.all, "detail"] as const, detail: (id: number) => [...NotificationKeys.details(), id] as const -}; \ No newline at end of file +}; + +export const ProxyKeys = { + all: ["proxy"] as const, + lists: () => [...ProxyKeys.all, "list"] as const, + details: () => [...ProxyKeys.all, "detail"] as const, + detail: (id: number) => [...ProxyKeys.details(), id] as const +}; diff --git a/web/src/components/inputs/select_wide.tsx b/web/src/components/inputs/select_wide.tsx index ce42482..d6df8bb 100644 --- a/web/src/components/inputs/select_wide.tsx +++ b/web/src/components/inputs/select_wide.tsx @@ -17,6 +17,7 @@ interface SelectFieldProps { label: string; help?: string; placeholder?: string; + required?: boolean; defaultValue?: OptionBasicTyped; tooltip?: JSX.Element; options: OptionBasicTyped[]; @@ -158,7 +159,7 @@ export function SelectField({ name, label, help, placeholder, options }: Sele ); } -export function SelectFieldBasic({ name, label, help, placeholder, tooltip, defaultValue, options }: SelectFieldProps) { +export function SelectFieldBasic({ name, label, help, placeholder, required, tooltip, defaultValue, options }: SelectFieldProps) { return (
@@ -182,6 +183,7 @@ export function SelectFieldBasic({ name, label, help, placeholder, tooltip, d ({ name, label, options }: SelectFieldProps) { IndicatorSeparator: common.IndicatorSeparator, DropdownIndicator: common.DropdownIndicator }} - placeholder="Choose a type" + placeholder={placeholder ?? "Choose a type"} styles={{ singleValue: (base) => ({ ...base, @@ -487,14 +520,18 @@ function SelectField({ name, label, options }: SelectFieldProps) { })} value={field?.value && options.find(o => o.value == field?.value)} onChange={(option) => { - resetForm(); + // resetForm(); - // const opt = option as SelectOption; - // setFieldValue("name", option?.label ?? "") - setFieldValue( - field.name, - option.value ?? "" - ); + if (option !== null) { + // const opt = option as SelectOption; + // setFieldValue("name", option?.label ?? "") + setFieldValue( + field.name, + option.value ?? "" + ); + } else { + setFieldValue(field.name, undefined); + } }} options={options} /> diff --git a/web/src/forms/settings/ProxyForms.tsx b/web/src/forms/settings/ProxyForms.tsx new file mode 100644 index 0000000..c34f5d2 --- /dev/null +++ b/web/src/forms/settings/ProxyForms.tsx @@ -0,0 +1,265 @@ +import { Fragment } from "react"; +import { Form, Formik, FormikValues } from "formik"; +import { Dialog, DialogPanel, DialogTitle, Transition, TransitionChild } from "@headlessui/react"; +import { XMarkIcon } from "@heroicons/react/24/solid"; +import { useMutation, useQueryClient } from "@tanstack/react-query"; +import { toast } from "react-hot-toast"; + +import { AddProps } from "@forms/settings/IndexerForms"; +import { DEBUG } from "@components/debug.tsx"; +import { PasswordFieldWide, SwitchGroupWide, TextFieldWide } from "@components/inputs"; +import { SelectFieldBasic } from "@components/inputs/select_wide"; +import { ProxyTypeOptions } from "@domain/constants"; +import { APIClient } from "@api/APIClient"; +import { ProxyKeys } from "@api/query_keys"; +import Toast from "@components/notifications/Toast"; +import { SlideOver } from "@components/panels"; + +export function ProxyAddForm({ isOpen, toggle }: AddProps) { + const queryClient = useQueryClient(); + + const createMutation = useMutation({ + mutationFn: (req: ProxyCreate) => APIClient.proxy.store(req), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() }); + + toast.custom((t) => ); + toggle(); + }, + onError: () => { + toast.custom((t) => ); + } + }); + + const onSubmit = (formData: FormikValues) => { + createMutation.mutate(formData as ProxyCreate); + } + + const testMutation = useMutation({ + mutationFn: (data: Proxy) => APIClient.proxy.test(data), + onError: (err) => { + console.error(err); + } + }); + + const testProxy = (data: unknown) => testMutation.mutate(data as Proxy); + + const initialValues: ProxyCreate = { + enabled: true, + name: "Proxy", + type: "SOCKS5", + addr: "socks5://ip:port", + user: "", + pass: "", + } + + return ( + + +
+ + +
+ + {({ values }) => ( +
+
+
+
+
+ + Add proxy + +

+ Add proxy to be used with Indexers or IRC. +

+
+
+ +
+
+
+ +
+ + + + Proxy type. Commonly SOCKS5.} + help="Usually SOCKS5" + /> + + +
+ +
+ + +
+
+ +
+
+ + + +
+
+ + + + )} +
+
+ +
+
+
+
+
+ ); +} + + +interface UpdateFormProps { + isOpen: boolean; + toggle: () => void; + data: T; +} + +export function ProxyUpdateForm({ isOpen, toggle, data }: UpdateFormProps) { + const queryClient = useQueryClient(); + + const updateMutation = useMutation({ + mutationFn: (req: Proxy) => APIClient.proxy.update(req), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() }); + + toast.custom((t) => ); + toggle(); + }, + onError: () => { + toast.custom((t) => ); + } + }); + + const onSubmit = (formData: unknown) => { + updateMutation.mutate(formData as Proxy); + } + + const deleteMutation = useMutation({ + mutationFn: (proxyId: number) => APIClient.proxy.delete(proxyId), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() }); + + toast.custom((t) => ); + } + }); + + const deleteFn = () => deleteMutation.mutate(data.id); + + const testMutation = useMutation({ + mutationFn: (data: Proxy) => APIClient.proxy.test(data), + onError: (err) => { + console.error(err); + } + }); + + const testProxy = (data: unknown) => testMutation.mutate(data as Proxy); + + const initialValues: Proxy = { + id: data.id, + enabled: data.enabled, + name: data.name, + type: data.type, + addr: data.addr, + user: data.user, + pass: data.pass, + } + + return ( + + title="Proxy" + initialValues={initialValues} + onSubmit={onSubmit} + deleteAction={deleteFn} + testFn={testProxy} + isOpen={isOpen} + toggle={toggle} + type="UPDATE" + > + {() => ( +
+
+ + + + + + Proxy type. Commonly SOCKS5.} + help="Usually SOCKS5" + /> + + +
+ +
+ + +
+
+ )} + + ); +} diff --git a/web/src/routes.tsx b/web/src/routes.tsx index c5e8909..53fb526 100644 --- a/web/src/routes.tsx +++ b/web/src/routes.tsx @@ -11,7 +11,7 @@ import { notFound, Outlet, redirect, - } from "@tanstack/react-router"; +} from "@tanstack/react-router"; import { z } from "zod"; import { QueryClient } from "@tanstack/react-query"; @@ -30,7 +30,8 @@ import { FilterByIdQueryOptions, IndexersQueryOptions, IrcQueryOptions, - NotificationsQueryOptions + NotificationsQueryOptions, + ProxiesQueryOptions } from "@api/queries"; import LogSettings from "@screens/settings/Logs"; import NotificationSettings from "@screens/settings/Notifications"; @@ -50,6 +51,7 @@ import { AuthContext, SettingsContext } from "@utils/Context"; import { TanStackRouterDevtools } from "@tanstack/router-devtools"; import { ReactQueryDevtools } from "@tanstack/react-query-devtools"; import { queryClient } from "@api/QueryClient"; +import ProxySettings from "@screens/settings/Proxy"; import { ErrorPage } from "@components/alerts"; @@ -212,6 +214,13 @@ export const SettingsApiRoute = createRoute({ component: APISettings }); +export const SettingsProxiesRoute = createRoute({ + getParentRoute: () => SettingsRoute, + path: 'proxies', + loader: (opts) => opts.context.queryClient.ensureQueryData(ProxiesQueryOptions()), + component: ProxySettings +}); + export const SettingsReleasesRoute = createRoute({ getParentRoute: () => SettingsRoute, path: 'releases', @@ -339,7 +348,7 @@ export const RootRoute = createRootRouteWithContext<{ }); const filterRouteTree = FiltersRoute.addChildren([FilterIndexRoute, FilterGetByIdRoute.addChildren([FilterGeneralRoute, FilterMoviesTvRoute, FilterMusicRoute, FilterAdvancedRoute, FilterExternalRoute, FilterActionsRoute])]) -const settingsRouteTree = SettingsRoute.addChildren([SettingsIndexRoute, SettingsLogRoute, SettingsIndexersRoute, SettingsIrcRoute, SettingsFeedsRoute, SettingsClientsRoute, SettingsNotificationsRoute, SettingsApiRoute, SettingsReleasesRoute, SettingsAccountRoute]) +const settingsRouteTree = SettingsRoute.addChildren([SettingsIndexRoute, SettingsLogRoute, SettingsIndexersRoute, SettingsIrcRoute, SettingsFeedsRoute, SettingsClientsRoute, SettingsNotificationsRoute, SettingsApiRoute, SettingsProxiesRoute, SettingsReleasesRoute, SettingsAccountRoute]) const authenticatedTree = AuthRoute.addChildren([AuthIndexRoute.addChildren([DashboardRoute, filterRouteTree, ReleasesRoute, settingsRouteTree, LogsRoute])]) const routeTree = RootRoute.addChildren([ authenticatedTree, diff --git a/web/src/screens/Settings.tsx b/web/src/screens/Settings.tsx index 4252c7c..6c9d197 100644 --- a/web/src/screens/Settings.tsx +++ b/web/src/screens/Settings.tsx @@ -8,6 +8,7 @@ import { ChatBubbleLeftRightIcon, CogIcon, FolderArrowDownIcon, + GlobeAltIcon, KeyIcon, RectangleStackIcon, RssIcon, @@ -34,6 +35,7 @@ const subNavigation: NavTabType[] = [ { name: "Clients", href: "/settings/clients", icon: FolderArrowDownIcon }, { name: "Notifications", href: "/settings/notifications", icon: BellIcon }, { name: "API keys", href: "/settings/api", icon: KeyIcon }, + { name: "Proxies", href: "/settings/proxies", icon: GlobeAltIcon }, { name: "Releases", href: "/settings/releases", icon: RectangleStackIcon }, { name: "Account", href: "/settings/account", icon: UserCircleIcon } // {name: 'Regex Playground', href: 'regex-playground', icon: CogIcon, current: false} diff --git a/web/src/screens/settings/Proxy.tsx b/web/src/screens/settings/Proxy.tsx new file mode 100644 index 0000000..072065d --- /dev/null +++ b/web/src/screens/settings/Proxy.tsx @@ -0,0 +1,140 @@ +import { useToggle } from "@hooks/hooks.ts"; +import { useMutation, useQueryClient, useSuspenseQuery } from "@tanstack/react-query"; +import { PlusIcon } from "@heroicons/react/24/solid"; +import { toast } from "react-hot-toast"; + +import { APIClient } from "@api/APIClient"; +import { ProxyKeys } from "@api/query_keys"; +import { ProxiesQueryOptions } from "@api/queries"; +import { Section } from "./_components"; +import { EmptySimple } from "@components/emptystates"; +import { Checkbox } from "@components/Checkbox"; +import { ProxyAddForm, ProxyUpdateForm } from "@forms/settings/ProxyForms"; +import Toast from "@components/notifications/Toast"; + +interface ListItemProps { + proxy: Proxy; +} + +function ListItem({ proxy }: ListItemProps) { + const [isOpen, toggleUpdate] = useToggle(false); + + const queryClient = useQueryClient(); + + const updateMutation = useMutation({ + mutationFn: (req: Proxy) => APIClient.proxy.update(req), + onSuccess: () => { + queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() }); + + toast.custom(t => ); + }, + onError: () => { + toast.custom((t) => ); + } + }); + + const onToggleMutation = (newState: boolean) => { + updateMutation.mutate({ + ...proxy, + enabled: newState + }); + }; + + return ( +
  • + + +
    +
    + +
    +
    + {proxy.name} +
    +
    + {proxy.type} +
    +
    + + Edit + +
    +
    +
  • + ); +} + +function ProxySettings() { + const [addProxyIsOpen, toggleAddProxy] = useToggle(false); + + const proxiesQuery = useSuspenseQuery(ProxiesQueryOptions()) + const proxies = proxiesQuery.data + + return ( +
    + Proxies that can be used with Indexers, feeds and IRC.
    + + } + rightSide={ + + } + > + + +
    + {proxies.length ? ( +
      +
    • +
      sortedIndexers.requestSort("enabled")} + > + Enabled + {/*{sortedIndexers.getSortIndicator("enabled")}*/} +
      +
      sortedIndexers.requestSort("name")} + > + Name + {/*{sortedIndexers.getSortIndicator("name")}*/} +
      +
      sortedIndexers.requestSort("implementation")} + > + Type + {/*{sortedIndexers.getSortIndicator("implementation")}*/} +
      +
    • + {proxies.map((proxy) => ( + + ))} +
    + ) : ( + + )} +
    +
    + ); +} + +export default ProxySettings; \ No newline at end of file diff --git a/web/src/screens/settings/index.ts b/web/src/screens/settings/index.ts index 1cffcf3..5653a1f 100644 --- a/web/src/screens/settings/index.ts +++ b/web/src/screens/settings/index.ts @@ -11,6 +11,7 @@ export { default as Indexer } from "./Indexer"; export { default as Irc } from "./Irc"; export { default as Logs } from "./Logs"; export { default as Notification } from "./Notifications"; +export { default as Proxy } from "./Proxy"; export { default as Release } from "./Releases"; export { default as RegexPlayground } from "./RegexPlayground"; export { default as Account } from "./Account"; diff --git a/web/src/types/Indexer.d.ts b/web/src/types/Indexer.d.ts index 1a06cb8..b4fb163 100644 --- a/web/src/types/Indexer.d.ts +++ b/web/src/types/Indexer.d.ts @@ -11,6 +11,8 @@ interface Indexer { enabled: boolean; implementation: string; base_url: string; + use_proxy?: boolean; + proxy_id?: number; settings: Array; } @@ -35,6 +37,8 @@ interface IndexerDefinition { protocol: string; urls: string[]; supports: string[]; + use_proxy?: boolean; + proxy_id?: number; settings: IndexerSetting[]; irc: IndexerIRC; torznab: IndexerTorznab; diff --git a/web/src/types/Irc.d.ts b/web/src/types/Irc.d.ts index 983625c..5aadcf0 100644 --- a/web/src/types/Irc.d.ts +++ b/web/src/types/Irc.d.ts @@ -20,6 +20,8 @@ interface IrcNetwork { channels: IrcChannel[]; connected: boolean; connected_since: string; + use_proxy: boolean; + proxy_id: number; } interface IrcNetworkCreate { @@ -53,23 +55,8 @@ interface IrcChannelWithHealth extends IrcChannel { last_announce: string; } -interface IrcNetworkWithHealth { - id: number; - name: string; - enabled: boolean; - server: string; - port: number; - tls: boolean; - pass: string; - nick: string; - auth: IrcAuth; // optional - invite_command: string; - use_bouncer: boolean; - bouncer_addr: string; - bot_mode: boolean; +interface IrcNetworkWithHealth extends IrcNetwork { channels: IrcChannelWithHealth[]; - connected: boolean; - connected_since: string; connection_errors: string[]; healthy: boolean; } diff --git a/web/src/types/Proxy.d.ts b/web/src/types/Proxy.d.ts new file mode 100644 index 0000000..ca7e434 --- /dev/null +++ b/web/src/types/Proxy.d.ts @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors. + * SPDX-License-Identifier: GPL-2.0-or-later + */ + +interface Proxy { + id: number; + name: string; + enabled: boolean; + type: ProxyType; + addr: string; + user?: string; + pass?: string; + timeout?: number; +} + +interface ProxyCreate { + name: string; + enabled: boolean; + type: ProxyType; + addr: string; + user?: string; + pass?: string; + timeout?: number; +} + +type ProxyType = "SOCKS5" | "HTTP";