feat(database): improve error handling (#1633)

This commit is contained in:
ze0s 2024-08-29 09:00:53 +02:00 committed by GitHub
parent cc0cca9f0d
commit accc875960
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 56 additions and 44 deletions

View file

@ -616,12 +616,8 @@ func (r *ActionRepo) Get(ctx context.Context, req *domain.GetActionRequest) (*do
}
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
if err := row.Err(); err != nil {
return nil, errors.Wrap(err, "rows error")
return nil, errors.Wrap(err, "error executing query")
}
var a domain.Action

View file

@ -98,7 +98,6 @@ func (r *APIRepo) GetAllAPIKeys(ctx context.Context) ([]domain.APIKey, error) {
if err := rows.Scan(&name, &a.Key, pq.Array(&a.Scopes), &a.CreatedAt); err != nil {
return nil, errors.Wrap(err, "error scanning row")
}
a.Name = name.String
@ -122,9 +121,6 @@ func (r *APIRepo) GetKey(ctx context.Context, key string) (*domain.APIKey, error
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Err(); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error executing query")
}
@ -133,6 +129,10 @@ func (r *APIRepo) GetKey(ctx context.Context, key string) (*domain.APIKey, error
var name sql.NullString
if err := row.Scan(&name, &apiKey.Key, pq.Array(&apiKey.Scopes), &apiKey.CreatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}

View file

@ -120,7 +120,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.TLS, &client.TLSSkipVerify, &client.Username, &client.Password, &settingsJsonStr); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, errors.New("no client configured")
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")

View file

@ -157,7 +157,7 @@ func TestDownloadClientRepo_FindByID(t *testing.T) {
t.Run(fmt.Sprintf("FindByID_Fails_With_Nonexistent_ID [%s]", dbType), func(t *testing.T) {
_, err := repo.FindByID(context.Background(), 9999)
assert.Error(t, err)
assert.Equal(t, "no client configured", err.Error())
assert.ErrorIs(t, err, domain.ErrRecordNotFound)
})
t.Run(fmt.Sprintf("FindByID_Fails_With_Negative_ID [%s]", dbType), func(t *testing.T) {
@ -179,7 +179,7 @@ func TestDownloadClientRepo_FindByID(t *testing.T) {
_ = repo.Delete(context.Background(), mock.ID)
_, err := repo.FindByID(context.Background(), mock.ID)
assert.Error(t, err)
assert.Equal(t, "no client configured", err.Error())
assert.ErrorIs(t, err, domain.ErrRecordNotFound)
// Cleanup
_ = repo.Delete(context.Background(), mock.ID)

View file

@ -68,6 +68,10 @@ func (r *FeedRepo) FindByID(ctx context.Context, id int) (*domain.Feed, error) {
var apiKey, cookie, settings sql.NullString
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
@ -126,6 +130,10 @@ func (r *FeedRepo) FindByIndexerIdentifier(ctx context.Context, indexer string)
var apiKey, cookie, settings sql.NullString
if err := row.Scan(&f.ID, &f.Indexer.ID, &f.Indexer.Identifier, &f.Indexer.IdentifierExternal, &f.Indexer.Name, &f.Name, &f.Type, &f.Enabled, &f.URL, &f.Interval, &f.Timeout, &f.MaxAge, &apiKey, &cookie, &settings, &f.CreatedAt, &f.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
@ -237,6 +245,10 @@ func (r *FeedRepo) GetLastRunDataByID(ctx context.Context, id int) (string, erro
var data sql.NullString
if err := row.Scan(&data); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return "", domain.ErrRecordNotFound
}
return "", errors.Wrap(err, "error scanning row")
}

View file

@ -117,7 +117,7 @@ func (r *FeedCacheRepo) GetCountByFeed(ctx context.Context, feedId int) (int, er
}
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err != nil {
if err := row.Err(); err != nil {
return 0, errors.Wrap(err, "error executing query")
}

View file

@ -32,26 +32,15 @@ func NewFilterRepo(log logger.Logger, db *DB) domain.FilterRepo {
}
func (r *FilterRepo) Find(ctx context.Context, params domain.FilterQueryParams) ([]domain.Filter, error) {
tx, err := r.db.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return nil, errors.Wrap(err, "error begin transaction")
}
defer tx.Rollback()
filters, err := r.find(ctx, tx, params)
filters, err := r.find(ctx, params)
if err != nil {
return nil, err
}
if err := tx.Commit(); err != nil {
return nil, errors.Wrap(err, "error commit transaction find releases")
}
return filters, nil
}
func (r *FilterRepo) find(ctx context.Context, tx *Tx, params domain.FilterQueryParams) ([]domain.Filter, error) {
func (r *FilterRepo) find(ctx context.Context, params domain.FilterQueryParams) ([]domain.Filter, error) {
actionCountQuery := r.db.squirrel.
Select("COUNT(*)").
From("action a").
@ -104,7 +93,7 @@ func (r *FilterRepo) find(ctx context.Context, tx *Tx, params domain.FilterQuery
return nil, errors.Wrap(err, "error building query")
}
rows, err := tx.QueryContext(ctx, query, args...)
rows, err := r.db.handler.QueryContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
@ -1347,6 +1336,10 @@ WHERE (release_action_status.status = 'PUSH_APPROVED' OR release_action_status.s
var f domain.FilterDownloads
if err := row.Scan(&f.HourCount, &f.DayCount, &f.WeekCount, &f.MonthCount, &f.TotalCount); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning stats data sqlite")
}
@ -1373,6 +1366,10 @@ WHERE (release_action_status.status = 'PUSH_APPROVED' OR release_action_status.s
var f domain.FilterDownloads
if err := row.Scan(&f.HourCount, &f.DayCount, &f.WeekCount, &f.MonthCount, &f.TotalCount); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning stats data postgres")
}

View file

@ -148,6 +148,10 @@ func (r *IndexerRepo) FindByID(ctx context.Context, id int) (*domain.Indexer, er
var identifierExternal, implementation, baseURL, settings sql.Null[string]
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
@ -193,6 +197,10 @@ func (r *IndexerRepo) GetBy(ctx context.Context, req domain.GetIndexerRequest) (
var identifierExternal, implementation, baseURL, settings sql.Null[string]
if err := row.Scan(&i.ID, &i.Enabled, &i.Name, &i.Identifier, &identifierExternal, &implementation, &baseURL, &settings); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}

View file

@ -48,6 +48,10 @@ func (r *IrcRepo) GetNetworkByID(ctx context.Context, id int64) (*domain.IrcNetw
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Scan(&n.ID, &n.Enabled, &n.Name, &n.Server, &n.Port, &tls, &pass, &nick, &n.Auth.Mechanism, &account, &password, &inviteCmd, &bouncerAddr, &n.UseBouncer, &n.BotMode); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
@ -253,6 +257,9 @@ func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcN
r.log.Trace().Str("database", "irc.checkExistingNetwork").Msgf("query: '%s', args: '%v'", query, args)
row := r.db.handler.QueryRowContext(ctx, query, args...)
if err := row.Err(); err != nil {
return nil, errors.Wrap(err, "error executing query")
}
var net domain.IrcNetwork

View file

@ -29,7 +29,6 @@ func NewNotificationRepo(log logger.Logger, db *DB) domain.NotificationRepo {
}
func (r *NotificationRepo) Find(ctx context.Context, params domain.NotificationQueryParams) ([]domain.Notification, int, error) {
queryBuilder := r.db.squirrel.
Select("id", "name", "type", "enabled", "events", "webhook", "token", "api_key", "channel", "priority", "topic", "host", "created_at", "updated_at", "COUNT(*) OVER() AS total_count").
From("notification").
@ -75,7 +74,6 @@ func (r *NotificationRepo) Find(ctx context.Context, params domain.NotificationQ
}
func (r *NotificationRepo) List(ctx context.Context) ([]domain.Notification, error) {
rows, err := r.db.handler.QueryContext(ctx, "SELECT id, name, type, enabled, events, token, api_key, webhook, title, icon, host, username, password, channel, targets, devices, priority, topic, created_at, updated_at FROM notification ORDER BY name ASC")
if err != nil {
return nil, errors.Wrap(err, "error executing query")
@ -117,7 +115,6 @@ func (r *NotificationRepo) List(ctx context.Context) ([]domain.Notification, err
}
func (r *NotificationRepo) FindByID(ctx context.Context, id int) (*domain.Notification, error) {
queryBuilder := r.db.squirrel.
Select(
"id",
@ -158,6 +155,10 @@ func (r *NotificationRepo) FindByID(ctx context.Context, id int) (*domain.Notifi
var token, apiKey, webhook, title, icon, host, username, password, channel, targets, devices, topic sql.NullString
if err := row.Scan(&n.ID, &n.Name, &n.Type, &n.Enabled, pq.Array(&n.Events), &token, &apiKey, &webhook, &title, &icon, &host, &username, &password, &channel, &targets, &devices, &n.Priority, &topic, &n.CreatedAt, &n.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}

View file

@ -423,12 +423,8 @@ func (repo *ReleaseRepo) Get(ctx context.Context, req *domain.GetReleaseRequest)
repo.log.Trace().Str("database", "release.find").Msgf("query: '%s', args: '%v'", query, args)
row := repo.db.handler.QueryRowContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
if err := row.Err(); err != nil {
return nil, errors.Wrap(err, "error rows find release")
return nil, errors.Wrap(err, "error executing query")
}
var rls domain.Release
@ -438,8 +434,9 @@ func (repo *ReleaseRepo) Get(ctx context.Context, req *domain.GetReleaseRequest)
if err := row.Scan(&rls.ID, &rls.FilterStatus, pq.Array(&rls.Rejections), &indexerName, &filterName, &filterId, &rls.Protocol, &rls.Implementation, &infoUrl, &downloadUrl, &rls.Title, &rls.TorrentName, &category, &rls.Size, &groupId, &torrentId, &uploader, &rls.Timestamp); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
}
@ -469,13 +466,8 @@ func (repo *ReleaseRepo) GetActionStatus(ctx context.Context, req *domain.GetRel
}
row := repo.db.handler.QueryRowContext(ctx, query, args...)
if err != nil {
return nil, errors.Wrap(err, "error executing query")
}
if err := row.Err(); err != nil {
repo.log.Error().Stack().Err(err)
return nil, err
return nil, errors.Wrap(err, "error executing query")
}
var rls domain.ReleaseActionStatus
@ -485,7 +477,7 @@ func (repo *ReleaseRepo) GetActionStatus(ctx context.Context, req *domain.GetRel
if err := row.Scan(&rls.ID, &rls.Status, &rls.Action, &actionId, &rls.Type, &client, &filter, &filterId, &rls.ReleaseID, pq.Array(&rls.Rejections), &rls.Timestamp); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
return nil, domain.ErrRecordNotFound
}
return nil, errors.Wrap(err, "error scanning row")
@ -545,7 +537,6 @@ func (repo *ReleaseRepo) attachActionStatus(ctx context.Context, tx *Tx, release
}
func (repo *ReleaseRepo) Stats(ctx context.Context) (*domain.ReleaseStats, error) {
query := `SELECT *
FROM (
SELECT