diff --git a/internal/api/service.go b/internal/api/service.go index 66a8480..5911c1f 100644 --- a/internal/api/service.go +++ b/internal/api/service.go @@ -10,6 +10,7 @@ import ( "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" + "github.com/autobrr/autobrr/pkg/errors" "github.com/rs/zerolog" ) @@ -17,7 +18,6 @@ import ( type Service interface { List(ctx context.Context) ([]domain.APIKey, error) Store(ctx context.Context, key *domain.APIKey) error - Update(ctx context.Context, key *domain.APIKey) error Delete(ctx context.Context, key string) error ValidateAPIKey(ctx context.Context, token string) bool } @@ -26,63 +26,75 @@ type service struct { log zerolog.Logger repo domain.APIRepo - keyCache []domain.APIKey + keyCache map[string]domain.APIKey } func NewService(log logger.Logger, repo domain.APIRepo) Service { return &service{ log: log.With().Str("module", "api").Logger(), repo: repo, - keyCache: []domain.APIKey{}, + keyCache: map[string]domain.APIKey{}, } } func (s *service) List(ctx context.Context) ([]domain.APIKey, error) { if len(s.keyCache) > 0 { - return s.keyCache, nil + keys := make([]domain.APIKey, 0, len(s.keyCache)) + + for _, key := range s.keyCache { + keys = append(keys, key) + } + + return keys, nil } - return s.repo.GetKeys(ctx) + return s.repo.GetAllAPIKeys(ctx) } -func (s *service) Store(ctx context.Context, key *domain.APIKey) error { - key.Key = GenerateSecureToken(16) +func (s *service) Store(ctx context.Context, apiKey *domain.APIKey) error { + apiKey.Key = GenerateSecureToken(16) - if err := s.repo.Store(ctx, key); err != nil { + if err := s.repo.Store(ctx, apiKey); err != nil { return err } if len(s.keyCache) > 0 { - // set new key - s.keyCache = append(s.keyCache, *key) + // set new apiKey + s.keyCache[apiKey.Key] = *apiKey } return nil } -func (s *service) Update(ctx context.Context, key *domain.APIKey) error { - return nil -} - func (s *service) Delete(ctx context.Context, key string) error { - // reset - s.keyCache = []domain.APIKey{} + err := s.repo.Delete(ctx, key) + if err != nil { + return errors.Wrap(err, "could not delete api key: %s", key) + } - return s.repo.Delete(ctx, key) + // remove key from cache + delete(s.keyCache, key) + + return nil } func (s *service) ValidateAPIKey(ctx context.Context, key string) bool { - keys, err := s.repo.GetKeys(ctx) + if _, ok := s.keyCache[key]; ok { + s.log.Trace().Msgf("api service key cache hit: %s", key) + return true + } + + apiKey, err := s.repo.GetKey(ctx, key) if err != nil { + s.log.Trace().Msgf("api service key cache invalid key: %s", key) return false } - for _, k := range keys { - if k.Key == key { - return true - } - } - return false + s.log.Trace().Msgf("api service key cache miss: %s", key) + + s.keyCache[key] = *apiKey + + return true } func GenerateSecureToken(length int) string { diff --git a/internal/database/api.go b/internal/database/api.go index c2569ce..256eaaa 100644 --- a/internal/database/api.go +++ b/internal/database/api.go @@ -25,9 +25,8 @@ func NewAPIRepo(log logger.Logger, db *DB) domain.APIRepo { } type APIRepo struct { - log zerolog.Logger - db *DB - cache map[string]domain.APIKey + log zerolog.Logger + db *DB } func (r *APIRepo) Store(ctx context.Context, key *domain.APIKey) error { @@ -57,9 +56,7 @@ func (r *APIRepo) Store(ctx context.Context, key *domain.APIKey) error { } func (r *APIRepo) Delete(ctx context.Context, key string) error { - queryBuilder := r.db.squirrel. - Delete("api_key"). - Where(sq.Eq{"key": key}) + queryBuilder := r.db.squirrel.Delete("api_key").Where(sq.Eq{"key": key}) query, args, err := queryBuilder.ToSql() if err != nil { @@ -76,14 +73,9 @@ func (r *APIRepo) Delete(ctx context.Context, key string) error { return nil } -func (r *APIRepo) GetKeys(ctx context.Context) ([]domain.APIKey, error) { +func (r *APIRepo) GetAllAPIKeys(ctx context.Context) ([]domain.APIKey, error) { queryBuilder := r.db.squirrel. - Select( - "name", - "key", - "scopes", - "created_at", - ). + Select("name", "key", "scopes", "created_at"). From("api_key") query, args, err := queryBuilder.ToSql() @@ -116,3 +108,35 @@ func (r *APIRepo) GetKeys(ctx context.Context) ([]domain.APIKey, error) { return keys, nil } + +func (r *APIRepo) GetKey(ctx context.Context, key string) (*domain.APIKey, error) { + queryBuilder := r.db.squirrel. + Select("name", "key", "scopes", "created_at"). + From("api_key"). + Where(sq.Eq{"key": key}) + + query, args, err := queryBuilder.ToSql() + if err != nil { + return nil, errors.Wrap(err, "error building query") + } + + row := r.db.handler.QueryRowContext(ctx, query, args...) + if err := row.Err(); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, domain.ErrRecordNotFound + } + return nil, errors.Wrap(err, "error executing query") + } + + var apiKey domain.APIKey + + var name sql.NullString + + if err := row.Scan(&name, &apiKey.Key, pq.Array(&apiKey.Scopes), &apiKey.CreatedAt); err != nil { + return nil, errors.Wrap(err, "error scanning row") + } + + apiKey.Name = name.String + + return &apiKey, nil +} diff --git a/internal/database/api_test.go b/internal/database/api_test.go index 9c240db..a6f9ae0 100644 --- a/internal/database/api_test.go +++ b/internal/database/api_test.go @@ -69,7 +69,7 @@ func TestAPIRepo_Delete(t *testing.T) { } -func TestAPIRepo_GetKeys(t *testing.T) { +func TestAPIRepo_GetAllAPIKeys(t *testing.T) { for dbType, db := range testDBs { log := setupLoggerForTest() repo := NewAPIRepo(log, db) @@ -77,7 +77,7 @@ func TestAPIRepo_GetKeys(t *testing.T) { t.Run(fmt.Sprintf("GetKeys_Returns_Keys_If_Exists [%s]", dbType), func(t *testing.T) { key := &domain.APIKey{Name: "TestKey", Key: "123", Scopes: []string{"read", "write"}} _ = repo.Store(context.Background(), key) - keys, err := repo.GetKeys(context.Background()) + keys, err := repo.GetAllAPIKeys(context.Background()) assert.NoError(t, err) assert.Greater(t, len(keys), 0) // Cleanup @@ -85,9 +85,32 @@ func TestAPIRepo_GetKeys(t *testing.T) { }) t.Run(fmt.Sprintf("GetKeys_Returns_Empty_If_No_Keys [%s]", dbType), func(t *testing.T) { - keys, err := repo.GetKeys(context.Background()) + keys, err := repo.GetAllAPIKeys(context.Background()) assert.NoError(t, err) assert.Equal(t, 0, len(keys)) }) } } + +func TestAPIRepo_GetKey(t *testing.T) { + for dbType, db := range testDBs { + log := setupLoggerForTest() + repo := NewAPIRepo(log, db) + + t.Run(fmt.Sprintf("GetKey_Returns_Key_If_Exists [%s]", dbType), func(t *testing.T) { + key := &domain.APIKey{Name: "TestKey", Key: "123", Scopes: []string{"read", "write"}} + _ = repo.Store(context.Background(), key) + apiKey, err := repo.GetKey(context.Background(), key.Key) + assert.NoError(t, err) + assert.NotNil(t, apiKey) + // Cleanup + _ = repo.Delete(context.Background(), key.Key) + }) + + t.Run(fmt.Sprintf("GetKeys_Returns_Empty_If_No_Keys [%s]", dbType), func(t *testing.T) { + key, err := repo.GetKey(context.Background(), "nonexistent") + assert.ErrorIs(t, err, domain.ErrRecordNotFound) + assert.Nil(t, key) + }) + } +} diff --git a/internal/domain/api.go b/internal/domain/api.go index d9a488c..808a207 100644 --- a/internal/domain/api.go +++ b/internal/domain/api.go @@ -11,7 +11,8 @@ import ( type APIRepo interface { Store(ctx context.Context, key *APIKey) error Delete(ctx context.Context, key string) error - GetKeys(ctx context.Context) ([]APIKey, error) + GetAllAPIKeys(ctx context.Context) ([]APIKey, error) + GetKey(ctx context.Context, key string) (*APIKey, error) } type APIKey struct { diff --git a/internal/http/apikey.go b/internal/http/apikey.go index 049fdb0..6fef6e3 100644 --- a/internal/http/apikey.go +++ b/internal/http/apikey.go @@ -17,7 +17,6 @@ import ( type apikeyService interface { List(ctx context.Context) ([]domain.APIKey, error) Store(ctx context.Context, key *domain.APIKey) error - Update(ctx context.Context, key *domain.APIKey) error Delete(ctx context.Context, key string) error ValidateAPIKey(ctx context.Context, token string) bool } @@ -51,18 +50,14 @@ func (h apikeyHandler) list(w http.ResponseWriter, r *http.Request) { } func (h apikeyHandler) store(w http.ResponseWriter, r *http.Request) { - - var ( - ctx = r.Context() - data domain.APIKey - ) + var data domain.APIKey if err := json.NewDecoder(r.Body).Decode(&data); err != nil { h.encoder.Error(w, err) return } - if err := h.service.Store(ctx, &data); err != nil { + if err := h.service.Store(r.Context(), &data); err != nil { h.encoder.Error(w, err) return } diff --git a/internal/http/encoder.go b/internal/http/encoder.go index 697ebaf..87ff439 100644 --- a/internal/http/encoder.go +++ b/internal/http/encoder.go @@ -24,6 +24,7 @@ func (e encoder) StatusResponse(w http.ResponseWriter, status int, response inte if response != nil { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(response); err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -37,6 +38,7 @@ func (e encoder) StatusResponseMessage(w http.ResponseWriter, status int, messag if message != "" { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(statusResponse{Message: message}); err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -53,6 +55,7 @@ func (e encoder) StatusCreated(w http.ResponseWriter) { func (e encoder) StatusCreatedData(w http.ResponseWriter, data interface{}) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(http.StatusCreated) + if err := json.NewEncoder(w).Encode(data); err != nil { w.WriteHeader(http.StatusInternalServerError) return @@ -74,7 +77,11 @@ func (e encoder) NotFoundErr(w http.ResponseWriter, err error) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(http.StatusNotFound) - json.NewEncoder(w).Encode(res) + + if err := json.NewEncoder(w).Encode(res); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } } func (e encoder) StatusInternalError(w http.ResponseWriter) { @@ -88,7 +95,11 @@ func (e encoder) Error(w http.ResponseWriter, err error) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(http.StatusInternalServerError) - json.NewEncoder(w).Encode(res) + + if err := json.NewEncoder(w).Encode(res); err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } } func (e encoder) StatusError(w http.ResponseWriter, status int, err error) { @@ -98,6 +109,7 @@ func (e encoder) StatusError(w http.ResponseWriter, status int, err error) { w.Header().Set("Content-Type", "application/json; charset=utf-8") w.WriteHeader(status) + if err := json.NewEncoder(w).Encode(res); err != nil { w.WriteHeader(http.StatusInternalServerError) return