diff --git a/cmd/autobrr/main.go b/cmd/autobrr/main.go index 053a55b..e998503 100644 --- a/cmd/autobrr/main.go +++ b/cmd/autobrr/main.go @@ -59,7 +59,7 @@ func main() { log.Info().Msgf("Log-level: %v", cfg.LogLevel) // open database connection - db := database.NewSqliteDB(configPath) + db, _ := database.NewDB(cfg) if err := db.Open(); err != nil { log.Fatal().Err(err).Msg("could not open db connection") } diff --git a/cmd/autobrrctl/main.go b/cmd/autobrrctl/main.go index 6f250c8..41b629d 100644 --- a/cmd/autobrrctl/main.go +++ b/cmd/autobrrctl/main.go @@ -39,7 +39,7 @@ func main() { } // open database connection - db := database.NewSqliteDB(configPath) + db, _ := database.NewDB(domain.Config{ConfigPath: configPath}) if err := db.Open(); err != nil { log.Fatal("could not open db connection") } diff --git a/docker-compose.yml b/docker-compose.yml index c99ae79..5206389 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -8,4 +8,19 @@ services: - ./config:/config ports: - "7474:7474" - restart: unless-stopped \ No newline at end of file + restart: unless-stopped + postgres: + image: postgres:12.10 + container_name: postgres + volumes: + - postgres:/var/lib/postgresql/data + ports: + - "5432:5432" + environment: + - POSTGRES_USER=autobrr + - POSTGRES_PASSWORD=postgres + - POSTGRES_DB=autobrr + + +volumes: + postgres: \ No newline at end of file diff --git a/internal/config/config.go b/internal/config/config.go index 421d7a1..0bf5254 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -23,6 +23,12 @@ func Defaults() domain.Config { BaseURL: "/", SessionSecret: "secret-session-key", CustomDefinitions: "", + DatabaseType: "sqlite", + PostgresHost: "", + PostgresPort: 0, + PostgresDatabase: "", + PostgresUser: "", + PostgresPass: "", } } @@ -127,6 +133,7 @@ func Read(configPath string) domain.Config { } viper.SetConfigFile(path.Join(configPath, "config.toml")) + config.ConfigPath = configPath } else { viper.SetConfigName("config") diff --git a/internal/database/action.go b/internal/database/action.go index 596844b..ba4df7c 100644 --- a/internal/database/action.go +++ b/internal/database/action.go @@ -10,10 +10,10 @@ import ( ) type ActionRepo struct { - db *SqliteDB + db *DB } -func NewActionRepo(db *SqliteDB) domain.ActionRepo { +func NewActionRepo(db *DB) domain.ActionRepo { return &ActionRepo{db: db} } diff --git a/internal/database/database.go b/internal/database/database.go new file mode 100644 index 0000000..08a6619 --- /dev/null +++ b/internal/database/database.go @@ -0,0 +1,94 @@ +package database + +import ( + "context" + "database/sql" + "fmt" + "sync" + + "github.com/rs/zerolog/log" + + "github.com/autobrr/autobrr/internal/domain" +) + +type DB struct { + handler *sql.DB + lock sync.RWMutex + ctx context.Context + cancel func() + + Driver string + DSN string +} + +func NewDB(cfg domain.Config) (*DB, error) { + db := &DB{} + db.ctx, db.cancel = context.WithCancel(context.Background()) + + switch cfg.DatabaseType { + case "sqlite": + db.Driver = "sqlite" + db.DSN = dataSourceName(cfg.ConfigPath, "autobrr.db") + case "postgres": + if cfg.PostgresHost == "" || cfg.PostgresPort == 0 || cfg.PostgresDatabase == "" { + return nil, fmt.Errorf("postgres: bad variables") + } + db.DSN = fmt.Sprintf("postgres://%v:%v@%v:%d/%v?sslmode=disable", cfg.PostgresUser, cfg.PostgresPass, cfg.PostgresHost, cfg.PostgresPort, cfg.PostgresDatabase) + db.Driver = "postgres" + default: + return nil, fmt.Errorf("unsupported databse: %v", cfg.DatabaseType) + } + + return db, nil +} + +func (db *DB) Open() error { + if db.DSN == "" { + return fmt.Errorf("DSN required") + } + + var err error + + switch db.Driver { + case "sqlite": + if err = db.openSQLite(); err != nil { + log.Fatal().Err(err).Msg("could not open sqlite db connection") + return err + } + case "postgres": + if err = db.openPostgres(); err != nil { + log.Fatal().Err(err).Msg("could not open postgres db connection") + return err + } + } + + return nil +} + +func (db *DB) Close() error { + // cancel background context + db.cancel() + + // close database + if db.handler != nil { + return db.handler.Close() + } + return nil +} + +func (db *DB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { + tx, err := db.handler.BeginTx(ctx, opts) + if err != nil { + return nil, err + } + + return &Tx{ + Tx: tx, + handler: db, + }, nil +} + +type Tx struct { + *sql.Tx + handler *DB +} diff --git a/internal/database/download_client.go b/internal/database/download_client.go index 63a5761..041dc20 100644 --- a/internal/database/download_client.go +++ b/internal/database/download_client.go @@ -12,7 +12,7 @@ import ( ) type DownloadClientRepo struct { - db *SqliteDB + db *DB cache *clientCache } @@ -49,7 +49,7 @@ func (c *clientCache) Pop(id int) { c.mu.Unlock() } -func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo { +func NewDownloadClientRepo(db *DB) domain.DownloadClientRepo { return &DownloadClientRepo{ db: db, cache: NewClientCache(), diff --git a/internal/database/filter.go b/internal/database/filter.go index 11c7840..7cbbcac 100644 --- a/internal/database/filter.go +++ b/internal/database/filter.go @@ -12,10 +12,10 @@ import ( ) type FilterRepo struct { - db *SqliteDB + db *DB } -func NewFilterRepo(db *SqliteDB) domain.FilterRepo { +func NewFilterRepo(db *DB) domain.FilterRepo { return &FilterRepo{db: db} } diff --git a/internal/database/indexer.go b/internal/database/indexer.go index 49c6b6d..fd947ad 100644 --- a/internal/database/indexer.go +++ b/internal/database/indexer.go @@ -8,10 +8,10 @@ import ( ) type IndexerRepo struct { - db *SqliteDB + db *DB } -func NewIndexerRepo(db *SqliteDB) domain.IndexerRepo { +func NewIndexerRepo(db *DB) domain.IndexerRepo { return &IndexerRepo{ db: db, } diff --git a/internal/database/irc.go b/internal/database/irc.go index 0e75e88..1a9417a 100644 --- a/internal/database/irc.go +++ b/internal/database/irc.go @@ -12,10 +12,10 @@ import ( ) type IrcRepo struct { - db *SqliteDB + db *DB } -func NewIrcRepo(db *SqliteDB) domain.IrcRepo { +func NewIrcRepo(db *DB) domain.IrcRepo { return &IrcRepo{db: db} } diff --git a/internal/database/migrate.go b/internal/database/migrate.go index 1bd86f9..9bd1e48 100644 --- a/internal/database/migrate.go +++ b/internal/database/migrate.go @@ -1,12 +1,5 @@ package database -import ( - "database/sql" - "fmt" - - "github.com/lib/pq" -) - const schema = ` CREATE TABLE users ( @@ -376,144 +369,3 @@ var migrations = []string{ ADD COLUMN webhook_headers TEXT [] DEFAULT '{}'; `, } - -func (db *SqliteDB) migrate() error { - db.lock.Lock() - defer db.lock.Unlock() - - var version int - if err := db.handler.QueryRow("PRAGMA user_version").Scan(&version); err != nil { - return fmt.Errorf("failed to query schema version: %v", err) - } - - if version == len(migrations) { - return nil - } else if version > len(migrations) { - return fmt.Errorf("autobrr (version %d) older than schema (version: %d)", len(migrations), version) - } - - tx, err := db.handler.Begin() - if err != nil { - return err - } - defer tx.Rollback() - - if version == 0 { - if _, err := tx.Exec(schema); err != nil { - return fmt.Errorf("failed to initialize schema: %v", err) - } - } else { - for i := version; i < len(migrations); i++ { - if _, err := tx.Exec(migrations[i]); err != nil { - return fmt.Errorf("failed to execute migration #%v: %v", i, err) - } - } - } - - // temp custom data migration - // get data from filter.sources, check if specific types, move to new table and clear - // if migration 6 - // TODO 2022-01-30 remove this in future version - if version == 5 && len(migrations) == 6 { - if err := customMigrateCopySourcesToMedia(tx); err != nil { - return fmt.Errorf("could not run custom data migration: %v", err) - } - } - - _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations))) - if err != nil { - return fmt.Errorf("failed to bump schema version: %v", err) - } - - return tx.Commit() -} - -// customMigrateCopySourcesToMedia move music specific sources to media -func customMigrateCopySourcesToMedia(tx *sql.Tx) error { - rows, err := tx.Query(` - SELECT id, sources - FROM filter - WHERE sources LIKE '%"CD"%' - OR sources LIKE '%"WEB"%' - OR sources LIKE '%"DVD"%' - OR sources LIKE '%"Vinyl"%' - OR sources LIKE '%"Soundboard"%' - OR sources LIKE '%"DAT"%' - OR sources LIKE '%"Cassette"%' - OR sources LIKE '%"Blu-Ray"%' - OR sources LIKE '%"SACD"%' - ;`) - if err != nil { - return fmt.Errorf("could not run custom data migration: %v", err) - } - - defer rows.Close() - - type tmpDataStruct struct { - id int - sources []string - } - - var tmpData []tmpDataStruct - - // scan data - for rows.Next() { - var t tmpDataStruct - - if err := rows.Scan(&t.id, pq.Array(&t.sources)); err != nil { - return err - } - - tmpData = append(tmpData, t) - } - if err := rows.Err(); err != nil { - return err - } - - // manipulate data - for _, d := range tmpData { - // create new slice with only music source if they exist in d.sources - mediaSources := []string{} - for _, source := range d.sources { - switch source { - case "CD": - mediaSources = append(mediaSources, source) - case "DVD": - mediaSources = append(mediaSources, source) - case "Vinyl": - mediaSources = append(mediaSources, source) - case "Soundboard": - mediaSources = append(mediaSources, source) - case "DAT": - mediaSources = append(mediaSources, source) - case "Cassette": - mediaSources = append(mediaSources, source) - case "Blu-Ray": - mediaSources = append(mediaSources, source) - case "SACD": - mediaSources = append(mediaSources, source) - } - } - _, err = tx.Exec(`UPDATE filter SET media = ? WHERE id = ?`, pq.Array(mediaSources), d.id) - if err != nil { - return err - } - - // remove all music specific sources - cleanSources := []string{} - for _, source := range d.sources { - switch source { - case "CD", "WEB", "DVD", "Vinyl", "Soundboard", "DAT", "Cassette", "Blu-Ray", "SACD": - continue - } - cleanSources = append(cleanSources, source) - } - _, err := tx.Exec(`UPDATE filter SET sources = ? WHERE id = ?`, pq.Array(cleanSources), d.id) - if err != nil { - return err - } - - } - - return nil -} diff --git a/internal/database/postgres.go b/internal/database/postgres.go new file mode 100644 index 0000000..749e2df --- /dev/null +++ b/internal/database/postgres.go @@ -0,0 +1,83 @@ +package database + +import ( + "database/sql" + "errors" + "fmt" + + _ "github.com/lib/pq" + "github.com/rs/zerolog/log" +) + +func (db *DB) openPostgres() error { + var err error + + // open database connection + if db.handler, err = sql.Open("postgres", db.DSN); err != nil { + log.Fatal().Err(err).Msg("could not open postgres connection") + return err + } + + err = db.handler.Ping() + if err != nil { + log.Fatal().Err(err).Msg("could not ping postgres database") + return err + } + + // migrate db + if err = db.migratePostgres(); err != nil { + log.Fatal().Err(err).Msg("could not migrate postgres database") + return err + } + + return nil +} + +func (db *DB) migratePostgres() error { + tx, err := db.handler.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + initialSchema := `CREATE TABLE IF NOT EXISTS schema_migrations ( + id INTEGER PRIMARY KEY, + version INTEGER NOT NULL +);` + + if _, err := tx.Exec(initialSchema); err != nil { + return fmt.Errorf("failed to create schema_migrations table: %s", err) + } + + var version int + err = tx.QueryRow(`SELECT version FROM schema_migrations`).Scan(&version) + if err != nil && !errors.Is(err, sql.ErrNoRows) { + return err + } + + if version == len(migrations) { + return nil + } + if version > len(migrations) { + return fmt.Errorf("old") + } + + if version == 0 { + if _, err := tx.Exec(schema); err != nil { + return fmt.Errorf("failed to initialize schema: %v", err) + } + } else { + for i := version; i < len(migrations); i++ { + if _, err := tx.Exec(migrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + _, err = tx.Exec(`INSERT INTO schema_migrations (id, version) VALUES (1, $1) ON CONFLICT (id) DO UPDATE SET version = $1`, len(migrations)) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} diff --git a/internal/database/release.go b/internal/database/release.go index 26ee396..7fd4286 100644 --- a/internal/database/release.go +++ b/internal/database/release.go @@ -11,10 +11,10 @@ import ( ) type ReleaseRepo struct { - db *SqliteDB + db *DB } -func NewReleaseRepo(db *SqliteDB) domain.ReleaseRepo { +func NewReleaseRepo(db *DB) domain.ReleaseRepo { return &ReleaseRepo{db: db} } diff --git a/internal/database/sqlite.go b/internal/database/sqlite.go index 13ba047..e4e546e 100644 --- a/internal/database/sqlite.go +++ b/internal/database/sqlite.go @@ -1,35 +1,15 @@ package database import ( - "context" "database/sql" "fmt" - "sync" + "github.com/lib/pq" "github.com/rs/zerolog/log" _ "modernc.org/sqlite" ) -type SqliteDB struct { - lock sync.RWMutex - handler *sql.DB - ctx context.Context - cancel func() - - DSN string -} - -func NewSqliteDB(source string) *SqliteDB { - db := &SqliteDB{ - DSN: dataSourceName(source, "autobrr.db"), - } - - db.ctx, db.cancel = context.WithCancel(context.Background()) - - return db -} - -func (db *SqliteDB) Open() error { +func (db *DB) openSQLite() error { if db.DSN == "" { return fmt.Errorf("DSN required") } @@ -61,7 +41,7 @@ func (db *SqliteDB) Open() error { } // migrate db - if err = db.migrate(); err != nil { + if err = db.migrateSQLite(); err != nil { log.Fatal().Err(err).Msg("could not migrate db") return err } @@ -69,30 +49,143 @@ func (db *SqliteDB) Open() error { return nil } -func (db *SqliteDB) Close() error { - // cancel background context - db.cancel() +func (db *DB) migrateSQLite() error { + db.lock.Lock() + defer db.lock.Unlock() - // close database - if db.handler != nil { - return db.handler.Close() + var version int + if err := db.handler.QueryRow("PRAGMA user_version").Scan(&version); err != nil { + return fmt.Errorf("failed to query schema version: %v", err) } + + if version == len(migrations) { + return nil + } else if version > len(migrations) { + return fmt.Errorf("autobrr (version %d) older than schema (version: %d)", len(migrations), version) + } + + tx, err := db.handler.Begin() + if err != nil { + return err + } + defer tx.Rollback() + + if version == 0 { + if _, err := tx.Exec(schema); err != nil { + return fmt.Errorf("failed to initialize schema: %v", err) + } + } else { + for i := version; i < len(migrations); i++ { + if _, err := tx.Exec(migrations[i]); err != nil { + return fmt.Errorf("failed to execute migration #%v: %v", i, err) + } + } + } + + // temp custom data migration + // get data from filter.sources, check if specific types, move to new table and clear + // if migration 6 + // TODO 2022-01-30 remove this in future version + if version == 5 && len(migrations) == 6 { + if err := customMigrateCopySourcesToMedia(tx); err != nil { + return fmt.Errorf("could not run custom data migration: %v", err) + } + } + + _, err = tx.Exec(fmt.Sprintf("PRAGMA user_version = %d", len(migrations))) + if err != nil { + return fmt.Errorf("failed to bump schema version: %v", err) + } + + return tx.Commit() +} + +// customMigrateCopySourcesToMedia move music specific sources to media +func customMigrateCopySourcesToMedia(tx *sql.Tx) error { + rows, err := tx.Query(` + SELECT id, sources + FROM filter + WHERE sources LIKE '%"CD"%' + OR sources LIKE '%"WEB"%' + OR sources LIKE '%"DVD"%' + OR sources LIKE '%"Vinyl"%' + OR sources LIKE '%"Soundboard"%' + OR sources LIKE '%"DAT"%' + OR sources LIKE '%"Cassette"%' + OR sources LIKE '%"Blu-Ray"%' + OR sources LIKE '%"SACD"%' + ;`) + if err != nil { + return fmt.Errorf("could not run custom data migration: %v", err) + } + + defer rows.Close() + + type tmpDataStruct struct { + id int + sources []string + } + + var tmpData []tmpDataStruct + + // scan data + for rows.Next() { + var t tmpDataStruct + + if err := rows.Scan(&t.id, pq.Array(&t.sources)); err != nil { + return err + } + + tmpData = append(tmpData, t) + } + if err := rows.Err(); err != nil { + return err + } + + // manipulate data + for _, d := range tmpData { + // create new slice with only music source if they exist in d.sources + mediaSources := []string{} + for _, source := range d.sources { + switch source { + case "CD": + mediaSources = append(mediaSources, source) + case "DVD": + mediaSources = append(mediaSources, source) + case "Vinyl": + mediaSources = append(mediaSources, source) + case "Soundboard": + mediaSources = append(mediaSources, source) + case "DAT": + mediaSources = append(mediaSources, source) + case "Cassette": + mediaSources = append(mediaSources, source) + case "Blu-Ray": + mediaSources = append(mediaSources, source) + case "SACD": + mediaSources = append(mediaSources, source) + } + } + _, err = tx.Exec(`UPDATE filter SET media = ? WHERE id = ?`, pq.Array(mediaSources), d.id) + if err != nil { + return err + } + + // remove all music specific sources + cleanSources := []string{} + for _, source := range d.sources { + switch source { + case "CD", "WEB", "DVD", "Vinyl", "Soundboard", "DAT", "Cassette", "Blu-Ray", "SACD": + continue + } + cleanSources = append(cleanSources, source) + } + _, err := tx.Exec(`UPDATE filter SET sources = ? WHERE id = ?`, pq.Array(cleanSources), d.id) + if err != nil { + return err + } + + } + return nil } - -func (db *SqliteDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Tx, error) { - tx, err := db.handler.BeginTx(ctx, opts) - if err != nil { - return nil, err - } - - return &Tx{ - Tx: tx, - handler: db, - }, nil -} - -type Tx struct { - *sql.Tx - handler *SqliteDB -} diff --git a/internal/database/user.go b/internal/database/user.go index bb1127d..71e12b4 100644 --- a/internal/database/user.go +++ b/internal/database/user.go @@ -8,10 +8,10 @@ import ( ) type UserRepo struct { - db *SqliteDB + db *DB } -func NewUserRepo(db *SqliteDB) domain.UserRepo { +func NewUserRepo(db *DB) domain.UserRepo { return &UserRepo{db: db} } diff --git a/internal/domain/config.go b/internal/domain/config.go index 4cb3a73..9e69ea4 100644 --- a/internal/domain/config.go +++ b/internal/domain/config.go @@ -1,6 +1,7 @@ package domain type Config struct { + ConfigPath string Host string `toml:"host"` Port int `toml:"port"` LogLevel string `toml:"logLevel"` @@ -8,4 +9,10 @@ type Config struct { BaseURL string `toml:"baseUrl"` SessionSecret string `toml:"sessionSecret"` CustomDefinitions string `toml:"customDefinitions"` + DatabaseType string `toml:"databaseType"` + PostgresHost string `toml:"postgresHost"` + PostgresPort int `toml:"postgresPort"` + PostgresDatabase string `toml:"postgresDatabase"` + PostgresUser string `toml:"postgresUser"` + PostgresPass string `toml:"postgresPass"` }