From 8b5e08727b5c6398051ac8be954c89635f26ef66 Mon Sep 17 00:00:00 2001 From: ze0s <43699394+zze0s@users.noreply.github.com> Date: Sun, 19 Nov 2023 22:16:46 +0100 Subject: [PATCH] fix(config): load from env vars (#995) * fix(config): load from env and bind * fix(config): remove unused imports * feat: add new postgres config as vars --- internal/config/config.go | 129 +++++++++++++++++++++++++++++++------- 1 file changed, 108 insertions(+), 21 deletions(-) diff --git a/internal/config/config.go b/internal/config/config.go index 1fb1c1b..9485c1b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,6 +10,7 @@ import ( "os" "path" "path/filepath" + "strconv" "strings" "sync" "text/template" @@ -176,6 +177,7 @@ func New(configPath string, version string) *AppConfig { c.Config.ConfigPath = configPath c.load(configPath) + c.loadFromEnv() return c } @@ -205,13 +207,108 @@ func (c *AppConfig) defaults() { } -func (c *AppConfig) load(configPath string) { - // or use viper.SetDefault(val, def) - //viper.SetDefault("host", config.Host) - //viper.SetDefault("port", config.Port) - //viper.SetDefault("logLevel", config.LogLevel) - //viper.SetDefault("logPath", config.LogPath) +func (c *AppConfig) loadFromEnv() { + prefix := "AUTOBRR__" + if v := os.Getenv(prefix + "HOST"); v != "" { + c.Config.Host = v + } + + if v := os.Getenv(prefix + "PORT"); v != "" { + i, _ := strconv.ParseInt(v, 10, 32) + if i > 0 { + c.Config.Port = int(i) + } + } + + if v := os.Getenv(prefix + "BASE_URL"); v != "" { + c.Config.BaseURL = v + } + + if v := os.Getenv(prefix + "LOG_LEVEL"); v != "" { + c.Config.LogLevel = v + } + + if v := os.Getenv(prefix + "LOG_PATH"); v != "" { + c.Config.LogPath = v + } + + if v := os.Getenv(prefix + "LOG_MAX_SIZE"); v != "" { + i, _ := strconv.ParseInt(v, 10, 32) + if i > 0 { + c.Config.LogMaxSize = int(i) + } + } + + if v := os.Getenv(prefix + "LOG_MAX_BACKUPS"); v != "" { + i, _ := strconv.ParseInt(v, 10, 32) + if i > 0 { + c.Config.LogMaxBackups = int(i) + } + } + + if v := os.Getenv(prefix + "SESSION_SECRET"); v != "" { + c.Config.SessionSecret = v + } + + if v := os.Getenv(prefix + "CUSTOM_DEFINITIONS"); v != "" { + c.Config.CustomDefinitions = v + } + + if v := os.Getenv(prefix + "CHECK_FOR_UPDATES"); v != "" { + c.Config.CheckForUpdates = strings.EqualFold(strings.ToLower(v), "true") + } + + if v := os.Getenv(prefix + "DATABASE_TYPE"); v != "" { + if validDatabaseType(v) { + c.Config.DatabaseType = v + } + } + + if v := os.Getenv(prefix + "POSTGRES_HOST"); v != "" { + c.Config.PostgresHost = v + } + + if v := os.Getenv(prefix + "POSTGRES_PORT"); v != "" { + i, _ := strconv.ParseInt(v, 10, 32) + if i > 0 { + c.Config.PostgresPort = int(i) + } + } + + if v := os.Getenv(prefix + "POSTGRES_DATABASE"); v != "" { + c.Config.PostgresDatabase = v + } + + if v := os.Getenv(prefix + "POSTGRES_USER"); v != "" { + c.Config.PostgresUser = v + } + + if v := os.Getenv(prefix + "POSTGRES_PASS"); v != "" { + c.Config.PostgresPass = v + } + + if v := os.Getenv(prefix + "POSTGRES_SSLMODE"); v != "" { + c.Config.PostgresSSLMode = v + } + + if v := os.Getenv(prefix + "POSTGRES_EXTRA_PARAMS"); v != "" { + c.Config.PostgresExtraParams = v + } +} + +func validDatabaseType(v string) bool { + valid := []string{"sqlite", "postgres"} + for _, s := range valid { + if s == v { + return true + } + } + + return false +} + +func (c *AppConfig) load(configPath string) { viper.SetConfigType("toml") // clean trailing slash from configPath @@ -236,21 +333,11 @@ func (c *AppConfig) load(configPath string) { viper.AddConfigPath("$HOME/.autobrr") } - viper.SetEnvPrefix("AUTOBRR") - // read config if err := viper.ReadInConfig(); err != nil { log.Printf("config read error: %q", err) } - for _, key := range viper.AllKeys() { - envKey := strings.ToUpper(strings.ReplaceAll(key, ".", "_")) - err := viper.BindEnv(key, "AUTOBRR__"+envKey) - if err != nil { - log.Fatal("config: unable to bind env: " + err.Error()) - } - } - if err := viper.Unmarshal(c.Config); err != nil { log.Fatalf("Could not unmarshal config file: %v: err %q", viper.ConfigFileUsed(), err) } @@ -278,19 +365,19 @@ func (c *AppConfig) DynamicReload(log logger.Logger) { } func (c *AppConfig) UpdateConfig() error { - file := path.Join(c.Config.ConfigPath, "config.toml") + filePath := path.Join(c.Config.ConfigPath, "config.toml") - f, err := os.ReadFile(file) + f, err := os.ReadFile(filePath) if err != nil { - return errors.Wrap(err, "could not read config file: %s", file) + return errors.Wrap(err, "could not read config file: %s", filePath) } lines := strings.Split(string(f), "\n") lines = c.processLines(lines) output := strings.Join(lines, "\n") - if err := os.WriteFile(file, []byte(output), 0644); err != nil { - return errors.Wrap(err, "could not write config file: %s", file) + if err := os.WriteFile(filePath, []byte(output), 0644); err != nil { + return errors.Wrap(err, "could not write config file: %s", filePath) } return nil