diff --git a/esiclient/tests/unit/v1/test_port_forwarding.py b/esiclient/tests/unit/v1/test_port_forwarding.py index 35c0da0..e082e20 100644 --- a/esiclient/tests/unit/v1/test_port_forwarding.py +++ b/esiclient/tests/unit/v1/test_port_forwarding.py @@ -144,22 +144,25 @@ def setUp(self): self.netops.app.client_manager.sdk_connection = self.connection def test_find_port_given_port(self): - assert self.netops.find_port("myport") == "myport" + assert self.netops.find_port("myport", None) == "myport" def test_find_port_given_address(self): self.connection.network.ports.return_value = [self.port_1] - assert self.netops.find_port(ipaddress.ip_address("10.10.10.10")) == self.port_1 + assert ( + self.netops.find_port(ipaddress.ip_address("10.10.10.10"), None) + == self.port_1 + ) def test_find_port_given_missing_address(self): self.connection.network.ports.return_value = [] - self.assertRaises( - KeyError, self.netops.find_port, ipaddress.ip_address("10.10.10.10") + self.assertIsNone( + self.netops.find_port(ipaddress.ip_address("10.10.10.10"), None) ) def test_find_port_given_multiple_matches(self): self.connection.network.ports.return_value = [self.port_1, self.port_1] self.assertRaises( - ValueError, self.netops.find_port, ipaddress.ip_address("10.10.10.10") + ValueError, self.netops.find_port, ipaddress.ip_address("10.10.10.10"), None ) def test_find_or_create_port_given_existing_address(self): @@ -276,7 +279,16 @@ def test_create_take_action(self): self.forward_1 ) parser = self.cmd.get_parser("test") - args = parser.parse_args(["-p", "22", "10.10.10.10", "111.111.111.111"]) + args = parser.parse_args( + [ + "--internal-ip-network", + "testnetwork", + "-p", + "22", + "10.10.10.10", + "111.111.111.111", + ] + ) res = self.cmd.take_action(args) assert res == ( [ diff --git a/esiclient/v1/port_forwarding.py b/esiclient/v1/port_forwarding.py index 7cf263a..ae65b25 100644 --- a/esiclient/v1/port_forwarding.py +++ b/esiclient/v1/port_forwarding.py @@ -178,11 +178,16 @@ def find_or_create_floating_ip(self, address): return fip - def find_port(self, address): + def find_port(self, address, internal_ip_network): connection = self.app.client_manager.sdk_connection if isinstance(address, (ipaddress.IPv4Address, ipaddress.IPv6Address)): # see if there exists a port with the given internal ip - ports = list(connection.network.ports(fixed_ips=f"ip_address={address}")) + ports = list( + connection.network.ports( + network_id=internal_ip_network.id if internal_ip_network else None, + fixed_ips=f"ip_address={address}", + ) + ) # error out if we find multiple matches if len(ports) > 1: @@ -191,49 +196,51 @@ def find_port(self, address): # if there was a single port, use it if len(ports) == 1: return ports[0] - - raise KeyError(f"unable to find port with address {address}") else: # we already have a port, so just return it return address + # we found nothing + return None + def find_or_create_port( self, address, internal_ip_network=None, internal_ip_subnet=None ): connection = self.app.client_manager.sdk_connection - try: - return self.find_port(address) - except KeyError: - # we need to create a port, which means we need to know the appropriate internal network - if internal_ip_network is None: - if internal_ip_subnet is None: - raise ValueError( - "unable to create a port because --internal-ip-network is unset" - ) - internal_network_id = internal_ip_subnet.network_id - else: - internal_network_id = internal_ip_network.id - # if we were given a subnet name, use it, otherwise search through subnets for an appropriate match - if internal_ip_subnet: - subnet = internal_ip_subnet - else: - for subnet in connection.network.subnets( - network_id=internal_network_id, - ): - if subnet.ip_version != address.version: - continue - cidr = ipaddress.ip_network(subnet.cidr) - if address in cidr: - break - else: - raise KeyError(f"unable to find a subnet for address {address}") - - return connection.network.create_port( - name=f"esi-autocreated-{address}", + if port := self.find_port(address, internal_ip_network): + return port + + # we need to know the appropriate network in order to find or create a port + if internal_ip_network is None: + if internal_ip_subnet is None: + raise ValueError( + "unable to create a port because --internal-ip-network is unset" + ) + internal_network_id = internal_ip_subnet.network_id + else: + internal_network_id = internal_ip_network.id + + # if we were given a subnet name, use it, otherwise search through subnets for an appropriate match + if internal_ip_subnet: + subnet = internal_ip_subnet + else: + for subnet in connection.network.subnets( network_id=internal_network_id, - fixed_ips=[{"subnet_id": subnet.id, "ip_address": str(address)}], - ) + ): + if subnet.ip_version != address.version: + continue + cidr = ipaddress.ip_network(subnet.cidr) + if address in cidr: + break + else: + raise KeyError(f"unable to find a subnet for address {address}") + + return connection.network.create_port( + name=f"esi-autocreated-{address}", + network_id=internal_network_id, + fixed_ips=[{"subnet_id": subnet.id, "ip_address": str(address)}], + ) def port_forwarding_exists(fip, internal_ip_address, port): @@ -375,8 +382,8 @@ def get_parser(self, prog_name: str): parser.add_argument("--port", "-p", type=PortSpec.from_spec, action="append") parser.add_argument( "internal_ip_descriptor", - type=AddressOrPortArg(self), - help="ip address, port name, or port uuid", + type=ipaddress.ip_address, + help="internal ip address", ) parser.add_argument( "external_ip_descriptor", @@ -391,18 +398,7 @@ def take_action(self, parsed_args: argparse.Namespace): forwards = [] fip = self.find_floating_ip(parsed_args.external_ip_descriptor) - internal_port = self.find_port(parsed_args.internal_ip_descriptor) - - if isinstance( - parsed_args.internal_ip_descriptor, - (ipaddress.IPv4Address, ipaddress.IPv6Address), - ): - internal_ip_address = str(parsed_args.internal_ip_descriptor) - else: - # if we were given a port name, always pick the first fixed ip. if the user - # wants to forward to a specific address, they should specify the address - # rather than the port. - internal_ip_address = internal_port.fixed_ips[0]["ip_address"] + internal_ip_address = str(parsed_args.internal_ip_descriptor) for port in parsed_args.port: for fwd in self.app.client_manager.sdk_connection.network.floating_ip_port_forwardings( @@ -411,7 +407,6 @@ def take_action(self, parsed_args: argparse.Namespace): if ( fwd.external_port == port.external_port and fwd.internal_ip_address == internal_ip_address - and fwd.internal_port == port.internal_port ): forwards.append((fip, fwd)) break