package main import ( "net" "net/http" "sync" "golang.org/x/time/rate" ) // IP limiter code taken from // https://medium.com/@pliutau/rate-limiting-http-requests-in-go-based-on-ip-address-4e66d1bea4cf var limiter = NewIPRateLimiter(1, 10) // IPRateLimiter type IPRateLimiter struct { ips map[string]*rate.Limiter mu *sync.RWMutex r rate.Limit b int } // NewIPRateLimiter func NewIPRateLimiter(r rate.Limit, b int) *IPRateLimiter { i := &IPRateLimiter{ ips: make(map[string]*rate.Limiter), mu: &sync.RWMutex{}, r: r, b: b, } return i } // AddIP creates a new rate limiter and adds it to the ips map, // using the IP address as the key func (i *IPRateLimiter) AddIP(ip string) *rate.Limiter { i.mu.Lock() defer i.mu.Unlock() limiter := rate.NewLimiter(i.r, i.b) i.ips[ip] = limiter return limiter } // GetLimiter returns the rate limiter for the provided IP address if it exists. // Otherwise calls AddIP to add IP address to the map func (i *IPRateLimiter) GetLimiter(ip string) *rate.Limiter { i.mu.Lock() limiter, exists := i.ips[ip] if !exists { i.mu.Unlock() return i.AddIP(ip) } i.mu.Unlock() return limiter } func limitMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ip, _, err := net.SplitHostPort(r.RemoteAddr) if err != nil { w.WriteHeader(http.StatusBadRequest) return } limiter := limiter.GetLimiter(ip) if !limiter.Allow() { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) return } next.ServeHTTP(w, r) }) }