79 lines
1.3 KiB
Go
79 lines
1.3 KiB
Go
package api
|
||
|
||
import (
|
||
"net"
|
||
"net/http"
|
||
"sync"
|
||
"time"
|
||
)
|
||
|
||
type rateLimiter struct {
|
||
mu sync.Mutex
|
||
bk map[string]*bucket
|
||
rate float64 // tokens per second
|
||
burst float64
|
||
window time.Duration
|
||
}
|
||
|
||
type bucket struct {
|
||
tokens float64
|
||
last time.Time
|
||
}
|
||
|
||
func newRateLimiter(rps float64, burst int, window time.Duration) *rateLimiter {
|
||
return &rateLimiter{
|
||
bk: make(map[string]*bucket),
|
||
rate: rps,
|
||
burst: float64(burst),
|
||
window: window,
|
||
}
|
||
}
|
||
|
||
func (rl *rateLimiter) allow(key string) bool {
|
||
now := time.Now()
|
||
rl.mu.Lock()
|
||
defer rl.mu.Unlock()
|
||
|
||
b := rl.bk[key]
|
||
if b == nil {
|
||
b = &bucket{tokens: rl.burst, last: now}
|
||
rl.bk[key] = b
|
||
}
|
||
// refill
|
||
elapsed := now.Sub(b.last).Seconds()
|
||
b.tokens = min(rl.burst, b.tokens+elapsed*rl.rate)
|
||
b.last = now
|
||
|
||
if b.tokens < 1.0 {
|
||
return false
|
||
}
|
||
b.tokens -= 1.0
|
||
|
||
// occasional cleanup
|
||
for k, v := range rl.bk {
|
||
if now.Sub(v.last) > rl.window {
|
||
delete(rl.bk, k)
|
||
}
|
||
}
|
||
return true
|
||
}
|
||
|
||
func min(a, b float64) float64 {
|
||
if a < b {
|
||
return a
|
||
}
|
||
return b
|
||
}
|
||
|
||
func clientIP(r *http.Request) string {
|
||
// Prefer Cloudflare’s header if present; fall back to RemoteAddr.
|
||
if ip := r.Header.Get("CF-Connecting-IP"); ip != "" {
|
||
return ip
|
||
}
|
||
host, _, err := net.SplitHostPort(r.RemoteAddr)
|
||
if err != nil {
|
||
return r.RemoteAddr
|
||
}
|
||
return host
|
||
}
|