Fix: Performance issues and sqlite locking (#74)

* fix: performance issues and sqlite locking

* fix: dashboard release stats was reversed

* refactor: open and migrate db

* chore: cleanup
This commit is contained in:
Ludvig Lundgren 2022-01-11 19:35:27 +01:00 committed by GitHub
parent d8c37dde2f
commit f466657ed4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
25 changed files with 362 additions and 658 deletions

View file

@ -4,8 +4,6 @@ before:
builds:
- id: autobrr
env:
- CGO_ENABLED=0
goos:
- linux
- windows
@ -23,8 +21,6 @@ builds:
main: ./cmd/autobrr/main.go
binary: autobrr
- id: autobrrctl
env:
- CGO_ENABLED=0
goos:
- linux
- windows

View file

@ -1,7 +1,6 @@
package main
import (
"database/sql"
"os"
"os/signal"
"syscall"
@ -10,7 +9,6 @@ import (
"github.com/r3labs/sse/v2"
"github.com/rs/zerolog/log"
"github.com/spf13/pflag"
_ "modernc.org/sqlite"
"github.com/autobrr/autobrr/internal/action"
"github.com/autobrr/autobrr/internal/auth"
@ -60,19 +58,11 @@ func main() {
log.Info().Msgf("Version: %v", version)
log.Info().Msgf("Log-level: %v", cfg.LogLevel)
// if configPath is set then put database inside that path, otherwise create wherever it's run
var dataSource = database.DataSourceName(configPath, "autobrr.db")
// open database connection
db, err := sql.Open("sqlite", dataSource)
if err != nil {
db := database.NewSqliteDB(configPath)
if err := db.Open(); err != nil {
log.Fatal().Err(err).Msg("could not open db connection")
}
defer db.Close()
if err = database.Migrate(db); err != nil {
log.Fatal().Err(err).Msg("could not migrate db")
}
// setup repos
var (
@ -125,12 +115,15 @@ func main() {
case syscall.SIGHUP:
log.Print("shutting down server sighup")
srv.Shutdown()
db.Close()
os.Exit(1)
case syscall.SIGINT, syscall.SIGQUIT:
srv.Shutdown()
db.Close()
os.Exit(1)
case syscall.SIGKILL, syscall.SIGTERM:
srv.Shutdown()
db.Close()
os.Exit(1)
}
}

View file

@ -3,14 +3,13 @@ package main
import (
"bufio"
"context"
"database/sql"
"flag"
"fmt"
"log"
"os"
_ "github.com/mattn/go-sqlite3"
"golang.org/x/crypto/ssh/terminal"
_ "modernc.org/sqlite"
"github.com/autobrr/autobrr/internal/database"
"github.com/autobrr/autobrr/internal/domain"
@ -39,18 +38,10 @@ func main() {
log.Fatal("--config required")
}
// if configPath is set then put database inside that path, otherwise create wherever it's run
var dataSource = database.DataSourceName(configPath, "autobrr.db")
// open database connection
db, err := sql.Open("sqlite", dataSource)
if err != nil {
log.Fatalf("failed to open database: %v", err)
}
defer db.Close()
if err = database.Migrate(db); err != nil {
log.Fatalf("could not migrate db: %v", err)
db := database.NewSqliteDB(configPath)
if err := db.Open(); err != nil {
log.Fatal("could not open db connection")
}
userRepo := database.NewUserRepo(db)

8
go.mod
View file

@ -4,16 +4,22 @@ go 1.16
require (
github.com/Masterminds/squirrel v1.5.1
github.com/anacrolix/torrent v1.38.0
github.com/anacrolix/dht/v2 v2.5.1 // indirect
github.com/anacrolix/missinggo v1.3.0 // indirect
github.com/anacrolix/missinggo/v2 v2.5.2 // indirect
github.com/anacrolix/torrent v1.11.0
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef
github.com/dustin/go-humanize v1.0.0
github.com/gdm85/go-libdeluge v0.5.5
github.com/go-chi/chi v1.5.4
github.com/gorilla/sessions v1.2.1
github.com/kr/pretty v0.3.0 // indirect
github.com/lib/pq v1.10.4
github.com/mattn/go-isatty v0.0.14 // indirect
github.com/mattn/go-sqlite3 v1.14.10
github.com/pkg/errors v0.9.1
github.com/r3labs/sse/v2 v2.7.2
github.com/rogpeppe/go-internal v1.8.0 // indirect
github.com/rs/cors v1.8.0
github.com/rs/zerolog v1.26.0
github.com/spf13/pflag v1.0.5

494
go.sum

File diff suppressed because it is too large Load diff

View file

@ -1,6 +1,7 @@
package action
import (
"context"
"encoding/base64"
"errors"
"io/ioutil"
@ -18,7 +19,7 @@ func (s *service) deluge(action domain.Action, torrentFile string) error {
var err error
// get client for action
client, err := s.clientSvc.FindByID(action.ClientID)
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
if err != nil {
log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID)
return err
@ -56,7 +57,7 @@ func (s *service) delugeCheckRulesCanDownload(action domain.Action) (bool, error
log.Trace().Msgf("action Deluge: %v check rules", action.Name)
// get client for action
client, err := s.clientSvc.FindByID(action.ClientID)
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
if err != nil {
log.Error().Stack().Err(err).Msgf("error finding client: %v ID %v", action.Name, action.ClientID)
return false, err

View file

@ -1,6 +1,7 @@
package action
import (
"context"
"time"
"github.com/autobrr/autobrr/internal/domain"
@ -15,7 +16,7 @@ func (s *service) lidarr(release domain.Release, action domain.Action) error {
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(action.ClientID)
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
if err != nil {
log.Error().Err(err).Msgf("error finding client: %v", action.ClientID)
return err

View file

@ -1,6 +1,7 @@
package action
import (
"context"
"strconv"
"time"
@ -12,36 +13,9 @@ import (
const REANNOUNCE_MAX_ATTEMPTS = 30
const REANNOUNCE_INTERVAL = 7000
func (s *service) qbittorrent(action domain.Action, hash string, torrentFile string) error {
func (s *service) qbittorrent(qbt *qbittorrent.Client, action domain.Action, hash string, torrentFile string) error {
log.Debug().Msgf("action qBittorrent: %v", action.Name)
// get client for action
client, err := s.clientSvc.FindByID(action.ClientID)
if err != nil {
log.Error().Stack().Err(err).Msgf("error finding client: %v ID %v", action.Name, action.ClientID)
return err
}
if client == nil {
return err
}
qbtSettings := qbittorrent.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Username: client.Username,
Password: client.Password,
SSL: client.SSL,
}
qbt := qbittorrent.NewClient(qbtSettings)
// save cookies?
err = qbt.Login()
if err != nil {
log.Error().Stack().Err(err).Msgf("error logging into client: %v %v", client.Name, client.Host)
return err
}
options := map[string]string{}
if action.Paused {
@ -66,9 +40,9 @@ func (s *service) qbittorrent(action domain.Action, hash string, torrentFile str
log.Trace().Msgf("action qBittorrent options: %+v", options)
err = qbt.AddTorrentFromFile(torrentFile, options)
err := qbt.AddTorrentFromFile(torrentFile, options)
if err != nil {
log.Error().Stack().Err(err).Msgf("could not add torrent %v to client: %v", torrentFile, client.Name)
log.Error().Stack().Err(err).Msgf("could not add torrent %v to client: %v", torrentFile, qbt.Name)
return err
}
@ -80,23 +54,23 @@ func (s *service) qbittorrent(action domain.Action, hash string, torrentFile str
}
}
log.Info().Msgf("torrent with hash %v successfully added to client: '%v'", hash, client.Name)
log.Info().Msgf("torrent with hash %v successfully added to client: '%v'", hash, qbt.Name)
return nil
}
func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool, error) {
func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool, *qbittorrent.Client, error) {
log.Trace().Msgf("action qBittorrent: %v check rules", action.Name)
// get client for action
client, err := s.clientSvc.FindByID(action.ClientID)
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
if err != nil {
log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID)
return false, err
return false, nil, err
}
if client == nil {
return false, err
return false, nil, err
}
qbtSettings := qbittorrent.Settings{
@ -108,11 +82,12 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
}
qbt := qbittorrent.NewClient(qbtSettings)
qbt.Name = client.Name
// save cookies?
err = qbt.Login()
if err != nil {
log.Error().Stack().Err(err).Msgf("error logging into client: %v", client.Host)
return false, err
return false, nil, err
}
// check for active downloads and other rules
@ -120,7 +95,7 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
activeDownloads, err := qbt.GetTorrentsFilter(qbittorrent.TorrentFilterDownloading)
if err != nil {
log.Error().Stack().Err(err).Msg("could not fetch downloading torrents")
return false, err
return false, nil, err
}
// make sure it's not set to 0 by default
@ -133,14 +108,14 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
info, err := qbt.GetTransferInfo()
if err != nil {
log.Error().Err(err).Msg("could not get transfer info")
return false, err
return false, nil, err
}
// if current transfer speed is more than threshold return out and skip
// DlInfoSpeed is in bytes so lets convert to KB to match DownloadSpeedThreshold
if info.DlInfoSpeed/1024 >= client.Settings.Rules.DownloadSpeedThreshold {
log.Debug().Msg("max active downloads reached, skipping")
return false, nil
return false, nil, nil
}
log.Debug().Msg("active downloads are slower than set limit, lets add it")
@ -149,7 +124,7 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
}
}
return true, nil
return true, qbt, nil
}
func checkTrackerStatus(qb qbittorrent.Client, hash string) error {

View file

@ -1,6 +1,7 @@
package action
import (
"context"
"time"
"github.com/autobrr/autobrr/internal/domain"
@ -15,7 +16,7 @@ func (s *service) radarr(release domain.Release, action domain.Action) error {
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(action.ClientID)
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
if err != nil {
log.Error().Err(err).Msgf("error finding client: %v", action.ClientID)
return err

View file

@ -102,7 +102,7 @@ func (s *service) RunActions(actions []domain.Action, release domain.Release) er
}(action, tmpFile)
case domain.ActionTypeQbittorrent:
canDownload, err := s.qbittorrentCheckRulesCanDownload(action)
canDownload, client, err := s.qbittorrentCheckRulesCanDownload(action)
if err != nil {
log.Error().Stack().Err(err).Msgf("error checking client rules: %v", action.Name)
continue
@ -131,7 +131,7 @@ func (s *service) RunActions(actions []domain.Action, release domain.Release) er
}
go func(action domain.Action, hash string, tmpFile string) {
err = s.qbittorrent(action, hash, tmpFile)
err = s.qbittorrent(client, action, hash, tmpFile)
if err != nil {
log.Error().Stack().Err(err).Msg("error sending torrent to qBittorrent")
}
@ -206,7 +206,7 @@ func (s *service) CheckCanDownload(actions []domain.Action) bool {
return true
case domain.ActionTypeQbittorrent:
canDownload, err := s.qbittorrentCheckRulesCanDownload(action)
canDownload, _, err := s.qbittorrentCheckRulesCanDownload(action)
if err != nil {
log.Error().Stack().Err(err).Msgf("error checking client rules: %v", action.Name)
continue

View file

@ -1,6 +1,7 @@
package action
import (
"context"
"time"
"github.com/autobrr/autobrr/internal/domain"
@ -15,7 +16,7 @@ func (s *service) sonarr(release domain.Release, action domain.Action) error {
// TODO validate data
// get client for action
client, err := s.clientSvc.FindByID(action.ClientID)
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
if err != nil {
log.Error().Err(err).Msgf("error finding client: %v", action.ClientID)
return err

View file

@ -3,23 +3,24 @@ package database
import (
"context"
"database/sql"
"github.com/autobrr/autobrr/internal/domain"
"github.com/rs/zerolog/log"
)
type ActionRepo struct {
db *sql.DB
db *SqliteDB
}
func NewActionRepo(db *sql.DB) domain.ActionRepo {
func NewActionRepo(db *SqliteDB) domain.ActionRepo {
return &ActionRepo{db: db}
}
func (r *ActionRepo) FindByFilterID(filterID int) ([]domain.Action, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action WHERE action.filter_id = ?", filterID)
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action WHERE action.filter_id = ?", filterID)
if err != nil {
log.Fatal().Err(err)
}
@ -66,8 +67,10 @@ func (r *ActionRepo) FindByFilterID(filterID int) ([]domain.Action, error) {
}
func (r *ActionRepo) List() ([]domain.Action, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action")
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action")
if err != nil {
log.Fatal().Err(err)
}
@ -110,7 +113,10 @@ func (r *ActionRepo) List() ([]domain.Action, error) {
}
func (r *ActionRepo) Delete(actionID int) error {
res, err := r.db.Exec(`DELETE FROM action WHERE action.id = ?`, actionID)
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
res, err := r.db.handler.Exec(`DELETE FROM action WHERE action.id = ?`, actionID)
if err != nil {
return err
}
@ -123,7 +129,10 @@ func (r *ActionRepo) Delete(actionID int) error {
}
func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error {
_, err := r.db.ExecContext(ctx, `DELETE FROM action WHERE filter_id = ?`, filterID)
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
_, err := r.db.handler.ExecContext(ctx, `DELETE FROM action WHERE filter_id = ?`, filterID)
if err != nil {
log.Error().Stack().Err(err).Msg("actions: error deleting by filterid")
return err
@ -135,6 +144,8 @@ func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error {
}
func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.Action, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
execCmd := toNullString(action.ExecCmd)
execArgs := toNullString(action.ExecArgs)
@ -152,12 +163,12 @@ func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.A
var err error
if action.ID != 0 {
log.Debug().Msg("actions: update existing record")
_, err = r.db.ExecContext(ctx, `UPDATE action SET name = ?, type = ?, enabled = ?, exec_cmd = ?, exec_args = ?, watch_folder = ? , category =? , tags = ?, label = ?, save_path = ?, paused = ?, ignore_rules = ?, limit_upload_speed = ?, limit_download_speed = ?, client_id = ?
_, err = r.db.handler.ExecContext(ctx, `UPDATE action SET name = ?, type = ?, enabled = ?, exec_cmd = ?, exec_args = ?, watch_folder = ? , category =? , tags = ?, label = ?, save_path = ?, paused = ?, ignore_rules = ?, limit_upload_speed = ?, limit_download_speed = ?, client_id = ?
WHERE id = ?`, action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, clientID, action.ID)
} else {
var res sql.Result
res, err = r.db.ExecContext(ctx, `INSERT INTO action(name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_upload_speed, limit_download_speed, client_id, filter_id)
res, err = r.db.handler.ExecContext(ctx, `INSERT INTO action(name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_upload_speed, limit_download_speed, client_id, filter_id)
VALUES (?, ?, ?, ?, ?,? ,?, ?,?,?,?,?,?,?,?,?) ON CONFLICT DO NOTHING`, action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, clientID, filterID)
if err != nil {
log.Error().Err(err)
@ -173,8 +184,10 @@ func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.A
}
func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Action, filterID int64) ([]domain.Action, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
tx, err := r.db.BeginTx(ctx, nil)
tx, err := r.db.handler.BeginTx(ctx, nil)
if err != nil {
return nil, err
}
@ -227,11 +240,13 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Ac
}
func (r *ActionRepo) ToggleEnabled(actionID int) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
var err error
var res sql.Result
res, err = r.db.Exec(`UPDATE action SET enabled = NOT enabled WHERE id = ?`, actionID)
res, err = r.db.handler.Exec(`UPDATE action SET enabled = NOT enabled WHERE id = ?`, actionID)
if err != nil {
log.Error().Err(err)
return err

View file

@ -1,6 +1,7 @@
package database
import (
"context"
"database/sql"
"encoding/json"
@ -10,16 +11,48 @@ import (
)
type DownloadClientRepo struct {
db *sql.DB
db *SqliteDB
cache *clientCache
}
func NewDownloadClientRepo(db *sql.DB) domain.DownloadClientRepo {
return &DownloadClientRepo{db: db}
type clientCache struct {
clients map[int]*domain.DownloadClient
}
func NewClientCache() *clientCache {
return &clientCache{
clients: make(map[int]*domain.DownloadClient, 0),
}
}
func (c *clientCache) Set(id int, client *domain.DownloadClient) {
c.clients[id] = client
}
func (c *clientCache) Get(id int) *domain.DownloadClient {
v, ok := c.clients[id]
if ok {
return v
}
return nil
}
func (c *clientCache) Pop(id int) {
delete(c.clients, id)
}
func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
return &DownloadClientRepo{
db: db,
cache: NewClientCache(),
}
}
func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client")
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client")
if err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err
@ -55,13 +88,21 @@ func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
return clients, nil
}
func (r *DownloadClientRepo) FindByID(id int32) (*domain.DownloadClient, error) {
func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
// get client from cache
c := r.cache.Get(int(id))
if c != nil {
return c, nil
}
query := `
SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client WHERE id = ?
`
row := r.db.QueryRow(query, id)
row := r.db.handler.QueryRowContext(ctx, query, id)
if err := row.Err(); err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err
@ -86,6 +127,9 @@ func (r *DownloadClientRepo) FindByID(id int32) (*domain.DownloadClient, error)
}
func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.DownloadClient, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
var err error
settings := domain.DownloadClientSettings{
@ -101,7 +145,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
}
if client.ID != 0 {
_, err = r.db.Exec(`
_, err = r.db.handler.Exec(`
UPDATE
client
SET
@ -134,7 +178,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
} else {
var res sql.Result
res, err = r.db.Exec(`INSERT INTO
res, err = r.db.handler.Exec(`INSERT INTO
client(
name,
type,
@ -170,16 +214,25 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
log.Info().Msgf("store download client: %v", client.Name)
log.Trace().Msgf("store download client: %+v", client)
// save to cache
r.cache.Set(client.ID, &client)
return &client, nil
}
func (r *DownloadClientRepo) Delete(clientID int) error {
res, err := r.db.Exec(`DELETE FROM client WHERE client.id = ?`, clientID)
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
res, err := r.db.handler.Exec(`DELETE FROM client WHERE client.id = ?`, clientID)
if err != nil {
log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID)
return err
}
// remove from cache
r.cache.Pop(clientID)
rows, _ := res.RowsAffected()
if rows == 0 {

View file

@ -13,16 +13,18 @@ import (
)
type FilterRepo struct {
db *sql.DB
db *SqliteDB
}
func NewFilterRepo(db *sql.DB) domain.FilterRepo {
func NewFilterRepo(db *SqliteDB) domain.FilterRepo {
return &FilterRepo{db: db}
}
func (r *FilterRepo) ListFilters() ([]domain.Filter, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter")
rows, err := r.db.handler.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter")
if err != nil {
log.Fatal().Err(err)
}
@ -62,8 +64,10 @@ func (r *FilterRepo) ListFilters() ([]domain.Filter, error) {
}
func (r *FilterRepo) FindByID(filterID int) (*domain.Filter, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
row := r.db.QueryRow("SELECT id, enabled, name, min_size, max_size, delay, match_releases, except_releases, use_regex, match_release_groups, except_release_groups, scene, freeleech, freeleech_percent, shows, seasons, episodes, resolutions, codecs, sources, containers, years, match_categories, except_categories, match_uploaders, except_uploaders, tags, except_tags, created_at, updated_at FROM filter WHERE id = ?", filterID)
row := r.db.handler.QueryRow("SELECT id, enabled, name, min_size, max_size, delay, match_releases, except_releases, use_regex, match_release_groups, except_release_groups, scene, freeleech, freeleech_percent, shows, seasons, episodes, resolutions, codecs, sources, containers, years, match_categories, except_categories, match_uploaders, except_uploaders, tags, except_tags, created_at, updated_at FROM filter WHERE id = ?", filterID)
var f domain.Filter
@ -114,8 +118,10 @@ func (r *FilterRepo) FindByID(filterID int) (*domain.Filter, error) {
// TODO remove
func (r *FilterRepo) FindFiltersForSite(site string) ([]domain.Filter, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter WHERE match_sites LIKE ?", site)
rows, err := r.db.handler.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter WHERE match_sites LIKE ?", site)
if err != nil {
log.Fatal().Err(err)
}
@ -144,8 +150,10 @@ func (r *FilterRepo) FindFiltersForSite(site string) ([]domain.Filter, error) {
// FindByIndexerIdentifier find active filters only
func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.Query(`
rows, err := r.db.handler.Query(`
SELECT
f.id,
f.enabled,
@ -241,6 +249,8 @@ func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, e
}
func (r *FilterRepo) Store(filter domain.Filter) (*domain.Filter, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
var err error
if filter.ID != 0 {
@ -248,7 +258,7 @@ func (r *FilterRepo) Store(filter domain.Filter) (*domain.Filter, error) {
} else {
var res sql.Result
res, err = r.db.Exec(`INSERT INTO filter (
res, err = r.db.handler.Exec(`INSERT INTO filter (
name,
enabled,
min_size,
@ -319,11 +329,13 @@ func (r *FilterRepo) Store(filter domain.Filter) (*domain.Filter, error) {
}
func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain.Filter, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
//var res sql.Result
var err error
_, err = r.db.ExecContext(ctx, `
_, err = r.db.handler.ExecContext(ctx, `
UPDATE filter SET
name = ?,
enabled = ?,
@ -392,9 +404,11 @@ func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain.
}
func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bool) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
var err error
_, err = r.db.ExecContext(ctx, `
_, err = r.db.handler.ExecContext(ctx, `
UPDATE filter SET
enabled = ?,
updated_at = CURRENT_TIMESTAMP
@ -411,8 +425,10 @@ func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bo
}
func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int, indexers []domain.Indexer) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
tx, err := r.db.BeginTx(ctx, nil)
tx, err := r.db.handler.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -447,8 +463,11 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int,
}
func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, indexerID int) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
query := `INSERT INTO filter_indexer (filter_id, indexer_id) VALUES ($1, $2)`
_, err := r.db.ExecContext(ctx, query, filterID, indexerID)
_, err := r.db.handler.ExecContext(ctx, query, filterID, indexerID)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return err
@ -458,9 +477,11 @@ func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, i
}
func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
query := `DELETE FROM filter_indexer WHERE filter_id = ?`
_, err := r.db.ExecContext(ctx, query, filterID)
_, err := r.db.handler.ExecContext(ctx, query, filterID)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return err
@ -470,8 +491,10 @@ func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int)
}
func (r *FilterRepo) Delete(ctx context.Context, filterID int) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
_, err := r.db.ExecContext(ctx, `DELETE FROM filter WHERE id = ?`, filterID)
_, err := r.db.handler.ExecContext(ctx, `DELETE FROM filter WHERE id = ?`, filterID)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return err

View file

@ -2,23 +2,24 @@ package database
import (
"context"
"database/sql"
"encoding/json"
"github.com/autobrr/autobrr/internal/domain"
"github.com/rs/zerolog/log"
)
type IndexerRepo struct {
db *sql.DB
db *SqliteDB
}
func NewIndexerRepo(db *sql.DB) domain.IndexerRepo {
func NewIndexerRepo(db *SqliteDB) domain.IndexerRepo {
return &IndexerRepo{
db: db,
}
}
func (r *IndexerRepo) Store(indexer domain.Indexer) (*domain.Indexer, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
settings, err := json.Marshal(indexer.Settings)
if err != nil {
@ -26,7 +27,7 @@ func (r *IndexerRepo) Store(indexer domain.Indexer) (*domain.Indexer, error) {
return nil, err
}
res, err := r.db.Exec(`INSERT INTO indexer (enabled, name, identifier, settings) VALUES (?, ?, ?, ?)`, indexer.Enabled, indexer.Name, indexer.Identifier, settings)
res, err := r.db.handler.Exec(`INSERT INTO indexer (enabled, name, identifier, settings) VALUES (?, ?, ?, ?)`, indexer.Enabled, indexer.Name, indexer.Identifier, settings)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return nil, err
@ -39,6 +40,8 @@ func (r *IndexerRepo) Store(indexer domain.Indexer) (*domain.Indexer, error) {
}
func (r *IndexerRepo) Update(indexer domain.Indexer) (*domain.Indexer, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
sett, err := json.Marshal(indexer.Settings)
if err != nil {
@ -46,7 +49,7 @@ func (r *IndexerRepo) Update(indexer domain.Indexer) (*domain.Indexer, error) {
return nil, err
}
_, err = r.db.Exec(`UPDATE indexer SET enabled = ?, name = ?, settings = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, indexer.Enabled, indexer.Name, sett, indexer.ID)
_, err = r.db.handler.Exec(`UPDATE indexer SET enabled = ?, name = ?, settings = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, indexer.Enabled, indexer.Name, sett, indexer.ID)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return nil, err
@ -56,8 +59,10 @@ func (r *IndexerRepo) Update(indexer domain.Indexer) (*domain.Indexer, error) {
}
func (r *IndexerRepo) List() ([]domain.Indexer, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.Query("SELECT id, enabled, name, identifier, settings FROM indexer ORDER BY name ASC")
rows, err := r.db.handler.Query("SELECT id, enabled, name, identifier, settings FROM indexer ORDER BY name ASC")
if err != nil {
log.Fatal().Err(err)
}
@ -96,7 +101,10 @@ func (r *IndexerRepo) List() ([]domain.Indexer, error) {
}
func (r *IndexerRepo) FindByFilterID(id int) ([]domain.Indexer, error) {
rows, err := r.db.Query(`
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := r.db.handler.Query(`
SELECT i.id, i.enabled, i.name, i.identifier
FROM indexer i
JOIN filter_indexer fi on i.id = fi.indexer_id
@ -140,9 +148,12 @@ func (r *IndexerRepo) FindByFilterID(id int) ([]domain.Indexer, error) {
}
func (r *IndexerRepo) Delete(ctx context.Context, id int) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
query := `DELETE FROM indexer WHERE id = ?`
_, err := r.db.ExecContext(ctx, query, id)
_, err := r.db.handler.ExecContext(ctx, query, id)
if err != nil {
log.Error().Stack().Err(err).Msgf("indexer.delete: error executing query: '%v'", query)
return err

View file

@ -12,16 +12,18 @@ import (
)
type IrcRepo struct {
db *sql.DB
db *SqliteDB
}
func NewIrcRepo(db *sql.DB) domain.IrcRepo {
func NewIrcRepo(db *SqliteDB) domain.IrcRepo {
return &IrcRepo{db: db}
}
func (ir *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) {
func (r *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
row := ir.db.QueryRow("SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE id = ?", id)
row := r.db.handler.QueryRow("SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE id = ?", id)
if err := row.Err(); err != nil {
log.Fatal().Err(err)
return nil, err
@ -46,8 +48,8 @@ func (ir *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) {
return &n, nil
}
func (ir *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
tx, err := ir.db.BeginTx(ctx, nil)
func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
tx, err := r.db.handler.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -76,9 +78,11 @@ func (ir *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
return nil
}
func (ir *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := ir.db.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE enabled = true")
rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE enabled = true")
if err != nil {
log.Fatal().Err(err)
}
@ -109,9 +113,11 @@ func (ir *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork,
return networks, nil
}
func (ir *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := ir.db.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network ORDER BY name ASC")
rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network ORDER BY name ASC")
if err != nil {
log.Fatal().Err(err)
}
@ -142,9 +148,11 @@ func (ir *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error
return networks, nil
}
func (ir *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
rows, err := ir.db.Query("SELECT id, name, enabled FROM irc_channel WHERE network_id = ?", networkID)
rows, err := r.db.handler.Query("SELECT id, name, enabled FROM irc_channel WHERE network_id = ?", networkID)
if err != nil {
log.Fatal().Err(err)
}
@ -167,7 +175,9 @@ func (ir *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
return channels, nil
}
func (ir *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) {
func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
queryBuilder := sq.
Select("id", "enabled", "name", "server", "port", "tls", "pass", "invite_command", "nickserv_account", "nickserv_password").
@ -182,7 +192,7 @@ func (ir *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.Irc
}
log.Trace().Str("database", "irc.check_existing_network").Msgf("query: '%v', args: '%v'", query, args)
row := ir.db.QueryRowContext(ctx, query, args...)
row := r.db.handler.QueryRowContext(ctx, query, args...)
var net domain.IrcNetwork
@ -206,7 +216,9 @@ func (ir *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.Irc
return &net, nil
}
func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
func (r *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
netName := toNullString(network.Name)
pass := toNullString(network.Pass)
@ -218,7 +230,7 @@ func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
var err error
if network.ID != 0 {
// update record
_, err = ir.db.Exec(`UPDATE irc_network
_, err = r.db.handler.Exec(`UPDATE irc_network
SET enabled = ?,
name = ?,
server = ?,
@ -248,7 +260,7 @@ func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
} else {
var res sql.Result
res, err = ir.db.Exec(`INSERT INTO irc_network (
res, err = r.db.handler.Exec(`INSERT INTO irc_network (
enabled,
name,
server,
@ -280,7 +292,9 @@ func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
return err
}
func (ir *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error {
func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
netName := toNullString(network.Name)
pass := toNullString(network.Pass)
@ -291,7 +305,7 @@ func (ir *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork
var err error
// update record
_, err = ir.db.ExecContext(ctx, `UPDATE irc_network
_, err = r.db.handler.ExecContext(ctx, `UPDATE irc_network
SET enabled = ?,
name = ?,
server = ?,
@ -324,9 +338,11 @@ func (ir *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork
// TODO create new channel handler to only add, not delete
func (ir *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, channels []domain.IrcChannel) error {
func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, channels []domain.IrcChannel) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
tx, err := ir.db.BeginTx(ctx, nil)
tx, err := r.db.handler.BeginTx(ctx, nil)
if err != nil {
return err
}
@ -373,13 +389,16 @@ func (ir *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, ch
return nil
}
func (ir *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) error {
func (r *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
pass := toNullString(channel.Password)
var err error
if channel.ID != 0 {
// update record
_, err = ir.db.Exec(`UPDATE irc_channel
_, err = r.db.handler.Exec(`UPDATE irc_channel
SET
enabled = ?,
detached = ?,
@ -396,7 +415,7 @@ func (ir *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) err
} else {
var res sql.Result
res, err = ir.db.Exec(`INSERT INTO irc_channel (
res, err = r.db.handler.Exec(`INSERT INTO irc_channel (
enabled,
detached,
name,

View file

@ -1,7 +1,6 @@
package database
import (
"database/sql"
"fmt"
)
@ -283,9 +282,12 @@ var migrations = []string{
`,
}
func Migrate(db *sql.DB) error {
func (db *SqliteDB) migrate() error {
db.lock.Lock()
defer db.lock.Unlock()
var version int
if err := db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
if err := db.handler.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
return fmt.Errorf("failed to query schema version: %v", err)
}
@ -295,7 +297,7 @@ func Migrate(db *sql.DB) error {
return fmt.Errorf("autobrr (version %d) older than schema (version: %d)", len(migrations), version)
}
tx, err := db.Begin()
tx, err := db.handler.Begin()
if err != nil {
return err
}

View file

@ -13,21 +13,24 @@ import (
)
type ReleaseRepo struct {
db *sql.DB
db *SqliteDB
}
func NewReleaseRepo(db *sql.DB) domain.ReleaseRepo {
func NewReleaseRepo(db *SqliteDB) domain.ReleaseRepo {
return &ReleaseRepo{db: db}
}
func (repo *ReleaseRepo) Store(ctx context.Context, r *domain.Release) (*domain.Release, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
query, args, err := sq.
Insert("release").
Columns("filter_status", "rejections", "indexer", "filter", "protocol", "implementation", "timestamp", "group_id", "torrent_id", "torrent_name", "size", "raw", "title", "category", "season", "episode", "year", "resolution", "source", "codec", "container", "hdr", "audio", "release_group", "region", "language", "edition", "unrated", "hybrid", "proper", "repack", "website", "artists", "type", "format", "bitrate", "log_score", "has_log", "has_cue", "is_scene", "origin", "tags", "freeleech", "freeleech_percent", "uploader", "pre_time").
Values(r.FilterStatus, pq.Array(r.Rejections), r.Indexer, r.FilterName, r.Protocol, r.Implementation, r.Timestamp, r.GroupID, r.TorrentID, r.TorrentName, r.Size, r.Raw, r.Title, r.Category, r.Season, r.Episode, r.Year, r.Resolution, r.Source, r.Codec, r.Container, r.HDR, r.Audio, r.Group, r.Region, r.Language, r.Edition, r.Unrated, r.Hybrid, r.Proper, r.Repack, r.Website, pq.Array(r.Artists), r.Type, r.Format, r.Bitrate, r.LogScore, r.HasLog, r.HasCue, r.IsScene, r.Origin, pq.Array(r.Tags), r.Freeleech, r.FreeleechPercent, r.Uploader, r.PreTime).
ToSql()
res, err := repo.db.ExecContext(ctx, query, args...)
res, err := repo.db.handler.ExecContext(ctx, query, args...)
if err != nil {
log.Error().Stack().Err(err).Msg("error inserting release")
return nil, err
@ -42,6 +45,8 @@ func (repo *ReleaseRepo) Store(ctx context.Context, r *domain.Release) (*domain.
}
func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain.ReleaseActionStatus) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
if a.ID != 0 {
query, args, err := sq.
@ -53,7 +58,7 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
Where("release_id = ?", a.ReleaseID).
ToSql()
_, err = repo.db.ExecContext(ctx, query, args...)
_, err = repo.db.handler.ExecContext(ctx, query, args...)
if err != nil {
log.Error().Stack().Err(err).Msg("error updating status of release")
return err
@ -66,7 +71,7 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
Values(a.Status, a.Action, a.Type, pq.Array(a.Rejections), a.Timestamp, a.ReleaseID).
ToSql()
res, err := repo.db.ExecContext(ctx, query, args...)
res, err := repo.db.handler.ExecContext(ctx, query, args...)
if err != nil {
log.Error().Stack().Err(err).Msg("error inserting status of release")
return err
@ -82,6 +87,8 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
}
func (repo *ReleaseRepo) Find(ctx context.Context, params domain.QueryParams) ([]domain.Release, int64, int64, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
queryBuilder := sq.
Select("id", "filter_status", "rejections", "indexer", "filter", "protocol", "title", "torrent_name", "size", "timestamp", "COUNT() OVER() AS total_count").
@ -116,7 +123,7 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.QueryParams) ([
res := make([]domain.Release, 0)
rows, err := repo.db.QueryContext(ctx, query, args...)
rows, err := repo.db.handler.QueryContext(ctx, query, args...)
if err != nil {
log.Error().Stack().Err(err).Msg("error fetching releases")
return res, 0, 0, nil
@ -171,6 +178,8 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.QueryParams) ([
}
func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, releaseID int64) ([]domain.ReleaseActionStatus, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
queryBuilder := sq.
Select("id", "status", "action", "type", "rejections", "timestamp").
@ -181,7 +190,7 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release
res := make([]domain.ReleaseActionStatus, 0)
rows, err := repo.db.QueryContext(ctx, query, args...)
rows, err := repo.db.handler.QueryContext(ctx, query, args...)
if err != nil {
log.Error().Stack().Err(err).Msg("error fetching releases")
return res, nil
@ -214,6 +223,9 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release
}
func (repo *ReleaseRepo) Stats(ctx context.Context) (*domain.ReleaseStats, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
query := `SELECT COUNT(*) total,
IFNULL(SUM(CASE WHEN filter_status = 'FILTER_APPROVED' THEN 1 ELSE 0 END), 0) filtered_count,
IFNULL(SUM(CASE WHEN filter_status = 'FILTER_REJECTED' THEN 1 ELSE 0 END), 0) filter_rejected_count,
@ -223,7 +235,7 @@ func (repo *ReleaseRepo) Stats(ctx context.Context) (*domain.ReleaseStats, error
FROM "release_action_status") AS push_rejected_count
FROM "release";`
row := repo.db.QueryRowContext(ctx, query)
row := repo.db.handler.QueryRowContext(ctx, query)
if err := row.Err(); err != nil {
log.Error().Stack().Err(err).Msg("release.stats: error querying stats")
return nil, err
@ -231,7 +243,7 @@ FROM "release";`
var rls domain.ReleaseStats
if err := row.Scan(&rls.TotalCount, &rls.PushApprovedCount, &rls.PushRejectedCount, &rls.FilteredCount, &rls.FilterRejectedCount); err != nil {
if err := row.Scan(&rls.TotalCount, &rls.FilteredCount, &rls.FilterRejectedCount, &rls.PushApprovedCount, &rls.PushRejectedCount); err != nil {
log.Error().Stack().Err(err).Msg("release.stats: error scanning stats data to struct")
return nil, err
}

View file

@ -0,0 +1,81 @@
package database
import (
"context"
"database/sql"
"fmt"
"sync"
_ "github.com/mattn/go-sqlite3"
"github.com/rs/zerolog/log"
)
type SqliteDB struct {
lock sync.RWMutex
handler *sql.DB
ctx context.Context
cancel func()
DSN string
}
func NewSqliteDB(source string) *SqliteDB {
db := &SqliteDB{
DSN: dataSourceName(source, "autobrr.db"),
}
db.ctx, db.cancel = context.WithCancel(context.Background())
return db
}
func (db *SqliteDB) Open() error {
if db.DSN == "" {
return fmt.Errorf("DSN required")
}
var err error
// open database connection
if db.handler, err = sql.Open("sqlite3", db.DSN); err != nil {
log.Fatal().Err(err).Msg("could not open db connection")
return err
}
// Set busy timeout
if _, err = db.handler.Exec(`PRAGMA busy_timeout = 5000;`); err != nil {
return fmt.Errorf("busy timeout pragma")
}
// Enable WAL. SQLite performs better with the WAL because it allows
// multiple readers to operate while data is being written.
if _, err = db.handler.Exec(`PRAGMA journal_mode = wal;`); err != nil {
return fmt.Errorf("enable wal: %w", err)
}
// Enable foreign key checks. For historical reasons, SQLite does not check
// foreign key constraints by default. There's some overhead on inserts to
// verify foreign key integrity, but it's definitely worth it.
if _, err = db.handler.Exec(`PRAGMA foreign_keys = ON;`); err != nil {
return fmt.Errorf("foreign keys pragma: %w", err)
}
// migrate db
if err = db.migrate(); err != nil {
log.Fatal().Err(err).Msg("could not migrate db")
return err
}
return nil
}
func (db *SqliteDB) Close() error {
// cancel background context
db.cancel()
// close database
if db.handler != nil {
return db.handler.Close()
}
return nil
}

View file

@ -2,25 +2,26 @@ package database
import (
"context"
"database/sql"
"github.com/rs/zerolog/log"
"github.com/autobrr/autobrr/internal/domain"
)
type UserRepo struct {
db *sql.DB
db *SqliteDB
}
func NewUserRepo(db *sql.DB) domain.UserRepo {
func NewUserRepo(db *SqliteDB) domain.UserRepo {
return &UserRepo{db: db}
}
func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain.User, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
query := `SELECT id, username, password FROM users WHERE username = ?`
row := r.db.QueryRowContext(ctx, query, username)
row := r.db.handler.QueryRowContext(ctx, query, username)
if err := row.Err(); err != nil {
return nil, err
}
@ -36,11 +37,13 @@ func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain
}
func (r *UserRepo) Store(ctx context.Context, user domain.User) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
var err error
if user.ID != 0 {
update := `UPDATE users SET password = ? WHERE username = ?`
_, err = r.db.ExecContext(ctx, update, user.Password, user.Username)
_, err = r.db.handler.ExecContext(ctx, update, user.Password, user.Username)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return err
@ -48,7 +51,7 @@ func (r *UserRepo) Store(ctx context.Context, user domain.User) error {
} else {
query := `INSERT INTO users (username, password) VALUES (?, ?)`
_, err = r.db.ExecContext(ctx, query, user.Username, user.Password)
_, err = r.db.handler.ExecContext(ctx, query, user.Username, user.Password)
if err != nil {
log.Error().Stack().Err(err).Msg("error executing query")
return err

View file

@ -4,7 +4,7 @@ import (
"path"
)
func DataSourceName(configPath string, name string) string {
func dataSourceName(configPath string, name string) string {
if configPath != "" {
return path.Join(configPath, name)
}

View file

@ -51,7 +51,7 @@ func TestDataSourceName(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := DataSourceName(tt.args.configPath, tt.args.name)
got := dataSourceName(tt.args.configPath, tt.args.name)
assert.Equal(t, tt.want, got)
})
}

View file

@ -1,9 +1,11 @@
package domain
import "context"
type DownloadClientRepo interface {
//FindByActionID(actionID int) ([]DownloadClient, error)
List() ([]DownloadClient, error)
FindByID(id int32) (*DownloadClient, error)
FindByID(ctx context.Context, id int32) (*DownloadClient, error)
Store(client DownloadClient) (*DownloadClient, error)
Delete(clientID int) error
}

View file

@ -1,6 +1,7 @@
package download_client
import (
"context"
"errors"
"github.com/autobrr/autobrr/internal/domain"
@ -10,7 +11,7 @@ import (
type Service interface {
List() ([]domain.DownloadClient, error)
FindByID(id int32) (*domain.DownloadClient, error)
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
Store(client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(clientID int) error
Test(client domain.DownloadClient) error
@ -33,8 +34,8 @@ func (s *service) List() ([]domain.DownloadClient, error) {
return clients, nil
}
func (s *service) FindByID(id int32) (*domain.DownloadClient, error) {
client, err := s.repo.FindByID(id)
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
client, err := s.repo.FindByID(ctx, id)
if err != nil {
return nil, err
}

View file

@ -17,6 +17,7 @@ import (
)
type Client struct {
Name string
settings Settings
http *http.Client
}