diff --git a/cmd/autobrrctl/main.go b/cmd/autobrrctl/main.go index 6edaadc..88261ff 100644 --- a/cmd/autobrrctl/main.go +++ b/cmd/autobrrctl/main.go @@ -17,22 +17,35 @@ import ( "github.com/autobrr/autobrr/internal/auth" "github.com/autobrr/autobrr/internal/config" "github.com/autobrr/autobrr/internal/database" + "github.com/autobrr/autobrr/internal/database/tools" "github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/logger" "github.com/autobrr/autobrr/internal/user" "github.com/autobrr/autobrr/pkg/errors" + _ "github.com/lib/pq" "golang.org/x/term" - _ "modernc.org/sqlite" ) -const usage = `usage: autobrrctl --config path +const usage = `usage: autobrrctl [options] - create-user Create user - change-password Change password for user - version Can be run without --config - help Show this help message +Actions: + create-user Create a new user + change-password Change the password + db:seed --db-path --seed-db Seed the sqlite database + db:reset --db-path --seed-db Reset the sqlite database + db:convert --sqlite-db --postgres-url Convert SQLite to Postgres + version Display the version of autobrrctl + help Show this help message +Examples: + autobrrctl --config /path/to/config/dir create-user john + autobrrctl --config /path/to/config/dir change-password john + autobrrctl db:reset --db-path /path/to/autobrr.db --seed-db /path/to/seed + autobrrctl db:seed --db-path /path/to/autobrr.db --seed-db /path/to/seed + autobrrctl db:convert --sqlite-db /path/to/autobrr.db --postgres-url postgres://username:password@127.0.0.1:5432/autobrr + autobrrctl version + autobrrctl help ` var ( @@ -56,6 +69,7 @@ func main() { flag.Parse() switch cmd := flag.Arg(0); cmd { + case "version": fmt.Printf("Version: %v\nCommit: %v\nBuild: %v\n", version, commit, date) @@ -91,7 +105,6 @@ func main() { fmt.Printf("Latest release: %v\n", rel.TagName) case "create-user": - if configPath == "" { log.Fatal("--config required") } @@ -141,7 +154,6 @@ func main() { } case "change-password": - if configPath == "" { log.Fatal("--config required") } @@ -204,6 +216,68 @@ func main() { log.Printf("successfully updated password for user %q", username) + case "db:convert": + var sqliteDBPath, postgresDBURL string + migrateFlagSet := flag.NewFlagSet("db:convert", flag.ExitOnError) + migrateFlagSet.StringVar(&sqliteDBPath, "sqlite-db", "", "path to SQLite database file") + migrateFlagSet.StringVar(&postgresDBURL, "postgres-url", "", "URL for PostgreSQL database") + + if err := migrateFlagSet.Parse(flag.Args()[1:]); err != nil { + fmt.Printf("Error parsing flags for db:convert: %v\n", err) + migrateFlagSet.Usage() + os.Exit(1) + } + + if sqliteDBPath == "" || postgresDBURL == "" { + fmt.Println("Error: missing required flags for db:convert") + flag.Usage() + os.Exit(1) + } + + c := tools.NewConverter(sqliteDBPath, postgresDBURL) + if err := c.Convert(); err != nil { + log.Fatalf("database conversion failed: %v", err) + } + + case "db:seed", "db:reset": + var dbPath, seedDBPath string + seedResetFlagSet := flag.NewFlagSet("db:seed/db:reset", flag.ExitOnError) + seedResetFlagSet.StringVar(&dbPath, "db-path", "", "path to the database file") + seedResetFlagSet.StringVar(&seedDBPath, "seed-db", "", "path to SQL seed file") + + if err := seedResetFlagSet.Parse(flag.Args()[1:]); err != nil { + fmt.Printf("Error parsing flags for db:seed or db:reset: %v\n", err) + seedResetFlagSet.Usage() + os.Exit(1) + } + + if dbPath == "" || seedDBPath == "" { + fmt.Println("Error: missing required flags for db:seed or db:reset") + flag.Usage() + os.Exit(1) + } + + s := tools.NewSQLiteSeeder(dbPath, seedDBPath) + + if cmd == "db:seed" { + if err := s.Seed(); err != nil { + fmt.Println("Error seeding the database:", err) + os.Exit(1) + } + fmt.Println("Database seeding completed successfully!") + } else { + if err := s.Reset(); err != nil { + fmt.Println("Error resetting the database:", err) + os.Exit(1) + } + + if err := s.Seed(); err != nil { + fmt.Println("Error seeding the database:", err) + os.Exit(1) + } + fmt.Println("Database reset and reseed completed successfully!") + } + default: flag.Usage() if cmd != "help" { diff --git a/internal/database/tools/convert.go b/internal/database/tools/convert.go new file mode 100644 index 0000000..6d6dfb5 --- /dev/null +++ b/internal/database/tools/convert.go @@ -0,0 +1,163 @@ +package tools + +import ( + "database/sql" + "fmt" + "log" + "strings" + "time" + + _ "modernc.org/sqlite" +) + +var tables = []string{ + "action", + "api_key", + "client", + "feed", + "filter", + "filter_external", + "filter_indexer", + "indexer", + "irc_channel", + "irc_network", + "notification", + "release", + "release_action_status", + "users", +} + +type Converter interface { + Convert() error +} + +type SqliteToPostgresConverter struct { + sqliteDBPath, postgresDBURL string +} + +func NewConverter(sqliteDBPath, postgresDBURL string) Converter { + return &SqliteToPostgresConverter{ + sqliteDBPath: sqliteDBPath, + postgresDBURL: postgresDBURL, + } +} + +func (c *SqliteToPostgresConverter) Convert() error { + startTime := time.Now() + + sqliteDB, err := sql.Open("sqlite", c.sqliteDBPath) + if err != nil { + log.Fatalf("Failed to connect to SQLite database: %v", err) + } + defer sqliteDB.Close() + + postgresDB, err := sql.Open("postgres", c.postgresDBURL) + if err != nil { + log.Fatalf("Failed to connect to PostgreSQL database: %v", err) + } + defer postgresDB.Close() + + tables := GetTables() + + // Store all foreign key violation messages. + var allFKViolations []string + for _, table := range tables { + fkViolations := c.migrateTable(sqliteDB, postgresDB, table) + allFKViolations = append(allFKViolations, fkViolations...) + } + + c.printConversionResult(startTime, allFKViolations) + + return err +} + +func (c *SqliteToPostgresConverter) printConversionResult(startTime time.Time, allFKViolations []string) { + var sb strings.Builder + + sb.WriteString("Convert completed successfully!\n") + sb.WriteString(fmt.Sprintf("Elapsed time: %s\n", time.Since(startTime))) + if len(allFKViolations) > 0 { + sb.WriteString("\nSummary of Foreign Key Violations:\n\n") + for _, msg := range allFKViolations { + sb.WriteString(" - " + msg + "\n") + } + sb.WriteString("\nThese are due to missing references, likely because the related item in another table no longer exists.\n") + } + fmt.Print(sb.String()) +} + +func GetTables() []string { + return append([]string(nil), tables...) +} + +func (c *SqliteToPostgresConverter) migrateTable(sqliteDB, postgresDB *sql.DB, table string) []string { + var fkViolationMessages []string + + rows, err := sqliteDB.Query("SELECT * FROM ?", table) + if err != nil { + log.Fatalf("Failed to query SQLite table '%s': %v", table, err) + } + defer rows.Close() + + columns, err := rows.ColumnTypes() + if err != nil { + log.Fatalf("Failed to get column types for table '%s': %v", table, err) + } + + // Prepare the INSERT statement for PostgreSQL. + colNames, colPlaceholders := prepareColumns(columns) + insertStmt, err := postgresDB.Prepare(fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", table, colNames, colPlaceholders)) + if err != nil { + log.Fatalf("Failed to prepare INSERT statement for table '%s': %v", table, err) + } + defer insertStmt.Close() + + var rowsAffected int64 + + for rows.Next() { + values, valuePtrs := prepareValues(columns) + + if err := rows.Scan(valuePtrs...); err != nil { + log.Fatalf("Failed to scan row from SQLite table '%s': %v", table, err) + } + + _, err := insertStmt.Exec(values...) + if err != nil { + if isForeignKeyViolation(err) { + // Record foreign key violation message. + message := fmt.Sprintf("Table '%s': %v", table, err) + fkViolationMessages = append(fkViolationMessages, message) + continue + } + } else { + rowsAffected++ + } + } + log.Printf("Converted %d rows to table '%s' from SQLite to PostgreSQL\n", rowsAffected, table) + return fkViolationMessages +} + +func prepareColumns(columns []*sql.ColumnType) (colNames, colPlaceholders string) { + for i, col := range columns { + colNames += col.Name() + colPlaceholders += fmt.Sprintf("$%d", i+1) + if i < len(columns)-1 { + colNames += ", " + colPlaceholders += ", " + } + } + return +} + +func prepareValues(columns []*sql.ColumnType) ([]interface{}, []interface{}) { + values := make([]interface{}, len(columns)) + valuePtrs := make([]interface{}, len(columns)) + for i := range values { + valuePtrs[i] = &values[i] + } + return values, valuePtrs +} + +func isForeignKeyViolation(err error) bool { + return strings.Contains(err.Error(), "violates foreign key constraint") +} diff --git a/internal/database/tools/seed.go b/internal/database/tools/seed.go new file mode 100644 index 0000000..00ed9a5 --- /dev/null +++ b/internal/database/tools/seed.go @@ -0,0 +1,82 @@ +package tools + +import ( + "database/sql" + "fmt" + "os" + "strings" + + _ "modernc.org/sqlite" +) + +type Seeder interface { + Reset() error + Seed() error +} + +type SQLiteSeeder struct { + dbPath string + seedFile string +} + +func NewSQLiteSeeder(dbPath, seedFile string) *SQLiteSeeder { + return &SQLiteSeeder{ + dbPath: dbPath, + seedFile: seedFile, + } +} + +func (s *SQLiteSeeder) Reset() error { + db, err := sql.Open("sqlite", s.dbPath) + if err != nil { + return fmt.Errorf("failed to open %s database: %v", "sqlite", err) + } + defer db.Close() + + tables := GetTables() + + for _, table := range tables { + if err := s.resetTable(db, table); err != nil { + return err + } + } + + return nil +} + +func (s *SQLiteSeeder) resetTable(db *sql.DB, table string) error { + if _, err := db.Exec("DELETE FROM ?", table); err != nil { + return fmt.Errorf("failed to delete rows from table %s: %v", table, err) + } + + // Update sqlite_sequence, ignore errors for missing sqlite_sequence entry + if _, err := db.Exec("UPDATE sqlite_sequence SET seq = 0 WHERE name = ?", table); err != nil { + if !strings.Contains(err.Error(), "no such table") { + return fmt.Errorf("failed to reset primary key sequence for table %s: %v", table, err) + } + } + + return nil +} + +func (s *SQLiteSeeder) Seed() error { + sqlFile, err := os.ReadFile(s.seedFile) + if err != nil { + return fmt.Errorf("failed to read SQL file: %v", err) + } + + db, err := sql.Open("sqlite", s.dbPath) + if err != nil { + return fmt.Errorf("failed to open %s database: %v", "sqlite", err) + } + defer db.Close() + + sqlCommands := strings.Split(string(sqlFile), ";") + for _, cmd := range sqlCommands { + if _, err := db.Exec(cmd); err != nil { + return fmt.Errorf("failed to execute SQL command: %v", err) + } + } + + return nil +}