diff --git a/go.mod b/go.mod index 92df90fe..1181175e 100644 --- a/go.mod +++ b/go.mod @@ -23,6 +23,7 @@ require ( golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 golang.org/x/net v0.0.0-20201029221708-28c70e62bb1d // indirect golang.org/x/sys v0.0.0-20201029080932-201ba4db2418 // indirect + golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba // indirect google.golang.org/genproto v0.0.0-20201030142918-24207fddd1c3 // indirect google.golang.org/grpc v1.33.1 // indirect google.golang.org/protobuf v1.25.0 // indirect diff --git a/go.sum b/go.sum index 01c75fab..2eee44ce 100644 --- a/go.sum +++ b/go.sum @@ -97,6 +97,8 @@ golang.org/x/sys v0.0.0-20201029080932-201ba4db2418 h1:HlFl4V6pEMziuLXyRkm5BIYq1 golang.org/x/sys v0.0.0-20201029080932-201ba4db2418/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba h1:O8mE0/t419eoIwhTFpKVkHiTs/Igowgfkj25AcZrtiE= +golang.org/x/time v0.0.0-20210220033141-f8bda1e9f3ba/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20181030221726-6c7e314b6563/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190114222345-bf090417da8b/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/goscrobble/ratelimiter.go b/internal/goscrobble/ratelimiter.go new file mode 100644 index 00000000..2acc2144 --- /dev/null +++ b/internal/goscrobble/ratelimiter.go @@ -0,0 +1,56 @@ +package goscrobble + +import ( + "sync" + + "golang.org/x/time/rate" +) + +// IPRateLimiter +type IPRateLimiter struct { + ips map[string]*rate.Limiter + mu *sync.RWMutex + r rate.Limit + b int +} + +// NewIPRateLimiter +func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter { + i := &IPRateLimiter{ + ips: make(map[string]*rate.Limiter), + mu: &sync.RWMutex{}, + r: r, + b: b, + } + + return i +} + +// AddIP creates a new rate limiter and adds it to the ips map, +// using the IP address as the key +func (i *IPRateLimiter) AddIP(ip string) *rate.Limiter { + i.mu.Lock() + defer i.mu.Unlock() + + limiter := rate.NewLimiter(i.r, i.b) + + i.ips[ip] = limiter + + return limiter +} + +// GetLimiter returns the rate limiter for the provided IP address if it exists. +// Otherwise calls AddIP to add IP address to the map +func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { + i.mu.Lock() + limiter, exists := i.ips[ip] + + if !exists { + i.mu.Unlock() + return i.AddIP(ip) + } + + i.mu.Unlock() + + return limiter +} diff --git a/internal/goscrobble/server.go b/internal/goscrobble/server.go index 45d6c143..4a5604e0 100644 --- a/internal/goscrobble/server.go +++ b/internal/goscrobble/server.go @@ -19,9 +19,13 @@ type spaHandler struct { } type jsonResponse struct { - Err string `json:"error"` + Err string `json:"error,omitempty"` + Msg string `json:"message,omitempty"` } +// Limits to 1 req / 10 sec +var limiter = NewIPRateLimiter(0.1, 1) + // HandleRequests - Boot HTTP! func HandleRequests() { // Create a new router @@ -36,7 +40,7 @@ func HandleRequests() { v1.HandleFunc("/profile/{id}", jwtMiddleware(serveEndpoint)) // No Auth - v1.HandleFunc("/register", handleRegister).Methods("POST") + v1.HandleFunc("/register", limitMiddleware(handleRegister)).Methods("POST") v1.HandleFunc("/login", serveEndpoint).Methods("POST") v1.HandleFunc("/logout", serveEndpoint).Methods("POST") @@ -62,7 +66,7 @@ func throwUnauthorized(w http.ResponseWriter, m string) { http.Error(w, err.Error(), http.StatusUnauthorized) } -// throwUnauthorized - Throws a 403 : +// throwUnauthorized - Throws a 403 func throwBadReq(w http.ResponseWriter, m string) { jr := jsonResponse{ Err: m, @@ -72,6 +76,15 @@ func throwBadReq(w http.ResponseWriter, m string) { http.Error(w, err.Error(), http.StatusBadRequest) } +// generateJsonMessage - Generates a message:str response +func generateJsonMessage(m string) []byte { + jr := jsonResponse{ + Msg: m, + } + js, _ := json.Marshal(&jr) + return js +} + // tokenMiddleware - Validates token to a user func tokenMiddleware(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { @@ -90,6 +103,19 @@ func jwtMiddleware(next http.HandlerFunc) http.HandlerFunc { } } +// limitMiddleware - Rate limits important stuff +func limitMiddleware(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + limiter := limiter.GetLimiter(r.RemoteAddr) + if !limiter.Allow() { + http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) + return + } + + next(w, r) + }) +} + // API ENDPOINT HANDLING // handleRegister - Does as it says! @@ -108,8 +134,9 @@ func handleRegister(w http.ResponseWriter, r *http.Request) { return } - // Lets trick 'em for now ;) ;) - fmt.Fprintf(w, "{}") + msg := generateJsonMessage("User created succesfully") + w.WriteHeader(http.StatusCreated) + w.Write(msg) } // serveEndpoint - API stuffs