feat: download clients skip tls verify option (#181)

This commit is contained in:
Ludvig Lundgren 2022-03-17 20:57:27 +01:00 committed by GitHub
parent 8bf43dc1e0
commit bb9e51f9d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 150 deletions

View file

@ -103,7 +103,8 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
Port: uint(client.Port),
Username: client.Username,
Password: client.Password,
SSL: client.SSL,
TLS: client.TLS,
TLSSkipVerify: client.TLSSkipVerify,
}
qbt := qbittorrent.NewClient(qbtSettings)

View file

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/autobrr/autobrr/internal/domain"
@ -16,6 +17,7 @@ type DownloadClientRepo struct {
}
type clientCache struct {
mu sync.RWMutex
clients map[int]*domain.DownloadClient
}
@ -26,10 +28,14 @@ func NewClientCache() *clientCache {
}
func (c *clientCache) Set(id int, client *domain.DownloadClient) {
c.mu.Lock()
c.clients[id] = client
c.mu.Unlock()
}
func (c *clientCache) Get(id int) *domain.DownloadClient {
c.mu.RLock()
defer c.mu.Unlock()
v, ok := c.clients[id]
if ok {
return v
@ -38,7 +44,9 @@ func (c *clientCache) Get(id int) *domain.DownloadClient {
}
func (c *clientCache) Pop(id int) {
c.mu.Lock()
delete(c.clients, id)
c.mu.Unlock()
}
func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
@ -48,33 +56,32 @@ func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
}
}
func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
clients := make([]domain.DownloadClient, 0)
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client")
rows, err := r.db.handler.QueryContext(ctx, "SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client")
if err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err
return clients, err
}
defer rows.Close()
clients := make([]domain.DownloadClient, 0)
for rows.Next() {
var f domain.DownloadClient
var settingsJsonStr string
if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.SSL, &f.Username, &f.Password, &settingsJsonStr); err != nil {
if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.TLS, &f.TLSSkipVerify, &f.Username, &f.Password, &settingsJsonStr); err != nil {
log.Error().Stack().Err(err).Msg("could not scan download client to struct")
return nil, err
return clients, err
}
if settingsJsonStr != "" {
if err := json.Unmarshal([]byte(settingsJsonStr), &f.Settings); err != nil {
log.Error().Stack().Err(err).Msgf("could not marshal download client settings %v", settingsJsonStr)
return nil, err
return clients, err
}
}
@ -82,7 +89,7 @@ func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
}
if err := rows.Err(); err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err
return clients, err
}
return clients, nil
@ -98,9 +105,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return c, nil
}
query := `
SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client WHERE id = ?
`
query := `SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client WHERE id = ?`
row := r.db.handler.QueryRowContext(ctx, query, id)
if err := row.Err(); err != nil {
@ -111,7 +116,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
var client domain.DownloadClient
var settingsJsonStr string
if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.SSL, &client.Username, &client.Password, &settingsJsonStr); err != nil {
if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.TLS, &client.TLSSkipVerify, &client.Username, &client.Password, &settingsJsonStr); err != nil {
log.Error().Stack().Err(err).Msg("could not scan download client to struct")
return nil, err
}
@ -126,7 +131,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return &client, nil
}
func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.DownloadClient, error) {
func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
@ -145,7 +150,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
}
if client.ID != 0 {
_, err = r.db.handler.Exec(`
_, err = r.db.handler.ExecContext(ctx, `
UPDATE
client
SET
@ -154,7 +159,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
enabled = ?,
host = ?,
port = ?,
ssl = ?,
tls = ?,
tls_skip_verify = ?,
username = ?,
password = ?,
settings = (?)
@ -165,7 +171,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
client.Enabled,
client.Host,
client.Port,
client.SSL,
client.TLS,
client.TLSSkipVerify,
client.Username,
client.Password,
string(settingsJson),
@ -178,24 +185,26 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
} else {
var res sql.Result
res, err = r.db.handler.Exec(`INSERT INTO
res, err = r.db.handler.ExecContext(ctx, `INSERT INTO
client(
name,
type,
enabled,
host,
port,
ssl,
tls,
tls_skip_verify,
username,
password,
settings)
VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?) ON CONFLICT DO NOTHING`,
VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?, ?) ON CONFLICT DO NOTHING`,
client.Name,
client.Type,
client.Enabled,
client.Host,
client.Port,
client.SSL,
client.TLS,
client.TLSSkipVerify,
client.Username,
client.Password,
string(settingsJson),
@ -220,11 +229,11 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
return &client, nil
}
func (r *DownloadClientRepo) Delete(clientID int) error {
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
res, err := r.db.handler.Exec(`DELETE FROM client WHERE client.id = ?`, clientID)
res, err := r.db.handler.ExecContext(ctx, `DELETE FROM client WHERE client.id = ?`, clientID)
if err != nil {
log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID)
return err
@ -234,7 +243,6 @@ func (r *DownloadClientRepo) Delete(clientID int) error {
r.cache.Pop(clientID)
rows, _ := res.RowsAffected()
if rows == 0 {
return err
}

View file

@ -3,6 +3,7 @@ package database
import (
"database/sql"
"fmt"
"github.com/lib/pq"
)
@ -126,6 +127,7 @@ CREATE TABLE client
host TEXT NOT NULL,
port INTEGER,
ssl BOOLEAN,
tls_skip_verify BOOLEAN,
username TEXT,
password TEXT,
settings JSON
@ -345,6 +347,13 @@ var migrations = []string{
ALTER TABLE "filter"
ADD COLUMN priority INTEGER DEFAULT 0 NOT NULL;
`,
`
ALTER TABLE "client"
ADD COLUMN tls_skip_verify BOOLEAN DEFAULT FALSE;
ALTER TABLE "client"
RENAME COLUMN ssl TO tls;
`,
}
func (db *SqliteDB) migrate() error {

View file

@ -4,10 +4,10 @@ import "context"
type DownloadClientRepo interface {
//FindByActionID(actionID int) ([]DownloadClient, error)
List() ([]DownloadClient, error)
List(ctx context.Context) ([]DownloadClient, error)
FindByID(ctx context.Context, id int32) (*DownloadClient, error)
Store(client DownloadClient) (*DownloadClient, error)
Delete(clientID int) error
Store(ctx context.Context, client DownloadClient) (*DownloadClient, error)
Delete(ctx context.Context, clientID int) error
}
type DownloadClient struct {
@ -17,7 +17,8 @@ type DownloadClient struct {
Enabled bool `json:"enabled"`
Host string `json:"host"`
Port int `json:"port"`
SSL bool `json:"ssl"`
TLS bool `json:"tls"`
TLSSkipVerify bool `json:"tls_skip_verify"`
Username string `json:"username"`
Password string `json:"password"`
Settings DownloadClientSettings `json:"settings,omitempty"`

View file

@ -40,7 +40,8 @@ func (s *service) testQbittorrentConnection(client domain.DownloadClient) error
Port: uint(client.Port),
Username: client.Username,
Password: client.Password,
SSL: client.SSL,
TLS: client.TLS,
TLSSkipVerify: client.TLSSkipVerify,
}
qbt := qbittorrent.NewClient(qbtSettings)

View file

@ -5,15 +5,13 @@ import (
"errors"
"github.com/autobrr/autobrr/internal/domain"
"github.com/rs/zerolog/log"
)
type Service interface {
List() ([]domain.DownloadClient, error)
List(ctx context.Context) ([]domain.DownloadClient, error)
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
Store(client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(clientID int) error
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(ctx context.Context, clientID int) error
Test(client domain.DownloadClient) error
}
@ -25,25 +23,15 @@ func NewService(repo domain.DownloadClientRepo) Service {
return &service{repo: repo}
}
func (s *service) List() ([]domain.DownloadClient, error) {
clients, err := s.repo.List()
if err != nil {
return nil, err
}
return clients, nil
func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) {
return s.repo.List(ctx)
}
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
client, err := s.repo.FindByID(ctx, id)
if err != nil {
return nil, err
}
return client, nil
return s.repo.FindByID(ctx, id)
}
func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, error) {
func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
// validate data
if client.Host == "" {
return nil, errors.New("validation error: no host")
@ -52,22 +40,11 @@ func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, e
}
// store
c, err := s.repo.Store(client)
if err != nil {
return nil, err
}
return c, nil
return s.repo.Store(ctx, client)
}
func (s *service) Delete(clientID int) error {
if err := s.repo.Delete(clientID); err != nil {
return err
}
log.Debug().Msgf("delete client: %v", clientID)
return nil
func (s *service) Delete(ctx context.Context, clientID int) error {
return s.repo.Delete(ctx, clientID)
}
func (s *service) Test(client domain.DownloadClient) error {
@ -79,10 +56,5 @@ func (s *service) Test(client domain.DownloadClient) error {
}
// test
err := s.testConnection(client)
if err != nil {
return err
}
return nil
return s.testConnection(client)
}

View file

@ -1,7 +1,9 @@
package http
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
@ -11,9 +13,9 @@ import (
)
type downloadClientService interface {
List() ([]domain.DownloadClient, error)
Store(client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(clientID int) error
List(ctx context.Context) ([]domain.DownloadClient, error)
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(ctx context.Context, clientID int) error
Test(client domain.DownloadClient) error
}
@ -40,87 +42,84 @@ func (h downloadClientHandler) Routes(r chi.Router) {
func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
clients, err := h.service.List()
clients, err := h.service.List(ctx)
if err != nil {
//
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, clients, http.StatusOK)
}
func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.DownloadClient
)
var data domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error
h.encoder.Error(w, err)
return
}
client, err := h.service.Store(data)
client, err := h.service.Store(r.Context(), data)
if err != nil {
// encode error
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, client, http.StatusCreated)
h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated)
}
func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.DownloadClient
)
var data domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
h.encoder.Error(w, err)
return
}
err := h.service.Test(data)
if err != nil {
// encode error
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
h.encoder.NoContent(w)
}
func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.DownloadClient
)
var data domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error
h.encoder.Error(w, err)
return
}
client, err := h.service.Store(data)
client, err := h.service.Store(r.Context(), data)
if err != nil {
// encode error
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, client, http.StatusCreated)
h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated)
}
func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
clientID = chi.URLParam(r, "clientID")
)
var clientID = chi.URLParam(r, "clientID")
// if !clientID return error
id, _ := strconv.Atoi(clientID)
if err := h.service.Delete(id); err != nil {
// encode error
if clientID == "" {
h.encoder.Error(w, errors.New("no clientID given"))
return
}
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
id, err := strconv.Atoi(clientID)
if err != nil {
h.encoder.Error(w, err)
return
}
if err = h.service.Delete(r.Context(), id); err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}

View file

@ -2,6 +2,7 @@ package qbittorrent
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"mime/multipart"
@ -36,7 +37,8 @@ type Settings struct {
Port uint
Username string
Password string
SSL bool
TLS bool
TLSSkipVerify bool
protocol string
}
@ -58,10 +60,19 @@ func NewClient(s Settings) *Client {
}
c.settings.protocol = "http"
if c.settings.SSL {
if c.settings.TLS {
c.settings.protocol = "https"
}
if c.settings.TLSSkipVerify {
//skip TLS verification
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
c.http.Transport = tr
}
return c
}

View file

@ -36,7 +36,8 @@ interface InitialValues {
enabled: boolean;
host: string;
port: number;
ssl: boolean;
tls: boolean;
tls_skip_verify: boolean;
username: string;
password: string;
settings: InitialValuesSettings;
@ -44,14 +45,24 @@ interface InitialValues {
function FormFieldsDefault() {
const {
values: { tls },
} = useFormikContext<InitialValues>();
return (
<Fragment>
<TextFieldWide name="host" label="Host" help="Eg. client.domain.ltd, domain.ltd/client, domain.ltd:port" />
<NumberFieldWide name="port" label="Port" help="WebUI port for qBittorrent and daemon port for Deluge" />
<div className="py-6 px-6 space-y-6 sm:py-0 sm:space-y-0 sm:divide-y sm:divide-gray-200">
<SwitchGroupWide name="ssl" label="SSL" />
<div className="py-6 px-6 space-y-6 sm:py-0 sm:space-y-0 sm:divide-y sm:divide-gray-200 dark:divide-gray-700">
<SwitchGroupWide name="tls" label="TLS" />
{tls && (
<Fragment>
<SwitchGroupWide name="tls_skip_verify" label="Skip TLS verification (insecure)" />
</Fragment>
)}
</div>
<TextFieldWide name="username" label="Username" />
@ -325,7 +336,8 @@ export function DownloadClientAddForm({ isOpen, toggle }: any) {
enabled: true,
host: "",
port: 10000,
ssl: false,
tls: false,
tls_skip_verify: false,
username: "",
password: "",
settings: {}
@ -512,7 +524,8 @@ export function DownloadClientUpdateForm({ client, isOpen, toggle }: any) {
enabled: client.enabled,
host: client.host,
port: client.port,
ssl: client.ssl,
tls: client.tls,
tls_skip_verify: client.tls_skip_verify,
username: client.username,
password: client.password,
settings: client.settings,

View file

@ -33,7 +33,8 @@ interface DownloadClient {
enabled: boolean;
host: string;
port: number;
ssl: boolean;
tls: boolean;
tls_skip_verify: boolean;
username: string;
password: string;
settings?: DownloadClientSettings;