feat: add support for proxies to use with IRC and Indexers (#1421)

* feat: add support for proxies

* fix(http): release handler

* fix(migrations): define proxy early

* fix(migrations): pg proxy

* fix(proxy): list update delete

* fix(proxy): remove log and imports

* feat(irc): use proxy

* feat(irc): tests

* fix(web): update imports for ProxyForms.tsx

* fix(database): migration

* feat(proxy): test

* feat(proxy): validate proxy type

* feat(proxy): validate and test

* feat(proxy): improve validate and test

* feat(proxy): fix db schema

* feat(proxy): add db tests

* feat(proxy): handle http errors

* fix(http): imports

* feat(proxy): use proxy for indexer downloads

* feat(proxy): indexerforms select proxy

* feat(proxy): handle torrent download

* feat(proxy): skip if disabled

* feat(proxy): imports

* feat(proxy): implement in Feeds

* feat(proxy): update helper text indexer proxy

* feat(proxy): add internal cache
This commit is contained in:
ze0s 2024-09-02 11:10:45 +02:00 committed by GitHub
parent 472d327308
commit bc0f4cc055
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 2533 additions and 371 deletions

View file

@ -24,7 +24,9 @@ import (
"github.com/autobrr/autobrr/internal/irc" "github.com/autobrr/autobrr/internal/irc"
"github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/notification" "github.com/autobrr/autobrr/internal/notification"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/internal/releasedownload"
"github.com/autobrr/autobrr/internal/scheduler" "github.com/autobrr/autobrr/internal/scheduler"
"github.com/autobrr/autobrr/internal/server" "github.com/autobrr/autobrr/internal/server"
"github.com/autobrr/autobrr/internal/update" "github.com/autobrr/autobrr/internal/update"
@ -103,24 +105,27 @@ func main() {
notificationRepo = database.NewNotificationRepo(log, db) notificationRepo = database.NewNotificationRepo(log, db)
releaseRepo = database.NewReleaseRepo(log, db) releaseRepo = database.NewReleaseRepo(log, db)
userRepo = database.NewUserRepo(log, db) userRepo = database.NewUserRepo(log, db)
proxyRepo = database.NewProxyRepo(log, db)
) )
// setup services // setup services
var ( var (
apiService = api.NewService(log, apikeyRepo) apiService = api.NewService(log, apikeyRepo)
notificationService = notification.NewService(log, notificationRepo)
updateService = update.NewUpdate(log, cfg.Config) updateService = update.NewUpdate(log, cfg.Config)
notificationService = notification.NewService(log, notificationRepo)
schedulingService = scheduler.NewService(log, cfg.Config, notificationService, updateService) schedulingService = scheduler.NewService(log, cfg.Config, notificationService, updateService)
indexerAPIService = indexer.NewAPIService(log) indexerAPIService = indexer.NewAPIService(log)
userService = user.NewService(userRepo) userService = user.NewService(userRepo)
authService = auth.NewService(log, userService) authService = auth.NewService(log, userService)
proxyService = proxy.NewService(log, proxyRepo)
downloadService = releasedownload.NewDownloadService(log, releaseRepo, indexerRepo, proxyService)
downloadClientService = download_client.NewService(log, downloadClientRepo) downloadClientService = download_client.NewService(log, downloadClientRepo)
actionService = action.NewService(log, actionRepo, downloadClientService, bus) actionService = action.NewService(log, actionRepo, downloadClientService, downloadService, bus)
indexerService = indexer.NewService(log, cfg.Config, indexerRepo, releaseRepo, indexerAPIService, schedulingService) indexerService = indexer.NewService(log, cfg.Config, indexerRepo, releaseRepo, indexerAPIService, schedulingService)
filterService = filter.NewService(log, filterRepo, actionService, releaseRepo, indexerAPIService, indexerService) filterService = filter.NewService(log, filterRepo, actionService, releaseRepo, indexerAPIService, indexerService, downloadService)
releaseService = release.NewService(log, releaseRepo, actionService, filterService, indexerService) releaseService = release.NewService(log, releaseRepo, actionService, filterService, indexerService)
ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService) ircService = irc.NewService(log, serverEvents, ircRepo, releaseService, indexerService, notificationService, proxyService)
feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, schedulingService) feedService = feed.NewService(log, feedRepo, feedCacheRepo, releaseService, proxyService, schedulingService)
) )
// register event subscribers // register event subscribers
@ -146,6 +151,7 @@ func main() {
indexerService, indexerService,
ircService, ircService,
notificationService, notificationService,
proxyService,
releaseService, releaseService,
updateService, updateService,
) )

View file

@ -140,9 +140,8 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
return nil, nil return nil, nil
} else { } else {
if release.TorrentTmpFile == "" { if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName) return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
return nil, err
} }
} }
@ -243,11 +242,8 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a
return nil, nil return nil, nil
} else { } else {
if release.TorrentTmpFile == "" { if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
return nil, err
}
} }
t, err := os.ReadFile(release.TorrentTmpFile) t, err := os.ReadFile(release.TorrentTmpFile)

View file

@ -74,10 +74,8 @@ func (s *service) porla(ctx context.Context, action *domain.Action, release doma
return nil, nil return nil, nil
} else { } else {
if release.TorrentTmpFile == "" { if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
return nil, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName)
}
} }
file, err := os.Open(release.TorrentTmpFile) file, err := os.Open(release.TorrentTmpFile)

View file

@ -57,10 +57,8 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
return nil, nil return nil, nil
} }
if release.TorrentTmpFile == "" { if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
return nil, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName)
}
} }
options, err := s.prepareQbitOptions(action) options, err := s.prepareQbitOptions(action)

View file

@ -68,11 +68,8 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d
return nil, nil return nil, nil
} }
if release.TorrentTmpFile == "" { if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
return nil, err
}
} }
tmpFile, err := os.ReadFile(release.TorrentTmpFile) tmpFile, err := os.ReadFile(release.TorrentTmpFile)

View file

@ -36,9 +36,8 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release
return nil, errors.New("action %s client %s %s not enabled, skipping", action.Name, action.Client.Type, action.Client.Name) return nil, errors.New("action %s client %s %s not enabled, skipping", action.Name, action.Client.Type, action.Client.Name)
} }
// if set, try to resolve MagnetURI before parsing macros // Check preconditions: download torrent file if needed
// to allow webhook and exec to get the magnet_uri if err := s.CheckActionPreconditions(ctx, action, release); err != nil {
if err := release.ResolveMagnetUri(ctx); err != nil {
return nil, err return nil, err
} }
@ -137,6 +136,30 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release
return rejections, err return rejections, err
} }
func (s *service) CheckActionPreconditions(ctx context.Context, action *domain.Action, release *domain.Release) error {
if err := s.downloadSvc.ResolveMagnetURI(ctx, release); err != nil {
return errors.Wrap(err, "could not resolve magnet uri: %s", release.MagnetURI)
}
// parse all macros in one go
if action.CheckMacrosNeedTorrentTmpFile(release) {
if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
return errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
}
if action.CheckMacrosNeedRawDataBytes(release) {
tmpFile, err := os.ReadFile(release.TorrentTmpFile)
if err != nil {
return errors.Wrap(err, "could not read torrent file: %v", release.TorrentTmpFile)
}
release.TorrentDataRawBytes = tmpFile
}
return nil
}
func (s *service) test(name string) { func (s *service) test(name string) {
s.log.Info().Msgf("action TEST: %v", name) s.log.Info().Msgf("action TEST: %v", name)
} }

View file

@ -12,6 +12,7 @@ import (
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/download_client" "github.com/autobrr/autobrr/internal/download_client"
"github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/releasedownload"
"github.com/autobrr/autobrr/pkg/sharedhttp" "github.com/autobrr/autobrr/pkg/sharedhttp"
"github.com/asaskevich/EventBus" "github.com/asaskevich/EventBus"
@ -33,21 +34,23 @@ type Service interface {
} }
type service struct { type service struct {
log zerolog.Logger log zerolog.Logger
subLogger *log.Logger subLogger *log.Logger
repo domain.ActionRepo repo domain.ActionRepo
clientSvc download_client.Service clientSvc download_client.Service
bus EventBus.Bus downloadSvc *releasedownload.DownloadService
bus EventBus.Bus
httpClient *http.Client httpClient *http.Client
} }
func NewService(log logger.Logger, repo domain.ActionRepo, clientSvc download_client.Service, bus EventBus.Bus) Service { func NewService(log logger.Logger, repo domain.ActionRepo, clientSvc download_client.Service, downloadSvc *releasedownload.DownloadService, bus EventBus.Bus) Service {
s := &service{ s := &service{
log: log.With().Str("module", "action").Logger(), log: log.With().Str("module", "action").Logger(),
repo: repo, repo: repo,
clientSvc: clientSvc, clientSvc: clientSvc,
bus: bus, downloadSvc: downloadSvc,
bus: bus,
httpClient: &http.Client{ httpClient: &http.Client{
Timeout: time.Second * 120, Timeout: time.Second * 120,

View file

@ -107,11 +107,8 @@ func (s *service) transmission(ctx context.Context, action *domain.Action, relea
return nil, nil return nil, nil
} }
if release.TorrentTmpFile == "" { if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
return nil, err
}
} }
b64, err := transmissionrpc.File2Base64(release.TorrentTmpFile) b64, err := transmissionrpc.File2Base64(release.TorrentTmpFile)

View file

@ -36,6 +36,8 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
"i.identifier", "i.identifier",
"i.identifier_external", "i.identifier_external",
"i.name", "i.name",
"i.use_proxy",
"i.proxy_id",
"f.name", "f.name",
"f.type", "f.type",
"f.enabled", "f.enabled",
@ -66,8 +68,9 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
var f domain.Feed var f domain.Feed
var apiKey, cookie, settings sql.NullString var apiKey, cookie, settings sql.NullString
var proxyID sql.NullInt64
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound return nil, domain.ErrRecordNotFound
} }
@ -75,6 +78,7 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
f.ProxyID = proxyID.Int64
f.ApiKey = apiKey.String f.ApiKey = apiKey.String
f.Cookie = cookie.String f.Cookie = cookie.String
@ -98,6 +102,8 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
"i.identifier", "i.identifier",
"i.identifier_external", "i.identifier_external",
"i.name", "i.name",
"i.use_proxy",
"i.proxy_id",
"f.name", "f.name",
"f.type", "f.type",
"f.enabled", "f.enabled",
@ -128,8 +134,9 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
var f domain.Feed var f domain.Feed
var apiKey, cookie, settings sql.NullString var apiKey, cookie, settings sql.NullString
var proxyID sql.NullInt64
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound return nil, domain.ErrRecordNotFound
} }
@ -137,6 +144,7 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
f.ProxyID = proxyID.Int64
f.ApiKey = apiKey.String f.ApiKey = apiKey.String
f.Cookie = cookie.String f.Cookie = cookie.String
@ -158,6 +166,8 @@ func (r *FeedRepo) Find(ctx context.Context) ([]domain.Feed, error) {
"i.identifier", "i.identifier",
"i.identifier_external", "i.identifier_external",
"i.name", "i.name",
"i.use_proxy",
"i.proxy_id",
"f.name", "f.name",
"f.type", "f.type",
"f.enabled", "f.enabled",
@ -196,10 +206,13 @@ func (r *FeedRepo) Find(ctx context.Context) ([]domain.Feed, error) {
var apiKey, cookie, lastRunData, settings sql.NullString var apiKey, cookie, lastRunData, settings sql.NullString
var lastRun sql.NullTime var lastRun sql.NullTime
if err := rows.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &lastRun, &lastRunData, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil { var proxyID sql.NullInt64
if err := rows.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &lastRun, &lastRunData, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
f.ProxyID = proxyID.Int64
f.LastRun = lastRun.Time f.LastRun = lastRun.Time
f.LastRunData = lastRunData.String f.LastRunData = lastRunData.String
f.ApiKey = apiKey.String f.ApiKey = apiKey.String

View file

@ -36,8 +36,8 @@ func (r *IndexerRepo) Store(ctx context.Context, indexer domain.Indexer) (*domai
} }
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Insert("indexer").Columns("enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). Insert("indexer").Columns("enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
Values(indexer.Enabled, indexer.Name, indexer.Identifier, indexer.IdentifierExternal, indexer.Implementation, indexer.BaseURL, settings). Values(indexer.Enabled, indexer.Name, indexer.Identifier, indexer.IdentifierExternal, indexer.Implementation, indexer.BaseURL, indexer.UseProxy, toNullInt64(indexer.ProxyID), settings).
Suffix("RETURNING id").RunWith(r.db.handler) Suffix("RETURNING id").RunWith(r.db.handler)
// return values // return values
@ -61,6 +61,8 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma
Set("name", indexer.Name). Set("name", indexer.Name).
Set("identifier_external", indexer.IdentifierExternal). Set("identifier_external", indexer.IdentifierExternal).
Set("base_url", indexer.BaseURL). Set("base_url", indexer.BaseURL).
Set("use_proxy", indexer.UseProxy).
Set("proxy_id", toNullInt64(indexer.ProxyID)).
Set("settings", settings). Set("settings", settings).
Set("updated_at", time.Now().Format(time.RFC3339)). Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": indexer.ID}) Where(sq.Eq{"id": indexer.ID})
@ -70,16 +72,26 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma
return nil, errors.Wrap(err, "error building query") return nil, errors.Wrap(err, "error building query")
} }
if _, err = r.db.handler.ExecContext(ctx, query, args...); err != nil { result, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query") return nil, errors.Wrap(err, "error executing query")
} }
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, errors.Wrap(err, "error rows affected")
}
if rowsAffected == 0 {
return nil, domain.ErrUpdateFailed
}
return &indexer, nil return &indexer, nil
} }
func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) { func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer"). From("indexer").
OrderBy("name ASC") OrderBy("name ASC")
@ -98,27 +110,29 @@ func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) {
indexers := make([]domain.Indexer, 0) indexers := make([]domain.Indexer, 0)
for rows.Next() { for rows.Next() {
var f domain.Indexer var i domain.Indexer
var identifierExternal, implementation, baseURL sql.Null[string] var identifierExternal, implementation, baseURL sql.Null[string]
var proxyID sql.Null[int64]
var settings string var settings string
var settingsMap map[string]string var settingsMap map[string]string
if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &f.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil { if err := rows.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
f.IdentifierExternal = identifierExternal.V i.IdentifierExternal = identifierExternal.V
f.Implementation = implementation.V i.Implementation = implementation.V
f.BaseURL = baseURL.V i.BaseURL = baseURL.V
i.ProxyID = proxyID.V
if err = json.Unmarshal([]byte(settings), &settingsMap); err != nil { if err = json.Unmarshal([]byte(settings), &settingsMap); err != nil {
return nil, errors.Wrap(err, "error unmarshal settings") return nil, errors.Wrap(err, "error unmarshal settings")
} }
f.Settings = settingsMap i.Settings = settingsMap
indexers = append(indexers, f) indexers = append(indexers, i)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "error rows") return nil, errors.Wrap(err, "error rows")
@ -129,7 +143,7 @@ func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) {
func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, error) { func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer"). From("indexer").
Where(sq.Eq{"id": id}) Where(sq.Eq{"id": id})
@ -146,8 +160,9 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
var i domain.Indexer var i domain.Indexer
var identifierExternal, implementation, baseURL, settings sql.Null[string] var identifierExternal, implementation, baseURL, settings sql.Null[string]
var proxyID sql.Null[int64]
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil { if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound return nil, domain.ErrRecordNotFound
} }
@ -158,6 +173,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
i.IdentifierExternal = identifierExternal.V i.IdentifierExternal = identifierExternal.V
i.Implementation = implementation.V i.Implementation = implementation.V
i.BaseURL = baseURL.V i.BaseURL = baseURL.V
i.ProxyID = proxyID.V
var settingsMap map[string]string var settingsMap map[string]string
if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil { if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil {
@ -171,7 +187,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (*domain.Indexer, error) { func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (*domain.Indexer, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings"). Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer") From("indexer")
if req.ID > 0 { if req.ID > 0 {
@ -195,8 +211,9 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (
var i domain.Indexer var i domain.Indexer
var identifierExternal, implementation, baseURL, settings sql.Null[string] var identifierExternal, implementation, baseURL, settings sql.Null[string]
var proxyID sql.Null[int64]
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil { if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound return nil, domain.ErrRecordNotFound
} }
@ -207,6 +224,7 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (
i.IdentifierExternal = identifierExternal.V i.IdentifierExternal = identifierExternal.V
i.Implementation = implementation.V i.Implementation = implementation.V
i.BaseURL = baseURL.V i.BaseURL = baseURL.V
i.ProxyID = proxyID.V
var settingsMap map[string]string var settingsMap map[string]string
if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil { if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil {
@ -220,7 +238,7 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (
func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Indexer, error) { func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Indexer, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "base_url", "settings"). Select("id", "enabled", "name", "identifier", "identifier_external", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer"). From("indexer").
Join("filter_indexer ON indexer.id = filter_indexer.indexer_id"). Join("filter_indexer ON indexer.id = filter_indexer.indexer_id").
Where(sq.Eq{"filter_indexer.filter_id": id}) Where(sq.Eq{"filter_indexer.filter_id": id})
@ -239,13 +257,14 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde
indexers := make([]domain.Indexer, 0) indexers := make([]domain.Indexer, 0)
for rows.Next() { for rows.Next() {
var f domain.Indexer var i domain.Indexer
var settings string var settings string
var settingsMap map[string]string var settingsMap map[string]string
var identifierExternal, baseURL sql.Null[string] var identifierExternal, baseURL sql.Null[string]
var proxyID sql.Null[int64]
if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &f.Identifier, &identifierExternal, &baseURL, &settings); err != nil { if err := rows.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
@ -253,11 +272,12 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde
return nil, errors.Wrap(err, "error unmarshal settings") return nil, errors.Wrap(err, "error unmarshal settings")
} }
f.IdentifierExternal = identifierExternal.V i.IdentifierExternal = identifierExternal.V
f.BaseURL = baseURL.V i.BaseURL = baseURL.V
f.Settings = settingsMap i.ProxyID = proxyID.V
i.Settings = settingsMap
indexers = append(indexers, f) indexers = append(indexers, i)
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "error rows") return nil, errors.Wrap(err, "error rows")
@ -282,13 +302,13 @@ func (r *IndexerRepo) Delete(ctx context.Context, id int) error {
return errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }
rows, err := result.RowsAffected() rowsAffected, err := result.RowsAffected()
if err != nil { if err != nil {
return errors.Wrap(err, "error rows affected") return errors.Wrap(err, "error rows affected")
} }
if rows != 1 { if rowsAffected == 0 {
return errors.New("error deleting row") return domain.ErrRecordNotFound
} }
r.log.Debug().Str("method", "delete").Msgf("successfully deleted indexer with id %v", id) r.log.Debug().Str("method", "delete").Msgf("successfully deleted indexer with id %v", id)
@ -297,8 +317,6 @@ func (r *IndexerRepo) Delete(ctx context.Context, id int) error {
} }
func (r *IndexerRepo) ToggleEnabled(ctx context.Context, indexerID int, enabled bool) error { func (r *IndexerRepo) ToggleEnabled(ctx context.Context, indexerID int, enabled bool) error {
var err error
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Update("indexer"). Update("indexer").
Set("enabled", enabled). Set("enabled", enabled).
@ -310,10 +328,19 @@ func (r *IndexerRepo) ToggleEnabled(ctx context.Context, indexerID int, enabled
return errors.Wrap(err, "error building query") return errors.Wrap(err, "error building query")
} }
_, err = r.db.handler.ExecContext(ctx, query, args...) result, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil { if err != nil {
return errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "error rows affected")
}
if rowsAffected == 0 {
return domain.ErrUpdateFailed
}
return nil return nil
} }

View file

@ -30,7 +30,7 @@ func NewIrcRepo(log logger.Logger, db *DB) domain.IrcRepo {
func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) { func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network"). From("irc_network").
Where(sq.Eq{"id": id}) Where(sq.Eq{"id": id})
@ -42,26 +42,27 @@ func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetw
var n domain.IrcNetwork var n domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.NullString var account, password sql.Null[string]
var tls sql.NullBool var tls sql.Null[bool]
var proxyId sql.Null[int64]
row := r.db.handler.QueryRowContext(ctx, query, args...) row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &nick, &n.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &n.UseBouncer, &n.BotMode); err != nil { if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &nick, &n.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &n.UseBouncer, &n.BotMode, &n.UseProxy, &proxyId); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound return nil, domain.ErrRecordNotFound
} }
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
n.TLS = tls.Bool n.TLS = tls.V
n.Pass = pass.String n.Pass = pass.V
n.Nick = nick.String n.Nick = nick.V
n.InviteCommand = inviteCmd.String n.InviteCommand = inviteCmd.V
n.Auth.Account = account.String n.BouncerAddr = bouncerAddr.V
n.Auth.Password = password.String n.Auth.Account = account.V
n.BouncerAddr = bouncerAddr.String n.Auth.Password = password.V
n.ProxyId = proxyId.V
return &n, nil return &n, nil
} }
@ -111,7 +112,7 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) { func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network"). From("irc_network").
Where(sq.Eq{"enabled": true}) Where(sq.Eq{"enabled": true})
@ -131,22 +132,24 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork,
for rows.Next() { for rows.Next() {
var net domain.IrcNetwork var net domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.NullString var account, password sql.Null[string]
var tls sql.NullBool var tls sql.Null[bool]
var proxyId sql.Null[int64]
if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil { if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil {
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
net.TLS = tls.Bool net.TLS = tls.V
net.Pass = pass.String net.Pass = pass.V
net.Nick = nick.String net.Nick = nick.V
net.InviteCommand = inviteCmd.String net.InviteCommand = inviteCmd.V
net.BouncerAddr = bouncerAddr.String net.BouncerAddr = bouncerAddr.V
net.Auth.Account = account.V
net.Auth.Password = password.V
net.Auth.Account = account.String net.ProxyId = proxyId.V
net.Auth.Password = password.String
networks = append(networks, net) networks = append(networks, net)
} }
@ -159,7 +162,7 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork,
func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) { func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network"). From("irc_network").
OrderBy("name ASC") OrderBy("name ASC")
@ -179,22 +182,24 @@ func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error)
for rows.Next() { for rows.Next() {
var net domain.IrcNetwork var net domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.NullString var account, password sql.Null[string]
var tls sql.NullBool var tls sql.Null[bool]
var proxyId sql.Null[int64]
if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil { if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil {
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
net.TLS = tls.Bool net.TLS = tls.V
net.Pass = pass.String net.Pass = pass.V
net.Nick = nick.String net.Nick = nick.V
net.InviteCommand = inviteCmd.String net.InviteCommand = inviteCmd.V
net.BouncerAddr = bouncerAddr.String net.BouncerAddr = bouncerAddr.V
net.Auth.Account = account.V
net.Auth.Password = password.V
net.Auth.Account = account.String net.ProxyId = proxyId.V
net.Auth.Password = password.String
networks = append(networks, net) networks = append(networks, net)
} }
@ -225,13 +230,13 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
var channels []domain.IrcChannel var channels []domain.IrcChannel
for rows.Next() { for rows.Next() {
var ch domain.IrcChannel var ch domain.IrcChannel
var pass sql.NullString var pass sql.Null[string]
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Enabled, &pass); err != nil { if err := rows.Scan(&ch.ID, &ch.Name, &ch.Enabled, &pass); err != nil {
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
ch.Password = pass.String ch.Password = pass.V
channels = append(channels, ch) channels = append(channels, ch)
} }
@ -244,7 +249,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) { func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode"). Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network"). From("irc_network").
Where(sq.Eq{"server": network.Server}). Where(sq.Eq{"server": network.Server}).
Where(sq.Eq{"port": network.Port}). Where(sq.Eq{"port": network.Port}).
@ -263,11 +268,12 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN
var net domain.IrcNetwork var net domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.NullString var account, password sql.Null[string]
var tls sql.NullBool var tls sql.Null[bool]
var proxyId sql.Null[int64]
if err = row.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil { if err = row.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
// no result is not an error in our case // no result is not an error in our case
return nil, nil return nil, nil
@ -276,29 +282,20 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN
return nil, errors.Wrap(err, "error scanning row") return nil, errors.Wrap(err, "error scanning row")
} }
net.TLS = tls.Bool net.TLS = tls.V
net.Pass = pass.String net.Pass = pass.V
net.Nick = nick.String net.Nick = nick.V
net.InviteCommand = inviteCmd.String net.InviteCommand = inviteCmd.V
net.BouncerAddr = bouncerAddr.String net.BouncerAddr = bouncerAddr.V
net.Auth.Account = account.String net.Auth.Account = account.V
net.Auth.Password = password.String net.Auth.Password = password.V
net.ProxyId = proxyId.V
return &net, nil return &net, nil
} }
func (r *IrcRepo) StoreNetwork(ctx context.Context, network *domain.IrcNetwork) error { func (r *IrcRepo) StoreNetwork(ctx context.Context, network *domain.IrcNetwork) error {
netName := toNullString(network.Name)
pass := toNullString(network.Pass)
nick := toNullString(network.Nick)
inviteCmd := toNullString(network.InviteCommand)
bouncerAddr := toNullString(network.BouncerAddr)
account := toNullString(network.Auth.Account)
password := toNullString(network.Auth.Password)
var retID int64
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Insert("irc_network"). Insert("irc_network").
Columns( Columns(
@ -319,60 +316,49 @@ func (r *IrcRepo) StoreNetwork(ctx context.Context, network *domain.IrcNetwork)
). ).
Values( Values(
network.Enabled, network.Enabled,
netName, toNullString(network.Name),
network.Server, network.Server,
network.Port, network.Port,
network.TLS, network.TLS,
pass, toNullString(network.Pass),
nick, toNullString(network.Nick),
network.Auth.Mechanism, network.Auth.Mechanism,
account, toNullString(network.Auth.Account),
password, toNullString(network.Auth.Password),
inviteCmd, toNullString(network.InviteCommand),
bouncerAddr, toNullString(network.BouncerAddr),
network.UseBouncer, network.UseBouncer,
network.BotMode, network.BotMode,
). ).
Suffix("RETURNING id"). Suffix("RETURNING id").
RunWith(r.db.handler) RunWith(r.db.handler)
if err := queryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil { if err := queryBuilder.QueryRowContext(ctx).Scan(&network.ID); err != nil {
return errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }
network.ID = retID
return nil return nil
} }
func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error { func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error {
netName := toNullString(network.Name)
pass := toNullString(network.Pass)
nick := toNullString(network.Nick)
inviteCmd := toNullString(network.InviteCommand)
bouncerAddr := toNullString(network.BouncerAddr)
account := toNullString(network.Auth.Account)
password := toNullString(network.Auth.Password)
var err error
queryBuilder := r.db.squirrel. queryBuilder := r.db.squirrel.
Update("irc_network"). Update("irc_network").
Set("enabled", network.Enabled). Set("enabled", network.Enabled).
Set("name", netName). Set("name", toNullString(network.Name)).
Set("server", network.Server). Set("server", network.Server).
Set("port", network.Port). Set("port", network.Port).
Set("tls", network.TLS). Set("tls", network.TLS).
Set("pass", pass). Set("pass", toNullString(network.Pass)).
Set("nick", nick). Set("nick", toNullString(network.Nick)).
Set("auth_mechanism", network.Auth.Mechanism). Set("auth_mechanism", network.Auth.Mechanism).
Set("auth_account", account). Set("auth_account", toNullString(network.Auth.Account)).
Set("auth_password", password). Set("auth_password", toNullString(network.Auth.Password)).
Set("invite_command", inviteCmd). Set("invite_command", toNullString(network.InviteCommand)).
Set("bouncer_addr", bouncerAddr). Set("bouncer_addr", toNullString(network.BouncerAddr)).
Set("use_bouncer", network.UseBouncer). Set("use_bouncer", network.UseBouncer).
Set("bot_mode", network.BotMode). Set("bot_mode", network.BotMode).
Set("use_proxy", network.UseProxy).
Set("proxy_id", toNullInt64(network.ProxyId)).
Set("updated_at", time.Now().Format(time.RFC3339)). Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": network.ID}) Where(sq.Eq{"id": network.ID})
@ -414,7 +400,6 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
for _, channel := range channels { for _, channel := range channels {
// values // values
pass := toNullString(channel.Password)
channelQueryBuilder := r.db.squirrel. channelQueryBuilder := r.db.squirrel.
Insert("irc_channel"). Insert("irc_channel").
@ -429,21 +414,17 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
channel.Enabled, channel.Enabled,
true, true,
channel.Name, channel.Name,
pass, toNullString(channel.Password),
networkID, networkID,
). ).
Suffix("RETURNING id"). Suffix("RETURNING id").
RunWith(tx) RunWith(tx)
// returning // returning
var retID int64 if err = channelQueryBuilder.QueryRowContext(ctx).Scan(&channel.ID); err != nil {
if err = channelQueryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil {
return errors.Wrap(err, "error executing query storeNetworkChannels") return errors.Wrap(err, "error executing query storeNetworkChannels")
} }
channel.ID = retID
//channelQuery, channelArgs, err := channelQueryBuilder.ToSql() //channelQuery, channelArgs, err := channelQueryBuilder.ToSql()
//if err != nil { //if err != nil {
// r.log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error building query") // r.log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error building query")
@ -467,8 +448,6 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
} }
func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *domain.IrcChannel) error { func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *domain.IrcChannel) error {
pass := toNullString(channel.Password)
if channel.ID != 0 { if channel.ID != 0 {
// update record // update record
channelQueryBuilder := r.db.squirrel. channelQueryBuilder := r.db.squirrel.
@ -476,7 +455,7 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do
Set("enabled", channel.Enabled). Set("enabled", channel.Enabled).
Set("detached", channel.Detached). Set("detached", channel.Detached).
Set("name", channel.Name). Set("name", channel.Name).
Set("password", pass). Set("password", toNullString(channel.Password)).
Where(sq.Eq{"id": channel.ID}) Where(sq.Eq{"id": channel.ID})
query, args, err := channelQueryBuilder.ToSql() query, args, err := channelQueryBuilder.ToSql()
@ -501,21 +480,17 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do
channel.Enabled, channel.Enabled,
true, true,
channel.Name, channel.Name,
pass, toNullString(channel.Password),
networkID, networkID,
). ).
Suffix("RETURNING id"). Suffix("RETURNING id").
RunWith(r.db.handler) RunWith(r.db.handler)
// returning // returning
var retID int64 if err := queryBuilder.QueryRowContext(ctx).Scan(&channel.ID); err != nil {
if err := queryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil {
return errors.Wrap(err, "error executing query") return errors.Wrap(err, "error executing query")
} }
channel.ID = retID
//channelQuery, channelArgs, err := channelQueryBuilder.ToSql() //channelQuery, channelArgs, err := channelQueryBuilder.ToSql()
//if err != nil { //if err != nil {
// r.log.Error().Stack().Err(err).Msg("irc.storeChannel: error building query") // r.log.Error().Stack().Err(err).Msg("irc.storeChannel: error building query")
@ -536,15 +511,13 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do
} }
func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error { func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error {
pass := toNullString(channel.Password)
// update record // update record
channelQueryBuilder := r.db.squirrel. channelQueryBuilder := r.db.squirrel.
Update("irc_channel"). Update("irc_channel").
Set("enabled", channel.Enabled). Set("enabled", channel.Enabled).
Set("detached", channel.Detached). Set("detached", channel.Detached).
Set("name", channel.Name). Set("name", channel.Name).
Set("password", pass). Set("password", toNullString(channel.Password)).
Where(sq.Eq{"id": channel.ID}) Where(sq.Eq{"id": channel.ID})
query, args, err := channelQueryBuilder.ToSql() query, args, err := channelQueryBuilder.ToSql()
@ -561,7 +534,6 @@ func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error {
} }
func (r *IrcRepo) UpdateInviteCommand(networkID int64, invite string) error { func (r *IrcRepo) UpdateInviteCommand(networkID int64, invite string) error {
// update record // update record
channelQueryBuilder := r.db.squirrel. channelQueryBuilder := r.db.squirrel.
Update("irc_network"). Update("irc_network").

View file

@ -14,6 +14,20 @@ CREATE TABLE users
UNIQUE (username) UNIQUE (username)
); );
CREATE TABLE proxy
(
id SERIAL PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE indexer CREATE TABLE indexer
( (
id SERIAL PRIMARY KEY, id SERIAL PRIMARY KEY,
@ -24,8 +38,11 @@ CREATE TABLE indexer
enabled BOOLEAN, enabled BOOLEAN,
name TEXT NOT NULL, name TEXT NOT NULL,
settings TEXT, settings TEXT,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (identifier) UNIQUE (identifier)
); );
@ -51,8 +68,11 @@ CREATE TABLE irc_network
bot_mode BOOLEAN DEFAULT FALSE, bot_mode BOOLEAN DEFAULT FALSE,
connected BOOLEAN, connected BOOLEAN,
connected_since TIMESTAMP, connected_since TIMESTAMP,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (server, port, nick) UNIQUE (server, port, nick)
); );
@ -900,5 +920,39 @@ ADD COLUMN months TEXT;
ALTER TABLE filter ALTER TABLE filter
ADD COLUMN days TEXT; ADD COLUMN days TEXT;
`,
`CREATE TABLE proxy
(
id SERIAL PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE indexer
ADD COLUMN proxy_id INTEGER;
ALTER TABLE indexer
ADD COLUMN use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE indexer
ADD FOREIGN KEY (proxy_id) REFERENCES proxy
ON DELETE SET NULL;
ALTER TABLE irc_network
ADD COLUMN proxy_id INTEGER;
ALTER TABLE irc_network
ADD COLUMN use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE irc_network
ADD FOREIGN KEY (proxy_id) REFERENCES proxy
ON DELETE SET NULL;
`, `,
} }

265
internal/database/proxy.go Normal file
View file

@ -0,0 +1,265 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package database
import (
"context"
"database/sql"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/rs/zerolog"
)
type ProxyRepo struct {
log zerolog.Logger
db *DB
}
func NewProxyRepo(log logger.Logger, db *DB) domain.ProxyRepo {
return &ProxyRepo{
log: log.With().Str("repo", "proxy").Logger(),
db: db,
}
}
func (r *ProxyRepo) Store(ctx context.Context, p *domain.Proxy) error {
queryBuilder := r.db.squirrel.
Insert("proxy").
Columns(
"enabled",
"name",
"type",
"addr",
"auth_user",
"auth_pass",
"timeout",
).
Values(
p.Enabled,
p.Name,
p.Type,
toNullString(p.Addr),
toNullString(p.User),
toNullString(p.Pass),
p.Timeout,
).
Suffix("RETURNING id").
RunWith(r.db.handler)
var retID int64
err := queryBuilder.QueryRowContext(ctx).Scan(&retID)
if err != nil {
return errors.Wrap(err, "error executing query")
}
p.ID = retID
return nil
}
func (r *ProxyRepo) Update(ctx context.Context, p *domain.Proxy) error {
queryBuilder := r.db.squirrel.
Update("proxy").
Set("enabled", p.Enabled).
Set("name", p.Name).
Set("type", p.Type).
Set("addr", p.Addr).
Set("auth_user", toNullString(p.User)).
Set("auth_pass", toNullString(p.Pass)).
Set("timeout", p.Timeout).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": p.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
return errors.Wrap(err, "error building query")
}
// update record
res, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return errors.Wrap(err, "error getting affected rows")
}
if rowsAffected == 0 {
return domain.ErrUpdateFailed
}
return err
}
func (r *ProxyRepo) List(ctx context.Context) ([]domain.Proxy, error) {
queryBuilder := r.db.squirrel.
Select(
"id",
"enabled",
"name",
"type",
"addr",
"auth_user",
"auth_pass",
"timeout",
).
From("proxy").
OrderBy("name ASC")
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
rows, err := r.db.handler.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
defer rows.Close()
proxies := make([]domain.Proxy, 0)
for rows.Next() {
var proxy domain.Proxy
var user, pass sql.NullString
if err := rows.Scan(&proxy.ID, &proxy.Enabled, &proxy.Name, &proxy.Type, &proxy.Addr, &user, &pass, &proxy.Timeout); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
proxy.User = user.String
proxy.Pass = pass.String
proxies = append(proxies, proxy)
}
err = rows.Err()
if err != nil {
return nil, errors.Wrap(err, "error row")
}
return proxies, nil
}
func (r *ProxyRepo) Delete(ctx context.Context, id int64) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "error begin transaction")
}
defer tx.Rollback()
queryBuilder := r.db.squirrel.
Delete("proxy").
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
return errors.Wrap(err, "error building query")
}
res, err := tx.ExecContext(ctx, query, args...)
if err != nil {
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return errors.Wrap(err, "error getting affected rows")
}
if rowsAffected == 0 {
return domain.ErrDeleteFailed
}
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "error commit deleting proxy")
}
return nil
}
func (r *ProxyRepo) FindByID(ctx context.Context, id int64) (*domain.Proxy, error) {
queryBuilder := r.db.squirrel.
Select(
"id",
"enabled",
"name",
"type",
"addr",
"auth_user",
"auth_pass",
"timeout",
).
From("proxy").
OrderBy("name ASC").
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Err(); err != nil {
return nil, errors.Wrap(err, "error executing query")
}
var proxy domain.Proxy
var user, pass sql.NullString
err = row.Scan(&proxy.ID, &proxy.Enabled, &proxy.Name, &proxy.Type, &proxy.Addr, &user, &pass, &proxy.Timeout)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
proxy.User = user.String
proxy.Pass = pass.String
return &proxy, nil
}
func (r *ProxyRepo) ToggleEnabled(ctx context.Context, id int64, enabled bool) error {
queryBuilder := r.db.squirrel.
Update("proxy").
Set("enabled", enabled).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
return errors.Wrap(err, "error building query")
}
// update record
res, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return errors.Wrap(err, "error getting affected rows")
}
if rowsAffected == 0 {
return domain.ErrUpdateFailed
}
return nil
}

View file

@ -0,0 +1,220 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (
"context"
"fmt"
"testing"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/stretchr/testify/assert"
)
func getMockProxy() *domain.Proxy {
return &domain.Proxy{
//ID: 0,
Name: "Proxy",
Enabled: true,
Type: domain.ProxyTypeSocks5,
Addr: "socks5://127.0.0.1:1080",
User: "",
Pass: "",
Timeout: 0,
}
}
func TestProxyRepo_Store(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, mockData.Name, proxies[0].Name)
// Cleanup
_ = repo.Delete(context.Background(), mockData.ID)
})
t.Run(fmt.Sprintf("Store_Fails_With_Missing_or_empty_fields [%s]", dbType), func(t *testing.T) {
mockData := domain.Proxy{}
err := repo.Store(context.Background(), &mockData)
assert.Error(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.Empty(t, proxies)
//assert.Nil(t, proxies)
// Cleanup
// No cleanup needed
})
t.Run(fmt.Sprintf("Store_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
err := repo.Store(ctx, mockData)
assert.Error(t, err)
})
}
}
func TestProxyRepo_Update(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("Update_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
// Update mockData
updatedProxy := mockData
updatedProxy.Name = "Updated Proxy"
updatedProxy.Enabled = false
// Execute
err = repo.Update(context.Background(), updatedProxy)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, "Updated Proxy", proxies[0].Name)
assert.Equal(t, false, proxies[0].Enabled)
// Cleanup
_ = repo.Delete(context.Background(), proxies[0].ID)
})
t.Run(fmt.Sprintf("Update_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) {
mockData.ID = -1
err := repo.Update(context.Background(), mockData)
assert.Error(t, err)
assert.ErrorIs(t, err, domain.ErrUpdateFailed)
})
}
}
func TestProxyRepo_Delete(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, mockData.Name, proxies[0].Name)
// Execute
err = repo.Delete(context.Background(), proxies[0].ID)
assert.NoError(t, err)
// Verify that the proxy is deleted and return error ErrRecordNotFound
proxy, err := repo.FindByID(context.Background(), proxies[0].ID)
assert.ErrorIs(t, err, domain.ErrRecordNotFound)
assert.Nil(t, proxy)
})
t.Run(fmt.Sprintf("Delete_Fails_No_Record [%s]", dbType), func(t *testing.T) {
err := repo.Delete(context.Background(), 9999)
assert.Error(t, err)
assert.ErrorIs(t, err, domain.ErrDeleteFailed)
})
}
}
func TestProxyRepo_ToggleEnabled(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("ToggleEnabled_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, true, proxies[0].Enabled)
// Execute
err = repo.ToggleEnabled(context.Background(), mockData.ID, false)
assert.NoError(t, err)
// Verify that the proxy is updated
proxy, err := repo.FindByID(context.Background(), proxies[0].ID)
assert.NoError(t, err)
assert.NotNil(t, proxy)
assert.Equal(t, false, proxy.Enabled)
// Cleanup
_ = repo.Delete(context.Background(), proxies[0].ID)
})
t.Run(fmt.Sprintf("ToggleEnabled_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) {
err := repo.ToggleEnabled(context.Background(), -1, false)
assert.Error(t, err)
assert.ErrorIs(t, err, domain.ErrUpdateFailed)
})
}
}
func TestProxyRepo_FindByID(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
// Execute
proxy, err := repo.FindByID(context.Background(), proxies[0].ID)
assert.NoError(t, err)
assert.NotNil(t, proxy)
assert.Equal(t, proxies[0].ID, proxy.ID)
// Cleanup
_ = repo.Delete(context.Background(), proxies[0].ID)
})
t.Run(fmt.Sprintf("FindByID_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) {
// Test using an invalid ID
proxy, err := repo.FindByID(context.Background(), -1)
assert.ErrorIs(t, err, domain.ErrRecordNotFound) // should return an error
assert.Nil(t, proxy) // should be nil
})
}
}

View file

@ -14,6 +14,20 @@ CREATE TABLE users
UNIQUE (username) UNIQUE (username)
); );
CREATE TABLE proxy
(
id INTEGER PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE indexer CREATE TABLE indexer
( (
id INTEGER PRIMARY KEY, id INTEGER PRIMARY KEY,
@ -24,8 +38,11 @@ CREATE TABLE indexer
enabled BOOLEAN, enabled BOOLEAN,
name TEXT NOT NULL, name TEXT NOT NULL,
settings TEXT, settings TEXT,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (identifier) UNIQUE (identifier)
); );
@ -51,8 +68,11 @@ CREATE TABLE irc_network
bot_mode BOOLEAN DEFAULT FALSE, bot_mode BOOLEAN DEFAULT FALSE,
connected BOOLEAN, connected BOOLEAN,
connected_since TIMESTAMP, connected_since TIMESTAMP,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (server, port, nick) UNIQUE (server, port, nick)
); );
@ -1538,5 +1558,37 @@ ADD COLUMN months TEXT;
ALTER TABLE filter ALTER TABLE filter
ADD COLUMN days TEXT; ADD COLUMN days TEXT;
`,
`CREATE TABLE proxy
(
id INTEGER PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE indexer
ADD proxy_id INTEGER
CONSTRAINT indexer_proxy_id_fk
REFERENCES proxy(id)
ON DELETE SET NULL;
ALTER TABLE indexer
ADD use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE irc_network
ADD use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE irc_network
ADD proxy_id INTEGER
CONSTRAINT irc_network_proxy_id_fk
REFERENCES proxy(id)
ON DELETE SET NULL;
`, `,
} }

View file

@ -16,10 +16,10 @@ func dataSourceName(configPath string, name string) string {
return name return name
} }
func toNullString(s string) sql.NullString { func toNullString(s string) sql.Null[string] {
return sql.NullString{ return sql.Null[string]{
String: s, V: s,
Valid: s != "", Valid: s != "",
} }
} }

View file

@ -5,7 +5,6 @@ package domain
import ( import (
"context" "context"
"os"
"strings" "strings"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
@ -60,43 +59,69 @@ type Action struct {
Client *DownloadClient `json:"client,omitempty"` Client *DownloadClient `json:"client,omitempty"`
} }
// ParseMacros parse all macros on action // CheckMacrosNeedTorrentTmpFile check if macros needs torrent downloaded
func (a *Action) ParseMacros(release *Release) error { func (a *Action) CheckMacrosNeedTorrentTmpFile(release *Release) bool {
var err error
if release.TorrentTmpFile == "" && if release.TorrentTmpFile == "" &&
(strings.Contains(a.ExecArgs, "TorrentPathName") || strings.Contains(a.ExecArgs, "TorrentDataRawBytes") || (strings.Contains(a.ExecArgs, "TorrentPathName") ||
strings.Contains(a.WebhookData, "TorrentPathName") || strings.Contains(a.WebhookData, "TorrentDataRawBytes") || strings.Contains(a.ExecArgs, "TorrentDataRawBytes") ||
strings.Contains(a.SavePath, "TorrentPathName") || a.Type == ActionTypeWatchFolder) { strings.Contains(a.WebhookData, "TorrentPathName") ||
if err := release.DownloadTorrentFile(); err != nil { strings.Contains(a.WebhookData, "TorrentDataRawBytes") ||
return errors.Wrap(err, "webhook: could not download torrent file for release: %v", release.TorrentName) strings.Contains(a.SavePath, "TorrentPathName") ||
} a.Type == ActionTypeWatchFolder) {
return true
} }
return false
}
func (a *Action) CheckMacrosNeedRawDataBytes(release *Release) bool {
// if webhook data contains TorrentDataRawBytes, lets read the file into bytes we can then use in the macro // if webhook data contains TorrentDataRawBytes, lets read the file into bytes we can then use in the macro
if len(release.TorrentDataRawBytes) == 0 && if len(release.TorrentDataRawBytes) == 0 &&
(strings.Contains(a.ExecArgs, "TorrentDataRawBytes") || strings.Contains(a.WebhookData, "TorrentDataRawBytes") || (strings.Contains(a.ExecArgs, "TorrentDataRawBytes") || strings.Contains(a.WebhookData, "TorrentDataRawBytes") ||
a.Type == ActionTypeWatchFolder) { a.Type == ActionTypeWatchFolder) {
t, err := os.ReadFile(release.TorrentTmpFile) return true
if err != nil {
return errors.Wrap(err, "could not read torrent file: %v", release.TorrentTmpFile)
}
release.TorrentDataRawBytes = t
} }
return false
}
// ParseMacros parse all macros on action
func (a *Action) ParseMacros(release *Release) error {
var err error
m := NewMacro(*release) m := NewMacro(*release)
a.ExecArgs, err = m.Parse(a.ExecArgs) a.ExecArgs, err = m.Parse(a.ExecArgs)
a.WatchFolder, err = m.Parse(a.WatchFolder)
a.Category, err = m.Parse(a.Category)
a.Tags, err = m.Parse(a.Tags)
a.Label, err = m.Parse(a.Label)
a.SavePath, err = m.Parse(a.SavePath)
a.WebhookData, err = m.Parse(a.WebhookData)
if err != nil { if err != nil {
return errors.Wrap(err, "could not parse macros for action: %v", a.Name) return errors.Wrap(err, "could not parse exec args")
}
a.WatchFolder, err = m.Parse(a.WatchFolder)
if err != nil {
return errors.Wrap(err, "could not parse watch folder")
}
a.Category, err = m.Parse(a.Category)
if err != nil {
return errors.Wrap(err, "could not parse category")
}
a.Tags, err = m.Parse(a.Tags)
if err != nil {
return errors.Wrap(err, "could not parse tags")
}
a.Label, err = m.Parse(a.Label)
if err != nil {
return errors.Wrap(err, "could not parse label")
}
a.SavePath, err = m.Parse(a.SavePath)
if err != nil {
return errors.Wrap(err, "could not parse save_path")
}
a.WebhookData, err = m.Parse(a.WebhookData)
if err != nil {
return errors.Wrap(err, "could not parse webhook_data")
} }
return nil return nil

View file

@ -3,8 +3,14 @@
package domain package domain
import "database/sql" import (
"database/sql"
"github.com/autobrr/autobrr/pkg/errors"
)
var ( var (
ErrRecordNotFound = sql.ErrNoRows ErrRecordNotFound = sql.ErrNoRows
ErrUpdateFailed = errors.New("update failed")
ErrDeleteFailed = errors.New("delete failed")
) )

View file

@ -53,6 +53,11 @@ type Feed struct {
LastRun time.Time `json:"last_run"` LastRun time.Time `json:"last_run"`
LastRunData string `json:"last_run_data"` LastRunData string `json:"last_run_data"`
NextRun time.Time `json:"next_run"` NextRun time.Time `json:"next_run"`
// belongs to Indexer
ProxyID int64
UseProxy bool
Proxy *Proxy
} }
type FeedSettingsJSON struct { type FeedSettingsJSON struct {

View file

@ -34,6 +34,9 @@ type Indexer struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Implementation string `json:"implementation"` Implementation string `json:"implementation"`
BaseURL string `json:"base_url,omitempty"` BaseURL string `json:"base_url,omitempty"`
UseProxy bool `json:"use_proxy"`
Proxy *Proxy `json:"proxy"`
ProxyID int64 `json:"proxy_id"`
Settings map[string]string `json:"settings,omitempty"` Settings map[string]string `json:"settings,omitempty"`
} }
@ -66,6 +69,8 @@ type IndexerDefinition struct {
Protocol string `json:"protocol"` Protocol string `json:"protocol"`
URLS []string `json:"urls"` URLS []string `json:"urls"`
Supports []string `json:"supports"` Supports []string `json:"supports"`
UseProxy bool `json:"use_proxy"`
ProxyID int64 `json:"proxy_id"`
Settings []IndexerSetting `json:"settings,omitempty"` Settings []IndexerSetting `json:"settings,omitempty"`
SettingsMap map[string]string `json:"-"` SettingsMap map[string]string `json:"-"`
IRC *IndexerIRC `json:"irc,omitempty"` IRC *IndexerIRC `json:"irc,omitempty"`

View file

@ -48,6 +48,9 @@ type IrcNetwork struct {
InviteCommand string `json:"invite_command"` InviteCommand string `json:"invite_command"`
UseBouncer bool `json:"use_bouncer"` UseBouncer bool `json:"use_bouncer"`
BouncerAddr string `json:"bouncer_addr"` BouncerAddr string `json:"bouncer_addr"`
UseProxy bool `json:"use_proxy"`
ProxyId int64 `json:"proxy_id"`
Proxy *Proxy `json:"proxy"`
BotMode bool `json:"bot_mode"` BotMode bool `json:"bot_mode"`
Channels []IrcChannel `json:"channels"` Channels []IrcChannel `json:"channels"`
Connected bool `json:"connected"` Connected bool `json:"connected"`
@ -70,6 +73,9 @@ type IrcNetworkWithHealth struct {
BotMode bool `json:"bot_mode"` BotMode bool `json:"bot_mode"`
CurrentNick string `json:"current_nick"` CurrentNick string `json:"current_nick"`
PreferredNick string `json:"preferred_nick"` PreferredNick string `json:"preferred_nick"`
UseProxy bool `json:"use_proxy"`
ProxyId int64 `json:"proxy_id"`
Proxy *Proxy `json:"proxy"`
Channels []ChannelWithHealth `json:"channels"` Channels []ChannelWithHealth `json:"channels"`
Connected bool `json:"connected"` Connected bool `json:"connected"`
ConnectedSince time.Time `json:"connected_since"` ConnectedSince time.Time `json:"connected_since"`

View file

@ -163,6 +163,8 @@ func (m Macro) Parse(text string) (string, error) {
return "", nil return "", nil
} }
// TODO implement template cache
// setup template // setup template
tmpl, err := template.New("macro").Funcs(sprig.TxtFuncMap()).Parse(text) tmpl, err := template.New("macro").Funcs(sprig.TxtFuncMap()).Parse(text)
if err != nil { if err != nil {

78
internal/domain/proxy.go Normal file
View file

@ -0,0 +1,78 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package domain
import (
"context"
"net/url"
"github.com/autobrr/autobrr/pkg/errors"
)
type ProxyRepo interface {
Store(ctx context.Context, p *Proxy) error
Update(ctx context.Context, p *Proxy) error
List(ctx context.Context) ([]Proxy, error)
Delete(ctx context.Context, id int64) error
FindByID(ctx context.Context, id int64) (*Proxy, error)
ToggleEnabled(ctx context.Context, id int64, enabled bool) error
}
type Proxy struct {
ID int64 `json:"id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Type ProxyType `json:"type"`
Addr string `json:"addr"`
User string `json:"user"`
Pass string `json:"pass"`
Timeout int `json:"timeout"`
}
type ProxyType string
const (
ProxyTypeSocks5 = "SOCKS5"
)
func (p Proxy) ValidProxyType() bool {
if p.Type == ProxyTypeSocks5 {
return true
}
return false
}
func (p Proxy) Validate() error {
if !p.ValidProxyType() {
return errors.New("invalid proxy type: %s", p.Type)
}
if err := ValidateProxyAddr(p.Addr); err != nil {
return err
}
if p.Name == "" {
return errors.New("name is required")
}
return nil
}
func ValidateProxyAddr(addr string) error {
if addr == "" {
return errors.New("addr is required")
}
proxyUrl, err := url.Parse(addr)
if err != nil {
return errors.Wrap(err, "could not parse proxy url: %s", addr)
}
if proxyUrl.Scheme != "socks5" && proxyUrl.Scheme != "socks5h" {
return errors.New("proxy url scheme must be socks5 or socks5h")
}
return nil
}

View file

@ -356,8 +356,6 @@ func (r *Release) ParseString(title string) {
r.ParseReleaseTagsString(r.ReleaseTags) r.ParseReleaseTagsString(r.ReleaseTags)
} }
var ErrUnrecoverableError = errors.New("unrecoverable error")
func (r *Release) ParseReleaseTagsString(tags string) { func (r *Release) ParseReleaseTagsString(tags string) {
cleanTags := CleanReleaseTags(tags) cleanTags := CleanReleaseTags(tags)
t := ParseReleaseTagString(cleanTags) t := ParseReleaseTagString(cleanTags)
@ -432,10 +430,6 @@ func (r *Release) DownloadTorrentFileCtx(ctx context.Context) error {
return r.downloadTorrentFile(ctx) return r.downloadTorrentFile(ctx)
} }
func (r *Release) DownloadTorrentFile() error {
return r.downloadTorrentFile(context.Background())
}
func (r *Release) downloadTorrentFile(ctx context.Context) error { func (r *Release) downloadTorrentFile(ctx context.Context) error {
if r.HasMagnetUri() { if r.HasMagnetUri() {
return errors.New("downloading magnet links is not supported: %s", r.MagnetURI) return errors.New("downloading magnet links is not supported: %s", r.MagnetURI)
@ -592,7 +586,7 @@ func (r *Release) downloadTorrentFile(ctx context.Context) error {
} }
func (r *Release) CleanupTemporaryFiles() { func (r *Release) CleanupTemporaryFiles() {
if len(r.TorrentTmpFile) == 0 { if r.TorrentTmpFile == "" {
return return
} }
@ -600,54 +594,15 @@ func (r *Release) CleanupTemporaryFiles() {
r.TorrentTmpFile = "" r.TorrentTmpFile = ""
} }
// HasMagnetUri check uf MagnetURI is set or empty // HasMagnetUri check uf MagnetURI is set and valid or empty
func (r *Release) HasMagnetUri() bool { func (r *Release) HasMagnetUri() bool {
return r.MagnetURI != "" if r.MagnetURI != "" && strings.HasPrefix(r.MagnetURI, MagnetURIPrefix) {
return true
}
return false
} }
func (r *Release) ResolveMagnetUri(ctx context.Context) error { const MagnetURIPrefix = "magnet:?"
if r.MagnetURI == "" {
return nil
} else if strings.HasPrefix(r.MagnetURI, "magnet:?") {
return nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.MagnetURI, nil)
if err != nil {
return errors.Wrap(err, "could not build request to resolve magnet uri")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "autobrr")
client := &http.Client{
Timeout: time.Second * 45,
Transport: sharedhttp.MagnetTransport,
}
res, err := client.Do(req)
if err != nil {
return errors.Wrap(err, "could not make request to resolve magnet uri")
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return errors.New("unexpected status code: %d", res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return errors.Wrap(err, "could not read response body")
}
magnet := string(body)
if magnet != "" {
r.MagnetURI = magnet
}
return nil
}
func (r *Release) addRejection(reason string) { func (r *Release) addRejection(reason string) {
r.Rejections = append(r.Rejections, reason) r.Rejections = append(r.Rejections, reason)

View file

@ -6,6 +6,7 @@
package domain package domain
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -290,7 +291,7 @@ func TestRelease_DownloadTorrentFile(t *testing.T) {
Filter: tt.fields.Filter, Filter: tt.fields.Filter,
ActionStatus: tt.fields.ActionStatus, ActionStatus: tt.fields.ActionStatus,
} }
err := r.DownloadTorrentFile() err := r.DownloadTorrentFileCtx(context.Background())
if err == nil && tt.wantErr { if err == nil && tt.wantErr {
fmt.Println("error") fmt.Println("error")
} }

View file

@ -42,10 +42,23 @@ func NewFeedParser(timeout time.Duration, cookie string) *RSSParser {
} }
c.http.Timeout = timeout c.http.Timeout = timeout
c.parser.Client = httpClient
return c return c
} }
func (c *RSSParser) WithHTTPClient(client *http.Client) {
httpClient := client
if client.Jar == nil {
jarOptions := &cookiejar.Options{PublicSuffixList: publicsuffix.List}
jar, _ := cookiejar.New(jarOptions)
httpClient.Jar = jar
}
c.http = httpClient
c.parser.Client = httpClient
}
func (c *RSSParser) ParseURLWithContext(ctx context.Context, feedURL string) (feed *gofeed.Feed, err error) { func (c *RSSParser) ParseURLWithContext(ctx context.Context, feedURL string) (feed *gofeed.Feed, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil)
if err != nil { if err != nil {

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/internal/scheduler" "github.com/autobrr/autobrr/internal/scheduler"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
@ -129,6 +130,18 @@ func (j *NewznabJob) process(ctx context.Context) error {
} }
func (j *NewznabJob) getFeed(ctx context.Context) ([]newznab.FeedItem, error) { func (j *NewznabJob) getFeed(ctx context.Context) ([]newznab.FeedItem, error) {
// add proxy if enabled and exists
if j.Feed.UseProxy && j.Feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy)
if err != nil {
return nil, errors.Wrap(err, "could not get proxy client")
}
j.Client.WithHTTPClient(proxyClient)
j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name)
}
// get feed // get feed
feed, err := j.Client.GetFeed(ctx) feed, err := j.Client.GetFeed(ctx)
if err != nil { if err != nil {
@ -156,36 +169,34 @@ func (j *NewznabJob) getFeed(ctx context.Context) ([]newznab.FeedItem, error) {
// set ttl to 1 month // set ttl to 1 month
ttl := time.Now().AddDate(0, 1, 0) ttl := time.Now().AddDate(0, 1, 0)
for _, i := range feed.Channel.Items { for _, item := range feed.Channel.Items {
i := i if item.GUID == "" {
if i.GUID == "" {
j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name) j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name)
continue continue
} }
exists, err := j.CacheRepo.Exists(j.Feed.ID, i.GUID) exists, err := j.CacheRepo.Exists(j.Feed.ID, item.GUID)
if err != nil { if err != nil {
j.Log.Error().Err(err).Msg("could not check if item exists") j.Log.Error().Err(err).Msg("could not check if item exists")
continue continue
} }
if exists { if exists {
j.Log.Trace().Msgf("cache item exists, skipping release: %s", i.Title) j.Log.Trace().Msgf("cache item exists, skipping release: %s", item.Title)
continue continue
} }
j.Log.Debug().Msgf("found new release: %s", i.Title) j.Log.Debug().Msgf("found new release: %s", item.Title)
toCache = append(toCache, domain.FeedCacheItem{ toCache = append(toCache, domain.FeedCacheItem{
FeedId: strconv.Itoa(j.Feed.ID), FeedId: strconv.Itoa(j.Feed.ID),
Key: i.GUID, Key: item.GUID,
Value: []byte(i.Title), Value: []byte(item.Title),
TTL: ttl, TTL: ttl,
}) })
// only append if we successfully added to cache // only append if we successfully added to cache
items = append(items, *i) items = append(items, *item)
} }
if len(toCache) > 0 { if len(toCache) > 0 {

View file

@ -13,6 +13,7 @@ import (
"time" "time"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
@ -93,7 +94,6 @@ func (j *RSSJob) process(ctx context.Context) error {
releases := make([]*domain.Release, 0) releases := make([]*domain.Release, 0)
for _, item := range items { for _, item := range items {
item := item
j.Log.Debug().Msgf("item: %v", item.Title) j.Log.Debug().Msgf("item: %v", item.Title)
rls := j.processItem(item) rls := j.processItem(item)
@ -139,7 +139,7 @@ func (j *RSSJob) processItem(item *gofeed.Item) *domain.Release {
} }
if j.Feed.Settings != nil && j.Feed.Settings.DownloadType == domain.FeedDownloadTypeMagnet { if j.Feed.Settings != nil && j.Feed.Settings.DownloadType == domain.FeedDownloadTypeMagnet {
if !strings.HasPrefix(rls.MagnetURI, "magnet:?") && strings.HasPrefix(e.URL, "magnet:?") { if !strings.HasPrefix(rls.MagnetURI, domain.MagnetURIPrefix) && strings.HasPrefix(e.URL, domain.MagnetURIPrefix) {
rls.MagnetURI = e.URL rls.MagnetURI = e.URL
rls.DownloadURL = "" rls.DownloadURL = ""
} }
@ -232,7 +232,20 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error)
ctx, cancel := context.WithTimeout(ctx, j.Timeout) ctx, cancel := context.WithTimeout(ctx, j.Timeout)
defer cancel() defer cancel()
feed, err := NewFeedParser(j.Timeout, j.Feed.Cookie).ParseURLWithContext(ctx, j.URL) feedParser := NewFeedParser(j.Timeout, j.Feed.Cookie)
if j.Feed.UseProxy && j.Feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy)
if err != nil {
return nil, errors.Wrap(err, "could not get proxy client")
}
feedParser.WithHTTPClient(proxyClient)
j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name)
}
feed, err := feedParser.ParseURLWithContext(ctx, j.URL)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error fetching rss feed items") return nil, errors.Wrap(err, "error fetching rss feed items")
} }
@ -257,9 +270,7 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error)
// set ttl to 1 month // set ttl to 1 month
ttl := time.Now().AddDate(0, 1, 0) ttl := time.Now().AddDate(0, 1, 0)
for _, i := range feed.Items { for _, item := range feed.Items {
item := i
key := item.GUID key := item.GUID
if len(key) == 0 { if len(key) == 0 {
key = item.Link key = item.Link
@ -278,12 +289,12 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error)
continue continue
} }
j.Log.Debug().Msgf("found new release: %s", i.Title) j.Log.Debug().Msgf("found new release: %s", item.Title)
toCache = append(toCache, domain.FeedCacheItem{ toCache = append(toCache, domain.FeedCacheItem{
FeedId: strconv.Itoa(j.Feed.ID), FeedId: strconv.Itoa(j.Feed.ID),
Key: key, Key: key,
Value: []byte(i.Title), Value: []byte(item.Title),
TTL: ttl, TTL: ttl,
}) })

View file

@ -11,6 +11,7 @@ import (
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/internal/scheduler" "github.com/autobrr/autobrr/internal/scheduler"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
@ -68,16 +69,18 @@ type service struct {
repo domain.FeedRepo repo domain.FeedRepo
cacheRepo domain.FeedCacheRepo cacheRepo domain.FeedCacheRepo
releaseSvc release.Service releaseSvc release.Service
proxySvc proxy.Service
scheduler scheduler.Service scheduler scheduler.Service
} }
func NewService(log logger.Logger, repo domain.FeedRepo, cacheRepo domain.FeedCacheRepo, releaseSvc release.Service, scheduler scheduler.Service) Service { func NewService(log logger.Logger, repo domain.FeedRepo, cacheRepo domain.FeedCacheRepo, releaseSvc release.Service, proxySvc proxy.Service, scheduler scheduler.Service) Service {
return &service{ return &service{
log: log.With().Str("module", "feed").Logger(), log: log.With().Str("module", "feed").Logger(),
jobs: map[string]int{}, jobs: map[string]int{},
repo: repo, repo: repo,
cacheRepo: cacheRepo, cacheRepo: cacheRepo,
releaseSvc: releaseSvc, releaseSvc: releaseSvc,
proxySvc: proxySvc,
scheduler: scheduler, scheduler: scheduler,
} }
} }
@ -150,6 +153,13 @@ func (s *service) update(ctx context.Context, feed *domain.Feed) error {
return err return err
} }
// get Feed again for ProxyID and UseProxy to be correctly populated
feed, err := s.repo.FindByID(ctx, feed.ID)
if err != nil {
s.log.Error().Err(err).Msg("error finding feed")
return err
}
if err := s.restartJob(feed); err != nil { if err := s.restartJob(feed); err != nil {
s.log.Error().Err(err).Msg("error restarting feed") s.log.Error().Err(err).Msg("error restarting feed")
return err return err
@ -227,6 +237,18 @@ func (s *service) test(ctx context.Context, feed *domain.Feed) error {
// create sub logger // create sub logger
subLogger := zstdlog.NewStdLoggerWithLevel(s.log.With().Logger(), zerolog.DebugLevel) subLogger := zstdlog.NewStdLoggerWithLevel(s.log.With().Logger(), zerolog.DebugLevel)
// add proxy conf
if feed.UseProxy {
proxyConf, err := s.proxySvc.FindByID(ctx, feed.ProxyID)
if err != nil {
return errors.Wrap(err, "could not find proxy for indexer feed")
}
if proxyConf.Enabled {
feed.Proxy = proxyConf
}
}
// test feeds // test feeds
switch feed.Type { switch feed.Type {
case string(domain.FeedTypeTorznab): case string(domain.FeedTypeTorznab):
@ -254,13 +276,27 @@ func (s *service) test(ctx context.Context, feed *domain.Feed) error {
} }
func (s *service) testRSS(ctx context.Context, feed *domain.Feed) error { func (s *service) testRSS(ctx context.Context, feed *domain.Feed) error {
f, err := NewFeedParser(time.Duration(feed.Timeout)*time.Second, feed.Cookie).ParseURLWithContext(ctx, feed.URL) feedParser := NewFeedParser(time.Duration(feed.Timeout)*time.Second, feed.Cookie)
// add proxy if enabled and exists
if feed.UseProxy && feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxy client")
}
feedParser.WithHTTPClient(proxyClient)
s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name)
}
feedResponse, err := feedParser.ParseURLWithContext(ctx, feed.URL)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("error fetching rss feed items") s.log.Error().Err(err).Msgf("error fetching rss feed items")
return errors.Wrap(err, "error fetching rss feed items") return errors.Wrap(err, "error fetching rss feed items")
} }
s.log.Info().Msgf("refreshing rss feed: %s, found (%d) items", feed.Name, len(f.Items)) s.log.Info().Msgf("refreshing rss feed: %s, found (%d) items", feed.Name, len(feedResponse.Items))
return nil return nil
} }
@ -269,6 +305,18 @@ func (s *service) testTorznab(ctx context.Context, feed *domain.Feed, subLogger
// setup torznab Client // setup torznab Client
c := torznab.NewClient(torznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger}) c := torznab.NewClient(torznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger})
// add proxy if enabled and exists
if feed.UseProxy && feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxy client")
}
c.WithHTTPClient(proxyClient)
s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name)
}
items, err := c.FetchFeed(ctx) items, err := c.FetchFeed(ctx)
if err != nil { if err != nil {
s.log.Error().Err(err).Msg("error getting torznab feed") s.log.Error().Err(err).Msg("error getting torznab feed")
@ -284,6 +332,18 @@ func (s *service) testNewznab(ctx context.Context, feed *domain.Feed, subLogger
// setup newznab Client // setup newznab Client
c := newznab.NewClient(newznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger}) c := newznab.NewClient(newznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger})
// add proxy if enabled and exists
if feed.UseProxy && feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxy client")
}
c.WithHTTPClient(proxyClient)
s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name)
}
items, err := c.GetFeed(ctx) items, err := c.GetFeed(ctx)
if err != nil { if err != nil {
s.log.Error().Err(err).Msg("error getting newznab feed") s.log.Error().Err(err).Msg("error getting newznab feed")
@ -316,8 +376,6 @@ func (s *service) start() error {
s.log.Debug().Msgf("preparing staggered start of %d feeds", len(feeds)) s.log.Debug().Msgf("preparing staggered start of %d feeds", len(feeds))
for _, feed := range feeds { for _, feed := range feeds {
feed := feed
if !feed.Enabled { if !feed.Enabled {
s.log.Trace().Msgf("feed disabled, skipping... %s", feed.Name) s.log.Trace().Msgf("feed disabled, skipping... %s", feed.Name)
continue continue
@ -408,6 +466,18 @@ func (s *service) startJob(f *domain.Feed) error {
return errors.New("no URL provided for feed: %s", f.Name) return errors.New("no URL provided for feed: %s", f.Name)
} }
// add proxy conf
if f.UseProxy {
proxyConf, err := s.proxySvc.FindByID(context.Background(), f.ProxyID)
if err != nil {
return errors.Wrap(err, "could not find proxy for indexer feed")
}
if proxyConf.Enabled {
f.Proxy = proxyConf
}
}
fi := newFeedInstance(f) fi := newFeedInstance(f)
job, err := s.initializeFeedJob(fi) job, err := s.initializeFeedJob(fi)

View file

@ -5,6 +5,7 @@ package feed
import ( import (
"context" "context"
"github.com/autobrr/autobrr/internal/proxy"
"math" "math"
"sort" "sort"
"strconv" "strconv"
@ -224,6 +225,18 @@ func mapFreeleechToBonus(percentage int) string {
} }
func (j *TorznabJob) getFeed(ctx context.Context) ([]torznab.FeedItem, error) { func (j *TorznabJob) getFeed(ctx context.Context) ([]torznab.FeedItem, error) {
// add proxy if enabled and exists
if j.Feed.UseProxy && j.Feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy)
if err != nil {
return nil, errors.Wrap(err, "could not get proxy client")
}
j.Client.WithHTTPClient(proxyClient)
j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name)
}
// get feed // get feed
feed, err := j.Client.FetchFeed(ctx) feed, err := j.Client.FetchFeed(ctx)
if err != nil { if err != nil {
@ -251,35 +264,33 @@ func (j *TorznabJob) getFeed(ctx context.Context) ([]torznab.FeedItem, error) {
// set ttl to 1 month // set ttl to 1 month
ttl := time.Now().AddDate(0, 1, 0) ttl := time.Now().AddDate(0, 1, 0)
for _, i := range feed.Channel.Items { for _, item := range feed.Channel.Items {
i := i if item.GUID == "" {
if i.GUID == "" {
j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name) j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name)
continue continue
} }
exists, err := j.CacheRepo.Exists(j.Feed.ID, i.GUID) exists, err := j.CacheRepo.Exists(j.Feed.ID, item.GUID)
if err != nil { if err != nil {
j.Log.Error().Err(err).Msg("could not check if item exists") j.Log.Error().Err(err).Msg("could not check if item exists")
continue continue
} }
if exists { if exists {
j.Log.Trace().Msgf("cache item exists, skipping release: %s", i.Title) j.Log.Trace().Msgf("cache item exists, skipping release: %s", item.Title)
continue continue
} }
j.Log.Debug().Msgf("found new release: %s", i.Title) j.Log.Debug().Msgf("found new release: %s", item.Title)
toCache = append(toCache, domain.FeedCacheItem{ toCache = append(toCache, domain.FeedCacheItem{
FeedId: strconv.Itoa(j.Feed.ID), FeedId: strconv.Itoa(j.Feed.ID),
Key: i.GUID, Key: item.GUID,
Value: []byte(i.Title), Value: []byte(item.Title),
TTL: ttl, TTL: ttl,
}) })
// only append if we successfully added to cache // only append if we successfully added to cache
items = append(items, *i) items = append(items, *item)
} }
if len(toCache) > 0 { if len(toCache) > 0 {

View file

@ -7,7 +7,6 @@ import (
"bytes" "bytes"
"context" "context"
"fmt" "fmt"
"github.com/autobrr/autobrr/internal/action"
"io" "io"
"net/http" "net/http"
"os" "os"
@ -17,9 +16,11 @@ import (
"strings" "strings"
"time" "time"
"github.com/autobrr/autobrr/internal/action"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/indexer" "github.com/autobrr/autobrr/internal/indexer"
"github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/releasedownload"
"github.com/autobrr/autobrr/internal/utils" "github.com/autobrr/autobrr/internal/utils"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/sharedhttp" "github.com/autobrr/autobrr/pkg/sharedhttp"
@ -53,11 +54,12 @@ type service struct {
releaseRepo domain.ReleaseRepo releaseRepo domain.ReleaseRepo
indexerSvc indexer.Service indexerSvc indexer.Service
apiService indexer.APIService apiService indexer.APIService
downloadSvc *releasedownload.DownloadService
httpClient *http.Client httpClient *http.Client
} }
func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service { func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service, downloadSvc *releasedownload.DownloadService) Service {
return &service{ return &service{
log: log.With().Str("module", "filter").Logger(), log: log.With().Str("module", "filter").Logger(),
repo: repo, repo: repo,
@ -65,6 +67,7 @@ func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Serv
actionService: actionSvc, actionService: actionSvc,
apiService: apiService, apiService: apiService,
indexerSvc: indexerSvc, indexerSvc: indexerSvc,
downloadSvc: downloadSvc,
httpClient: &http.Client{ httpClient: &http.Client{
Timeout: time.Second * 120, Timeout: time.Second * 120,
Transport: sharedhttp.TransportTLSInsecure, Transport: sharedhttp.TransportTLSInsecure,
@ -504,9 +507,9 @@ func (s *service) AdditionalSizeCheck(ctx context.Context, f *domain.Filter, rel
l.Trace().Msgf("(%s) preparing to download torrent metafile", f.Name) l.Trace().Msgf("(%s) preparing to download torrent metafile", f.Name)
// if indexer doesn't have api, download torrent and add to tmpPath // if indexer doesn't have api, download torrent and add to tmpPath
if err := release.DownloadTorrentFileCtx(ctx); err != nil { if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
l.Error().Err(err).Msgf("(%s) could not download torrent file with id: '%s' from: %s", f.Name, release.TorrentID, release.Indexer.Identifier) l.Error().Err(err).Msgf("(%s) could not download torrent file with id: '%s' from: %s", f.Name, release.TorrentID, release.Indexer.Identifier)
return false, err return false, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
} }
} }
@ -584,8 +587,8 @@ func (s *service) execCmd(ctx context.Context, external domain.FilterExternal, r
s.log.Trace().Msgf("filter exec release: %s", release.TorrentName) s.log.Trace().Msgf("filter exec release: %s", release.TorrentName)
if release.TorrentTmpFile == "" && strings.Contains(external.ExecArgs, "TorrentPathName") { if release.TorrentTmpFile == "" && strings.Contains(external.ExecArgs, "TorrentPathName") {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
return 0, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName) return 0, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
} }
} }
@ -686,7 +689,7 @@ func (s *service) webhook(ctx context.Context, external domain.FilterExternal, r
// if webhook data contains TorrentPathName or TorrentDataRawBytes, lets download the torrent file // if webhook data contains TorrentPathName or TorrentDataRawBytes, lets download the torrent file
if release.TorrentTmpFile == "" && (strings.Contains(external.WebhookData, "TorrentPathName") || strings.Contains(external.WebhookData, "TorrentDataRawBytes")) { if release.TorrentTmpFile == "" && (strings.Contains(external.WebhookData, "TorrentPathName") || strings.Contains(external.WebhookData, "TorrentDataRawBytes")) {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
return 0, errors.Wrap(err, "webhook: could not download torrent file for release: %s", release.TorrentName) return 0, errors.Wrap(err, "webhook: could not download torrent file for release: %s", release.TorrentName)
} }
} }

View file

@ -6,11 +6,11 @@ package http
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/autobrr/autobrr/pkg/errors"
"net/http" "net/http"
"strconv" "strconv"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )

View file

@ -7,12 +7,12 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"github.com/autobrr/autobrr/pkg/errors"
"net/http" "net/http"
"strconv" "strconv"
"strings" "strings"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/r3labs/sse/v2" "github.com/r3labs/sse/v2"

View file

@ -6,11 +6,11 @@ package http
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"github.com/autobrr/autobrr/pkg/errors"
"net/http" "net/http"
"strconv" "strconv"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
) )

151
internal/http/proxy.go Normal file
View file

@ -0,0 +1,151 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package http
import (
"context"
"encoding/json"
"net/http"
"strconv"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5"
)
type proxyService interface {
Store(ctx context.Context, p *domain.Proxy) error
Update(ctx context.Context, p *domain.Proxy) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context) ([]domain.Proxy, error)
FindByID(ctx context.Context, id int64) (*domain.Proxy, error)
Test(ctx context.Context, p *domain.Proxy) error
}
type proxyHandler struct {
encoder encoder
service proxyService
}
func newProxyHandler(encoder encoder, service proxyService) *proxyHandler {
return &proxyHandler{
encoder: encoder,
service: service,
}
}
func (h proxyHandler) Routes(r chi.Router) {
r.Get("/", h.list)
r.Post("/", h.store)
r.Post("/test", h.test)
r.Route("/{proxyID}", func(r chi.Router) {
r.Get("/", h.findByID)
r.Put("/", h.update)
r.Delete("/", h.delete)
})
}
func (h proxyHandler) store(w http.ResponseWriter, r *http.Request) {
var data domain.Proxy
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
if err := h.service.Store(r.Context(), &data); err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}
func (h proxyHandler) update(w http.ResponseWriter, r *http.Request) {
var data domain.Proxy
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
if err := h.service.Update(r.Context(), &data); err != nil {
if errors.Is(err, domain.ErrUpdateFailed) {
h.encoder.StatusError(w, http.StatusBadRequest, err)
return
}
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}
func (h proxyHandler) list(w http.ResponseWriter, r *http.Request) {
proxies, err := h.service.List(r.Context())
if err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(w, http.StatusOK, proxies)
}
func (h proxyHandler) findByID(w http.ResponseWriter, r *http.Request) {
proxyID, err := strconv.Atoi(chi.URLParam(r, "proxyID"))
if err != nil {
h.encoder.Error(w, err)
return
}
proxies, err := h.service.FindByID(r.Context(), int64(proxyID))
if err != nil {
if errors.Is(err, domain.ErrRecordNotFound) {
h.encoder.NotFoundErr(w, errors.New("could not find proxy with id %d", proxyID))
return
}
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(w, http.StatusOK, proxies)
}
func (h proxyHandler) delete(w http.ResponseWriter, r *http.Request) {
proxyID, err := strconv.Atoi(chi.URLParam(r, "proxyID"))
if err != nil {
h.encoder.Error(w, err)
return
}
err = h.service.Delete(r.Context(), int64(proxyID))
if err != nil {
if errors.Is(err, domain.ErrDeleteFailed) {
h.encoder.StatusError(w, http.StatusBadRequest, err)
return
}
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}
func (h proxyHandler) test(w http.ResponseWriter, r *http.Request) {
var data domain.Proxy
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
if err := h.service.Test(r.Context(), &data); err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}

View file

@ -43,11 +43,12 @@ type Server struct {
indexerService indexerService indexerService indexerService
ircService ircService ircService ircService
notificationService notificationService notificationService notificationService
proxyService proxyService
releaseService releaseService releaseService releaseService
updateService updateService updateService updateService
} }
func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db *database.DB, version string, commit string, date string, actionService actionService, apiService apikeyService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, feedSvc feedService, indexerSvc indexerService, ircSvc ircService, notificationSvc notificationService, releaseSvc releaseService, updateSvc updateService) Server { func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db *database.DB, version string, commit string, date string, actionService actionService, apiService apikeyService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, feedSvc feedService, indexerSvc indexerService, ircSvc ircService, notificationSvc notificationService, proxySvc proxyService, releaseSvc releaseService, updateSvc updateService) Server {
return Server{ return Server{
log: log.With().Str("module", "http").Logger(), log: log.With().Str("module", "http").Logger(),
config: config, config: config,
@ -68,6 +69,7 @@ func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db
indexerService: indexerSvc, indexerService: indexerSvc,
ircService: ircSvc, ircService: ircSvc,
notificationService: notificationSvc, notificationService: notificationSvc,
proxyService: proxySvc,
releaseService: releaseSvc, releaseService: releaseSvc,
updateService: updateSvc, updateService: updateSvc,
} }
@ -142,6 +144,7 @@ func (s Server) Handler() http.Handler {
r.Route("/keys", newAPIKeyHandler(encoder, s.apiService).Routes) r.Route("/keys", newAPIKeyHandler(encoder, s.apiService).Routes)
r.Route("/logs", newLogsHandler(s.config).Routes) r.Route("/logs", newLogsHandler(s.config).Routes)
r.Route("/notification", newNotificationHandler(encoder, s.notificationService).Routes) r.Route("/notification", newNotificationHandler(encoder, s.notificationService).Routes)
r.Route("/proxy", newProxyHandler(encoder, s.proxyService).Routes)
r.Route("/release", newReleaseHandler(encoder, s.releaseService).Routes) r.Route("/release", newReleaseHandler(encoder, s.releaseService).Routes)
r.Route("/updates", newUpdateHandler(encoder, s.updateService).Routes) r.Route("/updates", newUpdateHandler(encoder, s.updateService).Routes)

View file

@ -295,6 +295,9 @@ func (s *service) mapIndexer(indexer domain.Indexer) (*domain.IndexerDefinition,
d.BaseURL = indexer.BaseURL d.BaseURL = indexer.BaseURL
d.Enabled = indexer.Enabled d.Enabled = indexer.Enabled
d.UseProxy = indexer.UseProxy
d.ProxyID = indexer.ProxyID
if d.SettingsMap == nil { if d.SettingsMap == nil {
d.SettingsMap = make(map[string]string) d.SettingsMap = make(map[string]string)
} }
@ -332,6 +335,9 @@ func (s *service) updateMapIndexer(indexer domain.Indexer) (*domain.IndexerDefin
d.BaseURL = indexer.BaseURL d.BaseURL = indexer.BaseURL
d.Enabled = indexer.Enabled d.Enabled = indexer.Enabled
d.UseProxy = indexer.UseProxy
d.ProxyID = indexer.ProxyID
if d.SettingsMap == nil { if d.SettingsMap == nil {
d.SettingsMap = make(map[string]string) d.SettingsMap = make(map[string]string)
} }

View file

@ -6,6 +6,8 @@ package irc
import ( import (
"crypto/tls" "crypto/tls"
"fmt" "fmt"
"golang.org/x/net/proxy"
"net/url"
"slices" "slices"
"strings" "strings"
"time" "time"
@ -220,6 +222,37 @@ func (h *Handler) Run() (err error) {
Log: subLogger, Log: subLogger,
} }
if h.network.UseProxy && h.network.Proxy != nil {
if !h.network.Proxy.Enabled {
h.log.Debug().Msgf("proxy disabled, skip")
} else {
if h.network.Proxy.Addr == "" {
return errors.New("proxy addr missing")
}
proxyUrl, err := url.Parse(h.network.Proxy.Addr)
if err != nil {
return errors.Wrap(err, "could not parse proxy url: %s", h.network.Proxy.Addr)
}
// set user and pass if not empty
if h.network.Proxy.User != "" && h.network.Proxy.Pass != "" {
proxyUrl.User = url.UserPassword(h.network.Proxy.User, h.network.Proxy.Pass)
}
proxyDialer, err := proxy.FromURL(proxyUrl, proxy.Direct)
if err != nil {
return errors.Wrap(err, "could not create proxy dialer from url: %s", h.network.Proxy.Addr)
}
proxyContextDialer, ok := proxyDialer.(proxy.ContextDialer)
if !ok {
return errors.Wrap(err, "proxy dialer does not expose DialContext(): %v", proxyDialer)
}
client.DialContext = proxyContextDialer.DialContext
}
}
if h.network.Auth.Mechanism == domain.IRCAuthMechanismSASLPlain { if h.network.Auth.Mechanism == domain.IRCAuthMechanismSASLPlain {
if h.network.Auth.Account != "" && h.network.Auth.Password != "" { if h.network.Auth.Account != "" && h.network.Auth.Password != "" {
client.SASLLogin = h.network.Auth.Account client.SASLLogin = h.network.Auth.Account

View file

@ -14,6 +14,7 @@ import (
"github.com/autobrr/autobrr/internal/indexer" "github.com/autobrr/autobrr/internal/indexer"
"github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/notification" "github.com/autobrr/autobrr/internal/notification"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release" "github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
@ -47,8 +48,10 @@ type service struct {
releaseService release.Service releaseService release.Service
indexerService indexer.Service indexerService indexer.Service
notificationService notification.Service notificationService notification.Service
indexerMap map[string]string proxyService proxy.Service
handlers map[int64]*Handler
indexerMap map[string]string
handlers map[int64]*Handler
stopWG sync.WaitGroup stopWG sync.WaitGroup
lock sync.RWMutex lock sync.RWMutex
@ -56,7 +59,7 @@ type service struct {
const sseMaxEntries = 1000 const sseMaxEntries = 1000
func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, releaseSvc release.Service, indexerSvc indexer.Service, notificationSvc notification.Service) Service { func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, releaseSvc release.Service, indexerSvc indexer.Service, notificationSvc notification.Service, proxySvc proxy.Service) Service {
return &service{ return &service{
log: log.With().Str("module", "irc").Logger(), log: log.With().Str("module", "irc").Logger(),
sse: sse, sse: sse,
@ -64,6 +67,7 @@ func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, release
releaseService: releaseSvc, releaseService: releaseSvc,
indexerService: indexerSvc, indexerService: indexerSvc,
notificationService: notificationSvc, notificationService: notificationSvc,
proxyService: proxySvc,
handlers: make(map[int64]*Handler), handlers: make(map[int64]*Handler),
} }
} }
@ -79,6 +83,15 @@ func (s *service) StartHandlers() {
continue continue
} }
if network.ProxyId != 0 {
networkProxy, err := s.proxyService.FindByID(context.Background(), network.ProxyId)
if err != nil {
s.log.Error().Err(err).Msgf("failed to get proxy for network: %s", network.Server)
return
}
network.Proxy = networkProxy
}
channels, err := s.repo.ListChannels(network.ID) channels, err := s.repo.ListChannels(network.ID)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("failed to list channels for network: %s", network.Server) s.log.Error().Err(err).Msgf("failed to list channels for network: %s", network.Server)
@ -215,6 +228,14 @@ func (s *service) checkIfNetworkRestartNeeded(network *domain.IrcNetwork) error
restartNeeded = true restartNeeded = true
fieldsChanged = append(fieldsChanged, "bot mode") fieldsChanged = append(fieldsChanged, "bot mode")
} }
if handler.UseProxy != network.UseProxy {
restartNeeded = true
fieldsChanged = append(fieldsChanged, "use proxy")
}
if handler.ProxyId != network.ProxyId {
restartNeeded = true
fieldsChanged = append(fieldsChanged, "proxy id")
}
if handler.Auth.Mechanism != network.Auth.Mechanism { if handler.Auth.Mechanism != network.Auth.Mechanism {
restartNeeded = true restartNeeded = true
fieldsChanged = append(fieldsChanged, "auth mechanism") fieldsChanged = append(fieldsChanged, "auth mechanism")
@ -476,12 +497,16 @@ func (s *service) GetNetworksWithHealth(ctx context.Context) ([]domain.IrcNetwor
BouncerAddr: n.BouncerAddr, BouncerAddr: n.BouncerAddr,
UseBouncer: n.UseBouncer, UseBouncer: n.UseBouncer,
BotMode: n.BotMode, BotMode: n.BotMode,
UseProxy: n.UseProxy,
ProxyId: n.ProxyId,
Connected: false, Connected: false,
Channels: []domain.ChannelWithHealth{}, Channels: []domain.ChannelWithHealth{},
ConnectionErrors: []string{}, ConnectionErrors: []string{},
} }
s.lock.RLock()
handler, ok := s.handlers[n.ID] handler, ok := s.handlers[n.ID]
s.lock.RUnlock()
if ok { if ok {
handler.ReportStatus(&netw) handler.ReportStatus(&netw)
} }
@ -566,6 +591,18 @@ func (s *service) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork)
} }
s.log.Debug().Msgf("irc.service: update network: %s", network.Name) s.log.Debug().Msgf("irc.service: update network: %s", network.Name)
network.Proxy = nil
// attach proxy
if network.UseProxy && network.ProxyId != 0 {
networkProxy, err := s.proxyService.FindByID(context.Background(), network.ProxyId)
if err != nil {
s.log.Error().Err(err).Msgf("failed to get proxy for network: %s", network.Server)
return errors.Wrap(err, "could not get proxy for network: %s", network.Server)
}
network.Proxy = networkProxy
}
// stop or start network // stop or start network
// TODO get current state to see if enabled or not? // TODO get current state to see if enabled or not?
if network.Enabled { if network.Enabled {

194
internal/proxy/service.go Normal file
View file

@ -0,0 +1,194 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package proxy
import (
"context"
"net/http"
"net/url"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/sharedhttp"
"github.com/rs/zerolog"
netProxy "golang.org/x/net/proxy"
)
type Service interface {
List(ctx context.Context) ([]domain.Proxy, error)
FindByID(ctx context.Context, id int64) (*domain.Proxy, error)
Store(ctx context.Context, p *domain.Proxy) error
Update(ctx context.Context, p *domain.Proxy) error
Delete(ctx context.Context, id int64) error
Test(ctx context.Context, p *domain.Proxy) error
}
type service struct {
log zerolog.Logger
repo domain.ProxyRepo
cache map[int64]*domain.Proxy
}
func NewService(log logger.Logger, repo domain.ProxyRepo) Service {
return &service{
log: log.With().Str("module", "proxy").Logger(),
repo: repo,
cache: make(map[int64]*domain.Proxy),
}
}
func (s *service) Store(ctx context.Context, proxy *domain.Proxy) error {
if err := proxy.Validate(); err != nil {
return errors.Wrap(err, "validation error")
}
err := s.repo.Store(ctx, proxy)
if err != nil {
return err
}
s.cache[proxy.ID] = proxy
return nil
}
func (s *service) Update(ctx context.Context, proxy *domain.Proxy) error {
if err := proxy.Validate(); err != nil {
return errors.Wrap(err, "validation error")
}
err := s.repo.Update(ctx, proxy)
if err != nil {
return err
}
s.cache[proxy.ID] = proxy
// TODO update IRC handlers
return nil
}
func (s *service) FindByID(ctx context.Context, id int64) (*domain.Proxy, error) {
if proxy, ok := s.cache[id]; ok {
return proxy, nil
}
return s.repo.FindByID(ctx, id)
}
func (s *service) List(ctx context.Context) ([]domain.Proxy, error) {
return s.repo.List(ctx)
}
func (s *service) ToggleEnabled(ctx context.Context, id int64, enabled bool) error {
err := s.repo.ToggleEnabled(ctx, id, enabled)
if err != nil {
return err
}
v, ok := s.cache[id]
if !ok {
v.Enabled = !enabled
s.cache[id] = v
}
// TODO update IRC handlers
return nil
}
func (s *service) Delete(ctx context.Context, id int64) error {
err := s.repo.Delete(ctx, id)
if err != nil {
return err
}
delete(s.cache, id)
// TODO update IRC handlers
return nil
}
func (s *service) Test(ctx context.Context, proxy *domain.Proxy) error {
if !proxy.ValidProxyType() {
return errors.New("invalid proxy type %s", proxy.Type)
}
if proxy.Addr == "" {
return errors.New("proxy addr missing")
}
httpClient, err := GetProxiedHTTPClient(proxy)
if err != nil {
return errors.Wrap(err, "could not get http client")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://autobrr.com", nil)
if err != nil {
return errors.Wrap(err, "could not create proxy request")
}
resp, err := httpClient.Do(req)
if err != nil {
return errors.Wrap(err, "could not connect to proxy server: %s", proxy.Addr)
}
if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}
s.log.Debug().Msgf("proxy %s test OK!", proxy.Addr)
return nil
}
func GetProxiedHTTPClient(p *domain.Proxy) (*http.Client, error) {
proxyUrl, err := url.Parse(p.Addr)
if err != nil {
return nil, errors.Wrap(err, "could not parse proxy url: %s", p.Addr)
}
// set user and pass if not empty
if p.User != "" && p.Pass != "" {
proxyUrl.User = url.UserPassword(p.User, p.Pass)
}
transport := sharedhttp.TransportTLSInsecure
// set user and pass if not empty
if p.User != "" && p.Pass != "" {
proxyUrl.User = url.UserPassword(p.User, p.Pass)
}
switch p.Type {
case domain.ProxyTypeSocks5:
proxyDialer, err := netProxy.FromURL(proxyUrl, netProxy.Direct)
if err != nil {
return nil, errors.Wrap(err, "could not create proxy dialer from url: %s", p.Addr)
}
proxyContextDialer, ok := proxyDialer.(netProxy.ContextDialer)
if !ok {
return nil, errors.Wrap(err, "proxy dialer does not expose DialContext(): %v", proxyDialer)
}
transport.DialContext = proxyContextDialer.DialContext
default:
return nil, errors.New("invalid proxy type: %s", p.Type)
}
client := &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
return client, nil
}

View file

@ -0,0 +1,316 @@
package releasedownload
import (
"bufio"
"bytes"
"context"
"io"
"net"
"net/http"
"net/http/cookiejar"
"os"
"strings"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/sharedhttp"
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/torrent/metainfo"
"github.com/avast/retry-go/v4"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
)
type DownloadService struct {
log zerolog.Logger
repo domain.ReleaseRepo
indexerRepo domain.IndexerRepo
proxySvc proxy.Service
}
func NewDownloadService(log logger.Logger, repo domain.ReleaseRepo, indexerRepo domain.IndexerRepo, proxySvc proxy.Service) *DownloadService {
return &DownloadService{
log: log.With().Str("module", "release-download").Logger(),
repo: repo,
indexerRepo: indexerRepo,
proxySvc: proxySvc,
}
}
func (s *DownloadService) DownloadRelease(ctx context.Context, rls *domain.Release) error {
if rls.HasMagnetUri() {
return errors.New("downloading magnet links is not supported: %s", rls.MagnetURI)
} else if rls.Protocol != domain.ReleaseProtocolTorrent {
return errors.New("could not download file: protocol %s is not supported", rls.Protocol)
}
if rls.DownloadURL == "" {
return errors.New("download_file: url can't be empty")
} else if rls.TorrentTmpFile != "" {
// already downloaded
return nil
}
// get indexer
indexer, err := s.indexerRepo.FindByID(ctx, rls.Indexer.ID)
if err != nil {
return err
}
// get proxy
if indexer.UseProxy {
proxyConf, err := s.proxySvc.FindByID(ctx, indexer.ProxyID)
if err != nil {
return err
}
if proxyConf.Enabled {
s.log.Debug().Msgf("using proxy: %s", proxyConf.Name)
indexer.Proxy = proxyConf
} else {
s.log.Debug().Msgf("proxy disabled, skip: %s", proxyConf.Name)
}
}
// download release
err = s.downloadTorrentFile(ctx, indexer, rls)
if err != nil {
return err
}
return nil
}
func (s *DownloadService) downloadTorrentFile(ctx context.Context, indexer *domain.Indexer, r *domain.Release) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.DownloadURL, nil)
if err != nil {
return errors.Wrap(err, "error downloading file")
}
req.Header.Set("User-Agent", "autobrr")
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: sharedhttp.TransportTLSInsecure,
}
// handle proxy
if indexer.Proxy != nil {
s.log.Debug().Msgf("using proxy: %s", indexer.Proxy.Name)
proxiedClient, err := proxy.GetProxiedHTTPClient(indexer.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxied http client")
}
httpClient = proxiedClient
}
if r.RawCookie != "" {
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
return errors.Wrap(err, "could not create cookiejar")
}
httpClient.Jar = jar
// set the cookie on the header instead of req.AddCookie
// since we have a raw cookie like "uid=10; pass=000"
req.Header.Set("Cookie", r.RawCookie)
}
tmpFilePattern := "autobrr-"
tmpDir := os.TempDir()
// Create tmp file
// TODO check if tmp file is wanted
tmpFile, err := os.CreateTemp(tmpDir, tmpFilePattern)
if err != nil {
if os.IsNotExist(err) {
if mkdirErr := os.MkdirAll(tmpDir, os.ModePerm); mkdirErr != nil {
return errors.Wrap(mkdirErr, "could not create TMP dir: %s", tmpDir)
}
tmpFile, err = os.CreateTemp(tmpDir, tmpFilePattern)
if err != nil {
return errors.Wrap(err, "error creating tmp file in: %s", tmpDir)
}
} else {
return errors.Wrap(err, "error creating tmp file")
}
}
defer tmpFile.Close()
errFunc := retry.Do(retryableRequest(httpClient, req, r, tmpFile), retry.Delay(time.Second*3), retry.Attempts(3), retry.MaxJitter(time.Second*1))
return errFunc
}
func retryableRequest(httpClient *http.Client, req *http.Request, r *domain.Release, tmpFile *os.File) func() error {
return func() error {
// Get the data
resp, err := httpClient.Do(req)
if err != nil {
if errors.As(err, net.OpError{}) {
return retry.Unrecoverable(errors.Wrap(err, "issue from proxy"))
}
return errors.Wrap(err, "error downloading file")
}
defer resp.Body.Close()
// Check server response
switch resp.StatusCode {
case http.StatusOK:
// Continue processing the response
break
//case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
// // Handle redirect
// return retry.Unrecoverable(errors.New("redirect encountered for torrent (%s) file (%s) - status code: %d - check indexer keys for %s", r.TorrentName, r.DownloadURL, resp.StatusCode, r.Indexer.Name))
case http.StatusUnauthorized, http.StatusForbidden:
return retry.Unrecoverable(errors.New("unrecoverable error downloading torrent (%s) file (%s) - status code: %d - check indexer keys for %s", r.TorrentName, r.DownloadURL, resp.StatusCode, r.Indexer.Name))
case http.StatusMethodNotAllowed:
return retry.Unrecoverable(errors.New("unrecoverable error downloading torrent (%s) file (%s) from '%s' - status code: %d. Check if the request method is correct", r.TorrentName, r.DownloadURL, r.Indexer.Name, resp.StatusCode))
case http.StatusNotFound:
return errors.New("torrent %s not found on %s (%d) - retrying", r.TorrentName, r.Indexer.Name, resp.StatusCode)
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return errors.New("server error (%d) encountered while downloading torrent (%s) file (%s) from '%s' - retrying", resp.StatusCode, r.TorrentName, r.DownloadURL, r.Indexer.Name)
case http.StatusInternalServerError:
return errors.New("server error (%d) encountered while downloading torrent (%s) file (%s) - check indexer keys for %s", resp.StatusCode, r.TorrentName, r.DownloadURL, r.Indexer.Name)
default:
return retry.Unrecoverable(errors.New("unexpected status code %d: check indexer keys for %s", resp.StatusCode, r.Indexer.Name))
}
resetTmpFile := func() {
tmpFile.Seek(0, io.SeekStart)
tmpFile.Truncate(0)
}
// Read the body into bytes
bodyBytes, err := io.ReadAll(bufio.NewReader(resp.Body))
if err != nil {
return errors.Wrap(err, "error reading response body")
}
// Create a new reader for bodyBytes
bodyReader := bytes.NewReader(bodyBytes)
// Try to decode as torrent file
meta, err := metainfo.Load(bodyReader)
if err != nil {
resetTmpFile()
// explicitly check for unexpected content type that match html
var bse *bencode.SyntaxError
if errors.As(err, &bse) {
// regular error so we can retry if we receive html first run
return errors.Wrap(err, "metainfo unexpected content type, got HTML expected a bencoded torrent. check indexer keys for %s - %s", r.Indexer.Name, r.TorrentName)
}
return retry.Unrecoverable(errors.Wrap(err, "metainfo unexpected content type. check indexer keys for %s - %s", r.Indexer.Name, r.TorrentName))
}
torrentMetaInfo, err := meta.UnmarshalInfo()
if err != nil {
resetTmpFile()
return retry.Unrecoverable(errors.Wrap(err, "metainfo could not unmarshal info from torrent: %s", tmpFile.Name()))
}
hashInfoBytes := meta.HashInfoBytes().Bytes()
if len(hashInfoBytes) < 1 {
resetTmpFile()
return retry.Unrecoverable(errors.New("could not read infohash"))
}
// Write the body to file
// TODO move to io.Reader and pass around in the future
if _, err := tmpFile.Write(bodyBytes); err != nil {
resetTmpFile()
return errors.Wrap(err, "error writing downloaded file: %s", tmpFile.Name())
}
r.TorrentTmpFile = tmpFile.Name()
r.TorrentHash = meta.HashInfoBytes().String()
r.Size = uint64(torrentMetaInfo.TotalLength())
return nil
}
}
func (s *DownloadService) ResolveMagnetURI(ctx context.Context, r *domain.Release) error {
if r.MagnetURI == "" {
return nil
} else if strings.HasPrefix(r.MagnetURI, domain.MagnetURIPrefix) {
return nil
}
// get indexer
indexer, err := s.indexerRepo.FindByID(ctx, r.Indexer.ID)
if err != nil {
return err
}
httpClient := &http.Client{
Timeout: time.Second * 45,
Transport: sharedhttp.MagnetTransport,
}
// get proxy
if indexer.UseProxy {
proxyConf, err := s.proxySvc.FindByID(ctx, indexer.ProxyID)
if err != nil {
return err
}
s.log.Debug().Msgf("using proxy: %s", proxyConf.Name)
proxiedClient, err := proxy.GetProxiedHTTPClient(proxyConf)
if err != nil {
return errors.Wrap(err, "could not get proxied http client")
}
httpClient = proxiedClient
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.MagnetURI, nil)
if err != nil {
return errors.Wrap(err, "could not build request to resolve magnet uri")
}
//req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "autobrr")
res, err := httpClient.Do(req)
if err != nil {
return errors.Wrap(err, "could not make request to resolve magnet uri")
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return errors.New("unexpected status code: %d", res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return errors.Wrap(err, "could not read response body")
}
magnet := string(body)
if magnet != "" {
r.MagnetURI = magnet
}
return nil
}

View file

@ -25,6 +25,7 @@ type Client interface {
GetFeed(ctx context.Context) (*Feed, error) GetFeed(ctx context.Context) (*Feed, error)
GetCaps(ctx context.Context) (*Caps, error) GetCaps(ctx context.Context) (*Caps, error)
Caps() *Caps Caps() *Caps
WithHTTPClient(client *http.Client)
} }
type client struct { type client struct {
@ -41,6 +42,10 @@ type client struct {
Log *log.Logger Log *log.Logger
} }
func (c *client) WithHTTPClient(client *http.Client) {
c.http = client
}
type BasicAuth struct { type BasicAuth struct {
Username string Username string
Password string Password string
@ -99,6 +104,9 @@ func (c *client) get(ctx context.Context, endpoint string, queryParams map[strin
} }
u, err := url.Parse(c.Host) u, err := url.Parse(c.Host)
if err != nil {
return 0, nil, err
}
u.Path = strings.TrimSuffix(u.Path, "/") u.Path = strings.TrimSuffix(u.Path, "/")
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
reqUrl := u.String() reqUrl := u.String()
@ -273,6 +281,9 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s
} }
u, err := url.Parse(c.Host) u, err := url.Parse(c.Host)
if err != nil {
return 0, nil, err
}
u.Path = strings.TrimSuffix(u.Path, "/") u.Path = strings.TrimSuffix(u.Path, "/")
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
reqUrl := u.String() reqUrl := u.String()
@ -325,7 +336,6 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s
} }
func (c *client) GetCaps(ctx context.Context) (*Caps, error) { func (c *client) GetCaps(ctx context.Context) (*Caps, error) {
status, res, err := c.getCaps(ctx, "?t=caps", nil) status, res, err := c.getCaps(ctx, "?t=caps", nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get caps for feed") return nil, errors.Wrap(err, "could not get caps for feed")

View file

@ -23,6 +23,7 @@ type Client interface {
FetchFeed(ctx context.Context) (*Feed, error) FetchFeed(ctx context.Context) (*Feed, error)
FetchCaps(ctx context.Context) (*Caps, error) FetchCaps(ctx context.Context) (*Caps, error)
GetCaps() *Caps GetCaps() *Caps
WithHTTPClient(client *http.Client)
} }
type client struct { type client struct {
@ -39,6 +40,10 @@ type client struct {
Log *log.Logger Log *log.Logger
} }
func (c *client) WithHTTPClient(client *http.Client) {
c.http = client
}
type BasicAuth struct { type BasicAuth struct {
Username string Username string
Password string Password string
@ -90,6 +95,10 @@ func (c *client) get(ctx context.Context, endpoint string, opts map[string]strin
} }
u, err := url.Parse(c.Host) u, err := url.Parse(c.Host)
if err != nil {
return 0, nil, err
}
u.Path = strings.TrimSuffix(u.Path, "/") u.Path = strings.TrimSuffix(u.Path, "/")
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
reqUrl := u.String() reqUrl := u.String()
@ -177,6 +186,9 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s
} }
u, err := url.Parse(c.Host) u, err := url.Parse(c.Host)
if err != nil {
return 0, nil, err
}
u.Path = strings.TrimSuffix(u.Path, "/") u.Path = strings.TrimSuffix(u.Path, "/")
u.RawQuery = params.Encode() u.RawQuery = params.Encode()
reqUrl := u.String() reqUrl := u.String()
@ -229,7 +241,6 @@ func (c *client) getCaps(ctx context.Context, endpoint string, opts map[string]s
} }
func (c *client) FetchCaps(ctx context.Context) (*Caps, error) { func (c *client) FetchCaps(ctx context.Context) (*Caps, error) {
status, res, err := c.getCaps(ctx, "?t=caps", nil) status, res, err := c.getCaps(ctx, "?t=caps", nil)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not get caps for feed") return nil, errors.Wrap(err, "could not get caps for feed")

View file

@ -389,6 +389,20 @@ export const APIClient = {
body: notification body: notification
}) })
}, },
proxy: {
list: () => appClient.Get<Proxy[]>("api/proxy"),
getByID: (id: number) => appClient.Get<Proxy>(`api/proxy/${id}`),
store: (proxy: ProxyCreate) => appClient.Post("api/proxy", {
body: proxy
}),
update: (proxy: Proxy) => appClient.Put(`api/proxy/${proxy.id}`, {
body: proxy
}),
delete: (id: number) => appClient.Delete(`api/proxy/${id}`),
test: (proxy: Proxy) => appClient.Post("api/proxy/test", {
body: proxy
})
},
release: { release: {
find: (query?: string) => appClient.Get<ReleaseFindResponse>(`api/release${query}`), find: (query?: string) => appClient.Get<ReleaseFindResponse>(`api/release${query}`),
findRecent: () => appClient.Get<ReleaseFindResponse>("api/release/recent"), findRecent: () => appClient.Get<ReleaseFindResponse>("api/release/recent"),

View file

@ -11,7 +11,7 @@ import {
FeedKeys, FeedKeys,
FilterKeys, FilterKeys,
IndexerKeys, IndexerKeys,
IrcKeys, NotificationKeys, IrcKeys, NotificationKeys, ProxyKeys,
ReleaseKeys, ReleaseKeys,
SettingsKeys SettingsKeys
} from "@api/query_keys"; } from "@api/query_keys";
@ -137,3 +137,17 @@ export const ReleasesIndexersQueryOptions = () =>
placeholderData: keepPreviousData, placeholderData: keepPreviousData,
staleTime: Infinity staleTime: Infinity
}); });
export const ProxiesQueryOptions = () =>
queryOptions({
queryKey: ProxyKeys.lists(),
queryFn: () => APIClient.proxy.list(),
refetchOnWindowFocus: false
});
export const ProxyByIdQueryOptions = (proxyId: number) =>
queryOptions({
queryKey: ProxyKeys.detail(proxyId),
queryFn: async ({queryKey}) => await APIClient.proxy.getByID(queryKey[2]),
retry: false,
});

View file

@ -80,3 +80,10 @@ export const NotificationKeys = {
details: () => [...NotificationKeys.all, "detail"] as const, details: () => [...NotificationKeys.all, "detail"] as const,
detail: (id: number) => [...NotificationKeys.details(), id] as const detail: (id: number) => [...NotificationKeys.details(), id] as const
}; };
export const ProxyKeys = {
all: ["proxy"] as const,
lists: () => [...ProxyKeys.all, "list"] as const,
details: () => [...ProxyKeys.all, "detail"] as const,
detail: (id: number) => [...ProxyKeys.details(), id] as const
};

View file

@ -17,6 +17,7 @@ interface SelectFieldProps<T> {
label: string; label: string;
help?: string; help?: string;
placeholder?: string; placeholder?: string;
required?: boolean;
defaultValue?: OptionBasicTyped<T>; defaultValue?: OptionBasicTyped<T>;
tooltip?: JSX.Element; tooltip?: JSX.Element;
options: OptionBasicTyped<T>[]; options: OptionBasicTyped<T>[];
@ -158,7 +159,7 @@ export function SelectField<T>({ name, label, help, placeholder, options }: Sele
); );
} }
export function SelectFieldBasic<T>({ name, label, help, placeholder, tooltip, defaultValue, options }: SelectFieldProps<T>) { export function SelectFieldBasic<T>({ name, label, help, placeholder, required, tooltip, defaultValue, options }: SelectFieldProps<T>) {
return ( return (
<div className="space-y-1 p-4 sm:space-y-0 sm:grid sm:grid-cols-3 sm:gap-4"> <div className="space-y-1 p-4 sm:space-y-0 sm:grid sm:grid-cols-3 sm:gap-4">
<div> <div>
@ -182,6 +183,7 @@ export function SelectFieldBasic<T>({ name, label, help, placeholder, tooltip, d
<Select <Select
{...field} {...field}
id={name} id={name}
required={required}
components={{ components={{
Input: common.SelectInput, Input: common.SelectInput,
Control: common.SelectControl, Control: common.SelectControl,

View file

@ -79,4 +79,33 @@ const SwitchGroup = ({
</Field> </Field>
); );
export { SwitchGroup }; interface SwitchButtonProps {
name: string;
defaultValue?: boolean;
className?: string;
}
const SwitchButton = ({ name, defaultValue }: SwitchButtonProps) => (
<Field as="div" className="flex items-center justify-between">
<FormikField
name={name}
defaultValue={defaultValue as boolean}
type="checkbox"
>
{({
field,
form: { setFieldValue }
}: FieldProps) => (
<Checkbox
{...field}
value={!!field.checked}
setValue={(value) => {
setFieldValue(field?.name ?? "", value);
}}
/>
)}
</FormikField>
</Field>
);
export { SwitchGroup, SwitchButton };

View file

@ -577,3 +577,10 @@ export const ExternalFilterWebhookMethodOptions: OptionBasicTyped<WebhookMethod>
{ label: "PATCH", value: "PATCH" }, { label: "PATCH", value: "PATCH" },
{ label: "DELETE", value: "DELETE" } { label: "DELETE", value: "DELETE" }
]; ];
export const ProxyTypeOptions: OptionBasicTyped<ProxyType>[] = [
{
label: "SOCKS5",
value: "SOCKS5"
},
];

View file

@ -16,14 +16,15 @@ import { classNames, sleep } from "@utils";
import { DEBUG } from "@components/debug"; import { DEBUG } from "@components/debug";
import { APIClient } from "@api/APIClient"; import { APIClient } from "@api/APIClient";
import { FeedKeys, IndexerKeys, ReleaseKeys } from "@api/query_keys"; import { FeedKeys, IndexerKeys, ReleaseKeys } from "@api/query_keys";
import { IndexersSchemaQueryOptions } from "@api/queries"; import { IndexersSchemaQueryOptions, ProxiesQueryOptions } from "@api/queries";
import { SlideOver } from "@components/panels"; import { SlideOver } from "@components/panels";
import Toast from "@components/notifications/Toast"; import Toast from "@components/notifications/Toast";
import { PasswordFieldWide, SwitchGroupWide, TextFieldWide } from "@components/inputs"; import { PasswordFieldWide, SwitchButton, SwitchGroupWide, TextFieldWide } from "@components/inputs";
import { SelectFieldBasic, SelectFieldCreatable } from "@components/inputs/select_wide"; import { SelectFieldBasic, SelectFieldCreatable } from "@components/inputs/select_wide";
import { FeedDownloadTypeOptions } from "@domain/constants"; import { FeedDownloadTypeOptions } from "@domain/constants";
import { DocsLink } from "@components/ExternalLink"; import { DocsLink } from "@components/ExternalLink";
import * as common from "@components/inputs/common"; import * as common from "@components/inputs/common";
import { SelectField } from "@forms/settings/IrcForms";
// const isRequired = (message: string) => (value?: string | undefined) => (!!value ? undefined : message); // const isRequired = (message: string) => (value?: string | undefined) => (!!value ? undefined : message);
@ -254,7 +255,7 @@ type SelectValue = {
value: string; value: string;
}; };
interface AddProps { export interface AddProps {
isOpen: boolean; isOpen: boolean;
toggle: () => void; toggle: () => void;
} }
@ -718,6 +719,8 @@ interface IndexerUpdateInitialValues {
identifier_external: string; identifier_external: string;
implementation: string; implementation: string;
base_url: string; base_url: string;
use_proxy?: boolean;
proxy_id?: number;
settings: { settings: {
api_key?: string; api_key?: string;
api_user?: string; api_user?: string;
@ -735,6 +738,8 @@ interface UpdateProps {
export function IndexerUpdateForm({ isOpen, toggle, indexer }: UpdateProps) { export function IndexerUpdateForm({ isOpen, toggle, indexer }: UpdateProps) {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const proxies = useQuery(ProxiesQueryOptions());
const mutation = useMutation({ const mutation = useMutation({
mutationFn: (indexer: Indexer) => APIClient.indexers.update(indexer), mutationFn: (indexer: Indexer) => APIClient.indexers.update(indexer),
onSuccess: () => { onSuccess: () => {
@ -813,6 +818,8 @@ export function IndexerUpdateForm({ isOpen, toggle, indexer }: UpdateProps) {
identifier_external: indexer.identifier_external, identifier_external: indexer.identifier_external,
implementation: indexer.implementation, implementation: indexer.implementation,
base_url: indexer.base_url, base_url: indexer.base_url,
use_proxy: indexer.use_proxy,
proxy_id: indexer.proxy_id,
settings: indexer.settings?.reduce( settings: indexer.settings?.reduce(
(o: Record<string, string>, obj: IndexerSetting) => ({ (o: Record<string, string>, obj: IndexerSetting) => ({
...o, ...o,
@ -833,7 +840,7 @@ export function IndexerUpdateForm({ isOpen, toggle, indexer }: UpdateProps) {
initialValues={initialValues} initialValues={initialValues}
extraButtons={(values) => <TestApiButton values={values as FormikValues} show={indexer.implementation === "irc" && indexer.supports.includes("api")} />} extraButtons={(values) => <TestApiButton values={values as FormikValues} show={indexer.implementation === "irc" && indexer.supports.includes("api")} />}
> >
{() => ( {(values) => (
<div className="py-2 space-y-6 sm:py-0 sm:space-y-0 divide-y divide-gray-200 dark:divide-gray-700"> <div className="py-2 space-y-6 sm:py-0 sm:space-y-0 divide-y divide-gray-200 dark:divide-gray-700">
<div className="space-y-1 p-4 sm:space-y-0 sm:grid sm:grid-cols-3 sm:gap-4"> <div className="space-y-1 p-4 sm:space-y-0 sm:grid sm:grid-cols-3 sm:gap-4">
<label <label
@ -863,14 +870,15 @@ export function IndexerUpdateForm({ isOpen, toggle, indexer }: UpdateProps) {
tooltip={ tooltip={
<div> <div>
<p>External Identifier for use with ARRs to get features like seed limits working.</p> <p>External Identifier for use with ARRs to get features like seed limits working.</p>
<br /> <br/>
<p>This needs to match the indexer name in your ARR. If using Prowlarr it will likely be "{indexer.name} (Prowlarr)"</p> <p>This needs to match the indexer name in your ARR. If using Prowlarr it will likely be
<br /> "{indexer.name} (Prowlarr)"</p>
<DocsLink href="https://autobrr.com/configuration/indexers#setup" /> <br/>
<DocsLink href="https://autobrr.com/configuration/indexers#setup"/>
</div> </div>
} }
/> />
<SwitchGroupWide name="enabled" label="Enabled" /> <SwitchGroupWide name="enabled" label="Enabled"/>
{indexer.implementation == "irc" && ( {indexer.implementation == "irc" && (
<SelectFieldCreatable <SelectFieldCreatable
@ -882,6 +890,31 @@ export function IndexerUpdateForm({ isOpen, toggle, indexer }: UpdateProps) {
)} )}
{renderSettingFields(indexer.settings)} {renderSettingFields(indexer.settings)}
<div className="border-t border-gray-200 dark:border-gray-700 py-4">
<div className="flex justify-between px-4">
<div className="space-y-1">
<DialogTitle className="text-lg font-medium text-gray-900 dark:text-white">
Proxy
</DialogTitle>
<p className="text-sm text-gray-500 dark:text-gray-400">
Set a proxy to be used for downloads of .torrent files and feeds.
</p>
</div>
<SwitchButton name="use_proxy" />
</div>
{values.use_proxy === true && (
<div className="py-4 pt-6">
<SelectField<number>
name="proxy_id"
label="Select proxy"
placeholder="Select a proxy"
options={proxies.data ? proxies.data.map((p) => ({ label: p.name, value: p.id })) : []}
/>
</div>
)}
</div>
</div> </div>
)} )}
</SlideOver> </SlideOver>

View file

@ -3,7 +3,7 @@
* SPDX-License-Identifier: GPL-2.0-or-later * SPDX-License-Identifier: GPL-2.0-or-later
*/ */
import { useMutation, useQueryClient } from "@tanstack/react-query"; import { useMutation, useQuery, useQueryClient } from "@tanstack/react-query";
import { toast } from "react-hot-toast"; import { toast } from "react-hot-toast";
import { XMarkIcon } from "@heroicons/react/24/solid"; import { XMarkIcon } from "@heroicons/react/24/solid";
import type { FieldProps } from "formik"; import type { FieldProps } from "formik";
@ -16,11 +16,12 @@ import { DialogTitle } from "@headlessui/react";
import { IrcAuthMechanismTypeOptions, OptionBasicTyped } from "@domain/constants"; import { IrcAuthMechanismTypeOptions, OptionBasicTyped } from "@domain/constants";
import { APIClient } from "@api/APIClient"; import { APIClient } from "@api/APIClient";
import { IrcKeys } from "@api/query_keys"; import { IrcKeys } from "@api/query_keys";
import { NumberFieldWide, PasswordFieldWide, SwitchGroupWide, TextFieldWide } from "@components/inputs"; import { NumberFieldWide, PasswordFieldWide, SwitchButton, SwitchGroupWide, TextFieldWide } from "@components/inputs";
import { SlideOver } from "@components/panels"; import { SlideOver } from "@components/panels";
import Toast from "@components/notifications/Toast"; import Toast from "@components/notifications/Toast";
import * as common from "@components/inputs/common"; import * as common from "@components/inputs/common";
import { classNames } from "@utils"; import { classNames } from "@utils";
import { ProxiesQueryOptions } from "@api/queries";
interface ChannelsFieldArrayProps { interface ChannelsFieldArrayProps {
channels: IrcChannel[]; channels: IrcChannel[];
@ -270,6 +271,8 @@ interface IrcNetworkUpdateFormValues {
bouncer_addr: string; bouncer_addr: string;
bot_mode: boolean; bot_mode: boolean;
channels: Array<IrcChannel>; channels: Array<IrcChannel>;
use_proxy: boolean;
proxy_id: number;
} }
interface IrcNetworkUpdateFormProps { interface IrcNetworkUpdateFormProps {
@ -285,6 +288,8 @@ export function IrcNetworkUpdateForm({
}: IrcNetworkUpdateFormProps) { }: IrcNetworkUpdateFormProps) {
const queryClient = useQueryClient(); const queryClient = useQueryClient();
const proxies = useQuery(ProxiesQueryOptions());
const updateMutation = useMutation({ const updateMutation = useMutation({
mutationFn: (network: IrcNetwork) => APIClient.irc.updateNetwork(network), mutationFn: (network: IrcNetwork) => APIClient.irc.updateNetwork(network),
onSuccess: () => { onSuccess: () => {
@ -325,7 +330,9 @@ export function IrcNetworkUpdateForm({
use_bouncer: network.use_bouncer, use_bouncer: network.use_bouncer,
bouncer_addr: network.bouncer_addr, bouncer_addr: network.bouncer_addr,
bot_mode: network.bot_mode, bot_mode: network.bot_mode,
channels: network.channels channels: network.channels,
use_proxy: network.use_proxy,
proxy_id: network.proxy_id,
}; };
return ( return (
@ -348,7 +355,7 @@ export function IrcNetworkUpdateForm({
required={true} required={true}
/> />
<SwitchGroupWide name="enabled" label="Enabled" /> <SwitchGroupWide name="enabled" label="Enabled"/>
<TextFieldWide <TextFieldWide
name="server" name="server"
label="Server" label="Server"
@ -362,7 +369,7 @@ export function IrcNetworkUpdateForm({
required={true} required={true}
/> />
<SwitchGroupWide name="tls" label="TLS" /> <SwitchGroupWide name="tls" label="TLS"/>
<PasswordFieldWide <PasswordFieldWide
name="pass" name="pass"
@ -377,7 +384,7 @@ export function IrcNetworkUpdateForm({
required={true} required={true}
/> />
<SwitchGroupWide name="use_bouncer" label="Bouncer (BNC)" /> <SwitchGroupWide name="use_bouncer" label="Bouncer (BNC)"/>
{values.use_bouncer && ( {values.use_bouncer && (
<TextFieldWide <TextFieldWide
name="bouncer_addr" name="bouncer_addr"
@ -386,7 +393,32 @@ export function IrcNetworkUpdateForm({
/> />
)} )}
<SwitchGroupWide name="bot_mode" label="IRCv3 Bot Mode" /> <SwitchGroupWide name="bot_mode" label="IRCv3 Bot Mode"/>
<div className="border-t border-gray-200 dark:border-gray-700 py-4">
<div className="flex justify-between px-4">
<div className="space-y-1">
<DialogTitle className="text-lg font-medium text-gray-900 dark:text-white">
Proxy
</DialogTitle>
<p className="text-sm text-gray-500 dark:text-gray-400">
Set a proxy to be used for connecting to the irc server.
</p>
</div>
<SwitchButton name="use_proxy"/>
</div>
{values.use_proxy === true && (
<div className="py-4 pt-6">
<SelectField<number>
name="proxy_id"
label="Select proxy"
placeholder="Select a proxy"
options={proxies.data ? proxies.data.map((p) => ({ label: p.name, value: p.id })) : []}
/>
</div>
)}
</div>
<div className="border-t border-gray-200 dark:border-gray-700 py-5"> <div className="border-t border-gray-200 dark:border-gray-700 py-5">
<div className="px-4 space-y-1 mb-8"> <div className="px-4 space-y-1 mb-8">
@ -416,17 +448,17 @@ export function IrcNetworkUpdateForm({
/> />
</div> </div>
<PasswordFieldWide name="invite_command" label="Invite command" /> <PasswordFieldWide name="invite_command" label="Invite command"/>
<div className="border-t border-gray-200 dark:border-gray-700 py-5"> <div className="border-t border-gray-200 dark:border-gray-700 py-5">
<div className="px-4 space-y-1 mb-8"> <div className="px-4 space-y-1 mb-8">
<DialogTitle className="text-lg font-medium text-gray-900 dark:text-white">Channels</DialogTitle> <DialogTitle className="text-lg font-medium text-gray-900 dark:text-white">Channels</DialogTitle>
<p className="text-sm text-gray-500 dark:text-gray-400"> <p className="text-sm text-gray-500 dark:text-gray-400">
Channels are added when you setup IRC indexers. Do not edit unless you know what you are doing. Channels are added when you setup IRC indexers. Do not edit unless you know what you are doing.
</p> </p>
</div> </div>
<ChannelsFieldArray channels={values.channels} /> <ChannelsFieldArray channels={values.channels}/>
</div> </div>
</div> </div>
)} )}
@ -438,9 +470,10 @@ interface SelectFieldProps<T> {
name: string; name: string;
label: string; label: string;
options: OptionBasicTyped<T>[] options: OptionBasicTyped<T>[]
placeholder?: string;
} }
function SelectField<T>({ name, label, options }: SelectFieldProps<T>) { export function SelectField<T>({ name, label, options, placeholder }: SelectFieldProps<T>) {
return ( return (
<div className="flex items-center justify-between space-y-1 px-4 sm:space-y-0 sm:grid sm:grid-cols-3 sm:gap-4"> <div className="flex items-center justify-between space-y-1 px-4 sm:space-y-0 sm:grid sm:grid-cols-3 sm:gap-4">
<div> <div>
@ -454,9 +487,9 @@ function SelectField<T>({ name, label, options }: SelectFieldProps<T>) {
<div className="sm:col-span-2"> <div className="sm:col-span-2">
<Field name={name} type="select"> <Field name={name} type="select">
{({ {({
field, field,
form: { setFieldValue, resetForm } form: { setFieldValue }
}: FieldProps) => ( }: FieldProps) => (
<Select <Select
{...field} {...field}
id={name} id={name}
@ -470,7 +503,7 @@ function SelectField<T>({ name, label, options }: SelectFieldProps<T>) {
IndicatorSeparator: common.IndicatorSeparator, IndicatorSeparator: common.IndicatorSeparator,
DropdownIndicator: common.DropdownIndicator DropdownIndicator: common.DropdownIndicator
}} }}
placeholder="Choose a type" placeholder={placeholder ?? "Choose a type"}
styles={{ styles={{
singleValue: (base) => ({ singleValue: (base) => ({
...base, ...base,
@ -487,14 +520,18 @@ function SelectField<T>({ name, label, options }: SelectFieldProps<T>) {
})} })}
value={field?.value && options.find(o => o.value == field?.value)} value={field?.value && options.find(o => o.value == field?.value)}
onChange={(option) => { onChange={(option) => {
resetForm(); // resetForm();
// const opt = option as SelectOption; if (option !== null) {
// setFieldValue("name", option?.label ?? "") // const opt = option as SelectOption;
setFieldValue( // setFieldValue("name", option?.label ?? "")
field.name, setFieldValue(
option.value ?? "" field.name,
); option.value ?? ""
);
} else {
setFieldValue(field.name, undefined);
}
}} }}
options={options} options={options}
/> />

View file

@ -0,0 +1,265 @@
import { Fragment } from "react";
import { Form, Formik, FormikValues } from "formik";
import { Dialog, DialogPanel, DialogTitle, Transition, TransitionChild } from "@headlessui/react";
import { XMarkIcon } from "@heroicons/react/24/solid";
import { useMutation, useQueryClient } from "@tanstack/react-query";
import { toast } from "react-hot-toast";
import { AddProps } from "@forms/settings/IndexerForms";
import { DEBUG } from "@components/debug.tsx";
import { PasswordFieldWide, SwitchGroupWide, TextFieldWide } from "@components/inputs";
import { SelectFieldBasic } from "@components/inputs/select_wide";
import { ProxyTypeOptions } from "@domain/constants";
import { APIClient } from "@api/APIClient";
import { ProxyKeys } from "@api/query_keys";
import Toast from "@components/notifications/Toast";
import { SlideOver } from "@components/panels";
export function ProxyAddForm({ isOpen, toggle }: AddProps) {
const queryClient = useQueryClient();
const createMutation = useMutation({
mutationFn: (req: ProxyCreate) => APIClient.proxy.store(req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() });
toast.custom((t) => <Toast type="success" body="Proxy added!" t={t} />);
toggle();
},
onError: () => {
toast.custom((t) => <Toast type="error" body="Proxy could not be added" t={t} />);
}
});
const onSubmit = (formData: FormikValues) => {
createMutation.mutate(formData as ProxyCreate);
}
const testMutation = useMutation({
mutationFn: (data: Proxy) => APIClient.proxy.test(data),
onError: (err) => {
console.error(err);
}
});
const testProxy = (data: unknown) => testMutation.mutate(data as Proxy);
const initialValues: ProxyCreate = {
enabled: true,
name: "Proxy",
type: "SOCKS5",
addr: "socks5://ip:port",
user: "",
pass: "",
}
return (
<Transition show={isOpen} as={Fragment}>
<Dialog
as="div"
static
className="fixed inset-0 overflow-hidden"
open={isOpen}
onClose={toggle}
>
<div className="absolute inset-0 overflow-hidden">
<DialogPanel className="absolute inset-y-0 right-0 pl-10 max-w-full flex sm:pl-16">
<TransitionChild
as={Fragment}
enter="transform transition ease-in-out duration-500 sm:duration-700"
enterFrom="translate-x-full"
enterTo="translate-x-0"
leave="transform transition ease-in-out duration-500 sm:duration-700"
leaveFrom="translate-x-0"
leaveTo="translate-x-full"
>
<div className="w-screen max-w-2xl dark:border-gray-700 border-l">
<Formik
enableReinitialize={true}
initialValues={initialValues}
onSubmit={onSubmit}
>
{({ values }) => (
<Form className="h-full flex flex-col bg-white dark:bg-gray-800 shadow-xl overflow-y-auto">
<div className="flex-1">
<div className="px-4 py-6 bg-gray-50 dark:bg-gray-900 sm:px-6">
<div className="flex items-start justify-between space-x-3">
<div className="space-y-1">
<DialogTitle className="text-lg font-medium text-gray-900 dark:text-white">
Add proxy
</DialogTitle>
<p className="text-sm text-gray-500 dark:text-gray-200">
Add proxy to be used with Indexers or IRC.
</p>
</div>
<div className="h-7 flex items-center">
<button
type="button"
className="bg-white dark:bg-gray-700 rounded-md text-gray-400 hover:text-gray-500 focus:outline-none focus:ring-2 focus:ring-blue-500"
onClick={toggle}
>
<span className="sr-only">Close panel</span>
<XMarkIcon className="h-6 w-6" aria-hidden="true" />
</button>
</div>
</div>
</div>
<div className="py-6 space-y-4 divide-y divide-gray-200 dark:divide-gray-700">
<SwitchGroupWide name="enabled" label="Enabled" />
<TextFieldWide name="name" label="Name" defaultValue="" required={true} />
<SelectFieldBasic
name="type"
label="Proxy type"
options={ProxyTypeOptions}
tooltip={<span>Proxy type. Commonly SOCKS5.</span>}
help="Usually SOCKS5"
/>
<TextFieldWide name="addr" label="Addr" required={true} help="Addr: scheme://ip:port or scheme://domain" autoComplete="off"/>
</div>
<div>
<TextFieldWide name="user" label="User" help="auth: username" autoComplete="off" />
<PasswordFieldWide name="pass" label="Pass" help="auth: password" autoComplete="off"/>
</div>
</div>
<div
className="flex-shrink-0 px-4 border-t border-gray-200 dark:border-gray-700 py-5 sm:px-6">
<div className="space-x-3 flex justify-end">
<button
type="button"
className="bg-white dark:bg-gray-700 py-2 px-4 border border-gray-300 dark:border-gray-600 rounded-md shadow-sm text-sm font-medium text-gray-700 dark:text-gray-200 hover:bg-gray-50 dark:hover:bg-gray-600 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 dark:focus:ring-blue-500"
onClick={() => testProxy(values)}
>
Test
</button>
<button
type="button"
className="bg-white dark:bg-gray-700 py-2 px-4 border border-gray-300 dark:border-gray-600 rounded-md shadow-sm text-sm font-medium text-gray-700 dark:text-gray-200 hover:bg-gray-50 dark:hover:bg-gray-600 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 dark:focus:ring-blue-500"
onClick={toggle}
>
Cancel
</button>
<button
type="submit"
className="inline-flex justify-center py-2 px-4 border border-transparent shadow-sm text-sm font-medium rounded-md text-white bg-blue-600 dark:bg-blue-600 hover:bg-blue-700 dark:hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 dark:focus:ring-blue-500"
>
Save
</button>
</div>
</div>
<DEBUG values={values}/>
</Form>
)}
</Formik>
</div>
</TransitionChild>
</DialogPanel>
</div>
</Dialog>
</Transition>
);
}
interface UpdateFormProps<T> {
isOpen: boolean;
toggle: () => void;
data: T;
}
export function ProxyUpdateForm({ isOpen, toggle, data }: UpdateFormProps<Proxy>) {
const queryClient = useQueryClient();
const updateMutation = useMutation({
mutationFn: (req: Proxy) => APIClient.proxy.update(req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() });
toast.custom((t) => <Toast type="success" body={`Proxy ${data.name} updated!`} t={t} />);
toggle();
},
onError: () => {
toast.custom((t) => <Toast type="error" body="Proxy could not be updated" t={t} />);
}
});
const onSubmit = (formData: unknown) => {
updateMutation.mutate(formData as Proxy);
}
const deleteMutation = useMutation({
mutationFn: (proxyId: number) => APIClient.proxy.delete(proxyId),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() });
toast.custom((t) => <Toast type="success" body={`Proxy ${data.name} was deleted.`} t={t}/>);
}
});
const deleteFn = () => deleteMutation.mutate(data.id);
const testMutation = useMutation({
mutationFn: (data: Proxy) => APIClient.proxy.test(data),
onError: (err) => {
console.error(err);
}
});
const testProxy = (data: unknown) => testMutation.mutate(data as Proxy);
const initialValues: Proxy = {
id: data.id,
enabled: data.enabled,
name: data.name,
type: data.type,
addr: data.addr,
user: data.user,
pass: data.pass,
}
return (
<SlideOver<Proxy>
title="Proxy"
initialValues={initialValues}
onSubmit={onSubmit}
deleteAction={deleteFn}
testFn={testProxy}
isOpen={isOpen}
toggle={toggle}
type="UPDATE"
>
{() => (
<div>
<div className="py-6 space-y-4 divide-y divide-gray-200 dark:divide-gray-700">
<SwitchGroupWide name="enabled" label="Enabled"/>
<TextFieldWide name="name" label="Name" defaultValue="" required={true}/>
<SelectFieldBasic
name="type"
label="Proxy type"
required={true}
options={ProxyTypeOptions}
tooltip={<span>Proxy type. Commonly SOCKS5.</span>}
help="Usually SOCKS5"
/>
<TextFieldWide name="addr" label="Addr" required={true} help="Addr: scheme://ip:port or scheme://domain" autoComplete="off"/>
</div>
<div>
<TextFieldWide name="user" label="User" help="auth: username" autoComplete="off"/>
<PasswordFieldWide name="pass" label="Pass" help="auth: password" autoComplete="off"/>
</div>
</div>
)}
</SlideOver>
);
}

View file

@ -11,7 +11,7 @@ import {
notFound, notFound,
Outlet, Outlet,
redirect, redirect,
} from "@tanstack/react-router"; } from "@tanstack/react-router";
import { z } from "zod"; import { z } from "zod";
import { QueryClient } from "@tanstack/react-query"; import { QueryClient } from "@tanstack/react-query";
@ -30,7 +30,8 @@ import {
FilterByIdQueryOptions, FilterByIdQueryOptions,
IndexersQueryOptions, IndexersQueryOptions,
IrcQueryOptions, IrcQueryOptions,
NotificationsQueryOptions NotificationsQueryOptions,
ProxiesQueryOptions
} from "@api/queries"; } from "@api/queries";
import LogSettings from "@screens/settings/Logs"; import LogSettings from "@screens/settings/Logs";
import NotificationSettings from "@screens/settings/Notifications"; import NotificationSettings from "@screens/settings/Notifications";
@ -50,6 +51,7 @@ import { AuthContext, SettingsContext } from "@utils/Context";
import { TanStackRouterDevtools } from "@tanstack/router-devtools"; import { TanStackRouterDevtools } from "@tanstack/router-devtools";
import { ReactQueryDevtools } from "@tanstack/react-query-devtools"; import { ReactQueryDevtools } from "@tanstack/react-query-devtools";
import { queryClient } from "@api/QueryClient"; import { queryClient } from "@api/QueryClient";
import ProxySettings from "@screens/settings/Proxy";
import { ErrorPage } from "@components/alerts"; import { ErrorPage } from "@components/alerts";
@ -212,6 +214,13 @@ export const SettingsApiRoute = createRoute({
component: APISettings component: APISettings
}); });
export const SettingsProxiesRoute = createRoute({
getParentRoute: () => SettingsRoute,
path: 'proxies',
loader: (opts) => opts.context.queryClient.ensureQueryData(ProxiesQueryOptions()),
component: ProxySettings
});
export const SettingsReleasesRoute = createRoute({ export const SettingsReleasesRoute = createRoute({
getParentRoute: () => SettingsRoute, getParentRoute: () => SettingsRoute,
path: 'releases', path: 'releases',
@ -339,7 +348,7 @@ export const RootRoute = createRootRouteWithContext<{
}); });
const filterRouteTree = FiltersRoute.addChildren([FilterIndexRoute, FilterGetByIdRoute.addChildren([FilterGeneralRoute, FilterMoviesTvRoute, FilterMusicRoute, FilterAdvancedRoute, FilterExternalRoute, FilterActionsRoute])]) const filterRouteTree = FiltersRoute.addChildren([FilterIndexRoute, FilterGetByIdRoute.addChildren([FilterGeneralRoute, FilterMoviesTvRoute, FilterMusicRoute, FilterAdvancedRoute, FilterExternalRoute, FilterActionsRoute])])
const settingsRouteTree = SettingsRoute.addChildren([SettingsIndexRoute, SettingsLogRoute, SettingsIndexersRoute, SettingsIrcRoute, SettingsFeedsRoute, SettingsClientsRoute, SettingsNotificationsRoute, SettingsApiRoute, SettingsReleasesRoute, SettingsAccountRoute]) const settingsRouteTree = SettingsRoute.addChildren([SettingsIndexRoute, SettingsLogRoute, SettingsIndexersRoute, SettingsIrcRoute, SettingsFeedsRoute, SettingsClientsRoute, SettingsNotificationsRoute, SettingsApiRoute, SettingsProxiesRoute, SettingsReleasesRoute, SettingsAccountRoute])
const authenticatedTree = AuthRoute.addChildren([AuthIndexRoute.addChildren([DashboardRoute, filterRouteTree, ReleasesRoute, settingsRouteTree, LogsRoute])]) const authenticatedTree = AuthRoute.addChildren([AuthIndexRoute.addChildren([DashboardRoute, filterRouteTree, ReleasesRoute, settingsRouteTree, LogsRoute])])
const routeTree = RootRoute.addChildren([ const routeTree = RootRoute.addChildren([
authenticatedTree, authenticatedTree,

View file

@ -8,6 +8,7 @@ import {
ChatBubbleLeftRightIcon, ChatBubbleLeftRightIcon,
CogIcon, CogIcon,
FolderArrowDownIcon, FolderArrowDownIcon,
GlobeAltIcon,
KeyIcon, KeyIcon,
RectangleStackIcon, RectangleStackIcon,
RssIcon, RssIcon,
@ -34,6 +35,7 @@ const subNavigation: NavTabType[] = [
{ name: "Clients", href: "/settings/clients", icon: FolderArrowDownIcon }, { name: "Clients", href: "/settings/clients", icon: FolderArrowDownIcon },
{ name: "Notifications", href: "/settings/notifications", icon: BellIcon }, { name: "Notifications", href: "/settings/notifications", icon: BellIcon },
{ name: "API keys", href: "/settings/api", icon: KeyIcon }, { name: "API keys", href: "/settings/api", icon: KeyIcon },
{ name: "Proxies", href: "/settings/proxies", icon: GlobeAltIcon },
{ name: "Releases", href: "/settings/releases", icon: RectangleStackIcon }, { name: "Releases", href: "/settings/releases", icon: RectangleStackIcon },
{ name: "Account", href: "/settings/account", icon: UserCircleIcon } { name: "Account", href: "/settings/account", icon: UserCircleIcon }
// {name: 'Regex Playground', href: 'regex-playground', icon: CogIcon, current: false} // {name: 'Regex Playground', href: 'regex-playground', icon: CogIcon, current: false}

View file

@ -0,0 +1,140 @@
import { useToggle } from "@hooks/hooks.ts";
import { useMutation, useQueryClient, useSuspenseQuery } from "@tanstack/react-query";
import { PlusIcon } from "@heroicons/react/24/solid";
import { toast } from "react-hot-toast";
import { APIClient } from "@api/APIClient";
import { ProxyKeys } from "@api/query_keys";
import { ProxiesQueryOptions } from "@api/queries";
import { Section } from "./_components";
import { EmptySimple } from "@components/emptystates";
import { Checkbox } from "@components/Checkbox";
import { ProxyAddForm, ProxyUpdateForm } from "@forms/settings/ProxyForms";
import Toast from "@components/notifications/Toast";
interface ListItemProps {
proxy: Proxy;
}
function ListItem({ proxy }: ListItemProps) {
const [isOpen, toggleUpdate] = useToggle(false);
const queryClient = useQueryClient();
const updateMutation = useMutation({
mutationFn: (req: Proxy) => APIClient.proxy.update(req),
onSuccess: () => {
queryClient.invalidateQueries({ queryKey: ProxyKeys.lists() });
toast.custom(t => <Toast type="success" body={`Proxy ${proxy.name} was ${proxy.enabled ? "enabled" : "disabled"} successfully.`} t={t} />);
},
onError: () => {
toast.custom((t) => <Toast type="error" body="Proxy state could not be updated" t={t} />);
}
});
const onToggleMutation = (newState: boolean) => {
updateMutation.mutate({
...proxy,
enabled: newState
});
};
return (
<li>
<ProxyUpdateForm isOpen={isOpen} toggle={toggleUpdate} data={proxy} />
<div className="grid grid-cols-12 items-center py-1.5">
<div className="col-span-2 sm:col-span-1 flex pl-1 sm:pl-5 items-center">
<Checkbox value={proxy.enabled ?? false} setValue={onToggleMutation} />
</div>
<div className="col-span-7 sm:col-span-8 pl-12 sm:pr-6 py-3 block flex-col text-sm font-medium text-gray-900 dark:text-white truncate">
{proxy.name}
</div>
<div className="hidden md:block col-span-2 pr-6 py-3 text-left items-center whitespace-nowrap text-sm text-gray-500 dark:text-gray-400 truncate">
{proxy.type}
</div>
<div className="col-span-1 flex first-letter:px-6 py-3 whitespace-nowrap text-right text-sm font-medium">
<span
className="col-span-1 px-6 text-blue-600 dark:text-gray-300 hover:text-blue-900 dark:hover:text-blue-500 cursor-pointer"
onClick={toggleUpdate}
>
Edit
</span>
</div>
</div>
</li>
);
}
function ProxySettings() {
const [addProxyIsOpen, toggleAddProxy] = useToggle(false);
const proxiesQuery = useSuspenseQuery(ProxiesQueryOptions())
const proxies = proxiesQuery.data
return (
<Section
title="Proxies"
description={
<>
Proxies that can be used with Indexers, feeds and IRC.<br/>
</>
}
rightSide={
<button
type="button"
onClick={toggleAddProxy}
className="relative inline-flex items-center px-4 py-2 border border-transparent shadow-sm text-sm font-medium rounded-md text-white bg-blue-600 dark:bg-blue-600 hover:bg-blue-700 dark:hover:bg-blue-700 focus:outline-none focus:ring-2 focus:ring-offset-2 focus:ring-blue-500 dark:focus:ring-blue-500"
>
<PlusIcon className="h-5 w-5 mr-1"/>
Add new
</button>
}
>
<ProxyAddForm isOpen={addProxyIsOpen} toggle={toggleAddProxy} />
<div className="flex flex-col">
{proxies.length ? (
<ul className="min-w-full relative">
<li className="grid grid-cols-12 border-b border-gray-200 dark:border-gray-700">
<div
className="flex col-span-2 sm:col-span-1 pl-0 sm:pl-3 py-3 text-left text-xs font-medium text-gray-500 dark:text-gray-400 hover:text-gray-800 hover:dark:text-gray-250 transition-colors uppercase tracking-wider cursor-pointer"
// onClick={() => sortedIndexers.requestSort("enabled")}
>
Enabled
{/*<span className="sort-indicator">{sortedIndexers.getSortIndicator("enabled")}</span>*/}
</div>
<div
className="col-span-7 sm:col-span-8 pl-12 py-3 text-left text-xs font-medium text-gray-500 dark:text-gray-400 hover:text-gray-800 hover:dark:text-gray-250 transition-colors uppercase tracking-wider cursor-pointer"
// onClick={() => sortedIndexers.requestSort("name")}
>
Name
{/*<span className="sort-indicator">{sortedIndexers.getSortIndicator("name")}</span>*/}
</div>
<div
className="hidden md:flex col-span-1 py-3 text-left text-xs font-medium text-gray-500 dark:text-gray-400 hover:text-gray-800 hover:dark:text-gray-250 transition-colors uppercase tracking-wider cursor-pointer"
// onClick={() => sortedIndexers.requestSort("implementation")}
>
Type
{/*<span className="sort-indicator">{sortedIndexers.getSortIndicator("implementation")}</span>*/}
</div>
</li>
{proxies.map((proxy) => (
<ListItem proxy={proxy} key={proxy.id}/>
))}
</ul>
) : (
<EmptySimple
title="No proxies"
subtitle=""
buttonText="Add new proxy"
buttonAction={toggleAddProxy}
/>
)}
</div>
</Section>
);
}
export default ProxySettings;

View file

@ -11,6 +11,7 @@ export { default as Indexer } from "./Indexer";
export { default as Irc } from "./Irc"; export { default as Irc } from "./Irc";
export { default as Logs } from "./Logs"; export { default as Logs } from "./Logs";
export { default as Notification } from "./Notifications"; export { default as Notification } from "./Notifications";
export { default as Proxy } from "./Proxy";
export { default as Release } from "./Releases"; export { default as Release } from "./Releases";
export { default as RegexPlayground } from "./RegexPlayground"; export { default as RegexPlayground } from "./RegexPlayground";
export { default as Account } from "./Account"; export { default as Account } from "./Account";

View file

@ -11,6 +11,8 @@ interface Indexer {
enabled: boolean; enabled: boolean;
implementation: string; implementation: string;
base_url: string; base_url: string;
use_proxy?: boolean;
proxy_id?: number;
settings: Array<IndexerSetting>; settings: Array<IndexerSetting>;
} }
@ -35,6 +37,8 @@ interface IndexerDefinition {
protocol: string; protocol: string;
urls: string[]; urls: string[];
supports: string[]; supports: string[];
use_proxy?: boolean;
proxy_id?: number;
settings: IndexerSetting[]; settings: IndexerSetting[];
irc: IndexerIRC; irc: IndexerIRC;
torznab: IndexerTorznab; torznab: IndexerTorznab;

View file

@ -20,6 +20,8 @@ interface IrcNetwork {
channels: IrcChannel[]; channels: IrcChannel[];
connected: boolean; connected: boolean;
connected_since: string; connected_since: string;
use_proxy: boolean;
proxy_id: number;
} }
interface IrcNetworkCreate { interface IrcNetworkCreate {
@ -53,23 +55,8 @@ interface IrcChannelWithHealth extends IrcChannel {
last_announce: string; last_announce: string;
} }
interface IrcNetworkWithHealth { interface IrcNetworkWithHealth extends IrcNetwork {
id: number;
name: string;
enabled: boolean;
server: string;
port: number;
tls: boolean;
pass: string;
nick: string;
auth: IrcAuth; // optional
invite_command: string;
use_bouncer: boolean;
bouncer_addr: string;
bot_mode: boolean;
channels: IrcChannelWithHealth[]; channels: IrcChannelWithHealth[];
connected: boolean;
connected_since: string;
connection_errors: string[]; connection_errors: string[];
healthy: boolean; healthy: boolean;
} }

27
web/src/types/Proxy.d.ts vendored Normal file
View file

@ -0,0 +1,27 @@
/*
* Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
* SPDX-License-Identifier: GPL-2.0-or-later
*/
interface Proxy {
id: number;
name: string;
enabled: boolean;
type: ProxyType;
addr: string;
user?: string;
pass?: string;
timeout?: number;
}
interface ProxyCreate {
name: string;
enabled: boolean;
type: ProxyType;
addr: string;
user?: string;
pass?: string;
timeout?: number;
}
type ProxyType = "SOCKS5" | "HTTP";