Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 18 additions & 6 deletions esiclient/tests/unit/v1/test_port_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,22 +144,25 @@ def setUp(self):
self.netops.app.client_manager.sdk_connection = self.connection

def test_find_port_given_port(self):
assert self.netops.find_port("myport") == "myport"
assert self.netops.find_port("myport", None) == "myport"

def test_find_port_given_address(self):
self.connection.network.ports.return_value = [self.port_1]
assert self.netops.find_port(ipaddress.ip_address("10.10.10.10")) == self.port_1
assert (
self.netops.find_port(ipaddress.ip_address("10.10.10.10"), None)
== self.port_1
)

def test_find_port_given_missing_address(self):
self.connection.network.ports.return_value = []
self.assertRaises(
KeyError, self.netops.find_port, ipaddress.ip_address("10.10.10.10")
self.assertIsNone(
self.netops.find_port(ipaddress.ip_address("10.10.10.10"), None)
)

def test_find_port_given_multiple_matches(self):
self.connection.network.ports.return_value = [self.port_1, self.port_1]
self.assertRaises(
ValueError, self.netops.find_port, ipaddress.ip_address("10.10.10.10")
ValueError, self.netops.find_port, ipaddress.ip_address("10.10.10.10"), None
)

def test_find_or_create_port_given_existing_address(self):
Expand Down Expand Up @@ -276,7 +279,16 @@ def test_create_take_action(self):
self.forward_1
)
parser = self.cmd.get_parser("test")
args = parser.parse_args(["-p", "22", "10.10.10.10", "111.111.111.111"])
args = parser.parse_args(
[
"--internal-ip-network",
"testnetwork",
"-p",
"22",
"10.10.10.10",
"111.111.111.111",
]
)
res = self.cmd.take_action(args)
assert res == (
[
Expand Down
95 changes: 45 additions & 50 deletions esiclient/v1/port_forwarding.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,11 +178,16 @@ def find_or_create_floating_ip(self, address):

return fip

def find_port(self, address):
def find_port(self, address, internal_ip_network):
connection = self.app.client_manager.sdk_connection
if isinstance(address, (ipaddress.IPv4Address, ipaddress.IPv6Address)):
# see if there exists a port with the given internal ip
ports = list(connection.network.ports(fixed_ips=f"ip_address={address}"))
ports = list(
connection.network.ports(
network_id=internal_ip_network.id if internal_ip_network else None,
fixed_ips=f"ip_address={address}",
)
)

# error out if we find multiple matches
if len(ports) > 1:
Expand All @@ -191,49 +196,51 @@ def find_port(self, address):
# if there was a single port, use it
if len(ports) == 1:
return ports[0]

raise KeyError(f"unable to find port with address {address}")
else:
# we already have a port, so just return it
return address

# we found nothing
return None

def find_or_create_port(
self, address, internal_ip_network=None, internal_ip_subnet=None
):
connection = self.app.client_manager.sdk_connection
try:
return self.find_port(address)
except KeyError:
# we need to create a port, which means we need to know the appropriate internal network
if internal_ip_network is None:
if internal_ip_subnet is None:
raise ValueError(
"unable to create a port because --internal-ip-network is unset"
)
internal_network_id = internal_ip_subnet.network_id
else:
internal_network_id = internal_ip_network.id

# if we were given a subnet name, use it, otherwise search through subnets for an appropriate match
if internal_ip_subnet:
subnet = internal_ip_subnet
else:
for subnet in connection.network.subnets(
network_id=internal_network_id,
):
if subnet.ip_version != address.version:
continue
cidr = ipaddress.ip_network(subnet.cidr)
if address in cidr:
break
else:
raise KeyError(f"unable to find a subnet for address {address}")

return connection.network.create_port(
name=f"esi-autocreated-{address}",
if port := self.find_port(address, internal_ip_network):
return port

# we need to know the appropriate network in order to find or create a port
if internal_ip_network is None:
if internal_ip_subnet is None:
raise ValueError(
"unable to create a port because --internal-ip-network is unset"
)
internal_network_id = internal_ip_subnet.network_id
else:
internal_network_id = internal_ip_network.id

# if we were given a subnet name, use it, otherwise search through subnets for an appropriate match
if internal_ip_subnet:
subnet = internal_ip_subnet
else:
for subnet in connection.network.subnets(
network_id=internal_network_id,
fixed_ips=[{"subnet_id": subnet.id, "ip_address": str(address)}],
)
):
if subnet.ip_version != address.version:
continue
cidr = ipaddress.ip_network(subnet.cidr)
if address in cidr:
break
else:
raise KeyError(f"unable to find a subnet for address {address}")

return connection.network.create_port(
name=f"esi-autocreated-{address}",
network_id=internal_network_id,
fixed_ips=[{"subnet_id": subnet.id, "ip_address": str(address)}],
)


def port_forwarding_exists(fip, internal_ip_address, port):
Expand Down Expand Up @@ -375,8 +382,8 @@ def get_parser(self, prog_name: str):
parser.add_argument("--port", "-p", type=PortSpec.from_spec, action="append")
parser.add_argument(
"internal_ip_descriptor",
type=AddressOrPortArg(self),
help="ip address, port name, or port uuid",
type=ipaddress.ip_address,
help="internal ip address",
)
parser.add_argument(
"external_ip_descriptor",
Expand All @@ -391,18 +398,7 @@ def take_action(self, parsed_args: argparse.Namespace):
forwards = []

fip = self.find_floating_ip(parsed_args.external_ip_descriptor)
internal_port = self.find_port(parsed_args.internal_ip_descriptor)

if isinstance(
parsed_args.internal_ip_descriptor,
(ipaddress.IPv4Address, ipaddress.IPv6Address),
):
internal_ip_address = str(parsed_args.internal_ip_descriptor)
else:
# if we were given a port name, always pick the first fixed ip. if the user
# wants to forward to a specific address, they should specify the address
# rather than the port.
internal_ip_address = internal_port.fixed_ips[0]["ip_address"]
internal_ip_address = str(parsed_args.internal_ip_descriptor)

for port in parsed_args.port:
for fwd in self.app.client_manager.sdk_connection.network.floating_ip_port_forwardings(
Expand All @@ -411,7 +407,6 @@ def take_action(self, parsed_args: argparse.Namespace):
if (
fwd.external_port == port.external_port
and fwd.internal_ip_address == internal_ip_address
and fwd.internal_port == port.internal_port
):
forwards.append((fip, fwd))
break
Expand Down