feat: download clients skip tls verify option (#181)

This commit is contained in:
Ludvig Lundgren 2022-03-17 20:57:27 +01:00 committed by GitHub
parent 8bf43dc1e0
commit bb9e51f9d3
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 166 additions and 150 deletions

View file

@ -99,11 +99,12 @@ func (s *service) qbittorrentCheckRulesCanDownload(action domain.Action) (bool,
}
qbtSettings := qbittorrent.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Username: client.Username,
Password: client.Password,
SSL: client.SSL,
Hostname: client.Host,
Port: uint(client.Port),
Username: client.Username,
Password: client.Password,
TLS: client.TLS,
TLSSkipVerify: client.TLSSkipVerify,
}
qbt := qbittorrent.NewClient(qbtSettings)

View file

@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"encoding/json"
"sync"
"github.com/autobrr/autobrr/internal/domain"
@ -16,6 +17,7 @@ type DownloadClientRepo struct {
}
type clientCache struct {
mu sync.RWMutex
clients map[int]*domain.DownloadClient
}
@ -26,10 +28,14 @@ func NewClientCache() *clientCache {
}
func (c *clientCache) Set(id int, client *domain.DownloadClient) {
c.mu.Lock()
c.clients[id] = client
c.mu.Unlock()
}
func (c *clientCache) Get(id int) *domain.DownloadClient {
c.mu.RLock()
defer c.mu.Unlock()
v, ok := c.clients[id]
if ok {
return v
@ -38,7 +44,9 @@ func (c *clientCache) Get(id int) *domain.DownloadClient {
}
func (c *clientCache) Pop(id int) {
c.mu.Lock()
delete(c.clients, id)
c.mu.Unlock()
}
func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
@ -48,33 +56,32 @@ func NewDownloadClientRepo(db *SqliteDB) domain.DownloadClientRepo {
}
}
func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
func (r *DownloadClientRepo) List(ctx context.Context) ([]domain.DownloadClient, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
clients := make([]domain.DownloadClient, 0)
rows, err := r.db.handler.Query("SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client")
rows, err := r.db.handler.QueryContext(ctx, "SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client")
if err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err
return clients, err
}
defer rows.Close()
clients := make([]domain.DownloadClient, 0)
for rows.Next() {
var f domain.DownloadClient
var settingsJsonStr string
if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.SSL, &f.Username, &f.Password, &settingsJsonStr); err != nil {
if err := rows.Scan(&f.ID, &f.Name, &f.Type, &f.Enabled, &f.Host, &f.Port, &f.TLS, &f.TLSSkipVerify, &f.Username, &f.Password, &settingsJsonStr); err != nil {
log.Error().Stack().Err(err).Msg("could not scan download client to struct")
return nil, err
return clients, err
}
if settingsJsonStr != "" {
if err := json.Unmarshal([]byte(settingsJsonStr), &f.Settings); err != nil {
log.Error().Stack().Err(err).Msgf("could not marshal download client settings %v", settingsJsonStr)
return nil, err
return clients, err
}
}
@ -82,7 +89,7 @@ func (r *DownloadClientRepo) List() ([]domain.DownloadClient, error) {
}
if err := rows.Err(); err != nil {
log.Error().Stack().Err(err).Msg("could not query download client rows")
return nil, err
return clients, err
}
return clients, nil
@ -98,9 +105,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return c, nil
}
query := `
SELECT id, name, type, enabled, host, port, ssl, username, password, settings FROM client WHERE id = ?
`
query := `SELECT id, name, type, enabled, host, port, tls, tls_skip_verify, username, password, settings FROM client WHERE id = ?`
row := r.db.handler.QueryRowContext(ctx, query, id)
if err := row.Err(); err != nil {
@ -111,7 +116,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
var client domain.DownloadClient
var settingsJsonStr string
if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.SSL, &client.Username, &client.Password, &settingsJsonStr); err != nil {
if err := row.Scan(&client.ID, &client.Name, &client.Type, &client.Enabled, &client.Host, &client.Port, &client.TLS, &client.TLSSkipVerify, &client.Username, &client.Password, &settingsJsonStr); err != nil {
log.Error().Stack().Err(err).Msg("could not scan download client to struct")
return nil, err
}
@ -126,7 +131,7 @@ func (r *DownloadClientRepo) FindByID(ctx context.Context, id int32) (*domain.Do
return &client, nil
}
func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.DownloadClient, error) {
func (r *DownloadClientRepo) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
@ -145,7 +150,7 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
}
if client.ID != 0 {
_, err = r.db.handler.Exec(`
_, err = r.db.handler.ExecContext(ctx, `
UPDATE
client
SET
@ -154,7 +159,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
enabled = ?,
host = ?,
port = ?,
ssl = ?,
tls = ?,
tls_skip_verify = ?,
username = ?,
password = ?,
settings = (?)
@ -165,7 +171,8 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
client.Enabled,
client.Host,
client.Port,
client.SSL,
client.TLS,
client.TLSSkipVerify,
client.Username,
client.Password,
string(settingsJson),
@ -178,24 +185,26 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
} else {
var res sql.Result
res, err = r.db.handler.Exec(`INSERT INTO
res, err = r.db.handler.ExecContext(ctx, `INSERT INTO
client(
name,
type,
enabled,
host,
port,
ssl,
tls,
tls_skip_verify,
username,
password,
settings)
VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?) ON CONFLICT DO NOTHING`,
VALUES (?, ?, ?, ?, ?, ? , ?, ?, ?, ?) ON CONFLICT DO NOTHING`,
client.Name,
client.Type,
client.Enabled,
client.Host,
client.Port,
client.SSL,
client.TLS,
client.TLSSkipVerify,
client.Username,
client.Password,
string(settingsJson),
@ -220,11 +229,11 @@ func (r *DownloadClientRepo) Store(client domain.DownloadClient) (*domain.Downlo
return &client, nil
}
func (r *DownloadClientRepo) Delete(clientID int) error {
func (r *DownloadClientRepo) Delete(ctx context.Context, clientID int) error {
//r.db.lock.RLock()
//defer r.db.lock.RUnlock()
res, err := r.db.handler.Exec(`DELETE FROM client WHERE client.id = ?`, clientID)
res, err := r.db.handler.ExecContext(ctx, `DELETE FROM client WHERE client.id = ?`, clientID)
if err != nil {
log.Error().Stack().Err(err).Msgf("could not delete download client: %d", clientID)
return err
@ -234,7 +243,6 @@ func (r *DownloadClientRepo) Delete(clientID int) error {
r.cache.Pop(clientID)
rows, _ := res.RowsAffected()
if rows == 0 {
return err
}

View file

@ -3,6 +3,7 @@ package database
import (
"database/sql"
"fmt"
"github.com/lib/pq"
)
@ -119,16 +120,17 @@ CREATE TABLE filter_indexer
CREATE TABLE client
(
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
enabled BOOLEAN,
type TEXT,
host TEXT NOT NULL,
port INTEGER,
ssl BOOLEAN,
username TEXT,
password TEXT,
settings JSON
id INTEGER PRIMARY KEY,
name TEXT NOT NULL,
enabled BOOLEAN,
type TEXT,
host TEXT NOT NULL,
port INTEGER,
ssl BOOLEAN,
tls_skip_verify BOOLEAN,
username TEXT,
password TEXT,
settings JSON
);
CREATE TABLE action
@ -345,6 +347,13 @@ var migrations = []string{
ALTER TABLE "filter"
ADD COLUMN priority INTEGER DEFAULT 0 NOT NULL;
`,
`
ALTER TABLE "client"
ADD COLUMN tls_skip_verify BOOLEAN DEFAULT FALSE;
ALTER TABLE "client"
RENAME COLUMN ssl TO tls;
`,
}
func (db *SqliteDB) migrate() error {

View file

@ -4,23 +4,24 @@ import "context"
type DownloadClientRepo interface {
//FindByActionID(actionID int) ([]DownloadClient, error)
List() ([]DownloadClient, error)
List(ctx context.Context) ([]DownloadClient, error)
FindByID(ctx context.Context, id int32) (*DownloadClient, error)
Store(client DownloadClient) (*DownloadClient, error)
Delete(clientID int) error
Store(ctx context.Context, client DownloadClient) (*DownloadClient, error)
Delete(ctx context.Context, clientID int) error
}
type DownloadClient struct {
ID int `json:"id"`
Name string `json:"name"`
Type DownloadClientType `json:"type"`
Enabled bool `json:"enabled"`
Host string `json:"host"`
Port int `json:"port"`
SSL bool `json:"ssl"`
Username string `json:"username"`
Password string `json:"password"`
Settings DownloadClientSettings `json:"settings,omitempty"`
ID int `json:"id"`
Name string `json:"name"`
Type DownloadClientType `json:"type"`
Enabled bool `json:"enabled"`
Host string `json:"host"`
Port int `json:"port"`
TLS bool `json:"tls"`
TLSSkipVerify bool `json:"tls_skip_verify"`
Username string `json:"username"`
Password string `json:"password"`
Settings DownloadClientSettings `json:"settings,omitempty"`
}
type DownloadClientSettings struct {

View file

@ -36,11 +36,12 @@ func (s *service) testConnection(client domain.DownloadClient) error {
func (s *service) testQbittorrentConnection(client domain.DownloadClient) error {
qbtSettings := qbittorrent.Settings{
Hostname: client.Host,
Port: uint(client.Port),
Username: client.Username,
Password: client.Password,
SSL: client.SSL,
Hostname: client.Host,
Port: uint(client.Port),
Username: client.Username,
Password: client.Password,
TLS: client.TLS,
TLSSkipVerify: client.TLSSkipVerify,
}
qbt := qbittorrent.NewClient(qbtSettings)

View file

@ -5,15 +5,13 @@ import (
"errors"
"github.com/autobrr/autobrr/internal/domain"
"github.com/rs/zerolog/log"
)
type Service interface {
List() ([]domain.DownloadClient, error)
List(ctx context.Context) ([]domain.DownloadClient, error)
FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error)
Store(client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(clientID int) error
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(ctx context.Context, clientID int) error
Test(client domain.DownloadClient) error
}
@ -25,25 +23,15 @@ func NewService(repo domain.DownloadClientRepo) Service {
return &service{repo: repo}
}
func (s *service) List() ([]domain.DownloadClient, error) {
clients, err := s.repo.List()
if err != nil {
return nil, err
}
return clients, nil
func (s *service) List(ctx context.Context) ([]domain.DownloadClient, error) {
return s.repo.List(ctx)
}
func (s *service) FindByID(ctx context.Context, id int32) (*domain.DownloadClient, error) {
client, err := s.repo.FindByID(ctx, id)
if err != nil {
return nil, err
}
return client, nil
return s.repo.FindByID(ctx, id)
}
func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, error) {
func (s *service) Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error) {
// validate data
if client.Host == "" {
return nil, errors.New("validation error: no host")
@ -52,22 +40,11 @@ func (s *service) Store(client domain.DownloadClient) (*domain.DownloadClient, e
}
// store
c, err := s.repo.Store(client)
if err != nil {
return nil, err
}
return c, nil
return s.repo.Store(ctx, client)
}
func (s *service) Delete(clientID int) error {
if err := s.repo.Delete(clientID); err != nil {
return err
}
log.Debug().Msgf("delete client: %v", clientID)
return nil
func (s *service) Delete(ctx context.Context, clientID int) error {
return s.repo.Delete(ctx, clientID)
}
func (s *service) Test(client domain.DownloadClient) error {
@ -79,10 +56,5 @@ func (s *service) Test(client domain.DownloadClient) error {
}
// test
err := s.testConnection(client)
if err != nil {
return err
}
return nil
return s.testConnection(client)
}

View file

@ -1,7 +1,9 @@
package http
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
@ -11,9 +13,9 @@ import (
)
type downloadClientService interface {
List() ([]domain.DownloadClient, error)
Store(client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(clientID int) error
List(ctx context.Context) ([]domain.DownloadClient, error)
Store(ctx context.Context, client domain.DownloadClient) (*domain.DownloadClient, error)
Delete(ctx context.Context, clientID int) error
Test(client domain.DownloadClient) error
}
@ -40,87 +42,84 @@ func (h downloadClientHandler) Routes(r chi.Router) {
func (h downloadClientHandler) listDownloadClients(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
clients, err := h.service.List()
clients, err := h.service.List(ctx)
if err != nil {
//
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, clients, http.StatusOK)
}
func (h downloadClientHandler) store(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.DownloadClient
)
var data domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error
h.encoder.Error(w, err)
return
}
client, err := h.service.Store(data)
client, err := h.service.Store(r.Context(), data)
if err != nil {
// encode error
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, client, http.StatusCreated)
h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated)
}
func (h downloadClientHandler) test(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.DownloadClient
)
var data domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
h.encoder.Error(w, err)
return
}
err := h.service.Test(data)
if err != nil {
// encode error
h.encoder.StatusResponse(ctx, w, nil, http.StatusBadRequest)
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
h.encoder.NoContent(w)
}
func (h downloadClientHandler) update(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
data domain.DownloadClient
)
var data domain.DownloadClient
if err := json.NewDecoder(r.Body).Decode(&data); err != nil {
// encode error
h.encoder.Error(w, err)
return
}
client, err := h.service.Store(data)
client, err := h.service.Store(r.Context(), data)
if err != nil {
// encode error
h.encoder.Error(w, err)
return
}
h.encoder.StatusResponse(ctx, w, client, http.StatusCreated)
h.encoder.StatusResponse(r.Context(), w, client, http.StatusCreated)
}
func (h downloadClientHandler) delete(w http.ResponseWriter, r *http.Request) {
var (
ctx = r.Context()
clientID = chi.URLParam(r, "clientID")
)
var clientID = chi.URLParam(r, "clientID")
// if !clientID return error
id, _ := strconv.Atoi(clientID)
if err := h.service.Delete(id); err != nil {
// encode error
if clientID == "" {
h.encoder.Error(w, errors.New("no clientID given"))
return
}
h.encoder.StatusResponse(ctx, w, nil, http.StatusNoContent)
id, err := strconv.Atoi(clientID)
if err != nil {
h.encoder.Error(w, err)
return
}
if err = h.service.Delete(r.Context(), id); err != nil {
h.encoder.Error(w, err)
return
}
h.encoder.NoContent(w)
}

View file

@ -2,6 +2,7 @@ package qbittorrent
import (
"bytes"
"crypto/tls"
"fmt"
"io"
"mime/multipart"
@ -32,12 +33,13 @@ type Client struct {
}
type Settings struct {
Hostname string
Port uint
Username string
Password string
SSL bool
protocol string
Hostname string
Port uint
Username string
Password string
TLS bool
TLSSkipVerify bool
protocol string
}
func NewClient(s Settings) *Client {
@ -58,10 +60,19 @@ func NewClient(s Settings) *Client {
}
c.settings.protocol = "http"
if c.settings.SSL {
if c.settings.TLS {
c.settings.protocol = "https"
}
if c.settings.TLSSkipVerify {
//skip TLS verification
tr := &http.Transport{
TLSClientConfig: &tls.Config{InsecureSkipVerify: true},
}
c.http.Transport = tr
}
return c
}

View file

@ -36,7 +36,8 @@ interface InitialValues {
enabled: boolean;
host: string;
port: number;
ssl: boolean;
tls: boolean;
tls_skip_verify: boolean;
username: string;
password: string;
settings: InitialValuesSettings;
@ -44,14 +45,24 @@ interface InitialValues {
function FormFieldsDefault() {
const {
values: { tls },
} = useFormikContext<InitialValues>();
return (
<Fragment>
<TextFieldWide name="host" label="Host" help="Eg. client.domain.ltd, domain.ltd/client, domain.ltd:port" />
<NumberFieldWide name="port" label="Port" help="WebUI port for qBittorrent and daemon port for Deluge" />
<div className="py-6 px-6 space-y-6 sm:py-0 sm:space-y-0 sm:divide-y sm:divide-gray-200">
<SwitchGroupWide name="ssl" label="SSL" />
<div className="py-6 px-6 space-y-6 sm:py-0 sm:space-y-0 sm:divide-y sm:divide-gray-200 dark:divide-gray-700">
<SwitchGroupWide name="tls" label="TLS" />
{tls && (
<Fragment>
<SwitchGroupWide name="tls_skip_verify" label="Skip TLS verification (insecure)" />
</Fragment>
)}
</div>
<TextFieldWide name="username" label="Username" />
@ -325,7 +336,8 @@ export function DownloadClientAddForm({ isOpen, toggle }: any) {
enabled: true,
host: "",
port: 10000,
ssl: false,
tls: false,
tls_skip_verify: false,
username: "",
password: "",
settings: {}
@ -512,7 +524,8 @@ export function DownloadClientUpdateForm({ client, isOpen, toggle }: any) {
enabled: client.enabled,
host: client.host,
port: client.port,
ssl: client.ssl,
tls: client.tls,
tls_skip_verify: client.tls_skip_verify,
username: client.username,
password: client.password,
settings: client.settings,

View file

@ -33,7 +33,8 @@ interface DownloadClient {
enabled: boolean;
host: string;
port: number;
ssl: boolean;
tls: boolean;
tls_skip_verify: boolean;
username: string;
password: string;
settings?: DownloadClientSettings;