diff --git a/internal/action/deluge.go b/internal/action/deluge.go index ae9f0d5..7ea236e 100644 --- a/internal/action/deluge.go +++ b/internal/action/deluge.go @@ -12,13 +12,13 @@ import ( delugeClient "github.com/gdm85/go-libdeluge" ) -func (s *service) deluge(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) deluge(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action Deluge: %v", action.Name) var err error // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID) return nil, err @@ -32,16 +32,16 @@ func (s *service) deluge(action domain.Action, release domain.Release) ([]string switch client.Type { case "DELUGE_V1": - rejections, err = s.delugeV1(client, action, release) + rejections, err = s.delugeV1(ctx, client, action, release) case "DELUGE_V2": - rejections, err = s.delugeV2(client, action, release) + rejections, err = s.delugeV2(ctx, client, action, release) } return rejections, err } -func (s *service) delugeCheckRulesCanDownload(deluge delugeClient.DelugeClient, client *domain.DownloadClient, action domain.Action) ([]string, error) { +func (s *service) delugeCheckRulesCanDownload(deluge delugeClient.DelugeClient, client *domain.DownloadClient, action *domain.Action) ([]string, error) { s.log.Trace().Msgf("action Deluge: %v check rules", action.Name) // check for active downloads and other rules @@ -86,7 +86,7 @@ func (s *service) delugeCheckRulesCanDownload(deluge delugeClient.DelugeClient, return nil, nil } -func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action, release domain.Release) ([]string, error) { +func (s *service) delugeV1(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) { settings := delugeClient.Settings{ Hostname: client.Host, Port: uint(client.Port), @@ -117,7 +117,7 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action, } if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFile(); err != nil { + if err := release.DownloadTorrentFileCtx(ctx); err != nil { s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) return nil, err } @@ -134,10 +134,7 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action, return nil, errors.Wrap(err, "could not encode torrent file: %v", release.TorrentTmpFile) } - // macros handle args and replace vars - m := domain.NewMacro(release) - - options, err := s.prepareDelugeOptions(action, m) + options, err := s.prepareDelugeOptions(action) if err != nil { return nil, errors.Wrap(err, "could not prepare options") } @@ -155,15 +152,9 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action, return nil, errors.Wrap(err, "could not load label plugin for client: %v", client.Name) } - // parse and replace values in argument string before continuing - labelArgs, err := m.Parse(action.Label) - if err != nil { - return nil, errors.Wrap(err, "could not parse macro label: %v", action.Label) - } - if labelPluginActive != nil { // TODO first check if label exists, if not, add it, otherwise set - err = labelPluginActive.SetTorrentLabel(torrentHash, labelArgs) + err = labelPluginActive.SetTorrentLabel(torrentHash, action.Label) if err != nil { return nil, errors.Wrap(err, "could not set label: %v on client: %v", action.Label, client.Name) } @@ -175,7 +166,7 @@ func (s *service) delugeV1(client *domain.DownloadClient, action domain.Action, return nil, nil } -func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action, release domain.Release) ([]string, error) { +func (s *service) delugeV2(ctx context.Context, client *domain.DownloadClient, action *domain.Action, release domain.Release) ([]string, error) { settings := delugeClient.Settings{ Hostname: client.Host, Port: uint(client.Port), @@ -206,7 +197,7 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action, } if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFile(); err != nil { + if err := release.DownloadTorrentFileCtx(ctx); err != nil { s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) return nil, err } @@ -223,11 +214,8 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action, return nil, errors.Wrap(err, "could not encode torrent file: %v", release.TorrentTmpFile) } - // macros handle args and replace vars - m := domain.NewMacro(release) - // set options - options, err := s.prepareDelugeOptions(action, m) + options, err := s.prepareDelugeOptions(action) if err != nil { return nil, errors.Wrap(err, "could not prepare options") } @@ -245,15 +233,9 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action, return nil, errors.Wrap(err, "could not load label plugin for client: %v", client.Name) } - // parse and replace values in argument string before continuing - labelArgs, err := m.Parse(action.Label) - if err != nil { - return nil, errors.Wrap(err, "could not parse macro label: %v", action.Label) - } - if labelPluginActive != nil { // TODO first check if label exists, if not, add it, otherwise set - err = labelPluginActive.SetTorrentLabel(torrentHash, labelArgs) + err = labelPluginActive.SetTorrentLabel(torrentHash, action.Label) if err != nil { return nil, errors.Wrap(err, "could not set label: %v on client: %v", action.Label, client.Name) } @@ -265,7 +247,7 @@ func (s *service) delugeV2(client *domain.DownloadClient, action domain.Action, return nil, nil } -func (s *service) prepareDelugeOptions(action domain.Action, m domain.Macro) (delugeClient.Options, error) { +func (s *service) prepareDelugeOptions(action *domain.Action) (delugeClient.Options, error) { // set options options := delugeClient.Options{} @@ -274,13 +256,7 @@ func (s *service) prepareDelugeOptions(action domain.Action, m domain.Macro) (de options.AddPaused = &action.Paused } if action.SavePath != "" { - // parse and replace values in argument string before continuing - savePathArgs, err := m.Parse(action.SavePath) - if err != nil { - return options, errors.Wrap(err, "could not parse save path macro: %v", action.SavePath) - } - - options.DownloadLocation = &savePathArgs + options.DownloadLocation = &action.SavePath } if action.LimitDownloadSpeed > 0 { maxDL := int(action.LimitDownloadSpeed) diff --git a/internal/action/exec.go b/internal/action/exec.go index 71f73be..7bc942e 100644 --- a/internal/action/exec.go +++ b/internal/action/exec.go @@ -1,6 +1,7 @@ package action import ( + "context" "os" "os/exec" "strings" @@ -12,11 +13,11 @@ import ( "github.com/mattn/go-shellwords" ) -func (s *service) execCmd(action domain.Action, release domain.Release) error { +func (s *service) execCmd(ctx context.Context, action *domain.Action, release domain.Release) error { s.log.Debug().Msgf("action exec: %v release: %v", action.Name, release.TorrentName) if release.TorrentTmpFile == "" && strings.Contains(action.ExecArgs, "TorrentPathName") { - if err := release.DownloadTorrentFile(); err != nil { + if err := release.DownloadTorrentFileCtx(ctx); err != nil { return errors.Wrap(err, "error downloading torrent file for release: %v", release.TorrentName) } } @@ -37,7 +38,9 @@ func (s *service) execCmd(action domain.Action, release domain.Release) error { return errors.Wrap(err, "exec failed, could not find program: %v", action.ExecCmd) } - args, err := s.parseExecArgs(release, action.ExecArgs) + p := shellwords.NewParser() + p.ParseBacktick = true + args, err := p.Parse(action.ExecArgs) if err != nil { return errors.Wrap(err, "could not parse exec args: %v", action.ExecArgs) } @@ -47,7 +50,7 @@ func (s *service) execCmd(action domain.Action, release domain.Release) error { start := time.Now() // setup command and args - command := exec.Command(cmd, args...) + command := exec.CommandContext(ctx, cmd, args...) // execute command output, err := command.CombinedOutput() @@ -64,23 +67,3 @@ func (s *service) execCmd(action domain.Action, release domain.Release) error { return nil } - -func (s *service) parseExecArgs(release domain.Release, execArgs string) ([]string, error) { - // handle args and replace vars - m := domain.NewMacro(release) - - // parse and replace values in argument string before continuing - parsedArgs, err := m.Parse(execArgs) - if err != nil { - return nil, errors.Wrap(err, "could not parse macro") - } - - p := shellwords.NewParser() - p.ParseBacktick = true - args, err := p.Parse(parsedArgs) - if err != nil { - return nil, errors.Wrap(err, "could not parse into shell-words") - } - - return args, nil -} diff --git a/internal/action/exec_test.go b/internal/action/exec_test.go index 1b384e1..10e1393 100644 --- a/internal/action/exec_test.go +++ b/internal/action/exec_test.go @@ -9,55 +9,48 @@ import ( "github.com/stretchr/testify/assert" ) -func Test_service_parseExecArgs(t *testing.T) { +func Test_service_parseMacros(t *testing.T) { type args struct { - release domain.Release - execArgs string + release domain.Release + action *domain.Action } tests := []struct { name string args args - want []string + want string wantErr bool }{ { name: "test_1", args: args{ - release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, - execArgs: `echo "{{ .TorrentName }}"`, - }, - want: []string{ - "echo", - "Sally Goes to the Mall S04E29", + release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, + action: &domain.Action{ + ExecArgs: `echo "{{ .TorrentName }}"`, + }, }, + want: `echo "Sally Goes to the Mall S04E29"`, wantErr: false, }, { name: "test_2", args: args{ - release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, - execArgs: `"{{ .TorrentName }}"`, - }, - want: []string{ - "Sally Goes to the Mall S04E29", + release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, + action: &domain.Action{ + ExecArgs: `"{{ .TorrentName }}"`, + }, }, + want: `"Sally Goes to the Mall S04E29"`, wantErr: false, }, { name: "test_3", args: args{ - release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, - execArgs: `--header "Content-Type: application/json" --request POST --data '{"release":"{{ .TorrentName }}"}' http://localhost:3000/api/release`, - }, - want: []string{ - "--header", - "Content-Type: application/json", - "--request", - "POST", - "--data", - `{"release":"Sally Goes to the Mall S04E29"}`, - "http://localhost:3000/api/release", + release: domain.Release{TorrentName: "Sally Goes to the Mall S04E29"}, + action: &domain.Action{ + ExecArgs: `--header "Content-Type: application/json" --request POST --data '{"release":"{{ .TorrentName }}"}' http://localhost:3000/api/release`, + }, }, + want: `--header "Content-Type: application/json" --request POST --data '{"release":"Sally Goes to the Mall S04E29"}' http://localhost:3000/api/release`, wantErr: false, }, } @@ -69,8 +62,8 @@ func Test_service_parseExecArgs(t *testing.T) { clientSvc: nil, bus: nil, } - got, _ := s.parseExecArgs(tt.args.release, tt.args.execArgs) - assert.Equalf(t, tt.want, got, "parseExecArgs(%v, %v)", tt.args.release, tt.args.execArgs) + _ = s.parseMacros(tt.args.action, tt.args.release) + assert.Equalf(t, tt.want, tt.args.action.ExecArgs, "parseMacros(%v, %v)", tt.args.action, tt.args.release) }) } } @@ -78,7 +71,7 @@ func Test_service_parseExecArgs(t *testing.T) { func Test_service_execCmd(t *testing.T) { type args struct { release domain.Release - action domain.Action + action *domain.Action } tests := []struct { name string @@ -92,7 +85,7 @@ func Test_service_execCmd(t *testing.T) { TorrentTmpFile: "tmp-10000", Indexer: "mock", }, - action: domain.Action{ + action: &domain.Action{ Name: "echo", ExecCmd: "echo", ExecArgs: "hello", @@ -108,7 +101,7 @@ func Test_service_execCmd(t *testing.T) { clientSvc: nil, bus: nil, } - s.execCmd(tt.args.action, tt.args.release) + s.execCmd(nil, tt.args.action, tt.args.release) }) } } diff --git a/internal/action/lidarr.go b/internal/action/lidarr.go index e07c9e3..8cc8595 100644 --- a/internal/action/lidarr.go +++ b/internal/action/lidarr.go @@ -10,13 +10,13 @@ import ( "github.com/autobrr/autobrr/pkg/lidarr" ) -func (s *service) lidarr(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) lidarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Trace().Msg("action LIDARR") // TODO validate data // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { s.log.Error().Err(err).Msgf("lidarr: error finding client: %v", action.ClientID) return nil, err @@ -59,9 +59,9 @@ func (s *service) lidarr(action domain.Action, release domain.Release) ([]string r.Title = fmt.Sprintf("%v (%d)", release.TorrentName, release.Year) } - rejections, err := arr.Push(r) + rejections, err := arr.Push(ctx, r) if err != nil { - s.log.Error().Stack().Err(err).Msgf("lidarr: failed to push release: %v", r) + s.log.Error().Err(err).Msgf("lidarr: failed to push release: %v", r) return nil, err } diff --git a/internal/action/qbittorrent.go b/internal/action/qbittorrent.go index 0a7ba9e..861175f 100644 --- a/internal/action/qbittorrent.go +++ b/internal/action/qbittorrent.go @@ -9,7 +9,7 @@ import ( "github.com/autobrr/go-qbittorrent" ) -func (s *service) qbittorrent(ctx context.Context, action domain.Action, release domain.Release) ([]string, error) { +func (s *service) qbittorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action qBittorrent: %v", action.Name) c := s.clientSvc.GetCachedClient(ctx, action.ClientID) @@ -29,10 +29,7 @@ func (s *service) qbittorrent(ctx context.Context, action domain.Action, release } } - // macros handle args and replace vars - m := domain.NewMacro(release) - - options, err := s.prepareQbitOptions(action, m) + options, err := s.prepareQbitOptions(action) if err != nil { return nil, errors.Wrap(err, "could not prepare options") } @@ -59,7 +56,7 @@ func (s *service) qbittorrent(ctx context.Context, action domain.Action, release return nil, nil } -func (s *service) prepareQbitOptions(action domain.Action, m domain.Macro) (map[string]string, error) { +func (s *service) prepareQbitOptions(action *domain.Action) (map[string]string, error) { opts := &qbittorrent.TorrentAddOptions{} opts.Paused = false @@ -78,32 +75,14 @@ func (s *service) prepareQbitOptions(action domain.Action, m domain.Macro) (map[ // if ORIGINAL then leave empty } if action.SavePath != "" { - // parse and replace values in argument string before continuing - actionArgs, err := m.Parse(action.SavePath) - if err != nil { - return nil, errors.Wrap(err, "could not parse savepath macro: %v", action.SavePath) - } - - opts.SavePath = actionArgs + opts.SavePath = action.SavePath opts.AutoTMM = false } if action.Category != "" { - // parse and replace values in argument string before continuing - categoryArgs, err := m.Parse(action.Category) - if err != nil { - return nil, errors.Wrap(err, "could not parse category macro: %v", action.Category) - } - - opts.Category = categoryArgs + opts.Category = action.Category } if action.Tags != "" { - // parse and replace values in argument string before continuing - tagsArgs, err := m.Parse(action.Tags) - if err != nil { - return nil, errors.Wrap(err, "could not parse tags macro: %v", action.Tags) - } - - opts.Tags = tagsArgs + opts.Tags = action.Tags } if action.LimitUploadSpeed > 0 { opts.LimitUploadSpeed = action.LimitUploadSpeed @@ -121,7 +100,7 @@ func (s *service) prepareQbitOptions(action domain.Action, m domain.Macro) (map[ return opts.Prepare(), nil } -func (s *service) qbittorrentCheckRulesCanDownload(ctx context.Context, action domain.Action, client *domain.DownloadClient, qbt *qbittorrent.Client) ([]string, error) { +func (s *service) qbittorrentCheckRulesCanDownload(ctx context.Context, action *domain.Action, client *domain.DownloadClient, qbt *qbittorrent.Client) ([]string, error) { s.log.Trace().Msgf("action qBittorrent: %v check rules", action.Name) // check for active downloads and other rules diff --git a/internal/action/radarr.go b/internal/action/radarr.go index b2a38c8..3e92762 100644 --- a/internal/action/radarr.go +++ b/internal/action/radarr.go @@ -9,13 +9,13 @@ import ( "github.com/autobrr/autobrr/pkg/radarr" ) -func (s *service) radarr(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) radarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Trace().Msg("action RADARR") // TODO validate data // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { return nil, errors.Wrap(err, "error finding client: %v", action.ClientID) } @@ -51,7 +51,7 @@ func (s *service) radarr(action domain.Action, release domain.Release) ([]string PublishDate: time.Now().Format(time.RFC3339), } - rejections, err := arr.Push(r) + rejections, err := arr.Push(ctx, r) if err != nil { return nil, errors.Wrap(err, "radarr failed to push release: %v", r) } diff --git a/internal/action/readarr.go b/internal/action/readarr.go index 2f031ec..a175b82 100644 --- a/internal/action/readarr.go +++ b/internal/action/readarr.go @@ -9,13 +9,13 @@ import ( "github.com/autobrr/autobrr/pkg/readarr" ) -func (s *service) readarr(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) readarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Trace().Msg("action READARR") // TODO validate data // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { return nil, errors.Wrap(err, "readarr could not find client: %v", action.ClientID) } @@ -51,7 +51,7 @@ func (s *service) readarr(action domain.Action, release domain.Release) ([]strin PublishDate: time.Now().Format(time.RFC3339), } - rejections, err := arr.Push(r) + rejections, err := arr.Push(ctx, r) if err != nil { return nil, errors.Wrap(err, "readarr: failed to push release: %v", r) } diff --git a/internal/action/rtorrent.go b/internal/action/rtorrent.go index 2ce44aa..f8f0798 100644 --- a/internal/action/rtorrent.go +++ b/internal/action/rtorrent.go @@ -9,13 +9,13 @@ import ( "github.com/mrobinsn/go-rtorrent/rtorrent" ) -func (s *service) rtorrent(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) rtorrent(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action rTorrent: %v", action.Name) var err error // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID) return nil, err @@ -28,7 +28,7 @@ func (s *service) rtorrent(action domain.Action, release domain.Release) ([]stri var rejections []string if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFile(); err != nil { + if err := release.DownloadTorrentFileCtx(ctx); err != nil { s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) return nil, err } diff --git a/internal/action/run.go b/internal/action/run.go index 66bef17..b8412b3 100644 --- a/internal/action/run.go +++ b/internal/action/run.go @@ -30,45 +30,50 @@ func (s *service) RunAction(ctx context.Context, action *domain.Action, release } }() + // parse all macros in one go + if err := action.ParseMacros(release); err != nil { + return nil, err + } + switch action.Type { case domain.ActionTypeTest: s.test(action.Name) case domain.ActionTypeExec: - err = s.execCmd(*action, release) + err = s.execCmd(ctx, action, release) case domain.ActionTypeWatchFolder: - err = s.watchFolder(*action, release) + err = s.watchFolder(ctx, action, release) case domain.ActionTypeWebhook: - err = s.webhook(ctx, *action, release) + err = s.webhook(ctx, action, release) case domain.ActionTypeDelugeV1, domain.ActionTypeDelugeV2: - rejections, err = s.deluge(*action, release) + rejections, err = s.deluge(ctx, action, release) case domain.ActionTypeQbittorrent: - rejections, err = s.qbittorrent(ctx, *action, release) + rejections, err = s.qbittorrent(ctx, action, release) case domain.ActionTypeRTorrent: - rejections, err = s.rtorrent(*action, release) + rejections, err = s.rtorrent(ctx, action, release) case domain.ActionTypeTransmission: - rejections, err = s.transmission(*action, release) + rejections, err = s.transmission(ctx, action, release) case domain.ActionTypeRadarr: - rejections, err = s.radarr(*action, release) + rejections, err = s.radarr(ctx, action, release) case domain.ActionTypeSonarr: - rejections, err = s.sonarr(*action, release) + rejections, err = s.sonarr(ctx, action, release) case domain.ActionTypeLidarr: - rejections, err = s.lidarr(*action, release) + rejections, err = s.lidarr(ctx, action, release) case domain.ActionTypeWhisparr: - rejections, err = s.whisparr(*action, release) + rejections, err = s.whisparr(ctx, action, release) case domain.ActionTypeReadarr: - rejections, err = s.readarr(*action, release) + rejections, err = s.readarr(ctx, action, release) default: s.log.Warn().Msgf("unsupported action type: %v", action.Type) @@ -137,9 +142,9 @@ func (s *service) test(name string) { s.log.Info().Msgf("action TEST: %v", name) } -func (s *service) watchFolder(action domain.Action, release domain.Release) error { +func (s *service) watchFolder(ctx context.Context, action *domain.Action, release domain.Release) error { if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFile(); err != nil { + if err := release.DownloadTorrentFileCtx(ctx); err != nil { return errors.Wrap(err, "watch folder: could not download torrent file for release: %v", release.TorrentName) } } @@ -153,19 +158,7 @@ func (s *service) watchFolder(action domain.Action, release domain.Release) erro release.TorrentDataRawBytes = t } - m := domain.NewMacro(release) - - // parse and replace values in argument string before continuing - // /mnt/watch/{{.Indexer}} - // /mnt/watch/mock - // /mnt/watch/{{.Indexer}}-{{.TorrentName}}.torrent - // /mnt/watch/mock-Torrent.Name-GROUP.torrent - watchFolderArgs, err := m.Parse(action.WatchFolder) - if err != nil { - return errors.Wrap(err, "could not parse watch folder macro: %v", action.WatchFolder) - } - - s.log.Trace().Msgf("action WATCH_FOLDER: %v file: %v", watchFolderArgs, release.TorrentTmpFile) + s.log.Trace().Msgf("action WATCH_FOLDER: %v file: %v", action.WatchFolder, release.TorrentTmpFile) // Open original file original, err := os.Open(release.TorrentTmpFile) @@ -175,16 +168,20 @@ func (s *service) watchFolder(action domain.Action, release domain.Release) erro defer original.Close() // default dir to watch folder - dir := watchFolderArgs - newFileName := watchFolderArgs + // /mnt/watch/{{.Indexer}} + // /mnt/watch/mock + // /mnt/watch/{{.Indexer}}-{{.TorrentName}}.torrent + // /mnt/watch/mock-Torrent.Name-GROUP.torrent + dir := action.WatchFolder + newFileName := action.WatchFolder // if watchFolderArgs does not contain .torrent, create - if !strings.HasSuffix(watchFolderArgs, ".torrent") { + if !strings.HasSuffix(action.WatchFolder, ".torrent") { _, tmpFileName := filepath.Split(release.TorrentTmpFile) - newFileName = filepath.Join(watchFolderArgs, tmpFileName+".torrent") + newFileName = filepath.Join(action.WatchFolder, tmpFileName+".torrent") } else { - dir, _ = filepath.Split(watchFolderArgs) + dir, _ = filepath.Split(action.WatchFolder) } // Create folder @@ -209,7 +206,7 @@ func (s *service) watchFolder(action domain.Action, release domain.Release) erro return nil } -func (s *service) webhook(ctx context.Context, action domain.Action, release domain.Release) error { +func (s *service) webhook(ctx context.Context, action *domain.Action, release domain.Release) error { // if webhook data contains TorrentPathName or TorrentDataRawBytes, lets download the torrent file if release.TorrentTmpFile == "" && (strings.Contains(action.WebhookData, "TorrentPathName") || strings.Contains(action.WebhookData, "TorrentDataRawBytes")) { if err := release.DownloadTorrentFileCtx(ctx); err != nil { @@ -227,14 +224,6 @@ func (s *service) webhook(ctx context.Context, action domain.Action, release dom release.TorrentDataRawBytes = t } - m := domain.NewMacro(release) - - // parse and replace values in argument string before continuing - dataArgs, err := m.Parse(action.WebhookData) - if err != nil { - return errors.Wrap(err, "could not parse webhook data macro: %v", action.WebhookData) - } - s.log.Trace().Msgf("action WEBHOOK: '%v' file: %v", action.Name, release.TorrentName) s.log.Trace().Msgf("webhook action '%v' - host: %v data: %v", action.Name, action.WebhookHost, action.WebhookData) @@ -246,7 +235,7 @@ func (s *service) webhook(ctx context.Context, action domain.Action, release dom client := http.Client{Transport: t, Timeout: 15 * time.Second} - req, err := http.NewRequestWithContext(ctx, http.MethodPost, action.WebhookHost, bytes.NewBufferString(dataArgs)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, action.WebhookHost, bytes.NewBufferString(action.WebhookData)) if err != nil { return errors.Wrap(err, "could not build request for webhook") } @@ -261,7 +250,11 @@ func (s *service) webhook(ctx context.Context, action domain.Action, release dom defer res.Body.Close() - s.log.Info().Msgf("successfully ran webhook action: '%v' to: %v payload: %v", action.Name, action.WebhookHost, dataArgs) + s.log.Info().Msgf("successfully ran webhook action: '%v' to: %v payload: %v", action.Name, action.WebhookHost, action.WebhookData) return nil } +func (s *service) parseMacros(action *domain.Action, release domain.Release) error { + // parse all macros in one go + return action.ParseMacros(release) +} diff --git a/internal/action/sonarr.go b/internal/action/sonarr.go index 398fb15..8b9b431 100644 --- a/internal/action/sonarr.go +++ b/internal/action/sonarr.go @@ -9,13 +9,13 @@ import ( "github.com/autobrr/autobrr/pkg/sonarr" ) -func (s *service) sonarr(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) sonarr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Trace().Msg("action SONARR") // TODO validate data // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) } @@ -51,7 +51,7 @@ func (s *service) sonarr(action domain.Action, release domain.Release) ([]string PublishDate: time.Now().Format(time.RFC3339), } - rejections, err := arr.Push(r) + rejections, err := arr.Push(ctx, r) if err != nil { return nil, errors.Wrap(err, "sonarr: failed to push release: %v", r) } diff --git a/internal/action/transmission.go b/internal/action/transmission.go index fb6b3d5..ba83dd8 100644 --- a/internal/action/transmission.go +++ b/internal/action/transmission.go @@ -9,13 +9,13 @@ import ( "github.com/hekmon/transmissionrpc/v2" ) -func (s *service) transmission(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) transmission(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Debug().Msgf("action Transmission: %v", action.Name) var err error // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { s.log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID) return nil, err @@ -28,7 +28,7 @@ func (s *service) transmission(action domain.Action, release domain.Release) ([] var rejections []string if release.TorrentTmpFile == "" { - if err := release.DownloadTorrentFile(); err != nil { + if err := release.DownloadTorrentFileCtx(ctx); err != nil { s.log.Error().Err(err).Msgf("could not download torrent file for release: %v", release.TorrentName) return nil, err } @@ -58,7 +58,7 @@ func (s *service) transmission(action domain.Action, release domain.Release) ([] } // Prepare and send payload - torrent, err := tbt.TorrentAdd(context.TODO(), payload) + torrent, err := tbt.TorrentAdd(ctx, payload) if err != nil { return nil, errors.Wrap(err, "could not add torrent %v to client: %v", release.TorrentTmpFile, client.Host) } diff --git a/internal/action/whisparr.go b/internal/action/whisparr.go index 3b5232d..5ab05bb 100644 --- a/internal/action/whisparr.go +++ b/internal/action/whisparr.go @@ -9,13 +9,13 @@ import ( "github.com/autobrr/autobrr/pkg/whisparr" ) -func (s *service) whisparr(action domain.Action, release domain.Release) ([]string, error) { +func (s *service) whisparr(ctx context.Context, action *domain.Action, release domain.Release) ([]string, error) { s.log.Trace().Msg("action WHISPARR") // TODO validate data // get client for action - client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID) + client, err := s.clientSvc.FindByID(ctx, action.ClientID) if err != nil { return nil, errors.Wrap(err, "sonarr could not find client: %v", action.ClientID) } @@ -51,7 +51,7 @@ func (s *service) whisparr(action domain.Action, release domain.Release) ([]stri PublishDate: time.Now().Format(time.RFC3339), } - rejections, err := arr.Push(r) + rejections, err := arr.Push(ctx, r) if err != nil { return nil, errors.Wrap(err, "whisparr: failed to push release: %v", r) } diff --git a/internal/domain/action.go b/internal/domain/action.go index 4991877..dde97ba 100644 --- a/internal/domain/action.go +++ b/internal/domain/action.go @@ -1,6 +1,9 @@ package domain -import "context" +import ( + "context" + "github.com/autobrr/autobrr/pkg/errors" +) type ActionRepo interface { Store(ctx context.Context, action Action) (*Action, error) @@ -46,6 +49,27 @@ type Action struct { Client DownloadClient `json:"client,omitempty"` } +// ParseMacros parse all macros on action +func (a *Action) ParseMacros(release Release) error { + var err error + + m := NewMacro(release) + + a.ExecArgs, err = m.Parse(a.ExecArgs) + a.WatchFolder, err = m.Parse(a.WatchFolder) + a.Category, err = m.Parse(a.Category) + a.Tags, err = m.Parse(a.Tags) + a.Label, err = m.Parse(a.Label) + a.SavePath, err = m.Parse(a.SavePath) + a.WebhookData, err = m.Parse(a.WebhookData) + + if err != nil { + return errors.Wrap(err, "could not parse macros for action: %v", a.Name) + } + + return nil +} + type ActionType string const ( diff --git a/internal/domain/client.go b/internal/domain/client.go index 9c6fa24..61650ed 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -5,6 +5,8 @@ import ( "fmt" "net/url" + "github.com/autobrr/autobrr/pkg/errors" + "github.com/autobrr/go-qbittorrent" ) @@ -69,6 +71,18 @@ const ( DownloadClientTypeReadarr DownloadClientType = "READARR" ) +// Validate basic validation of client +func (c DownloadClient) Validate() error { + // basic validation of client + if c.Host == "" { + return errors.New("validation error: missing host") + } else if c.Type == "" { + return errors.New("validation error: missing type") + } + + return nil +} + func (c DownloadClient) BuildLegacyHost() string { if c.Type == DownloadClientTypeQbittorrent { return c.qbitBuildLegacyHost() diff --git a/internal/domain/macros.go b/internal/domain/macros.go index 260c38a..98fcc4c 100644 --- a/internal/domain/macros.go +++ b/internal/domain/macros.go @@ -10,54 +10,54 @@ import ( ) type Macro struct { - TorrentName string - TorrentPathName string - TorrentHash string - TorrentUrl string - TorrentDataRawBytes []byte - Indexer string - Title string - Resolution string - Source string - HDR string - FilterName string - Size uint64 - Season int - Episode int - Year int - CurrentYear int - CurrentMonth int - CurrentDay int - CurrentHour int - CurrentMinute int - CurrentSecond int + TorrentName string + TorrentPathName string + TorrentHash string + TorrentUrl string + TorrentDataRawBytes []byte + Indexer string + Title string + Resolution string + Source string + HDR string + FilterName string + Size uint64 + Season int + Episode int + Year int + CurrentYear int + CurrentMonth int + CurrentDay int + CurrentHour int + CurrentMinute int + CurrentSecond int } func NewMacro(release Release) Macro { currentTime := time.Now() ma := Macro{ - TorrentName: release.TorrentName, - TorrentUrl: release.TorrentURL, - TorrentPathName: release.TorrentTmpFile, + TorrentName: release.TorrentName, + TorrentUrl: release.TorrentURL, + TorrentPathName: release.TorrentTmpFile, TorrentDataRawBytes: release.TorrentDataRawBytes, - TorrentHash: release.TorrentHash, - Indexer: release.Indexer, - Title: release.Title, - Resolution: release.Resolution, - Source: release.Source, - HDR: strings.Join(release.HDR, ", "), - FilterName: release.FilterName, - Size: release.Size, - Season: release.Season, - Episode: release.Episode, - Year: release.Year, - CurrentYear: currentTime.Year(), - CurrentMonth: int(currentTime.Month()), - CurrentDay: currentTime.Day(), - CurrentHour: currentTime.Hour(), - CurrentMinute: currentTime.Minute(), - CurrentSecond: currentTime.Second(), + TorrentHash: release.TorrentHash, + Indexer: release.Indexer, + Title: release.Title, + Resolution: release.Resolution, + Source: release.Source, + HDR: strings.Join(release.HDR, ", "), + FilterName: release.FilterName, + Size: release.Size, + Season: release.Season, + Episode: release.Episode, + Year: release.Year, + CurrentYear: currentTime.Year(), + CurrentMonth: int(currentTime.Month()), + CurrentDay: currentTime.Day(), + CurrentHour: currentTime.Hour(), + CurrentMinute: currentTime.Minute(), + CurrentSecond: currentTime.Second(), } return ma @@ -65,6 +65,9 @@ func NewMacro(release Release) Macro { // Parse takes a string and replaces valid vars func (m Macro) Parse(text string) (string, error) { + if text == "" { + return "", nil + } // setup template tmpl, err := template.New("macro").Parse(text) @@ -80,3 +83,24 @@ func (m Macro) Parse(text string) (string, error) { return tpl.String(), nil } + +// MustParse takes a string and replaces valid vars +func (m Macro) MustParse(text string) string { + if text == "" { + return "" + } + + // setup template + tmpl, err := template.New("macro").Parse(text) + if err != nil { + return "" + } + + var tpl bytes.Buffer + err = tmpl.Execute(&tpl, m) + if err != nil { + return "" + } + + return tpl.String() +} diff --git a/internal/download_client/connection.go b/internal/download_client/connection.go index 7f6ea31..d9b9928 100644 --- a/internal/download_client/connection.go +++ b/internal/download_client/connection.go @@ -30,22 +30,23 @@ func (s *service) testConnection(ctx context.Context, client domain.DownloadClie return s.testRTorrentConnection(client) case domain.DownloadClientTypeTransmission: - return s.testTransmissionConnection(client) + return s.testTransmissionConnection(ctx, client) case domain.DownloadClientTypeRadarr: - return s.testRadarrConnection(client) + return s.testRadarrConnection(ctx, client) case domain.DownloadClientTypeSonarr: - return s.testSonarrConnection(client) + return s.testSonarrConnection(ctx, client) case domain.DownloadClientTypeLidarr: - return s.testLidarrConnection(client) + return s.testLidarrConnection(ctx, client) case domain.DownloadClientTypeWhisparr: - return s.testWhisparrConnection(client) + return s.testWhisparrConnection(ctx, client) case domain.DownloadClientTypeReadarr: - return s.testReadarrConnection(client) + return s.testReadarrConnection(ctx, client) + default: return errors.New("unsupported client") } @@ -138,7 +139,7 @@ func (s *service) testRTorrentConnection(client domain.DownloadClient) error { return nil } -func (s *service) testTransmissionConnection(client domain.DownloadClient) error { +func (s *service) testTransmissionConnection(ctx context.Context, client domain.DownloadClient) error { tbt, err := transmissionrpc.New(client.Host, client.Username, client.Password, &transmissionrpc.AdvancedConfig{ HTTPS: client.TLS, Port: uint16(client.Port), @@ -147,7 +148,7 @@ func (s *service) testTransmissionConnection(client domain.DownloadClient) error return errors.Wrap(err, "error logging into client: %v", client.Host) } - ok, version, _, err := tbt.RPCVersion(context.TODO()) + ok, version, _, err := tbt.RPCVersion(ctx) if err != nil { return errors.Wrap(err, "error getting rpc info: %v", client.Host) } @@ -163,7 +164,7 @@ func (s *service) testTransmissionConnection(client domain.DownloadClient) error return nil } -func (s *service) testRadarrConnection(client domain.DownloadClient) error { +func (s *service) testRadarrConnection(ctx context.Context, client domain.DownloadClient) error { r := radarr.New(radarr.Config{ Hostname: client.Host, APIKey: client.Settings.APIKey, @@ -173,8 +174,7 @@ func (s *service) testRadarrConnection(client domain.DownloadClient) error { Log: s.subLogger, }) - _, err := r.Test() - if err != nil { + if _, err := r.Test(ctx); err != nil { return errors.Wrap(err, "radarr: connection test failed: %v", client.Host) } @@ -183,7 +183,7 @@ func (s *service) testRadarrConnection(client domain.DownloadClient) error { return nil } -func (s *service) testSonarrConnection(client domain.DownloadClient) error { +func (s *service) testSonarrConnection(ctx context.Context, client domain.DownloadClient) error { r := sonarr.New(sonarr.Config{ Hostname: client.Host, APIKey: client.Settings.APIKey, @@ -193,8 +193,7 @@ func (s *service) testSonarrConnection(client domain.DownloadClient) error { Log: s.subLogger, }) - _, err := r.Test() - if err != nil { + if _, err := r.Test(ctx); err != nil { return errors.Wrap(err, "sonarr: connection test failed: %v", client.Host) } @@ -203,7 +202,7 @@ func (s *service) testSonarrConnection(client domain.DownloadClient) error { return nil } -func (s *service) testLidarrConnection(client domain.DownloadClient) error { +func (s *service) testLidarrConnection(ctx context.Context, client domain.DownloadClient) error { r := lidarr.New(lidarr.Config{ Hostname: client.Host, APIKey: client.Settings.APIKey, @@ -213,8 +212,7 @@ func (s *service) testLidarrConnection(client domain.DownloadClient) error { Log: s.subLogger, }) - _, err := r.Test() - if err != nil { + if _, err := r.Test(ctx); err != nil { return errors.Wrap(err, "lidarr: connection test failed: %v", client.Host) } @@ -223,7 +221,7 @@ func (s *service) testLidarrConnection(client domain.DownloadClient) error { return nil } -func (s *service) testWhisparrConnection(client domain.DownloadClient) error { +func (s *service) testWhisparrConnection(ctx context.Context, client domain.DownloadClient) error { r := whisparr.New(whisparr.Config{ Hostname: client.Host, APIKey: client.Settings.APIKey, @@ -233,8 +231,7 @@ func (s *service) testWhisparrConnection(client domain.DownloadClient) error { Log: s.subLogger, }) - _, err := r.Test() - if err != nil { + if _, err := r.Test(ctx); err != nil { return errors.Wrap(err, "whisparr: connection test failed: %v", client.Host) } @@ -243,7 +240,7 @@ func (s *service) testWhisparrConnection(client domain.DownloadClient) error { return nil } -func (s *service) testReadarrConnection(client domain.DownloadClient) error { +func (s *service) testReadarrConnection(ctx context.Context, client domain.DownloadClient) error { r := readarr.New(readarr.Config{ Hostname: client.Host, APIKey: client.Settings.APIKey, @@ -253,8 +250,7 @@ func (s *service) testReadarrConnection(client domain.DownloadClient) error { Log: s.subLogger, }) - _, err := r.Test() - if err != nil { + if _, err := r.Test(ctx); err != nil { return errors.Wrap(err, "readarr: connection test failed: %v", client.Host) } diff --git a/internal/download_client/service.go b/internal/download_client/service.go index e9c7ed3..a0b242a 100644 --- a/internal/download_client/service.go +++ b/internal/download_client/service.go @@ -2,7 +2,6 @@ package download_client import ( "context" - "errors" "log" "sync" @@ -69,11 +68,9 @@ func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClien } func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { - // validate data - if client.Host == "" { - return nil, errors.New("validation error: no host") - } else if client.Type == "" { - return nil, errors.New("validation error: no type") + // basic validation of client + if err := client.Validate(); err != nil { + return nil, err } // store @@ -87,11 +84,9 @@ func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*dom } func (s *service) Update(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { - // validate data - if client.Host == "" { - return nil, errors.New("validation error: no host") - } else if client.Type == "" { - return nil, errors.New("validation error: no type") + // basic validation of client + if err := client.Validate(); err != nil { + return nil, err } // update @@ -125,10 +120,8 @@ func (s *service) Delete(ctx context.Context, clientID int) error { func (s *service) Test(ctx context.Context, client domain.DownloadClient) error { // basic validation of client - if client.Host == "" { - return errors.New("validation error: no host") - } else if client.Type == "" { - return errors.New("validation error: no type") + if err := client.Validate(); err != nil { + return err } // test diff --git a/pkg/lidarr/client.go b/pkg/lidarr/client.go index 7e5b98c..cd2b6d8 100644 --- a/pkg/lidarr/client.go +++ b/pkg/lidarr/client.go @@ -2,6 +2,7 @@ package lidarr import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -11,12 +12,12 @@ import ( "github.com/autobrr/autobrr/pkg/errors" ) -func (c *client) get(endpoint string) (int, []byte, error) { +func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v1/", endpoint) reqUrl := u.String() - req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody) if err != nil { return 0, nil, errors.Wrap(err, "lidarr client request error : %v", reqUrl) } @@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) { return resp.StatusCode, buf.Bytes(), nil } -func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { +func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v1/", endpoint) reqUrl := u.String() @@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return nil, errors.Wrap(err, "lidarr client could not marshal data: %v", reqUrl) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return nil, errors.Wrap(err, "lidarr client request error: %v", reqUrl) } @@ -81,7 +82,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return res, nil } -func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { +func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v1/", endpoint) reqUrl := u.String() @@ -91,7 +92,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error return 0, nil, errors.Wrap(err, "lidarr client could not marshal data: %v", reqUrl) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return 0, nil, errors.Wrap(err, "lidarr client request error: %v", reqUrl) } diff --git a/pkg/lidarr/lidarr.go b/pkg/lidarr/lidarr.go index 9297064..3e9ddd1 100644 --- a/pkg/lidarr/lidarr.go +++ b/pkg/lidarr/lidarr.go @@ -1,6 +1,7 @@ package lidarr import ( + "context" "encoding/json" "fmt" "io" @@ -25,8 +26,8 @@ type Config struct { } type Client interface { - Test() (*SystemStatusResponse, error) - Push(release Release) ([]string, error) + Test(ctx context.Context) (*SystemStatusResponse, error) + Push(ctx context.Context, release Release) ([]string, error) } type client struct { @@ -89,8 +90,8 @@ type SystemStatusResponse struct { Version string `json:"version"` } -func (c *client) Test() (*SystemStatusResponse, error) { - status, res, err := c.get("system/status") +func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) { + status, res, err := c.get(ctx, "system/status") if err != nil { return nil, errors.Wrap(err, "lidarr client get error") } @@ -110,8 +111,8 @@ func (c *client) Test() (*SystemStatusResponse, error) { return &response, nil } -func (c *client) Push(release Release) ([]string, error) { - status, res, err := c.postBody("release/push", release) +func (c *client) Push(ctx context.Context, release Release) ([]string, error) { + status, res, err := c.postBody(ctx, "release/push", release) if err != nil { return nil, errors.Wrap(err, "lidarr client post error") } diff --git a/pkg/lidarr/lidarr_test.go b/pkg/lidarr/lidarr_test.go index 840b06d..7eb65c4 100644 --- a/pkg/lidarr/lidarr_test.go +++ b/pkg/lidarr/lidarr_test.go @@ -1,6 +1,7 @@ package lidarr import ( + "context" "net/http" "net/http/httptest" "os" @@ -101,7 +102,7 @@ func Test_client_Push(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.fields.config) - rejections, err := c.Push(tt.args.release) + rejections, err := c.Push(context.Background(), tt.args.release) assert.Equal(t, tt.rejections, rejections) if tt.wantErr && assert.Error(t, err) { assert.Equal(t, tt.err, err) @@ -170,7 +171,7 @@ func Test_client_Test(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.cfg) - got, err := c.Test() + got, err := c.Test(context.Background()) if tt.wantErr && assert.Error(t, err) { assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) } diff --git a/pkg/radarr/client.go b/pkg/radarr/client.go index cdc503c..d3abde1 100644 --- a/pkg/radarr/client.go +++ b/pkg/radarr/client.go @@ -2,6 +2,7 @@ package radarr import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -11,12 +12,12 @@ import ( "github.com/autobrr/autobrr/pkg/errors" ) -func (c *client) get(endpoint string) (int, []byte, error) { +func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() - req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody) if err != nil { return 0, nil, errors.Wrap(err, "could not build request: %v", reqUrl) } @@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) { return resp.StatusCode, buf.Bytes(), nil } -func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { +func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() @@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return nil, errors.Wrap(err, "could not marshal data: %+v", data) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return nil, errors.Wrap(err, "could not build request: %v", reqUrl) } @@ -83,7 +84,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return res, nil } -func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { +func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() @@ -93,7 +94,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error return 0, nil, errors.Wrap(err, "could not marshal data: %+v", data) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return 0, nil, errors.Wrap(err, "could not build request: %v", reqUrl) } diff --git a/pkg/radarr/radarr.go b/pkg/radarr/radarr.go index fd06f28..4841479 100644 --- a/pkg/radarr/radarr.go +++ b/pkg/radarr/radarr.go @@ -1,6 +1,7 @@ package radarr import ( + "context" "encoding/json" "fmt" "io" @@ -25,8 +26,8 @@ type Config struct { } type Client interface { - Test() (*SystemStatusResponse, error) - Push(release Release) ([]string, error) + Test(ctx context.Context) (*SystemStatusResponse, error) + Push(ctx context.Context, release Release) ([]string, error) } type client struct { @@ -88,8 +89,8 @@ func (r *BadRequestResponse) String() string { return fmt.Sprintf("[%v: %v] %v: %v - got value: %v", r.Severity, r.ErrorCode, r.PropertyName, r.ErrorMessage, r.AttemptedValue) } -func (c *client) Test() (*SystemStatusResponse, error) { - status, res, err := c.get("system/status") +func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) { + status, res, err := c.get(ctx, "system/status") if err != nil { return nil, errors.Wrap(err, "radarr error running test") } @@ -108,8 +109,8 @@ func (c *client) Test() (*SystemStatusResponse, error) { return &response, nil } -func (c *client) Push(release Release) ([]string, error) { - status, res, err := c.postBody("release/push", release) +func (c *client) Push(ctx context.Context, release Release) ([]string, error) { + status, res, err := c.postBody(ctx, "release/push", release) if err != nil { return nil, errors.Wrap(err, "error push release") } diff --git a/pkg/radarr/radarr_test.go b/pkg/radarr/radarr_test.go index 346ee10..d59e1bc 100644 --- a/pkg/radarr/radarr_test.go +++ b/pkg/radarr/radarr_test.go @@ -1,6 +1,7 @@ package radarr import ( + "context" "io" "net/http" "net/http/httptest" @@ -141,7 +142,7 @@ func Test_client_Push(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.fields.config) - rejections, err := c.Push(tt.args.release) + rejections, err := c.Push(context.Background(), tt.args.release) assert.Equal(t, tt.rejections, rejections) if tt.wantErr && assert.Error(t, err) { assert.Equal(t, tt.err, err) @@ -223,7 +224,7 @@ func Test_client_Test(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.cfg) - got, err := c.Test() + got, err := c.Test(context.Background()) if tt.wantErr && assert.Error(t, err) { assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) } diff --git a/pkg/readarr/client.go b/pkg/readarr/client.go index e0b1045..832a2fc 100644 --- a/pkg/readarr/client.go +++ b/pkg/readarr/client.go @@ -2,6 +2,7 @@ package readarr import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -11,12 +12,12 @@ import ( "github.com/autobrr/autobrr/pkg/errors" ) -func (c *client) get(endpoint string) (int, []byte, error) { +func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v1/", endpoint) reqUrl := u.String() - req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody) if err != nil { return 0, nil, errors.Wrap(err, "could not build request") } @@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) { return resp.StatusCode, buf.Bytes(), nil } -func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { +func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v1/", endpoint) reqUrl := u.String() @@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return nil, errors.Wrap(err, "could not marshal data: %+v", data) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return nil, errors.Wrap(err, "could not build request") } @@ -81,7 +82,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return res, nil } -func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { +func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v1/", endpoint) reqUrl := u.String() @@ -93,7 +94,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error c.Log.Printf("readarr push JSON: %s\n", string(jsonData)) - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return 0, nil, errors.Wrap(err, "could not build request") } diff --git a/pkg/readarr/readarr.go b/pkg/readarr/readarr.go index 65f3736..75697fa 100644 --- a/pkg/readarr/readarr.go +++ b/pkg/readarr/readarr.go @@ -1,6 +1,7 @@ package readarr import ( + "context" "encoding/json" "fmt" "io" @@ -26,8 +27,8 @@ type Config struct { } type Client interface { - Test() (*SystemStatusResponse, error) - Push(release Release) ([]string, error) + Test(ctx context.Context) (*SystemStatusResponse, error) + Push(ctx context.Context, release Release) ([]string, error) } type client struct { @@ -92,8 +93,8 @@ type SystemStatusResponse struct { Version string `json:"version"` } -func (c *client) Test() (*SystemStatusResponse, error) { - status, res, err := c.get("system/status") +func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) { + status, res, err := c.get(ctx, "system/status") if err != nil { return nil, errors.Wrap(err, "could not make Test") } @@ -112,8 +113,8 @@ func (c *client) Test() (*SystemStatusResponse, error) { return &response, nil } -func (c *client) Push(release Release) ([]string, error) { - status, res, err := c.postBody("release/push", release) +func (c *client) Push(ctx context.Context, release Release) ([]string, error) { + status, res, err := c.postBody(ctx, "release/push", release) if err != nil { return nil, errors.Wrap(err, "could not push release to readarr") } diff --git a/pkg/readarr/readarr_test.go b/pkg/readarr/readarr_test.go index 9897b76..1d8cc01 100644 --- a/pkg/readarr/readarr_test.go +++ b/pkg/readarr/readarr_test.go @@ -1,6 +1,7 @@ package readarr import ( + "context" "net/http" "net/http/httptest" "os" @@ -78,7 +79,7 @@ func Test_client_Push(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.fields.config) - rejections, err := c.Push(tt.args.release) + rejections, err := c.Push(context.Background(), tt.args.release) assert.Equal(t, tt.rejections, rejections) if tt.wantErr && assert.Error(t, err) { assert.Equal(t, tt.err, err) @@ -147,7 +148,7 @@ func Test_client_Test(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.cfg) - got, err := c.Test() + got, err := c.Test(context.Background()) if tt.wantErr && assert.Error(t, err) { assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) } diff --git a/pkg/sonarr/client.go b/pkg/sonarr/client.go index 9a3151b..53a4b10 100644 --- a/pkg/sonarr/client.go +++ b/pkg/sonarr/client.go @@ -2,6 +2,7 @@ package sonarr import ( "bytes" + "context" "encoding/json" "io" "net/http" @@ -11,12 +12,12 @@ import ( "github.com/autobrr/autobrr/pkg/errors" ) -func (c *client) get(endpoint string) (int, []byte, error) { +func (c *client) get(ctx context.Context, endpoint string) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() - req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody) if err != nil { return 0, nil, errors.Wrap(err, "could not build request") } @@ -42,7 +43,7 @@ func (c *client) get(endpoint string) (int, []byte, error) { return resp.StatusCode, buf.Bytes(), nil } -func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { +func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() @@ -52,7 +53,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return nil, errors.Wrap(err, "could not marshal data: %+v", data) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return nil, errors.Wrap(err, "could not build request") } @@ -81,7 +82,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return res, nil } -func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error) { +func (c *client) postBody(ctx context.Context, endpoint string, data interface{}) (int, []byte, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() @@ -91,7 +92,7 @@ func (c *client) postBody(endpoint string, data interface{}) (int, []byte, error return 0, nil, errors.Wrap(err, "could not marshal data: %+v", data) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return 0, nil, errors.Wrap(err, "could not build request") } diff --git a/pkg/sonarr/sonarr.go b/pkg/sonarr/sonarr.go index d3bb9ef..de29007 100644 --- a/pkg/sonarr/sonarr.go +++ b/pkg/sonarr/sonarr.go @@ -1,6 +1,7 @@ package sonarr import ( + "context" "encoding/json" "fmt" "io" @@ -26,8 +27,8 @@ type Config struct { } type Client interface { - Test() (*SystemStatusResponse, error) - Push(release Release) ([]string, error) + Test(ctx context.Context) (*SystemStatusResponse, error) + Push(ctx context.Context, release Release) ([]string, error) } type client struct { @@ -91,8 +92,8 @@ type SystemStatusResponse struct { Version string `json:"version"` } -func (c *client) Test() (*SystemStatusResponse, error) { - status, res, err := c.get("system/status") +func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) { + status, res, err := c.get(ctx, "system/status") if err != nil { return nil, errors.Wrap(err, "could not make Test") } @@ -111,8 +112,8 @@ func (c *client) Test() (*SystemStatusResponse, error) { return &response, nil } -func (c *client) Push(release Release) ([]string, error) { - status, res, err := c.postBody("release/push", release) +func (c *client) Push(ctx context.Context, release Release) ([]string, error) { + status, res, err := c.postBody(ctx, "release/push", release) if err != nil { return nil, errors.Wrap(err, "could not push release to sonarr") } diff --git a/pkg/sonarr/sonarr_test.go b/pkg/sonarr/sonarr_test.go index 5c24d4d..7811bfd 100644 --- a/pkg/sonarr/sonarr_test.go +++ b/pkg/sonarr/sonarr_test.go @@ -1,6 +1,7 @@ package sonarr import ( + "context" "io/ioutil" "log" "net/http" @@ -109,7 +110,7 @@ func Test_client_Push(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.fields.config) - rejections, err := c.Push(tt.args.release) + rejections, err := c.Push(context.Background(), tt.args.release) assert.Equal(t, tt.rejections, rejections) if tt.wantErr && assert.Error(t, err) { assert.Equal(t, tt.err, err) @@ -179,7 +180,7 @@ func Test_client_Test(t *testing.T) { t.Run(tt.name, func(t *testing.T) { c := New(tt.cfg) - got, err := c.Test() + got, err := c.Test(context.Background()) if tt.wantErr && assert.Error(t, err) { assert.EqualErrorf(t, err, tt.expectedErr, "Error should be: %v, got: %v", tt.wantErr, err) } diff --git a/pkg/whisparr/client.go b/pkg/whisparr/client.go index 2391edc..0055973 100644 --- a/pkg/whisparr/client.go +++ b/pkg/whisparr/client.go @@ -2,6 +2,7 @@ package whisparr import ( "bytes" + "context" "encoding/json" "net/http" "net/url" @@ -10,12 +11,12 @@ import ( "github.com/autobrr/autobrr/pkg/errors" ) -func (c *client) get(endpoint string) (*http.Response, error) { +func (c *client) get(ctx context.Context, endpoint string) (*http.Response, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() - req, err := http.NewRequest(http.MethodGet, reqUrl, http.NoBody) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, reqUrl, http.NoBody) if err != nil { return nil, errors.Wrap(err, "could not build request") } @@ -39,7 +40,7 @@ func (c *client) get(endpoint string) (*http.Response, error) { return res, nil } -func (c *client) post(endpoint string, data interface{}) (*http.Response, error) { +func (c *client) post(ctx context.Context, endpoint string, data interface{}) (*http.Response, error) { u, err := url.Parse(c.config.Hostname) u.Path = path.Join(u.Path, "/api/v3/", endpoint) reqUrl := u.String() @@ -49,7 +50,7 @@ func (c *client) post(endpoint string, data interface{}) (*http.Response, error) return nil, errors.Wrap(err, "could not marshal data: %+v", data) } - req, err := http.NewRequest(http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) + req, err := http.NewRequestWithContext(ctx, http.MethodPost, reqUrl, bytes.NewBuffer(jsonData)) if err != nil { return nil, errors.Wrap(err, "could not build request") } diff --git a/pkg/whisparr/whisparr.go b/pkg/whisparr/whisparr.go index d37881c..667aa57 100644 --- a/pkg/whisparr/whisparr.go +++ b/pkg/whisparr/whisparr.go @@ -1,6 +1,7 @@ package whisparr import ( + "context" "encoding/json" "io" "log" @@ -24,8 +25,8 @@ type Config struct { } type Client interface { - Test() (*SystemStatusResponse, error) - Push(release Release) ([]string, error) + Test(ctx context.Context) (*SystemStatusResponse, error) + Push(ctx context.Context, release Release) ([]string, error) } type client struct { @@ -75,8 +76,8 @@ type SystemStatusResponse struct { Version string `json:"version"` } -func (c *client) Test() (*SystemStatusResponse, error) { - res, err := c.get("system/status") +func (c *client) Test(ctx context.Context) (*SystemStatusResponse, error) { + res, err := c.get(ctx, "system/status") if err != nil { return nil, errors.Wrap(err, "could not test whisparr") } @@ -99,8 +100,8 @@ func (c *client) Test() (*SystemStatusResponse, error) { return &response, nil } -func (c *client) Push(release Release) ([]string, error) { - res, err := c.post("release/push", release) +func (c *client) Push(ctx context.Context, release Release) ([]string, error) { + res, err := c.post(ctx, "release/push", release) if err != nil { return nil, errors.Wrap(err, "could not push release to whisparr: %+v", release) }