Skip to content

Commit ae66d2a

Browse files
committed
client: add UpdateGroups to HeartbeatSender and PIMServer
Make UpdateGroups apply group changes in-place instead of tearing down and restarting the goroutine/connection. HeartbeatSender: mutex-protected dsts field, goroutine reads under lock each tick, UpdateGroups swaps dsts and sends an immediate heartbeat. PIMServer: channel-based update via updateCh, goroutine computes added/removed groups and sends targeted join/prune messages. Eliminates the need for callers to create a new RawConner on every update. Remove conn param from PIMWriter.UpdateGroups interface and simplify the multicast service update path accordingly.
1 parent 3944a66 commit ae66d2a

9 files changed

Lines changed: 289 additions & 58 deletions

File tree

client/doublezerod/internal/manager/http_test.go

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,10 @@ func (m *MockPIMServer) Start(conn pim.RawConner, iface string, tunnelAddr net.I
449449
return nil
450450
}
451451

452+
func (m *MockPIMServer) UpdateGroups(groups []net.IP) error {
453+
return nil
454+
}
455+
452456
func (m *MockPIMServer) Close() error {
453457
return nil
454458
}
@@ -459,6 +463,10 @@ func (m *MockHeartbeatSender) Start(iface string, srcIP net.IP, groups []net.IP,
459463
return nil
460464
}
461465

466+
func (m *MockHeartbeatSender) UpdateGroups(groups []net.IP) error {
467+
return nil
468+
}
469+
462470
func (m *MockHeartbeatSender) Close() error {
463471
return nil
464472
}

client/doublezerod/internal/manager/reconciler_test.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,14 +69,16 @@ func (m *mockBgpServer) GetPeerStatus(net.IP) bgp.Session { return bgp.
6969
type mockPIMServer struct{}
7070

7171
func (m *mockPIMServer) Start(pim.RawConner, string, net.IP, []net.IP) error { return nil }
72+
func (m *mockPIMServer) UpdateGroups([]net.IP) error { return nil }
7273
func (m *mockPIMServer) Close() error { return nil }
7374

7475
type mockHeartbeatSender struct{}
7576

7677
func (m *mockHeartbeatSender) Start(string, net.IP, []net.IP, int, time.Duration) error {
7778
return nil
7879
}
79-
func (m *mockHeartbeatSender) Close() error { return nil }
80+
func (m *mockHeartbeatSender) UpdateGroups([]net.IP) error { return nil }
81+
func (m *mockHeartbeatSender) Close() error { return nil }
8082

8183
// --- test helpers ---
8284

client/doublezerod/internal/multicast/heartbeat.go

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ type PacketConner interface {
3939
type HeartbeatSender struct {
4040
done chan struct{}
4141
wg *sync.WaitGroup
42+
mu sync.Mutex
43+
dsts []*net.UDPAddr
44+
conn PacketConner
4245
}
4346

4447
func NewHeartbeatSender() *HeartbeatSender {
@@ -62,6 +65,24 @@ func (h *HeartbeatSender) Start(iface string, srcIP net.IP, groups []net.IP, ttl
6265
return h.startWithConn(p, intf, groups, ttl, interval)
6366
}
6467

68+
// UpdateGroups applies a new set of multicast groups in-place without
69+
// restarting the goroutine or connection. It sends an immediate heartbeat
70+
// to the new destinations so callers don't have to wait for the next tick.
71+
func (h *HeartbeatSender) UpdateGroups(groups []net.IP) error {
72+
dsts := make([]*net.UDPAddr, len(groups))
73+
for i, group := range groups {
74+
dsts[i] = &net.UDPAddr{IP: group, Port: HeartbeatPort}
75+
}
76+
77+
h.mu.Lock()
78+
h.dsts = dsts
79+
conn := h.conn
80+
h.mu.Unlock()
81+
82+
sendHeartbeats(conn, dsts)
83+
return nil
84+
}
85+
6586
// startWithConn is the internal start method that accepts a pre-built connection for testing.
6687
func (h *HeartbeatSender) startWithConn(p PacketConner, intf *net.Interface, groups []net.IP, ttl int, interval time.Duration) error {
6788
if err := p.SetMulticastTTL(ttl); err != nil {
@@ -78,21 +99,32 @@ func (h *HeartbeatSender) startWithConn(p PacketConner, intf *net.Interface, gro
7899
dsts[i] = &net.UDPAddr{IP: group, Port: HeartbeatPort}
79100
}
80101

102+
h.mu.Lock()
103+
h.dsts = dsts
104+
h.conn = p
105+
h.mu.Unlock()
106+
81107
h.wg = &sync.WaitGroup{}
82108
h.wg.Add(1)
83109
go func() {
84110
defer p.Close()
85111
defer h.wg.Done()
86112

87113
// Send immediately before starting ticker so we don't delay by the interval.
88-
sendHeartbeats(p, dsts)
114+
h.mu.Lock()
115+
currentDsts := h.dsts
116+
h.mu.Unlock()
117+
sendHeartbeats(p, currentDsts)
89118

90119
ticker := time.NewTicker(interval)
91120
defer ticker.Stop()
92121
for {
93122
select {
94123
case <-ticker.C:
95-
sendHeartbeats(p, dsts)
124+
h.mu.Lock()
125+
currentDsts := h.dsts
126+
h.mu.Unlock()
127+
sendHeartbeats(p, currentDsts)
96128
case <-h.done:
97129
return
98130
}

client/doublezerod/internal/multicast/heartbeat_test.go

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,64 @@ func TestHeartbeatSender_DoubleClose(t *testing.T) {
264264
sender.Close()
265265
}
266266

267+
func TestHeartbeatSender_UpdateGroups(t *testing.T) {
268+
conn := newMockPacketConn()
269+
sender := NewHeartbeatSender()
270+
271+
// Start with one group.
272+
intf := &net.Interface{Index: 1, Name: "lo0"}
273+
err := sender.startWithConn(conn, intf, []net.IP{net.IPv4(239, 0, 0, 1)}, 32, 10*time.Second)
274+
if err != nil {
275+
t.Fatalf("failed to start: %v", err)
276+
}
277+
// Drain the immediate send.
278+
<-conn.writeCh
279+
280+
// Update to two groups on the same running sender/connection.
281+
if err := sender.UpdateGroups([]net.IP{net.IPv4(239, 0, 0, 1), net.IPv4(239, 0, 0, 2)}); err != nil {
282+
t.Fatalf("UpdateGroups failed: %v", err)
283+
}
284+
285+
// UpdateGroups sends an immediate heartbeat to all new destinations.
286+
for i := range 2 {
287+
select {
288+
case <-conn.writeCh:
289+
case <-time.After(2 * time.Second):
290+
t.Fatalf("timed out waiting for heartbeat %d after update", i)
291+
}
292+
}
293+
294+
writes := conn.getWrites()
295+
// At least 3 writes: 1 initial + 2 from UpdateGroups immediate send.
296+
if len(writes) < 3 {
297+
t.Fatalf("expected at least 3 writes, got %d", len(writes))
298+
}
299+
300+
// Verify the last 2 writes went to both groups (from UpdateGroups immediate send).
301+
seen := map[string]bool{}
302+
for _, w := range writes[len(writes)-2:] {
303+
udpAddr := w.dst.(*net.UDPAddr)
304+
seen[udpAddr.IP.String()] = true
305+
}
306+
if !seen["239.0.0.1"] {
307+
t.Error("missing heartbeat to 239.0.0.1 after update")
308+
}
309+
if !seen["239.0.0.2"] {
310+
t.Error("missing heartbeat to 239.0.0.2 after update")
311+
}
312+
313+
sender.Close()
314+
315+
// Verify the connection was NOT closed during UpdateGroups (same conn reused).
316+
// It should only be closed after sender.Close().
317+
conn.mu.Lock()
318+
closed := conn.closed
319+
conn.mu.Unlock()
320+
if !closed {
321+
t.Error("connection should be closed after sender.Close()")
322+
}
323+
}
324+
267325
func TestHeartbeatSender_CloseStopsSending(t *testing.T) {
268326
conn := newMockPacketConn()
269327
sender := NewHeartbeatSender()

client/doublezerod/internal/pim/server.go

Lines changed: 82 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,27 @@ type RawConner interface {
3030
}
3131

3232
type PIMServer struct {
33-
iface string
34-
groups []net.IP
35-
done chan struct{}
36-
conn RawConner
37-
wg *sync.WaitGroup
33+
iface string
34+
groups []net.IP
35+
done chan struct{}
36+
conn RawConner
37+
wg *sync.WaitGroup
38+
tunnelAddr net.IP
39+
updateCh chan []net.IP
3840
}
3941

4042
func NewPIMServer() *PIMServer {
41-
return &PIMServer{done: make(chan struct{})}
43+
return &PIMServer{
44+
done: make(chan struct{}),
45+
updateCh: make(chan []net.IP),
46+
}
4247
}
4348

4449
func (s *PIMServer) Start(conn RawConner, iface string, tunnelAddr net.IP, groups []net.IP) error {
4550
s.iface = iface
4651
s.groups = groups
4752
s.conn = conn
53+
s.tunnelAddr = tunnelAddr
4854

4955
intf, err := net.InterfaceByName(s.iface)
5056
if err != nil {
@@ -58,61 +64,96 @@ func (s *PIMServer) Start(conn RawConner, iface string, tunnelAddr net.IP, group
5864
s.wg.Add(1)
5965
go func() {
6066
defer s.conn.Close()
67+
defer s.wg.Done()
68+
6169
// send before we start ticker so we don't delay provisioning time by ticker interval
62-
helloMsgBuf, err := constructHelloMessage()
63-
if err != nil {
64-
slog.Error("failed to serialize PIM hello msg", "error", err)
65-
}
66-
err = sendMsg(helloMsgBuf, intf, s.conn)
67-
if err != nil {
68-
slog.Error("failed to send PIM hello msg", "error", err)
69-
}
70-
joinPruneMsgBuf, err := constructJoinPruneMessage(tunnelAddr, groups, RpAddress, nil, joinHoldtime)
71-
if err != nil {
72-
slog.Error("failed to serialize PIM join msg", "error", err)
73-
}
74-
err = sendMsg(joinPruneMsgBuf, intf, s.conn)
75-
if err != nil {
76-
slog.Error("failed to send PIM join msg", "error", err)
77-
}
70+
sendHelloAndJoin(intf, s.conn, tunnelAddr, s.groups)
7871

7972
ticker := time.NewTicker(time.Second * 30)
73+
defer ticker.Stop()
8074
for {
8175
select {
8276
case <-ticker.C:
83-
helloMsgBuf, err := constructHelloMessage()
84-
if err != nil {
85-
slog.Error("failed to serialize PIM hello msg", "error", err)
86-
}
87-
err = sendMsg(helloMsgBuf, intf, s.conn)
88-
if err != nil {
89-
slog.Error("failed to send PIM hello msg", "error", err)
90-
}
91-
joinPruneMsgBuf, err := constructJoinPruneMessage(tunnelAddr, groups, RpAddress, nil, joinHoldtime)
92-
if err != nil {
93-
slog.Error("failed to serialize PIM join msg", "error", err)
77+
sendHelloAndJoin(intf, s.conn, tunnelAddr, s.groups)
78+
case newGroups := <-s.updateCh:
79+
added, removed := ipDiff(s.groups, newGroups)
80+
if len(removed) > 0 {
81+
buf, err := constructJoinPruneMessage(tunnelAddr, removed, nil, RpAddress, pruneHoldtime)
82+
if err != nil {
83+
slog.Error("failed to serialize PIM prune msg for removed groups", "error", err)
84+
} else if err := sendMsg(buf, intf, s.conn); err != nil {
85+
slog.Error("failed to send PIM prune msg for removed groups", "error", err)
86+
}
9487
}
95-
err = sendMsg(joinPruneMsgBuf, intf, s.conn)
96-
if err != nil {
97-
slog.Error("failed to send PIM join msg", "error", err)
88+
if len(added) > 0 {
89+
buf, err := constructJoinPruneMessage(tunnelAddr, added, RpAddress, nil, joinHoldtime)
90+
if err != nil {
91+
slog.Error("failed to serialize PIM join msg for added groups", "error", err)
92+
} else if err := sendMsg(buf, intf, s.conn); err != nil {
93+
slog.Error("failed to send PIM join msg for added groups", "error", err)
94+
}
9895
}
96+
s.groups = newGroups
9997
case <-s.done:
100-
joinPruneMsgBuf, err := constructJoinPruneMessage(tunnelAddr, groups, nil, RpAddress, pruneHoldtime)
98+
joinPruneMsgBuf, err := constructJoinPruneMessage(tunnelAddr, s.groups, nil, RpAddress, pruneHoldtime)
10199
if err != nil {
102100
slog.Error("failed to serialize PIM prune msg", "error", err)
103-
}
104-
err = sendMsg(joinPruneMsgBuf, intf, s.conn)
105-
if err != nil {
101+
} else if err := sendMsg(joinPruneMsgBuf, intf, s.conn); err != nil {
106102
slog.Error("failed to send PIM prune msg", "error", err)
107103
}
108-
s.wg.Done()
109104
return
110105
}
111106
}
112107
}()
113108
return nil
114109
}
115110

111+
// UpdateGroups applies a new set of multicast groups in-place without
112+
// restarting the goroutine or connection. The goroutine computes the diff
113+
// and sends targeted join/prune messages for added/removed groups.
114+
func (s *PIMServer) UpdateGroups(groups []net.IP) error {
115+
s.updateCh <- groups
116+
return nil
117+
}
118+
119+
func sendHelloAndJoin(intf *net.Interface, conn RawConner, tunnelAddr net.IP, groups []net.IP) {
120+
helloMsgBuf, err := constructHelloMessage()
121+
if err != nil {
122+
slog.Error("failed to serialize PIM hello msg", "error", err)
123+
} else if err := sendMsg(helloMsgBuf, intf, conn); err != nil {
124+
slog.Error("failed to send PIM hello msg", "error", err)
125+
}
126+
joinPruneMsgBuf, err := constructJoinPruneMessage(tunnelAddr, groups, RpAddress, nil, joinHoldtime)
127+
if err != nil {
128+
slog.Error("failed to serialize PIM join msg", "error", err)
129+
} else if err := sendMsg(joinPruneMsgBuf, intf, conn); err != nil {
130+
slog.Error("failed to send PIM join msg", "error", err)
131+
}
132+
}
133+
134+
// ipDiff returns IPs that were added and removed when transitioning from old to new.
135+
func ipDiff(old, new []net.IP) (added, removed []net.IP) {
136+
oldSet := make(map[string]bool, len(old))
137+
for _, ip := range old {
138+
oldSet[ip.String()] = true
139+
}
140+
newSet := make(map[string]bool, len(new))
141+
for _, ip := range new {
142+
newSet[ip.String()] = true
143+
}
144+
for _, ip := range new {
145+
if !oldSet[ip.String()] {
146+
added = append(added, ip)
147+
}
148+
}
149+
for _, ip := range old {
150+
if !newSet[ip.String()] {
151+
removed = append(removed, ip)
152+
}
153+
}
154+
return
155+
}
156+
116157
func (s *PIMServer) Close() error {
117158
s.done <- struct{}{}
118159
s.wg.Wait()

0 commit comments

Comments
 (0)