diff --git a/internal/database/notification.go b/internal/database/notification.go index bbe9e96..9f30d97 100644 --- a/internal/database/notification.go +++ b/internal/database/notification.go @@ -179,15 +179,7 @@ func (r *NotificationRepo) FindByID(ctx context.Context, id int) (*domain.Notifi return &n, nil } -func (r *NotificationRepo) Store(ctx context.Context, notification domain.Notification) (*domain.Notification, error) { - webhook := toNullString(notification.Webhook) - token := toNullString(notification.Token) - apiKey := toNullString(notification.APIKey) - channel := toNullString(notification.Channel) - topic := toNullString(notification.Topic) - host := toNullString(notification.Host) - username := toNullString(notification.Username) - +func (r *NotificationRepo) Store(ctx context.Context, notification *domain.Notification) error { queryBuilder := r.db.squirrel. Insert("notification"). Columns( @@ -209,68 +201,56 @@ func (r *NotificationRepo) Store(ctx context.Context, notification domain.Notifi notification.Type, notification.Enabled, pq.Array(notification.Events), - webhook, - token, - apiKey, - channel, + toNullString(notification.Webhook), + toNullString(notification.Token), + toNullString(notification.APIKey), + toNullString(notification.Channel), notification.Priority, - topic, - host, - username, + toNullString(notification.Topic), + toNullString(notification.Host), + toNullString(notification.Username), ). Suffix("RETURNING id").RunWith(r.db.handler) - // return values - var retID int64 - - if err := queryBuilder.QueryRowContext(ctx).Scan(&retID); err != nil { - return nil, errors.Wrap(err, "error executing query") + if err := queryBuilder.QueryRowContext(ctx).Scan(¬ification.ID); err != nil { + return errors.Wrap(err, "error executing query") } - r.log.Debug().Msgf("notification.store: added new %v", retID) - notification.ID = int(retID) + r.log.Debug().Msgf("notification.store: added new %v", notification.ID) - return ¬ification, nil + return nil } -func (r *NotificationRepo) Update(ctx context.Context, notification domain.Notification) (*domain.Notification, error) { - webhook := toNullString(notification.Webhook) - token := toNullString(notification.Token) - apiKey := toNullString(notification.APIKey) - channel := toNullString(notification.Channel) - topic := toNullString(notification.Topic) - host := toNullString(notification.Host) - username := toNullString(notification.Username) - +func (r *NotificationRepo) Update(ctx context.Context, notification *domain.Notification) error { queryBuilder := r.db.squirrel. Update("notification"). Set("name", notification.Name). Set("type", notification.Type). Set("enabled", notification.Enabled). Set("events", pq.Array(notification.Events)). - Set("webhook", webhook). - Set("token", token). - Set("api_key", apiKey). - Set("channel", channel). + Set("webhook", toNullString(notification.Webhook)). + Set("token", toNullString(notification.Token)). + Set("api_key", toNullString(notification.APIKey)). + Set("channel", toNullString(notification.Channel)). Set("priority", notification.Priority). - Set("topic", topic). - Set("host", host). - Set("username", username). + Set("topic", toNullString(notification.Topic)). + Set("host", toNullString(notification.Host)). + Set("username", toNullString(notification.Username)). Set("updated_at", sq.Expr("CURRENT_TIMESTAMP")). Where(sq.Eq{"id": notification.ID}) query, args, err := queryBuilder.ToSql() if err != nil { - return nil, errors.Wrap(err, "error building query") + return errors.Wrap(err, "error building query") } if _, err = r.db.handler.ExecContext(ctx, query, args...); err != nil { - return nil, errors.Wrap(err, "error executing query") + return errors.Wrap(err, "error executing query") } r.log.Debug().Msgf("notification.update: %v", notification.Name) - return ¬ification, nil + return nil } func (r *NotificationRepo) Delete(ctx context.Context, notificationID int) error { diff --git a/internal/database/notification_test.go b/internal/database/notification_test.go index 54da418..9ee2bee 100644 --- a/internal/database/notification_test.go +++ b/internal/database/notification_test.go @@ -54,8 +54,10 @@ func TestNotificationRepo_Store(t *testing.T) { // Setup assert.NotNil(t, mockData) + notification := getMockNotification() + // Execute - notification, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), ¬ification) // Verify assert.NoError(t, err) @@ -77,28 +79,32 @@ func TestNotificationRepo_Update(t *testing.T) { t.Run(fmt.Sprintf("Update_Succeeds [%s]", dbType), func(t *testing.T) { // Initial setup and Store - notification, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), &mockData) assert.NoError(t, err) - assert.NotNil(t, notification) + assert.NotNil(t, &mockData) // Modify some fields - updatedMockData := *notification - updatedMockData.Name = "UpdatedName" - updatedMockData.Type = domain.NotificationTypeTelegram - updatedMockData.Priority = 2 + newName := "UpdatedName" + newType := domain.NotificationTypeTelegram + newPriority := int32(2) + + updatedMockData := &mockData + updatedMockData.Name = newName + updatedMockData.Type = newType + updatedMockData.Priority = newPriority // Execute Update - updatedNotification, err := repo.Update(context.Background(), updatedMockData) + err = repo.Update(context.Background(), updatedMockData) // Verify assert.NoError(t, err) - assert.NotNil(t, updatedNotification) - assert.Equal(t, updatedMockData.Name, updatedNotification.Name) - assert.Equal(t, updatedMockData.Type, updatedNotification.Type) - assert.Equal(t, updatedMockData.Priority, updatedNotification.Priority) + assert.NotNil(t, &mockData) + assert.Equal(t, updatedMockData.Name, newName) + assert.Equal(t, updatedMockData.Type, newType) + assert.Equal(t, updatedMockData.Priority, newPriority) // Cleanup - _ = repo.Delete(context.Background(), updatedNotification.ID) + _ = repo.Delete(context.Background(), mockData.ID) }) } } @@ -108,11 +114,13 @@ func TestNotificationRepo_Delete(t *testing.T) { log := setupLoggerForTest() repo := NewNotificationRepo(log, db) - mockData := getMockNotification() + //mockData := getMockNotification() t.Run(fmt.Sprintf("Delete_Succeeds [%s]", dbType), func(t *testing.T) { + notification := getMockNotification() + // Initial setup and Store - notification, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), ¬ification) assert.NoError(t, err) assert.NotNil(t, notification) @@ -148,11 +156,11 @@ func TestNotificationRepo_Find(t *testing.T) { _ = repo.Delete(context.Background(), notification.ID) } - _, err := repo.Store(context.Background(), mockData1) + err := repo.Store(context.Background(), &mockData1) assert.NoError(t, err) - _, err = repo.Store(context.Background(), mockData2) + err = repo.Store(context.Background(), &mockData2) assert.NoError(t, err) - _, err = repo.Store(context.Background(), mockData3) + err = repo.Store(context.Background(), &mockData3) assert.NoError(t, err) // Setup query params @@ -188,11 +196,13 @@ func TestNotificationRepo_FindByID(t *testing.T) { t.Run(fmt.Sprintf("FindByID_Succeeds [%s]", dbType), func(t *testing.T) { // Setup + //notification := getMockNotification() + assert.NotNil(t, mockData) - notification, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), &mockData) // Execute - notification, err = repo.FindByID(context.Background(), notification.ID) + notification, err := repo.FindByID(context.Background(), mockData.ID) // Verify assert.NoError(t, err) @@ -221,7 +231,7 @@ func TestNotificationRepo_List(t *testing.T) { } for i := 0; i < 10; i++ { - _, err := repo.Store(context.Background(), mockData) + err := repo.Store(context.Background(), &mockData) assert.NoError(t, err) } diff --git a/internal/domain/notification.go b/internal/domain/notification.go index d282412..6b1b297 100644 --- a/internal/domain/notification.go +++ b/internal/domain/notification.go @@ -12,8 +12,8 @@ type NotificationRepo interface { List(ctx context.Context) ([]Notification, error) Find(ctx context.Context, params NotificationQueryParams) ([]Notification, int, error) FindByID(ctx context.Context, id int) (*Notification, error) - Store(ctx context.Context, notification Notification) (*Notification, error) - Update(ctx context.Context, notification Notification) (*Notification, error) + Store(ctx context.Context, notification *Notification) error + Update(ctx context.Context, notification *Notification) error Delete(ctx context.Context, notificationID int) error } diff --git a/internal/http/notification.go b/internal/http/notification.go index 8d6b0bc..685dfd2 100644 --- a/internal/http/notification.go +++ b/internal/http/notification.go @@ -18,10 +18,10 @@ import ( type notificationService interface { Find(context.Context, domain.NotificationQueryParams) ([]domain.Notification, int, error) FindByID(ctx context.Context, id int) (*domain.Notification, error) - Store(ctx context.Context, n domain.Notification) (*domain.Notification, error) - Update(ctx context.Context, n domain.Notification) (*domain.Notification, error) + Store(ctx context.Context, notification *domain.Notification) error + Update(ctx context.Context, notification *domain.Notification) error Delete(ctx context.Context, id int) error - Test(ctx context.Context, notification domain.Notification) error + Test(ctx context.Context, notification *domain.Notification) error } type notificationHandler struct { @@ -59,19 +59,19 @@ func (h notificationHandler) list(w http.ResponseWriter, r *http.Request) { } func (h notificationHandler) store(w http.ResponseWriter, r *http.Request) { - var data domain.Notification + var data *domain.Notification if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.Error(w, err) return } - filter, err := h.service.Store(r.Context(), data) + err := h.service.Store(r.Context(), data) if err != nil { h.encoder.Error(w, err) return } - h.encoder.StatusResponse(w, http.StatusCreated, filter) + h.encoder.StatusResponse(w, http.StatusCreated, data) } func (h notificationHandler) findByID(w http.ResponseWriter, r *http.Request) { @@ -96,19 +96,19 @@ func (h notificationHandler) findByID(w http.ResponseWriter, r *http.Request) { } func (h notificationHandler) update(w http.ResponseWriter, r *http.Request) { - var data domain.Notification + var data *domain.Notification if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.Error(w, err) return } - filter, err := h.service.Update(r.Context(), data) + err := h.service.Update(r.Context(), data) if err != nil { h.encoder.Error(w, err) return } - h.encoder.StatusResponse(w, http.StatusOK, filter) + h.encoder.StatusResponse(w, http.StatusOK, data) } func (h notificationHandler) delete(w http.ResponseWriter, r *http.Request) { @@ -127,7 +127,7 @@ func (h notificationHandler) delete(w http.ResponseWriter, r *http.Request) { } func (h notificationHandler) test(w http.ResponseWriter, r *http.Request) { - var data domain.Notification + var data *domain.Notification if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.Error(w, err) return diff --git a/internal/notification/discord.go b/internal/notification/discord.go index 6a63163..42fcc55 100644 --- a/internal/notification/discord.go +++ b/internal/notification/discord.go @@ -50,7 +50,7 @@ const ( type discordSender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification httpClient *http.Client } @@ -59,7 +59,7 @@ func (a *discordSender) Name() string { return "discord" } -func NewDiscordSender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewDiscordSender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { return &discordSender{ log: log.With().Str("sender", "discord").Logger(), Settings: settings, diff --git a/internal/notification/gotify.go b/internal/notification/gotify.go index 50697ca..c3cc166 100644 --- a/internal/notification/gotify.go +++ b/internal/notification/gotify.go @@ -26,7 +26,7 @@ type gotifyMessage struct { type gotifySender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification builder MessageBuilderPlainText httpClient *http.Client @@ -36,7 +36,7 @@ func (s *gotifySender) Name() string { return "gotify" } -func NewGotifySender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewGotifySender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { return &gotifySender{ log: log.With().Str("sender", "gotify").Logger(), Settings: settings, diff --git a/internal/notification/lunasea.go b/internal/notification/lunasea.go index b2fecb2..832d5ee 100644 --- a/internal/notification/lunasea.go +++ b/internal/notification/lunasea.go @@ -27,7 +27,7 @@ type LunaSeaMessage struct { type lunaSeaSender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification builder MessageBuilderPlainText httpClient *http.Client @@ -43,7 +43,7 @@ func (s *lunaSeaSender) rewriteWebhookURL(url string) string { return lunaWebhook.ReplaceAllString(url, "/custom/") } // `custom` is not mentioned in their docs, so I thought this would be a good idea to add to avoid user errors -func NewLunaSeaSender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewLunaSeaSender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { return &lunaSeaSender{ log: log.With().Str("sender", "lunasea").Logger(), Settings: settings, diff --git a/internal/notification/notifiarr.go b/internal/notification/notifiarr.go index 432de16..0961267 100644 --- a/internal/notification/notifiarr.go +++ b/internal/notification/notifiarr.go @@ -44,7 +44,7 @@ type notifiarrMessageData struct { type notifiarrSender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification baseUrl string httpClient *http.Client @@ -54,7 +54,7 @@ func (s *notifiarrSender) Name() string { return "notifiarr" } -func NewNotifiarrSender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewNotifiarrSender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { return ¬ifiarrSender{ log: log.With().Str("sender", "notifiarr").Logger(), Settings: settings, diff --git a/internal/notification/ntfy.go b/internal/notification/ntfy.go index 02bd8b7..e1d2d87 100644 --- a/internal/notification/ntfy.go +++ b/internal/notification/ntfy.go @@ -25,7 +25,7 @@ type ntfyMessage struct { type ntfySender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification builder MessageBuilderPlainText httpClient *http.Client @@ -35,7 +35,7 @@ func (s *ntfySender) Name() string { return "ntfy" } -func NewNtfySender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewNtfySender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { return &ntfySender{ log: log.With().Str("sender", "ntfy").Logger(), Settings: settings, diff --git a/internal/notification/pushover.go b/internal/notification/pushover.go index ed0fff1..9a150ff 100644 --- a/internal/notification/pushover.go +++ b/internal/notification/pushover.go @@ -32,7 +32,7 @@ type pushoverMessage struct { type pushoverSender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification baseUrl string builder MessageBuilderHTML @@ -43,7 +43,7 @@ func (s *pushoverSender) Name() string { return "pushover" } -func NewPushoverSender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewPushoverSender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { return &pushoverSender{ log: log.With().Str("sender", "pushover").Logger(), Settings: settings, diff --git a/internal/notification/service.go b/internal/notification/service.go index 1f570f5..ac912e3 100644 --- a/internal/notification/service.go +++ b/internal/notification/service.go @@ -19,11 +19,11 @@ import ( type Service interface { Find(ctx context.Context, params domain.NotificationQueryParams) ([]domain.Notification, int, error) FindByID(ctx context.Context, id int) (*domain.Notification, error) - Store(ctx context.Context, n domain.Notification) (*domain.Notification, error) - Update(ctx context.Context, n domain.Notification) (*domain.Notification, error) + Store(ctx context.Context, notification *domain.Notification) error + Update(ctx context.Context, notification *domain.Notification) error Delete(ctx context.Context, id int) error Send(event domain.NotificationEvent, payload domain.NotificationPayload) - Test(ctx context.Context, notification domain.Notification) error + Test(ctx context.Context, notification *domain.Notification) error } type service struct { @@ -64,30 +64,30 @@ func (s *service) FindByID(ctx context.Context, id int) (*domain.Notification, e return notification, err } -func (s *service) Store(ctx context.Context, notification domain.Notification) (*domain.Notification, error) { - _, err := s.repo.Store(ctx, notification) +func (s *service) Store(ctx context.Context, notification *domain.Notification) error { + err := s.repo.Store(ctx, notification) if err != nil { s.log.Error().Err(err).Msgf("could not store notification: %+v", notification) - return nil, err + return err } // register sender s.registerSender(notification) - return nil, nil + return nil } -func (s *service) Update(ctx context.Context, notification domain.Notification) (*domain.Notification, error) { - _, err := s.repo.Update(ctx, notification) +func (s *service) Update(ctx context.Context, notification *domain.Notification) error { + err := s.repo.Update(ctx, notification) if err != nil { s.log.Error().Err(err).Msgf("could not update notification: %+v", notification) - return nil, err + return err } // register sender s.registerSender(notification) - return nil, nil + return nil } func (s *service) Delete(ctx context.Context, id int) error { @@ -111,33 +111,36 @@ func (s *service) registerSenders() { } for _, notificationSender := range notificationSenders { - s.registerSender(notificationSender) + s.registerSender(¬ificationSender) } return } // registerSender registers an enabled notification via it's id -func (s *service) registerSender(notification domain.Notification) { - if notification.Enabled { - switch notification.Type { - case domain.NotificationTypeDiscord: - s.senders[notification.ID] = NewDiscordSender(s.log, notification) - case domain.NotificationTypeGotify: - s.senders[notification.ID] = NewGotifySender(s.log, notification) - case domain.NotificationTypeLunaSea: - s.senders[notification.ID] = NewLunaSeaSender(s.log, notification) - case domain.NotificationTypeNotifiarr: - s.senders[notification.ID] = NewNotifiarrSender(s.log, notification) - case domain.NotificationTypeNtfy: - s.senders[notification.ID] = NewNtfySender(s.log, notification) - case domain.NotificationTypePushover: - s.senders[notification.ID] = NewPushoverSender(s.log, notification) - case domain.NotificationTypeShoutrrr: - s.senders[notification.ID] = NewShoutrrrSender(s.log, notification) - case domain.NotificationTypeTelegram: - s.senders[notification.ID] = NewTelegramSender(s.log, notification) - } +func (s *service) registerSender(notification *domain.Notification) { + if !notification.Enabled { + delete(s.senders, notification.ID) + return + } + + switch notification.Type { + case domain.NotificationTypeDiscord: + s.senders[notification.ID] = NewDiscordSender(s.log, notification) + case domain.NotificationTypeGotify: + s.senders[notification.ID] = NewGotifySender(s.log, notification) + case domain.NotificationTypeLunaSea: + s.senders[notification.ID] = NewLunaSeaSender(s.log, notification) + case domain.NotificationTypeNotifiarr: + s.senders[notification.ID] = NewNotifiarrSender(s.log, notification) + case domain.NotificationTypeNtfy: + s.senders[notification.ID] = NewNtfySender(s.log, notification) + case domain.NotificationTypePushover: + s.senders[notification.ID] = NewPushoverSender(s.log, notification) + case domain.NotificationTypeShoutrrr: + s.senders[notification.ID] = NewShoutrrrSender(s.log, notification) + case domain.NotificationTypeTelegram: + s.senders[notification.ID] = NewTelegramSender(s.log, notification) } return @@ -163,7 +166,7 @@ func (s *service) Send(event domain.NotificationEvent, payload domain.Notificati return } -func (s *service) Test(ctx context.Context, notification domain.Notification) error { +func (s *service) Test(ctx context.Context, notification *domain.Notification) error { var agent domain.NotificationSender // send test events diff --git a/internal/notification/shoutrrr.go b/internal/notification/shoutrrr.go index 5fd247c..18cef51 100644 --- a/internal/notification/shoutrrr.go +++ b/internal/notification/shoutrrr.go @@ -9,7 +9,7 @@ import ( type shoutrrrSender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification builder MessageBuilderPlainText } @@ -17,7 +17,7 @@ func (s *shoutrrrSender) Name() string { return "shoutrrr" } -func NewShoutrrrSender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewShoutrrrSender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { return &shoutrrrSender{ log: log.With().Str("sender", "shoutrrr").Logger(), Settings: settings, diff --git a/internal/notification/telegram.go b/internal/notification/telegram.go index ebd6558..a5611a6 100644 --- a/internal/notification/telegram.go +++ b/internal/notification/telegram.go @@ -30,7 +30,7 @@ type TelegramMessage struct { type telegramSender struct { log zerolog.Logger - Settings domain.Notification + Settings *domain.Notification ThreadID int builder MessageBuilderHTML @@ -41,7 +41,7 @@ func (s *telegramSender) Name() string { return "telegram" } -func NewTelegramSender(log zerolog.Logger, settings domain.Notification) domain.NotificationSender { +func NewTelegramSender(log zerolog.Logger, settings *domain.Notification) domain.NotificationSender { threadID := 0 if t := settings.Topic; t != "" { var err error