Skip to content
44 changes: 24 additions & 20 deletions client_integration_test.go
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
package wgctrl_test

import (
"bytes"
"errors"
"fmt"
"net"
"net/netip"
"os"
"sort"
"strings"
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/mikioh/ipaddr"
"golang.zx2c4.com/wireguard/wgctrl"
"golang.zx2c4.com/wireguard/wgctrl/internal/wginternal"
Expand Down Expand Up @@ -144,9 +145,9 @@ func testGet(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
var (
port = 8888
ips = []net.IPNet{
wgtest.MustCIDR("192.0.2.0/32"),
wgtest.MustCIDR("2001:db8::/128"),
ips = []netip.Prefix{
netip.MustParsePrefix("192.0.2.0/32"),
netip.MustParsePrefix("2001:db8::/128"),
}

priv = wgtest.MustPrivateKey()
Expand Down Expand Up @@ -194,11 +195,11 @@ func testConfigure(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
for i := range dn.Peers {
ips := dn.Peers[i].AllowedIPs
sort.Slice(ips, func(i, j int) bool {
return bytes.Compare(ips[i].IP, ips[j].IP) > 0
return ips[i].Addr().Less(ips[j].Addr())
})
}

if diff := cmp.Diff(d, dn); diff != "" {
if diff := cmp.Diff(d, dn, cmpopts.EquateComparable(netip.Prefix{})); diff != "" {
t.Fatalf("unexpected Device from Device (-want +got):\n%s", diff)
}

Expand Down Expand Up @@ -229,17 +230,19 @@ func testConfigureManyIPs(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
t.Fatalf("failed to create cursor: %v", err)
}

var ips []net.IPNet
var ips []netip.Prefix
for pos := cur.Next(); pos != nil; pos = cur.Next() {
bits := 128
if pos.IP.To4() != nil {
bits = 32
}

ips = append(ips, net.IPNet{
IP: pos.IP,
Mask: net.CIDRMask(bits, bits),
})
addr, ok := netip.AddrFromSlice(pos.IP)
if !ok {
t.Fatalf("failed to convert net.IP to netip.Addr: %s", pos.IP)
}

ips = append(ips, netip.PrefixFrom(addr, bits))
}

peers = append(peers, wgtypes.PeerConfig{
Expand Down Expand Up @@ -291,7 +294,7 @@ func testConfigureManyPeers(t *testing.T, c *wgctrl.Client, d *wgtypes.Device) {
PresharedKey: &pk,
ReplaceAllowedIPs: true,
Endpoint: &net.UDPAddr{
IP: ips[0].IP,
IP: ips[0].Addr().AsSlice(),
Port: 1111,
},
PersistentKeepaliveInterval: &dur,
Expand Down Expand Up @@ -370,7 +373,6 @@ func testConfigurePeersUpdateOnly(t *testing.T, c *wgctrl.Client, d *wgtypes.Dev
t.Skip("FreeBSD kernel devices do not support UpdateOnly flag")
}


t.Fatalf("failed to configure second time on %q: %v", d.Name, err)
}

Expand Down Expand Up @@ -428,7 +430,7 @@ func countPeerIPs(d *wgtypes.Device) int {
return count
}

func ipsString(ipns []net.IPNet) string {
func ipsString(ipns []netip.Prefix) string {
ss := make([]string, 0, len(ipns))
for _, ipn := range ipns {
ss = append(ss, ipn.String())
Expand All @@ -437,23 +439,25 @@ func ipsString(ipns []net.IPNet) string {
return strings.Join(ss, ", ")
}

func generateIPs(n int) []net.IPNet {
func generateIPs(n int) []netip.Prefix {
cur, err := ipaddr.Parse("2001:db8::/64")
if err != nil {
panicf("failed to create cursor: %v", err)
}

ips := make([]net.IPNet, 0, n)
ips := make([]netip.Prefix, 0, n)
for i := 0; i < n; i++ {
pos := cur.Next()
if pos == nil {
panic("hit nil IP during IP generation")
}

ips = append(ips, net.IPNet{
IP: pos.IP,
Mask: net.CIDRMask(128, 128),
})
addr, ok := netip.AddrFromSlice(pos.IP)
if !ok {
panicf("failed to convert net.IP to netip.Addr: %s", pos.IP)
}

ips = append(ips, netip.PrefixFrom(addr, 128))
}

return ips
Expand Down
4 changes: 2 additions & 2 deletions cmd/wgctrl/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import (
"flag"
"fmt"
"log"
"net"
"net/netip"
"strings"

"golang.zx2c4.com/wireguard/wgctrl"
Expand Down Expand Up @@ -83,7 +83,7 @@ func printPeer(p wgtypes.Peer) {
)
}

func ipsString(ipns []net.IPNet) string {
func ipsString(ipns []netip.Prefix) string {
ss := make([]string, 0, len(ipns))
for _, ipn := range ipns {
ss = append(ss, ipn.String())
Expand Down
67 changes: 32 additions & 35 deletions internal/wgfreebsd/client_freebsd.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@ import (
"encoding/binary"
"fmt"
"net"
"net/netip"
"os"
"runtime"
"strconv"
"time"
"unsafe"

Expand Down Expand Up @@ -275,24 +277,18 @@ func parseEndpoint(ep []byte) *net.UDPAddr {
case unix.AF_INET:
sa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&ep[0]))

ep := &net.UDPAddr{
IP: make(net.IP, net.IPv4len),
Port: ntohs(sa.Port),
}
copy(ep.IP, sa.Addr[:])

return ep
return net.UDPAddrFromAddrPort(netip.AddrPortFrom(netip.AddrFrom4(sa.Addr), sa.Port))
case unix.AF_INET6:
sa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&ep[0]))

// TODO(mdlayher): IPv6 zone?
ep := &net.UDPAddr{
IP: make(net.IP, net.IPv6len),
Port: ntohs(sa.Port),
}
copy(ep.IP, sa.Addr[:])
addr := netip.AddrFrom16(sa.Addr)

return ep
// If the address is an IPv6 link-local address and the scope ID is non-zero
// then use the scope ID as the zone
if addr.Is6() && addr.IsLinkLocalUnicast() && sa.Scope_id != 0 {
addr = addr.WithZone(strconv.FormatUint(uint64(sa.Scope_id), 10))
}
return net.UDPAddrFromAddrPort(netip.AddrPortFrom(addr, sa.Port))
default:
// No endpoint configured.
return nil
Expand All @@ -302,54 +298,55 @@ func parseEndpoint(ep []byte) *net.UDPAddr {
func unparseEndpoint(ep net.UDPAddr) []byte {
var b []byte

if v4 := ep.IP.To4(); v4 != nil {
addrPort := ep.AddrPort()
addr := addrPort.Addr().Unmap()

switch {
case addr.Is4():
b = make([]byte, unsafe.Sizeof(unix.RawSockaddrInet4{}))
sa := (*unix.RawSockaddrInet4)(unsafe.Pointer(&b[0]))

sa.Family = unix.AF_INET
sa.Port = htons(ep.Port)
copy(sa.Addr[:], v4)
} else if v6 := ep.IP.To16(); v6 != nil {
sa.Addr = addr.As4()
case addr.Is6():
b = make([]byte, unsafe.Sizeof(unix.RawSockaddrInet6{}))
sa := (*unix.RawSockaddrInet6)(unsafe.Pointer(&b[0]))

sa.Family = unix.AF_INET6
sa.Port = htons(ep.Port)
copy(sa.Addr[:], v6)
sa.Addr = addr.As16()
}

return b
}

// parseAllowedIP unpacks a net.IPNet from a WGAIP structure.
func parseAllowedIP(aip nv.List) net.IPNet {
func parseAllowedIP(aip nv.List) netip.Prefix {
cidr := int(aip["cidr"].(uint64))
if ip, ok := aip["ipv4"]; ok {
return net.IPNet{
IP: net.IP(ip.([]byte)),
Mask: net.CIDRMask(cidr, 32),
}
addr, _ := netip.AddrFromSlice(ip.([]byte))
return netip.PrefixFrom(addr, cidr)
} else if ip, ok := aip["ipv6"]; ok {
return net.IPNet{
IP: net.IP(ip.([]byte)),
Mask: net.CIDRMask(cidr, 128),
}
addr, _ := netip.AddrFromSlice(ip.([]byte))
return netip.PrefixFrom(addr, cidr)
} else {
panicf("wgfreebsd: invalid address family for allowed IP: %+v", aip)
return net.IPNet{}
return netip.Prefix{}
}
}

func unparseAllowedIP(aip net.IPNet) nv.List {
func unparseAllowedIP(aip netip.Prefix) nv.List {
m := nv.List{}

ones, _ := aip.Mask.Size()
m["cidr"] = uint64(ones)
m["cidr"] = uint64(aip.Bits())

if v4 := aip.IP.To4(); v4 != nil {
m["ipv4"] = []byte(v4)
} else if v6 := aip.IP.To16(); v6 != nil {
m["ipv6"] = []byte(v6)
addr := aip.Addr().Unmap()
switch {
case addr.Is4():
m["ipv4"] = addr.AsSlice()
case addr.Is6():
m["ipv6"] = addr.AsSlice()
}

return m
Expand Down
4 changes: 2 additions & 2 deletions internal/wglinux/client_linux_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ package wglinux
import (
"errors"
"fmt"
"net"
"net/netip"
"os"
"os/user"
"syscall"
Expand Down Expand Up @@ -325,7 +325,7 @@ func diffAttrs(x, y []netlink.Attribute) string {
return cmp.Diff(xPrime, yPrime)
}

func mustAllowedIPs(ipns []net.IPNet) []byte {
func mustAllowedIPs(ipns []netip.Prefix) []byte {
ae := netlink.NewAttributeEncoder()
if err := encodeAllowedIPs(ipns)(ae); err != nil {
panicf("failed to create allowed IP attributes: %v", err)
Expand Down
Loading