mirror of
https://github.com/idanoo/autobrr
synced 2025-07-22 16:29:12 +00:00
Fix: Performance issues and sqlite locking (#74)
* fix: performance issues and sqlite locking * fix: dashboard release stats was reversed * refactor: open and migrate db * chore: cleanup
This commit is contained in:
parent
d8c37dde2f
commit
f466657ed4
25 changed files with 362 additions and 658 deletions
|
@ -4,8 +4,6 @@ before:
|
|||
|
||||
builds:
|
||||
- id: autobrr
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- windows
|
||||
|
@ -23,8 +21,6 @@ builds:
|
|||
main: ./cmd/autobrr/main.go
|
||||
binary: autobrr
|
||||
- id: autobrrctl
|
||||
env:
|
||||
- CGO_ENABLED=0
|
||||
goos:
|
||||
- linux
|
||||
- windows
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
@ -10,7 +9,6 @@ import (
|
|||
"github.com/r3labs/sse/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/spf13/pflag"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/action"
|
||||
"github.com/autobrr/autobrr/internal/auth"
|
||||
|
@ -60,19 +58,11 @@ func main() {
|
|||
log.Info().Msgf("Version: %v", version)
|
||||
log.Info().Msgf("Log-level: %v", cfg.LogLevel)
|
||||
|
||||
// if configPath is set then put database inside that path, otherwise create wherever it's run
|
||||
var dataSource = database.DataSourceName(configPath, "autobrr.db")
|
||||
|
||||
// open database connection
|
||||
db, err := sql.Open("sqlite", dataSource)
|
||||
if err != nil {
|
||||
db := database.NewSqliteDB(configPath)
|
||||
if err := db.Open(); err != nil {
|
||||
log.Fatal().Err(err).Msg("could not open db connection")
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err = database.Migrate(db); err != nil {
|
||||
log.Fatal().Err(err).Msg("could not migrate db")
|
||||
}
|
||||
|
||||
// setup repos
|
||||
var (
|
||||
|
@ -125,12 +115,15 @@ func main() {
|
|||
case syscall.SIGHUP:
|
||||
log.Print("shutting down server sighup")
|
||||
srv.Shutdown()
|
||||
db.Close()
|
||||
os.Exit(1)
|
||||
case syscall.SIGINT, syscall.SIGQUIT:
|
||||
srv.Shutdown()
|
||||
db.Close()
|
||||
os.Exit(1)
|
||||
case syscall.SIGKILL, syscall.SIGTERM:
|
||||
srv.Shutdown()
|
||||
db.Close()
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,14 +3,13 @@ package main
|
|||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"database/sql"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"golang.org/x/crypto/ssh/terminal"
|
||||
_ "modernc.org/sqlite"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/database"
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
@ -39,18 +38,10 @@ func main() {
|
|||
log.Fatal("--config required")
|
||||
}
|
||||
|
||||
// if configPath is set then put database inside that path, otherwise create wherever it's run
|
||||
var dataSource = database.DataSourceName(configPath, "autobrr.db")
|
||||
|
||||
// open database connection
|
||||
db, err := sql.Open("sqlite", dataSource)
|
||||
if err != nil {
|
||||
log.Fatalf("failed to open database: %v", err)
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
if err = database.Migrate(db); err != nil {
|
||||
log.Fatalf("could not migrate db: %v", err)
|
||||
db := database.NewSqliteDB(configPath)
|
||||
if err := db.Open(); err != nil {
|
||||
log.Fatal("could not open db connection")
|
||||
}
|
||||
|
||||
userRepo := database.NewUserRepo(db)
|
||||
|
|
8
go.mod
8
go.mod
|
@ -4,16 +4,22 @@ go 1.16
|
|||
|
||||
require (
|
||||
github.com/Masterminds/squirrel v1.5.1
|
||||
github.com/anacrolix/torrent v1.38.0
|
||||
github.com/anacrolix/dht/v2 v2.5.1 // indirect
|
||||
github.com/anacrolix/missinggo v1.3.0 // indirect
|
||||
github.com/anacrolix/missinggo/v2 v2.5.2 // indirect
|
||||
github.com/anacrolix/torrent v1.11.0
|
||||
github.com/asaskevich/EventBus v0.0.0-20200907212545-49d423059eef
|
||||
github.com/dustin/go-humanize v1.0.0
|
||||
github.com/gdm85/go-libdeluge v0.5.5
|
||||
github.com/go-chi/chi v1.5.4
|
||||
github.com/gorilla/sessions v1.2.1
|
||||
github.com/kr/pretty v0.3.0 // indirect
|
||||
github.com/lib/pq v1.10.4
|
||||
github.com/mattn/go-isatty v0.0.14 // indirect
|
||||
github.com/mattn/go-sqlite3 v1.14.10
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/r3labs/sse/v2 v2.7.2
|
||||
github.com/rogpeppe/go-internal v1.8.0 // indirect
|
||||
github.com/rs/cors v1.8.0
|
||||
github.com/rs/zerolog v1.26.0
|
||||
github.com/spf13/pflag v1.0.5
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package action
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
|
@ -18,7 +19,7 @@ func (s *service) deluge(action domain.Action, torrentFile string) error {
|
|||
var err error
|
||||
|
||||
// get client for action
|
||||
client, err := s.clientSvc.FindByID(action.ClientID)
|
||||
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID)
|
||||
return err
|
||||
|
@ -56,7 +57,7 @@ func (s *service) delugeCheckRulesCanDownload(action domain.Action) (bool, error
|
|||
log.Trace().Msgf("action Deluge: %v check rules", action.Name)
|
||||
|
||||
// get client for action
|
||||
client, err := s.clientSvc.FindByID(action.ClientID)
|
||||
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error finding client: %v ID %v", action.Name, action.ClientID)
|
||||
return false, err
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package action
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
@ -15,7 +16,7 @@ func (s *service) lidarr(release domain.Release, action domain.Action) error {
|
|||
// TODO validate data
|
||||
|
||||
// get client for action
|
||||
client, err := s.clientSvc.FindByID(action.ClientID)
|
||||
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("error finding client: %v", action.ClientID)
|
||||
return err
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package action
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
|
@ -12,36 +13,9 @@ import (
|
|||
const REANNOUNCE_MAX_ATTEMPTS = 30
|
||||
const REANNOUNCE_INTERVAL = 7000
|
||||
|
||||
func (s *service) qbittorrent(action domain.Action, hash string, torrentFile string) error {
|
||||
func (s *service) qbittorrent(qbt *qbittorrent.Client, action domain.Action, hash string, torrentFile string) error {
|
||||
log.Debug().Msgf("action qBittorrent: %v", action.Name)
|
||||
|
||||
// get client for action
|
||||
client, err := s.clientSvc.FindByID(action.ClientID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error finding client: %v ID %v", action.Name, action.ClientID)
|
||||
return err
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return err
|
||||
}
|
||||
|
||||
qbtSettings := qbittorrent.Settings{
|
||||
Hostname: client.Host,
|
||||
Port: uint(client.Port),
|
||||
Username: client.Username,
|
||||
Password: client.Password,
|
||||
SSL: client.SSL,
|
||||
}
|
||||
|
||||
qbt := qbittorrent.NewClient(qbtSettings)
|
||||
// save cookies?
|
||||
err = qbt.Login()
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error logging into client: %v %v", client.Name, client.Host)
|
||||
return err
|
||||
}
|
||||
|
||||
options := map[string]string{}
|
||||
|
||||
if action.Paused {
|
||||
|
@ -66,9 +40,9 @@ func (s *service) qbittorrent(action domain.Action, hash string, torrentFile str
|
|||
|
||||
log.Trace().Msgf("action qBittorrent options: %+v", options)
|
||||
|
||||
err = qbt.AddTorrentFromFile(torrentFile, options)
|
||||
err := qbt.AddTorrentFromFile(torrentFile, options)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("could not add torrent %v to client: %v", torrentFile, client.Name)
|
||||
log.Error().Stack().Err(err).Msgf("could not add torrent %v to client: %v", torrentFile, qbt.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -80,23 +54,23 @@ func (s *service) qbittorrent(action domain.Action, hash string, torrentFile str
|
|||
}
|
||||
}
|
||||
|
||||
log.Info().Msgf("torrent with hash %v successfully added to client: '%v'", hash, client.Name)
|
||||
log.Info().Msgf("torrent with hash %v successfully added to client: '%v'", hash, qbt.Name)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool, error) {
|
||||
func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool, *qbittorrent.Client, error) {
|
||||
log.Trace().Msgf("action qBittorrent: %v check rules", action.Name)
|
||||
|
||||
// get client for action
|
||||
client, err := s.clientSvc.FindByID(action.ClientID)
|
||||
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error finding client: %v", action.ClientID)
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
if client == nil {
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
qbtSettings := qbittorrent.Settings{
|
||||
|
@ -108,11 +82,12 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
|
|||
}
|
||||
|
||||
qbt := qbittorrent.NewClient(qbtSettings)
|
||||
qbt.Name = client.Name
|
||||
// save cookies?
|
||||
err = qbt.Login()
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error logging into client: %v", client.Host)
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
// check for active downloads and other rules
|
||||
|
@ -120,7 +95,7 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
|
|||
activeDownloads, err := qbt.GetTorrentsFilter(qbittorrent.TorrentFilterDownloading)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("could not fetch downloading torrents")
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
// make sure it's not set to 0 by default
|
||||
|
@ -133,14 +108,14 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
|
|||
info, err := qbt.GetTransferInfo()
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("could not get transfer info")
|
||||
return false, err
|
||||
return false, nil, err
|
||||
}
|
||||
|
||||
// if current transfer speed is more than threshold return out and skip
|
||||
// DlInfoSpeed is in bytes so lets convert to KB to match DownloadSpeedThreshold
|
||||
if info.DlInfoSpeed/1024 >= client.Settings.Rules.DownloadSpeedThreshold {
|
||||
log.Debug().Msg("max active downloads reached, skipping")
|
||||
return false, nil
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
log.Debug().Msg("active downloads are slower than set limit, lets add it")
|
||||
|
@ -149,7 +124,7 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
|
|||
}
|
||||
}
|
||||
|
||||
return true, nil
|
||||
return true, qbt, nil
|
||||
}
|
||||
|
||||
func checkTrackerStatus(qb qbittorrent.Client, hash string) error {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package action
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
@ -15,7 +16,7 @@ func (s *service) radarr(release domain.Release, action domain.Action) error {
|
|||
// TODO validate data
|
||||
|
||||
// get client for action
|
||||
client, err := s.clientSvc.FindByID(action.ClientID)
|
||||
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("error finding client: %v", action.ClientID)
|
||||
return err
|
||||
|
|
|
@ -102,7 +102,7 @@ func (s *service) RunActions(actions []domain.Action, release domain.Release) er
|
|||
}(action, tmpFile)
|
||||
|
||||
case domain.ActionTypeQbittorrent:
|
||||
canDownload, err := s.qbittorrentCheckRulesCanDownload(action)
|
||||
canDownload, client, err := s.qbittorrentCheckRulesCanDownload(action)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error checking client rules: %v", action.Name)
|
||||
continue
|
||||
|
@ -131,7 +131,7 @@ func (s *service) RunActions(actions []domain.Action, release domain.Release) er
|
|||
}
|
||||
|
||||
go func(action domain.Action, hash string, tmpFile string) {
|
||||
err = s.qbittorrent(action, hash, tmpFile)
|
||||
err = s.qbittorrent(client, action, hash, tmpFile)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error sending torrent to qBittorrent")
|
||||
}
|
||||
|
@ -206,7 +206,7 @@ func (s *service) CheckCanDownload(actions []domain.Action) bool {
|
|||
return true
|
||||
|
||||
case domain.ActionTypeQbittorrent:
|
||||
canDownload, err := s.qbittorrentCheckRulesCanDownload(action)
|
||||
canDownload, _, err := s.qbittorrentCheckRulesCanDownload(action)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("error checking client rules: %v", action.Name)
|
||||
continue
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package action
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
@ -15,7 +16,7 @@ func (s *service) sonarr(release domain.Release, action domain.Action) error {
|
|||
// TODO validate data
|
||||
|
||||
// get client for action
|
||||
client, err := s.clientSvc.FindByID(action.ClientID)
|
||||
client, err := s.clientSvc.FindByID(context.TODO(), action.ClientID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msgf("error finding client: %v", action.ClientID)
|
||||
return err
|
||||
|
|
|
@ -3,23 +3,24 @@ package database
|
|||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ActionRepo struct {
|
||||
db *sql.DB
|
||||
db *SqliteDB
|
||||
}
|
||||
|
||||
func NewActionRepo(db *sql.DB) domain.ActionRepo {
|
||||
func NewActionRepo(db *SqliteDB) domain.ActionRepo {
|
||||
return &ActionRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *ActionRepo) FindByFilterID(filterID int) ([]domain.Action, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action WHERE action.filter_id = ?", filterID)
|
||||
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action WHERE action.filter_id = ?", filterID)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -66,8 +67,10 @@ func (r *ActionRepo) FindByFilterID(filterID int) ([]domain.Action, error) {
|
|||
}
|
||||
|
||||
func (r *ActionRepo) List() ([]domain.Action, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action")
|
||||
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_download_speed, limit_upload_speed, client_id FROM action")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -110,7 +113,10 @@ func (r *ActionRepo) List() ([]domain.Action, error) {
|
|||
}
|
||||
|
||||
func (r *ActionRepo) Delete(actionID int) error {
|
||||
res, err := r.db.Exec(`DELETE FROM action WHERE action.id = ?`, actionID)
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
res, err := r.db.handler.Exec(`DELETE FROM action WHERE action.id = ?`, actionID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -123,7 +129,10 @@ func (r *ActionRepo) Delete(actionID int) error {
|
|||
}
|
||||
|
||||
func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error {
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM action WHERE filter_id = ?`, filterID)
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
_, err := r.db.handler.ExecContext(ctx, `DELETE FROM action WHERE filter_id = ?`, filterID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("actions: error deleting by filterid")
|
||||
return err
|
||||
|
@ -135,6 +144,8 @@ func (r *ActionRepo) DeleteByFilterID(ctx context.Context, filterID int) error {
|
|||
}
|
||||
|
||||
func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.Action, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
execCmd := toNullString(action.ExecCmd)
|
||||
execArgs := toNullString(action.ExecArgs)
|
||||
|
@ -152,12 +163,12 @@ func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.A
|
|||
var err error
|
||||
if action.ID != 0 {
|
||||
log.Debug().Msg("actions: update existing record")
|
||||
_, err = r.db.ExecContext(ctx, `UPDATE action SET name = ?, type = ?, enabled = ?, exec_cmd = ?, exec_args = ?, watch_folder = ? , category =? , tags = ?, label = ?, save_path = ?, paused = ?, ignore_rules = ?, limit_upload_speed = ?, limit_download_speed = ?, client_id = ?
|
||||
_, err = r.db.handler.ExecContext(ctx, `UPDATE action SET name = ?, type = ?, enabled = ?, exec_cmd = ?, exec_args = ?, watch_folder = ? , category =? , tags = ?, label = ?, save_path = ?, paused = ?, ignore_rules = ?, limit_upload_speed = ?, limit_download_speed = ?, client_id = ?
|
||||
WHERE id = ?`, action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, clientID, action.ID)
|
||||
} else {
|
||||
var res sql.Result
|
||||
|
||||
res, err = r.db.ExecContext(ctx, `INSERT INTO action(name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_upload_speed, limit_download_speed, client_id, filter_id)
|
||||
res, err = r.db.handler.ExecContext(ctx, `INSERT INTO action(name, type, enabled, exec_cmd, exec_args, watch_folder, category, tags, label, save_path, paused, ignore_rules, limit_upload_speed, limit_download_speed, client_id, filter_id)
|
||||
VALUES (?, ?, ?, ?, ?,? ,?, ?,?,?,?,?,?,?,?,?) ON CONFLICT DO NOTHING`, action.Name, action.Type, action.Enabled, execCmd, execArgs, watchFolder, category, tags, label, savePath, action.Paused, action.IgnoreRules, limitUL, limitDL, clientID, filterID)
|
||||
if err != nil {
|
||||
log.Error().Err(err)
|
||||
|
@ -173,8 +184,10 @@ func (r *ActionRepo) Store(ctx context.Context, action domain.Action) (*domain.A
|
|||
}
|
||||
|
||||
func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Action, filterID int64) ([]domain.Action, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
tx, err := r.db.handler.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -227,11 +240,13 @@ func (r *ActionRepo) StoreFilterActions(ctx context.Context, actions []domain.Ac
|
|||
}
|
||||
|
||||
func (r *ActionRepo) ToggleEnabled(actionID int) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
var err error
|
||||
var res sql.Result
|
||||
|
||||
res, err = r.db.Exec(`UPDATE action SET enabled = NOT enabled WHERE id = ?`, actionID)
|
||||
res, err = r.db.handler.Exec(`UPDATE action SET enabled = NOT enabled WHERE id = ?`, actionID)
|
||||
if err != nil {
|
||||
log.Error().Err(err)
|
||||
return err
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
|
||||
|
@ -10,16 +11,48 @@ import (
|
|||
)
|
||||
|
||||
type DownloadClientRepo struct {
|
||||
db *sql.DB
|
||||
db *SqliteDB
|
||||
cache *clientCache
|
||||
}
|
||||
|
||||
func NewDownloadClientRepo(db *sql.DB) domain.DownloadClientRepo {
|
||||
return &DownloadClientRepo{db: db}
|
||||
type clientCache struct {
|
||||
clients map[int]*domain.DownloadClient
|
||||
}
|
||||
|
||||
func NewClientCache() *clientCache {
|
||||
return &clientCache{
|
||||
clients: make(map[int]*domain.DownloadClient, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (c *clientCache) Set(id int, client *domain.DownloadClient) {
|
||||
c.clients[id] = client
|
||||
}
|
||||
|
||||
func (c *clientCache) Get(id int) *domain.DownloadClient {
|
||||
v, ok := c.clients[id]
|
||||
if ok {
|
||||
return v
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *clientCache) Pop(id int) {
|
||||
delete(c.clients, id)
|
||||
}
|
||||
|
||||
func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
|
||||
return &DownloadClientRepo{
|
||||
db: db,
|
||||
cache: NewClientCache(),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client")
|
||||
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client")
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("could not query download client rows")
|
||||
return nil, err
|
||||
|
@ -55,13 +88,21 @@ func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
|
|||
return clients, nil
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) FindByID(id int32) (*domain.DownloadClient, error) {
|
||||
func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
// get client from cache
|
||||
c := r.cache.Get(int(id))
|
||||
if c != nil {
|
||||
return c, nil
|
||||
}
|
||||
|
||||
query := `
|
||||
SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client WHERE id = ?
|
||||
`
|
||||
|
||||
row := r.db.QueryRow(query, id)
|
||||
row := r.db.handler.QueryRowContext(ctx, query, id)
|
||||
if err := row.Err(); err != nil {
|
||||
log.Error().Stack().Err(err).Msg("could not query download client rows")
|
||||
return nil, err
|
||||
|
@ -86,6 +127,9 @@ func (r *DownloadClientRepo) FindByID(id int32) (*domain.DownloadClient, error)
|
|||
}
|
||||
|
||||
func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.DownloadClient, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
var err error
|
||||
|
||||
settings := domain.DownloadClientSettings{
|
||||
|
@ -101,7 +145,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
|
|||
}
|
||||
|
||||
if client.ID != 0 {
|
||||
_, err = r.db.Exec(`
|
||||
_, err = r.db.handler.Exec(`
|
||||
UPDATE
|
||||
client
|
||||
SET
|
||||
|
@ -134,7 +178,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
|
|||
} else {
|
||||
var res sql.Result
|
||||
|
||||
res, err = r.db.Exec(`INSERT INTO
|
||||
res, err = r.db.handler.Exec(`INSERT INTO
|
||||
client(
|
||||
name,
|
||||
type,
|
||||
|
@ -170,16 +214,25 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
|
|||
log.Info().Msgf("store download client: %v", client.Name)
|
||||
log.Trace().Msgf("store download client: %+v", client)
|
||||
|
||||
// save to cache
|
||||
r.cache.Set(client.ID, &client)
|
||||
|
||||
return &client, nil
|
||||
}
|
||||
|
||||
func (r *DownloadClientRepo) Delete(clientID int) error {
|
||||
res, err := r.db.Exec(`DELETE FROM client WHERE client.id = ?`, clientID)
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
res, err := r.db.handler.Exec(`DELETE FROM client WHERE client.id = ?`, clientID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID)
|
||||
return err
|
||||
}
|
||||
|
||||
// remove from cache
|
||||
r.cache.Pop(clientID)
|
||||
|
||||
rows, _ := res.RowsAffected()
|
||||
|
||||
if rows == 0 {
|
||||
|
|
|
@ -13,16 +13,18 @@ import (
|
|||
)
|
||||
|
||||
type FilterRepo struct {
|
||||
db *sql.DB
|
||||
db *SqliteDB
|
||||
}
|
||||
|
||||
func NewFilterRepo(db *sql.DB) domain.FilterRepo {
|
||||
func NewFilterRepo(db *SqliteDB) domain.FilterRepo {
|
||||
return &FilterRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *FilterRepo) ListFilters() ([]domain.Filter, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter")
|
||||
rows, err := r.db.handler.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -62,8 +64,10 @@ func (r *FilterRepo) ListFilters() ([]domain.Filter, error) {
|
|||
}
|
||||
|
||||
func (r *FilterRepo) FindByID(filterID int) (*domain.Filter, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
row := r.db.QueryRow("SELECT id, enabled, name, min_size, max_size, delay, match_releases, except_releases, use_regex, match_release_groups, except_release_groups, scene, freeleech, freeleech_percent, shows, seasons, episodes, resolutions, codecs, sources, containers, years, match_categories, except_categories, match_uploaders, except_uploaders, tags, except_tags, created_at, updated_at FROM filter WHERE id = ?", filterID)
|
||||
row := r.db.handler.QueryRow("SELECT id, enabled, name, min_size, max_size, delay, match_releases, except_releases, use_regex, match_release_groups, except_release_groups, scene, freeleech, freeleech_percent, shows, seasons, episodes, resolutions, codecs, sources, containers, years, match_categories, except_categories, match_uploaders, except_uploaders, tags, except_tags, created_at, updated_at FROM filter WHERE id = ?", filterID)
|
||||
|
||||
var f domain.Filter
|
||||
|
||||
|
@ -114,8 +118,10 @@ func (r *FilterRepo) FindByID(filterID int) (*domain.Filter, error) {
|
|||
|
||||
// TODO remove
|
||||
func (r *FilterRepo) FindFiltersForSite(site string) ([]domain.Filter, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter WHERE match_sites LIKE ?", site)
|
||||
rows, err := r.db.handler.Query("SELECT id, enabled, name, match_releases, except_releases, created_at, updated_at FROM filter WHERE match_sites LIKE ?", site)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -144,8 +150,10 @@ func (r *FilterRepo) FindFiltersForSite(site string) ([]domain.Filter, error) {
|
|||
|
||||
// FindByIndexerIdentifier find active filters only
|
||||
func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.Query(`
|
||||
rows, err := r.db.handler.Query(`
|
||||
SELECT
|
||||
f.id,
|
||||
f.enabled,
|
||||
|
@ -241,6 +249,8 @@ func (r *FilterRepo) FindByIndexerIdentifier(indexer string) ([]domain.Filter, e
|
|||
}
|
||||
|
||||
func (r *FilterRepo) Store(filter domain.Filter) (*domain.Filter, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
var err error
|
||||
if filter.ID != 0 {
|
||||
|
@ -248,7 +258,7 @@ func (r *FilterRepo) Store(filter domain.Filter) (*domain.Filter, error) {
|
|||
} else {
|
||||
var res sql.Result
|
||||
|
||||
res, err = r.db.Exec(`INSERT INTO filter (
|
||||
res, err = r.db.handler.Exec(`INSERT INTO filter (
|
||||
name,
|
||||
enabled,
|
||||
min_size,
|
||||
|
@ -319,11 +329,13 @@ func (r *FilterRepo) Store(filter domain.Filter) (*domain.Filter, error) {
|
|||
}
|
||||
|
||||
func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain.Filter, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
//var res sql.Result
|
||||
|
||||
var err error
|
||||
_, err = r.db.ExecContext(ctx, `
|
||||
_, err = r.db.handler.ExecContext(ctx, `
|
||||
UPDATE filter SET
|
||||
name = ?,
|
||||
enabled = ?,
|
||||
|
@ -392,9 +404,11 @@ func (r *FilterRepo) Update(ctx context.Context, filter domain.Filter) (*domain.
|
|||
}
|
||||
|
||||
func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bool) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
var err error
|
||||
_, err = r.db.ExecContext(ctx, `
|
||||
_, err = r.db.handler.ExecContext(ctx, `
|
||||
UPDATE filter SET
|
||||
enabled = ?,
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
|
@ -411,8 +425,10 @@ func (r *FilterRepo) ToggleEnabled(ctx context.Context, filterID int, enabled bo
|
|||
}
|
||||
|
||||
func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int, indexers []domain.Indexer) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
tx, err := r.db.BeginTx(ctx, nil)
|
||||
tx, err := r.db.handler.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -447,8 +463,11 @@ func (r *FilterRepo) StoreIndexerConnections(ctx context.Context, filterID int,
|
|||
}
|
||||
|
||||
func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, indexerID int) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
query := `INSERT INTO filter_indexer (filter_id, indexer_id) VALUES ($1, $2)`
|
||||
_, err := r.db.ExecContext(ctx, query, filterID, indexerID)
|
||||
_, err := r.db.handler.ExecContext(ctx, query, filterID, indexerID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return err
|
||||
|
@ -458,9 +477,11 @@ func (r *FilterRepo) StoreIndexerConnection(ctx context.Context, filterID int, i
|
|||
}
|
||||
|
||||
func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
query := `DELETE FROM filter_indexer WHERE filter_id = ?`
|
||||
_, err := r.db.ExecContext(ctx, query, filterID)
|
||||
_, err := r.db.handler.ExecContext(ctx, query, filterID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return err
|
||||
|
@ -470,8 +491,10 @@ func (r *FilterRepo) DeleteIndexerConnections(ctx context.Context, filterID int)
|
|||
}
|
||||
|
||||
func (r *FilterRepo) Delete(ctx context.Context, filterID int) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
_, err := r.db.ExecContext(ctx, `DELETE FROM filter WHERE id = ?`, filterID)
|
||||
_, err := r.db.handler.ExecContext(ctx, `DELETE FROM filter WHERE id = ?`, filterID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return err
|
||||
|
|
|
@ -2,23 +2,24 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type IndexerRepo struct {
|
||||
db *sql.DB
|
||||
db *SqliteDB
|
||||
}
|
||||
|
||||
func NewIndexerRepo(db *sql.DB) domain.IndexerRepo {
|
||||
func NewIndexerRepo(db *SqliteDB) domain.IndexerRepo {
|
||||
return &IndexerRepo{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *IndexerRepo) Store(indexer domain.Indexer) (*domain.Indexer, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
settings, err := json.Marshal(indexer.Settings)
|
||||
if err != nil {
|
||||
|
@ -26,7 +27,7 @@ func (r *IndexerRepo) Store(indexer domain.Indexer) (*domain.Indexer, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
res, err := r.db.Exec(`INSERT INTO indexer (enabled, name, identifier, settings) VALUES (?, ?, ?, ?)`, indexer.Enabled, indexer.Name, indexer.Identifier, settings)
|
||||
res, err := r.db.handler.Exec(`INSERT INTO indexer (enabled, name, identifier, settings) VALUES (?, ?, ?, ?)`, indexer.Enabled, indexer.Name, indexer.Identifier, settings)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return nil, err
|
||||
|
@ -39,6 +40,8 @@ func (r *IndexerRepo) Store(indexer domain.Indexer) (*domain.Indexer, error) {
|
|||
}
|
||||
|
||||
func (r *IndexerRepo) Update(indexer domain.Indexer) (*domain.Indexer, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
sett, err := json.Marshal(indexer.Settings)
|
||||
if err != nil {
|
||||
|
@ -46,7 +49,7 @@ func (r *IndexerRepo) Update(indexer domain.Indexer) (*domain.Indexer, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
_, err = r.db.Exec(`UPDATE indexer SET enabled = ?, name = ?, settings = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, indexer.Enabled, indexer.Name, sett, indexer.ID)
|
||||
_, err = r.db.handler.Exec(`UPDATE indexer SET enabled = ?, name = ?, settings = ?, updated_at = CURRENT_TIMESTAMP WHERE id = ?`, indexer.Enabled, indexer.Name, sett, indexer.ID)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return nil, err
|
||||
|
@ -56,8 +59,10 @@ func (r *IndexerRepo) Update(indexer domain.Indexer) (*domain.Indexer, error) {
|
|||
}
|
||||
|
||||
func (r *IndexerRepo) List() ([]domain.Indexer, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.Query("SELECT id, enabled, name, identifier, settings FROM indexer ORDER BY name ASC")
|
||||
rows, err := r.db.handler.Query("SELECT id, enabled, name, identifier, settings FROM indexer ORDER BY name ASC")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -96,7 +101,10 @@ func (r *IndexerRepo) List() ([]domain.Indexer, error) {
|
|||
}
|
||||
|
||||
func (r *IndexerRepo) FindByFilterID(id int) ([]domain.Indexer, error) {
|
||||
rows, err := r.db.Query(`
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := r.db.handler.Query(`
|
||||
SELECT i.id, i.enabled, i.name, i.identifier
|
||||
FROM indexer i
|
||||
JOIN filter_indexer fi on i.id = fi.indexer_id
|
||||
|
@ -140,9 +148,12 @@ func (r *IndexerRepo) FindByFilterID(id int) ([]domain.Indexer, error) {
|
|||
}
|
||||
|
||||
func (r *IndexerRepo) Delete(ctx context.Context, id int) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
query := `DELETE FROM indexer WHERE id = ?`
|
||||
|
||||
_, err := r.db.ExecContext(ctx, query, id)
|
||||
_, err := r.db.handler.ExecContext(ctx, query, id)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msgf("indexer.delete: error executing query: '%v'", query)
|
||||
return err
|
||||
|
|
|
@ -12,16 +12,18 @@ import (
|
|||
)
|
||||
|
||||
type IrcRepo struct {
|
||||
db *sql.DB
|
||||
db *SqliteDB
|
||||
}
|
||||
|
||||
func NewIrcRepo(db *sql.DB) domain.IrcRepo {
|
||||
func NewIrcRepo(db *SqliteDB) domain.IrcRepo {
|
||||
return &IrcRepo{db: db}
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) {
|
||||
func (r *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
row := ir.db.QueryRow("SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE id = ?", id)
|
||||
row := r.db.handler.QueryRow("SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE id = ?", id)
|
||||
if err := row.Err(); err != nil {
|
||||
log.Fatal().Err(err)
|
||||
return nil, err
|
||||
|
@ -46,8 +48,8 @@ func (ir *IrcRepo) GetNetworkByID(id int64) (*domain.IrcNetwork, error) {
|
|||
return &n, nil
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
|
||||
tx, err := ir.db.BeginTx(ctx, nil)
|
||||
func (r *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
|
||||
tx, err := r.db.handler.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -76,9 +78,11 @@ func (ir *IrcRepo) DeleteNetwork(ctx context.Context, id int64) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
|
||||
func (r *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := ir.db.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE enabled = true")
|
||||
rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network WHERE enabled = true")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -109,9 +113,11 @@ func (ir *IrcRepo) FindActiveNetworks(ctx context.Context) ([]domain.IrcNetwork,
|
|||
return networks, nil
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
|
||||
func (r *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := ir.db.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network ORDER BY name ASC")
|
||||
rows, err := r.db.handler.QueryContext(ctx, "SELECT id, enabled, name, server, port, tls, pass, invite_command, nickserv_account, nickserv_password FROM irc_network ORDER BY name ASC")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -142,9 +148,11 @@ func (ir *IrcRepo) ListNetworks(ctx context.Context) ([]domain.IrcNetwork, error
|
|||
return networks, nil
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
|
||||
func (r *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
rows, err := ir.db.Query("SELECT id, name, enabled FROM irc_channel WHERE network_id = ?", networkID)
|
||||
rows, err := r.db.handler.Query("SELECT id, name, enabled FROM irc_channel WHERE network_id = ?", networkID)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err)
|
||||
}
|
||||
|
@ -167,7 +175,9 @@ func (ir *IrcRepo) ListChannels(networkID int64) ([]domain.IrcChannel, error) {
|
|||
return channels, nil
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) {
|
||||
func (r *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.IrcNetwork) (*domain.IrcNetwork, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
queryBuilder := sq.
|
||||
Select("id", "enabled", "name", "server", "port", "tls", "pass", "invite_command", "nickserv_account", "nickserv_password").
|
||||
|
@ -182,7 +192,7 @@ func (ir *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.Irc
|
|||
}
|
||||
log.Trace().Str("database", "irc.check_existing_network").Msgf("query: '%v', args: '%v'", query, args)
|
||||
|
||||
row := ir.db.QueryRowContext(ctx, query, args...)
|
||||
row := r.db.handler.QueryRowContext(ctx, query, args...)
|
||||
|
||||
var net domain.IrcNetwork
|
||||
|
||||
|
@ -206,7 +216,9 @@ func (ir *IrcRepo) CheckExistingNetwork(ctx context.Context, network *domain.Irc
|
|||
return &net, nil
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
|
||||
func (r *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
netName := toNullString(network.Name)
|
||||
pass := toNullString(network.Pass)
|
||||
|
@ -218,7 +230,7 @@ func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
|
|||
var err error
|
||||
if network.ID != 0 {
|
||||
// update record
|
||||
_, err = ir.db.Exec(`UPDATE irc_network
|
||||
_, err = r.db.handler.Exec(`UPDATE irc_network
|
||||
SET enabled = ?,
|
||||
name = ?,
|
||||
server = ?,
|
||||
|
@ -248,7 +260,7 @@ func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
|
|||
} else {
|
||||
var res sql.Result
|
||||
|
||||
res, err = ir.db.Exec(`INSERT INTO irc_network (
|
||||
res, err = r.db.handler.Exec(`INSERT INTO irc_network (
|
||||
enabled,
|
||||
name,
|
||||
server,
|
||||
|
@ -280,7 +292,9 @@ func (ir *IrcRepo) StoreNetwork(network *domain.IrcNetwork) error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error {
|
||||
func (r *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
netName := toNullString(network.Name)
|
||||
pass := toNullString(network.Pass)
|
||||
|
@ -291,7 +305,7 @@ func (ir *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork
|
|||
|
||||
var err error
|
||||
// update record
|
||||
_, err = ir.db.ExecContext(ctx, `UPDATE irc_network
|
||||
_, err = r.db.handler.ExecContext(ctx, `UPDATE irc_network
|
||||
SET enabled = ?,
|
||||
name = ?,
|
||||
server = ?,
|
||||
|
@ -324,9 +338,11 @@ func (ir *IrcRepo) UpdateNetwork(ctx context.Context, network *domain.IrcNetwork
|
|||
|
||||
// TODO create new channel handler to only add, not delete
|
||||
|
||||
func (ir *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, channels []domain.IrcChannel) error {
|
||||
func (r *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, channels []domain.IrcChannel) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
tx, err := ir.db.BeginTx(ctx, nil)
|
||||
tx, err := r.db.handler.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -373,13 +389,16 @@ func (ir *IrcRepo) StoreNetworkChannels(ctx context.Context, networkID int64, ch
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ir *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) error {
|
||||
func (r *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
pass := toNullString(channel.Password)
|
||||
|
||||
var err error
|
||||
if channel.ID != 0 {
|
||||
// update record
|
||||
_, err = ir.db.Exec(`UPDATE irc_channel
|
||||
_, err = r.db.handler.Exec(`UPDATE irc_channel
|
||||
SET
|
||||
enabled = ?,
|
||||
detached = ?,
|
||||
|
@ -396,7 +415,7 @@ func (ir *IrcRepo) StoreChannel(networkID int64, channel *domain.IrcChannel) err
|
|||
} else {
|
||||
var res sql.Result
|
||||
|
||||
res, err = ir.db.Exec(`INSERT INTO irc_channel (
|
||||
res, err = r.db.handler.Exec(`INSERT INTO irc_channel (
|
||||
enabled,
|
||||
detached,
|
||||
name,
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
|
@ -283,9 +282,12 @@ var migrations = []string{
|
|||
`,
|
||||
}
|
||||
|
||||
func Migrate(db *sql.DB) error {
|
||||
func (db *SqliteDB) migrate() error {
|
||||
db.lock.Lock()
|
||||
defer db.lock.Unlock()
|
||||
|
||||
var version int
|
||||
if err := db.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
|
||||
if err := db.handler.QueryRow("PRAGMA user_version").Scan(&version); err != nil {
|
||||
return fmt.Errorf("failed to query schema version: %v", err)
|
||||
}
|
||||
|
||||
|
@ -295,7 +297,7 @@ func Migrate(db *sql.DB) error {
|
|||
return fmt.Errorf("autobrr (version %d) older than schema (version: %d)", len(migrations), version)
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.handler.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -13,21 +13,24 @@ import (
|
|||
)
|
||||
|
||||
type ReleaseRepo struct {
|
||||
db *sql.DB
|
||||
db *SqliteDB
|
||||
}
|
||||
|
||||
func NewReleaseRepo(db *sql.DB) domain.ReleaseRepo {
|
||||
func NewReleaseRepo(db *SqliteDB) domain.ReleaseRepo {
|
||||
return &ReleaseRepo{db: db}
|
||||
}
|
||||
|
||||
func (repo *ReleaseRepo) Store(ctx context.Context, r *domain.Release) (*domain.Release, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
query, args, err := sq.
|
||||
Insert("release").
|
||||
Columns("filter_status", "rejections", "indexer", "filter", "protocol", "implementation", "timestamp", "group_id", "torrent_id", "torrent_name", "size", "raw", "title", "category", "season", "episode", "year", "resolution", "source", "codec", "container", "hdr", "audio", "release_group", "region", "language", "edition", "unrated", "hybrid", "proper", "repack", "website", "artists", "type", "format", "bitrate", "log_score", "has_log", "has_cue", "is_scene", "origin", "tags", "freeleech", "freeleech_percent", "uploader", "pre_time").
|
||||
Values(r.FilterStatus, pq.Array(r.Rejections), r.Indexer, r.FilterName, r.Protocol, r.Implementation, r.Timestamp, r.GroupID, r.TorrentID, r.TorrentName, r.Size, r.Raw, r.Title, r.Category, r.Season, r.Episode, r.Year, r.Resolution, r.Source, r.Codec, r.Container, r.HDR, r.Audio, r.Group, r.Region, r.Language, r.Edition, r.Unrated, r.Hybrid, r.Proper, r.Repack, r.Website, pq.Array(r.Artists), r.Type, r.Format, r.Bitrate, r.LogScore, r.HasLog, r.HasCue, r.IsScene, r.Origin, pq.Array(r.Tags), r.Freeleech, r.FreeleechPercent, r.Uploader, r.PreTime).
|
||||
ToSql()
|
||||
|
||||
res, err := repo.db.ExecContext(ctx, query, args...)
|
||||
res, err := repo.db.handler.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error inserting release")
|
||||
return nil, err
|
||||
|
@ -42,6 +45,8 @@ func (repo *ReleaseRepo) Store(ctx context.Context, r *domain.Release) (*domain.
|
|||
}
|
||||
|
||||
func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain.ReleaseActionStatus) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
if a.ID != 0 {
|
||||
query, args, err := sq.
|
||||
|
@ -53,7 +58,7 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
|
|||
Where("release_id = ?", a.ReleaseID).
|
||||
ToSql()
|
||||
|
||||
_, err = repo.db.ExecContext(ctx, query, args...)
|
||||
_, err = repo.db.handler.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error updating status of release")
|
||||
return err
|
||||
|
@ -66,7 +71,7 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
|
|||
Values(a.Status, a.Action, a.Type, pq.Array(a.Rejections), a.Timestamp, a.ReleaseID).
|
||||
ToSql()
|
||||
|
||||
res, err := repo.db.ExecContext(ctx, query, args...)
|
||||
res, err := repo.db.handler.ExecContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error inserting status of release")
|
||||
return err
|
||||
|
@ -82,6 +87,8 @@ func (repo *ReleaseRepo) StoreReleaseActionStatus(ctx context.Context, a *domain
|
|||
}
|
||||
|
||||
func (repo *ReleaseRepo) Find(ctx context.Context, params domain.QueryParams) ([]domain.Release, int64, int64, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
queryBuilder := sq.
|
||||
Select("id", "filter_status", "rejections", "indexer", "filter", "protocol", "title", "torrent_name", "size", "timestamp", "COUNT() OVER() AS total_count").
|
||||
|
@ -116,7 +123,7 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.QueryParams) ([
|
|||
|
||||
res := make([]domain.Release, 0)
|
||||
|
||||
rows, err := repo.db.QueryContext(ctx, query, args...)
|
||||
rows, err := repo.db.handler.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error fetching releases")
|
||||
return res, 0, 0, nil
|
||||
|
@ -171,6 +178,8 @@ func (repo *ReleaseRepo) Find(ctx context.Context, params domain.QueryParams) ([
|
|||
}
|
||||
|
||||
func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, releaseID int64) ([]domain.ReleaseActionStatus, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
queryBuilder := sq.
|
||||
Select("id", "status", "action", "type", "rejections", "timestamp").
|
||||
|
@ -181,7 +190,7 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release
|
|||
|
||||
res := make([]domain.ReleaseActionStatus, 0)
|
||||
|
||||
rows, err := repo.db.QueryContext(ctx, query, args...)
|
||||
rows, err := repo.db.handler.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error fetching releases")
|
||||
return res, nil
|
||||
|
@ -214,6 +223,9 @@ func (repo *ReleaseRepo) GetActionStatusByReleaseID(ctx context.Context, release
|
|||
}
|
||||
|
||||
func (repo *ReleaseRepo) Stats(ctx context.Context) (*domain.ReleaseStats, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
query := `SELECT COUNT(*) total,
|
||||
IFNULL(SUM(CASE WHEN filter_status = 'FILTER_APPROVED' THEN 1 ELSE 0 END), 0) filtered_count,
|
||||
IFNULL(SUM(CASE WHEN filter_status = 'FILTER_REJECTED' THEN 1 ELSE 0 END), 0) filter_rejected_count,
|
||||
|
@ -223,7 +235,7 @@ func (repo *ReleaseRepo) Stats(ctx context.Context) (*domain.ReleaseStats, error
|
|||
FROM "release_action_status") AS push_rejected_count
|
||||
FROM "release";`
|
||||
|
||||
row := repo.db.QueryRowContext(ctx, query)
|
||||
row := repo.db.handler.QueryRowContext(ctx, query)
|
||||
if err := row.Err(); err != nil {
|
||||
log.Error().Stack().Err(err).Msg("release.stats: error querying stats")
|
||||
return nil, err
|
||||
|
@ -231,7 +243,7 @@ FROM "release";`
|
|||
|
||||
var rls domain.ReleaseStats
|
||||
|
||||
if err := row.Scan(&rls.TotalCount, &rls.PushApprovedCount, &rls.PushRejectedCount, &rls.FilteredCount, &rls.FilterRejectedCount); err != nil {
|
||||
if err := row.Scan(&rls.TotalCount, &rls.FilteredCount, &rls.FilterRejectedCount, &rls.PushApprovedCount, &rls.PushRejectedCount); err != nil {
|
||||
log.Error().Stack().Err(err).Msg("release.stats: error scanning stats data to struct")
|
||||
return nil, err
|
||||
}
|
||||
|
|
81
internal/database/sqlite.go
Normal file
81
internal/database/sqlite.go
Normal file
|
@ -0,0 +1,81 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"sync"
|
||||
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type SqliteDB struct {
|
||||
lock sync.RWMutex
|
||||
handler *sql.DB
|
||||
ctx context.Context
|
||||
cancel func()
|
||||
|
||||
DSN string
|
||||
}
|
||||
|
||||
func NewSqliteDB(source string) *SqliteDB {
|
||||
db := &SqliteDB{
|
||||
DSN: dataSourceName(source, "autobrr.db"),
|
||||
}
|
||||
|
||||
db.ctx, db.cancel = context.WithCancel(context.Background())
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *SqliteDB) Open() error {
|
||||
if db.DSN == "" {
|
||||
return fmt.Errorf("DSN required")
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
// open database connection
|
||||
if db.handler, err = sql.Open("sqlite3", db.DSN); err != nil {
|
||||
log.Fatal().Err(err).Msg("could not open db connection")
|
||||
return err
|
||||
}
|
||||
|
||||
// Set busy timeout
|
||||
if _, err = db.handler.Exec(`PRAGMA busy_timeout = 5000;`); err != nil {
|
||||
return fmt.Errorf("busy timeout pragma")
|
||||
}
|
||||
|
||||
// Enable WAL. SQLite performs better with the WAL because it allows
|
||||
// multiple readers to operate while data is being written.
|
||||
if _, err = db.handler.Exec(`PRAGMA journal_mode = wal;`); err != nil {
|
||||
return fmt.Errorf("enable wal: %w", err)
|
||||
}
|
||||
|
||||
// Enable foreign key checks. For historical reasons, SQLite does not check
|
||||
// foreign key constraints by default. There's some overhead on inserts to
|
||||
// verify foreign key integrity, but it's definitely worth it.
|
||||
if _, err = db.handler.Exec(`PRAGMA foreign_keys = ON;`); err != nil {
|
||||
return fmt.Errorf("foreign keys pragma: %w", err)
|
||||
}
|
||||
|
||||
// migrate db
|
||||
if err = db.migrate(); err != nil {
|
||||
log.Fatal().Err(err).Msg("could not migrate db")
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *SqliteDB) Close() error {
|
||||
// cancel background context
|
||||
db.cancel()
|
||||
|
||||
// close database
|
||||
if db.handler != nil {
|
||||
return db.handler.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
|
@ -2,25 +2,26 @@ package database
|
|||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
)
|
||||
|
||||
type UserRepo struct {
|
||||
db *sql.DB
|
||||
db *SqliteDB
|
||||
}
|
||||
|
||||
func NewUserRepo(db *sql.DB) domain.UserRepo {
|
||||
func NewUserRepo(db *SqliteDB) domain.UserRepo {
|
||||
return &UserRepo{db: db}
|
||||
}
|
||||
|
||||
func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain.User, error) {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
query := `SELECT id, username, password FROM users WHERE username = ?`
|
||||
|
||||
row := r.db.QueryRowContext(ctx, query, username)
|
||||
row := r.db.handler.QueryRowContext(ctx, query, username)
|
||||
if err := row.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -36,11 +37,13 @@ func (r *UserRepo) FindByUsername(ctx context.Context, username string) (*domain
|
|||
}
|
||||
|
||||
func (r *UserRepo) Store(ctx context.Context, user domain.User) error {
|
||||
//r.db.lock.RLock()
|
||||
//defer r.db.lock.RUnlock()
|
||||
|
||||
var err error
|
||||
if user.ID != 0 {
|
||||
update := `UPDATE users SET password = ? WHERE username = ?`
|
||||
_, err = r.db.ExecContext(ctx, update, user.Password, user.Username)
|
||||
_, err = r.db.handler.ExecContext(ctx, update, user.Password, user.Username)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return err
|
||||
|
@ -48,7 +51,7 @@ func (r *UserRepo) Store(ctx context.Context, user domain.User) error {
|
|||
|
||||
} else {
|
||||
query := `INSERT INTO users (username, password) VALUES (?, ?)`
|
||||
_, err = r.db.ExecContext(ctx, query, user.Username, user.Password)
|
||||
_, err = r.db.handler.ExecContext(ctx, query, user.Username, user.Password)
|
||||
if err != nil {
|
||||
log.Error().Stack().Err(err).Msg("error executing query")
|
||||
return err
|
||||
|
|
|
@ -4,7 +4,7 @@ import (
|
|||
"path"
|
||||
)
|
||||
|
||||
func DataSourceName(configPath string, name string) string {
|
||||
func dataSourceName(configPath string, name string) string {
|
||||
if configPath != "" {
|
||||
return path.Join(configPath, name)
|
||||
}
|
||||
|
|
|
@ -51,7 +51,7 @@ func TestDataSourceName(t *testing.T) {
|
|||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := DataSourceName(tt.args.configPath, tt.args.name)
|
||||
got := dataSourceName(tt.args.configPath, tt.args.name)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
package domain
|
||||
|
||||
import "context"
|
||||
|
||||
type DownloadClientRepo interface {
|
||||
//FindByActionID(actionID int) ([]DownloadClient, error)
|
||||
List() ([]DownloadClient, error)
|
||||
FindByID(id int32) (*DownloadClient, error)
|
||||
FindByID(ctx context.Context, id int32) (*DownloadClient, error)
|
||||
Store(client DownloadClient) (*DownloadClient, error)
|
||||
Delete(clientID int) error
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package download_client
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/autobrr/autobrr/internal/domain"
|
||||
|
@ -10,7 +11,7 @@ import (
|
|||
|
||||
type Service interface {
|
||||
List() ([]domain.DownloadClient, error)
|
||||
FindByID(id int32) (*domain.DownloadClient, error)
|
||||
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
|
||||
Store(client domain.DownloadClient) (*domain.DownloadClient, error)
|
||||
Delete(clientID int) error
|
||||
Test(client domain.DownloadClient) error
|
||||
|
@ -33,8 +34,8 @@ func (s *service) List() ([]domain.DownloadClient, error) {
|
|||
return clients, nil
|
||||
}
|
||||
|
||||
func (s *service) FindByID(id int32) (*domain.DownloadClient, error) {
|
||||
client, err := s.repo.FindByID(id)
|
||||
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
|
||||
client, err := s.repo.FindByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ import (
|
|||
)
|
||||
|
||||
type Client struct {
|
||||
Name string
|
||||
settings Settings
|
||||
http *http.Client
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue