diff --git a/internal/action/qbittorrent.go b/internal/action/qbittorrent.go index b3d5e91..3355731 100644 --- a/internal/action/qbittorrent.go +++ b/internal/action/qbittorrent.go @@ -99,11 +99,12 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool, } qbtSettings := qbittorrent.Settings{ - Hostname: client.Host, - Port: uint(client.Port), - Username: client.Username, - Password: client.Password, - SSL: client.SSL, + Hostname: client.Host, + Port: uint(client.Port), + Username: client.Username, + Password: client.Password, + TLS: client.TLS, + TLSSkipVerify: client.TLSSkipVerify, } qbt := qbittorrent.NewClient(qbtSettings) diff --git a/internal/database/download_client.go b/internal/database/download_client.go index a4d9b7c..bc79e47 100644 --- a/internal/database/download_client.go +++ b/internal/database/download_client.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "encoding/json" + "sync" "github.com/autobrr/autobrr/internal/domain" @@ -16,6 +17,7 @@ type DownloadClientRepo struct { } type clientCache struct { + mu sync.RWMutex clients map[int]*domain.DownloadClient } @@ -26,10 +28,14 @@ func NewClientCache() *clientCache { } func (c *clientCache) Set(id int, client *domain.DownloadClient) { + c.mu.Lock() c.clients[id] = client + c.mu.Unlock() } func (c *clientCache) Get(id int) *domain.DownloadClient { + c.mu.RLock() + defer c.mu.Unlock() v, ok := c.clients[id] if ok { return v @@ -38,7 +44,9 @@ func (c *clientCache) Get(id int) *domain.DownloadClient { } func (c *clientCache) Pop(id int) { + c.mu.Lock() delete(c.clients, id) + c.mu.Unlock() } func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo { @@ -48,33 +56,32 @@ func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo { } } -func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) { +func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) { //r.db.lock.RLock() //defer r.db.lock.RUnlock() + clients := make([]domain.DownloadClient, 0) - rows, err := r.db.handler.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client") + rows, err := r.db.handler.QueryContext(ctx, "SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client") if err != nil { log.Error().Stack().Err(err).Msg("could not query download client rows") - return nil, err + return clients, err } defer rows.Close() - clients := make([]domain.DownloadClient, 0) - for rows.Next() { var f domain.DownloadClient var settingsJsonStr string - if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.SSL, &f.Username, &f.Password, &settingsJsonStr); err != nil { + if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.TLS, &f.TLSSkipVerify, &f.Username, &f.Password, &settingsJsonStr); err != nil { log.Error().Stack().Err(err).Msg("could not scan download client to struct") - return nil, err + return clients, err } if settingsJsonStr != "" { if err := json.Unmarshal([]byte(settingsJsonStr), &f.Settings); err != nil { log.Error().Stack().Err(err).Msgf("could not marshal download client settings %v", settingsJsonStr) - return nil, err + return clients, err } } @@ -82,7 +89,7 @@ func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) { } if err := rows.Err(); err != nil { log.Error().Stack().Err(err).Msg("could not query download client rows") - return nil, err + return clients, err } return clients, nil @@ -98,9 +105,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do return c, nil } - query := ` - SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client WHERE id = ? - ` + query := `SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client WHERE id = ?` row := r.db.handler.QueryRowContext(ctx, query, id) if err := row.Err(); err != nil { @@ -111,7 +116,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do var client domain.DownloadClient var settingsJsonStr string - if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.SSL, &client.Username, &client.Password, &settingsJsonStr); err != nil { + if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.TLS, &client.TLSSkipVerify, &client.Username, &client.Password, &settingsJsonStr); err != nil { log.Error().Stack().Err(err).Msg("could not scan download client to struct") return nil, err } @@ -126,7 +131,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do return &client, nil } -func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.DownloadClient, error) { +func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { //r.db.lock.RLock() //defer r.db.lock.RUnlock() @@ -145,7 +150,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo } if client.ID != 0 { - _, err = r.db.handler.Exec(` + _, err = r.db.handler.ExecContext(ctx, ` UPDATE client SET @@ -154,7 +159,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo enabled = ?, host = ?, port = ?, - ssl = ?, + tls = ?, + tls_skip_verify = ?, username = ?, password = ?, settings = (?) @@ -165,7 +171,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo client.Enabled, client.Host, client.Port, - client.SSL, + client.TLS, + client.TLSSkipVerify, client.Username, client.Password, string(settingsJson), @@ -178,24 +185,26 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo } else { var res sql.Result - res, err = r.db.handler.Exec(`INSERT INTO + res, err = r.db.handler.ExecContext(ctx, `INSERT INTO client( name, type, enabled, host, port, - ssl, + tls, + tls_skip_verify, username, password, settings) - VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?) ON CONFLICT DO NOTHING`, + VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?, ?) ON CONFLICT DO NOTHING`, client.Name, client.Type, client.Enabled, client.Host, client.Port, - client.SSL, + client.TLS, + client.TLSSkipVerify, client.Username, client.Password, string(settingsJson), @@ -220,11 +229,11 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo return &client, nil } -func (r *DownloadClientRepo) Delete(clientID int) error { +func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error { //r.db.lock.RLock() //defer r.db.lock.RUnlock() - res, err := r.db.handler.Exec(`DELETE FROM client WHERE client.id = ?`, clientID) + res, err := r.db.handler.ExecContext(ctx, `DELETE FROM client WHERE client.id = ?`, clientID) if err != nil { log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID) return err @@ -234,7 +243,6 @@ func (r *DownloadClientRepo) Delete(clientID int) error { r.cache.Pop(clientID) rows, _ := res.RowsAffected() - if rows == 0 { return err } diff --git a/internal/database/migrate.go b/internal/database/migrate.go index 37b2714..3a98e1e 100644 --- a/internal/database/migrate.go +++ b/internal/database/migrate.go @@ -3,6 +3,7 @@ package database import ( "database/sql" "fmt" + "github.com/lib/pq" ) @@ -119,16 +120,17 @@ CREATE TABLE filter_indexer CREATE TABLE client ( - id INTEGER PRIMARY KEY, - name TEXT NOT NULL, - enabled BOOLEAN, - type TEXT, - host TEXT NOT NULL, - port INTEGER, - ssl BOOLEAN, - username TEXT, - password TEXT, - settings JSON + id INTEGER PRIMARY KEY, + name TEXT NOT NULL, + enabled BOOLEAN, + type TEXT, + host TEXT NOT NULL, + port INTEGER, + ssl BOOLEAN, + tls_skip_verify BOOLEAN, + username TEXT, + password TEXT, + settings JSON ); CREATE TABLE action @@ -345,6 +347,13 @@ var migrations = []string{ ALTER TABLE "filter" ADD COLUMN priority INTEGER DEFAULT 0 NOT NULL; `, + ` + ALTER TABLE "client" + ADD COLUMN tls_skip_verify BOOLEAN DEFAULT FALSE; + + ALTER TABLE "client" + RENAME COLUMN ssl TO tls; + `, } func (db *SqliteDB) migrate() error { diff --git a/internal/domain/client.go b/internal/domain/client.go index e8f8767..04aafcb 100644 --- a/internal/domain/client.go +++ b/internal/domain/client.go @@ -4,23 +4,24 @@ import "context" type DownloadClientRepo interface { //FindByActionID(actionID int) ([]DownloadClient, error) - List() ([]DownloadClient, error) + List(ctx context.Context) ([]DownloadClient, error) FindByID(ctx context.Context, id int32) (*DownloadClient, error) - Store(client DownloadClient) (*DownloadClient, error) - Delete(clientID int) error + Store(ctx context.Context, client DownloadClient) (*DownloadClient, error) + Delete(ctx context.Context, clientID int) error } type DownloadClient struct { - ID int `json:"id"` - Name string `json:"name"` - Type DownloadClientType `json:"type"` - Enabled bool `json:"enabled"` - Host string `json:"host"` - Port int `json:"port"` - SSL bool `json:"ssl"` - Username string `json:"username"` - Password string `json:"password"` - Settings DownloadClientSettings `json:"settings,omitempty"` + ID int `json:"id"` + Name string `json:"name"` + Type DownloadClientType `json:"type"` + Enabled bool `json:"enabled"` + Host string `json:"host"` + Port int `json:"port"` + TLS bool `json:"tls"` + TLSSkipVerify bool `json:"tls_skip_verify"` + Username string `json:"username"` + Password string `json:"password"` + Settings DownloadClientSettings `json:"settings,omitempty"` } type DownloadClientSettings struct { diff --git a/internal/download_client/connection.go b/internal/download_client/connection.go index f2aa1a4..c11380e 100644 --- a/internal/download_client/connection.go +++ b/internal/download_client/connection.go @@ -36,11 +36,12 @@ func (s *service) testConnection(client domain.DownloadClient) error { func (s *service) testQbittorrentConnection(client domain.DownloadClient) error { qbtSettings := qbittorrent.Settings{ - Hostname: client.Host, - Port: uint(client.Port), - Username: client.Username, - Password: client.Password, - SSL: client.SSL, + Hostname: client.Host, + Port: uint(client.Port), + Username: client.Username, + Password: client.Password, + TLS: client.TLS, + TLSSkipVerify: client.TLSSkipVerify, } qbt := qbittorrent.NewClient(qbtSettings) diff --git a/internal/download_client/service.go b/internal/download_client/service.go index 8fee056..a5d2b02 100644 --- a/internal/download_client/service.go +++ b/internal/download_client/service.go @@ -5,15 +5,13 @@ import ( "errors" "github.com/autobrr/autobrr/internal/domain" - - "github.com/rs/zerolog/log" ) type Service interface { - List() ([]domain.DownloadClient, error) + List(ctx context.Context) ([]domain.DownloadClient, error) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) - Store(client domain.DownloadClient) (*domain.DownloadClient, error) - Delete(clientID int) error + Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) + Delete(ctx context.Context, clientID int) error Test(client domain.DownloadClient) error } @@ -25,25 +23,15 @@ func NewService(repo domain.DownloadClientRepo) Service { return &service{repo: repo} } -func (s *service) List() ([]domain.DownloadClient, error) { - clients, err := s.repo.List() - if err != nil { - return nil, err - } - - return clients, nil +func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) { + return s.repo.List(ctx) } func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) { - client, err := s.repo.FindByID(ctx, id) - if err != nil { - return nil, err - } - - return client, nil + return s.repo.FindByID(ctx, id) } -func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, error) { +func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) { // validate data if client.Host == "" { return nil, errors.New("validation error: no host") @@ -52,22 +40,11 @@ func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, e } // store - c, err := s.repo.Store(client) - if err != nil { - return nil, err - } - - return c, nil + return s.repo.Store(ctx, client) } -func (s *service) Delete(clientID int) error { - if err := s.repo.Delete(clientID); err != nil { - return err - } - - log.Debug().Msgf("delete client: %v", clientID) - - return nil +func (s *service) Delete(ctx context.Context, clientID int) error { + return s.repo.Delete(ctx, clientID) } func (s *service) Test(client domain.DownloadClient) error { @@ -79,10 +56,5 @@ func (s *service) Test(client domain.DownloadClient) error { } // test - err := s.testConnection(client) - if err != nil { - return err - } - - return nil + return s.testConnection(client) } diff --git a/internal/http/download_client.go b/internal/http/download_client.go index e233528..1d90d9e 100644 --- a/internal/http/download_client.go +++ b/internal/http/download_client.go @@ -1,7 +1,9 @@ package http import ( + "context" "encoding/json" + "errors" "net/http" "strconv" @@ -11,9 +13,9 @@ import ( ) type downloadClientService interface { - List() ([]domain.DownloadClient, error) - Store(client domain.DownloadClient) (*domain.DownloadClient, error) - Delete(clientID int) error + List(ctx context.Context) ([]domain.DownloadClient, error) + Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) + Delete(ctx context.Context, clientID int) error Test(client domain.DownloadClient) error } @@ -40,87 +42,84 @@ func (h downloadClientHandler) Routes(r chi.Router) { func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - clients, err := h.service.List() + clients, err := h.service.List(ctx) if err != nil { - // + h.encoder.Error(w, err) + return } h.encoder.StatusResponse(ctx, w, clients, http.StatusOK) } func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - data domain.DownloadClient - ) + var data domain.DownloadClient if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - // encode error + h.encoder.Error(w, err) return } - client, err := h.service.Store(data) + client, err := h.service.Store(r.Context(), data) if err != nil { - // encode error + h.encoder.Error(w, err) + return } - h.encoder.StatusResponse(ctx, w, client, http.StatusCreated) + h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated) } func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - data domain.DownloadClient - ) + var data domain.DownloadClient if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - // encode error - h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest) + h.encoder.Error(w, err) return } err := h.service.Test(data) if err != nil { - // encode error - h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest) + h.encoder.Error(w, err) return } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + h.encoder.NoContent(w) } func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - data domain.DownloadClient - ) + var data domain.DownloadClient if err := json.NewDecoder(r.Body).Decode(&data); err != nil { - // encode error + h.encoder.Error(w, err) return } - client, err := h.service.Store(data) + client, err := h.service.Store(r.Context(), data) if err != nil { - // encode error + h.encoder.Error(w, err) + return } - h.encoder.StatusResponse(ctx, w, client, http.StatusCreated) + h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated) } func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) { - var ( - ctx = r.Context() - clientID = chi.URLParam(r, "clientID") - ) + var clientID = chi.URLParam(r, "clientID") - // if !clientID return error - - id, _ := strconv.Atoi(clientID) - - if err := h.service.Delete(id); err != nil { - // encode error + if clientID == "" { + h.encoder.Error(w, errors.New("no clientID given")) + return } - h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) + id, err := strconv.Atoi(clientID) + if err != nil { + h.encoder.Error(w, err) + return + } + + if err = h.service.Delete(r.Context(), id); err != nil { + h.encoder.Error(w, err) + return + } + + h.encoder.NoContent(w) } diff --git a/pkg/qbittorrent/client.go b/pkg/qbittorrent/client.go index 8e5b4b7..1406c98 100644 --- a/pkg/qbittorrent/client.go +++ b/pkg/qbittorrent/client.go @@ -2,6 +2,7 @@ package qbittorrent import ( "bytes" + "crypto/tls" "fmt" "io" "mime/multipart" @@ -32,12 +33,13 @@ type Client struct { } type Settings struct { - Hostname string - Port uint - Username string - Password string - SSL bool - protocol string + Hostname string + Port uint + Username string + Password string + TLS bool + TLSSkipVerify bool + protocol string } func NewClient(s Settings) *Client { @@ -58,10 +60,19 @@ func NewClient(s Settings) *Client { } c.settings.protocol = "http" - if c.settings.SSL { + if c.settings.TLS { c.settings.protocol = "https" } + if c.settings.TLSSkipVerify { + //skip TLS verification + tr := &http.Transport{ + TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, + } + + c.http.Transport = tr + } + return c } diff --git a/web/src/forms/settings/DownloadClientForms.tsx b/web/src/forms/settings/DownloadClientForms.tsx index e2505e4..e3757ef 100644 --- a/web/src/forms/settings/DownloadClientForms.tsx +++ b/web/src/forms/settings/DownloadClientForms.tsx @@ -36,7 +36,8 @@ interface InitialValues { enabled: boolean; host: string; port: number; - ssl: boolean; + tls: boolean; + tls_skip_verify: boolean; username: string; password: string; settings: InitialValuesSettings; @@ -44,14 +45,24 @@ interface InitialValues { function FormFieldsDefault() { + const { + values: { tls }, + } = useFormikContext(); + return ( -
- +
+ + + {tls && ( + + + + )}
@@ -325,7 +336,8 @@ export function DownloadClientAddForm({ isOpen, toggle }: any) { enabled: true, host: "", port: 10000, - ssl: false, + tls: false, + tls_skip_verify: false, username: "", password: "", settings: {} @@ -512,7 +524,8 @@ export function DownloadClientUpdateForm({ client, isOpen, toggle }: any) { enabled: client.enabled, host: client.host, port: client.port, - ssl: client.ssl, + tls: client.tls, + tls_skip_verify: client.tls_skip_verify, username: client.username, password: client.password, settings: client.settings, diff --git a/web/src/types/Download.d.ts b/web/src/types/Download.d.ts index fec119c..8d5871e 100644 --- a/web/src/types/Download.d.ts +++ b/web/src/types/Download.d.ts @@ -33,7 +33,8 @@ interface DownloadClient { enabled: boolean; host: string; port: number; - ssl: boolean; + tls: boolean; + tls_skip_verify: boolean; username: string; password: string; settings?: DownloadClientSettings;