package database import ( "context" "database/sql" "encoding/json" "github.com/autobrr/autobrr/internal/domain" sq "github.com/Masterminds/squirrel" "github.com/rs/zerolog/log" ) type ActionRepo struct { db *DB clientRepo domain.DownloadClientRepo } func NewActionRepo(db *DB, clientRepo domain.DownloadClientRepo) domain.ActionRepo { return &ActionRepo{ db: db, clientRepo: clientRepo, } } func (r *ActionRepo) FindByFilterID(ctx context.Context, filterID int) ([]*domain.Action, error) { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() actions, err := r.findByFilterID(ctx, tx, filterID) if err != nil { return nil, err } for _, action := range actions { if action.ClientID != 0 { client, err := r.attachDownloadClient(ctx, tx, action.ClientID) if err != nil { return nil, err } action.Client = *client } } return actions, nil } func (r *ActionRepo) findByFilterID(ctx context.Context, tx *Tx, filterID int) ([]*domain.Action, error) { queryBuilder := r.db.squirrel. Select( "id", "name", "type", "enabled", "exec_cmd", "exec_args", "watch_folder", "category", "tags", "label", "save_path", "paused", "ignore_rules", "limit_upload_speed", "limit_download_speed", "webhook_host", "webhook_type", "webhook_method", "webhook_data", "client_id", ). From("action"). Where("filter_id = ?", filterID) query, args, err := queryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.findByFilterID: error building query") return nil, err } rows, err := tx.QueryContext(ctx, query, args...) if err != nil { log.Error().Stack().Err(err).Msg("action.findByFilterID: query error") return nil, err } defer rows.Close() actions := make([]*domain.Action, 0) for rows.Next() { var a domain.Action var execCmd, execArgs, watchFolder, category, tags, label, savePath, host, data, webhookType, webhookMethod sql.NullString var limitUl, limitDl sql.NullInt64 var clientID sql.NullInt32 // filterID var paused, ignoreRules sql.NullBool if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &limitDl, &limitUl, &host, &data, &webhookType, &webhookMethod, &clientID); err != nil { log.Error().Stack().Err(err).Msg("action.findByFilterID: error scanning row") return nil, err } a.ExecCmd = execCmd.String a.ExecArgs = execArgs.String a.WatchFolder = watchFolder.String a.Category = category.String a.Tags = tags.String a.Label = label.String a.SavePath = savePath.String a.Paused = paused.Bool a.IgnoreRules = ignoreRules.Bool a.LimitUploadSpeed = limitUl.Int64 a.LimitDownloadSpeed = limitDl.Int64 a.WebhookHost = host.String a.WebhookData = data.String a.WebhookType = webhookType.String a.WebhookMethod = webhookMethod.String a.ClientID = clientID.Int32 actions = append(actions, &a) } if err := rows.Err(); err != nil { log.Error().Stack().Err(err).Msg("action.findByFilterID: row error") return nil, err } return actions, nil } func (r *ActionRepo) attachDownloadClient(ctx context.Context, tx *Tx, clientID int32) (*domain.DownloadClient, error) { queryBuilder := r.db.squirrel. Select( "id", "name", "type", "enabled", "host", "port", "tls", "tls_skip_verify", "username", "password", "settings", ). From("client"). Where("id = ?", clientID) query, args, err := queryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.attachDownloadClient: error building query") return nil, err } row := tx.QueryRowContext(ctx, query, args...) if err := row.Err(); err != nil { log.Error().Stack().Err(err).Msg("action.attachDownloadClient: error query row") return nil, err } var client domain.DownloadClient var settingsJsonStr string if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.TLS, &client.TLSSkipVerify, &client.Username, &client.Password, &settingsJsonStr); err != nil { log.Error().Stack().Err(err).Msg("action.attachDownloadClient: error scanning row") return nil, err } if settingsJsonStr != "" { if err := json.Unmarshal([]byte(settingsJsonStr), &client.Settings); err != nil { log.Error().Stack().Err(err).Msgf("action.attachDownloadClient: could not marshal download client settings %v", settingsJsonStr) return nil, err } } return &client, nil } func (r *ActionRepo) List(ctx context.Context) ([]domain.Action, error) { queryBuilder := r.db.squirrel. Select( "id", "name", "type", "enabled", "exec_cmd", "exec_args", "watch_folder", "category", "tags", "label", "save_path", "paused", "ignore_rules", "limit_upload_speed", "limit_download_speed", "webhook_host", "webhook_type", "webhook_method", "webhook_data", "client_id", ). From("action") query, args, err := queryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.list: error building query") return nil, err } rows, err := r.db.handler.QueryContext(ctx, query, args...) if err != nil { log.Error().Stack().Err(err).Msg("action.list: error executing query") return nil, err } defer rows.Close() actions := make([]domain.Action, 0) for rows.Next() { var a domain.Action var execCmd, execArgs, watchFolder, category, tags, label, savePath sql.NullString var limitUl, limitDl sql.NullInt64 var clientID sql.NullInt32 var paused, ignoreRules sql.NullBool if err := rows.Scan(&a.ID, &a.Name, &a.Type, &a.Enabled, &execCmd, &execArgs, &watchFolder, &category, &tags, &label, &savePath, &paused, &ignoreRules, &limitDl, &limitUl, &clientID); err != nil { log.Error().Stack().Err(err).Msg("action.list: error scanning row") return nil, err } a.Category = category.String a.Tags = tags.String a.Label = label.String a.SavePath = savePath.String a.Paused = paused.Bool a.IgnoreRules = ignoreRules.Bool a.LimitUploadSpeed = limitUl.Int64 a.LimitDownloadSpeed = limitDl.Int64 a.ClientID = clientID.Int32 actions = append(actions, a) } if err := rows.Err(); err != nil { log.Error().Stack().Err(err).Msg("action.list: row error") return nil, err } return actions, nil } func (r *ActionRepo) Delete(actionID int) error { queryBuilder := r.db.squirrel. Delete("action"). Where("id = ?", actionID) query, args, err := queryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.delete: error building query") return err } _, err = r.db.handler.Exec(query, args...) if err != nil { log.Error().Stack().Err(err).Msg("action.delete: error executing query") return err } log.Debug().Msgf("action.delete: %v", actionID) return nil } func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error { queryBuilder := r.db.squirrel. Delete("action"). Where("filter_id = ?", filterID) query, args, err := queryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.deleteByFilterID: error building query") return err } _, err = r.db.handler.ExecContext(ctx, query, args...) if err != nil { log.Error().Stack().Err(err).Msg("action.deleteByFilterID: error executing query") return err } log.Debug().Msgf("action.deleteByFilterID: %v", filterID) return nil } func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.Action, error) { execCmd := toNullString(action.ExecCmd) execArgs := toNullString(action.ExecArgs) watchFolder := toNullString(action.WatchFolder) category := toNullString(action.Category) tags := toNullString(action.Tags) label := toNullString(action.Label) savePath := toNullString(action.SavePath) webhookHost := toNullString(action.WebhookHost) webhookData := toNullString(action.WebhookData) webhookType := toNullString(action.WebhookType) webhookMethod := toNullString(action.WebhookMethod) limitDL := toNullInt64(action.LimitDownloadSpeed) limitUL := toNullInt64(action.LimitUploadSpeed) clientID := toNullInt32(action.ClientID) filterID := toNullInt32(int32(action.FilterID)) queryBuilder := r.db.squirrel. Insert("action"). Columns( "name", "type", "enabled", "exec_cmd", "exec_args", "watch_folder", "category", "tags", "label", "save_path", "paused", "ignore_rules", "limit_upload_speed", "limit_download_speed", "webhook_host", "webhook_type", "webhook_method", "webhook_data", "client_id", "filter_id", ). Values( action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, webhookHost, webhookType, webhookMethod, webhookData, clientID, filterID, ). Suffix("RETURNING id").RunWith(r.db.handler) // return values var retID int64 err := queryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { log.Error().Stack().Err(err).Msg("action.store: error executing query") return nil, err } log.Debug().Msgf("action.store: added new %v", retID) action.ID = int(retID) return &action, nil } func (r *ActionRepo) Update(ctx context.Context, action domain.Action) (*domain.Action, error) { execCmd := toNullString(action.ExecCmd) execArgs := toNullString(action.ExecArgs) watchFolder := toNullString(action.WatchFolder) category := toNullString(action.Category) tags := toNullString(action.Tags) label := toNullString(action.Label) savePath := toNullString(action.SavePath) webhookHost := toNullString(action.WebhookHost) webhookData := toNullString(action.WebhookData) webhookType := toNullString(action.WebhookType) webhookMethod := toNullString(action.WebhookMethod) limitDL := toNullInt64(action.LimitDownloadSpeed) limitUL := toNullInt64(action.LimitUploadSpeed) clientID := toNullInt32(action.ClientID) filterID := toNullInt32(int32(action.FilterID)) var err error queryBuilder := r.db.squirrel. Update("action"). Set("name", action.Name). Set("type", action.Type). Set("enabled", action.Enabled). Set("exec_cmd", execCmd). Set("exec_args", execArgs). Set("watch_folder", watchFolder). Set("category", category). Set("tags", tags). Set("label", label). Set("save_path", savePath). Set("paused", action.Paused). Set("ignore_rules", action.IgnoreRules). Set("limit_upload_speed", limitUL). Set("limit_download_speed", limitDL). Set("webhook_host", webhookHost). Set("webhook_type", webhookType). Set("webhook_method", webhookMethod). Set("webhook_data", webhookData). Set("client_id", clientID). Set("filter_id", filterID). Where("id = ?", action.ID) query, args, err := queryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.update: error building query") return nil, err } _, err = r.db.handler.ExecContext(ctx, query, args...) if err != nil { log.Error().Stack().Err(err).Msg("action.update: error executing query") return nil, err } log.Debug().Msgf("action.update: %v", action.ID) return &action, nil } func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []*domain.Action, filterID int64) ([]*domain.Action, error) { tx, err := r.db.handler.BeginTx(ctx, nil) if err != nil { return nil, err } defer tx.Rollback() deleteQueryBuilder := r.db.squirrel. Delete("action"). Where("filter_id = ?", filterID) deleteQuery, deleteArgs, err := deleteQueryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error building query") return nil, err } _, err = tx.ExecContext(ctx, deleteQuery, deleteArgs...) if err != nil { log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error executing query") return nil, err } for _, action := range actions { execCmd := toNullString(action.ExecCmd) execArgs := toNullString(action.ExecArgs) watchFolder := toNullString(action.WatchFolder) category := toNullString(action.Category) tags := toNullString(action.Tags) label := toNullString(action.Label) savePath := toNullString(action.SavePath) webhookHost := toNullString(action.WebhookHost) webhookData := toNullString(action.WebhookData) webhookType := toNullString(action.WebhookType) webhookMethod := toNullString(action.WebhookMethod) limitDL := toNullInt64(action.LimitDownloadSpeed) limitUL := toNullInt64(action.LimitUploadSpeed) clientID := toNullInt32(action.ClientID) queryBuilder := r.db.squirrel. Insert("action"). Columns( "name", "type", "enabled", "exec_cmd", "exec_args", "watch_folder", "category", "tags", "label", "save_path", "paused", "ignore_rules", "limit_upload_speed", "limit_download_speed", "webhook_host", "webhook_type", "webhook_method", "webhook_data", "client_id", "filter_id", ). Values( action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, webhookHost, webhookType, webhookMethod, webhookData, clientID, filterID, ). Suffix("RETURNING id").RunWith(tx) // return values var retID int err = queryBuilder.QueryRowContext(ctx).Scan(&retID) if err != nil { log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error executing query") return nil, err } action.ID = retID log.Debug().Msgf("action.StoreFilterActions: store '%v' type: '%v' on filter: %v", action.Name, action.Type, filterID) } err = tx.Commit() if err != nil { log.Error().Stack().Err(err).Msg("action.StoreFilterActions: error updating actions") return nil, err } return actions, nil } func (r *ActionRepo) ToggleEnabled(actionID int) error { var err error queryBuilder := r.db.squirrel. Update("action"). Set("enabled", sq.Expr("NOT enabled")). Where("id = ?", actionID) query, args, err := queryBuilder.ToSql() if err != nil { log.Error().Stack().Err(err).Msg("action.toggleEnabled: error building query") return err } _, err = r.db.handler.Exec(query, args...) if err != nil { log.Error().Stack().Err(err).Msg("action.toggleEnabled: error executing query") return err } log.Debug().Msgf("action.toggleEnabled: %v", actionID) return nil }