Skip to content

Commit 0a97f1c

Browse files
committed
security: harden daemon, DNS, forwarder, and Windows script generation
- Sanitize project/service names in PowerShell script output (M1) - Add mutex to daemon maps to prevent concurrent write panic (M2) - Restrict Unix socket permissions to 0660 (M3) - Validate loopback IPs before netlink add/remove (L2) - Restrict collision file permissions to 0600 (L3) - Replace pgrep with PID file for daemon management (L4) - Add 5s dial timeout to TCP forwarder (L5) - Validate Docker Compose labels against allowed charset (L6) - Fix test race conditions in DNS and forwarder test helpers
1 parent 5664bfa commit 0a97f1c

File tree

12 files changed

+173
-38
lines changed

12 files changed

+173
-38
lines changed

cmd/devproxy/main.go

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,9 @@ import (
1111
"os"
1212
"os/exec"
1313
"path/filepath"
14+
"strconv"
1415
"strings"
16+
"syscall"
1517

1618
"github.com/alysnnix/devproxy/internal/daemon"
1719
"github.com/alysnnix/devproxy/internal/ipman"
@@ -70,34 +72,30 @@ func runDaemon() {
7072
}
7173

7274
func runDown() {
73-
// Find and kill all devproxy daemon processes
74-
out, err := exec.Command("pgrep", "-f", "devproxy daemon").Output()
75+
pidPath := "/run/devproxy/devproxy.pid"
76+
77+
data, err := os.ReadFile(pidPath)
7578
if err != nil {
7679
fmt.Println("No devproxy daemon running")
7780
return
7881
}
7982

80-
pids := strings.Fields(strings.TrimSpace(string(out)))
81-
myPid := fmt.Sprintf("%d", os.Getpid())
82-
83-
killed := 0
84-
for _, pid := range pids {
85-
if pid == myPid {
86-
continue
87-
}
88-
if err := exec.Command("kill", "-9", pid).Run(); err != nil {
89-
fmt.Fprintf(os.Stderr, "failed to kill PID %s: %v\n", pid, err)
90-
} else {
91-
fmt.Printf("killed devproxy daemon (PID %s)\n", pid)
92-
killed++
93-
}
83+
pid, err := strconv.Atoi(strings.TrimSpace(string(data)))
84+
if err != nil {
85+
fmt.Fprintf(os.Stderr, "invalid PID file: %v\n", err)
86+
os.Remove(pidPath)
87+
return
9488
}
9589

96-
if killed == 0 {
97-
fmt.Println("No devproxy daemon running")
90+
if err := syscall.Kill(pid, syscall.SIGTERM); err != nil {
91+
fmt.Fprintf(os.Stderr, "failed to stop daemon (PID %d): %v\n", pid, err)
92+
os.Remove(pidPath)
9893
return
9994
}
10095

96+
fmt.Printf("stopped devproxy daemon (PID %d)\n", pid)
97+
os.Remove(pidPath)
98+
10199
// Cleanup stale state
102100
s := state.New("")
103101
m := ipman.New(s)

internal/api/api.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ func (a *API) ListenAndServe() error {
4747
return err
4848
}
4949
// Allow non-root users to connect
50-
os.Chmod(a.sockPath, 0666)
50+
os.Chmod(a.sockPath, 0660)
5151
a.listener = ln
5252

5353
mux := http.NewServeMux()

internal/daemon/daemon.go

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ import (
66
"log/slog"
77
"os"
88
"os/signal"
9+
"path/filepath"
910
"sort"
11+
"sync"
1012
"syscall"
1113
"time"
1214

@@ -21,6 +23,7 @@ import (
2123
)
2224

2325
type Daemon struct {
26+
mu sync.Mutex
2427
state *state.State
2528
ipman *ipman.IPManager
2629
dns *dns.Server
@@ -52,6 +55,11 @@ func (d *Daemon) Run() error {
5255
sigCh := make(chan os.Signal, 1)
5356
signal.Notify(sigCh, syscall.SIGTERM, syscall.SIGINT, syscall.SIGHUP)
5457

58+
// Write PID file
59+
pidPath := filepath.Join(filepath.Dir(d.socketPath), "devproxy.pid")
60+
os.WriteFile(pidPath, []byte(fmt.Sprintf("%d", os.Getpid())), 0644)
61+
defer os.Remove(pidPath)
62+
5563
// Cleanup stale state
5664
slog.Info("cleaning up stale state...")
5765
d.ipman.CleanupLoopbackIPs()
@@ -111,6 +119,9 @@ func (d *Daemon) Run() error {
111119
}
112120

113121
func (d *Daemon) OnContainerStart(ctx context.Context, info watcher.ContainerInfo) error {
122+
d.mu.Lock()
123+
defer d.mu.Unlock()
124+
114125
project := info.Project
115126

116127
// Allocate IP if not already assigned
@@ -220,6 +231,9 @@ func (d *Daemon) OnContainerStart(ctx context.Context, info watcher.ContainerInf
220231
}
221232

222233
func (d *Daemon) OnContainerDie(ctx context.Context, info watcher.ContainerInfo) error {
234+
d.mu.Lock()
235+
defer d.mu.Unlock()
236+
223237
project := info.Project
224238
service := info.Service
225239

@@ -263,6 +277,9 @@ func (d *Daemon) OnContainerDie(ctx context.Context, info watcher.ContainerInfo)
263277
}
264278

265279
func (d *Daemon) teardownAll() {
280+
d.mu.Lock()
281+
defer d.mu.Unlock()
282+
266283
for _, cancel := range d.cancelFns {
267284
cancel()
268285
}

internal/dns/dns_test.go

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,19 @@ func startTestServer(t *testing.T) (*Server, string) {
1111
t.Helper()
1212
s := New("127.0.0.1:0") // random port
1313
go s.ListenAndServe()
14-
time.Sleep(50 * time.Millisecond)
15-
addr := s.Addr()
14+
15+
var addr string
16+
for i := 0; i < 100; i++ {
17+
addr = s.Addr()
18+
if addr != "127.0.0.1:0" {
19+
break
20+
}
21+
time.Sleep(5 * time.Millisecond)
22+
}
23+
if addr == "127.0.0.1:0" {
24+
t.Fatal("DNS server did not start in time")
25+
}
26+
1627
t.Cleanup(func() { s.Shutdown() })
1728
return s, addr
1829
}

internal/forwarder/forwarder.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"log/slog"
77
"net"
88
"sync"
9+
"time"
910
)
1011

1112
type Forwarder struct {
@@ -62,7 +63,7 @@ func (f *Forwarder) Start(ctx context.Context) error {
6263
}
6364

6465
func (f *Forwarder) handleConn(ctx context.Context, clientConn net.Conn) {
65-
targetConn, err := net.Dial("tcp", f.targetAddr)
66+
targetConn, err := net.DialTimeout("tcp", f.targetAddr, 5*time.Second)
6667
if err != nil {
6768
slog.Error("failed to connect to target", "target", f.targetAddr, "error", err)
6869
clientConn.Close()

internal/forwarder/forwarder_test.go

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,16 +33,29 @@ func startEchoServer(t *testing.T) string {
3333
return ln.Addr().String()
3434
}
3535

36+
func waitForForwarder(t *testing.T, f *Forwarder) string {
37+
t.Helper()
38+
for i := 0; i < 100; i++ {
39+
addr := f.ListenAddr()
40+
if addr != "127.0.0.1:0" {
41+
return addr
42+
}
43+
time.Sleep(5 * time.Millisecond)
44+
}
45+
t.Fatal("forwarder did not start listening in time")
46+
return ""
47+
}
48+
3649
func TestForwardBidirectional(t *testing.T) {
3750
echoAddr := startEchoServer(t)
3851
ctx, cancel := context.WithCancel(context.Background())
3952
defer cancel()
4053

4154
f := New("127.0.0.1:0", echoAddr)
4255
go f.Start(ctx)
43-
time.Sleep(50 * time.Millisecond)
56+
addr := waitForForwarder(t, f)
4457

45-
conn, err := net.Dial("tcp", f.ListenAddr())
58+
conn, err := net.Dial("tcp", addr)
4659
if err != nil {
4760
t.Fatal(err)
4861
}
@@ -67,12 +80,12 @@ func TestForwardShutdown(t *testing.T) {
6780

6881
f := New("127.0.0.1:0", echoAddr)
6982
go f.Start(ctx)
70-
time.Sleep(50 * time.Millisecond)
83+
addr := waitForForwarder(t, f)
7184

7285
cancel()
7386
time.Sleep(50 * time.Millisecond)
7487

75-
_, err := net.DialTimeout("tcp", f.ListenAddr(), 100*time.Millisecond)
88+
_, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
7689
if err == nil {
7790
t.Error("expected connection refused after shutdown")
7891
}
@@ -85,10 +98,10 @@ func TestForwardMultipleConnections(t *testing.T) {
8598

8699
f := New("127.0.0.1:0", echoAddr)
87100
go f.Start(ctx)
88-
time.Sleep(50 * time.Millisecond)
101+
addr := waitForForwarder(t, f)
89102

90103
for i := 0; i < 5; i++ {
91-
conn, err := net.Dial("tcp", f.ListenAddr())
104+
conn, err := net.Dial("tcp", addr)
92105
if err != nil {
93106
t.Fatal(err)
94107
}
@@ -112,10 +125,9 @@ func TestForwardPortConflict(t *testing.T) {
112125
f1 := New("127.0.0.1:0", echoAddr)
113126
errCh1 := make(chan error, 1)
114127
go func() { errCh1 <- f1.Start(ctx) }()
115-
time.Sleep(50 * time.Millisecond)
128+
boundAddr := waitForForwarder(t, f1)
116129

117130
// Sanity-check: the first forwarder should be working.
118-
boundAddr := f1.ListenAddr()
119131
conn, err := net.Dial("tcp", boundAddr)
120132
if err != nil {
121133
t.Fatalf("first forwarder unreachable: %v", err)
@@ -170,9 +182,9 @@ func TestForwardHalfClose(t *testing.T) {
170182

171183
f := New("127.0.0.1:0", serverAddr)
172184
go f.Start(ctx)
173-
time.Sleep(50 * time.Millisecond)
185+
addr := waitForForwarder(t, f)
174186

175-
conn, err := net.Dial("tcp", f.ListenAddr())
187+
conn, err := net.Dial("tcp", addr)
176188
if err != nil {
177189
t.Fatal(err)
178190
}

internal/ipman/netlink.go

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@ import (
99
)
1010

1111
func (m *IPManager) addLoopbackIPNetlink(ip string) error {
12+
parsed := net.ParseIP(ip)
13+
if parsed == nil || !parsed.IsLoopback() {
14+
return fmt.Errorf("refusing to add non-loopback IP %s to lo", ip)
15+
}
16+
1217
lo, err := netlink.LinkByName("lo")
1318
if err != nil {
1419
return fmt.Errorf("failed to get loopback interface: %w", err)
@@ -28,6 +33,11 @@ func (m *IPManager) addLoopbackIPNetlink(ip string) error {
2833
}
2934

3035
func (m *IPManager) removeLoopbackIPNetlink(ip string) error {
36+
parsed := net.ParseIP(ip)
37+
if parsed == nil || !parsed.IsLoopback() {
38+
return fmt.Errorf("refusing to remove non-loopback IP %s from lo", ip)
39+
}
40+
3141
lo, err := netlink.LinkByName("lo")
3242
if err != nil {
3343
return fmt.Errorf("failed to get loopback interface: %w", err)

internal/state/state.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@ func (s *State) SaveCollision(project, ip string) error {
9292
if err != nil {
9393
return err
9494
}
95-
return os.WriteFile(s.collisionsPath(), data, 0644)
95+
return os.WriteFile(s.collisionsPath(), data, 0600)
9696
}
9797

9898
func (s *State) LoadCollisions() (map[string]string, error) {

internal/watcher/watcher.go

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"fmt"
66
"log/slog"
7+
"regexp"
78
"sync"
89
"time"
910

@@ -12,6 +13,8 @@ import (
1213
"github.com/docker/docker/api/types/filters"
1314
)
1415

16+
var validNameRe = regexp.MustCompile(`^[a-zA-Z0-9_-]+$`)
17+
1518
// DockerClient is the subset of the Docker API used by the Watcher.
1619
// The full client.APIClient satisfies this interface.
1720
type DockerClient interface {
@@ -50,7 +53,17 @@ func (w *Watcher) projectMutex(project string) *sync.Mutex {
5053
}
5154

5255
func extractComposeInfo(labels map[string]string) (project, service string) {
53-
return labels["com.docker.compose.project"], labels["com.docker.compose.service"]
56+
project = labels["com.docker.compose.project"]
57+
service = labels["com.docker.compose.service"]
58+
if project != "" && (!validNameRe.MatchString(project) || len(project) > 128) {
59+
slog.Warn("invalid compose project name in container label", "project", project)
60+
return "", ""
61+
}
62+
if service != "" && (!validNameRe.MatchString(service) || len(service) > 128) {
63+
slog.Warn("invalid compose service name in container label", "service", service)
64+
return "", ""
65+
}
66+
return project, service
5467
}
5568

5669
func backoffSequence() []time.Duration {

internal/watcher/watcher_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,43 @@ func TestExtractComposeProjectMissing(t *testing.T) {
2727
}
2828
}
2929

30+
func TestExtractComposeInfoRejectsInvalidProject(t *testing.T) {
31+
labels := map[string]string{
32+
"com.docker.compose.project": `foo"; rm -rf /; #`,
33+
"com.docker.compose.service": "postgres",
34+
}
35+
project, service := extractComposeInfo(labels)
36+
if project != "" || service != "" {
37+
t.Errorf("expected empty strings for invalid project, got project=%q service=%q", project, service)
38+
}
39+
}
40+
41+
func TestExtractComposeInfoRejectsInvalidService(t *testing.T) {
42+
labels := map[string]string{
43+
"com.docker.compose.project": "myapp",
44+
"com.docker.compose.service": "pg;drop table",
45+
}
46+
project, service := extractComposeInfo(labels)
47+
if project != "" || service != "" {
48+
t.Errorf("expected empty strings for invalid service, got project=%q service=%q", project, service)
49+
}
50+
}
51+
52+
func TestExtractComposeInfoRejectsTooLongName(t *testing.T) {
53+
long := ""
54+
for i := 0; i < 129; i++ {
55+
long += "a"
56+
}
57+
labels := map[string]string{
58+
"com.docker.compose.project": long,
59+
"com.docker.compose.service": "svc",
60+
}
61+
project, _ := extractComposeInfo(labels)
62+
if project != "" {
63+
t.Error("expected empty project for name exceeding 128 chars")
64+
}
65+
}
66+
3067
func TestBackoffSequence(t *testing.T) {
3168
durations := backoffSequence()
3269
expected := []time.Duration{

0 commit comments

Comments
 (0)