feat: download clients skip tls verify option (#181)

This commit is contained in:
Ludvig Lundgren 2022-03-17 20:57:27 +01:00 committed by GitHub
parent 8bf43dc1e0
commit bb9e51f9d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 150 deletions

View file

@ -103,7 +103,8 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
Port: uint(client.Port), Port: uint(client.Port),
Username: client.Username, Username: client.Username,
Password: client.Password, Password: client.Password,
SSL: client.SSL, TLS: client.TLS,
TLSSkipVerify: client.TLSSkipVerify,
} }
qbt := qbittorrent.NewClient(qbtSettings) qbt := qbittorrent.NewClient(qbtSettings)

View file

@ -4,6 +4,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"sync"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
@ -16,6 +17,7 @@ type DownloadClientRepo struct {
} }
type clientCache struct { type clientCache struct {
mu sync.RWMutex
clients map[int]*domain.DownloadClient clients map[int]*domain.DownloadClient
} }
@ -26,10 +28,14 @@ func NewClientCache() *clientCache {
} }
func (c *clientCache) Set(id int, client *domain.DownloadClient) { func (c *clientCache) Set(id int, client *domain.DownloadClient) {
c.mu.Lock()
c.clients[id] = client c.clients[id] = client
c.mu.Unlock()
} }
func (c *clientCache) Get(id int) *domain.DownloadClient { func (c *clientCache) Get(id int) *domain.DownloadClient {
c.mu.RLock()
defer c.mu.Unlock()
v, ok := c.clients[id] v, ok := c.clients[id]
if ok { if ok {
return v return v
@ -38,7 +44,9 @@ func (c *clientCache) Get(id int) *domain.DownloadClient {
} }
func (c *clientCache) Pop(id int) { func (c *clientCache) Pop(id int) {
c.mu.Lock()
delete(c.clients, id) delete(c.clients, id)
c.mu.Unlock()
} }
func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo { func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
@ -48,33 +56,32 @@ func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
} }
} }
func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) { func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) {
//r.db.lock.RLock() //r.db.lock.RLock()
//defer r.db.lock.RUnlock() //defer r.db.lock.RUnlock()
clients := make([]domain.DownloadClient, 0)
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client") rows, err := r.db.handler.QueryContext(ctx, "SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client")
if err != nil { if err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows") log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err return clients, err
} }
defer rows.Close() defer rows.Close()
clients := make([]domain.DownloadClient, 0)
for rows.Next() { for rows.Next() {
var f domain.DownloadClient var f domain.DownloadClient
var settingsJsonStr string var settingsJsonStr string
if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.SSL, &f.Username, &f.Password, &settingsJsonStr); err != nil { if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.TLS, &f.TLSSkipVerify, &f.Username, &f.Password, &settingsJsonStr); err != nil {
log.Error().Stack().Err(err).Msg("could not scan download client to struct") log.Error().Stack().Err(err).Msg("could not scan download client to struct")
return nil, err return clients, err
} }
if settingsJsonStr != "" { if settingsJsonStr != "" {
if err := json.Unmarshal([]byte(settingsJsonStr), &f.Settings); err != nil { if err := json.Unmarshal([]byte(settingsJsonStr), &f.Settings); err != nil {
log.Error().Stack().Err(err).Msgf("could not marshal download client settings %v", settingsJsonStr) log.Error().Stack().Err(err).Msgf("could not marshal download client settings %v", settingsJsonStr)
return nil, err return clients, err
} }
} }
@ -82,7 +89,7 @@ func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
} }
if err := rows.Err(); err != nil { if err := rows.Err(); err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows") log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err return clients, err
} }
return clients, nil return clients, nil
@ -98,9 +105,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return c, nil return c, nil
} }
query := ` query := `SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client WHERE id = ?`
SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client WHERE id = ?
`
row := r.db.handler.QueryRowContext(ctx, query, id) row := r.db.handler.QueryRowContext(ctx, query, id)
if err := row.Err(); err != nil { if err := row.Err(); err != nil {
@ -111,7 +116,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
var client domain.DownloadClient var client domain.DownloadClient
var settingsJsonStr string var settingsJsonStr string
if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.SSL, &client.Username, &client.Password, &settingsJsonStr); err != nil { if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.TLS, &client.TLSSkipVerify, &client.Username, &client.Password, &settingsJsonStr); err != nil {
log.Error().Stack().Err(err).Msg("could not scan download client to struct") log.Error().Stack().Err(err).Msg("could not scan download client to struct")
return nil, err return nil, err
} }
@ -126,7 +131,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return &client, nil return &client, nil
} }
func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.DownloadClient, error) { func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
//r.db.lock.RLock() //r.db.lock.RLock()
//defer r.db.lock.RUnlock() //defer r.db.lock.RUnlock()
@ -145,7 +150,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
} }
if client.ID != 0 { if client.ID != 0 {
_, err = r.db.handler.Exec(` _, err = r.db.handler.ExecContext(ctx, `
UPDATE UPDATE
client client
SET SET
@ -154,7 +159,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
enabled = ?, enabled = ?,
host = ?, host = ?,
port = ?, port = ?,
ssl = ?, tls = ?,
tls_skip_verify = ?,
username = ?, username = ?,
password = ?, password = ?,
settings = (?) settings = (?)
@ -165,7 +171,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
client.Enabled, client.Enabled,
client.Host, client.Host,
client.Port, client.Port,
client.SSL, client.TLS,
client.TLSSkipVerify,
client.Username, client.Username,
client.Password, client.Password,
string(settingsJson), string(settingsJson),
@ -178,24 +185,26 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
} else { } else {
var res sql.Result var res sql.Result
res, err = r.db.handler.Exec(`INSERT INTO res, err = r.db.handler.ExecContext(ctx, `INSERT INTO
client( client(
name, name,
type, type,
enabled, enabled,
host, host,
port, port,
ssl, tls,
tls_skip_verify,
username, username,
password, password,
settings) settings)
VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?) ON CONFLICT DO NOTHING`, VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?, ?) ON CONFLICT DO NOTHING`,
client.Name, client.Name,
client.Type, client.Type,
client.Enabled, client.Enabled,
client.Host, client.Host,
client.Port, client.Port,
client.SSL, client.TLS,
client.TLSSkipVerify,
client.Username, client.Username,
client.Password, client.Password,
string(settingsJson), string(settingsJson),
@ -220,11 +229,11 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
return &client, nil return &client, nil
} }
func (r *DownloadClientRepo) Delete(clientID int) error { func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
//r.db.lock.RLock() //r.db.lock.RLock()
//defer r.db.lock.RUnlock() //defer r.db.lock.RUnlock()
res, err := r.db.handler.Exec(`DELETE FROM client WHERE client.id = ?`, clientID) res, err := r.db.handler.ExecContext(ctx, `DELETE FROM client WHERE client.id = ?`, clientID)
if err != nil { if err != nil {
log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID) log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID)
return err return err
@ -234,7 +243,6 @@ func (r *DownloadClientRepo) Delete(clientID int) error {
r.cache.Pop(clientID) r.cache.Pop(clientID)
rows, _ := res.RowsAffected() rows, _ := res.RowsAffected()
if rows == 0 { if rows == 0 {
return err return err
} }

View file

@ -3,6 +3,7 @@ package database
import ( import (
"database/sql" "database/sql"
"fmt" "fmt"
"github.com/lib/pq" "github.com/lib/pq"
) )
@ -126,6 +127,7 @@ CREATE TABLE client
host TEXT NOT NULL, host TEXT NOT NULL,
port INTEGER, port INTEGER,
ssl BOOLEAN, ssl BOOLEAN,
tls_skip_verify BOOLEAN,
username TEXT, username TEXT,
password TEXT, password TEXT,
settings JSON settings JSON
@ -345,6 +347,13 @@ var migrations = []string{
ALTER TABLE "filter" ALTER TABLE "filter"
ADD COLUMN priority INTEGER DEFAULT 0 NOT NULL; ADD COLUMN priority INTEGER DEFAULT 0 NOT NULL;
`, `,
`
ALTER TABLE "client"
ADD COLUMN tls_skip_verify BOOLEAN DEFAULT FALSE;
ALTER TABLE "client"
RENAME COLUMN ssl TO tls;
`,
} }
func (db *SqliteDB) migrate() error { func (db *SqliteDB) migrate() error {

View file

@ -4,10 +4,10 @@ import "context"
type DownloadClientRepo interface { type DownloadClientRepo interface {
//FindByActionID(actionID int) ([]DownloadClient, error) //FindByActionID(actionID int) ([]DownloadClient, error)
List() ([]DownloadClient, error) List(ctx context.Context) ([]DownloadClient, error)
FindByID(ctx context.Context, id int32) (*DownloadClient, error) FindByID(ctx context.Context, id int32) (*DownloadClient, error)
Store(client DownloadClient) (*DownloadClient, error) Store(ctx context.Context, client DownloadClient) (*DownloadClient, error)
Delete(clientID int) error Delete(ctx context.Context, clientID int) error
} }
type DownloadClient struct { type DownloadClient struct {
@ -17,7 +17,8 @@ type DownloadClient struct {
Enabled bool `json:"enabled"` Enabled bool `json:"enabled"`
Host string `json:"host"` Host string `json:"host"`
Port int `json:"port"` Port int `json:"port"`
SSL bool `json:"ssl"` TLS bool `json:"tls"`
TLSSkipVerify bool `json:"tls_skip_verify"`
Username string `json:"username"` Username string `json:"username"`
Password string `json:"password"` Password string `json:"password"`
Settings DownloadClientSettings `json:"settings,omitempty"` Settings DownloadClientSettings `json:"settings,omitempty"`

View file

@ -40,7 +40,8 @@ func (s *service) testQbittorrentConnection(client domain.DownloadClient) error
Port: uint(client.Port), Port: uint(client.Port),
Username: client.Username, Username: client.Username,
Password: client.Password, Password: client.Password,
SSL: client.SSL, TLS: client.TLS,
TLSSkipVerify: client.TLSSkipVerify,
} }
qbt := qbittorrent.NewClient(qbtSettings) qbt := qbittorrent.NewClient(qbtSettings)

View file

@ -5,15 +5,13 @@ import (
"errors" "errors"
"github.com/autobrr/autobrr/internal/domain" "github.com/autobrr/autobrr/internal/domain"
"github.com/rs/zerolog/log"
) )
type Service interface { type Service interface {
List() ([]domain.DownloadClient, error) List(ctx context.Context) ([]domain.DownloadClient, error)
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
Store(client domain.DownloadClient) (*domain.DownloadClient, error) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(clientID int) error Delete(ctx context.Context, clientID int) error
Test(client domain.DownloadClient) error Test(client domain.DownloadClient) error
} }
@ -25,25 +23,15 @@ func NewService(repo domain.DownloadClientRepo) Service {
return &service{repo: repo} return &service{repo: repo}
} }
func (s *service) List() ([]domain.DownloadClient, error) { func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) {
clients, err := s.repo.List() return s.repo.List(ctx)
if err != nil {
return nil, err
}
return clients, nil
} }
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) { func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
client, err := s.repo.FindByID(ctx, id) return s.repo.FindByID(ctx, id)
if err != nil {
return nil, err
} }
return client, nil func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
}
func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, error) {
// validate data // validate data
if client.Host == "" { if client.Host == "" {
return nil, errors.New("validation error: no host") return nil, errors.New("validation error: no host")
@ -52,22 +40,11 @@ func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, e
} }
// store // store
c, err := s.repo.Store(client) return s.repo.Store(ctx, client)
if err != nil {
return nil, err
} }
return c, nil func (s *service) Delete(ctx context.Context, clientID int) error {
} return s.repo.Delete(ctx, clientID)
func (s *service) Delete(clientID int) error {
if err := s.repo.Delete(clientID); err != nil {
return err
}
log.Debug().Msgf("delete client: %v", clientID)
return nil
} }
func (s *service) Test(client domain.DownloadClient) error { func (s *service) Test(client domain.DownloadClient) error {
@ -79,10 +56,5 @@ func (s *service) Test(client domain.DownloadClient) error {
} }
// test // test
err := s.testConnection(client) return s.testConnection(client)
if err != nil {
return err
}
return nil
} }

View file

@ -1,7 +1,9 @@
package http package http
import ( import (
"context"
"encoding/json" "encoding/json"
"errors"
"net/http" "net/http"
"strconv" "strconv"
@ -11,9 +13,9 @@ import (
) )
type downloadClientService interface { type downloadClientService interface {
List() ([]domain.DownloadClient, error) List(ctx context.Context) ([]domain.DownloadClient, error)
Store(client domain.DownloadClient) (*domain.DownloadClient, error) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(clientID int) error Delete(ctx context.Context, clientID int) error
Test(client domain.DownloadClient) error Test(client domain.DownloadClient) error
} }
@ -40,87 +42,84 @@ func (h downloadClientHandler) Routes(r chi.Router) {
func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
clients, err := h.service.List() clients, err := h.service.List(ctx)
if err != nil { if err != nil {
// h.encoder.Error(w, err)
return
} }
h.encoder.StatusResponse(ctx, w, clients, http.StatusOK) h.encoder.StatusResponse(ctx, w, clients, http.StatusOK)
} }
func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) {
var ( var data domain.DownloadClient
ctx = r.Context()
data domain.DownloadClient
)
if err := json.NewDecoder(r.Body).Decode(&data); err != nil { if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error h.encoder.Error(w, err)
return return
} }
client, err := h.service.Store(data) client, err := h.service.Store(r.Context(), data)
if err != nil { if err != nil {
// encode error h.encoder.Error(w, err)
return
} }
h.encoder.StatusResponse(ctx, w, client, http.StatusCreated) h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated)
} }
func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) {
var ( var data domain.DownloadClient
ctx = r.Context()
data domain.DownloadClient
)
if err := json.NewDecoder(r.Body).Decode(&data); err != nil { if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error h.encoder.Error(w, err)
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
return return
} }
err := h.service.Test(data) err := h.service.Test(data)
if err != nil { if err != nil {
// encode error h.encoder.Error(w, err)
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
return return
} }
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) h.encoder.NoContent(w)
} }
func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) {
var ( var data domain.DownloadClient
ctx = r.Context()
data domain.DownloadClient
)
if err := json.NewDecoder(r.Body).Decode(&data); err != nil { if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error h.encoder.Error(w, err)
return return
} }
client, err := h.service.Store(data) client, err := h.service.Store(r.Context(), data)
if err != nil { if err != nil {
// encode error h.encoder.Error(w, err)
return
} }
h.encoder.StatusResponse(ctx, w, client, http.StatusCreated) h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated)
} }
func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) { func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) {
var ( var clientID = chi.URLParam(r, "clientID")
ctx = r.Context()
clientID = chi.URLParam(r, "clientID")
)
// if !clientID return error if clientID == "" {
h.encoder.Error(w, errors.New("no clientID given"))
id, _ := strconv.Atoi(clientID) return
if err := h.service.Delete(id); err != nil {
// encode error
} }
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent) id, err := strconv.Atoi(clientID)
if err != nil {
h.encoder.Error(w, err)
return
}
if err = h.service.Delete(r.Context(), id); err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
} }

View file

@ -2,6 +2,7 @@ package qbittorrent
import ( import (
"bytes" "bytes"
"crypto/tls"
"fmt" "fmt"
"io" "io"
"mime/multipart" "mime/multipart"
@ -36,7 +37,8 @@ type Settings struct {
Port uint Port uint
Username string Username string
Password string Password string
SSL bool TLS bool
TLSSkipVerify bool
protocol string protocol string
} }
@ -58,10 +60,19 @@ func NewClient(s Settings) *Client {
} }
c.settings.protocol = "http" c.settings.protocol = "http"
if c.settings.SSL { if c.settings.TLS {
c.settings.protocol = "https" c.settings.protocol = "https"
} }
if c.settings.TLSSkipVerify {
//skip TLS verification
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
c.http.Transport = tr
}
return c return c
} }

View file

@ -36,7 +36,8 @@ interface InitialValues {
enabled: boolean; enabled: boolean;
host: string; host: string;
port: number; port: number;
ssl: boolean; tls: boolean;
tls_skip_verify: boolean;
username: string; username: string;
password: string; password: string;
settings: InitialValuesSettings; settings: InitialValuesSettings;
@ -44,14 +45,24 @@ interface InitialValues {
function FormFieldsDefault() { function FormFieldsDefault() {
const {
values: { tls },
} = useFormikContext<InitialValues>();
return ( return (
<Fragment> <Fragment>
<TextFieldWide name="host" label="Host" help="Eg. client.domain.ltd, domain.ltd/client, domain.ltd:port" /> <TextFieldWide name="host" label="Host" help="Eg. client.domain.ltd, domain.ltd/client, domain.ltd:port" />
<NumberFieldWide name="port" label="Port" help="WebUI port for qBittorrent and daemon port for Deluge" /> <NumberFieldWide name="port" label="Port" help="WebUI port for qBittorrent and daemon port for Deluge" />
<div className="py-6 px-6 space-y-6 sm:py-0 sm:space-y-0 sm:divide-y sm:divide-gray-200"> <div className="py-6 px-6 space-y-6 sm:py-0 sm:space-y-0 sm:divide-y sm:divide-gray-200 dark:divide-gray-700">
<SwitchGroupWide name="ssl" label="SSL" /> <SwitchGroupWide name="tls" label="TLS" />
{tls && (
<Fragment>
<SwitchGroupWide name="tls_skip_verify" label="Skip TLS verification (insecure)" />
</Fragment>
)}
</div> </div>
<TextFieldWide name="username" label="Username" /> <TextFieldWide name="username" label="Username" />
@ -325,7 +336,8 @@ export function DownloadClientAddForm({ isOpen, toggle }: any) {
enabled: true, enabled: true,
host: "", host: "",
port: 10000, port: 10000,
ssl: false, tls: false,
tls_skip_verify: false,
username: "", username: "",
password: "", password: "",
settings: {} settings: {}
@ -512,7 +524,8 @@ export function DownloadClientUpdateForm({ client, isOpen, toggle }: any) {
enabled: client.enabled, enabled: client.enabled,
host: client.host, host: client.host,
port: client.port, port: client.port,
ssl: client.ssl, tls: client.tls,
tls_skip_verify: client.tls_skip_verify,
username: client.username, username: client.username,
password: client.password, password: client.password,
settings: client.settings, settings: client.settings,

View file

@ -33,7 +33,8 @@ interface DownloadClient {
enabled: boolean; enabled: boolean;
host: string; host: string;
port: number; port: number;
ssl: boolean; tls: boolean;
tls_skip_verify: boolean;
username: string; username: string;
password: string; password: string;
settings?: DownloadClientSettings; settings?: DownloadClientSettings;