Skip to content

Commit

Permalink
Merge pull request #89 from ddosify/develop
Browse files Browse the repository at this point in the history
Refactor/speed up
  • Loading branch information
fatihbaltaci authored Feb 13, 2024
2 parents 80596e7 + 9968000 commit 1c40178
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 143 deletions.
211 changes: 71 additions & 140 deletions aggregator/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"context"
"encoding/binary"
"fmt"
"io"
"net"
"os"
"os/exec"
Expand Down Expand Up @@ -47,13 +48,11 @@ type Aggregator struct {

// listen to events from different sources
k8sChan <-chan interface{}
ebpfChan chan interface{}
ebpfChan <-chan interface{}
ebpfProcChan <-chan interface{}
ebpfTcpChan <-chan interface{}
tlsAttachSignalChan chan uint32

ec *ebpf.EbpfCollector

// store the service map
clusterInfo *ClusterInfo

Expand Down Expand Up @@ -118,18 +117,7 @@ type ClusterInfo struct {

// Pid -> SocketMap
// pid -> fd -> {saddr, sport, daddr, dport}

// shard pidToSocketMap by pid to reduce lock contention
mu0 sync.RWMutex
mu1 sync.RWMutex
mu2 sync.RWMutex
mu3 sync.RWMutex
mu4 sync.RWMutex
PidToSocketMap0 map[uint32]*SocketMap `json:"pidToSocketMap0"` // pid ending with 0-1
PidToSocketMap1 map[uint32]*SocketMap `json:"pidToSocketMap1"` // pid ending with 2-3
PidToSocketMap2 map[uint32]*SocketMap `json:"pidToSocketMap2"` // pid ending with 4-5
PidToSocketMap3 map[uint32]*SocketMap `json:"pidToSocketMap3"` // pid ending with 6-7
PidToSocketMap4 map[uint32]*SocketMap `json:"pidToSocketMap4"` // pid ending with 8-9
SocketMaps []*SocketMap // index symbolizes pid
}

// If we have information from the container runtimes
Expand All @@ -154,10 +142,7 @@ var (
purgeTime = 10 * time.Minute
)

var usePgDs bool = false
var useBackendDs bool = true // default to true
var reverseDnsCache *cache.Cache

var re *regexp.Regexp

func init() {
Expand All @@ -179,13 +164,24 @@ func NewAggregator(parentCtx context.Context, k8sChan <-chan interface{},
clusterInfo := &ClusterInfo{
PodIPToPodUid: map[string]types.UID{},
ServiceIPToServiceUid: map[string]types.UID{},
PidToSocketMap0: make(map[uint32]*SocketMap),
PidToSocketMap1: make(map[uint32]*SocketMap),
PidToSocketMap2: make(map[uint32]*SocketMap),
PidToSocketMap3: make(map[uint32]*SocketMap),
PidToSocketMap4: make(map[uint32]*SocketMap),
}

maxPid, err := getPidMax()
if err != nil {
log.Logger.Fatal().Err(err).Msg("error getting max pid")
}
sockMaps := make([]*SocketMap, maxPid+1) // index=pid

// initialize sockMaps
for i := range sockMaps {
sockMaps[i] = &SocketMap{
M: nil, // initialized on demand later
mu: sync.RWMutex{},
}
}

clusterInfo.SocketMaps = sockMaps

a := &Aggregator{
ctx: ctx,
k8sChan: k8sChan,
Expand Down Expand Up @@ -289,6 +285,7 @@ func (a *Aggregator) Run() {
}()
go a.processk8s()

// TODO: determine the number of workers with benchmarking
cpuCount := runtime.NumCPU()
numWorker := 5 * cpuCount
if numWorker < 50 {
Expand All @@ -300,7 +297,6 @@ func (a *Aggregator) Run() {
go a.processEbpfTcp(a.ctx)
}

// TODO: pod number may be ideal
for i := 0; i < 2*cpuCount; i++ {
go a.processHttp2Frames()
go a.processEbpfProc(a.ctx)
Expand Down Expand Up @@ -472,23 +468,15 @@ func (a *Aggregator) processTcpConnect(d *tcp_state.TcpConnectEvent) {
var sockMap *SocketMap
var ok bool

mu, pidToSocketMap := a.getShard(d.Pid)
mu.Lock()
sockMap, ok = pidToSocketMap[d.Pid]
if !ok {
sockMap = &SocketMap{
M: make(map[uint64]*SocketLine),
mu: sync.RWMutex{},
}
pidToSocketMap[d.Pid] = sockMap
}
mu.Unlock() // unlock for writing

sockMap = a.clusterInfo.SocketMaps[d.Pid]
var skLine *SocketLine

sockMap.mu.Lock() // lock for reading
skLine, ok = sockMap.M[d.Fd]
if sockMap.M == nil {
sockMap.M = make(map[uint64]*SocketLine)
}

skLine, ok = sockMap.M[d.Fd]
if !ok {
skLine = NewSocketLine(d.Pid, d.Fd)
sockMap.M[d.Fd] = skLine
Expand All @@ -512,24 +500,14 @@ func (a *Aggregator) processTcpConnect(d *tcp_state.TcpConnectEvent) {
var sockMap *SocketMap
var ok bool

mu, pidToSocketMap := a.getShard(d.Pid)
mu.Lock()
sockMap, ok = pidToSocketMap[d.Pid]
if !ok {
sockMap = &SocketMap{
M: make(map[uint64]*SocketLine),
mu: sync.RWMutex{},
}

pidToSocketMap[d.Pid] = sockMap
mu.Unlock() // unlock for writing
return
}
mu.Unlock()
sockMap = a.clusterInfo.SocketMaps[d.Pid]

var skLine *SocketLine

sockMap.mu.Lock() // lock for reading
if sockMap.M == nil {
sockMap.M = make(map[uint64]*SocketLine)
}
skLine, ok = sockMap.M[d.Fd]
if !ok {
sockMap.mu.Unlock() // unlock for reading
Expand Down Expand Up @@ -1068,17 +1046,9 @@ func (a *Aggregator) fetchSkLine(sockMap *SocketMap, pid uint32, fd uint64) *Soc
// add it to the socket map
func (a *Aggregator) getAlreadyExistingSockets(pid uint32) {
// no need for locking because this is called firstmost and no other goroutine is running
_, pidToSocketMap := a.getShard(pid)
sockMap, ok := pidToSocketMap[pid]
if !ok {
sockMap = &SocketMap{
M: make(map[uint64]*SocketLine),
mu: sync.RWMutex{},
}
pidToSocketMap[pid] = sockMap
}

socks := map[string]sock{}
sockMap := a.fetchSocketMap(pid)

// Get the sockets for the process.
var err error
Expand Down Expand Up @@ -1140,7 +1110,12 @@ func (a *Aggregator) getAlreadyExistingSockets(pid uint32) {
skLine := NewSocketLine(pid, fd.Fd)
skLine.AddValue(0, sockInfo)

sockMap.mu.Lock()
if sockMap.M == nil {
sockMap.M = make(map[uint64]*SocketLine)
}
sockMap.M[fd.Fd] = skLine
sockMap.mu.Unlock()
}
}

Expand Down Expand Up @@ -1178,56 +1153,20 @@ func (a *Aggregator) fetchSkInfo(ctx context.Context, skLine *SocketLine, d *l7_
return skInfo
}

func (a *Aggregator) getShard(pid uint32) (*sync.RWMutex, map[uint32]*SocketMap) {
lastDigit := pid % 10
var mu *sync.RWMutex
var pidToSocketMap map[uint32]*SocketMap
switch lastDigit {
case 0, 1:
mu = &a.clusterInfo.mu0
pidToSocketMap = a.clusterInfo.PidToSocketMap0
case 2, 3:
mu = &a.clusterInfo.mu1
pidToSocketMap = a.clusterInfo.PidToSocketMap1
case 4, 5:
mu = &a.clusterInfo.mu2
pidToSocketMap = a.clusterInfo.PidToSocketMap2
case 6, 7:
mu = &a.clusterInfo.mu3
pidToSocketMap = a.clusterInfo.PidToSocketMap3
case 8, 9:
mu = &a.clusterInfo.mu4
pidToSocketMap = a.clusterInfo.PidToSocketMap4
}

return mu, pidToSocketMap
}

func (a *Aggregator) removeFromClusterInfo(pid uint32) {
mu, pidToSocketMap := a.getShard(pid)
mu.Lock()
delete(pidToSocketMap, pid)
mu.Unlock()
sockMap := a.clusterInfo.SocketMaps[pid]
sockMap.mu.Lock()
sockMap.M = nil
sockMap.mu.Unlock()
}

func (a *Aggregator) fetchSocketMap(pid uint32) *SocketMap {
var sockMap *SocketMap
var ok bool

mu, pidToSocketMap := a.getShard(pid) // create shard if not exists
mu.Lock() // lock for reading
sockMap, ok = pidToSocketMap[pid]
if !ok {
// initialize socket map
sockMap = &SocketMap{
M: make(map[uint64]*SocketLine),
mu: sync.RWMutex{},
}
pidToSocketMap[pid] = sockMap

go a.signalTlsAttachment(pid)
sockMap := a.clusterInfo.SocketMaps[pid]
sockMap.mu.Lock()
if sockMap.M == nil {
sockMap.M = make(map[uint64]*SocketLine)
}
mu.Unlock() // unlock for writing
sockMap.mu.Unlock()

return sockMap
}
Expand Down Expand Up @@ -1408,44 +1347,36 @@ func (a *Aggregator) clearSocketLines(ctx context.Context) {
}()

for range ticker.C {
a.clusterInfo.mu0.RLock()
for _, socketMap := range a.clusterInfo.PidToSocketMap0 {
for _, socketLine := range socketMap.M {
skLineCh <- socketLine
}
}
a.clusterInfo.mu0.RUnlock()

a.clusterInfo.mu1.RLock()
for _, socketMap := range a.clusterInfo.PidToSocketMap1 {
for _, socketLine := range socketMap.M {
skLineCh <- socketLine
}
}
a.clusterInfo.mu1.RUnlock()

a.clusterInfo.mu2.RLock()
for _, socketMap := range a.clusterInfo.PidToSocketMap2 {
for _, socketLine := range socketMap.M {
skLineCh <- socketLine
for _, sockMap := range a.clusterInfo.SocketMaps {
sockMap.mu.Lock()
if sockMap.M != nil {
for _, skLine := range sockMap.M {
skLineCh <- skLine
}
}
sockMap.mu.Unlock()
}
a.clusterInfo.mu2.RUnlock()
}
}

a.clusterInfo.mu3.RLock()
for _, socketMap := range a.clusterInfo.PidToSocketMap3 {
for _, socketLine := range socketMap.M {
skLineCh <- socketLine
}
}
a.clusterInfo.mu3.RUnlock()
func getPidMax() (int, error) {
// Read the contents of the file
f, err := os.Open("/proc/sys/kernel/pid_max")
if err != nil {
fmt.Println("Error opening file:", err)
return 0, err
}
content, err := io.ReadAll(f)
if err != nil {
fmt.Println("Error reading file:", err)
return 0, err
}

a.clusterInfo.mu4.RLock()
for _, socketMap := range a.clusterInfo.PidToSocketMap4 {
for _, socketLine := range socketMap.M {
skLineCh <- socketLine
}
}
a.clusterInfo.mu4.RUnlock()
// Convert the content to an integer
pidMax, err := strconv.Atoi(string(content[:len(content)-1])) // trim newline
if err != nil {
fmt.Println("Error converting to integer:", err)
return 0, err
}
return pidMax, nil
}
2 changes: 1 addition & 1 deletion ebpf/collector.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func NewEbpfCollector(parentCtx context.Context) *EbpfCollector {
ctx: ctx,
done: make(chan struct{}),
ebpfEvents: make(chan interface{}, 100000), // interface is 16 bytes, 16 * 100000 = 8 Megabytes
ebpfProcEvents: make(chan interface{}, 100),
ebpfProcEvents: make(chan interface{}, 2000),
ebpfTcpEvents: make(chan interface{}, 1000),
tlsPidMap: make(map[uint32]struct{}),
sslWriteUprobes: make(map[uint32]link.Link),
Expand Down
5 changes: 3 additions & 2 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,12 @@ func main() {
var ec *ebpf.EbpfCollector
if ebpfEnabled {
ec = ebpf.NewEbpfCollector(ctx)
ec.Init()
go ec.ListenEvents()

a := aggregator.NewAggregator(ctx, kubeEvents, ec.EbpfEvents(), ec.EbpfProcEvents(), ec.EbpfTcpEvents(), ec.TlsAttachQueue(), dsBackend)
a.Run()

ec.Init()
go ec.ListenEvents()
}

go http.ListenAndServe(":8181", nil)
Expand Down

0 comments on commit 1c40178

Please sign in to comment.