feat: add support for proxies to use with IRC and Indexers (#1421)

* feat: add support for proxies

* fix(http): release handler

* fix(migrations): define proxy early

* fix(migrations): pg proxy

* fix(proxy): list update delete

* fix(proxy): remove log and imports

* feat(irc): use proxy

* feat(irc): tests

* fix(web): update imports for ProxyForms.tsx

* fix(database): migration

* feat(proxy): test

* feat(proxy): validate proxy type

* feat(proxy): validate and test

* feat(proxy): improve validate and test

* feat(proxy): fix db schema

* feat(proxy): add db tests

* feat(proxy): handle http errors

* fix(http): imports

* feat(proxy): use proxy for indexer downloads

* feat(proxy): indexerforms select proxy

* feat(proxy): handle torrent download

* feat(proxy): skip if disabled

* feat(proxy): imports

* feat(proxy): implement in Feeds

* feat(proxy): update helper text indexer proxy

* feat(proxy): add internal cache
This commit is contained in:
ze0s 2024-09-02 11:10:45 +02:00 committed by GitHub
parent 472d327308
commit bc0f4cc055
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
59 changed files with 2533 additions and 371 deletions

View file

@ -140,9 +140,8 @@ func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, a
return nil, nil
} else {
if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
return nil, err
if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
}
@ -243,11 +242,8 @@ func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, a
return nil, nil
} else {
if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
return nil, err
}
if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
t, err := os.ReadFile(release.TorrentTmpFile)

View file

@ -74,10 +74,8 @@ func (s *service) porla(ctx context.Context, action *domain.Action, release doma
return nil, nil
} else {
if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
return nil, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName)
}
if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
file, err := os.Open(release.TorrentTmpFile)

View file

@ -57,10 +57,8 @@ func (s *service) qbittorrent(ctx context.Context, action *domain.Action, releas
return nil, nil
}
if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
return nil, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName)
}
if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
options, err := s.prepareQbitOptions(action)

View file

@ -68,11 +68,8 @@ func (s *service) rtorrent(ctx context.Context, action *domain.Action, release d
return nil, nil
}
if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
return nil, err
}
if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
tmpFile, err := os.ReadFile(release.TorrentTmpFile)

View file

@ -36,9 +36,8 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release
return nil, errors.New("action %s client %s %s not enabled, skipping", action.Name, action.Client.Type, action.Client.Name)
}
// if set, try to resolve MagnetURI before parsing macros
// to allow webhook and exec to get the magnet_uri
if err := release.ResolveMagnetUri(ctx); err != nil {
// Check preconditions: download torrent file if needed
if err := s.CheckActionPreconditions(ctx, action, release); err != nil {
return nil, err
}
@ -137,6 +136,30 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release
return rejections, err
}
func (s *service) CheckActionPreconditions(ctx context.Context, action *domain.Action, release *domain.Release) error {
if err := s.downloadSvc.ResolveMagnetURI(ctx, release); err != nil {
return errors.Wrap(err, "could not resolve magnet uri: %s", release.MagnetURI)
}
// parse all macros in one go
if action.CheckMacrosNeedTorrentTmpFile(release) {
if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
return errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
}
if action.CheckMacrosNeedRawDataBytes(release) {
tmpFile, err := os.ReadFile(release.TorrentTmpFile)
if err != nil {
return errors.Wrap(err, "could not read torrent file: %v", release.TorrentTmpFile)
}
release.TorrentDataRawBytes = tmpFile
}
return nil
}
func (s *service) test(name string) {
s.log.Info().Msgf("action TEST: %v", name)
}

View file

@ -12,6 +12,7 @@ import (
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/download_client"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/releasedownload"
"github.com/autobrr/autobrr/pkg/sharedhttp"
"github.com/asaskevich/EventBus"
@ -33,21 +34,23 @@ type Service interface {
}
type service struct {
log zerolog.Logger
subLogger *log.Logger
repo domain.ActionRepo
clientSvc download_client.Service
bus EventBus.Bus
log zerolog.Logger
subLogger *log.Logger
repo domain.ActionRepo
clientSvc download_client.Service
downloadSvc *releasedownload.DownloadService
bus EventBus.Bus
httpClient *http.Client
}
func NewService(log logger.Logger, repo domain.ActionRepo, clientSvc download_client.Service, bus EventBus.Bus) Service {
func NewService(log logger.Logger, repo domain.ActionRepo, clientSvc download_client.Service, downloadSvc *releasedownload.DownloadService, bus EventBus.Bus) Service {
s := &service{
log: log.With().Str("module", "action").Logger(),
repo: repo,
clientSvc: clientSvc,
bus: bus,
log: log.With().Str("module", "action").Logger(),
repo: repo,
clientSvc: clientSvc,
downloadSvc: downloadSvc,
bus: bus,
httpClient: &http.Client{
Timeout: time.Second * 120,

View file

@ -107,11 +107,8 @@ func (s *service) transmission(ctx context.Context, action *domain.Action, relea
return nil, nil
}
if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %s", release.TorrentName)
return nil, err
}
if err := s.downloadSvc.DownloadRelease(ctx, &release); err != nil {
return nil, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
b64, err := transmissionrpc.File2Base64(release.TorrentTmpFile)

View file

@ -36,6 +36,8 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
"i.identifier",
"i.identifier_external",
"i.name",
"i.use_proxy",
"i.proxy_id",
"f.name",
"f.type",
"f.enabled",
@ -66,8 +68,9 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
var f domain.Feed
var apiKey, cookie, settings sql.NullString
var proxyID sql.NullInt64
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
@ -75,6 +78,7 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
return nil, errors.Wrap(err, "error scanning row")
}
f.ProxyID = proxyID.Int64
f.ApiKey = apiKey.String
f.Cookie = cookie.String
@ -98,6 +102,8 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
"i.identifier",
"i.identifier_external",
"i.name",
"i.use_proxy",
"i.proxy_id",
"f.name",
"f.type",
"f.enabled",
@ -128,8 +134,9 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
var f domain.Feed
var apiKey, cookie, settings sql.NullString
var proxyID sql.NullInt64
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
@ -137,6 +144,7 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
return nil, errors.Wrap(err, "error scanning row")
}
f.ProxyID = proxyID.Int64
f.ApiKey = apiKey.String
f.Cookie = cookie.String
@ -158,6 +166,8 @@ func (r *FeedRepo) Find(ctx context.Context) ([]domain.Feed, error) {
"i.identifier",
"i.identifier_external",
"i.name",
"i.use_proxy",
"i.proxy_id",
"f.name",
"f.type",
"f.enabled",
@ -196,10 +206,13 @@ func (r *FeedRepo) Find(ctx context.Context) ([]domain.Feed, error) {
var apiKey, cookie, lastRunData, settings sql.NullString
var lastRun sql.NullTime
if err := rows.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &lastRun, &lastRunData, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
var proxyID sql.NullInt64
if err := rows.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.UseProxy, &proxyID, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &lastRun, &lastRunData, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
f.ProxyID = proxyID.Int64
f.LastRun = lastRun.Time
f.LastRunData = lastRunData.String
f.ApiKey = apiKey.String

View file

@ -36,8 +36,8 @@ func (r *IndexerRepo) Store(ctx context.Context, indexer domain.Indexer) (*domai
}
queryBuilder := r.db.squirrel.
Insert("indexer").Columns("enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings").
Values(indexer.Enabled, indexer.Name, indexer.Identifier, indexer.IdentifierExternal, indexer.Implementation, indexer.BaseURL, settings).
Insert("indexer").Columns("enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
Values(indexer.Enabled, indexer.Name, indexer.Identifier, indexer.IdentifierExternal, indexer.Implementation, indexer.BaseURL, indexer.UseProxy, toNullInt64(indexer.ProxyID), settings).
Suffix("RETURNING id").RunWith(r.db.handler)
// return values
@ -61,6 +61,8 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma
Set("name", indexer.Name).
Set("identifier_external", indexer.IdentifierExternal).
Set("base_url", indexer.BaseURL).
Set("use_proxy", indexer.UseProxy).
Set("proxy_id", toNullInt64(indexer.ProxyID)).
Set("settings", settings).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": indexer.ID})
@ -70,16 +72,26 @@ func (r *IndexerRepo) Update(ctx context.Context, indexer domain.Indexer) (*doma
return nil, errors.Wrap(err, "error building query")
}
if _, err = r.db.handler.ExecContext(ctx, query, args...); err != nil {
result, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return nil, errors.Wrap(err, "error rows affected")
}
if rowsAffected == 0 {
return nil, domain.ErrUpdateFailed
}
return &indexer, nil
}
func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings").
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer").
OrderBy("name ASC")
@ -98,27 +110,29 @@ func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) {
indexers := make([]domain.Indexer, 0)
for rows.Next() {
var f domain.Indexer
var i domain.Indexer
var identifierExternal, implementation, baseURL sql.Null[string]
var proxyID sql.Null[int64]
var settings string
var settingsMap map[string]string
if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &f.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil {
if err := rows.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
f.IdentifierExternal = identifierExternal.V
f.Implementation = implementation.V
f.BaseURL = baseURL.V
i.IdentifierExternal = identifierExternal.V
i.Implementation = implementation.V
i.BaseURL = baseURL.V
i.ProxyID = proxyID.V
if err = json.Unmarshal([]byte(settings), &settingsMap); err != nil {
return nil, errors.Wrap(err, "error unmarshal settings")
}
f.Settings = settingsMap
i.Settings = settingsMap
indexers = append(indexers, f)
indexers = append(indexers, i)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "error rows")
@ -129,7 +143,7 @@ func (r *IndexerRepo) List(ctx context.Context) ([]domain.Indexer, error) {
func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings").
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer").
Where(sq.Eq{"id": id})
@ -146,8 +160,9 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
var i domain.Indexer
var identifierExternal, implementation, baseURL, settings sql.Null[string]
var proxyID sql.Null[int64]
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil {
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
@ -158,6 +173,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
i.IdentifierExternal = identifierExternal.V
i.Implementation = implementation.V
i.BaseURL = baseURL.V
i.ProxyID = proxyID.V
var settingsMap map[string]string
if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil {
@ -171,7 +187,7 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (*domain.Indexer, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "settings").
Select("id", "enabled", "name", "identifier", "identifier_external", "implementation", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer")
if req.ID > 0 {
@ -195,8 +211,9 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (
var i domain.Indexer
var identifierExternal, implementation, baseURL, settings sql.Null[string]
var proxyID sql.Null[int64]
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil {
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
@ -207,6 +224,7 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (
i.IdentifierExternal = identifierExternal.V
i.Implementation = implementation.V
i.BaseURL = baseURL.V
i.ProxyID = proxyID.V
var settingsMap map[string]string
if err = json.Unmarshal([]byte(settings.V), &settingsMap); err != nil {
@ -220,7 +238,7 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (
func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Indexer, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "identifier", "identifier_external", "base_url", "settings").
Select("id", "enabled", "name", "identifier", "identifier_external", "base_url", "use_proxy", "proxy_id", "settings").
From("indexer").
Join("filter_indexer ON indexer.id = filter_indexer.indexer_id").
Where(sq.Eq{"filter_indexer.filter_id": id})
@ -239,13 +257,14 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde
indexers := make([]domain.Indexer, 0)
for rows.Next() {
var f domain.Indexer
var i domain.Indexer
var settings string
var settingsMap map[string]string
var identifierExternal, baseURL sql.Null[string]
var proxyID sql.Null[int64]
if err := rows.Scan(&f.ID, &f.Enabled, &f.Name, &f.Identifier, &identifierExternal, &baseURL, &settings); err != nil {
if err := rows.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &baseURL, &i.UseProxy, &proxyID, &settings); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
@ -253,11 +272,12 @@ func (r *IndexerRepo) FindByFilterID(ctx context.Context, id int) ([]domain.Inde
return nil, errors.Wrap(err, "error unmarshal settings")
}
f.IdentifierExternal = identifierExternal.V
f.BaseURL = baseURL.V
f.Settings = settingsMap
i.IdentifierExternal = identifierExternal.V
i.BaseURL = baseURL.V
i.ProxyID = proxyID.V
i.Settings = settingsMap
indexers = append(indexers, f)
indexers = append(indexers, i)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "error rows")
@ -282,13 +302,13 @@ func (r *IndexerRepo) Delete(ctx context.Context, id int) error {
return errors.Wrap(err, "error executing query")
}
rows, err := result.RowsAffected()
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "error rows affected")
}
if rows != 1 {
return errors.New("error deleting row")
if rowsAffected == 0 {
return domain.ErrRecordNotFound
}
r.log.Debug().Str("method", "delete").Msgf("successfully deleted indexer with id %v", id)
@ -297,8 +317,6 @@ func (r *IndexerRepo) Delete(ctx context.Context, id int) error {
}
func (r *IndexerRepo) ToggleEnabled(ctx context.Context, indexerID int, enabled bool) error {
var err error
queryBuilder := r.db.squirrel.
Update("indexer").
Set("enabled", enabled).
@ -310,10 +328,19 @@ func (r *IndexerRepo) ToggleEnabled(ctx context.Context, indexerID int, enabled
return errors.Wrap(err, "error building query")
}
_, err = r.db.handler.ExecContext(ctx, query, args...)
result, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := result.RowsAffected()
if err != nil {
return errors.Wrap(err, "error rows affected")
}
if rowsAffected == 0 {
return domain.ErrUpdateFailed
}
return nil
}

View file

@ -30,7 +30,7 @@ func NewIrcRepo(log logger.Logger, db *DB) domain.IrcRepo {
func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode").
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network").
Where(sq.Eq{"id": id})
@ -42,26 +42,27 @@ func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetw
var n domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString
var account, password sql.NullString
var tls sql.NullBool
var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.Null[string]
var tls sql.Null[bool]
var proxyId sql.Null[int64]
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &nick, &n.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &n.UseBouncer, &n.BotMode); err != nil {
if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &nick, &n.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &n.UseBouncer, &n.BotMode, &n.UseProxy, &proxyId); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
n.TLS = tls.Bool
n.Pass = pass.String
n.Nick = nick.String
n.InviteCommand = inviteCmd.String
n.Auth.Account = account.String
n.Auth.Password = password.String
n.BouncerAddr = bouncerAddr.String
n.TLS = tls.V
n.Pass = pass.V
n.Nick = nick.V
n.InviteCommand = inviteCmd.V
n.BouncerAddr = bouncerAddr.V
n.Auth.Account = account.V
n.Auth.Password = password.V
n.ProxyId = proxyId.V
return &n, nil
}
@ -111,7 +112,7 @@ func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode").
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network").
Where(sq.Eq{"enabled": true})
@ -131,22 +132,24 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork,
for rows.Next() {
var net domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString
var account, password sql.NullString
var tls sql.NullBool
var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.Null[string]
var tls sql.Null[bool]
var proxyId sql.Null[int64]
if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil {
if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
net.TLS = tls.Bool
net.Pass = pass.String
net.Nick = nick.String
net.InviteCommand = inviteCmd.String
net.BouncerAddr = bouncerAddr.String
net.TLS = tls.V
net.Pass = pass.V
net.Nick = nick.V
net.InviteCommand = inviteCmd.V
net.BouncerAddr = bouncerAddr.V
net.Auth.Account = account.V
net.Auth.Password = password.V
net.Auth.Account = account.String
net.Auth.Password = password.String
net.ProxyId = proxyId.V
networks = append(networks, net)
}
@ -159,7 +162,7 @@ func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork,
func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode").
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network").
OrderBy("name ASC")
@ -179,22 +182,24 @@ func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error)
for rows.Next() {
var net domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString
var account, password sql.NullString
var tls sql.NullBool
var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.Null[string]
var tls sql.Null[bool]
var proxyId sql.Null[int64]
if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil {
if err := rows.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
net.TLS = tls.Bool
net.Pass = pass.String
net.Nick = nick.String
net.InviteCommand = inviteCmd.String
net.BouncerAddr = bouncerAddr.String
net.TLS = tls.V
net.Pass = pass.V
net.Nick = nick.V
net.InviteCommand = inviteCmd.V
net.BouncerAddr = bouncerAddr.V
net.Auth.Account = account.V
net.Auth.Password = password.V
net.Auth.Account = account.String
net.Auth.Password = password.String
net.ProxyId = proxyId.V
networks = append(networks, net)
}
@ -225,13 +230,13 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
var channels []domain.IrcChannel
for rows.Next() {
var ch domain.IrcChannel
var pass sql.NullString
var pass sql.Null[string]
if err := rows.Scan(&ch.ID, &ch.Name, &ch.Enabled, &pass); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
ch.Password = pass.String
ch.Password = pass.V
channels = append(channels, ch)
}
@ -244,7 +249,7 @@ func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) {
queryBuilder := r.db.squirrel.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode").
Select("id", "enabled", "name", "server", "port", "tls", "pass", "nick", "auth_mechanism", "auth_account", "auth_password", "invite_command", "bouncer_addr", "use_bouncer", "bot_mode", "use_proxy", "proxy_id").
From("irc_network").
Where(sq.Eq{"server": network.Server}).
Where(sq.Eq{"port": network.Port}).
@ -263,11 +268,12 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN
var net domain.IrcNetwork
var pass, nick, inviteCmd, bouncerAddr sql.NullString
var account, password sql.NullString
var tls sql.NullBool
var pass, nick, inviteCmd, bouncerAddr sql.Null[string]
var account, password sql.Null[string]
var tls sql.Null[bool]
var proxyId sql.Null[int64]
if err = row.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode); err != nil {
if err = row.Scan(&net.ID, &net.Enabled, &net.Name, &net.Server, &net.Port, &tls, &pass, &nick, &net.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &net.UseBouncer, &net.BotMode, &net.UseProxy, &proxyId); err != nil {
if errors.Is(err, sql.ErrNoRows) {
// no result is not an error in our case
return nil, nil
@ -276,29 +282,20 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN
return nil, errors.Wrap(err, "error scanning row")
}
net.TLS = tls.Bool
net.Pass = pass.String
net.Nick = nick.String
net.InviteCommand = inviteCmd.String
net.BouncerAddr = bouncerAddr.String
net.Auth.Account = account.String
net.Auth.Password = password.String
net.TLS = tls.V
net.Pass = pass.V
net.Nick = nick.V
net.InviteCommand = inviteCmd.V
net.BouncerAddr = bouncerAddr.V
net.Auth.Account = account.V
net.Auth.Password = password.V
net.ProxyId = proxyId.V
return &net, nil
}
func (r *IrcRepo) StoreNetwork(ctx context.Context, network *domain.IrcNetwork) error {
netName := toNullString(network.Name)
pass := toNullString(network.Pass)
nick := toNullString(network.Nick)
inviteCmd := toNullString(network.InviteCommand)
bouncerAddr := toNullString(network.BouncerAddr)
account := toNullString(network.Auth.Account)
password := toNullString(network.Auth.Password)
var retID int64
queryBuilder := r.db.squirrel.
Insert("irc_network").
Columns(
@ -319,60 +316,49 @@ func (r *IrcRepo) StoreNetwork(ctx context.Context, network *domain.IrcNetwork)
).
Values(
network.Enabled,
netName,
toNullString(network.Name),
network.Server,
network.Port,
network.TLS,
pass,
nick,
toNullString(network.Pass),
toNullString(network.Nick),
network.Auth.Mechanism,
account,
password,
inviteCmd,
bouncerAddr,
toNullString(network.Auth.Account),
toNullString(network.Auth.Password),
toNullString(network.InviteCommand),
toNullString(network.BouncerAddr),
network.UseBouncer,
network.BotMode,
).
Suffix("RETURNING id").
RunWith(r.db.handler)
if err := queryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil {
if err := queryBuilder.QueryRowContext(ctx).Scan(&network.ID); err != nil {
return errors.Wrap(err, "error executing query")
}
network.ID = retID
return nil
}
func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error {
netName := toNullString(network.Name)
pass := toNullString(network.Pass)
nick := toNullString(network.Nick)
inviteCmd := toNullString(network.InviteCommand)
bouncerAddr := toNullString(network.BouncerAddr)
account := toNullString(network.Auth.Account)
password := toNullString(network.Auth.Password)
var err error
queryBuilder := r.db.squirrel.
Update("irc_network").
Set("enabled", network.Enabled).
Set("name", netName).
Set("name", toNullString(network.Name)).
Set("server", network.Server).
Set("port", network.Port).
Set("tls", network.TLS).
Set("pass", pass).
Set("nick", nick).
Set("pass", toNullString(network.Pass)).
Set("nick", toNullString(network.Nick)).
Set("auth_mechanism", network.Auth.Mechanism).
Set("auth_account", account).
Set("auth_password", password).
Set("invite_command", inviteCmd).
Set("bouncer_addr", bouncerAddr).
Set("auth_account", toNullString(network.Auth.Account)).
Set("auth_password", toNullString(network.Auth.Password)).
Set("invite_command", toNullString(network.InviteCommand)).
Set("bouncer_addr", toNullString(network.BouncerAddr)).
Set("use_bouncer", network.UseBouncer).
Set("bot_mode", network.BotMode).
Set("use_proxy", network.UseProxy).
Set("proxy_id", toNullInt64(network.ProxyId)).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": network.ID})
@ -414,7 +400,6 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
for _, channel := range channels {
// values
pass := toNullString(channel.Password)
channelQueryBuilder := r.db.squirrel.
Insert("irc_channel").
@ -429,21 +414,17 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
channel.Enabled,
true,
channel.Name,
pass,
toNullString(channel.Password),
networkID,
).
Suffix("RETURNING id").
RunWith(tx)
// returning
var retID int64
if err = channelQueryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil {
if err = channelQueryBuilder.QueryRowContext(ctx).Scan(&channel.ID); err != nil {
return errors.Wrap(err, "error executing query storeNetworkChannels")
}
channel.ID = retID
//channelQuery, channelArgs, err := channelQueryBuilder.ToSql()
//if err != nil {
// r.log.Error().Stack().Err(err).Msg("irc.storeNetworkChannels: error building query")
@ -467,8 +448,6 @@ func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, cha
}
func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *domain.IrcChannel) error {
pass := toNullString(channel.Password)
if channel.ID != 0 {
// update record
channelQueryBuilder := r.db.squirrel.
@ -476,7 +455,7 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do
Set("enabled", channel.Enabled).
Set("detached", channel.Detached).
Set("name", channel.Name).
Set("password", pass).
Set("password", toNullString(channel.Password)).
Where(sq.Eq{"id": channel.ID})
query, args, err := channelQueryBuilder.ToSql()
@ -501,21 +480,17 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do
channel.Enabled,
true,
channel.Name,
pass,
toNullString(channel.Password),
networkID,
).
Suffix("RETURNING id").
RunWith(r.db.handler)
// returning
var retID int64
if err := queryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil {
if err := queryBuilder.QueryRowContext(ctx).Scan(&channel.ID); err != nil {
return errors.Wrap(err, "error executing query")
}
channel.ID = retID
//channelQuery, channelArgs, err := channelQueryBuilder.ToSql()
//if err != nil {
// r.log.Error().Stack().Err(err).Msg("irc.storeChannel: error building query")
@ -536,15 +511,13 @@ func (r *IrcRepo) StoreChannel(ctx context.Context, networkID int64, channel *do
}
func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error {
pass := toNullString(channel.Password)
// update record
channelQueryBuilder := r.db.squirrel.
Update("irc_channel").
Set("enabled", channel.Enabled).
Set("detached", channel.Detached).
Set("name", channel.Name).
Set("password", pass).
Set("password", toNullString(channel.Password)).
Where(sq.Eq{"id": channel.ID})
query, args, err := channelQueryBuilder.ToSql()
@ -561,7 +534,6 @@ func (r *IrcRepo) UpdateChannel(channel *domain.IrcChannel) error {
}
func (r *IrcRepo) UpdateInviteCommand(networkID int64, invite string) error {
// update record
channelQueryBuilder := r.db.squirrel.
Update("irc_network").

View file

@ -14,6 +14,20 @@ CREATE TABLE users
UNIQUE (username)
);
CREATE TABLE proxy
(
id SERIAL PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE indexer
(
id SERIAL PRIMARY KEY,
@ -24,8 +38,11 @@ CREATE TABLE indexer
enabled BOOLEAN,
name TEXT NOT NULL,
settings TEXT,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (identifier)
);
@ -51,8 +68,11 @@ CREATE TABLE irc_network
bot_mode BOOLEAN DEFAULT FALSE,
connected BOOLEAN,
connected_since TIMESTAMP,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (server, port, nick)
);
@ -900,5 +920,39 @@ ADD COLUMN months TEXT;
ALTER TABLE filter
ADD COLUMN days TEXT;
`,
`CREATE TABLE proxy
(
id SERIAL PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE indexer
ADD COLUMN proxy_id INTEGER;
ALTER TABLE indexer
ADD COLUMN use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE indexer
ADD FOREIGN KEY (proxy_id) REFERENCES proxy
ON DELETE SET NULL;
ALTER TABLE irc_network
ADD COLUMN proxy_id INTEGER;
ALTER TABLE irc_network
ADD COLUMN use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE irc_network
ADD FOREIGN KEY (proxy_id) REFERENCES proxy
ON DELETE SET NULL;
`,
}

265
internal/database/proxy.go Normal file
View file

@ -0,0 +1,265 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package database
import (
"context"
"database/sql"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
sq "github.com/Masterminds/squirrel"
"github.com/rs/zerolog"
)
type ProxyRepo struct {
log zerolog.Logger
db *DB
}
func NewProxyRepo(log logger.Logger, db *DB) domain.ProxyRepo {
return &ProxyRepo{
log: log.With().Str("repo", "proxy").Logger(),
db: db,
}
}
func (r *ProxyRepo) Store(ctx context.Context, p *domain.Proxy) error {
queryBuilder := r.db.squirrel.
Insert("proxy").
Columns(
"enabled",
"name",
"type",
"addr",
"auth_user",
"auth_pass",
"timeout",
).
Values(
p.Enabled,
p.Name,
p.Type,
toNullString(p.Addr),
toNullString(p.User),
toNullString(p.Pass),
p.Timeout,
).
Suffix("RETURNING id").
RunWith(r.db.handler)
var retID int64
err := queryBuilder.QueryRowContext(ctx).Scan(&retID)
if err != nil {
return errors.Wrap(err, "error executing query")
}
p.ID = retID
return nil
}
func (r *ProxyRepo) Update(ctx context.Context, p *domain.Proxy) error {
queryBuilder := r.db.squirrel.
Update("proxy").
Set("enabled", p.Enabled).
Set("name", p.Name).
Set("type", p.Type).
Set("addr", p.Addr).
Set("auth_user", toNullString(p.User)).
Set("auth_pass", toNullString(p.Pass)).
Set("timeout", p.Timeout).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": p.ID})
query, args, err := queryBuilder.ToSql()
if err != nil {
return errors.Wrap(err, "error building query")
}
// update record
res, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return errors.Wrap(err, "error getting affected rows")
}
if rowsAffected == 0 {
return domain.ErrUpdateFailed
}
return err
}
func (r *ProxyRepo) List(ctx context.Context) ([]domain.Proxy, error) {
queryBuilder := r.db.squirrel.
Select(
"id",
"enabled",
"name",
"type",
"addr",
"auth_user",
"auth_pass",
"timeout",
).
From("proxy").
OrderBy("name ASC")
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
rows, err := r.db.handler.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
defer rows.Close()
proxies := make([]domain.Proxy, 0)
for rows.Next() {
var proxy domain.Proxy
var user, pass sql.NullString
if err := rows.Scan(&proxy.ID, &proxy.Enabled, &proxy.Name, &proxy.Type, &proxy.Addr, &user, &pass, &proxy.Timeout); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
proxy.User = user.String
proxy.Pass = pass.String
proxies = append(proxies, proxy)
}
err = rows.Err()
if err != nil {
return nil, errors.Wrap(err, "error row")
}
return proxies, nil
}
func (r *ProxyRepo) Delete(ctx context.Context, id int64) error {
tx, err := r.db.BeginTx(ctx, nil)
if err != nil {
return errors.Wrap(err, "error begin transaction")
}
defer tx.Rollback()
queryBuilder := r.db.squirrel.
Delete("proxy").
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
return errors.Wrap(err, "error building query")
}
res, err := tx.ExecContext(ctx, query, args...)
if err != nil {
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return errors.Wrap(err, "error getting affected rows")
}
if rowsAffected == 0 {
return domain.ErrDeleteFailed
}
err = tx.Commit()
if err != nil {
return errors.Wrap(err, "error commit deleting proxy")
}
return nil
}
func (r *ProxyRepo) FindByID(ctx context.Context, id int64) (*domain.Proxy, error) {
queryBuilder := r.db.squirrel.
Select(
"id",
"enabled",
"name",
"type",
"addr",
"auth_user",
"auth_pass",
"timeout",
).
From("proxy").
OrderBy("name ASC").
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
return nil, errors.Wrap(err, "error building query")
}
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Err(); err != nil {
return nil, errors.Wrap(err, "error executing query")
}
var proxy domain.Proxy
var user, pass sql.NullString
err = row.Scan(&proxy.ID, &proxy.Enabled, &proxy.Name, &proxy.Type, &proxy.Addr, &user, &pass, &proxy.Timeout)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
proxy.User = user.String
proxy.Pass = pass.String
return &proxy, nil
}
func (r *ProxyRepo) ToggleEnabled(ctx context.Context, id int64, enabled bool) error {
queryBuilder := r.db.squirrel.
Update("proxy").
Set("enabled", enabled).
Set("updated_at", time.Now().Format(time.RFC3339)).
Where(sq.Eq{"id": id})
query, args, err := queryBuilder.ToSql()
if err != nil {
return errors.Wrap(err, "error building query")
}
// update record
res, err := r.db.handler.ExecContext(ctx, query, args...)
if err != nil {
return errors.Wrap(err, "error executing query")
}
rowsAffected, err := res.RowsAffected()
if err != nil {
return errors.Wrap(err, "error getting affected rows")
}
if rowsAffected == 0 {
return domain.ErrUpdateFailed
}
return nil
}

View file

@ -0,0 +1,220 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
//go:build integration
package database
import (
"context"
"fmt"
"testing"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/stretchr/testify/assert"
)
func getMockProxy() *domain.Proxy {
return &domain.Proxy{
//ID: 0,
Name: "Proxy",
Enabled: true,
Type: domain.ProxyTypeSocks5,
Addr: "socks5://127.0.0.1:1080",
User: "",
Pass: "",
Timeout: 0,
}
}
func TestProxyRepo_Store(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("Store_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, mockData.Name, proxies[0].Name)
// Cleanup
_ = repo.Delete(context.Background(), mockData.ID)
})
t.Run(fmt.Sprintf("Store_Fails_With_Missing_or_empty_fields [%s]", dbType), func(t *testing.T) {
mockData := domain.Proxy{}
err := repo.Store(context.Background(), &mockData)
assert.Error(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.Empty(t, proxies)
//assert.Nil(t, proxies)
// Cleanup
// No cleanup needed
})
t.Run(fmt.Sprintf("Store_Fails_With_Context_Timeout [%s]", dbType), func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Nanosecond)
defer cancel()
err := repo.Store(ctx, mockData)
assert.Error(t, err)
})
}
}
func TestProxyRepo_Update(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("Update_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
// Update mockData
updatedProxy := mockData
updatedProxy.Name = "Updated Proxy"
updatedProxy.Enabled = false
// Execute
err = repo.Update(context.Background(), updatedProxy)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, "Updated Proxy", proxies[0].Name)
assert.Equal(t, false, proxies[0].Enabled)
// Cleanup
_ = repo.Delete(context.Background(), proxies[0].ID)
})
t.Run(fmt.Sprintf("Update_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) {
mockData.ID = -1
err := repo.Update(context.Background(), mockData)
assert.Error(t, err)
assert.ErrorIs(t, err, domain.ErrUpdateFailed)
})
}
}
func TestProxyRepo_Delete(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, mockData.Name, proxies[0].Name)
// Execute
err = repo.Delete(context.Background(), proxies[0].ID)
assert.NoError(t, err)
// Verify that the proxy is deleted and return error ErrRecordNotFound
proxy, err := repo.FindByID(context.Background(), proxies[0].ID)
assert.ErrorIs(t, err, domain.ErrRecordNotFound)
assert.Nil(t, proxy)
})
t.Run(fmt.Sprintf("Delete_Fails_No_Record [%s]", dbType), func(t *testing.T) {
err := repo.Delete(context.Background(), 9999)
assert.Error(t, err)
assert.ErrorIs(t, err, domain.ErrDeleteFailed)
})
}
}
func TestProxyRepo_ToggleEnabled(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("ToggleEnabled_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
assert.Equal(t, true, proxies[0].Enabled)
// Execute
err = repo.ToggleEnabled(context.Background(), mockData.ID, false)
assert.NoError(t, err)
// Verify that the proxy is updated
proxy, err := repo.FindByID(context.Background(), proxies[0].ID)
assert.NoError(t, err)
assert.NotNil(t, proxy)
assert.Equal(t, false, proxy.Enabled)
// Cleanup
_ = repo.Delete(context.Background(), proxies[0].ID)
})
t.Run(fmt.Sprintf("ToggleEnabled_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) {
err := repo.ToggleEnabled(context.Background(), -1, false)
assert.Error(t, err)
assert.ErrorIs(t, err, domain.ErrUpdateFailed)
})
}
}
func TestProxyRepo_FindByID(t *testing.T) {
for dbType, db := range testDBs {
log := setupLoggerForTest()
repo := NewProxyRepo(log, db)
mockData := getMockProxy()
t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) {
// Setup
err := repo.Store(context.Background(), mockData)
assert.NoError(t, err)
proxies, err := repo.List(context.Background())
assert.NoError(t, err)
assert.NotNil(t, proxies)
// Execute
proxy, err := repo.FindByID(context.Background(), proxies[0].ID)
assert.NoError(t, err)
assert.NotNil(t, proxy)
assert.Equal(t, proxies[0].ID, proxy.ID)
// Cleanup
_ = repo.Delete(context.Background(), proxies[0].ID)
})
t.Run(fmt.Sprintf("FindByID_Fails_Invalid_ID [%s]", dbType), func(t *testing.T) {
// Test using an invalid ID
proxy, err := repo.FindByID(context.Background(), -1)
assert.ErrorIs(t, err, domain.ErrRecordNotFound) // should return an error
assert.Nil(t, proxy) // should be nil
})
}
}

View file

@ -14,6 +14,20 @@ CREATE TABLE users
UNIQUE (username)
);
CREATE TABLE proxy
(
id INTEGER PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE indexer
(
id INTEGER PRIMARY KEY,
@ -24,8 +38,11 @@ CREATE TABLE indexer
enabled BOOLEAN,
name TEXT NOT NULL,
settings TEXT,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (identifier)
);
@ -51,8 +68,11 @@ CREATE TABLE irc_network
bot_mode BOOLEAN DEFAULT FALSE,
connected BOOLEAN,
connected_since TIMESTAMP,
use_proxy BOOLEAN DEFAULT FALSE,
proxy_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY (proxy_id) REFERENCES proxy(id) ON DELETE SET NULL,
UNIQUE (server, port, nick)
);
@ -1538,5 +1558,37 @@ ADD COLUMN months TEXT;
ALTER TABLE filter
ADD COLUMN days TEXT;
`,
`CREATE TABLE proxy
(
id INTEGER PRIMARY KEY,
enabled BOOLEAN,
name TEXT NOT NULL,
type TEXT NOT NULL,
addr TEXT NOT NULL,
auth_user TEXT,
auth_pass TEXT,
timeout INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
);
ALTER TABLE indexer
ADD proxy_id INTEGER
CONSTRAINT indexer_proxy_id_fk
REFERENCES proxy(id)
ON DELETE SET NULL;
ALTER TABLE indexer
ADD use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE irc_network
ADD use_proxy BOOLEAN DEFAULT FALSE;
ALTER TABLE irc_network
ADD proxy_id INTEGER
CONSTRAINT irc_network_proxy_id_fk
REFERENCES proxy(id)
ON DELETE SET NULL;
`,
}

View file

@ -16,10 +16,10 @@ func dataSourceName(configPath string, name string) string {
return name
}
func toNullString(s string) sql.NullString {
return sql.NullString{
String: s,
Valid: s != "",
func toNullString(s string) sql.Null[string] {
return sql.Null[string]{
V: s,
Valid: s != "",
}
}

View file

@ -5,7 +5,6 @@ package domain
import (
"context"
"os"
"strings"
"github.com/autobrr/autobrr/pkg/errors"
@ -60,43 +59,69 @@ type Action struct {
Client *DownloadClient `json:"client,omitempty"`
}
// ParseMacros parse all macros on action
func (a *Action) ParseMacros(release *Release) error {
var err error
// CheckMacrosNeedTorrentTmpFile check if macros needs torrent downloaded
func (a *Action) CheckMacrosNeedTorrentTmpFile(release *Release) bool {
if release.TorrentTmpFile == "" &&
(strings.Contains(a.ExecArgs, "TorrentPathName") || strings.Contains(a.ExecArgs, "TorrentDataRawBytes") ||
strings.Contains(a.WebhookData, "TorrentPathName") || strings.Contains(a.WebhookData, "TorrentDataRawBytes") ||
strings.Contains(a.SavePath, "TorrentPathName") || a.Type == ActionTypeWatchFolder) {
if err := release.DownloadTorrentFile(); err != nil {
return errors.Wrap(err, "webhook: could not download torrent file for release: %v", release.TorrentName)
}
(strings.Contains(a.ExecArgs, "TorrentPathName") ||
strings.Contains(a.ExecArgs, "TorrentDataRawBytes") ||
strings.Contains(a.WebhookData, "TorrentPathName") ||
strings.Contains(a.WebhookData, "TorrentDataRawBytes") ||
strings.Contains(a.SavePath, "TorrentPathName") ||
a.Type == ActionTypeWatchFolder) {
return true
}
return false
}
func (a *Action) CheckMacrosNeedRawDataBytes(release *Release) bool {
// if webhook data contains TorrentDataRawBytes, lets read the file into bytes we can then use in the macro
if len(release.TorrentDataRawBytes) == 0 &&
(strings.Contains(a.ExecArgs, "TorrentDataRawBytes") || strings.Contains(a.WebhookData, "TorrentDataRawBytes") ||
a.Type == ActionTypeWatchFolder) {
t, err := os.ReadFile(release.TorrentTmpFile)
if err != nil {
return errors.Wrap(err, "could not read torrent file: %v", release.TorrentTmpFile)
}
release.TorrentDataRawBytes = t
return true
}
return false
}
// ParseMacros parse all macros on action
func (a *Action) ParseMacros(release *Release) error {
var err error
m := NewMacro(*release)
a.ExecArgs, err = m.Parse(a.ExecArgs)
a.WatchFolder, err = m.Parse(a.WatchFolder)
a.Category, err = m.Parse(a.Category)
a.Tags, err = m.Parse(a.Tags)
a.Label, err = m.Parse(a.Label)
a.SavePath, err = m.Parse(a.SavePath)
a.WebhookData, err = m.Parse(a.WebhookData)
if err != nil {
return errors.Wrap(err, "could not parse macros for action: %v", a.Name)
return errors.Wrap(err, "could not parse exec args")
}
a.WatchFolder, err = m.Parse(a.WatchFolder)
if err != nil {
return errors.Wrap(err, "could not parse watch folder")
}
a.Category, err = m.Parse(a.Category)
if err != nil {
return errors.Wrap(err, "could not parse category")
}
a.Tags, err = m.Parse(a.Tags)
if err != nil {
return errors.Wrap(err, "could not parse tags")
}
a.Label, err = m.Parse(a.Label)
if err != nil {
return errors.Wrap(err, "could not parse label")
}
a.SavePath, err = m.Parse(a.SavePath)
if err != nil {
return errors.Wrap(err, "could not parse save_path")
}
a.WebhookData, err = m.Parse(a.WebhookData)
if err != nil {
return errors.Wrap(err, "could not parse webhook_data")
}
return nil

View file

@ -3,8 +3,14 @@
package domain
import "database/sql"
import (
"database/sql"
"github.com/autobrr/autobrr/pkg/errors"
)
var (
ErrRecordNotFound = sql.ErrNoRows
ErrUpdateFailed = errors.New("update failed")
ErrDeleteFailed = errors.New("delete failed")
)

View file

@ -53,6 +53,11 @@ type Feed struct {
LastRun time.Time `json:"last_run"`
LastRunData string `json:"last_run_data"`
NextRun time.Time `json:"next_run"`
// belongs to Indexer
ProxyID int64
UseProxy bool
Proxy *Proxy
}
type FeedSettingsJSON struct {

View file

@ -34,6 +34,9 @@ type Indexer struct {
Enabled bool `json:"enabled"`
Implementation string `json:"implementation"`
BaseURL string `json:"base_url,omitempty"`
UseProxy bool `json:"use_proxy"`
Proxy *Proxy `json:"proxy"`
ProxyID int64 `json:"proxy_id"`
Settings map[string]string `json:"settings,omitempty"`
}
@ -66,6 +69,8 @@ type IndexerDefinition struct {
Protocol string `json:"protocol"`
URLS []string `json:"urls"`
Supports []string `json:"supports"`
UseProxy bool `json:"use_proxy"`
ProxyID int64 `json:"proxy_id"`
Settings []IndexerSetting `json:"settings,omitempty"`
SettingsMap map[string]string `json:"-"`
IRC *IndexerIRC `json:"irc,omitempty"`

View file

@ -48,6 +48,9 @@ type IrcNetwork struct {
InviteCommand string `json:"invite_command"`
UseBouncer bool `json:"use_bouncer"`
BouncerAddr string `json:"bouncer_addr"`
UseProxy bool `json:"use_proxy"`
ProxyId int64 `json:"proxy_id"`
Proxy *Proxy `json:"proxy"`
BotMode bool `json:"bot_mode"`
Channels []IrcChannel `json:"channels"`
Connected bool `json:"connected"`
@ -70,6 +73,9 @@ type IrcNetworkWithHealth struct {
BotMode bool `json:"bot_mode"`
CurrentNick string `json:"current_nick"`
PreferredNick string `json:"preferred_nick"`
UseProxy bool `json:"use_proxy"`
ProxyId int64 `json:"proxy_id"`
Proxy *Proxy `json:"proxy"`
Channels []ChannelWithHealth `json:"channels"`
Connected bool `json:"connected"`
ConnectedSince time.Time `json:"connected_since"`

View file

@ -163,6 +163,8 @@ func (m Macro) Parse(text string) (string, error) {
return "", nil
}
// TODO implement template cache
// setup template
tmpl, err := template.New("macro").Funcs(sprig.TxtFuncMap()).Parse(text)
if err != nil {

78
internal/domain/proxy.go Normal file
View file

@ -0,0 +1,78 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package domain
import (
"context"
"net/url"
"github.com/autobrr/autobrr/pkg/errors"
)
type ProxyRepo interface {
Store(ctx context.Context, p *Proxy) error
Update(ctx context.Context, p *Proxy) error
List(ctx context.Context) ([]Proxy, error)
Delete(ctx context.Context, id int64) error
FindByID(ctx context.Context, id int64) (*Proxy, error)
ToggleEnabled(ctx context.Context, id int64, enabled bool) error
}
type Proxy struct {
ID int64 `json:"id"`
Name string `json:"name"`
Enabled bool `json:"enabled"`
Type ProxyType `json:"type"`
Addr string `json:"addr"`
User string `json:"user"`
Pass string `json:"pass"`
Timeout int `json:"timeout"`
}
type ProxyType string
const (
ProxyTypeSocks5 = "SOCKS5"
)
func (p Proxy) ValidProxyType() bool {
if p.Type == ProxyTypeSocks5 {
return true
}
return false
}
func (p Proxy) Validate() error {
if !p.ValidProxyType() {
return errors.New("invalid proxy type: %s", p.Type)
}
if err := ValidateProxyAddr(p.Addr); err != nil {
return err
}
if p.Name == "" {
return errors.New("name is required")
}
return nil
}
func ValidateProxyAddr(addr string) error {
if addr == "" {
return errors.New("addr is required")
}
proxyUrl, err := url.Parse(addr)
if err != nil {
return errors.Wrap(err, "could not parse proxy url: %s", addr)
}
if proxyUrl.Scheme != "socks5" && proxyUrl.Scheme != "socks5h" {
return errors.New("proxy url scheme must be socks5 or socks5h")
}
return nil
}

View file

@ -356,8 +356,6 @@ func (r *Release) ParseString(title string) {
r.ParseReleaseTagsString(r.ReleaseTags)
}
var ErrUnrecoverableError = errors.New("unrecoverable error")
func (r *Release) ParseReleaseTagsString(tags string) {
cleanTags := CleanReleaseTags(tags)
t := ParseReleaseTagString(cleanTags)
@ -432,10 +430,6 @@ func (r *Release) DownloadTorrentFileCtx(ctx context.Context) error {
return r.downloadTorrentFile(ctx)
}
func (r *Release) DownloadTorrentFile() error {
return r.downloadTorrentFile(context.Background())
}
func (r *Release) downloadTorrentFile(ctx context.Context) error {
if r.HasMagnetUri() {
return errors.New("downloading magnet links is not supported: %s", r.MagnetURI)
@ -592,7 +586,7 @@ func (r *Release) downloadTorrentFile(ctx context.Context) error {
}
func (r *Release) CleanupTemporaryFiles() {
if len(r.TorrentTmpFile) == 0 {
if r.TorrentTmpFile == "" {
return
}
@ -600,54 +594,15 @@ func (r *Release) CleanupTemporaryFiles() {
r.TorrentTmpFile = ""
}
// HasMagnetUri check uf MagnetURI is set or empty
// HasMagnetUri check uf MagnetURI is set and valid or empty
func (r *Release) HasMagnetUri() bool {
return r.MagnetURI != ""
if r.MagnetURI != "" && strings.HasPrefix(r.MagnetURI, MagnetURIPrefix) {
return true
}
return false
}
func (r *Release) ResolveMagnetUri(ctx context.Context) error {
if r.MagnetURI == "" {
return nil
} else if strings.HasPrefix(r.MagnetURI, "magnet:?") {
return nil
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.MagnetURI, nil)
if err != nil {
return errors.Wrap(err, "could not build request to resolve magnet uri")
}
req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "autobrr")
client := &http.Client{
Timeout: time.Second * 45,
Transport: sharedhttp.MagnetTransport,
}
res, err := client.Do(req)
if err != nil {
return errors.Wrap(err, "could not make request to resolve magnet uri")
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return errors.New("unexpected status code: %d", res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return errors.Wrap(err, "could not read response body")
}
magnet := string(body)
if magnet != "" {
r.MagnetURI = magnet
}
return nil
}
const MagnetURIPrefix = "magnet:?"
func (r *Release) addRejection(reason string) {
r.Rejections = append(r.Rejections, reason)

View file

@ -6,6 +6,7 @@
package domain
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
@ -290,7 +291,7 @@ func TestRelease_DownloadTorrentFile(t *testing.T) {
Filter: tt.fields.Filter,
ActionStatus: tt.fields.ActionStatus,
}
err := r.DownloadTorrentFile()
err := r.DownloadTorrentFileCtx(context.Background())
if err == nil && tt.wantErr {
fmt.Println("error")
}

View file

@ -42,10 +42,23 @@ func NewFeedParser(timeout time.Duration, cookie string) *RSSParser {
}
c.http.Timeout = timeout
c.parser.Client = httpClient
return c
}
func (c *RSSParser) WithHTTPClient(client *http.Client) {
httpClient := client
if client.Jar == nil {
jarOptions := &cookiejar.Options{PublicSuffixList: publicsuffix.List}
jar, _ := cookiejar.New(jarOptions)
httpClient.Jar = jar
}
c.http = httpClient
c.parser.Client = httpClient
}
func (c *RSSParser) ParseURLWithContext(ctx context.Context, feedURL string) (feed *gofeed.Feed, err error) {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, feedURL, nil)
if err != nil {

View file

@ -10,6 +10,7 @@ import (
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/internal/scheduler"
"github.com/autobrr/autobrr/pkg/errors"
@ -129,6 +130,18 @@ func (j *NewznabJob) process(ctx context.Context) error {
}
func (j *NewznabJob) getFeed(ctx context.Context) ([]newznab.FeedItem, error) {
// add proxy if enabled and exists
if j.Feed.UseProxy && j.Feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy)
if err != nil {
return nil, errors.Wrap(err, "could not get proxy client")
}
j.Client.WithHTTPClient(proxyClient)
j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name)
}
// get feed
feed, err := j.Client.GetFeed(ctx)
if err != nil {
@ -156,36 +169,34 @@ func (j *NewznabJob) getFeed(ctx context.Context) ([]newznab.FeedItem, error) {
// set ttl to 1 month
ttl := time.Now().AddDate(0, 1, 0)
for _, i := range feed.Channel.Items {
i := i
if i.GUID == "" {
for _, item := range feed.Channel.Items {
if item.GUID == "" {
j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name)
continue
}
exists, err := j.CacheRepo.Exists(j.Feed.ID, i.GUID)
exists, err := j.CacheRepo.Exists(j.Feed.ID, item.GUID)
if err != nil {
j.Log.Error().Err(err).Msg("could not check if item exists")
continue
}
if exists {
j.Log.Trace().Msgf("cache item exists, skipping release: %s", i.Title)
j.Log.Trace().Msgf("cache item exists, skipping release: %s", item.Title)
continue
}
j.Log.Debug().Msgf("found new release: %s", i.Title)
j.Log.Debug().Msgf("found new release: %s", item.Title)
toCache = append(toCache, domain.FeedCacheItem{
FeedId: strconv.Itoa(j.Feed.ID),
Key: i.GUID,
Value: []byte(i.Title),
Key: item.GUID,
Value: []byte(item.Title),
TTL: ttl,
})
// only append if we successfully added to cache
items = append(items, *i)
items = append(items, *item)
}
if len(toCache) > 0 {

View file

@ -13,6 +13,7 @@ import (
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/pkg/errors"
@ -93,7 +94,6 @@ func (j *RSSJob) process(ctx context.Context) error {
releases := make([]*domain.Release, 0)
for _, item := range items {
item := item
j.Log.Debug().Msgf("item: %v", item.Title)
rls := j.processItem(item)
@ -139,7 +139,7 @@ func (j *RSSJob) processItem(item *gofeed.Item) *domain.Release {
}
if j.Feed.Settings != nil && j.Feed.Settings.DownloadType == domain.FeedDownloadTypeMagnet {
if !strings.HasPrefix(rls.MagnetURI, "magnet:?") && strings.HasPrefix(e.URL, "magnet:?") {
if !strings.HasPrefix(rls.MagnetURI, domain.MagnetURIPrefix) && strings.HasPrefix(e.URL, domain.MagnetURIPrefix) {
rls.MagnetURI = e.URL
rls.DownloadURL = ""
}
@ -232,7 +232,20 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error)
ctx, cancel := context.WithTimeout(ctx, j.Timeout)
defer cancel()
feed, err := NewFeedParser(j.Timeout, j.Feed.Cookie).ParseURLWithContext(ctx, j.URL)
feedParser := NewFeedParser(j.Timeout, j.Feed.Cookie)
if j.Feed.UseProxy && j.Feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy)
if err != nil {
return nil, errors.Wrap(err, "could not get proxy client")
}
feedParser.WithHTTPClient(proxyClient)
j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name)
}
feed, err := feedParser.ParseURLWithContext(ctx, j.URL)
if err != nil {
return nil, errors.Wrap(err, "error fetching rss feed items")
}
@ -257,9 +270,7 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error)
// set ttl to 1 month
ttl := time.Now().AddDate(0, 1, 0)
for _, i := range feed.Items {
item := i
for _, item := range feed.Items {
key := item.GUID
if len(key) == 0 {
key = item.Link
@ -278,12 +289,12 @@ func (j *RSSJob) getFeed(ctx context.Context) (items []*gofeed.Item, err error)
continue
}
j.Log.Debug().Msgf("found new release: %s", i.Title)
j.Log.Debug().Msgf("found new release: %s", item.Title)
toCache = append(toCache, domain.FeedCacheItem{
FeedId: strconv.Itoa(j.Feed.ID),
Key: key,
Value: []byte(i.Title),
Value: []byte(item.Title),
TTL: ttl,
})

View file

@ -11,6 +11,7 @@ import (
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/internal/scheduler"
"github.com/autobrr/autobrr/pkg/errors"
@ -68,16 +69,18 @@ type service struct {
repo domain.FeedRepo
cacheRepo domain.FeedCacheRepo
releaseSvc release.Service
proxySvc proxy.Service
scheduler scheduler.Service
}
func NewService(log logger.Logger, repo domain.FeedRepo, cacheRepo domain.FeedCacheRepo, releaseSvc release.Service, scheduler scheduler.Service) Service {
func NewService(log logger.Logger, repo domain.FeedRepo, cacheRepo domain.FeedCacheRepo, releaseSvc release.Service, proxySvc proxy.Service, scheduler scheduler.Service) Service {
return &service{
log: log.With().Str("module", "feed").Logger(),
jobs: map[string]int{},
repo: repo,
cacheRepo: cacheRepo,
releaseSvc: releaseSvc,
proxySvc: proxySvc,
scheduler: scheduler,
}
}
@ -150,6 +153,13 @@ func (s *service) update(ctx context.Context, feed *domain.Feed) error {
return err
}
// get Feed again for ProxyID and UseProxy to be correctly populated
feed, err := s.repo.FindByID(ctx, feed.ID)
if err != nil {
s.log.Error().Err(err).Msg("error finding feed")
return err
}
if err := s.restartJob(feed); err != nil {
s.log.Error().Err(err).Msg("error restarting feed")
return err
@ -227,6 +237,18 @@ func (s *service) test(ctx context.Context, feed *domain.Feed) error {
// create sub logger
subLogger := zstdlog.NewStdLoggerWithLevel(s.log.With().Logger(), zerolog.DebugLevel)
// add proxy conf
if feed.UseProxy {
proxyConf, err := s.proxySvc.FindByID(ctx, feed.ProxyID)
if err != nil {
return errors.Wrap(err, "could not find proxy for indexer feed")
}
if proxyConf.Enabled {
feed.Proxy = proxyConf
}
}
// test feeds
switch feed.Type {
case string(domain.FeedTypeTorznab):
@ -254,13 +276,27 @@ func (s *service) test(ctx context.Context, feed *domain.Feed) error {
}
func (s *service) testRSS(ctx context.Context, feed *domain.Feed) error {
f, err := NewFeedParser(time.Duration(feed.Timeout)*time.Second, feed.Cookie).ParseURLWithContext(ctx, feed.URL)
feedParser := NewFeedParser(time.Duration(feed.Timeout)*time.Second, feed.Cookie)
// add proxy if enabled and exists
if feed.UseProxy && feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxy client")
}
feedParser.WithHTTPClient(proxyClient)
s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name)
}
feedResponse, err := feedParser.ParseURLWithContext(ctx, feed.URL)
if err != nil {
s.log.Error().Err(err).Msgf("error fetching rss feed items")
return errors.Wrap(err, "error fetching rss feed items")
}
s.log.Info().Msgf("refreshing rss feed: %s, found (%d) items", feed.Name, len(f.Items))
s.log.Info().Msgf("refreshing rss feed: %s, found (%d) items", feed.Name, len(feedResponse.Items))
return nil
}
@ -269,6 +305,18 @@ func (s *service) testTorznab(ctx context.Context, feed *domain.Feed, subLogger
// setup torznab Client
c := torznab.NewClient(torznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger})
// add proxy if enabled and exists
if feed.UseProxy && feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxy client")
}
c.WithHTTPClient(proxyClient)
s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name)
}
items, err := c.FetchFeed(ctx)
if err != nil {
s.log.Error().Err(err).Msg("error getting torznab feed")
@ -284,6 +332,18 @@ func (s *service) testNewznab(ctx context.Context, feed *domain.Feed, subLogger
// setup newznab Client
c := newznab.NewClient(newznab.Config{Host: feed.URL, ApiKey: feed.ApiKey, Log: subLogger})
// add proxy if enabled and exists
if feed.UseProxy && feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(feed.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxy client")
}
c.WithHTTPClient(proxyClient)
s.log.Debug().Msgf("using proxy %s for feed %s", feed.Proxy.Name, feed.Name)
}
items, err := c.GetFeed(ctx)
if err != nil {
s.log.Error().Err(err).Msg("error getting newznab feed")
@ -316,8 +376,6 @@ func (s *service) start() error {
s.log.Debug().Msgf("preparing staggered start of %d feeds", len(feeds))
for _, feed := range feeds {
feed := feed
if !feed.Enabled {
s.log.Trace().Msgf("feed disabled, skipping... %s", feed.Name)
continue
@ -408,6 +466,18 @@ func (s *service) startJob(f *domain.Feed) error {
return errors.New("no URL provided for feed: %s", f.Name)
}
// add proxy conf
if f.UseProxy {
proxyConf, err := s.proxySvc.FindByID(context.Background(), f.ProxyID)
if err != nil {
return errors.Wrap(err, "could not find proxy for indexer feed")
}
if proxyConf.Enabled {
f.Proxy = proxyConf
}
}
fi := newFeedInstance(f)
job, err := s.initializeFeedJob(fi)

View file

@ -5,6 +5,7 @@ package feed
import (
"context"
"github.com/autobrr/autobrr/internal/proxy"
"math"
"sort"
"strconv"
@ -224,6 +225,18 @@ func mapFreeleechToBonus(percentage int) string {
}
func (j *TorznabJob) getFeed(ctx context.Context) ([]torznab.FeedItem, error) {
// add proxy if enabled and exists
if j.Feed.UseProxy && j.Feed.Proxy != nil {
proxyClient, err := proxy.GetProxiedHTTPClient(j.Feed.Proxy)
if err != nil {
return nil, errors.Wrap(err, "could not get proxy client")
}
j.Client.WithHTTPClient(proxyClient)
j.Log.Debug().Msgf("using proxy %s for feed %s", j.Feed.Proxy.Name, j.Feed.Name)
}
// get feed
feed, err := j.Client.FetchFeed(ctx)
if err != nil {
@ -251,35 +264,33 @@ func (j *TorznabJob) getFeed(ctx context.Context) ([]torznab.FeedItem, error) {
// set ttl to 1 month
ttl := time.Now().AddDate(0, 1, 0)
for _, i := range feed.Channel.Items {
i := i
if i.GUID == "" {
for _, item := range feed.Channel.Items {
if item.GUID == "" {
j.Log.Error().Msgf("missing GUID from feed: %s", j.Feed.Name)
continue
}
exists, err := j.CacheRepo.Exists(j.Feed.ID, i.GUID)
exists, err := j.CacheRepo.Exists(j.Feed.ID, item.GUID)
if err != nil {
j.Log.Error().Err(err).Msg("could not check if item exists")
continue
}
if exists {
j.Log.Trace().Msgf("cache item exists, skipping release: %s", i.Title)
j.Log.Trace().Msgf("cache item exists, skipping release: %s", item.Title)
continue
}
j.Log.Debug().Msgf("found new release: %s", i.Title)
j.Log.Debug().Msgf("found new release: %s", item.Title)
toCache = append(toCache, domain.FeedCacheItem{
FeedId: strconv.Itoa(j.Feed.ID),
Key: i.GUID,
Value: []byte(i.Title),
Key: item.GUID,
Value: []byte(item.Title),
TTL: ttl,
})
// only append if we successfully added to cache
items = append(items, *i)
items = append(items, *item)
}
if len(toCache) > 0 {

View file

@ -7,7 +7,6 @@ import (
"bytes"
"context"
"fmt"
"github.com/autobrr/autobrr/internal/action"
"io"
"net/http"
"os"
@ -17,9 +16,11 @@ import (
"strings"
"time"
"github.com/autobrr/autobrr/internal/action"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/indexer"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/releasedownload"
"github.com/autobrr/autobrr/internal/utils"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/sharedhttp"
@ -53,11 +54,12 @@ type service struct {
releaseRepo domain.ReleaseRepo
indexerSvc indexer.Service
apiService indexer.APIService
downloadSvc *releasedownload.DownloadService
httpClient *http.Client
}
func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service) Service {
func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Service, releaseRepo domain.ReleaseRepo, apiService indexer.APIService, indexerSvc indexer.Service, downloadSvc *releasedownload.DownloadService) Service {
return &service{
log: log.With().Str("module", "filter").Logger(),
repo: repo,
@ -65,6 +67,7 @@ func NewService(log logger.Logger, repo domain.FilterRepo, actionSvc action.Serv
actionService: actionSvc,
apiService: apiService,
indexerSvc: indexerSvc,
downloadSvc: downloadSvc,
httpClient: &http.Client{
Timeout: time.Second * 120,
Transport: sharedhttp.TransportTLSInsecure,
@ -504,9 +507,9 @@ func (s *service) AdditionalSizeCheck(ctx context.Context, f *domain.Filter, rel
l.Trace().Msgf("(%s) preparing to download torrent metafile", f.Name)
// if indexer doesn't have api, download torrent and add to tmpPath
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
l.Error().Err(err).Msgf("(%s) could not download torrent file with id: '%s' from: %s", f.Name, release.TorrentID, release.Indexer.Identifier)
return false, err
return false, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
}
@ -584,8 +587,8 @@ func (s *service) execCmd(ctx context.Context, external domain.FilterExternal, r
s.log.Trace().Msgf("filter exec release: %s", release.TorrentName)
if release.TorrentTmpFile == "" && strings.Contains(external.ExecArgs, "TorrentPathName") {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
return 0, errors.Wrap(err, "error downloading torrent file for release: %s", release.TorrentName)
if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
return 0, errors.Wrap(err, "could not download torrent file for release: %s", release.TorrentName)
}
}
@ -686,7 +689,7 @@ func (s *service) webhook(ctx context.Context, external domain.FilterExternal, r
// if webhook data contains TorrentPathName or TorrentDataRawBytes, lets download the torrent file
if release.TorrentTmpFile == "" && (strings.Contains(external.WebhookData, "TorrentPathName") || strings.Contains(external.WebhookData, "TorrentDataRawBytes")) {
if err := release.DownloadTorrentFileCtx(ctx); err != nil {
if err := s.downloadSvc.DownloadRelease(ctx, release); err != nil {
return 0, errors.Wrap(err, "webhook: could not download torrent file for release: %s", release.TorrentName)
}
}

View file

@ -6,11 +6,11 @@ package http
import (
"context"
"encoding/json"
"github.com/autobrr/autobrr/pkg/errors"
"net/http"
"strconv"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5"
)

View file

@ -7,12 +7,12 @@ import (
"context"
"encoding/json"
"fmt"
"github.com/autobrr/autobrr/pkg/errors"
"net/http"
"strconv"
"strings"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5"
"github.com/r3labs/sse/v2"

View file

@ -6,11 +6,11 @@ package http
import (
"context"
"encoding/json"
"github.com/autobrr/autobrr/pkg/errors"
"net/http"
"strconv"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5"
)

151
internal/http/proxy.go Normal file
View file

@ -0,0 +1,151 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package http
import (
"context"
"encoding/json"
"net/http"
"strconv"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/go-chi/chi/v5"
)
type proxyService interface {
Store(ctx context.Context, p *domain.Proxy) error
Update(ctx context.Context, p *domain.Proxy) error
Delete(ctx context.Context, id int64) error
List(ctx context.Context) ([]domain.Proxy, error)
FindByID(ctx context.Context, id int64) (*domain.Proxy, error)
Test(ctx context.Context, p *domain.Proxy) error
}
type proxyHandler struct {
encoder encoder
service proxyService
}
func newProxyHandler(encoder encoder, service proxyService) *proxyHandler {
return &proxyHandler{
encoder: encoder,
service: service,
}
}
func (h proxyHandler) Routes(r chi.Router) {
r.Get("/", h.list)
r.Post("/", h.store)
r.Post("/test", h.test)
r.Route("/{proxyID}", func(r chi.Router) {
r.Get("/", h.findByID)
r.Put("/", h.update)
r.Delete("/", h.delete)
})
}
func (h proxyHandler) store(w http.ResponseWriter, r *http.Request) {
var data domain.Proxy
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
if err := h.service.Store(r.Context(), &data); err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}
func (h proxyHandler) update(w http.ResponseWriter, r *http.Request) {
var data domain.Proxy
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
if err := h.service.Update(r.Context(), &data); err != nil {
if errors.Is(err, domain.ErrUpdateFailed) {
h.encoder.StatusError(w, http.StatusBadRequest, err)
return
}
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}
func (h proxyHandler) list(w http.ResponseWriter, r *http.Request) {
proxies, err := h.service.List(r.Context())
if err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(w, http.StatusOK, proxies)
}
func (h proxyHandler) findByID(w http.ResponseWriter, r *http.Request) {
proxyID, err := strconv.Atoi(chi.URLParam(r, "proxyID"))
if err != nil {
h.encoder.Error(w, err)
return
}
proxies, err := h.service.FindByID(r.Context(), int64(proxyID))
if err != nil {
if errors.Is(err, domain.ErrRecordNotFound) {
h.encoder.NotFoundErr(w, errors.New("could not find proxy with id %d", proxyID))
return
}
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(w, http.StatusOK, proxies)
}
func (h proxyHandler) delete(w http.ResponseWriter, r *http.Request) {
proxyID, err := strconv.Atoi(chi.URLParam(r, "proxyID"))
if err != nil {
h.encoder.Error(w, err)
return
}
err = h.service.Delete(r.Context(), int64(proxyID))
if err != nil {
if errors.Is(err, domain.ErrDeleteFailed) {
h.encoder.StatusError(w, http.StatusBadRequest, err)
return
}
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}
func (h proxyHandler) test(w http.ResponseWriter, r *http.Request) {
var data domain.Proxy
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
h.encoder.Error(w, err)
return
}
if err := h.service.Test(r.Context(), &data); err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}

View file

@ -43,11 +43,12 @@ type Server struct {
indexerService indexerService
ircService ircService
notificationService notificationService
proxyService proxyService
releaseService releaseService
updateService updateService
}
func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db *database.DB, version string, commit string, date string, actionService actionService, apiService apikeyService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, feedSvc feedService, indexerSvc indexerService, ircSvc ircService, notificationSvc notificationService, releaseSvc releaseService, updateSvc updateService) Server {
func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db *database.DB, version string, commit string, date string, actionService actionService, apiService apikeyService, authService authService, downloadClientSvc downloadClientService, filterSvc filterService, feedSvc feedService, indexerSvc indexerService, ircSvc ircService, notificationSvc notificationService, proxySvc proxyService, releaseSvc releaseService, updateSvc updateService) Server {
return Server{
log: log.With().Str("module", "http").Logger(),
config: config,
@ -68,6 +69,7 @@ func NewServer(log logger.Logger, config *config.AppConfig, sse *sse.Server, db
indexerService: indexerSvc,
ircService: ircSvc,
notificationService: notificationSvc,
proxyService: proxySvc,
releaseService: releaseSvc,
updateService: updateSvc,
}
@ -142,6 +144,7 @@ func (s Server) Handler() http.Handler {
r.Route("/keys", newAPIKeyHandler(encoder, s.apiService).Routes)
r.Route("/logs", newLogsHandler(s.config).Routes)
r.Route("/notification", newNotificationHandler(encoder, s.notificationService).Routes)
r.Route("/proxy", newProxyHandler(encoder, s.proxyService).Routes)
r.Route("/release", newReleaseHandler(encoder, s.releaseService).Routes)
r.Route("/updates", newUpdateHandler(encoder, s.updateService).Routes)

View file

@ -295,6 +295,9 @@ func (s *service) mapIndexer(indexer domain.Indexer) (*domain.IndexerDefinition,
d.BaseURL = indexer.BaseURL
d.Enabled = indexer.Enabled
d.UseProxy = indexer.UseProxy
d.ProxyID = indexer.ProxyID
if d.SettingsMap == nil {
d.SettingsMap = make(map[string]string)
}
@ -332,6 +335,9 @@ func (s *service) updateMapIndexer(indexer domain.Indexer) (*domain.IndexerDefin
d.BaseURL = indexer.BaseURL
d.Enabled = indexer.Enabled
d.UseProxy = indexer.UseProxy
d.ProxyID = indexer.ProxyID
if d.SettingsMap == nil {
d.SettingsMap = make(map[string]string)
}

View file

@ -6,6 +6,8 @@ package irc
import (
"crypto/tls"
"fmt"
"golang.org/x/net/proxy"
"net/url"
"slices"
"strings"
"time"
@ -220,6 +222,37 @@ func (h *Handler) Run() (err error) {
Log: subLogger,
}
if h.network.UseProxy && h.network.Proxy != nil {
if !h.network.Proxy.Enabled {
h.log.Debug().Msgf("proxy disabled, skip")
} else {
if h.network.Proxy.Addr == "" {
return errors.New("proxy addr missing")
}
proxyUrl, err := url.Parse(h.network.Proxy.Addr)
if err != nil {
return errors.Wrap(err, "could not parse proxy url: %s", h.network.Proxy.Addr)
}
// set user and pass if not empty
if h.network.Proxy.User != "" && h.network.Proxy.Pass != "" {
proxyUrl.User = url.UserPassword(h.network.Proxy.User, h.network.Proxy.Pass)
}
proxyDialer, err := proxy.FromURL(proxyUrl, proxy.Direct)
if err != nil {
return errors.Wrap(err, "could not create proxy dialer from url: %s", h.network.Proxy.Addr)
}
proxyContextDialer, ok := proxyDialer.(proxy.ContextDialer)
if !ok {
return errors.Wrap(err, "proxy dialer does not expose DialContext(): %v", proxyDialer)
}
client.DialContext = proxyContextDialer.DialContext
}
}
if h.network.Auth.Mechanism == domain.IRCAuthMechanismSASLPlain {
if h.network.Auth.Account != "" && h.network.Auth.Password != "" {
client.SASLLogin = h.network.Auth.Account

View file

@ -14,6 +14,7 @@ import (
"github.com/autobrr/autobrr/internal/indexer"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/notification"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/internal/release"
"github.com/autobrr/autobrr/pkg/errors"
@ -47,8 +48,10 @@ type service struct {
releaseService release.Service
indexerService indexer.Service
notificationService notification.Service
indexerMap map[string]string
handlers map[int64]*Handler
proxyService proxy.Service
indexerMap map[string]string
handlers map[int64]*Handler
stopWG sync.WaitGroup
lock sync.RWMutex
@ -56,7 +59,7 @@ type service struct {
const sseMaxEntries = 1000
func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, releaseSvc release.Service, indexerSvc indexer.Service, notificationSvc notification.Service) Service {
func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, releaseSvc release.Service, indexerSvc indexer.Service, notificationSvc notification.Service, proxySvc proxy.Service) Service {
return &service{
log: log.With().Str("module", "irc").Logger(),
sse: sse,
@ -64,6 +67,7 @@ func NewService(log logger.Logger, sse *sse.Server, repo domain.IrcRepo, release
releaseService: releaseSvc,
indexerService: indexerSvc,
notificationService: notificationSvc,
proxyService: proxySvc,
handlers: make(map[int64]*Handler),
}
}
@ -79,6 +83,15 @@ func (s *service) StartHandlers() {
continue
}
if network.ProxyId != 0 {
networkProxy, err := s.proxyService.FindByID(context.Background(), network.ProxyId)
if err != nil {
s.log.Error().Err(err).Msgf("failed to get proxy for network: %s", network.Server)
return
}
network.Proxy = networkProxy
}
channels, err := s.repo.ListChannels(network.ID)
if err != nil {
s.log.Error().Err(err).Msgf("failed to list channels for network: %s", network.Server)
@ -215,6 +228,14 @@ func (s *service) checkIfNetworkRestartNeeded(network *domain.IrcNetwork) error
restartNeeded = true
fieldsChanged = append(fieldsChanged, "bot mode")
}
if handler.UseProxy != network.UseProxy {
restartNeeded = true
fieldsChanged = append(fieldsChanged, "use proxy")
}
if handler.ProxyId != network.ProxyId {
restartNeeded = true
fieldsChanged = append(fieldsChanged, "proxy id")
}
if handler.Auth.Mechanism != network.Auth.Mechanism {
restartNeeded = true
fieldsChanged = append(fieldsChanged, "auth mechanism")
@ -476,12 +497,16 @@ func (s *service) GetNetworksWithHealth(ctx context.Context) ([]domain.IrcNetwor
BouncerAddr: n.BouncerAddr,
UseBouncer: n.UseBouncer,
BotMode: n.BotMode,
UseProxy: n.UseProxy,
ProxyId: n.ProxyId,
Connected: false,
Channels: []domain.ChannelWithHealth{},
ConnectionErrors: []string{},
}
s.lock.RLock()
handler, ok := s.handlers[n.ID]
s.lock.RUnlock()
if ok {
handler.ReportStatus(&netw)
}
@ -566,6 +591,18 @@ func (s *service) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork)
}
s.log.Debug().Msgf("irc.service: update network: %s", network.Name)
network.Proxy = nil
// attach proxy
if network.UseProxy && network.ProxyId != 0 {
networkProxy, err := s.proxyService.FindByID(context.Background(), network.ProxyId)
if err != nil {
s.log.Error().Err(err).Msgf("failed to get proxy for network: %s", network.Server)
return errors.Wrap(err, "could not get proxy for network: %s", network.Server)
}
network.Proxy = networkProxy
}
// stop or start network
// TODO get current state to see if enabled or not?
if network.Enabled {

194
internal/proxy/service.go Normal file
View file

@ -0,0 +1,194 @@
// Copyright (c) 2021 - 2024, Ludvig Lundgren and the autobrr contributors.
// SPDX-License-Identifier: GPL-2.0-or-later
package proxy
import (
"context"
"net/http"
"net/url"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/sharedhttp"
"github.com/rs/zerolog"
netProxy "golang.org/x/net/proxy"
)
type Service interface {
List(ctx context.Context) ([]domain.Proxy, error)
FindByID(ctx context.Context, id int64) (*domain.Proxy, error)
Store(ctx context.Context, p *domain.Proxy) error
Update(ctx context.Context, p *domain.Proxy) error
Delete(ctx context.Context, id int64) error
Test(ctx context.Context, p *domain.Proxy) error
}
type service struct {
log zerolog.Logger
repo domain.ProxyRepo
cache map[int64]*domain.Proxy
}
func NewService(log logger.Logger, repo domain.ProxyRepo) Service {
return &service{
log: log.With().Str("module", "proxy").Logger(),
repo: repo,
cache: make(map[int64]*domain.Proxy),
}
}
func (s *service) Store(ctx context.Context, proxy *domain.Proxy) error {
if err := proxy.Validate(); err != nil {
return errors.Wrap(err, "validation error")
}
err := s.repo.Store(ctx, proxy)
if err != nil {
return err
}
s.cache[proxy.ID] = proxy
return nil
}
func (s *service) Update(ctx context.Context, proxy *domain.Proxy) error {
if err := proxy.Validate(); err != nil {
return errors.Wrap(err, "validation error")
}
err := s.repo.Update(ctx, proxy)
if err != nil {
return err
}
s.cache[proxy.ID] = proxy
// TODO update IRC handlers
return nil
}
func (s *service) FindByID(ctx context.Context, id int64) (*domain.Proxy, error) {
if proxy, ok := s.cache[id]; ok {
return proxy, nil
}
return s.repo.FindByID(ctx, id)
}
func (s *service) List(ctx context.Context) ([]domain.Proxy, error) {
return s.repo.List(ctx)
}
func (s *service) ToggleEnabled(ctx context.Context, id int64, enabled bool) error {
err := s.repo.ToggleEnabled(ctx, id, enabled)
if err != nil {
return err
}
v, ok := s.cache[id]
if !ok {
v.Enabled = !enabled
s.cache[id] = v
}
// TODO update IRC handlers
return nil
}
func (s *service) Delete(ctx context.Context, id int64) error {
err := s.repo.Delete(ctx, id)
if err != nil {
return err
}
delete(s.cache, id)
// TODO update IRC handlers
return nil
}
func (s *service) Test(ctx context.Context, proxy *domain.Proxy) error {
if !proxy.ValidProxyType() {
return errors.New("invalid proxy type %s", proxy.Type)
}
if proxy.Addr == "" {
return errors.New("proxy addr missing")
}
httpClient, err := GetProxiedHTTPClient(proxy)
if err != nil {
return errors.Wrap(err, "could not get http client")
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://autobrr.com", nil)
if err != nil {
return errors.Wrap(err, "could not create proxy request")
}
resp, err := httpClient.Do(req)
if err != nil {
return errors.Wrap(err, "could not connect to proxy server: %s", proxy.Addr)
}
if resp.StatusCode != http.StatusOK {
return errors.New(resp.Status)
}
s.log.Debug().Msgf("proxy %s test OK!", proxy.Addr)
return nil
}
func GetProxiedHTTPClient(p *domain.Proxy) (*http.Client, error) {
proxyUrl, err := url.Parse(p.Addr)
if err != nil {
return nil, errors.Wrap(err, "could not parse proxy url: %s", p.Addr)
}
// set user and pass if not empty
if p.User != "" && p.Pass != "" {
proxyUrl.User = url.UserPassword(p.User, p.Pass)
}
transport := sharedhttp.TransportTLSInsecure
// set user and pass if not empty
if p.User != "" && p.Pass != "" {
proxyUrl.User = url.UserPassword(p.User, p.Pass)
}
switch p.Type {
case domain.ProxyTypeSocks5:
proxyDialer, err := netProxy.FromURL(proxyUrl, netProxy.Direct)
if err != nil {
return nil, errors.Wrap(err, "could not create proxy dialer from url: %s", p.Addr)
}
proxyContextDialer, ok := proxyDialer.(netProxy.ContextDialer)
if !ok {
return nil, errors.Wrap(err, "proxy dialer does not expose DialContext(): %v", proxyDialer)
}
transport.DialContext = proxyContextDialer.DialContext
default:
return nil, errors.New("invalid proxy type: %s", p.Type)
}
client := &http.Client{
Timeout: 30 * time.Second,
Transport: transport,
}
return client, nil
}

View file

@ -0,0 +1,316 @@
package releasedownload
import (
"bufio"
"bytes"
"context"
"io"
"net"
"net/http"
"net/http/cookiejar"
"os"
"strings"
"time"
"github.com/autobrr/autobrr/internal/domain"
"github.com/autobrr/autobrr/internal/logger"
"github.com/autobrr/autobrr/internal/proxy"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/autobrr/pkg/sharedhttp"
"github.com/anacrolix/torrent/bencode"
"github.com/anacrolix/torrent/metainfo"
"github.com/avast/retry-go/v4"
"github.com/rs/zerolog"
"golang.org/x/net/publicsuffix"
)
type DownloadService struct {
log zerolog.Logger
repo domain.ReleaseRepo
indexerRepo domain.IndexerRepo
proxySvc proxy.Service
}
func NewDownloadService(log logger.Logger, repo domain.ReleaseRepo, indexerRepo domain.IndexerRepo, proxySvc proxy.Service) *DownloadService {
return &DownloadService{
log: log.With().Str("module", "release-download").Logger(),
repo: repo,
indexerRepo: indexerRepo,
proxySvc: proxySvc,
}
}
func (s *DownloadService) DownloadRelease(ctx context.Context, rls *domain.Release) error {
if rls.HasMagnetUri() {
return errors.New("downloading magnet links is not supported: %s", rls.MagnetURI)
} else if rls.Protocol != domain.ReleaseProtocolTorrent {
return errors.New("could not download file: protocol %s is not supported", rls.Protocol)
}
if rls.DownloadURL == "" {
return errors.New("download_file: url can't be empty")
} else if rls.TorrentTmpFile != "" {
// already downloaded
return nil
}
// get indexer
indexer, err := s.indexerRepo.FindByID(ctx, rls.Indexer.ID)
if err != nil {
return err
}
// get proxy
if indexer.UseProxy {
proxyConf, err := s.proxySvc.FindByID(ctx, indexer.ProxyID)
if err != nil {
return err
}
if proxyConf.Enabled {
s.log.Debug().Msgf("using proxy: %s", proxyConf.Name)
indexer.Proxy = proxyConf
} else {
s.log.Debug().Msgf("proxy disabled, skip: %s", proxyConf.Name)
}
}
// download release
err = s.downloadTorrentFile(ctx, indexer, rls)
if err != nil {
return err
}
return nil
}
func (s *DownloadService) downloadTorrentFile(ctx context.Context, indexer *domain.Indexer, r *domain.Release) error {
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.DownloadURL, nil)
if err != nil {
return errors.Wrap(err, "error downloading file")
}
req.Header.Set("User-Agent", "autobrr")
httpClient := &http.Client{
Timeout: 30 * time.Second,
Transport: sharedhttp.TransportTLSInsecure,
}
// handle proxy
if indexer.Proxy != nil {
s.log.Debug().Msgf("using proxy: %s", indexer.Proxy.Name)
proxiedClient, err := proxy.GetProxiedHTTPClient(indexer.Proxy)
if err != nil {
return errors.Wrap(err, "could not get proxied http client")
}
httpClient = proxiedClient
}
if r.RawCookie != "" {
jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
if err != nil {
return errors.Wrap(err, "could not create cookiejar")
}
httpClient.Jar = jar
// set the cookie on the header instead of req.AddCookie
// since we have a raw cookie like "uid=10; pass=000"
req.Header.Set("Cookie", r.RawCookie)
}
tmpFilePattern := "autobrr-"
tmpDir := os.TempDir()
// Create tmp file
// TODO check if tmp file is wanted
tmpFile, err := os.CreateTemp(tmpDir, tmpFilePattern)
if err != nil {
if os.IsNotExist(err) {
if mkdirErr := os.MkdirAll(tmpDir, os.ModePerm); mkdirErr != nil {
return errors.Wrap(mkdirErr, "could not create TMP dir: %s", tmpDir)
}
tmpFile, err = os.CreateTemp(tmpDir, tmpFilePattern)
if err != nil {
return errors.Wrap(err, "error creating tmp file in: %s", tmpDir)
}
} else {
return errors.Wrap(err, "error creating tmp file")
}
}
defer tmpFile.Close()
errFunc := retry.Do(retryableRequest(httpClient, req, r, tmpFile), retry.Delay(time.Second*3), retry.Attempts(3), retry.MaxJitter(time.Second*1))
return errFunc
}
func retryableRequest(httpClient *http.Client, req *http.Request, r *domain.Release, tmpFile *os.File) func() error {
return func() error {
// Get the data
resp, err := httpClient.Do(req)
if err != nil {
if errors.As(err, net.OpError{}) {
return retry.Unrecoverable(errors.Wrap(err, "issue from proxy"))
}
return errors.Wrap(err, "error downloading file")
}
defer resp.Body.Close()
// Check server response
switch resp.StatusCode {
case http.StatusOK:
// Continue processing the response
break
//case http.StatusMovedPermanently, http.StatusFound, http.StatusSeeOther, http.StatusTemporaryRedirect, http.StatusPermanentRedirect:
// // Handle redirect
// return retry.Unrecoverable(errors.New("redirect encountered for torrent (%s) file (%s) - status code: %d - check indexer keys for %s", r.TorrentName, r.DownloadURL, resp.StatusCode, r.Indexer.Name))
case http.StatusUnauthorized, http.StatusForbidden:
return retry.Unrecoverable(errors.New("unrecoverable error downloading torrent (%s) file (%s) - status code: %d - check indexer keys for %s", r.TorrentName, r.DownloadURL, resp.StatusCode, r.Indexer.Name))
case http.StatusMethodNotAllowed:
return retry.Unrecoverable(errors.New("unrecoverable error downloading torrent (%s) file (%s) from '%s' - status code: %d. Check if the request method is correct", r.TorrentName, r.DownloadURL, r.Indexer.Name, resp.StatusCode))
case http.StatusNotFound:
return errors.New("torrent %s not found on %s (%d) - retrying", r.TorrentName, r.Indexer.Name, resp.StatusCode)
case http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout:
return errors.New("server error (%d) encountered while downloading torrent (%s) file (%s) from '%s' - retrying", resp.StatusCode, r.TorrentName, r.DownloadURL, r.Indexer.Name)
case http.StatusInternalServerError:
return errors.New("server error (%d) encountered while downloading torrent (%s) file (%s) - check indexer keys for %s", resp.StatusCode, r.TorrentName, r.DownloadURL, r.Indexer.Name)
default:
return retry.Unrecoverable(errors.New("unexpected status code %d: check indexer keys for %s", resp.StatusCode, r.Indexer.Name))
}
resetTmpFile := func() {
tmpFile.Seek(0, io.SeekStart)
tmpFile.Truncate(0)
}
// Read the body into bytes
bodyBytes, err := io.ReadAll(bufio.NewReader(resp.Body))
if err != nil {
return errors.Wrap(err, "error reading response body")
}
// Create a new reader for bodyBytes
bodyReader := bytes.NewReader(bodyBytes)
// Try to decode as torrent file
meta, err := metainfo.Load(bodyReader)
if err != nil {
resetTmpFile()
// explicitly check for unexpected content type that match html
var bse *bencode.SyntaxError
if errors.As(err, &bse) {
// regular error so we can retry if we receive html first run
return errors.Wrap(err, "metainfo unexpected content type, got HTML expected a bencoded torrent. check indexer keys for %s - %s", r.Indexer.Name, r.TorrentName)
}
return retry.Unrecoverable(errors.Wrap(err, "metainfo unexpected content type. check indexer keys for %s - %s", r.Indexer.Name, r.TorrentName))
}
torrentMetaInfo, err := meta.UnmarshalInfo()
if err != nil {
resetTmpFile()
return retry.Unrecoverable(errors.Wrap(err, "metainfo could not unmarshal info from torrent: %s", tmpFile.Name()))
}
hashInfoBytes := meta.HashInfoBytes().Bytes()
if len(hashInfoBytes) < 1 {
resetTmpFile()
return retry.Unrecoverable(errors.New("could not read infohash"))
}
// Write the body to file
// TODO move to io.Reader and pass around in the future
if _, err := tmpFile.Write(bodyBytes); err != nil {
resetTmpFile()
return errors.Wrap(err, "error writing downloaded file: %s", tmpFile.Name())
}
r.TorrentTmpFile = tmpFile.Name()
r.TorrentHash = meta.HashInfoBytes().String()
r.Size = uint64(torrentMetaInfo.TotalLength())
return nil
}
}
func (s *DownloadService) ResolveMagnetURI(ctx context.Context, r *domain.Release) error {
if r.MagnetURI == "" {
return nil
} else if strings.HasPrefix(r.MagnetURI, domain.MagnetURIPrefix) {
return nil
}
// get indexer
indexer, err := s.indexerRepo.FindByID(ctx, r.Indexer.ID)
if err != nil {
return err
}
httpClient := &http.Client{
Timeout: time.Second * 45,
Transport: sharedhttp.MagnetTransport,
}
// get proxy
if indexer.UseProxy {
proxyConf, err := s.proxySvc.FindByID(ctx, indexer.ProxyID)
if err != nil {
return err
}
s.log.Debug().Msgf("using proxy: %s", proxyConf.Name)
proxiedClient, err := proxy.GetProxiedHTTPClient(proxyConf)
if err != nil {
return errors.Wrap(err, "could not get proxied http client")
}
httpClient = proxiedClient
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, r.MagnetURI, nil)
if err != nil {
return errors.Wrap(err, "could not build request to resolve magnet uri")
}
//req.Header.Set("Content-Type", "application/json")
req.Header.Set("User-Agent", "autobrr")
res, err := httpClient.Do(req)
if err != nil {
return errors.Wrap(err, "could not make request to resolve magnet uri")
}
defer res.Body.Close()
if res.StatusCode != http.StatusOK {
return errors.New("unexpected status code: %d", res.StatusCode)
}
body, err := io.ReadAll(res.Body)
if err != nil {
return errors.Wrap(err, "could not read response body")
}
magnet := string(body)
if magnet != "" {
r.MagnetURI = magnet
}
return nil
}