fix(config): load from env vars (#995)

* fix(config): load from env and bind

* fix(config): remove unused imports

* feat: add new postgres config as vars
This commit is contained in:
ze0s 2023-11-19 22:16:46 +01:00 committed by GitHub
parent 70a2f2d713
commit 8b5e08727b
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -10,6 +10,7 @@ import (
"os"
"path"
"path/filepath"
"strconv"
"strings"
"sync"
"text/template"
@ -176,6 +177,7 @@ func New(configPath string, version string) *AppConfig {
c.Config.ConfigPath = configPath
c.load(configPath)
c.loadFromEnv()
return c
}
@ -205,13 +207,108 @@ func (c *AppConfig) defaults() {
}
func (c *AppConfig) load(configPath string) {
// or use viper.SetDefault(val, def)
//viper.SetDefault("host", config.Host)
//viper.SetDefault("port", config.Port)
//viper.SetDefault("logLevel", config.LogLevel)
//viper.SetDefault("logPath", config.LogPath)
func (c *AppConfig) loadFromEnv() {
prefix := "AUTOBRR__"
if v := os.Getenv(prefix + "HOST"); v != "" {
c.Config.Host = v
}
if v := os.Getenv(prefix + "PORT"); v != "" {
i, _ := strconv.ParseInt(v, 10, 32)
if i > 0 {
c.Config.Port = int(i)
}
}
if v := os.Getenv(prefix + "BASE_URL"); v != "" {
c.Config.BaseURL = v
}
if v := os.Getenv(prefix + "LOG_LEVEL"); v != "" {
c.Config.LogLevel = v
}
if v := os.Getenv(prefix + "LOG_PATH"); v != "" {
c.Config.LogPath = v
}
if v := os.Getenv(prefix + "LOG_MAX_SIZE"); v != "" {
i, _ := strconv.ParseInt(v, 10, 32)
if i > 0 {
c.Config.LogMaxSize = int(i)
}
}
if v := os.Getenv(prefix + "LOG_MAX_BACKUPS"); v != "" {
i, _ := strconv.ParseInt(v, 10, 32)
if i > 0 {
c.Config.LogMaxBackups = int(i)
}
}
if v := os.Getenv(prefix + "SESSION_SECRET"); v != "" {
c.Config.SessionSecret = v
}
if v := os.Getenv(prefix + "CUSTOM_DEFINITIONS"); v != "" {
c.Config.CustomDefinitions = v
}
if v := os.Getenv(prefix + "CHECK_FOR_UPDATES"); v != "" {
c.Config.CheckForUpdates = strings.EqualFold(strings.ToLower(v), "true")
}
if v := os.Getenv(prefix + "DATABASE_TYPE"); v != "" {
if validDatabaseType(v) {
c.Config.DatabaseType = v
}
}
if v := os.Getenv(prefix + "POSTGRES_HOST"); v != "" {
c.Config.PostgresHost = v
}
if v := os.Getenv(prefix + "POSTGRES_PORT"); v != "" {
i, _ := strconv.ParseInt(v, 10, 32)
if i > 0 {
c.Config.PostgresPort = int(i)
}
}
if v := os.Getenv(prefix + "POSTGRES_DATABASE"); v != "" {
c.Config.PostgresDatabase = v
}
if v := os.Getenv(prefix + "POSTGRES_USER"); v != "" {
c.Config.PostgresUser = v
}
if v := os.Getenv(prefix + "POSTGRES_PASS"); v != "" {
c.Config.PostgresPass = v
}
if v := os.Getenv(prefix + "POSTGRES_SSLMODE"); v != "" {
c.Config.PostgresSSLMode = v
}
if v := os.Getenv(prefix + "POSTGRES_EXTRA_PARAMS"); v != "" {
c.Config.PostgresExtraParams = v
}
}
func validDatabaseType(v string) bool {
valid := []string{"sqlite", "postgres"}
for _, s := range valid {
if s == v {
return true
}
}
return false
}
func (c *AppConfig) load(configPath string) {
viper.SetConfigType("toml")
// clean trailing slash from configPath
@ -236,21 +333,11 @@ func (c *AppConfig) load(configPath string) {
viper.AddConfigPath("$HOME/.autobrr")
}
viper.SetEnvPrefix("AUTOBRR")
// read config
if err := viper.ReadInConfig(); err != nil {
log.Printf("config read error: %q", err)
}
for _, key := range viper.AllKeys() {
envKey := strings.ToUpper(strings.ReplaceAll(key, ".", "_"))
err := viper.BindEnv(key, "AUTOBRR__"+envKey)
if err != nil {
log.Fatal("config: unable to bind env: " + err.Error())
}
}
if err := viper.Unmarshal(c.Config); err != nil {
log.Fatalf("Could not unmarshal config file: %v: err %q", viper.ConfigFileUsed(), err)
}
@ -278,19 +365,19 @@ func (c *AppConfig) DynamicReload(log logger.Logger) {
}
func (c *AppConfig) UpdateConfig() error {
file := path.Join(c.Config.ConfigPath, "config.toml")
filePath := path.Join(c.Config.ConfigPath, "config.toml")
f, err := os.ReadFile(file)
f, err := os.ReadFile(filePath)
if err != nil {
return errors.Wrap(err, "could not read config file: %s", file)
return errors.Wrap(err, "could not read config file: %s", filePath)
}
lines := strings.Split(string(f), "\n")
lines = c.processLines(lines)
output := strings.Join(lines, "\n")
if err := os.WriteFile(file, []byte(output), 0644); err != nil {
return errors.Wrap(err, "could not write config file: %s", file)
if err := os.WriteFile(filePath, []byte(output), 0644); err != nil {
return errors.Wrap(err, "could not write config file: %s", filePath)
}
return nil