feat(actions): simplify macro parsing (#560)

* refactor(action): parse macros

* feat(action): add ctx to arr clients and test
This commit is contained in:
ze0s 2022-12-10 21:48:19 +01:00 committed by GitHub
parent f6e68fae2b
commit 839eb9f3f3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
31 changed files with 323 additions and 334 deletions

View file

@ -12,13 +12,13 @@ import (
delugeClient "github.com/gdm85/go-libdeluge" delugeClient "github.com/gdm85/go-libdeluge"
) )
func (s *service) deluge(action domain.Action, release domain.Release) ([]string, error) { func (s *service) deluge(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action Deluge: %v", action.Name) s.log.Debug().Msgf("action Deluge: %v", action.Name)
var err error var err error
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID) s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID)
return nil, err return nil, err
@ -32,16 +32,16 @@ func (s *service) deluge(action domain.Action, release domain.Release) ([]string
switch client.Type { switch client.Type {
case "DELUGE_V1": case "DELUGE_V1":
rejections, err = s.delugeV1(client, action, release) rejections, err = s.delugeV1(ctx, client, action, release)
case "DELUGE_V2": case "DELUGE_V2":
rejections, err = s.delugeV2(client, action, release) rejections, err = s.delugeV2(ctx, client, action, release)
} }
return rejections, err return rejections, err
} }
func (s *service) delugeCheckRulesCanDownload(deluge delugeClient.DelugeClient, client *domain.DownloadClient, action domain.Action) ([]string, error) { func (s *service) delugeCheckRulesCanDownload(deluge delugeClient.DelugeClient, client *domain.DownloadClient, action *domain.Action) ([]string, error) {
s.log.Trace().Msgf("action Deluge: %v check rules", action.Name) s.log.Trace().Msgf("action Deluge: %v check rules", action.Name)
// check for active downloads and other rules // check for active downloads and other rules
@ -86,7 +86,7 @@ func (s *service) delugeCheckRulesCanDownload(deluge delugeClient.DelugeClient,
return nil, nil return nil, nil
} }
func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action, release domain.Release) ([]string, error) { func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) {
settings := delugeClient.Settings{ settings := delugeClient.Settings{
Hostname: client.Host, Hostname: client.Host,
Port: uint(client.Port), Port: uint(client.Port),
@ -117,7 +117,7 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action,
} }
if release.TorrentTmpFile == "" { if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFile(); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName)
return nil, err return nil, err
} }
@ -134,10 +134,7 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action,
return nil, errors.Wrap(err, "could not encode torrent file: %v", release.TorrentTmpFile) return nil, errors.Wrap(err, "could not encode torrent file: %v", release.TorrentTmpFile)
} }
// macros handle args and replace vars options, err := s.prepareDelugeOptions(action)
m := domain.NewMacro(release)
options, err := s.prepareDelugeOptions(action, m)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not prepare options") return nil, errors.Wrap(err, "could not prepare options")
} }
@ -155,15 +152,9 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action,
return nil, errors.Wrap(err, "could not load label plugin for client: %v", client.Name) return nil, errors.Wrap(err, "could not load label plugin for client: %v", client.Name)
} }
// parse and replace values in argument string before continuing
labelArgs, err := m.Parse(action.Label)
if err != nil {
return nil, errors.Wrap(err, "could not parse macro label: %v", action.Label)
}
if labelPluginActive != nil { if labelPluginActive != nil {
// TODO first check if label exists, if not, add it, otherwise set // TODO first check if label exists, if not, add it, otherwise set
err = labelPluginActive.SetTorrentLabel(torrentHash, labelArgs) err = labelPluginActive.SetTorrentLabel(torrentHash, action.Label)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not set label: %v on client: %v", action.Label, client.Name) return nil, errors.Wrap(err, "could not set label: %v on client: %v", action.Label, client.Name)
} }
@ -175,7 +166,7 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action,
return nil, nil return nil, nil
} }
func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action, release domain.Release) ([]string, error) { func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) {
settings := delugeClient.Settings{ settings := delugeClient.Settings{
Hostname: client.Host, Hostname: client.Host,
Port: uint(client.Port), Port: uint(client.Port),
@ -206,7 +197,7 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action,
} }
if release.TorrentTmpFile == "" { if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFile(); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName)
return nil, err return nil, err
} }
@ -223,11 +214,8 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action,
return nil, errors.Wrap(err, "could not encode torrent file: %v", release.TorrentTmpFile) return nil, errors.Wrap(err, "could not encode torrent file: %v", release.TorrentTmpFile)
} }
// macros handle args and replace vars
m := domain.NewMacro(release)
// set options // set options
options, err := s.prepareDelugeOptions(action, m) options, err := s.prepareDelugeOptions(action)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not prepare options") return nil, errors.Wrap(err, "could not prepare options")
} }
@ -245,15 +233,9 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action,
return nil, errors.Wrap(err, "could not load label plugin for client: %v", client.Name) return nil, errors.Wrap(err, "could not load label plugin for client: %v", client.Name)
} }
// parse and replace values in argument string before continuing
labelArgs, err := m.Parse(action.Label)
if err != nil {
return nil, errors.Wrap(err, "could not parse macro label: %v", action.Label)
}
if labelPluginActive != nil { if labelPluginActive != nil {
// TODO first check if label exists, if not, add it, otherwise set // TODO first check if label exists, if not, add it, otherwise set
err = labelPluginActive.SetTorrentLabel(torrentHash, labelArgs) err = labelPluginActive.SetTorrentLabel(torrentHash, action.Label)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not set label: %v on client: %v", action.Label, client.Name) return nil, errors.Wrap(err, "could not set label: %v on client: %v", action.Label, client.Name)
} }
@ -265,7 +247,7 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action,
return nil, nil return nil, nil
} }
func (s *service) prepareDelugeOptions(action domain.Action, m domain.Macro) (delugeClient.Options, error) { func (s *service) prepareDelugeOptions(action *domain.Action) (delugeClient.Options, error) {
// set options // set options
options := delugeClient.Options{} options := delugeClient.Options{}
@ -274,13 +256,7 @@ func (s *service) prepareDelugeOptions(action domain.Action, m domain.Macro) (de
options.AddPaused = &action.Paused options.AddPaused = &action.Paused
} }
if action.SavePath != "" { if action.SavePath != "" {
// parse and replace values in argument string before continuing options.DownloadLocation = &action.SavePath
savePathArgs, err := m.Parse(action.SavePath)
if err != nil {
return options, errors.Wrap(err, "could not parse save path macro: %v", action.SavePath)
}
options.DownloadLocation = &savePathArgs
} }
if action.LimitDownloadSpeed > 0 { if action.LimitDownloadSpeed > 0 {
maxDL := int(action.LimitDownloadSpeed) maxDL := int(action.LimitDownloadSpeed)

View file

@ -1,6 +1,7 @@
package action package action
import ( import (
"context"
"os" "os"
"os/exec" "os/exec"
"strings" "strings"
@ -12,11 +13,11 @@ import (
"github.com/mattn/go-shellwords" "github.com/mattn/go-shellwords"
) )
func (s *service) execCmd(action domain.Action, release domain.Release) error { func (s *service) execCmd(ctx context.Context, action *domain.Action, release domain.Release) error {
s.log.Debug().Msgf("action exec: %v release: %v", action.Name, release.TorrentName) s.log.Debug().Msgf("action exec: %v release: %v", action.Name, release.TorrentName)
if release.TorrentTmpFile == "" && strings.Contains(action.ExecArgs, "TorrentPathName") { if release.TorrentTmpFile == "" && strings.Contains(action.ExecArgs, "TorrentPathName") {
if err := release.DownloadTorrentFile(); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
return errors.Wrap(err, "error downloading torrent file for release: %v", release.TorrentName) return errors.Wrap(err, "error downloading torrent file for release: %v", release.TorrentName)
} }
} }
@ -37,7 +38,9 @@ func (s *service) execCmd(action domain.Action, release domain.Release) error {
return errors.Wrap(err, "exec failed, could not find program: %v", action.ExecCmd) return errors.Wrap(err, "exec failed, could not find program: %v", action.ExecCmd)
} }
args, err := s.parseExecArgs(release, action.ExecArgs) p := shellwords.NewParser()
p.ParseBacktick = true
args, err := p.Parse(action.ExecArgs)
if err != nil { if err != nil {
return errors.Wrap(err, "could not parse exec args: %v", action.ExecArgs) return errors.Wrap(err, "could not parse exec args: %v", action.ExecArgs)
} }
@ -47,7 +50,7 @@ func (s *service) execCmd(action domain.Action, release domain.Release) error {
start := time.Now() start := time.Now()
// setup command and args // setup command and args
command := exec.Command(cmd, args...) command := exec.CommandContext(ctx, cmd, args...)
// execute command // execute command
output, err := command.CombinedOutput() output, err := command.CombinedOutput()
@ -64,23 +67,3 @@ func (s *service) execCmd(action domain.Action, release domain.Release) error {
return nil return nil
} }
func (s *service) parseExecArgs(release domain.Release, execArgs string) ([]string, error) {
// handle args and replace vars
m := domain.NewMacro(release)
// parse and replace values in argument string before continuing
parsedArgs, err := m.Parse(execArgs)
if err != nil {
return nil, errors.Wrap(err, "could not parse macro")
}
p := shellwords.NewParser()
p.ParseBacktick = true
args, err := p.Parse(parsedArgs)
if err != nil {
return nil, errors.Wrap(err, "could not parse into shell-words")
}
return args, nil
}

View file

@ -9,55 +9,48 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func Test_service_parseExecArgs(t *testing.T) { func Test_service_parseMacros(t *testing.T) {
type args struct { type args struct {
release domain.Release release domain.Release
execArgs string action *domain.Action
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want []string want string
wantErr bool wantErr bool
}{ }{
{ {
name: "test_1", name: "test_1",
args: args{ args: args{
release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"},
execArgs: `echo "{{ .TorrentName }}"`, action: &domain.Action{
}, ExecArgs: `echo "{{ .TorrentName }}"`,
want: []string{ },
"echo",
"Sally Goes to the Mall S04E29",
}, },
want: `echo "Sally Goes to the Mall S04E29"`,
wantErr: false, wantErr: false,
}, },
{ {
name: "test_2", name: "test_2",
args: args{ args: args{
release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"},
execArgs: `"{{ .TorrentName }}"`, action: &domain.Action{
}, ExecArgs: `"{{ .TorrentName }}"`,
want: []string{ },
"Sally Goes to the Mall S04E29",
}, },
want: `"Sally Goes to the Mall S04E29"`,
wantErr: false, wantErr: false,
}, },
{ {
name: "test_3", name: "test_3",
args: args{ args: args{
release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"},
execArgs: `--header "Content-Type: application/json" --request POST --data '{"release":"{{ .TorrentName }}"}' http://localhost:3000/api/release`, action: &domain.Action{
}, ExecArgs: `--header "Content-Type: application/json" --request POST --data '{"release":"{{ .TorrentName }}"}' http://localhost:3000/api/release`,
want: []string{ },
"--header",
"Content-Type: application/json",
"--request",
"POST",
"--data",
`{"release":"Sally Goes to the Mall S04E29"}`,
"http://localhost:3000/api/release",
}, },
want: `--header "Content-Type: application/json" --request POST --data '{"release":"Sally Goes to the Mall S04E29"}' http://localhost:3000/api/release`,
wantErr: false, wantErr: false,
}, },
} }
@ -69,8 +62,8 @@ func Test_service_parseExecArgs(t *testing.T) {
clientSvc: nil, clientSvc: nil,
bus: nil, bus: nil,
} }
got, _ := s.parseExecArgs(tt.args.release, tt.args.execArgs) _ = s.parseMacros(tt.args.action, tt.args.release)
assert.Equalf(t, tt.want, got, "parseExecArgs(%v, %v)", tt.args.release, tt.args.execArgs) assert.Equalf(t, tt.want, tt.args.action.ExecArgs, "parseMacros(%v, %v)", tt.args.action, tt.args.release)
}) })
} }
} }
@ -78,7 +71,7 @@ func Test_service_parseExecArgs(t *testing.T) {
func Test_service_execCmd(t *testing.T) { func Test_service_execCmd(t *testing.T) {
type args struct { type args struct {
release domain.Release release domain.Release
action domain.Action action *domain.Action
} }
tests := []struct { tests := []struct {
name string name string
@ -92,7 +85,7 @@ func Test_service_execCmd(t *testing.T) {
TorrentTmpFile: "tmp-10000", TorrentTmpFile: "tmp-10000",
Indexer: "mock", Indexer: "mock",
}, },
action: domain.Action{ action: &domain.Action{
Name: "echo", Name: "echo",
ExecCmd: "echo", ExecCmd: "echo",
ExecArgs: "hello", ExecArgs: "hello",
@ -108,7 +101,7 @@ func Test_service_execCmd(t *testing.T) {
clientSvc: nil, clientSvc: nil,
bus: nil, bus: nil,
} }
s.execCmd(tt.args.action, tt.args.release) s.execCmd(nil, tt.args.action, tt.args.release)
}) })
} }
} }

View file

@ -10,13 +10,13 @@ import (
"github.com/autobrr/autobrr/pkg/lidarr" "github.com/autobrr/autobrr/pkg/lidarr"
) )
func (s *service) lidarr(action domain.Action, release domain.Release) ([]string, error) { func (s *service) lidarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Trace().Msg("action LIDARR") s.log.Trace().Msg("action LIDARR")
// TODO validate data // TODO validate data
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Err(err).Msgf("lidarr: error finding client: %v", action.ClientID) s.log.Error().Err(err).Msgf("lidarr: error finding client: %v", action.ClientID)
return nil, err return nil, err
@ -59,9 +59,9 @@ func (s *service) lidarr(action domain.Action, release domain.Release) ([]string
r.Title = fmt.Sprintf("%v (%d)", release.TorrentName, release.Year) r.Title = fmt.Sprintf("%v (%d)", release.TorrentName, release.Year)
} }
rejections, err := arr.Push(r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {
s.log.Error().Stack().Err(err).Msgf("lidarr: failed to push release: %v", r) s.log.Error().Err(err).Msgf("lidarr: failed to push release: %v", r)
return nil, err return nil, err
} }

View file

@ -9,7 +9,7 @@ import (
"github.com/autobrr/go-qbittorrent" "github.com/autobrr/go-qbittorrent"
) )
func (s *service) qbittorrent(ctx context.Context, action domain.Action, release domain.Release) ([]string, error) { func (s *service) qbittorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action qBittorrent: %v", action.Name) s.log.Debug().Msgf("action qBittorrent: %v", action.Name)
c := s.clientSvc.GetCachedClient(ctx, action.ClientID) c := s.clientSvc.GetCachedClient(ctx, action.ClientID)
@ -29,10 +29,7 @@ func (s *service) qbittorrent(ctx context.Context, action domain.Action, release
} }
} }
// macros handle args and replace vars options, err := s.prepareQbitOptions(action)
m := domain.NewMacro(release)
options, err := s.prepareQbitOptions(action, m)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not prepare options") return nil, errors.Wrap(err, "could not prepare options")
} }
@ -59,7 +56,7 @@ func (s *service) qbittorrent(ctx context.Context, action domain.Action, release
return nil, nil return nil, nil
} }
func (s *service) prepareQbitOptions(action domain.Action, m domain.Macro) (map[string]string, error) { func (s *service) prepareQbitOptions(action *domain.Action) (map[string]string, error) {
opts := &qbittorrent.TorrentAddOptions{} opts := &qbittorrent.TorrentAddOptions{}
opts.Paused = false opts.Paused = false
@ -78,32 +75,14 @@ func (s *service) prepareQbitOptions(action domain.Action, m domain.Macro) (map[
// if ORIGINAL then leave empty // if ORIGINAL then leave empty
} }
if action.SavePath != "" { if action.SavePath != "" {
// parse and replace values in argument string before continuing opts.SavePath = action.SavePath
actionArgs, err := m.Parse(action.SavePath)
if err != nil {
return nil, errors.Wrap(err, "could not parse savepath macro: %v", action.SavePath)
}
opts.SavePath = actionArgs
opts.AutoTMM = false opts.AutoTMM = false
} }
if action.Category != "" { if action.Category != "" {
// parse and replace values in argument string before continuing opts.Category = action.Category
categoryArgs, err := m.Parse(action.Category)
if err != nil {
return nil, errors.Wrap(err, "could not parse category macro: %v", action.Category)
}
opts.Category = categoryArgs
} }
if action.Tags != "" { if action.Tags != "" {
// parse and replace values in argument string before continuing opts.Tags = action.Tags
tagsArgs, err := m.Parse(action.Tags)
if err != nil {
return nil, errors.Wrap(err, "could not parse tags macro: %v", action.Tags)
}
opts.Tags = tagsArgs
} }
if action.LimitUploadSpeed > 0 { if action.LimitUploadSpeed > 0 {
opts.LimitUploadSpeed = action.LimitUploadSpeed opts.LimitUploadSpeed = action.LimitUploadSpeed
@ -121,7 +100,7 @@ func (s *service) prepareQbitOptions(action domain.Action, m domain.Macro) (map[
return opts.Prepare(), nil return opts.Prepare(), nil
} }
func (s *service) qbittorrentCheckRulesCanDownload(ctx context.Context, action domain.Action, client *domain.DownloadClient, qbt *qbittorrent.Client) ([]string, error) { func (s *service) qbittorrentCheckRulesCanDownload(ctx context.Context, action *domain.Action, client *domain.DownloadClient, qbt *qbittorrent.Client) ([]string, error) {
s.log.Trace().Msgf("action qBittorrent: %v check rules", action.Name) s.log.Trace().Msgf("action qBittorrent: %v check rules", action.Name)
// check for active downloads and other rules // check for active downloads and other rules

View file

@ -9,13 +9,13 @@ import (
"github.com/autobrr/autobrr/pkg/radarr" "github.com/autobrr/autobrr/pkg/radarr"
) )
func (s *service) radarr(action domain.Action, release domain.Release) ([]string, error) { func (s *service) radarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Trace().Msg("action RADARR") s.log.Trace().Msg("action RADARR")
// TODO validate data // TODO validate data
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error finding client: %v", action.ClientID) return nil, errors.Wrap(err, "error finding client: %v", action.ClientID)
} }
@ -51,7 +51,7 @@ func (s *service) radarr(action domain.Action, release domain.Release) ([]string
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
rejections, err := arr.Push(r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "radarr failed to push release: %v", r) return nil, errors.Wrap(err, "radarr failed to push release: %v", r)
} }

View file

@ -9,13 +9,13 @@ import (
"github.com/autobrr/autobrr/pkg/readarr" "github.com/autobrr/autobrr/pkg/readarr"
) )
func (s *service) readarr(action domain.Action, release domain.Release) ([]string, error) { func (s *service) readarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Trace().Msg("action READARR") s.log.Trace().Msg("action READARR")
// TODO validate data // TODO validate data
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "readarr could not find client: %v", action.ClientID) return nil, errors.Wrap(err, "readarr could not find client: %v", action.ClientID)
} }
@ -51,7 +51,7 @@ func (s *service) readarr(action domain.Action, release domain.Release) ([]strin
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
rejections, err := arr.Push(r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "readarr: failed to push release: %v", r) return nil, errors.Wrap(err, "readarr: failed to push release: %v", r)
} }

View file

@ -9,13 +9,13 @@ import (
"github.com/mrobinsn/go-rtorrent/rtorrent" "github.com/mrobinsn/go-rtorrent/rtorrent"
) )
func (s *service) rtorrent(action domain.Action, release domain.Release) ([]string, error) { func (s *service) rtorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action rTorrent: %v", action.Name) s.log.Debug().Msgf("action rTorrent: %v", action.Name)
var err error var err error
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID) s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID)
return nil, err return nil, err
@ -28,7 +28,7 @@ func (s *service) rtorrent(action domain.Action, release domain.Release) ([]stri
var rejections []string var rejections []string
if release.TorrentTmpFile == "" { if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFile(); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName)
return nil, err return nil, err
} }

View file

@ -30,45 +30,50 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release
} }
}() }()
// parse all macros in one go
if err := action.ParseMacros(release); err != nil {
return nil, err
}
switch action.Type { switch action.Type {
case domain.ActionTypeTest: case domain.ActionTypeTest:
s.test(action.Name) s.test(action.Name)
case domain.ActionTypeExec: case domain.ActionTypeExec:
err = s.execCmd(*action, release) err = s.execCmd(ctx, action, release)
case domain.ActionTypeWatchFolder: case domain.ActionTypeWatchFolder:
err = s.watchFolder(*action, release) err = s.watchFolder(ctx, action, release)
case domain.ActionTypeWebhook: case domain.ActionTypeWebhook:
err = s.webhook(ctx, *action, release) err = s.webhook(ctx, action, release)
case domain.ActionTypeDelugeV1, domain.ActionTypeDelugeV2: case domain.ActionTypeDelugeV1, domain.ActionTypeDelugeV2:
rejections, err = s.deluge(*action, release) rejections, err = s.deluge(ctx, action, release)
case domain.ActionTypeQbittorrent: case domain.ActionTypeQbittorrent:
rejections, err = s.qbittorrent(ctx, *action, release) rejections, err = s.qbittorrent(ctx, action, release)
case domain.ActionTypeRTorrent: case domain.ActionTypeRTorrent:
rejections, err = s.rtorrent(*action, release) rejections, err = s.rtorrent(ctx, action, release)
case domain.ActionTypeTransmission: case domain.ActionTypeTransmission:
rejections, err = s.transmission(*action, release) rejections, err = s.transmission(ctx, action, release)
case domain.ActionTypeRadarr: case domain.ActionTypeRadarr:
rejections, err = s.radarr(*action, release) rejections, err = s.radarr(ctx, action, release)
case domain.ActionTypeSonarr: case domain.ActionTypeSonarr:
rejections, err = s.sonarr(*action, release) rejections, err = s.sonarr(ctx, action, release)
case domain.ActionTypeLidarr: case domain.ActionTypeLidarr:
rejections, err = s.lidarr(*action, release) rejections, err = s.lidarr(ctx, action, release)
case domain.ActionTypeWhisparr: case domain.ActionTypeWhisparr:
rejections, err = s.whisparr(*action, release) rejections, err = s.whisparr(ctx, action, release)
case domain.ActionTypeReadarr: case domain.ActionTypeReadarr:
rejections, err = s.readarr(*action, release) rejections, err = s.readarr(ctx, action, release)
default: default:
s.log.Warn().Msgf("unsupported action type: %v", action.Type) s.log.Warn().Msgf("unsupported action type: %v", action.Type)
@ -137,9 +142,9 @@ func (s *service) test(name string) {
s.log.Info().Msgf("action TEST: %v", name) s.log.Info().Msgf("action TEST: %v", name)
} }
func (s *service) watchFolder(action domain.Action, release domain.Release) error { func (s *service) watchFolder(ctx context.Context, action *domain.Action, release domain.Release) error {
if release.TorrentTmpFile == "" { if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFile(); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
return errors.Wrap(err, "watch folder: could not download torrent file for release: %v", release.TorrentName) return errors.Wrap(err, "watch folder: could not download torrent file for release: %v", release.TorrentName)
} }
} }
@ -153,19 +158,7 @@ func (s *service) watchFolder(action domain.Action, release domain.Release) erro
release.TorrentDataRawBytes = t release.TorrentDataRawBytes = t
} }
m := domain.NewMacro(release) s.log.Trace().Msgf("action WATCH_FOLDER: %v file: %v", action.WatchFolder, release.TorrentTmpFile)
// parse and replace values in argument string before continuing
// /mnt/watch/{{.Indexer}}
// /mnt/watch/mock
// /mnt/watch/{{.Indexer}}-{{.TorrentName}}.torrent
// /mnt/watch/mock-Torrent.Name-GROUP.torrent
watchFolderArgs, err := m.Parse(action.WatchFolder)
if err != nil {
return errors.Wrap(err, "could not parse watch folder macro: %v", action.WatchFolder)
}
s.log.Trace().Msgf("action WATCH_FOLDER: %v file: %v", watchFolderArgs, release.TorrentTmpFile)
// Open original file // Open original file
original, err := os.Open(release.TorrentTmpFile) original, err := os.Open(release.TorrentTmpFile)
@ -175,16 +168,20 @@ func (s *service) watchFolder(action domain.Action, release domain.Release) erro
defer original.Close() defer original.Close()
// default dir to watch folder // default dir to watch folder
dir := watchFolderArgs // /mnt/watch/{{.Indexer}}
newFileName := watchFolderArgs // /mnt/watch/mock
// /mnt/watch/{{.Indexer}}-{{.TorrentName}}.torrent
// /mnt/watch/mock-Torrent.Name-GROUP.torrent
dir := action.WatchFolder
newFileName := action.WatchFolder
// if watchFolderArgs does not contain .torrent, create // if watchFolderArgs does not contain .torrent, create
if !strings.HasSuffix(watchFolderArgs, ".torrent") { if !strings.HasSuffix(action.WatchFolder, ".torrent") {
_, tmpFileName := filepath.Split(release.TorrentTmpFile) _, tmpFileName := filepath.Split(release.TorrentTmpFile)
newFileName = filepath.Join(watchFolderArgs, tmpFileName+".torrent") newFileName = filepath.Join(action.WatchFolder, tmpFileName+".torrent")
} else { } else {
dir, _ = filepath.Split(watchFolderArgs) dir, _ = filepath.Split(action.WatchFolder)
} }
// Create folder // Create folder
@ -209,7 +206,7 @@ func (s *service) watchFolder(action domain.Action, release domain.Release) erro
return nil return nil
} }
func (s *service) webhook(ctx context.Context, action domain.Action, release domain.Release) error { func (s *service) webhook(ctx context.Context, action *domain.Action, release domain.Release) error {
// if webhook data contains TorrentPathName or TorrentDataRawBytes, lets download the torrent file // if webhook data contains TorrentPathName or TorrentDataRawBytes, lets download the torrent file
if release.TorrentTmpFile == "" && (strings.Contains(action.WebhookData, "TorrentPathName") || strings.Contains(action.WebhookData, "TorrentDataRawBytes")) { if release.TorrentTmpFile == "" && (strings.Contains(action.WebhookData, "TorrentPathName") || strings.Contains(action.WebhookData, "TorrentDataRawBytes")) {
if err := release.DownloadTorrentFileCtx(ctx); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
@ -227,14 +224,6 @@ func (s *service) webhook(ctx context.Context, action domain.Action, release dom
release.TorrentDataRawBytes = t release.TorrentDataRawBytes = t
} }
m := domain.NewMacro(release)
// parse and replace values in argument string before continuing
dataArgs, err := m.Parse(action.WebhookData)
if err != nil {
return errors.Wrap(err, "could not parse webhook data macro: %v", action.WebhookData)
}
s.log.Trace().Msgf("action WEBHOOK: '%v' file: %v", action.Name, release.TorrentName) s.log.Trace().Msgf("action WEBHOOK: '%v' file: %v", action.Name, release.TorrentName)
s.log.Trace().Msgf("webhook action '%v' - host: %v data: %v", action.Name, action.WebhookHost, action.WebhookData) s.log.Trace().Msgf("webhook action '%v' - host: %v data: %v", action.Name, action.WebhookHost, action.WebhookData)
@ -246,7 +235,7 @@ func (s *service) webhook(ctx context.Context, action domain.Action, release dom
client := http.Client{Transport: t, Timeout: 15 * time.Second} client := http.Client{Transport: t, Timeout: 15 * time.Second}
req, err := http.NewRequestWithContext(ctx, http.MethodPost, action.WebhookHost, bytes.NewBufferString(dataArgs)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, action.WebhookHost, bytes.NewBufferString(action.WebhookData))
if err != nil { if err != nil {
return errors.Wrap(err, "could not build request for webhook") return errors.Wrap(err, "could not build request for webhook")
} }
@ -261,7 +250,11 @@ func (s *service) webhook(ctx context.Context, action domain.Action, release dom
defer res.Body.Close() defer res.Body.Close()
s.log.Info().Msgf("successfully ran webhook action: '%v' to: %v payload: %v", action.Name, action.WebhookHost, dataArgs) s.log.Info().Msgf("successfully ran webhook action: '%v' to: %v payload: %v", action.Name, action.WebhookHost, action.WebhookData)
return nil return nil
} }
func (s *service) parseMacros(action *domain.Action, release domain.Release) error {
// parse all macros in one go
return action.ParseMacros(release)
}

View file

@ -9,13 +9,13 @@ import (
"github.com/autobrr/autobrr/pkg/sonarr" "github.com/autobrr/autobrr/pkg/sonarr"
) )
func (s *service) sonarr(action domain.Action, release domain.Release) ([]string, error) { func (s *service) sonarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Trace().Msg("action SONARR") s.log.Trace().Msg("action SONARR")
// TODO validate data // TODO validate data
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID)
} }
@ -51,7 +51,7 @@ func (s *service) sonarr(action domain.Action, release domain.Release) ([]string
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
rejections, err := arr.Push(r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sonarr: failed to push release: %v", r) return nil, errors.Wrap(err, "sonarr: failed to push release: %v", r)
} }

View file

@ -9,13 +9,13 @@ import (
"github.com/hekmon/transmissionrpc/v2" "github.com/hekmon/transmissionrpc/v2"
) )
func (s *service) transmission(action domain.Action, release domain.Release) ([]string, error) { func (s *service) transmission(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Debug().Msgf("action Transmission: %v", action.Name) s.log.Debug().Msgf("action Transmission: %v", action.Name)
var err error var err error
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID) s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID)
return nil, err return nil, err
@ -28,7 +28,7 @@ func (s *service) transmission(action domain.Action, release domain.Release) ([]
var rejections []string var rejections []string
if release.TorrentTmpFile == "" { if release.TorrentTmpFile == "" {
if err := release.DownloadTorrentFile(); err != nil { if err := release.DownloadTorrentFileCtx(ctx); err != nil {
s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName)
return nil, err return nil, err
} }
@ -58,7 +58,7 @@ func (s *service) transmission(action domain.Action, release domain.Release) ([]
} }
// Prepare and send payload // Prepare and send payload
torrent, err := tbt.TorrentAdd(context.TODO(), payload) torrent, err := tbt.TorrentAdd(ctx, payload)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not add torrent %v to client: %v", release.TorrentTmpFile, client.Host) return nil, errors.Wrap(err, "could not add torrent %v to client: %v", release.TorrentTmpFile, client.Host)
} }

View file

@ -9,13 +9,13 @@ import (
"github.com/autobrr/autobrr/pkg/whisparr" "github.com/autobrr/autobrr/pkg/whisparr"
) )
func (s *service) whisparr(action domain.Action, release domain.Release) ([]string, error) { func (s *service) whisparr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) {
s.log.Trace().Msg("action WHISPARR") s.log.Trace().Msg("action WHISPARR")
// TODO validate data // TODO validate data
// get client for action // get client for action
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) client, err := s.clientSvc.FindByID(ctx, action.ClientID)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID)
} }
@ -51,7 +51,7 @@ func (s *service) whisparr(action domain.Action, release domain.Release) ([]stri
PublishDate: time.Now().Format(time.RFC3339), PublishDate: time.Now().Format(time.RFC3339),
} }
rejections, err := arr.Push(r) rejections, err := arr.Push(ctx, r)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "whisparr: failed to push release: %v", r) return nil, errors.Wrap(err, "whisparr: failed to push release: %v", r)
} }

View file

@ -1,6 +1,9 @@
package domain package domain
import "context" import (
"context"
"github.com/autobrr/autobrr/pkg/errors"
)
type ActionRepo interface { type ActionRepo interface {
Store(ctx context.Context, action Action) (*Action, error) Store(ctx context.Context, action Action) (*Action, error)
@ -46,6 +49,27 @@ type Action struct {
Client DownloadClient `json:"client,omitempty"` Client DownloadClient `json:"client,omitempty"`
} }
// ParseMacros parse all macros on action
func (a *Action) ParseMacros(release Release) error {
var err error
m := NewMacro(release)
a.ExecArgs, err = m.Parse(a.ExecArgs)
a.WatchFolder, err = m.Parse(a.WatchFolder)
a.Category, err = m.Parse(a.Category)
a.Tags, err = m.Parse(a.Tags)
a.Label, err = m.Parse(a.Label)
a.SavePath, err = m.Parse(a.SavePath)
a.WebhookData, err = m.Parse(a.WebhookData)
if err != nil {
return errors.Wrap(err, "could not parse macros for action: %v", a.Name)
}
return nil
}
type ActionType string type ActionType string
const ( const (

View file

@ -5,6 +5,8 @@ import (
"fmt" "fmt"
"net/url" "net/url"
"github.com/autobrr/autobrr/pkg/errors"
"github.com/autobrr/go-qbittorrent" "github.com/autobrr/go-qbittorrent"
) )
@ -69,6 +71,18 @@ const (
DownloadClientTypeReadarr DownloadClientType = "READARR" DownloadClientTypeReadarr DownloadClientType = "READARR"
) )
// Validate basic validation of client
func (c DownloadClient) Validate() error {
// basic validation of client
if c.Host == "" {
return errors.New("validation error: missing host")
} else if c.Type == "" {
return errors.New("validation error: missing type")
}
return nil
}
func (c DownloadClient) BuildLegacyHost() string { func (c DownloadClient) BuildLegacyHost() string {
if c.Type == DownloadClientTypeQbittorrent { if c.Type == DownloadClientTypeQbittorrent {
return c.qbitBuildLegacyHost() return c.qbitBuildLegacyHost()

View file

@ -10,54 +10,54 @@ import (
) )
type Macro struct { type Macro struct {
TorrentName string TorrentName string
TorrentPathName string TorrentPathName string
TorrentHash string TorrentHash string
TorrentUrl string TorrentUrl string
TorrentDataRawBytes []byte TorrentDataRawBytes []byte
Indexer string Indexer string
Title string Title string
Resolution string Resolution string
Source string Source string
HDR string HDR string
FilterName string FilterName string
Size uint64 Size uint64
Season int Season int
Episode int Episode int
Year int Year int
CurrentYear int CurrentYear int
CurrentMonth int CurrentMonth int
CurrentDay int CurrentDay int
CurrentHour int CurrentHour int
CurrentMinute int CurrentMinute int
CurrentSecond int CurrentSecond int
} }
func NewMacro(release Release) Macro { func NewMacro(release Release) Macro {
currentTime := time.Now() currentTime := time.Now()
ma := Macro{ ma := Macro{
TorrentName: release.TorrentName, TorrentName: release.TorrentName,
TorrentUrl: release.TorrentURL, TorrentUrl: release.TorrentURL,
TorrentPathName: release.TorrentTmpFile, TorrentPathName: release.TorrentTmpFile,
TorrentDataRawBytes: release.TorrentDataRawBytes, TorrentDataRawBytes: release.TorrentDataRawBytes,
TorrentHash: release.TorrentHash, TorrentHash: release.TorrentHash,
Indexer: release.Indexer, Indexer: release.Indexer,
Title: release.Title, Title: release.Title,
Resolution: release.Resolution, Resolution: release.Resolution,
Source: release.Source, Source: release.Source,
HDR: strings.Join(release.HDR, ", "), HDR: strings.Join(release.HDR, ", "),
FilterName: release.FilterName, FilterName: release.FilterName,
Size: release.Size, Size: release.Size,
Season: release.Season, Season: release.Season,
Episode: release.Episode, Episode: release.Episode,
Year: release.Year, Year: release.Year,
CurrentYear: currentTime.Year(), CurrentYear: currentTime.Year(),
CurrentMonth: int(currentTime.Month()), CurrentMonth: int(currentTime.Month()),
CurrentDay: currentTime.Day(), CurrentDay: currentTime.Day(),
CurrentHour: currentTime.Hour(), CurrentHour: currentTime.Hour(),
CurrentMinute: currentTime.Minute(), CurrentMinute: currentTime.Minute(),
CurrentSecond: currentTime.Second(), CurrentSecond: currentTime.Second(),
} }
return ma return ma
@ -65,6 +65,9 @@ func NewMacro(release Release) Macro {
// Parse takes a string and replaces valid vars // Parse takes a string and replaces valid vars
func (m Macro) Parse(text string) (string, error) { func (m Macro) Parse(text string) (string, error) {
if text == "" {
return "", nil
}
// setup template // setup template
tmpl, err := template.New("macro").Parse(text) tmpl, err := template.New("macro").Parse(text)
@ -80,3 +83,24 @@ func (m Macro) Parse(text string) (string, error) {
return tpl.String(), nil return tpl.String(), nil
} }
// MustParse takes a string and replaces valid vars
func (m Macro) MustParse(text string) string {
if text == "" {
return ""
}
// setup template
tmpl, err := template.New("macro").Parse(text)
if err != nil {
return ""
}
var tpl bytes.Buffer
err = tmpl.Execute(&tpl, m)
if err != nil {
return ""
}
return tpl.String()
}

View file

@ -30,22 +30,23 @@ func (s *service) testConnection(ctx context.Context, client domain.DownloadClie
return s.testRTorrentConnection(client) return s.testRTorrentConnection(client)
case domain.DownloadClientTypeTransmission: case domain.DownloadClientTypeTransmission:
return s.testTransmissionConnection(client) return s.testTransmissionConnection(ctx, client)
case domain.DownloadClientTypeRadarr: case domain.DownloadClientTypeRadarr:
return s.testRadarrConnection(client) return s.testRadarrConnection(ctx, client)
case domain.DownloadClientTypeSonarr: case domain.DownloadClientTypeSonarr:
return s.testSonarrConnection(client) return s.testSonarrConnection(ctx, client)
case domain.DownloadClientTypeLidarr: case domain.DownloadClientTypeLidarr:
return s.testLidarrConnection(client) return s.testLidarrConnection(ctx, client)
case domain.DownloadClientTypeWhisparr: case domain.DownloadClientTypeWhisparr:
return s.testWhisparrConnection(client) return s.testWhisparrConnection(ctx, client)
case domain.DownloadClientTypeReadarr: case domain.DownloadClientTypeReadarr:
return s.testReadarrConnection(client) return s.testReadarrConnection(ctx, client)
default: default:
return errors.New("unsupported client") return errors.New("unsupported client")
} }
@ -138,7 +139,7 @@ func (s *service) testRTorrentConnection(client domain.DownloadClient) error {
return nil return nil
} }
func (s *service) testTransmissionConnection(client domain.DownloadClient) error { func (s *service) testTransmissionConnection(ctx context.Context, client domain.DownloadClient) error {
tbt, err := transmissionrpc.New(client.Host, client.Username, client.Password, &transmissionrpc.AdvancedConfig{ tbt, err := transmissionrpc.New(client.Host, client.Username, client.Password, &transmissionrpc.AdvancedConfig{
HTTPS: client.TLS, HTTPS: client.TLS,
Port: uint16(client.Port), Port: uint16(client.Port),
@ -147,7 +148,7 @@ func (s *service) testTransmissionConnection(client domain.DownloadClient) error
return errors.Wrap(err, "error logging into client: %v", client.Host) return errors.Wrap(err, "error logging into client: %v", client.Host)
} }
ok, version, _, err := tbt.RPCVersion(context.TODO()) ok, version, _, err := tbt.RPCVersion(ctx)
if err != nil { if err != nil {
return errors.Wrap(err, "error getting rpc info: %v", client.Host) return errors.Wrap(err, "error getting rpc info: %v", client.Host)
} }
@ -163,7 +164,7 @@ func (s *service) testTransmissionConnection(client domain.DownloadClient) error
return nil return nil
} }
func (s *service) testRadarrConnection(client domain.DownloadClient) error { func (s *service) testRadarrConnection(ctx context.Context, client domain.DownloadClient) error {
r := radarr.New(radarr.Config{ r := radarr.New(radarr.Config{
Hostname: client.Host, Hostname: client.Host,
APIKey: client.Settings.APIKey, APIKey: client.Settings.APIKey,
@ -173,8 +174,7 @@ func (s *service) testRadarrConnection(client domain.DownloadClient) error {
Log: s.subLogger, Log: s.subLogger,
}) })
_, err := r.Test() if _, err := r.Test(ctx); err != nil {
if err != nil {
return errors.Wrap(err, "radarr: connection test failed: %v", client.Host) return errors.Wrap(err, "radarr: connection test failed: %v", client.Host)
} }
@ -183,7 +183,7 @@ func (s *service) testRadarrConnection(client domain.DownloadClient) error {
return nil return nil
} }
func (s *service) testSonarrConnection(client domain.DownloadClient) error { func (s *service) testSonarrConnection(ctx context.Context, client domain.DownloadClient) error {
r := sonarr.New(sonarr.Config{ r := sonarr.New(sonarr.Config{
Hostname: client.Host, Hostname: client.Host,
APIKey: client.Settings.APIKey, APIKey: client.Settings.APIKey,
@ -193,8 +193,7 @@ func (s *service) testSonarrConnection(client domain.DownloadClient) error {
Log: s.subLogger, Log: s.subLogger,
}) })
_, err := r.Test() if _, err := r.Test(ctx); err != nil {
if err != nil {
return errors.Wrap(err, "sonarr: connection test failed: %v", client.Host) return errors.Wrap(err, "sonarr: connection test failed: %v", client.Host)
} }
@ -203,7 +202,7 @@ func (s *service) testSonarrConnection(client domain.DownloadClient) error {
return nil return nil
} }
func (s *service) testLidarrConnection(client domain.DownloadClient) error { func (s *service) testLidarrConnection(ctx context.Context, client domain.DownloadClient) error {
r := lidarr.New(lidarr.Config{ r := lidarr.New(lidarr.Config{
Hostname: client.Host, Hostname: client.Host,
APIKey: client.Settings.APIKey, APIKey: client.Settings.APIKey,
@ -213,8 +212,7 @@ func (s *service) testLidarrConnection(client domain.DownloadClient) error {
Log: s.subLogger, Log: s.subLogger,
}) })
_, err := r.Test() if _, err := r.Test(ctx); err != nil {
if err != nil {
return errors.Wrap(err, "lidarr: connection test failed: %v", client.Host) return errors.Wrap(err, "lidarr: connection test failed: %v", client.Host)
} }
@ -223,7 +221,7 @@ func (s *service) testLidarrConnection(client domain.DownloadClient) error {
return nil return nil
} }
func (s *service) testWhisparrConnection(client domain.DownloadClient) error { func (s *service) testWhisparrConnection(ctx context.Context, client domain.DownloadClient) error {
r := whisparr.New(whisparr.Config{ r := whisparr.New(whisparr.Config{
Hostname: client.Host, Hostname: client.Host,
APIKey: client.Settings.APIKey, APIKey: client.Settings.APIKey,
@ -233,8 +231,7 @@ func (s *service) testWhisparrConnection(client domain.DownloadClient) error {
Log: s.subLogger, Log: s.subLogger,
}) })
_, err := r.Test() if _, err := r.Test(ctx); err != nil {
if err != nil {
return errors.Wrap(err, "whisparr: connection test failed: %v", client.Host) return errors.Wrap(err, "whisparr: connection test failed: %v", client.Host)
} }
@ -243,7 +240,7 @@ func (s *service) testWhisparrConnection(client domain.DownloadClient) error {
return nil return nil
} }
func (s *service) testReadarrConnection(client domain.DownloadClient) error { func (s *service) testReadarrConnection(ctx context.Context, client domain.DownloadClient) error {
r := readarr.New(readarr.Config{ r := readarr.New(readarr.Config{
Hostname: client.Host, Hostname: client.Host,
APIKey: client.Settings.APIKey, APIKey: client.Settings.APIKey,
@ -253,8 +250,7 @@ func (s *service) testReadarrConnection(client domain.DownloadClient) error {
Log: s.subLogger, Log: s.subLogger,
}) })
_, err := r.Test() if _, err := r.Test(ctx); err != nil {
if err != nil {
return errors.Wrap(err, "readarr: connection test failed: %v", client.Host) return errors.Wrap(err, "readarr: connection test failed: %v", client.Host)
} }

View file

@ -2,7 +2,6 @@ package download_client
import ( import (
"context" "context"
"errors"
"log" "log"
"sync" "sync"
@ -69,11 +68,9 @@ func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClien
} }
func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
// validate data // basic validation of client
if client.Host == "" { if err := client.Validate(); err != nil {
return nil, errors.New("validation error: no host") return nil, err
} else if client.Type == "" {
return nil, errors.New("validation error: no type")
} }
// store // store
@ -87,11 +84,9 @@ func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*dom
} }
func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
// validate data // basic validation of client
if client.Host == "" { if err := client.Validate(); err != nil {
return nil, errors.New("validation error: no host") return nil, err
} else if client.Type == "" {
return nil, errors.New("validation error: no type")
} }
// update // update
@ -125,10 +120,8 @@ func (s *service) Delete(ctx context.Context, clientID int) error {
func (s *service) Test(ctx context.Context, client domain.DownloadClient) error { func (s *service) Test(ctx context.Context, client domain.DownloadClient) error {
// basic validation of client // basic validation of client
if client.Host == "" { if err := client.Validate(); err != nil {
return errors.New("validation error: no host") return err
} else if client.Type == "" {
return errors.New("validation error: no type")
} }
// test // test

View file

@ -2,6 +2,7 @@ package lidarr
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -11,12 +12,12 @@ import (
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
) )
func (c *client) get(endpoint string) (int, []byte, error) { func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v1/", endpoint) u.Path = path.Join(u.Path, "/api/v1/", endpoint)
reqUrl := u.String() reqUrl := u.String()
req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody)
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "lidarr client request error : %v", reqUrl) return 0, nil, errors.Wrap(err, "lidarr client request error : %v", reqUrl)
} }
@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) {
return resp.StatusCode, buf.Bytes(), nil return resp.StatusCode, buf.Bytes(), nil
} }
func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v1/", endpoint) u.Path = path.Join(u.Path, "/api/v1/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return nil, errors.Wrap(err, "lidarr client could not marshal data: %v", reqUrl) return nil, errors.Wrap(err, "lidarr client could not marshal data: %v", reqUrl)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "lidarr client request error: %v", reqUrl) return nil, errors.Wrap(err, "lidarr client request error: %v", reqUrl)
} }
@ -81,7 +82,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return res, nil return res, nil
} }
func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v1/", endpoint) u.Path = path.Join(u.Path, "/api/v1/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -91,7 +92,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error
return 0, nil, errors.Wrap(err, "lidarr client could not marshal data: %v", reqUrl) return 0, nil, errors.Wrap(err, "lidarr client could not marshal data: %v", reqUrl)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "lidarr client request error: %v", reqUrl) return 0, nil, errors.Wrap(err, "lidarr client request error: %v", reqUrl)
} }

View file

@ -1,6 +1,7 @@
package lidarr package lidarr
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -25,8 +26,8 @@ type Config struct {
} }
type Client interface { type Client interface {
Test() (*SystemStatusResponse, error) Test(ctx context.Context) (*SystemStatusResponse, error)
Push(release Release) ([]string, error) Push(ctx context.Context, release Release) ([]string, error)
} }
type client struct { type client struct {
@ -89,8 +90,8 @@ type SystemStatusResponse struct {
Version string `json:"version"` Version string `json:"version"`
} }
func (c *client) Test() (*SystemStatusResponse, error) { func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) {
status, res, err := c.get("system/status") status, res, err := c.get(ctx, "system/status")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "lidarr client get error") return nil, errors.Wrap(err, "lidarr client get error")
} }
@ -110,8 +111,8 @@ func (c *client) Test() (*SystemStatusResponse, error) {
return &response, nil return &response, nil
} }
func (c *client) Push(release Release) ([]string, error) { func (c *client) Push(ctx context.Context, release Release) ([]string, error) {
status, res, err := c.postBody("release/push", release) status, res, err := c.postBody(ctx, "release/push", release)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "lidarr client post error") return nil, errors.Wrap(err, "lidarr client post error")
} }

View file

@ -1,6 +1,7 @@
package lidarr package lidarr
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -101,7 +102,7 @@ func Test_client_Push(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.fields.config) c := New(tt.fields.config)
rejections, err := c.Push(tt.args.release) rejections, err := c.Push(context.Background(), tt.args.release)
assert.Equal(t, tt.rejections, rejections) assert.Equal(t, tt.rejections, rejections)
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.Equal(t, tt.err, err) assert.Equal(t, tt.err, err)
@ -170,7 +171,7 @@ func Test_client_Test(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.cfg) c := New(tt.cfg)
got, err := c.Test() got, err := c.Test(context.Background())
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err)
} }

View file

@ -2,6 +2,7 @@ package radarr
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -11,12 +12,12 @@ import (
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
) )
func (c *client) get(endpoint string) (int, []byte, error) { func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody)
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "could not build request: %v", reqUrl) return 0, nil, errors.Wrap(err, "could not build request: %v", reqUrl)
} }
@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) {
return resp.StatusCode, buf.Bytes(), nil return resp.StatusCode, buf.Bytes(), nil
} }
func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return nil, errors.Wrap(err, "could not marshal data: %+v", data) return nil, errors.Wrap(err, "could not marshal data: %+v", data)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not build request: %v", reqUrl) return nil, errors.Wrap(err, "could not build request: %v", reqUrl)
} }
@ -83,7 +84,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return res, nil return res, nil
} }
func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -93,7 +94,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error
return 0, nil, errors.Wrap(err, "could not marshal data: %+v", data) return 0, nil, errors.Wrap(err, "could not marshal data: %+v", data)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "could not build request: %v", reqUrl) return 0, nil, errors.Wrap(err, "could not build request: %v", reqUrl)
} }

View file

@ -1,6 +1,7 @@
package radarr package radarr
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -25,8 +26,8 @@ type Config struct {
} }
type Client interface { type Client interface {
Test() (*SystemStatusResponse, error) Test(ctx context.Context) (*SystemStatusResponse, error)
Push(release Release) ([]string, error) Push(ctx context.Context, release Release) ([]string, error)
} }
type client struct { type client struct {
@ -88,8 +89,8 @@ func (r *BadRequestResponse) String() string {
return fmt.Sprintf("[%v: %v] %v: %v - got value: %v", r.Severity, r.ErrorCode, r.PropertyName, r.ErrorMessage, r.AttemptedValue) return fmt.Sprintf("[%v: %v] %v: %v - got value: %v", r.Severity, r.ErrorCode, r.PropertyName, r.ErrorMessage, r.AttemptedValue)
} }
func (c *client) Test() (*SystemStatusResponse, error) { func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) {
status, res, err := c.get("system/status") status, res, err := c.get(ctx, "system/status")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "radarr error running test") return nil, errors.Wrap(err, "radarr error running test")
} }
@ -108,8 +109,8 @@ func (c *client) Test() (*SystemStatusResponse, error) {
return &response, nil return &response, nil
} }
func (c *client) Push(release Release) ([]string, error) { func (c *client) Push(ctx context.Context, release Release) ([]string, error) {
status, res, err := c.postBody("release/push", release) status, res, err := c.postBody(ctx, "release/push", release)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "error push release") return nil, errors.Wrap(err, "error push release")
} }

View file

@ -1,6 +1,7 @@
package radarr package radarr
import ( import (
"context"
"io" "io"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
@ -141,7 +142,7 @@ func Test_client_Push(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.fields.config) c := New(tt.fields.config)
rejections, err := c.Push(tt.args.release) rejections, err := c.Push(context.Background(), tt.args.release)
assert.Equal(t, tt.rejections, rejections) assert.Equal(t, tt.rejections, rejections)
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.Equal(t, tt.err, err) assert.Equal(t, tt.err, err)
@ -223,7 +224,7 @@ func Test_client_Test(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.cfg) c := New(tt.cfg)
got, err := c.Test() got, err := c.Test(context.Background())
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err)
} }

View file

@ -2,6 +2,7 @@ package readarr
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -11,12 +12,12 @@ import (
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
) )
func (c *client) get(endpoint string) (int, []byte, error) { func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v1/", endpoint) u.Path = path.Join(u.Path, "/api/v1/", endpoint)
reqUrl := u.String() reqUrl := u.String()
req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody)
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "could not build request") return 0, nil, errors.Wrap(err, "could not build request")
} }
@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) {
return resp.StatusCode, buf.Bytes(), nil return resp.StatusCode, buf.Bytes(), nil
} }
func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v1/", endpoint) u.Path = path.Join(u.Path, "/api/v1/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return nil, errors.Wrap(err, "could not marshal data: %+v", data) return nil, errors.Wrap(err, "could not marshal data: %+v", data)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not build request") return nil, errors.Wrap(err, "could not build request")
} }
@ -81,7 +82,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return res, nil return res, nil
} }
func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v1/", endpoint) u.Path = path.Join(u.Path, "/api/v1/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -93,7 +94,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error
c.Log.Printf("readarr push JSON: %s\n", string(jsonData)) c.Log.Printf("readarr push JSON: %s\n", string(jsonData))
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "could not build request") return 0, nil, errors.Wrap(err, "could not build request")
} }

View file

@ -1,6 +1,7 @@
package readarr package readarr
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -26,8 +27,8 @@ type Config struct {
} }
type Client interface { type Client interface {
Test() (*SystemStatusResponse, error) Test(ctx context.Context) (*SystemStatusResponse, error)
Push(release Release) ([]string, error) Push(ctx context.Context, release Release) ([]string, error)
} }
type client struct { type client struct {
@ -92,8 +93,8 @@ type SystemStatusResponse struct {
Version string `json:"version"` Version string `json:"version"`
} }
func (c *client) Test() (*SystemStatusResponse, error) { func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) {
status, res, err := c.get("system/status") status, res, err := c.get(ctx, "system/status")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not make Test") return nil, errors.Wrap(err, "could not make Test")
} }
@ -112,8 +113,8 @@ func (c *client) Test() (*SystemStatusResponse, error) {
return &response, nil return &response, nil
} }
func (c *client) Push(release Release) ([]string, error) { func (c *client) Push(ctx context.Context, release Release) ([]string, error) {
status, res, err := c.postBody("release/push", release) status, res, err := c.postBody(ctx, "release/push", release)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not push release to readarr") return nil, errors.Wrap(err, "could not push release to readarr")
} }

View file

@ -1,6 +1,7 @@
package readarr package readarr
import ( import (
"context"
"net/http" "net/http"
"net/http/httptest" "net/http/httptest"
"os" "os"
@ -78,7 +79,7 @@ func Test_client_Push(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.fields.config) c := New(tt.fields.config)
rejections, err := c.Push(tt.args.release) rejections, err := c.Push(context.Background(), tt.args.release)
assert.Equal(t, tt.rejections, rejections) assert.Equal(t, tt.rejections, rejections)
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.Equal(t, tt.err, err) assert.Equal(t, tt.err, err)
@ -147,7 +148,7 @@ func Test_client_Test(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.cfg) c := New(tt.cfg)
got, err := c.Test() got, err := c.Test(context.Background())
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err)
} }

View file

@ -2,6 +2,7 @@ package sonarr
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"io" "io"
"net/http" "net/http"
@ -11,12 +12,12 @@ import (
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
) )
func (c *client) get(endpoint string) (int, []byte, error) { func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody)
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "could not build request") return 0, nil, errors.Wrap(err, "could not build request")
} }
@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) {
return resp.StatusCode, buf.Bytes(), nil return resp.StatusCode, buf.Bytes(), nil
} }
func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return nil, errors.Wrap(err, "could not marshal data: %+v", data) return nil, errors.Wrap(err, "could not marshal data: %+v", data)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not build request") return nil, errors.Wrap(err, "could not build request")
} }
@ -81,7 +82,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return res, nil return res, nil
} }
func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -91,7 +92,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error
return 0, nil, errors.Wrap(err, "could not marshal data: %+v", data) return 0, nil, errors.Wrap(err, "could not marshal data: %+v", data)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return 0, nil, errors.Wrap(err, "could not build request") return 0, nil, errors.Wrap(err, "could not build request")
} }

View file

@ -1,6 +1,7 @@
package sonarr package sonarr
import ( import (
"context"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@ -26,8 +27,8 @@ type Config struct {
} }
type Client interface { type Client interface {
Test() (*SystemStatusResponse, error) Test(ctx context.Context) (*SystemStatusResponse, error)
Push(release Release) ([]string, error) Push(ctx context.Context, release Release) ([]string, error)
} }
type client struct { type client struct {
@ -91,8 +92,8 @@ type SystemStatusResponse struct {
Version string `json:"version"` Version string `json:"version"`
} }
func (c *client) Test() (*SystemStatusResponse, error) { func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) {
status, res, err := c.get("system/status") status, res, err := c.get(ctx, "system/status")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not make Test") return nil, errors.Wrap(err, "could not make Test")
} }
@ -111,8 +112,8 @@ func (c *client) Test() (*SystemStatusResponse, error) {
return &response, nil return &response, nil
} }
func (c *client) Push(release Release) ([]string, error) { func (c *client) Push(ctx context.Context, release Release) ([]string, error) {
status, res, err := c.postBody("release/push", release) status, res, err := c.postBody(ctx, "release/push", release)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not push release to sonarr") return nil, errors.Wrap(err, "could not push release to sonarr")
} }

View file

@ -1,6 +1,7 @@
package sonarr package sonarr
import ( import (
"context"
"io/ioutil" "io/ioutil"
"log" "log"
"net/http" "net/http"
@ -109,7 +110,7 @@ func Test_client_Push(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.fields.config) c := New(tt.fields.config)
rejections, err := c.Push(tt.args.release) rejections, err := c.Push(context.Background(), tt.args.release)
assert.Equal(t, tt.rejections, rejections) assert.Equal(t, tt.rejections, rejections)
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.Equal(t, tt.err, err) assert.Equal(t, tt.err, err)
@ -179,7 +180,7 @@ func Test_client_Test(t *testing.T) {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
c := New(tt.cfg) c := New(tt.cfg)
got, err := c.Test() got, err := c.Test(context.Background())
if tt.wantErr && assert.Error(t, err) { if tt.wantErr && assert.Error(t, err) {
assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err)
} }

View file

@ -2,6 +2,7 @@ package whisparr
import ( import (
"bytes" "bytes"
"context"
"encoding/json" "encoding/json"
"net/http" "net/http"
"net/url" "net/url"
@ -10,12 +11,12 @@ import (
"github.com/autobrr/autobrr/pkg/errors" "github.com/autobrr/autobrr/pkg/errors"
) )
func (c *client) get(endpoint string) (*http.Response, error) { func (c *client) get(ctx context.Context, endpoint string) (*http.Response, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not build request") return nil, errors.Wrap(err, "could not build request")
} }
@ -39,7 +40,7 @@ func (c *client) get(endpoint string) (*http.Response, error) {
return res, nil return res, nil
} }
func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) {
u, err := url.Parse(c.config.Hostname) u, err := url.Parse(c.config.Hostname)
u.Path = path.Join(u.Path, "/api/v3/", endpoint) u.Path = path.Join(u.Path, "/api/v3/", endpoint)
reqUrl := u.String() reqUrl := u.String()
@ -49,7 +50,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error)
return nil, errors.Wrap(err, "could not marshal data: %+v", data) return nil, errors.Wrap(err, "could not marshal data: %+v", data)
} }
req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData))
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not build request") return nil, errors.Wrap(err, "could not build request")
} }

View file

@ -1,6 +1,7 @@
package whisparr package whisparr
import ( import (
"context"
"encoding/json" "encoding/json"
"io" "io"
"log" "log"
@ -24,8 +25,8 @@ type Config struct {
} }
type Client interface { type Client interface {
Test() (*SystemStatusResponse, error) Test(ctx context.Context) (*SystemStatusResponse, error)
Push(release Release) ([]string, error) Push(ctx context.Context, release Release) ([]string, error)
} }
type client struct { type client struct {
@ -75,8 +76,8 @@ type SystemStatusResponse struct {
Version string `json:"version"` Version string `json:"version"`
} }
func (c *client) Test() (*SystemStatusResponse, error) { func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) {
res, err := c.get("system/status") res, err := c.get(ctx, "system/status")
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not test whisparr") return nil, errors.Wrap(err, "could not test whisparr")
} }
@ -99,8 +100,8 @@ func (c *client) Test() (*SystemStatusResponse, error) {
return &response, nil return &response, nil
} }
func (c *client) Push(release Release) ([]string, error) { func (c *client) Push(ctx context.Context, release Release) ([]string, error) {
res, err := c.post("release/push", release) res, err := c.post(ctx, "release/push", release)
if err != nil { if err != nil {
return nil, errors.Wrap(err, "could not push release to whisparr: %+v", release) return nil, errors.Wrap(err, "could not push release to whisparr: %+v", release)
} }