From f68764d346af93490a63b21121121409f6a56bea Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 20 Feb 2026 10:29:38 +0300 Subject: [PATCH 01/37] Add: New SID prefix and Group RID enums --- .../552b4eafb1aa_remove_objectsid_vals.py | 27 +++++++++++++++++++ app/enums.py | 14 ++++++++++ 2 files changed, 41 insertions(+) create mode 100644 app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py new file mode 100644 index 000000000..ab6ac4e78 --- /dev/null +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -0,0 +1,27 @@ +"""empty message. + +Revision ID: 552b4eafb1aa +Revises: 2dadf40c026a +Create Date: 2026-02-17 09:24:57.906080 + +""" + +from dishka import AsyncContainer + +# revision identifiers, used by Alembic. +revision: None | str = "552b4eafb1aa" +down_revision: None | str = "2dadf40c026a" +branch_labels: None | list[str] = None +depends_on: None | list[str] = None + + +def upgrade(container: AsyncContainer) -> None: + """Upgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + # ### end Alembic commands ### + + +def downgrade(container: AsyncContainer) -> None: + """Downgrade.""" + # ### commands auto generated by Alembic - please adjust! ### + # ### end Alembic commands ### diff --git a/app/enums.py b/app/enums.py index e9bdd8f1f..847b48d53 100644 --- a/app/enums.py +++ b/app/enums.py @@ -279,3 +279,17 @@ class SamAccountTypeCodes(IntEnum): def to_hex(value: int) -> str: """Convert decimal value to hex string.""" return hex(value) + + +class SidPrefix(StrEnum): + """SID prefix.""" + + DOMAIN_IDENTIFIER = "S-1-5-21" + BUILT_IN_DOMAIN = "S-1-5-32" + + +class GroupRid(IntEnum): + ADMINISTRATORS = 544 + USERS = 545 + GUESTS = 546 + POWER_USERS = 547 From ca792480ffb7f2e8091a2e66976a5df3079b8420 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 26 Feb 2026 13:48:04 +0300 Subject: [PATCH 02/37] Add: Implement RID Manager and update object SID handling across the application --- .../552b4eafb1aa_remove_objectsid_vals.py | 301 ++++++++++- app/api/main/schema.py | 7 +- app/constants.py | 19 +- app/entities.py | 13 +- app/enums.py | 17 +- app/extra/scripts/add_domain_controller.py | 20 +- app/ioc.py | 17 + app/ldap_protocol/auth/setup_gateway.py | 53 +- app/ldap_protocol/auth/use_cases.py | 16 +- app/ldap_protocol/kerberos/dtos.py | 1 - app/ldap_protocol/kerberos/ldap_structure.py | 16 +- app/ldap_protocol/kerberos/service.py | 9 - app/ldap_protocol/ldap_requests/add.py | 3 +- app/ldap_protocol/ldap_requests/contexts.py | 3 + app/ldap_protocol/ldap_requests/search.py | 40 +- app/ldap_protocol/rid_manager/__init__.py | 15 + app/ldap_protocol/rid_manager/gateways.py | 486 ++++++++++++++++++ app/ldap_protocol/rid_manager/use_cases.py | 158 ++++++ app/ldap_protocol/rid_manager/utils.py | 13 + app/ldap_protocol/rootdse/reader.py | 12 +- app/ldap_protocol/utils/cte.py | 7 +- app/ldap_protocol/utils/helpers.py | 29 -- app/ldap_protocol/utils/queries.py | 35 +- app/repo/pg/tables.py | 2 - tests/conftest.py | 3 +- tests/test_ldap/test_rid_manager/__init__.py | 1 + 26 files changed, 1149 insertions(+), 147 deletions(-) create mode 100644 app/ldap_protocol/rid_manager/__init__.py create mode 100644 app/ldap_protocol/rid_manager/gateways.py create mode 100644 app/ldap_protocol/rid_manager/use_cases.py create mode 100644 app/ldap_protocol/rid_manager/utils.py create mode 100644 tests/test_ldap/test_rid_manager/__init__.py diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index ab6ac4e78..0e0136cfe 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -1,4 +1,4 @@ -"""empty message. +"""Add rIDManager and rIDSet objectClasses to LDAP schema. Revision ID: 552b4eafb1aa Revises: 2dadf40c026a @@ -6,22 +6,303 @@ """ -from dishka import AsyncContainer +import sqlalchemy as sa +from alembic import op +from dishka import AsyncContainer, Scope +from sqlalchemy import delete, select, update +from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession + +from entities import Attribute, Directory, EntityType +from enums import EntityTypeNames +from ldap_protocol.ldap_schema.dto import EntityTypeDTO +from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.rid_manager.gateways import ( + RIDManagerGateway, + RIDManagerSetupGateway, +) +from ldap_protocol.rid_manager.use_cases import ( + RID_AVAILABLE_MAX, + RIDManagerSetupUseCase, +) +from ldap_protocol.rid_manager.utils import create_qword +from ldap_protocol.roles.ace_dao import AccessControlEntryDAO +from ldap_protocol.roles.role_dao import RoleDAO +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. revision: None | str = "552b4eafb1aa" -down_revision: None | str = "2dadf40c026a" +down_revision: None | str = "ebf19750805e" branch_labels: None | list[str] = None depends_on: None | list[str] = None -def upgrade(container: AsyncContainer) -> None: - """Upgrade.""" - # ### commands auto generated by Alembic - please adjust! ### - # ### end Alembic commands ### +def upgrade(container: AsyncContainer) -> None: # noqa: C901 + """Add rIDManager and rIDSet objectClasses to LDAP schema.""" + + async def _create_entity_types( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Create rIDManager and rIDSet Entity Types.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + entity_type_use_case = await cnt.get(EntityTypeUseCase) + + if not await get_base_directories(session): + return + + await entity_type_use_case.create( + EntityTypeDTO( + name=EntityTypeNames.RID_MANAGER, + object_class_names=[ + "top", + "rIDManager", + ], + is_system=True, + ), + ) + + await entity_type_use_case.create( + EntityTypeDTO( + name=EntityTypeNames.RID_SET, + object_class_names=[ + "top", + "rIDSet", + ], + is_system=True, + ), + ) + + await session.commit() + + op.run_async(_create_entity_types) + + async def _migrate_object_sids( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Move Directory.objectSid values into Attributes table. + + Additionally, for domain directories move the domain SID prefix part + into the ``DomainIdentifier`` attribute. + """ + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + directories = await session.scalars(select(Directory)) + + for directory in directories: + if not directory.object_sid: + continue + + existing_attr = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == directory.id, + qa(Attribute.name) == "objectSid", + ), + ) + + if not existing_attr: + session.add( + Attribute( + name="objectSid", + value=directory.object_sid, + directory_id=directory.id, + ), + ) + + if directory.name == "domain": + identifier = directory.object_sid.split("-")[ + -1 + ] # remove sid prefix + + session.add( + Attribute( + name="DomainIdentifier", + value=identifier, + directory_id=directory.id, + ), + ) + + await session.commit() + + op.run_async(_migrate_object_sids) + + async def _init_rid_manager( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Initialize RID Manager and RID Set for existing data.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + rid_setup_gateway = await cnt.get(RIDManagerSetupGateway) + rid_setup_use_case = RIDManagerSetupUseCase( + rid_manager_setup_gateway=rid_setup_gateway, + role_dao=await cnt.get(RoleDAO), + access_control_entry_dao=await cnt.get(AccessControlEntryDAO), + ) + rid_gateway = RIDManagerGateway(session) + + if not await get_base_directories(session): + return + + try: + await rid_gateway.get_rid_manager() + except ValueError: + await rid_setup_use_case.setup() + await session.commit() + await rid_gateway.get_rid_manager() + + rid_set_dir = await rid_gateway.get_rid_set() + + base_domain = await rid_gateway.get_base_domain() + domain_identifier = await rid_gateway.get_domain_identifier( + base_domain, + ) + sid_prefix = f"S-1-5-21-{domain_identifier}-" + + sid_values = await session.scalars( + select(Attribute).where( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).like(f"{sid_prefix}%"), + ), + ) + + max_rid = 0 + for sid_value in sid_values: + if not sid_value or not sid_value.value: + continue + try: + parts = sid_value.value.split("-") + rid = int(parts[-1]) + except (ValueError, IndexError): + continue + if rid > max_rid: + max_rid = rid + + start_rid = max(max_rid, RIDManagerSetupUseCase.RID_USER_MIN) + + qword = create_qword(start_rid, RID_AVAILABLE_MAX) + await rid_gateway.update_available_pool(qword) + + result = await session.execute( + update(Attribute) + .where( + qa(Attribute.directory_id) == rid_set_dir.id, + qa(Attribute.name) == "rIDNextRID", + ) + .values(value=str(start_rid)), + ) + if result.rowcount == 0: + session.add( + Attribute( + directory_id=rid_set_dir.id, + name="rIDNextRID", + value=str(start_rid), + ), + ) + + await session.commit() + + op.run_async(_init_rid_manager) + + op.drop_column("Directory", "objectSid") def downgrade(container: AsyncContainer) -> None: - """Downgrade.""" - # ### commands auto generated by Alembic - please adjust! ### - # ### end Alembic commands ### + """Remove rIDManager and rIDSet objectClasses from LDAP schema.""" + op.add_column( + "Directory", + sa.Column("objectSid", sa.String(), nullable=True), + ) + + async def _delete_entity_types( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Delete rIDManager and rIDSet Entity Types.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + if not await get_base_directories(session): + return + + await session.execute( + delete(EntityType).where( + qa(EntityType.name).in_( + [ + EntityTypeNames.RID_MANAGER, + EntityTypeNames.RID_SET, + ], + ), + ), + ) + + await session.commit() + + op.run_async(_delete_entity_types) + + async def _delete_rid_manager_dirs( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Delete RID Manager and RID Set directories.""" + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + if not await get_base_directories(session): + return + + await session.execute( + delete(Directory).where( + qa(Directory.name).in_( + [ + "RID Manager$", + "RID Set", + ], + ), + ), + ) + await session.commit() + + op.run_async(_delete_rid_manager_dirs) + + async def _rollback_object_sids( + connection: AsyncConnection, # noqa: ARG001 + ) -> None: + """Restore Directory.objectSid values from Attributes. + + Also removes the DomainIdentifier attribute that was introduced in + upgrade for domain directories. + """ + async with container(scope=Scope.REQUEST) as cnt: + session = await cnt.get(AsyncSession) + + directories = await session.scalars(select(Directory)) + + for directory in directories: + await session.execute( + delete(Attribute).where( + qa(Attribute.directory_id) == directory.id, + qa(Attribute.name) == "DomainIdentifier", + ), + ) + + attr = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == directory.id, + qa(Attribute.name) == "objectSid", + ), + ) + + if not attr or not attr.value: + continue + + directory.object_sid = attr.value + + await session.execute( + delete(Attribute).where( + qa(Attribute.directory_id) == directory.id, + qa(Attribute.name) == "objectSid", + ), + ) + + await session.commit() + + op.run_async(_rollback_object_sids) diff --git a/app/api/main/schema.py b/app/api/main/schema.py index 5ea6545a8..4140bb716 100644 --- a/app/api/main/schema.py +++ b/app/api/main/schema.py @@ -38,8 +38,11 @@ def _cast_filter(self) -> UnaryExpression | ColumnElement: ) @staticmethod - def get_directory_sid(directory: Directory) -> str: # type: ignore - return directory.object_sid + def get_directory_sid(directory: Directory) -> str | None: # type: ignore + for attr in getattr(directory, "attributes", []): + if attr.name and attr.name.lower() == "objectsid" and attr.value: + return attr.value + return None @staticmethod def get_directory_guid(directory: Directory) -> str: # type: ignore diff --git a/app/constants.py b/app/constants.py index a6192f314..c33f54711 100644 --- a/app/constants.py +++ b/app/constants.py @@ -4,13 +4,14 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from enums import EntityTypeNames, SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes, SecurityPrincipalRid from ldap_protocol.ldap_schema.dto import EntityTypeDTO CONFIGURATION_DIR_NAME = "Configuration" GROUPS_CONTAINER_NAME = "Groups" COMPUTERS_CONTAINER_NAME = "Computers" USERS_CONTAINER_NAME = "Users" +SYSTEM_CONTAINER_NAME = "System" DOMAIN_CONTROLLERS_OU_NAME = "Domain Controllers" READ_ONLY_GROUP_NAME = "read-only" @@ -324,6 +325,14 @@ "object_class": "container", "attributes": {"objectClass": ["top", "configuration"]}, }, + { + "name": SYSTEM_CONTAINER_NAME, + "object_class": "organizationalUnit", + "attributes": { + "objectClass": ["top", "container"], + }, + "children": [], + }, { "name": GROUPS_CONTAINER_NAME, "entity_type_name": EntityTypeNames.CONTAINER, @@ -347,7 +356,7 @@ ], "gidNumber": ["512"], }, - "objectSid": 512, + "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, }, { "name": DOMAIN_USERS_GROUP_NAME, @@ -363,7 +372,7 @@ ], "gidNumber": ["513"], }, - "objectSid": 513, + "objectSid": SecurityPrincipalRid.DOMAIN_USERS, }, { "name": READ_ONLY_GROUP_NAME, @@ -379,7 +388,7 @@ ], "gidNumber": ["521"], }, - "objectSid": 521, + "objectSid": SecurityPrincipalRid.DOMAIN_READ_ONLY, }, { "name": DOMAIN_COMPUTERS_GROUP_NAME, @@ -395,7 +404,7 @@ ], "gidNumber": ["515"], }, - "objectSid": 515, + "objectSid": SecurityPrincipalRid.DOMAIN_COMPUTERS, }, ], }, diff --git a/app/entities.py b/app/entities.py index 5a1ec2adb..4c7f131c1 100644 --- a/app/entities.py +++ b/app/entities.py @@ -140,7 +140,6 @@ class Directory: search_fields: ClassVar[dict[str, str]] = { "name": "name", "objectguid": "objectGUID", - "objectsid": "objectSid", } ro_fields: ClassVar[set[str]] = { "uid", @@ -186,12 +185,18 @@ def create_path( @property def relative_id(self) -> str: - """Get RID from objectSid. + """Get RID from objectSid attribute. Relative Identifier (RID) is the last sub-authority value of a SID. """ - if "-" in self.object_sid: - return self.object_sid.split("-")[-1] + attrs = self.__dict__.get("attributes") + if not attrs: + return "" + + for attr in attrs: + if attr.name and attr.name.lower() == "objectsid" and attr.value: + if "-" in attr.value: + return attr.value.split("-")[-1] return "" @property diff --git a/app/enums.py b/app/enums.py index 847b48d53..79d530085 100644 --- a/app/enums.py +++ b/app/enums.py @@ -72,6 +72,8 @@ class EntityTypeNames(StrEnum): KRB_CONTAINER = "KRB Container" KRB_PRINCIPAL = "KRB Principal" KRB_REALM_CONTAINER = "KRB Realm Container" + RID_MANAGER = "RID Manager" + RID_SET = "RID Set" class KindType(StrEnum): @@ -288,8 +290,13 @@ class SidPrefix(StrEnum): BUILT_IN_DOMAIN = "S-1-5-32" -class GroupRid(IntEnum): - ADMINISTRATORS = 544 - USERS = 545 - GUESTS = 546 - POWER_USERS = 547 +class SecurityPrincipalRid(IntEnum): + ADMINISTRATOR = 500 + GUESTS = 501 + KRBTGT = 502 + DOMAIN_ADMINS = 512 + DOMAIN_USERS = 513 + DOMAIN_GUESTS = 514 + DOMAIN_COMPUTERS = 515 + DOMAIN_CONTROLLERS = 516 + DOMAIN_READ_ONLY = 521 diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py index 21d3bbaed..36eb9a00b 100644 --- a/app/extra/scripts/add_domain_controller.py +++ b/app/extra/scripts/add_domain_controller.py @@ -11,14 +11,13 @@ from config import Settings from constants import DOMAIN_CONTROLLERS_OU_NAME from entities import Attribute, Directory -from enums import SamAccountTypeCodes +from enums import SamAccountTypeCodes, SecurityPrincipalRid from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) from ldap_protocol.objects import UserAccountControlFlag +from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase from ldap_protocol.roles.role_use_case import RoleUseCase -from ldap_protocol.utils.helpers import create_object_sid -from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -27,8 +26,8 @@ async def _add_domain_controller( role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, settings: Settings, - domain: Directory, dc_ou_dir: Directory, + rid_manager_use_case: RIDManagerUseCase, ) -> None: dc_directory = Directory( object_class="", @@ -40,7 +39,10 @@ async def _add_domain_controller( await session.flush() dc_directory.parent_id = dc_ou_dir.id - dc_directory.object_sid = create_object_sid(domain, dc_directory.id) + await rid_manager_use_case.set_object_sid( + directory=dc_directory, + rid=SecurityPrincipalRid.DOMAIN_CONTROLLERS, + ) await session.flush() attributes = [ @@ -103,14 +105,10 @@ async def add_domain_controller( settings: Settings, role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, + rid_manager_use_case: RIDManagerUseCase, ) -> None: logger.info("Adding domain controller.") - domains = await get_base_directories(session) - if not domains: - logger.debug("Cannot get base directory") - return - domain_controllers_ou = await session.scalar( select(Directory).where( qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, @@ -140,8 +138,8 @@ async def add_domain_controller( role_use_case=role_use_case, entity_type_use_case=entity_type_use_case, settings=settings, - domain=domains[0], dc_ou_dir=domain_controllers_ou, + rid_manager_use_case=rid_manager_use_case, ) logger.debug("Domain controller added.") diff --git a/app/ioc.py b/app/ioc.py index 019a2f6f3..06cc7c617 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -178,6 +178,12 @@ PasswordBanWordUseCases, UserPasswordHistoryUseCases, ) +from ldap_protocol.rid_manager import ( + RIDManagerGateway, + RIDManagerSetupGateway, + RIDManagerSetupUseCase, + RIDManagerUseCase, +) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.migrations_ace_dao import ( @@ -641,6 +647,17 @@ def get_object_class_use_case_legacy( rootdse_reader = provide(RootDSEReader, scope=Scope.REQUEST) dcinfo_reader = provide(DCInfoReader, scope=Scope.REQUEST) + rid_manager_gateway = provide(RIDManagerGateway, scope=Scope.REQUEST) + rid_manager_setup_gateway = provide( + RIDManagerSetupGateway, + scope=Scope.REQUEST, + ) + rid_manager_use_case = provide(RIDManagerUseCase, scope=Scope.REQUEST) + rid_manager_setup_use_case = provide( + RIDManagerSetupUseCase, + scope=Scope.REQUEST, + ) + class LDAPContextProvider(Provider): """Context provider.""" diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index 9d79c80f8..96b293dde 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -12,7 +12,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, Directory, Group, NetworkPolicy, User -from enums import EntityTypeNames +from enums import EntityTypeNames, SidPrefix from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, ) @@ -20,8 +20,8 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) +from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase from ldap_protocol.utils.async_cache import base_directories_cache -from ldap_protocol.utils.helpers import create_object_sid, generate_domain_sid from ldap_protocol.utils.queries import get_domain_object_class from password_utils import PasswordUtils from repo.pg.tables import queryable_attr as qa @@ -37,6 +37,7 @@ def __init__( entity_type_use_case: EntityTypeUseCase, attribute_value_validator: AttributeValueValidator, directory_dao: DirectoryDAO, + rid_manager_use_case: RIDManagerUseCase, ) -> None: """Initialize Setup use case. @@ -49,6 +50,7 @@ def __init__( self._entity_type_use_case = entity_type_use_case self._attribute_value_validator = attribute_value_validator self._directory_dao = directory_dao + self._rid_manager_use_case = rid_manager_use_case async def is_setup(self) -> bool: """Check if setup is performed. @@ -67,21 +69,9 @@ async def setup_enviroment( *, data: list, is_system: bool = True, - dn: str = "multifactor.dev", + domain: Directory, ) -> None: """Create directories and users for enviroment.""" - cat_result = await self._session.execute(select(Directory)) - if cat_result.scalar_one_or_none(): - logger.warning("dev data already set up") - return - - domain = Directory(name=dn, object_class="domain") - domain.is_system = True - domain.object_sid = generate_domain_sid() - domain.path = [f"dc={path}" for path in reversed(dn.split("."))] - domain.depth = len(domain.path) - domain.rdname = "" - async with self._session.begin_nested(): self._session.add(domain) self._session.add( @@ -132,6 +122,28 @@ async def setup_enviroment( logger.error(traceback.format_exc()) raise + async def is_base_domain_created(self) -> bool: + """Check if base domain is created.""" + cat_result = await self._session.execute(select(Directory)) + if cat_result.scalar_one_or_none(): + logger.warning("dev data already set up") + return True + return False + + async def create_base_domain( + self, + dn: str = "multifactor.dev", + ) -> Directory: + """Create base domain.""" + domain = Directory(name=dn, object_class="domain") + domain.is_system = True + domain.path = [f"dc={path}" for path in reversed(dn.split("."))] + domain.depth = len(domain.path) + domain.rdname = "" + self._session.add(domain) + await self._session.flush() + return domain + async def create_dir( self, data: dict, @@ -161,11 +173,12 @@ async def create_dir( ), ) - dir_.object_sid = create_object_sid( - domain, - rid=data.get("objectSid", dir_.id), - reserved="objectSid" in data, - ) + if "objectSid" in data: + await self._rid_manager_use_case.set_object_sid( + directory=dir_, + rid=int(data["objectSid"]), + sid_prefix=SidPrefix.BUILT_IN_DOMAIN, + ) if dir_.object_class == "group": group = Group(directory_id=dir_.id) diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index b6691df00..b323d08d8 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -19,7 +19,7 @@ FIRST_SETUP_DATA, USERS_CONTAINER_NAME, ) -from enums import EntityTypeNames, SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes, SecurityPrincipalRid from ldap_protocol.auth.dto import SetupDTO from ldap_protocol.auth.setup_gateway import SetupGateway from ldap_protocol.identity.exceptions import ( @@ -44,6 +44,7 @@ from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases +from ldap_protocol.rid_manager.use_cases import RIDManagerSetupUseCase from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.helpers import create_integer_hash, ft_now @@ -64,6 +65,7 @@ def __init__( audit_use_case: AuditUseCase, session: AsyncSession, settings: Settings, + rid_manager_setup_use_case: RIDManagerSetupUseCase, ) -> None: """Initialize Setup manager. @@ -82,6 +84,7 @@ def __init__( self._object_class_use_case_legacy = object_class_use_case_legacy self._object_class_use_case = object_class_use_case self._settings = settings + self._rid_manager_setup_use_case = rid_manager_setup_use_case async def setup(self, dto: SetupDTO) -> None: """Perform the initial setup of structure and policies. @@ -123,6 +126,7 @@ def _create_domain_controller_data(self) -> dict: "name": self._settings.HOST_MACHINE_SHORT_NAME, "entity_type_name": EntityTypeNames.COMPUTER, "object_class": "computer", + "objectSid": SecurityPrincipalRid.DOMAIN_CONTROLLERS, "attributes": { "objectClass": ["top"], "userAccountControl": [ @@ -186,7 +190,7 @@ def _create_user_data(self, dto: SetupDTO) -> dict: str(SamAccountTypeCodes.SAM_USER_OBJECT), ], }, - "objectSid": 500, + "objectSid": SecurityPrincipalRid.ADMINISTRATOR, }, ], } @@ -199,10 +203,14 @@ async def _create(self, dto: SetupDTO, data: list) -> None: :return: None. """ try: + if await self._setup_gateway.is_base_domain_created(): + return + domain = await self._setup_gateway.create_base_domain(dto.domain) + await self._rid_manager_setup_use_case.create_domain_identifier() await self._setup_gateway.setup_enviroment( data=data, - dn=dto.domain, is_system=True, + domain=domain, ) attrs = await self._attribute_type_use_case_legacy.get_all() @@ -237,6 +245,8 @@ async def _create(self, dto: SetupDTO, data: list) -> None: await self._role_use_case.create_domain_admins_role() await self._role_use_case.create_read_only_role() await self._audit_use_case.create_policies() + await self._rid_manager_setup_use_case.setup() + await self._session.commit() except IntegrityError: await self._session.rollback() diff --git a/app/ldap_protocol/kerberos/dtos.py b/app/ldap_protocol/kerberos/dtos.py index d01775aee..ce11b6e2f 100644 --- a/app/ldap_protocol/kerberos/dtos.py +++ b/app/ldap_protocol/kerberos/dtos.py @@ -24,7 +24,6 @@ class AddRequestsDTO: """AddRequestsDTO for Kerberos admin structure.""" group: AddRequest - services: AddRequest krb_user: AddRequest diff --git a/app/ldap_protocol/kerberos/ldap_structure.py b/app/ldap_protocol/kerberos/ldap_structure.py index fec8741c0..d501fe858 100644 --- a/app/ldap_protocol/kerberos/ldap_structure.py +++ b/app/ldap_protocol/kerberos/ldap_structure.py @@ -39,28 +39,17 @@ def __init__( async def create_kerberos_structure( self, group: AddRequest, - services: AddRequest, krb_user: AddRequest, ctx: LDAPAddRequestContext, ) -> None: """Create Kerberos structure in the LDAP directory. :param AddRequest group: AddRequest for Kerberos group. - :param AddRequest services: AddRequest for services container. :param AddRequest krb_user: AddRequest for Kerberos admin user. - :param LDAPSession ldap_session: LDAP session. - :param AbstractKadmin kadmin: Kerberos admin interface. - :param EntityTypeDAO entity_type_dao: DAO for entity types. - :param str services_container: DN for services container. - :param str krbgroup: DN for Kerberos group. + :param LDAPAddRequestContext ctx: LDAP request context. :raises Exception: On structure creation error. :return None. """ - async with self._session.begin_nested(): - service_result = await anext(services.handle(ctx)) - if service_result.result_code != 0: - raise KerberosConflictError("Service error") - async with self._session.begin_nested(): group_result = await anext(group.handle(ctx)) if group_result.result_code != 0: @@ -76,20 +65,17 @@ async def create_kerberos_structure( async def rollback_kerberos_structure( self, krbadmin: str, - services_container: str, krbgroup: str, ) -> None: """Rollback Kerberos structure in the LDAP directory. :param str krbadmin: DN for Kerberos admin user. - :param str services_container: DN for services container. :param str krbgroup: DN for Kerberos group. :return None. """ directories_query = select(Directory).where( or_( get_filter_from_path(krbadmin), - get_filter_from_path(services_container), get_filter_from_path(krbgroup), ), ) diff --git a/app/ldap_protocol/kerberos/service.py b/app/ldap_protocol/kerberos/service.py index fa838abb9..9a6d331a9 100644 --- a/app/ldap_protocol/kerberos/service.py +++ b/app/ldap_protocol/kerberos/service.py @@ -121,14 +121,12 @@ async def setup_krb_catalogue( try: await self._ldap_manager.create_kerberos_structure( add_requests.group, - add_requests.services, add_requests.krb_user, ctx, ) except Exception: await self._ldap_manager.rollback_kerberos_structure( dns.krbadmin_dn, - dns.services_container_dn, dns.krbadmin_group_dn, ) await self._session.commit() @@ -188,11 +186,6 @@ def _build_add_requests( }, is_system=True, ) - services = AddRequest.from_dict( - dns.services_container_dn, - {"objectClass": ["organizationalUnit", "top", "container"]}, - is_system=True, - ) krb_user = AddRequest.from_dict( dns.krbadmin_dn, password=krbadmin_password.get_secret_value(), @@ -229,7 +222,6 @@ def _build_add_requests( ) return AddRequestsDTO( group=group, - services=services, krb_user=krb_user, ) @@ -283,7 +275,6 @@ async def setup_kdc( ) as err: await self._ldap_manager.rollback_kerberos_structure( context.krbadmin, - context.services_container, context.krbgroup, ) await self._kadmin.reset_setup() diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 9ea2eccaf..d65c89e6d 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -35,7 +35,6 @@ is_dn_in_base_directory, ) from ldap_protocol.utils.queries import ( - create_object_sid, get_base_directories, get_group, get_groups, @@ -215,7 +214,7 @@ async def handle( # noqa: C901 await ctx.session.flush() - new_dir.object_sid = create_object_sid(base_dn, new_dir.id) + await ctx.rid_manager_use_case.set_object_sid(directory=new_dir) await ctx.session.flush() except IntegrityError: await ctx.session.rollback() diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index d94b92af8..d53d73b51 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -27,6 +27,7 @@ from ldap_protocol.multifactor import LDAPMultiFactorAPI from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases +from ldap_protocol.rid_manager import RIDManagerUseCase from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.rootdse.reader import RootDSEReader @@ -47,6 +48,7 @@ class LDAPAddRequestContext: access_manager: AccessManager role_use_case: RoleUseCase attribute_value_validator: AttributeValueValidator + rid_manager_use_case: RIDManagerUseCase @dataclass @@ -63,6 +65,7 @@ class LDAPModifyRequestContext: password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils attribute_value_validator: AttributeValueValidator + rid_manager_use_case: RIDManagerUseCase @dataclass diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index bc5d67ca9..ccc74ee0a 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -376,10 +376,14 @@ def _mutate_query_with_attributes_to_load( if attr not in _ATTRS_TO_CLEAN } - cond = or_( + cond_parts = [ func.lower(Attribute.name).in_(attrs), func.lower(Attribute.name) == "objectclass", - ) + ] + if self.is_sid_requested: + cond_parts.append(func.lower(Attribute.name) == "objectsid") + + cond = or_(*cond_parts) return query.options( selectinload(qa(Directory.attributes)), @@ -489,7 +493,7 @@ async def paginate_query( return query, int(ceil(count / float(self.size_limit))), count - async def _fill_attrs( + async def _fill_attrs( # noqa: C901 self, directory: Directory, obj_classes: list[str], @@ -541,17 +545,23 @@ async def _fill_attrs( if group_directories is not None: async for directory_ in group_directories: - attrs["tokenGroups"].append( - string_to_sid(directory_.object_sid), # type: ignore - ) + sid_bytes = self.get_directory_sid(directory_) + if sid_bytes is not None: + attrs["tokenGroups"].append( + sid_bytes, # type: ignore + ) if self.member and "group" in obj_classes and directory.group: for member in directory.group.members: attrs["member"].append(member.path_dn) @staticmethod - def get_directory_sid(directory: Directory) -> bytes: - return string_to_sid(directory.object_sid) + def get_directory_sid(directory: Directory) -> bytes | None: + """Get objectSid as bytes from directory attributes.""" + for attr in directory.attributes: + if attr.name and attr.name.lower() == "objectsid" and attr.value: + return string_to_sid(attr.value) + return None @staticmethod def get_directory_guid(directory: Directory) -> bytes: @@ -600,6 +610,13 @@ async def tree_view( # noqa: C901 attrs[attr.name].append(value) continue + if ( + attr.name + and attr.name.lower() == "objectsid" + and self.is_sid_requested + ): + continue + attrs[attr.name].append(value) distinguished_name = directory.path_dn @@ -670,8 +687,11 @@ async def tree_view( # noqa: C901 attrs[directory.search_fields["objectguid"]].append(guid) # type: ignore if self.is_sid_requested: - guid = self.get_directory_sid(directory) - attrs[directory.search_fields["objectsid"]].append(guid) # type: ignore + sid_bytes = self.get_directory_sid(directory) + if sid_bytes is not None: + attrs["objectSid"].append( + sid_bytes, # type: ignore + ) if self.entity_type_name: attrs["entityTypeName"].append(directory.entity_type.name) diff --git a/app/ldap_protocol/rid_manager/__init__.py b/app/ldap_protocol/rid_manager/__init__.py new file mode 100644 index 000000000..a32cedc94 --- /dev/null +++ b/app/ldap_protocol/rid_manager/__init__.py @@ -0,0 +1,15 @@ +"""RID Manager module. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from .gateways import RIDManagerGateway, RIDManagerSetupGateway +from .use_cases import RIDManagerSetupUseCase, RIDManagerUseCase + +__all__ = [ + "RIDManagerGateway", + "RIDManagerSetupGateway", + "RIDManagerUseCase", + "RIDManagerSetupUseCase", +] diff --git a/app/ldap_protocol/rid_manager/gateways.py b/app/ldap_protocol/rid_manager/gateways.py new file mode 100644 index 000000000..6e77bf1d9 --- /dev/null +++ b/app/ldap_protocol/rid_manager/gateways.py @@ -0,0 +1,486 @@ +"""RID Manager Gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import secrets + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from config import Settings +from entities import Attribute, Directory +from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + + +class RIDManagerGateway: + """Gateway for RID Manager database operations. + + Handles all database operations for RID Manager: + - Reading/writing rIDAvailablePool (global pool in CN=RID Manager$) + - Reading/writing rIDNextRID (local counter, non-replicated) + """ + + def __init__(self, session: AsyncSession) -> None: + """Initialize RID Manager Gateway. + + :param session: SQLAlchemy async session + """ + self._session = session + + async def get_rid_available_pool(self, domain: Directory) -> int: + """Get rIDAvailablePool attribute from domain. + + This is a QWORD (64-bit) value where: + - Lower 32 bits: next available RID + - Upper 32 bits: maximum RID in pool + + :param domain: Domain directory object + :return: QWORD value of rIDAvailablePool + :raises ValueError: if attribute not found + """ + query = select(Attribute).where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "rIDAvailablePool", + ) + + attr = await self._session.scalar(query) + + if not attr or not attr.value: + raise ValueError("rIDAvailablePool attribute not found") + + return int(attr.value) + + async def get_next_rid(self, domain: Directory) -> int: + """Get next RID attribute from domain. + + This is the last issued RID (not the next one, despite the name). + This attribute is NOT replicated. + + :param domain: Domain directory object + :return: Last issued RID or None if not set + """ + query = await self._session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "rIDNextRID", + ), + ) + + if not query or not query.value: + raise ValueError("next RID attribute not found") + + return int(query.value) + + async def get_domain_identifier(self, domain: Directory) -> str: + """Get domain identifier. + + :return: Domain identifier + :raises ValueError: if domain identifier not found + """ + query = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "DomainIdentifier", + qa(Attribute.directory_id) == domain.id, + ), + ) + + if not query or not query.value: + raise ValueError("domain identifier not found") + + return query.value + + async def get_rid_set(self) -> Directory: + """Get RID Set directory. + + :return: RID Set directory + :raises ValueError: if RID Set directory not found + """ + rid_set = await self._session.scalar( + select(Directory).where(qa(Directory.name) == "RID Set"), + ) + if not rid_set: + raise ValueError("RID Set directory not found") + + return rid_set + + async def update_next_rid(self, rid_set: Directory, next_rid: int) -> None: + """Update next RID attribute in RID Set directory. + + :param rid_set: RID Set directory + :param next_rid: Next RID + """ + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.directory_id) == rid_set.id, + qa(Attribute.name) == "rIDNextRID", + ) + .values(value=str(next_rid)), + ) + + async def get_rid_manager(self) -> Directory: + """Get RID Manager directory. + + :return: RID Manager directory + :raises ValueError: if RID Manager directory not found + """ + rid_manager = await self._session.scalar( + select(Directory).where(qa(Directory.name) == "RID Manager$"), + ) + if not rid_manager: + raise ValueError("RID Manager directory not found") + + return rid_manager + + async def update_available_pool( + self, + qword_value: int, + ) -> None: + """Update available pool attribute in RID Manager directory. + + :param rid_manager: RID Manager directory + :param qword_value: QWORD value + """ + rid_manager = await self.get_rid_manager() + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.directory_id) == rid_manager.id, + qa(Attribute.name) == "rIDAvailablePool", + ) + .values(value=str(qword_value)), + ) + + async def add_object_sid( + self, + directory: Directory, + object_sid: str, + ) -> None: + """Add object SID to directory. + + :param directory: Directory + :param object_sid: Object SID + """ + self._session.add( + Attribute( + name="objectSid", + value=object_sid, + directory_id=directory.id, + ), + ) + + async def get_object_sid( + self, + rid_set: Directory, + ) -> str: + """Get object SID from directory. + + :param rid_set: RID Set directory + :return: Object SID + """ + query = await self._session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == rid_set.id, + qa(Attribute.name) == "objectSid", + ), + ) + if not query or not query.value: + raise ValueError("object SID not found") + return query.value + + async def get_base_domain(self) -> Directory: + """Get base domain directory. + + :return: Base domain directory + :raises ValueError: if base domain not found + """ + base_domain = await self._session.scalar( + select(Directory).where(qa(Directory.object_class) == "domain"), + ) + if not base_domain: + raise ValueError("base domain not found") + return base_domain + + +class RIDManagerSetupGateway: + """Gateway for RID Manager setup database operations.""" + + def __init__( + self, + session: AsyncSession, + entity_type_dao: EntityTypeDAO, + settings: Settings, + ) -> None: + """Initialize RID Manager setup gateway.""" + self._session = session + self._entity_type_dao = entity_type_dao + self._settings = settings + + async def get_domain_controller(self) -> Directory: + """Get domain controller directory. + + :return: Domain controller directory + :raises ValueError: if domain controller not found + """ + dc = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == self._settings.HOST_MACHINE_NAME, + ), + ) + + if not dc: + raise ValueError( + "Domain controller not found", + ) + + return dc + + async def get_system_container(self) -> Directory: + """Get System container directory. + + :return: System container directory + :raises ValueError: if System container not found + """ + base_dn_list = await get_base_directories(self._session) + if not base_dn_list: + raise ValueError("Domain not found") + + domain = base_dn_list[0] + + query = select(Directory).where( + qa(Directory.name) == "System", + qa(Directory.parent_id) == domain.id, + ) + + system_container = await self._session.scalar(query) + + if not system_container: + raise ValueError("System container not found") + + return system_container + + async def set_rid_manager(self) -> Directory: + """Create RID Manager directory.""" + system_container = await self.get_system_container() + + base_dn_list = await get_base_directories(self._session) + if not base_dn_list: + raise ValueError("Domain not found") + base_dn_list[0] + + rid_manager_dir = Directory( + is_system=True, + name="RID Manager$", + ) + rid_manager_dir.create_path(system_container, "cn") + + self._session.add(rid_manager_dir) + await self._session.flush() + + rid_manager_dir.parent_id = system_container.id + await self._session.refresh(rid_manager_dir, ["id"]) + + self._session.add( + Attribute( + name="cn", + value="RID Manager$", + directory_id=rid_manager_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="top", + directory_id=rid_manager_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="rIDManager", + directory_id=rid_manager_dir.id, + ), + ) + + await self._session.flush() + + await self._session.refresh( + instance=rid_manager_dir, + attribute_names=["attributes"], + with_for_update=None, + ) + + await self._entity_type_dao.attach_entity_type_to_directory( + directory=rid_manager_dir, + is_system_entity_type=True, + ) + + await self._session.flush() + + return rid_manager_dir + + async def create_rid_set( + self, + domain_controller: Directory, + ) -> Directory: + """Create CN=RID Set directory under Domain Controller. + + :param domain_controller: Domain Controller directory object + :return: Created RID Set directory + """ + base_dn_list = await get_base_directories(self._session) + if not base_dn_list: + raise ValueError("Domain not found") + base_dn_list[0] + + rid_set_dir = Directory( + is_system=True, + name="RID Set", + ) + rid_set_dir.create_path(domain_controller, "cn") + + self._session.add(rid_set_dir) + await self._session.flush() + + rid_set_dir.parent_id = domain_controller.id + await self._session.refresh(rid_set_dir, ["id"]) + + self._session.add( + Attribute( + name="cn", + value="RID Set", + directory_id=rid_set_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="top", + directory_id=rid_set_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="rIDSet", + directory_id=rid_set_dir.id, + ), + ) + + await self._session.flush() + + await self._session.refresh( + instance=rid_set_dir, + attribute_names=["attributes"], + with_for_update=None, + ) + + await self._entity_type_dao.attach_entity_type_to_directory( + directory=rid_set_dir, + is_system_entity_type=True, + ) + + await self._session.flush() + + return rid_set_dir + + async def set_rid_available_pool( + self, + domain: Directory, + qword_value: int, + ) -> None: + """Set rIDAvailablePool attribute in domain. + + Updates the global RID pool counter. + + :param domain: Domain directory object + :param qword_value: New QWORD value (64-bit) + """ + query = ( + update(Attribute) + .where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "rIDAvailablePool", + ) + .values(value=str(qword_value)) + ) + + result = await self._session.execute(query) + + if result.rowcount == 0: + self._session.add( + Attribute( + directory_id=domain.id, + name="rIDAvailablePool", + value=str(qword_value), + ), + ) + + await self._session.flush() + + async def set_next_rid( + self, + domain: Directory, + rid: int, + ) -> None: + """Set next RID attribute in domain. + + Updates the last issued RID counter. + + :param domain: Domain directory object + :param rid: Last issued RID value + """ + self._session.add( + Attribute( + directory_id=domain.id, + name="rIDNextRID", + value=str(rid), + ), + ) + + await self._session.flush() + + def _generate_domain_sid_identifier(self) -> str: + """Generate Domain Identifier for Active Directory domain.""" + return ( + f"{secrets.randbits(32)}" + f"-{secrets.randbits(32)}-{secrets.randbits(32)}" + ) + + async def create_domain_identifier(self) -> None: + """Add domain identifier to domain.""" + domain = await self._session.scalar( + select(Directory).where( + qa(Directory.object_class) == "domain", + ), + ) + if not domain: + raise ValueError("Domain not found") + + self._session.add( + Attribute( + name="DomainIdentifier", + value=f"{self._generate_domain_sid_identifier()}", + directory_id=domain.id, + ), + ) + await self._session.flush() + + async def get_domain_identifier(self) -> str: + """Get domain identifier.""" + domain = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "DomainIdentifier", + ), + ) + if not domain or not domain.value: + raise ValueError("Domain not found") + return domain.value diff --git a/app/ldap_protocol/rid_manager/use_cases.py b/app/ldap_protocol/rid_manager/use_cases.py new file mode 100644 index 000000000..f5bfbce5e --- /dev/null +++ b/app/ldap_protocol/rid_manager/use_cases.py @@ -0,0 +1,158 @@ +"""RID Manager for issuing RID from pools. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE + +""" + +import asyncio + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Directory +from enums import AceType, RoleConstants, RoleScope, SidPrefix +from ldap_protocol.rid_manager.gateways import ( + RIDManagerGateway, + RIDManagerSetupGateway, +) +from ldap_protocol.rid_manager.utils import create_qword +from ldap_protocol.roles.ace_dao import AccessControlEntryDAO +from ldap_protocol.roles.dataclasses import AccessControlEntryDTO +from ldap_protocol.roles.role_dao import RoleDAO + +RID_AVAILABLE_MAX = 1073741822 # 30-bit max (2^30 - 2) + + +class RIDManagerUseCase: + """RID Manager Use Case for issuing RID from pools.""" + + def __init__( + self, + gateway: RIDManagerGateway, + session: AsyncSession, + ) -> None: + """Initialize RID Manager Use Case. + + :param gateway: RID Manager Gateway for database operations + """ + self._gateway = gateway + self._lock = asyncio.Lock() + self._session = session + + async def get_object_sid( + self, + directory: Directory, + ) -> str: + """Get object SID for directory.""" + return await self._gateway.get_object_sid(directory) + + async def set_object_sid( + self, + directory: Directory, + rid: int | None = None, + sid_prefix: SidPrefix = SidPrefix.DOMAIN_IDENTIFIER, + ) -> None: + """Create object SID.""" + async with self._lock, await self._session.begin_nested(): + if rid is None: + rid_set = await self._gateway.get_rid_set() + next_rid = await self._gateway.get_next_rid(rid_set) + rid = next_rid + 1 + await self._gateway.update_next_rid(rid_set, rid) + await self._gateway.update_available_pool( + create_qword(rid, RID_AVAILABLE_MAX), + ) + + if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: + sid = f"{sid_prefix}-{rid}" + elif sid_prefix == SidPrefix.DOMAIN_IDENTIFIER: + base_domain = await self._gateway.get_base_domain() + domain_identifier = await self._gateway.get_domain_identifier( + base_domain, + ) + sid = f"{sid_prefix}-{domain_identifier}-{rid}" + + await self._gateway.add_object_sid(directory, sid) + + await self._session.flush() + + async def parse_object_sid(self, object_sid: str) -> tuple[str, str, int]: + """Parse object SID. + + :param object_sid: Object SID + :return: Tuple containing domain identifier, rid, and reserved flag + """ + parts = object_sid.split("-") + return parts[1], parts[2], int(parts[3]) + + +class RIDManagerSetupUseCase: + """RID Manager setup use case.""" + + RID_SYSTEM_MIN = 1 + RID_SYSTEM_MAX = 499 + RID_BUILTIN_MIN = 500 + RID_BUILTIN_MAX = 1000 + RID_USER_MIN = 1100 + + def __init__( + self, + rid_manager_setup_gateway: RIDManagerSetupGateway, + role_dao: RoleDAO, + access_control_entry_dao: AccessControlEntryDAO, + ) -> None: + """Initialize RID Manager setup use case. + + :param rid_manager_setup_gateway: Gateway for setup operations + """ + self._gateway = rid_manager_setup_gateway + self._role_dao = role_dao + self._access_control_entry_dao = access_control_entry_dao + + async def setup(self) -> None: + """Create RID Manager.""" + rid_manager_dir = await self._gateway.set_rid_manager() + await self.grant_domain_admins_read_to_rid_manager( + rid_manager_dir, + ) + + qword = create_qword(self.RID_USER_MIN, RID_AVAILABLE_MAX) + + await self._gateway.set_rid_available_pool( + rid_manager_dir, + qword, + ) + domain_controller = await self._gateway.get_domain_controller() + + rid_set_dir = await self._gateway.create_rid_set( + domain_controller, + ) + await self._gateway.set_next_rid( + rid_set_dir, + self.RID_USER_MIN, + ) + + async def grant_domain_admins_read_to_rid_manager( + self, + rid_manager_dir: Directory, + ) -> None: + """Grant READ access on RID Manager to Domain Admins Role.""" + role = await self._role_dao.get_by_name( + RoleConstants.DOMAIN_ADMINS_ROLE_NAME, + ) + + await self._access_control_entry_dao.create( + AccessControlEntryDTO( + role_id=role.get_id(), + ace_type=AceType.READ, + scope=RoleScope.BASE_OBJECT, + base_dn=rid_manager_dir.path_dn, + attribute_type_id=None, + entity_type_id=None, + is_allow=True, + ), + ) + + async def create_domain_identifier(self) -> None: + """Create domain identifier.""" + await self._gateway.create_domain_identifier() diff --git a/app/ldap_protocol/rid_manager/utils.py b/app/ldap_protocol/rid_manager/utils.py new file mode 100644 index 000000000..d99df16fc --- /dev/null +++ b/app/ldap_protocol/rid_manager/utils.py @@ -0,0 +1,13 @@ +"""RID Manager utils.""" + + +def create_qword(lower: int, upper: int) -> int: + """Create QWORD (64-bit) from two DWORDs (32-bit each).""" + if lower < 0 or lower > 0xFFFFFFFF: + raise ValueError(f"Lower boundary out of range: {lower}") + if upper < 0 or upper > 0xFFFFFFFF: + raise ValueError(f"Upper boundary out of range: {upper}") + + qword = (upper << 32) | lower + + return qword diff --git a/app/ldap_protocol/rootdse/reader.py b/app/ldap_protocol/rootdse/reader.py index 065be0a54..20503b4d4 100644 --- a/app/ldap_protocol/rootdse/reader.py +++ b/app/ldap_protocol/rootdse/reader.py @@ -8,6 +8,7 @@ from config import Settings from constants import DEFAULT_DC_POSTFIX, UNC_PREFIX +from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase from ldap_protocol.utils.helpers import get_generalized_now from .dto import DomainControllerInfo @@ -87,14 +88,21 @@ async def get( class DCInfoReader: - def __init__(self, settings: Settings, gw: DomainReadProtocol) -> None: + def __init__( + self, + settings: Settings, + gw: DomainReadProtocol, + rid_manager: RIDManagerUseCase, + ) -> None: self._settings = settings self._gw = gw + self._rid_manager = rid_manager async def get(self) -> DomainControllerInfo: domain = await self._gw.get_domain() dns = domain.name.lower() nb_domain = dns.split(".")[0].upper() + object_sid = await self._rid_manager.get_object_sid(domain) return DomainControllerInfo( net_bios_domain=nb_domain, @@ -102,6 +110,6 @@ async def get(self) -> DomainControllerInfo: unc=UNC_PREFIX + dns, dns=dns, dns_forest=dns, - object_sid=domain.object_sid, + object_sid=object_sid, object_guid=str(domain.object_guid), ) diff --git a/app/ldap_protocol/utils/cte.py b/app/ldap_protocol/utils/cte.py index 7b4628254..6b9c513af 100644 --- a/app/ldap_protocol/utils/cte.py +++ b/app/ldap_protocol/utils/cte.py @@ -6,6 +6,7 @@ from sqlalchemy import exists, or_ from sqlalchemy.ext.asyncio import AsyncScalarResult, AsyncSession +from sqlalchemy.orm import selectinload from sqlalchemy.sql.expression import select from sqlalchemy.sql.selectable import CTE @@ -237,6 +238,10 @@ async def get_all_parent_group_directories( if not directories_ids: return None - query = select(Directory).where(directory_table.c.id.in_(directories_ids)) + query = ( + select(Directory) + .where(directory_table.c.id.in_(directories_ids)) + .options(selectinload(qa(Directory.attributes))) + ) return await session.stream_scalars(query) diff --git a/app/ldap_protocol/utils/helpers.py b/app/ldap_protocol/utils/helpers.py index 296ff5978..e5db1444a 100644 --- a/app/ldap_protocol/utils/helpers.py +++ b/app/ldap_protocol/utils/helpers.py @@ -132,7 +132,6 @@ import functools import hashlib -import random import re import struct import time @@ -301,34 +300,6 @@ def string_to_sid(sid_string: str) -> bytes: return sid -def create_object_sid( - domain: Directory, - rid: int, - reserved: bool = False, -) -> str: - """Generate the objectSid attribute for an object. - - :param domain: domain directory - :param int rid: relative identifier - :param bool reserved: A flag indicating whether the RID is reserved. - If `True`, the given RID is used directly. If - `False`, 1000 is added to the given RID to generate - the final RID - :return str: the complete objectSid as a string - """ - return domain.object_sid + f"-{rid if reserved else 1000 + rid}" - - -def generate_domain_sid() -> str: - """Generate domain objectSid attr.""" - sub_authorities = [ - random.randint(1000000000, (1 << 32) - 1), - random.randint(1000000000, (1 << 32) - 1), - random.randint(100000000, 999999999), - ] - return "S-1-5-21-" + "-".join(str(part) for part in sub_authorities) - - def create_user_name(directory_id: int) -> str: """Create username by directory id. diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 2e078b840..df0268988 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -21,7 +21,7 @@ from sqlalchemy.sql.expression import ColumnElement from entities import Attribute, Directory, Group, User -from enums import SamAccountTypeCodes +from enums import SamAccountTypeCodes, SidPrefix from ldap_protocol.ldap_schema.attribute_value_validator import ( AttributeValueValidator, AttributeValueValidatorError, @@ -36,7 +36,6 @@ from .const import EMAIL_RE, GRANT_DN_STRING from .helpers import ( create_integer_hash, - create_object_sid, dn_is_base_directory, ft_now, validate_entry, @@ -190,16 +189,16 @@ async def get_directory_by_rid( rid: str, session: AsyncSession, ) -> Directory | None: - """Get directory by relative ID (rid). - - :param str rid: relative ID - :param AsyncSession session: SA session - :return Directory | None: directory or None - """ query = ( select(Directory) - .options(joinedload(qa(Directory.group))) - .filter(qa(Directory.object_sid).endswith(f"-{rid}")) + .join(Attribute) # связь Directory.id == Attribute.directory_id + .options( + joinedload(qa(Directory.group)), + ) + .filter( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).endswith(f"-{rid}"), + ) ) return await session.scalar(query) @@ -386,10 +385,12 @@ async def create_group( dir_.create_path(parent) session.add(group) - dir_.object_sid = create_object_sid( - base_dn_list[0], - rid=sid or dir_.id, - reserved=bool(sid), + session.add( + Attribute( + name="objectSid", + value=f"{SidPrefix.BUILT_IN_DOMAIN}-{sid or dir_.id}", + directory_id=dir_.id, + ), ) await session.flush() @@ -559,9 +560,13 @@ async def get_group_path_dn_by_primary_group_id( """ query = ( select(Directory) + .join(Attribute) .join(qa(Directory.group)) .options(contains_eager(qa(Directory.group))) - .filter(qa(Directory.object_sid).endswith(f"-{primary_group_id}")) + .filter( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).endswith(f"-{primary_group_id}"), + ) ) directory = await session.scalar(query) diff --git a/app/repo/pg/tables.py b/app/repo/pg/tables.py index be17cef6f..4c7b641fb 100644 --- a/app/repo/pg/tables.py +++ b/app/repo/pg/tables.py @@ -145,7 +145,6 @@ def _compile_create_uc( key="updated_at", ), Column("depth", Integer, nullable=True), - Column("objectSid", String, nullable=True, key="object_sid"), Column( "objectGUID", PG_UUID(as_uuid=True), @@ -781,7 +780,6 @@ def _compile_create_uc( ), "objectclass": synonym("object_class"), "objectguid": synonym("object_guid"), - "objectsid": synonym("object_sid"), "whencreated": synonym("created_at"), "whenchanged": synonym("updated_at"), }, diff --git a/tests/conftest.py b/tests/conftest.py index 2d59cec3b..ce696103b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1111,8 +1111,9 @@ async def setup_session( await session.flush() await audit_use_case.create_policies() + domain = await setup_gateway.create_base_domain() await setup_gateway.setup_enviroment( - dn="md.test", + domain=domain, data=TEST_DATA, is_system=False, ) diff --git a/tests/test_ldap/test_rid_manager/__init__.py b/tests/test_ldap/test_rid_manager/__init__.py new file mode 100644 index 000000000..ae7ee0bad --- /dev/null +++ b/tests/test_ldap/test_rid_manager/__init__.py @@ -0,0 +1 @@ +"""Tests for RID Manager.""" From 87d61c9195167ffed7de8493f38f766fcd14e550 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 26 Feb 2026 13:50:33 +0300 Subject: [PATCH 03/37] Refactor: Clean up join statement in get_directory_by_rid function --- app/ldap_protocol/utils/queries.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index df0268988..518ede084 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -191,7 +191,7 @@ async def get_directory_by_rid( ) -> Directory | None: query = ( select(Directory) - .join(Attribute) # связь Directory.id == Attribute.directory_id + .join(Attribute) .options( joinedload(qa(Directory.group)), ) From d030505c458811707c827fb31a8a812d37d1bc1a Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 26 Feb 2026 18:29:10 +0300 Subject: [PATCH 04/37] Refactor: Simplify object SID extraction and update RID Manager use case handling --- .../552b4eafb1aa_remove_objectsid_vals.py | 6 +- app/entities.py | 3 +- app/ldap_protocol/ldap_requests/add.py | 6 +- app/ldap_protocol/rid_manager/gateways.py | 4 +- app/ldap_protocol/rid_manager/use_cases.py | 46 ++++++------ app/ldap_protocol/utils/queries.py | 2 +- tests/conftest.py | 71 ++++++++++++++++++- tests/constants.py | 16 ++++- .../test_main/test_router/conftest.py | 3 + .../test_main/test_router/test_search.py | 1 + tests/test_ldap/test_roles/test_search.py | 3 +- tests/test_ldap/test_util/test_modify.py | 20 +----- tests/test_shedule.py | 3 + 13 files changed, 127 insertions(+), 57 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 0e0136cfe..3162d8d87 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -26,7 +26,7 @@ ) from ldap_protocol.rid_manager.utils import create_qword from ldap_protocol.roles.ace_dao import AccessControlEntryDAO -from ldap_protocol.roles.role_dao import RoleDAO +from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -136,7 +136,7 @@ async def _init_rid_manager( rid_setup_gateway = await cnt.get(RIDManagerSetupGateway) rid_setup_use_case = RIDManagerSetupUseCase( rid_manager_setup_gateway=rid_setup_gateway, - role_dao=await cnt.get(RoleDAO), + role_use_case=await cnt.get(RoleUseCase), access_control_entry_dao=await cnt.get(AccessControlEntryDAO), ) rid_gateway = RIDManagerGateway(session) @@ -152,6 +152,8 @@ async def _init_rid_manager( await rid_gateway.get_rid_manager() rid_set_dir = await rid_gateway.get_rid_set() + if not rid_set_dir: + return base_domain = await rid_gateway.get_base_domain() domain_identifier = await rid_gateway.get_domain_identifier( diff --git a/app/entities.py b/app/entities.py index 4c7f131c1..28f0730cb 100644 --- a/app/entities.py +++ b/app/entities.py @@ -195,8 +195,7 @@ def relative_id(self) -> str: for attr in attrs: if attr.name and attr.name.lower() == "objectsid" and attr.value: - if "-" in attr.value: - return attr.value.split("-")[-1] + return attr.value.split("-")[-1] return "" @property diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index d65c89e6d..0d84ca206 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -213,8 +213,10 @@ async def handle( # noqa: C901 ctx.session.add(new_dir) await ctx.session.flush() - - await ctx.rid_manager_use_case.set_object_sid(directory=new_dir) + # if await ctx.rid_manager_use_case.get_rid_set(): + await ctx.rid_manager_use_case.set_object_sid( + directory=new_dir, + ) await ctx.session.flush() except IntegrityError: await ctx.session.rollback() diff --git a/app/ldap_protocol/rid_manager/gateways.py b/app/ldap_protocol/rid_manager/gateways.py index 6e77bf1d9..b035a7ce7 100644 --- a/app/ldap_protocol/rid_manager/gateways.py +++ b/app/ldap_protocol/rid_manager/gateways.py @@ -93,7 +93,7 @@ async def get_domain_identifier(self, domain: Directory) -> str: return query.value - async def get_rid_set(self) -> Directory: + async def get_rid_set(self) -> Directory | None: """Get RID Set directory. :return: RID Set directory @@ -102,8 +102,6 @@ async def get_rid_set(self) -> Directory: rid_set = await self._session.scalar( select(Directory).where(qa(Directory.name) == "RID Set"), ) - if not rid_set: - raise ValueError("RID Set directory not found") return rid_set diff --git a/app/ldap_protocol/rid_manager/use_cases.py b/app/ldap_protocol/rid_manager/use_cases.py index f5bfbce5e..337fc3138 100644 --- a/app/ldap_protocol/rid_manager/use_cases.py +++ b/app/ldap_protocol/rid_manager/use_cases.py @@ -10,15 +10,14 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Directory -from enums import AceType, RoleConstants, RoleScope, SidPrefix +from enums import SidPrefix from ldap_protocol.rid_manager.gateways import ( RIDManagerGateway, RIDManagerSetupGateway, ) from ldap_protocol.rid_manager.utils import create_qword from ldap_protocol.roles.ace_dao import AccessControlEntryDAO -from ldap_protocol.roles.dataclasses import AccessControlEntryDTO -from ldap_protocol.roles.role_dao import RoleDAO +from ldap_protocol.roles.role_use_case import RoleUseCase RID_AVAILABLE_MAX = 1073741822 # 30-bit max (2^30 - 2) @@ -46,6 +45,10 @@ async def get_object_sid( """Get object SID for directory.""" return await self._gateway.get_object_sid(directory) + async def get_rid_set(self) -> Directory | None: + """Get RID Set directory.""" + return await self._gateway.get_rid_set() + async def set_object_sid( self, directory: Directory, @@ -56,6 +59,8 @@ async def set_object_sid( async with self._lock, await self._session.begin_nested(): if rid is None: rid_set = await self._gateway.get_rid_set() + if not rid_set: + raise ValueError("RID Set directory not found") next_rid = await self._gateway.get_next_rid(rid_set) rid = next_rid + 1 await self._gateway.update_next_rid(rid_set, rid) @@ -98,23 +103,24 @@ class RIDManagerSetupUseCase: def __init__( self, rid_manager_setup_gateway: RIDManagerSetupGateway, - role_dao: RoleDAO, + role_use_case: RoleUseCase, access_control_entry_dao: AccessControlEntryDAO, ) -> None: """Initialize RID Manager setup use case. :param rid_manager_setup_gateway: Gateway for setup operations + :param role_use_case: Role use case """ self._gateway = rid_manager_setup_gateway - self._role_dao = role_dao + self._role_use_case = role_use_case self._access_control_entry_dao = access_control_entry_dao async def setup(self) -> None: """Create RID Manager.""" rid_manager_dir = await self._gateway.set_rid_manager() - await self.grant_domain_admins_read_to_rid_manager( - rid_manager_dir, - ) + # await self.grant_domain_admins_read_to_rid_manager( + # rid_manager_dir, + # ) qword = create_qword(self.RID_USER_MIN, RID_AVAILABLE_MAX) @@ -136,21 +142,17 @@ async def grant_domain_admins_read_to_rid_manager( self, rid_manager_dir: Directory, ) -> None: - """Grant READ access on RID Manager to Domain Admins Role.""" - role = await self._role_dao.get_by_name( - RoleConstants.DOMAIN_ADMINS_ROLE_NAME, - ) + """Inherit ACEs from domain root to RID Manager directory. - await self._access_control_entry_dao.create( - AccessControlEntryDTO( - role_id=role.get_id(), - ace_type=AceType.READ, - scope=RoleScope.BASE_OBJECT, - base_dn=rid_manager_dir.path_dn, - attribute_type_id=None, - entity_type_id=None, - is_allow=True, - ), + Instead of creating a special ACE or role for RID Manager, + we reuse the existing ACL model: all ACEs that apply to the + domain root (including Domain Admins) are inherited by the + `CN=RID Manager$` directory, similar to how it is done in + migration `ebf19750805e_add_domain_controllers_ou`. + """ + await self._role_use_case.inherit_parent_aces( + parent_directory=await self._gateway.get_system_container(), + directory=rid_manager_dir, ) async def create_domain_identifier(self) -> None: diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 518ede084..64bb827e2 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -191,8 +191,8 @@ async def get_directory_by_rid( ) -> Directory | None: query = ( select(Directory) - .join(Attribute) .options( + selectinload(qa(Directory.attributes)), joinedload(qa(Directory.group)), ) .filter( diff --git a/tests/conftest.py b/tests/conftest.py index ce696103b..01572a652 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -64,6 +64,7 @@ from authorization_provider_protocol import AuthorizationProviderProtocol from config import Settings from constants import ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2 +from entities import Directory from enums import AuthorizationRules from ioc import AuditRedisClient, MFACredsProvider, SessionStorageClient from ldap_protocol.auth import AuthManager, MFAManager @@ -174,6 +175,14 @@ PasswordBanWordUseCases, UserPasswordHistoryUseCases, ) +from ldap_protocol.rid_manager.gateways import ( + RIDManagerGateway, + RIDManagerSetupGateway, +) +from ldap_protocol.rid_manager.use_cases import ( + RIDManagerSetupUseCase, + RIDManagerUseCase, +) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.dataclasses import RoleDTO @@ -819,6 +828,16 @@ def authorization_provider_protocol( ) rootdse_reader = provide(RootDSEReader, scope=Scope.REQUEST) dcinfo_reader = provide(DCInfoReader, scope=Scope.REQUEST) + rid_manager_gateway = provide(RIDManagerGateway, scope=Scope.REQUEST) + rid_manager_use_case = provide(RIDManagerUseCase, scope=Scope.REQUEST) + rid_manager_setup_gateway = provide( + RIDManagerSetupGateway, + scope=Scope.REQUEST, + ) + rid_manager_setup_use_case = provide( + RIDManagerSetupUseCase, + scope=Scope.REQUEST, + ) @dataclass @@ -1025,6 +1044,7 @@ async def setup_session( session: AsyncSession, raw_audit_manager: RawAuditManager, password_utils: PasswordUtils, + settings: Settings, ) -> None: """Get session and acquire after completion.""" role_dao = RoleDAO(session) @@ -1098,25 +1118,55 @@ async def setup_session( password_policy_validator, password_ban_word_repository, ) + rid_manager_gateway = RIDManagerGateway(session) + rid_manager_use_case = RIDManagerUseCase( + rid_manager_gateway, + session, + ) + rid_manager_setup_gateway = RIDManagerSetupGateway( + session=session, + entity_type_dao=entity_type_dao, + settings=settings, + ) + role_dao = RoleDAO(session) + ace_dao = AccessControlEntryDAO(session) + role_use_case = RoleUseCase(role_dao, ace_dao) + rid_manager_setup_use_case = RIDManagerSetupUseCase( + rid_manager_setup_gateway=rid_manager_setup_gateway, + role_use_case=role_use_case, + access_control_entry_dao=AccessControlEntryDAO(session), + ) setup_gateway = SetupGateway( session, password_utils, entity_type_use_case=entity_type_use_case, attribute_value_validator=attribute_value_validator, directory_dao=directory_dao, + rid_manager_use_case=rid_manager_use_case, ) - for entity_type_dto in chain(ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2): await entity_type_use_case.create_not_safe(entity_type_dto) await session.flush() - await audit_use_case.create_policies() - domain = await setup_gateway.create_base_domain() + domain = await setup_gateway.create_base_domain("md.test") + await rid_manager_setup_use_case.create_domain_identifier() + await setup_gateway.setup_enviroment( domain=domain, data=TEST_DATA, is_system=False, ) + dc_directory = Directory( + name=settings.HOST_MACHINE_NAME, + object_class="computer", + is_system=True, + ) + dc_directory.create_path(domain, "cn") + session.add(dc_directory) + await session.flush() + dc_directory.parent_id = domain.id + await session.refresh(dc_directory, ["id"]) + await session.flush() for _at_dto in ( AttributeTypeDTO[None]( @@ -1185,11 +1235,15 @@ async def setup_session( ] await object_class_use_case.create(_oc_dto) # type: ignore + await audit_use_case.create_policies() + # NOTE: after setup environment we need base DN to be created await password_use_cases.create_default_domain_policy() await role_use_case.create_domain_admins_role() + await rid_manager_setup_use_case.setup() + await role_use_case._role_dao.create( # noqa: SLF001 dto=RoleDTO( name="TEST ONLY LOGIN ROLE", @@ -1203,6 +1257,17 @@ async def setup_session( await session.commit() +@pytest_asyncio.fixture(scope="function") +async def rid_manager_use_case( + container: AsyncContainer, +) -> AsyncIterator[RIDManagerUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + gateway = RIDManagerGateway(session) + yield RIDManagerUseCase(gateway, session) + + @pytest_asyncio.fixture(scope="function") async def ldap_session( container: AsyncContainer, diff --git a/tests/constants.py b/tests/constants.py index 68e980383..19228e3db 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -10,9 +10,10 @@ DOMAIN_COMPUTERS_GROUP_NAME, DOMAIN_USERS_GROUP_NAME, GROUPS_CONTAINER_NAME, + SYSTEM_CONTAINER_NAME, USERS_CONTAINER_NAME, ) -from enums import EntityTypeNames, SamAccountTypeCodes +from enums import EntityTypeNames, SamAccountTypeCodes, SecurityPrincipalRid from ldap_protocol.objects import UserAccountControlFlag user_data_dict = { @@ -66,7 +67,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, - "objectSid": 512, + "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, }, { "name": "developers", @@ -82,6 +83,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, + "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, }, { "name": "admin login only", @@ -96,6 +98,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, + "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, }, { "name": DOMAIN_USERS_GROUP_NAME, @@ -110,6 +113,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, + "objectSid": SecurityPrincipalRid.DOMAIN_USERS, }, { "name": DOMAIN_COMPUTERS_GROUP_NAME, @@ -124,6 +128,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, + "objectSid": SecurityPrincipalRid.DOMAIN_COMPUTERS, }, ], }, @@ -464,6 +469,13 @@ "entity_type_name": EntityTypeNames.CONFIGURATION, "object_class": "container", "attributes": {"objectClass": ["top", "configuration"]}, + }, + { + "name": SYSTEM_CONTAINER_NAME, + "object_class": "organizationalUnit", + "attributes": { + "objectClass": ["top", "container"], + }, "children": [], }, ] diff --git a/tests/test_api/test_main/test_router/conftest.py b/tests/test_api/test_main/test_router/conftest.py index dc26b0577..f11a0c259 100644 --- a/tests/test_api/test_main/test_router/conftest.py +++ b/tests/test_api/test_main/test_router/conftest.py @@ -19,6 +19,7 @@ from ldap_protocol.ldap_schema.object_class.object_class_dao import ( ObjectClassDAO, ) +from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase from ldap_protocol.utils.queries import get_base_directories from password_utils import PasswordUtils from tests.constants import TEST_SYSTEM_ADMIN_DATA @@ -29,6 +30,7 @@ async def add_system_administrator( session: AsyncSession, password_utils: PasswordUtils, setup_session: None, # noqa: ARG001 + rid_manager_use_case: RIDManagerUseCase, ) -> None: """Create system administrator user for tests that require it.""" attribute_value_validator = AttributeValueValidator() @@ -51,6 +53,7 @@ async def add_system_administrator( entity_type_use_case, attribute_value_validator=attribute_value_validator, directory_dao=directory_dao, + rid_manager_use_case=rid_manager_use_case, ) domain = (await get_base_directories(session))[0] diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index c4c604ed1..c8c6ab76b 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -132,6 +132,7 @@ async def test_api_search(http_client: AsyncClient) -> None: sub_dirs = { "cn=Groups,dc=md,dc=test", "cn=Configuration,dc=md,dc=test", + "ou=System,dc=md,dc=test", "cn=Users,dc=md,dc=test", "ou=testModifyDn1,dc=md,dc=test", "ou=testModifyDn3,dc=md,dc=test", diff --git a/tests/test_ldap/test_roles/test_search.py b/tests/test_ldap/test_roles/test_search.py index 54e9a0641..55a4d8532 100644 --- a/tests/test_ldap/test_roles/test_search.py +++ b/tests/test_ldap/test_roles/test_search.py @@ -108,9 +108,10 @@ async def test_role_search_3( "dn: cn=Groups,dc=md,dc=test", "dn: cn=Users,dc=md,dc=test", "dn: cn=user_non_admin,cn=Users,dc=md,dc=test", - "dn: ou=test_bit_rules,dc=md,dc=test", + "dn: ou=System,dc=md,dc=test", "dn: ou=testModifyDn1,dc=md,dc=test", "dn: ou=testModifyDn3,dc=md,dc=test", + "dn: ou=test_bit_rules,dc=md,dc=test", ], expected_attrs_present=[], expected_attrs_absent=[], diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index adf89c1b6..32c1025b3 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -982,12 +982,6 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: @pytest.mark.parametrize( ("operation", "group_dn", "expected_groups", "expected_primary_group"), [ - ( - "add", - "cn=developers,cn=Groups,dc=md,dc=test", - {"domain admins", "developers"}, - True, - ), ( "add", "cn=domain admins,cn=Groups,dc=md,dc=test", @@ -1000,12 +994,6 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: {"domain admins", "developers"}, False, ), - ( - "replace", - "cn=developers,cn=Groups,dc=md,dc=test", - {"domain admins", "developers"}, - True, - ), ], ) async def test_ldap_modify_primary_group_id_scenarios( @@ -1062,7 +1050,7 @@ async def test_ldap_modify_primary_group_id_scenarios( attributes[attr.name].append(attr.value) if expected_primary_group: - assert attributes["primaryGroupID"] == [group_dir.relative_id] + assert attributes["primaryGroupID"] == [rid] else: assert "primaryGroupID" not in attributes @@ -1072,12 +1060,6 @@ async def test_ldap_modify_primary_group_id_scenarios( @pytest.mark.parametrize( ("values", "include_dev_group", "expected_result", "expected_groups"), [ - ( - ["cn=domain admins,cn=Groups,dc=md,dc=test"], - True, - 1, - {"domain admins", "developers"}, - ), ( ["cn=domain admins,cn=Groups,dc=md,dc=test"], False, diff --git a/tests/test_shedule.py b/tests/test_shedule.py index a952b94e1..ea2ef6e3c 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -17,6 +17,7 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) +from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase from ldap_protocol.roles.role_use_case import RoleUseCase @@ -88,6 +89,7 @@ async def test_add_domain_controller( settings: Settings, role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, + rid_manager_use_case: RIDManagerUseCase, ) -> None: """Test add domain controller.""" await add_domain_controller( @@ -95,4 +97,5 @@ async def test_add_domain_controller( session=session, role_use_case=role_use_case, entity_type_use_case=entity_type_use_case, + rid_manager_use_case=rid_manager_use_case, ) From dd7642bd5cf82ddb5f23e83db77113d15915656a Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 26 Feb 2026 18:40:06 +0300 Subject: [PATCH 05/37] Refactor: Rename and update RID Manager setup method to inherit ACEs --- app/ldap_protocol/rid_manager/use_cases.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/app/ldap_protocol/rid_manager/use_cases.py b/app/ldap_protocol/rid_manager/use_cases.py index 337fc3138..7549ce174 100644 --- a/app/ldap_protocol/rid_manager/use_cases.py +++ b/app/ldap_protocol/rid_manager/use_cases.py @@ -118,9 +118,9 @@ def __init__( async def setup(self) -> None: """Create RID Manager.""" rid_manager_dir = await self._gateway.set_rid_manager() - # await self.grant_domain_admins_read_to_rid_manager( - # rid_manager_dir, - # ) + await self.inherit_aces( + rid_manager_dir, + ) qword = create_qword(self.RID_USER_MIN, RID_AVAILABLE_MAX) @@ -138,7 +138,7 @@ async def setup(self) -> None: self.RID_USER_MIN, ) - async def grant_domain_admins_read_to_rid_manager( + async def inherit_aces( self, rid_manager_dir: Directory, ) -> None: From 83e37200fde255f1237397051a1cefbb17f731cf Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 26 Feb 2026 18:41:20 +0300 Subject: [PATCH 06/37] Update down_revision in Alembic migration to reflect new dependency --- app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 3162d8d87..16e9437f0 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -32,7 +32,7 @@ # revision identifiers, used by Alembic. revision: None | str = "552b4eafb1aa" -down_revision: None | str = "ebf19750805e" +down_revision: None | str = "19d86e660cf2" branch_labels: None | list[str] = None depends_on: None | list[str] = None From 10109c23023d8dae66dc33b498d091119bfb2217 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 27 Feb 2026 14:00:19 +0300 Subject: [PATCH 07/37] Update test constants and modify test cases to use AsyncSession; adjust primary group ID in search tests --- tests/constants.py | 5 +---- .../test_main/test_router/test_modify_dn.py | 18 ++++++++++++++++-- .../test_main/test_router/test_search.py | 2 +- tests/test_ldap/test_util/test_modify.py | 16 ++++++---------- 4 files changed, 24 insertions(+), 17 deletions(-) diff --git a/tests/constants.py b/tests/constants.py index 19228e3db..b570415c6 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -83,7 +83,7 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, - "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, + "objectSid": 999, }, { "name": "admin login only", @@ -98,7 +98,6 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, - "objectSid": SecurityPrincipalRid.DOMAIN_ADMINS, }, { "name": DOMAIN_USERS_GROUP_NAME, @@ -113,7 +112,6 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, - "objectSid": SecurityPrincipalRid.DOMAIN_USERS, }, { "name": DOMAIN_COMPUTERS_GROUP_NAME, @@ -128,7 +126,6 @@ str(SamAccountTypeCodes.SAM_GROUP_OBJECT.value), ], }, - "objectSid": SecurityPrincipalRid.DOMAIN_COMPUTERS, }, ], }, diff --git a/tests/test_api/test_main/test_router/test_modify_dn.py b/tests/test_api/test_main/test_router/test_modify_dn.py index 8313049f5..4a3dbda6f 100644 --- a/tests/test_api/test_main/test_router/test_modify_dn.py +++ b/tests/test_api/test_main/test_router/test_modify_dn.py @@ -6,6 +6,7 @@ import pytest from httpx import AsyncClient +from sqlalchemy.ext.asyncio import AsyncSession from ldap_protocol.ldap_codes import LDAPCodes @@ -83,6 +84,7 @@ async def test_api_modify_dn_without_level_change( @pytest.mark.usefixtures("session") async def test_api_modify_dn_with_level_down( http_client: AsyncClient, + session: AsyncSession, ) -> None: """Test API for updating DN. @@ -109,6 +111,8 @@ async def test_api_modify_dn_with_level_down( == "cn=testGroup1,ou=testModifyDn2,ou=testModifyDn1,dc=md,dc=test" ) + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ @@ -217,7 +221,10 @@ async def test_api_modify_dn_with_level_up( @pytest.mark.asyncio @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") -async def test_api_correct_update_dn(http_client: AsyncClient) -> None: +async def test_api_correct_update_dn( + http_client: AsyncClient, + session: AsyncSession, +) -> None: """Test API for update DN.""" old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" newrdn_user = "cn=new_test2" @@ -254,6 +261,8 @@ async def test_api_correct_update_dn(http_client: AsyncClient) -> None: if attr["type"] == "cn": assert attr["vals"] == ["user1"] + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ @@ -336,7 +345,10 @@ async def test_api_correct_update_dn(http_client: AsyncClient) -> None: @pytest.mark.asyncio @pytest.mark.usefixtures("setup_session") @pytest.mark.usefixtures("session") -async def test_api_update_dn_with_parent(http_client: AsyncClient) -> None: +async def test_api_update_dn_with_parent( + http_client: AsyncClient, + session: AsyncSession, +) -> None: """Test API for update DN.""" old_user_dn = "cn=user1,cn=moscow,cn=russia,cn=Users,dc=md,dc=test" new_user_dn = "cn=new_test2,cn=Users,dc=md,dc=test" @@ -368,6 +380,8 @@ async def test_api_update_dn_with_parent(http_client: AsyncClient) -> None: assert groups_user + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ diff --git a/tests/test_api/test_main/test_router/test_search.py b/tests/test_api/test_main/test_router/test_search.py index c8c6ab76b..18aa019a8 100644 --- a/tests/test_api/test_main/test_router/test_search.py +++ b/tests/test_api/test_main/test_router/test_search.py @@ -663,7 +663,7 @@ async def test_api_get_group_path_dn_by_primary_group_id_not_found( http_client: AsyncClient, ) -> None: """Test api get group path DN by primary group id not found.""" - primary_group_id = 513 + primary_group_id = 5135 response = await http_client.get( f"entry/group/primary/{primary_group_id}", ) diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index 32c1025b3..2b006e529 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -1062,19 +1062,15 @@ async def test_ldap_modify_primary_group_id_scenarios( [ ( ["cn=domain admins,cn=Groups,dc=md,dc=test"], - False, - 0, - {"domain admins"}, + True, + 1, + {"domain admins", "developers"}, ), ( - [ - "cn=domain admins,cn=Groups,dc=md,dc=test", - "cn=developers,cn=Groups,dc=md,dc=test", - "cn=domain computers,cn=Groups,dc=md,dc=test", - ], - True, + ["cn=domain admins,cn=Groups,dc=md,dc=test"], + False, 0, - {"domain admins", "developers", "domain computers"}, + {"domain admins"}, ), ], ) From 3cc79dd9d21e4edca42ff606d450684c77bb2f22 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 27 Feb 2026 16:54:34 +0300 Subject: [PATCH 08/37] Refactor: Remove unused RID set check and enhance RID Manager functionality with new get_rid_set method --- app/ldap_protocol/ldap_requests/add.py | 1 - app/ldap_protocol/rid_manager/gateways.py | 18 +++- app/ldap_protocol/rid_manager/use_cases.py | 13 ++- tests/conftest.py | 11 --- tests/test_ldap/test_rid_manager.py | 86 ++++++++++++++++++++ tests/test_ldap/test_rid_manager/__init__.py | 1 - 6 files changed, 113 insertions(+), 17 deletions(-) create mode 100644 tests/test_ldap/test_rid_manager.py delete mode 100644 tests/test_ldap/test_rid_manager/__init__.py diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 0d84ca206..20643ac2f 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -213,7 +213,6 @@ async def handle( # noqa: C901 ctx.session.add(new_dir) await ctx.session.flush() - # if await ctx.rid_manager_use_case.get_rid_set(): await ctx.rid_manager_use_case.set_object_sid( directory=new_dir, ) diff --git a/app/ldap_protocol/rid_manager/gateways.py b/app/ldap_protocol/rid_manager/gateways.py index b035a7ce7..e8850a541 100644 --- a/app/ldap_protocol/rid_manager/gateways.py +++ b/app/ldap_protocol/rid_manager/gateways.py @@ -72,7 +72,6 @@ async def get_next_rid(self, domain: Directory) -> int: if not query or not query.value: raise ValueError("next RID attribute not found") - return int(query.value) async def get_domain_identifier(self, domain: Directory) -> str: @@ -482,3 +481,20 @@ async def get_domain_identifier(self) -> str: if not domain or not domain.value: raise ValueError("Domain not found") return domain.value + + async def get_rid_set(self, domain_controller: Directory) -> Directory: + """Get RID Set directory. + + :param domain_controller: Domain controller directory + :return: RID Set directory + :raises ValueError: if RID Set directory not found + """ + rid_set = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == "RID Set", + qa(Directory.parent_id) == domain_controller.id, + ), + ) + if not rid_set: + raise ValueError("RID Set directory not found") + return rid_set diff --git a/app/ldap_protocol/rid_manager/use_cases.py b/app/ldap_protocol/rid_manager/use_cases.py index 7549ce174..d8721552b 100644 --- a/app/ldap_protocol/rid_manager/use_cases.py +++ b/app/ldap_protocol/rid_manager/use_cases.py @@ -61,6 +61,7 @@ async def set_object_sid( rid_set = await self._gateway.get_rid_set() if not rid_set: raise ValueError("RID Set directory not found") + next_rid = await self._gateway.get_next_rid(rid_set) rid = next_rid + 1 await self._gateway.update_next_rid(rid_set, rid) @@ -118,9 +119,6 @@ def __init__( async def setup(self) -> None: """Create RID Manager.""" rid_manager_dir = await self._gateway.set_rid_manager() - await self.inherit_aces( - rid_manager_dir, - ) qword = create_qword(self.RID_USER_MIN, RID_AVAILABLE_MAX) @@ -137,6 +135,9 @@ async def setup(self) -> None: rid_set_dir, self.RID_USER_MIN, ) + await self.inherit_aces( + rid_manager_dir, + ) async def inherit_aces( self, @@ -155,6 +156,12 @@ async def inherit_aces( directory=rid_manager_dir, ) + domain_controller = await self._gateway.get_domain_controller() + await self._role_use_case.inherit_parent_aces( + parent_directory=domain_controller, + directory=await self._gateway.get_rid_set(domain_controller), + ) + async def create_domain_identifier(self) -> None: """Create domain identifier.""" await self._gateway.create_domain_identifier() diff --git a/tests/conftest.py b/tests/conftest.py index 01572a652..195c4edad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1257,17 +1257,6 @@ async def setup_session( await session.commit() -@pytest_asyncio.fixture(scope="function") -async def rid_manager_use_case( - container: AsyncContainer, -) -> AsyncIterator[RIDManagerUseCase]: - """Provide RIDManagerUseCase for tests that request it explicitly.""" - async with container(scope=Scope.SESSION) as container: - session = await container.get(AsyncSession) - gateway = RIDManagerGateway(session) - yield RIDManagerUseCase(gateway, session) - - @pytest_asyncio.fixture(scope="function") async def ldap_session( container: AsyncContainer, diff --git a/tests/test_ldap/test_rid_manager.py b/tests/test_ldap/test_rid_manager.py new file mode 100644 index 000000000..da077f353 --- /dev/null +++ b/tests/test_ldap/test_rid_manager.py @@ -0,0 +1,86 @@ +"""Tests for RID Manager.""" + +from typing import AsyncIterator + +import pytest +import pytest_asyncio +from dishka import AsyncContainer, Scope +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload + +from entities import Directory +from enums import SidPrefix +from ldap_protocol.rid_manager.gateways import RIDManagerGateway +from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase +from ldap_protocol.utils.queries import get_filter_from_path +from repo.pg.tables import queryable_attr as qa + + +@pytest_asyncio.fixture(scope="function") +async def rid_manager_gateway( + container: AsyncContainer, +) -> AsyncIterator[RIDManagerGateway]: + """Get RID Manager gateway.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDManagerGateway(session) + + +@pytest_asyncio.fixture(scope="function") +async def rid_manager_use_case( + container: AsyncContainer, + rid_manager_gateway: RIDManagerGateway, +) -> AsyncIterator[RIDManagerUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDManagerUseCase(rid_manager_gateway, session) + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_session") +@pytest.mark.parametrize( + "sid_prefix", + [SidPrefix.DOMAIN_IDENTIFIER, SidPrefix.BUILT_IN_DOMAIN], +) +async def test_set_object_sid( + session: AsyncSession, + rid_manager_gateway: RIDManagerGateway, + rid_manager_use_case: RIDManagerUseCase, + sid_prefix: SidPrefix, +) -> None: + """Test RID Manager use case.""" + directory = ( + await session.scalars( + select(Directory) + .options(selectinload(qa(Directory.attributes))) + .filter(get_filter_from_path("cn=user0,cn=Users,dc=md,dc=test")), + ) + ).one() + + rid_set = await rid_manager_use_case.get_rid_set() + assert rid_set + rid_manager = await rid_manager_gateway.get_rid_manager() + pool_before = await rid_manager_gateway.get_rid_available_pool(rid_manager) + next_before = await rid_manager_gateway.get_next_rid(rid_set) + + await rid_manager_use_case.set_object_sid( + directory, rid=None, sid_prefix=sid_prefix + ) + await session.commit() + + expected_rid = next_before + 1 + pool_after = await rid_manager_gateway.get_rid_available_pool(rid_manager) + assert (pool_after & 0xFFFFFFFF) == expected_rid + assert pool_after != pool_before + + assert await rid_manager_gateway.get_next_rid(rid_set) == expected_rid + + await session.refresh(directory, ["attributes"]) + sid = await rid_manager_use_case.get_object_sid(directory) + if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: + assert sid == f"{sid_prefix}-{expected_rid}" + else: + assert sid.startswith(f"{sid_prefix}-") + assert sid.endswith(f"-{expected_rid}") diff --git a/tests/test_ldap/test_rid_manager/__init__.py b/tests/test_ldap/test_rid_manager/__init__.py deleted file mode 100644 index ae7ee0bad..000000000 --- a/tests/test_ldap/test_rid_manager/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for RID Manager.""" From 50acfa778e6eedb26651b52a04756ee665da00d5 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 27 Feb 2026 17:04:01 +0300 Subject: [PATCH 09/37] Add: Introduce new pytest fixtures for RID Manager gateway and use case in test suite --- tests/conftest.py | 21 +++++++++++++++++++++ tests/test_ldap/test_rid_manager.py | 29 +++-------------------------- 2 files changed, 24 insertions(+), 26 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 195c4edad..a76280f0a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1764,6 +1764,27 @@ async def ctx_search( yield await c.get(LDAPSearchRequestContext) +@pytest_asyncio.fixture(scope="function") +async def rid_manager_gateway( + container: AsyncContainer, +) -> AsyncIterator[RIDManagerGateway]: + """Get RID Manager gateway.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDManagerGateway(session) + + +@pytest_asyncio.fixture(scope="function") +async def rid_manager_use_case( + container: AsyncContainer, + rid_manager_gateway: RIDManagerGateway, +) -> AsyncIterator[RIDManagerUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDManagerUseCase(rid_manager_gateway, session) + + def pytest_configure(config: pytest.Config) -> None: """Pytest hook to limit xdist workers based on Dragonfly DBs. diff --git a/tests/test_ldap/test_rid_manager.py b/tests/test_ldap/test_rid_manager.py index da077f353..d13dccc88 100644 --- a/tests/test_ldap/test_rid_manager.py +++ b/tests/test_ldap/test_rid_manager.py @@ -1,10 +1,6 @@ """Tests for RID Manager.""" -from typing import AsyncIterator - import pytest -import pytest_asyncio -from dishka import AsyncContainer, Scope from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload @@ -17,27 +13,6 @@ from repo.pg.tables import queryable_attr as qa -@pytest_asyncio.fixture(scope="function") -async def rid_manager_gateway( - container: AsyncContainer, -) -> AsyncIterator[RIDManagerGateway]: - """Get RID Manager gateway.""" - async with container(scope=Scope.SESSION) as container: - session = await container.get(AsyncSession) - yield RIDManagerGateway(session) - - -@pytest_asyncio.fixture(scope="function") -async def rid_manager_use_case( - container: AsyncContainer, - rid_manager_gateway: RIDManagerGateway, -) -> AsyncIterator[RIDManagerUseCase]: - """Provide RIDManagerUseCase for tests that request it explicitly.""" - async with container(scope=Scope.SESSION) as container: - session = await container.get(AsyncSession) - yield RIDManagerUseCase(rid_manager_gateway, session) - - @pytest.mark.asyncio @pytest.mark.usefixtures("setup_session") @pytest.mark.parametrize( @@ -66,7 +41,9 @@ async def test_set_object_sid( next_before = await rid_manager_gateway.get_next_rid(rid_set) await rid_manager_use_case.set_object_sid( - directory, rid=None, sid_prefix=sid_prefix + directory, + rid=None, + sid_prefix=sid_prefix, ) await session.commit() From 062a9f676d6fb161ac34c3b04407db4efe5b6f20 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 27 Feb 2026 17:16:19 +0300 Subject: [PATCH 10/37] Enhance: Update test_api_modify_dn_with_level_up to include session expiration before API call --- tests/test_api/test_main/test_router/test_modify_dn.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/test_api/test_main/test_router/test_modify_dn.py b/tests/test_api/test_main/test_router/test_modify_dn.py index 4a3dbda6f..efe7dcf0a 100644 --- a/tests/test_api/test_main/test_router/test_modify_dn.py +++ b/tests/test_api/test_main/test_router/test_modify_dn.py @@ -155,6 +155,7 @@ async def test_api_modify_dn_with_level_down( @pytest.mark.usefixtures("session") async def test_api_modify_dn_with_level_up( http_client: AsyncClient, + session: AsyncSession, ) -> None: """Test API for updating DN. @@ -181,6 +182,8 @@ async def test_api_modify_dn_with_level_up( == "cn=testGroup2,ou=testModifyDn1,dc=md,dc=test" ) + session.expire_all() + response = await http_client.put( "/entry/update/dn", json={ From 6a03b770e643d869dd0aa227857af029b60550a9 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Mon, 2 Mar 2026 10:53:09 +0300 Subject: [PATCH 11/37] Refactor: Replace ValueError with specific RID Manager exceptions for better error handling --- .../552b4eafb1aa_remove_objectsid_vals.py | 20 +--- app/ldap_protocol/rid_manager/exceptions.py | 103 ++++++++++++++++++ app/ldap_protocol/rid_manager/gateways.py | 92 ++++++++-------- app/ldap_protocol/rid_manager/use_cases.py | 16 ++- tests/conftest.py | 2 +- 5 files changed, 166 insertions(+), 67 deletions(-) create mode 100644 app/ldap_protocol/rid_manager/exceptions.py diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 16e9437f0..e059ad3c1 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -16,17 +16,13 @@ from enums import EntityTypeNames from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase -from ldap_protocol.rid_manager.gateways import ( - RIDManagerGateway, - RIDManagerSetupGateway, -) +from ldap_protocol.rid_manager.exceptions import RIDManagerNotFoundError +from ldap_protocol.rid_manager.gateways import RIDManagerGateway from ldap_protocol.rid_manager.use_cases import ( RID_AVAILABLE_MAX, RIDManagerSetupUseCase, ) from ldap_protocol.rid_manager.utils import create_qword -from ldap_protocol.roles.ace_dao import AccessControlEntryDAO -from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -133,22 +129,16 @@ async def _init_rid_manager( """Initialize RID Manager and RID Set for existing data.""" async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) - rid_setup_gateway = await cnt.get(RIDManagerSetupGateway) - rid_setup_use_case = RIDManagerSetupUseCase( - rid_manager_setup_gateway=rid_setup_gateway, - role_use_case=await cnt.get(RoleUseCase), - access_control_entry_dao=await cnt.get(AccessControlEntryDAO), - ) - rid_gateway = RIDManagerGateway(session) + rid_setup_use_case = await cnt.get(RIDManagerSetupUseCase) + rid_gateway = await cnt.get(RIDManagerGateway) if not await get_base_directories(session): return try: await rid_gateway.get_rid_manager() - except ValueError: + except RIDManagerNotFoundError: await rid_setup_use_case.setup() - await session.commit() await rid_gateway.get_rid_manager() rid_set_dir = await rid_gateway.get_rid_set() diff --git a/app/ldap_protocol/rid_manager/exceptions.py b/app/ldap_protocol/rid_manager/exceptions.py new file mode 100644 index 000000000..cefa0c3e7 --- /dev/null +++ b/app/ldap_protocol/rid_manager/exceptions.py @@ -0,0 +1,103 @@ +"""RID Manager exceptions.""" + +from enum import IntEnum + +from errors import BaseDomainException + + +class ErrorCodes(IntEnum): + """Error codes.""" + + BASE_ERROR = 0 + RID_MANAGER_NOT_FOUND_ERROR = 1 + RID_MANAGER_ALREADY_EXISTS_ERROR = 2 + RID_MANAGER_CANT_MODIFY_ERROR = 3 + RID_MANAGER_SETUP_ERROR = 4 + RID_AVAILABLE_POOL_NOT_FOUND_ERROR = 5 + RID_NEXT_RID_NOT_FOUND_ERROR = 6 + RID_SET_NOT_FOUND_ERROR = 7 + RID_SET_CANT_MODIFY_ERROR = 8 + RID_SET_ALREADY_EXISTS_ERROR = 9 + RID_DOMAIN_IDENTIFIER_NOT_FOUND_ERROR = 10 + RID_DOMAIN_CONTROLLER_NOT_FOUND_ERROR = 11 + RID_OBJECT_SID_NOT_FOUND_ERROR = 12 + RID_BASE_DOMAIN_NOT_FOUND_ERROR = 13 + RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR = 14 + + +class RIDManagerError(BaseDomainException): + """RID Manager error.""" + + code: ErrorCodes = ErrorCodes.BASE_ERROR + + +class RIDManagerNotFoundError(RIDManagerError): + """RID Manager not found error.""" + + code = ErrorCodes.RID_MANAGER_NOT_FOUND_ERROR + + +class RIDManagerSetupError(RIDManagerError): + """RID Manager setup error.""" + + code = ErrorCodes.RID_MANAGER_SETUP_ERROR + + +class RIDManagerAvailablePoolNotFoundError(RIDManagerError): + """RID Manager available pool not found error.""" + + code = ErrorCodes.RID_AVAILABLE_POOL_NOT_FOUND_ERROR + + +class RIDManagerNextRIDNotFoundError(RIDManagerError): + """RID Manager next RID not found error.""" + + code = ErrorCodes.RID_NEXT_RID_NOT_FOUND_ERROR + + +class RIDManagerRidSetNotFoundError(RIDManagerError): + """RID Manager RID Set not found error.""" + + code = ErrorCodes.RID_SET_NOT_FOUND_ERROR + + +class RIDManagerSetCantModifyError(RIDManagerError): + """RID Manager set can't modify error.""" + + code = ErrorCodes.RID_SET_CANT_MODIFY_ERROR + + +class RIDManagerSetAlreadyExistsError(RIDManagerError): + """RID Manager set already exists error.""" + + code = ErrorCodes.RID_SET_ALREADY_EXISTS_ERROR + + +class RIDManagerDomainIdentifierNotFoundError(RIDManagerError): + """RID Manager domain identifier not found error.""" + + code = ErrorCodes.RID_DOMAIN_IDENTIFIER_NOT_FOUND_ERROR + + +class RIDManagerDomainControllerNotFoundError(RIDManagerError): + """RID Manager domain controller not found error.""" + + code = ErrorCodes.RID_DOMAIN_CONTROLLER_NOT_FOUND_ERROR + + +class RIDManagerObjectSidNotFoundError(RIDManagerError): + """RID Manager object SID not found error.""" + + code = ErrorCodes.RID_OBJECT_SID_NOT_FOUND_ERROR + + +class RIDManagerDomainNotFoundError(RIDManagerError): + """RID Manager base domain not found error.""" + + code = ErrorCodes.RID_BASE_DOMAIN_NOT_FOUND_ERROR + + +class RIDManagerSystemContainerNotFoundError(RIDManagerError): + """RID Manager system container not found error.""" + + code = ErrorCodes.RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR diff --git a/app/ldap_protocol/rid_manager/gateways.py b/app/ldap_protocol/rid_manager/gateways.py index e8850a541..1ede7df80 100644 --- a/app/ldap_protocol/rid_manager/gateways.py +++ b/app/ldap_protocol/rid_manager/gateways.py @@ -9,9 +9,19 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from config import Settings from entities import Attribute, Directory from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerAvailablePoolNotFoundError, + RIDManagerDomainControllerNotFoundError, + RIDManagerDomainIdentifierNotFoundError, + RIDManagerDomainNotFoundError, + RIDManagerNextRIDNotFoundError, + RIDManagerNotFoundError, + RIDManagerObjectSidNotFoundError, + RIDManagerRidSetNotFoundError, + RIDManagerSystemContainerNotFoundError, +) from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -40,19 +50,20 @@ async def get_rid_available_pool(self, domain: Directory) -> int: :param domain: Domain directory object :return: QWORD value of rIDAvailablePool - :raises ValueError: if attribute not found """ - query = select(Attribute).where( - qa(Attribute.directory_id) == domain.id, - qa(Attribute.name) == "rIDAvailablePool", + query = await self._session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "rIDAvailablePool", + ), ) - attr = await self._session.scalar(query) - - if not attr or not attr.value: - raise ValueError("rIDAvailablePool attribute not found") + if not query or not query.value: + raise RIDManagerAvailablePoolNotFoundError( + "rIDAvailablePool attribute not found", + ) - return int(attr.value) + return int(query.value) async def get_next_rid(self, domain: Directory) -> int: """Get next RID attribute from domain. @@ -64,21 +75,24 @@ async def get_next_rid(self, domain: Directory) -> int: :return: Last issued RID or None if not set """ query = await self._session.scalar( - select(Attribute).where( + select(Attribute) + .where( qa(Attribute.directory_id) == domain.id, qa(Attribute.name) == "rIDNextRID", - ), + ) + .with_for_update(), ) if not query or not query.value: - raise ValueError("next RID attribute not found") + raise RIDManagerNextRIDNotFoundError( + "next RID attribute not found", + ) return int(query.value) async def get_domain_identifier(self, domain: Directory) -> str: """Get domain identifier. :return: Domain identifier - :raises ValueError: if domain identifier not found """ query = await self._session.scalar( select(Attribute).where( @@ -88,7 +102,9 @@ async def get_domain_identifier(self, domain: Directory) -> str: ) if not query or not query.value: - raise ValueError("domain identifier not found") + raise RIDManagerDomainIdentifierNotFoundError( + "domain identifier not found", + ) return query.value @@ -96,14 +112,11 @@ async def get_rid_set(self) -> Directory | None: """Get RID Set directory. :return: RID Set directory - :raises ValueError: if RID Set directory not found """ - rid_set = await self._session.scalar( + return await self._session.scalar( select(Directory).where(qa(Directory.name) == "RID Set"), ) - return rid_set - async def update_next_rid(self, rid_set: Directory, next_rid: int) -> None: """Update next RID attribute in RID Set directory. @@ -123,13 +136,12 @@ async def get_rid_manager(self) -> Directory: """Get RID Manager directory. :return: RID Manager directory - :raises ValueError: if RID Manager directory not found """ rid_manager = await self._session.scalar( select(Directory).where(qa(Directory.name) == "RID Manager$"), ) if not rid_manager: - raise ValueError("RID Manager directory not found") + raise RIDManagerNotFoundError("RID Manager directory not found") return rid_manager @@ -186,20 +198,19 @@ async def get_object_sid( ), ) if not query or not query.value: - raise ValueError("object SID not found") + raise RIDManagerObjectSidNotFoundError("object SID not found") return query.value async def get_base_domain(self) -> Directory: """Get base domain directory. :return: Base domain directory - :raises ValueError: if base domain not found """ base_domain = await self._session.scalar( select(Directory).where(qa(Directory.object_class) == "domain"), ) if not base_domain: - raise ValueError("base domain not found") + raise RIDManagerDomainNotFoundError("base domain not found") return base_domain @@ -210,27 +221,24 @@ def __init__( self, session: AsyncSession, entity_type_dao: EntityTypeDAO, - settings: Settings, ) -> None: """Initialize RID Manager setup gateway.""" self._session = session self._entity_type_dao = entity_type_dao - self._settings = settings - async def get_domain_controller(self) -> Directory: + async def get_domain_controller(self, host_machine_name: str) -> Directory: """Get domain controller directory. :return: Domain controller directory - :raises ValueError: if domain controller not found """ dc = await self._session.scalar( select(Directory).where( - qa(Directory.name) == self._settings.HOST_MACHINE_NAME, + qa(Directory.name) == host_machine_name, ), ) if not dc: - raise ValueError( + raise RIDManagerDomainControllerNotFoundError( "Domain controller not found", ) @@ -240,11 +248,8 @@ async def get_system_container(self) -> Directory: """Get System container directory. :return: System container directory - :raises ValueError: if System container not found """ base_dn_list = await get_base_directories(self._session) - if not base_dn_list: - raise ValueError("Domain not found") domain = base_dn_list[0] @@ -256,7 +261,9 @@ async def get_system_container(self) -> Directory: system_container = await self._session.scalar(query) if not system_container: - raise ValueError("System container not found") + raise RIDManagerSystemContainerNotFoundError( + "System container not found", + ) return system_container @@ -264,11 +271,6 @@ async def set_rid_manager(self) -> Directory: """Create RID Manager directory.""" system_container = await self.get_system_container() - base_dn_list = await get_base_directories(self._session) - if not base_dn_list: - raise ValueError("Domain not found") - base_dn_list[0] - rid_manager_dir = Directory( is_system=True, name="RID Manager$", @@ -331,11 +333,6 @@ async def create_rid_set( :param domain_controller: Domain Controller directory object :return: Created RID Set directory """ - base_dn_list = await get_base_directories(self._session) - if not base_dn_list: - raise ValueError("Domain not found") - base_dn_list[0] - rid_set_dir = Directory( is_system=True, name="RID Set", @@ -460,7 +457,7 @@ async def create_domain_identifier(self) -> None: ), ) if not domain: - raise ValueError("Domain not found") + raise RIDManagerDomainNotFoundError("Domain not found") self._session.add( Attribute( @@ -479,7 +476,7 @@ async def get_domain_identifier(self) -> str: ), ) if not domain or not domain.value: - raise ValueError("Domain not found") + raise RIDManagerDomainIdentifierNotFoundError("Domain not found") return domain.value async def get_rid_set(self, domain_controller: Directory) -> Directory: @@ -487,7 +484,6 @@ async def get_rid_set(self, domain_controller: Directory) -> Directory: :param domain_controller: Domain controller directory :return: RID Set directory - :raises ValueError: if RID Set directory not found """ rid_set = await self._session.scalar( select(Directory).where( @@ -496,5 +492,5 @@ async def get_rid_set(self, domain_controller: Directory) -> Directory: ), ) if not rid_set: - raise ValueError("RID Set directory not found") + raise RIDManagerRidSetNotFoundError("RID Set directory not found") return rid_set diff --git a/app/ldap_protocol/rid_manager/use_cases.py b/app/ldap_protocol/rid_manager/use_cases.py index d8721552b..31fddb4ec 100644 --- a/app/ldap_protocol/rid_manager/use_cases.py +++ b/app/ldap_protocol/rid_manager/use_cases.py @@ -9,8 +9,10 @@ from sqlalchemy.ext.asyncio import AsyncSession +from config import Settings from entities import Directory from enums import SidPrefix +from ldap_protocol.rid_manager.exceptions import RIDManagerRidSetNotFoundError from ldap_protocol.rid_manager.gateways import ( RIDManagerGateway, RIDManagerSetupGateway, @@ -60,7 +62,9 @@ async def set_object_sid( if rid is None: rid_set = await self._gateway.get_rid_set() if not rid_set: - raise ValueError("RID Set directory not found") + raise RIDManagerRidSetNotFoundError( + "RID Set directory not found", + ) next_rid = await self._gateway.get_next_rid(rid_set) rid = next_rid + 1 @@ -106,6 +110,7 @@ def __init__( rid_manager_setup_gateway: RIDManagerSetupGateway, role_use_case: RoleUseCase, access_control_entry_dao: AccessControlEntryDAO, + settings: Settings, ) -> None: """Initialize RID Manager setup use case. @@ -115,6 +120,7 @@ def __init__( self._gateway = rid_manager_setup_gateway self._role_use_case = role_use_case self._access_control_entry_dao = access_control_entry_dao + self._settings = settings async def setup(self) -> None: """Create RID Manager.""" @@ -126,7 +132,9 @@ async def setup(self) -> None: rid_manager_dir, qword, ) - domain_controller = await self._gateway.get_domain_controller() + domain_controller = await self._gateway.get_domain_controller( + self._settings.HOST_MACHINE_NAME, + ) rid_set_dir = await self._gateway.create_rid_set( domain_controller, @@ -156,7 +164,9 @@ async def inherit_aces( directory=rid_manager_dir, ) - domain_controller = await self._gateway.get_domain_controller() + domain_controller = await self._gateway.get_domain_controller( + self._settings.HOST_MACHINE_NAME, + ) await self._role_use_case.inherit_parent_aces( parent_directory=domain_controller, directory=await self._gateway.get_rid_set(domain_controller), diff --git a/tests/conftest.py b/tests/conftest.py index a76280f0a..4e379b809 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1126,7 +1126,6 @@ async def setup_session( rid_manager_setup_gateway = RIDManagerSetupGateway( session=session, entity_type_dao=entity_type_dao, - settings=settings, ) role_dao = RoleDAO(session) ace_dao = AccessControlEntryDAO(session) @@ -1135,6 +1134,7 @@ async def setup_session( rid_manager_setup_gateway=rid_manager_setup_gateway, role_use_case=role_use_case, access_control_entry_dao=AccessControlEntryDAO(session), + settings=settings, ) setup_gateway = SetupGateway( session, From de2c4b2fcdd6c157558f238870cb4c7e89b05fa4 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 6 Mar 2026 10:32:55 +0300 Subject: [PATCH 12/37] Refactor: Introduce ObjectSIDUseCase and related gateways, enhancing RID management functionality --- .../552b4eafb1aa_remove_objectsid_vals.py | 187 +++++-- app/extra/scripts/add_domain_controller.py | 10 +- app/ioc.py | 8 + app/ldap_protocol/auth/setup_gateway.py | 8 +- app/ldap_protocol/auth/use_cases.py | 2 +- app/ldap_protocol/ldap_requests/add.py | 2 +- app/ldap_protocol/ldap_requests/contexts.py | 6 +- app/ldap_protocol/rid_manager/__init__.py | 16 +- app/ldap_protocol/rid_manager/dtos.py | 16 + app/ldap_protocol/rid_manager/exceptions.py | 23 +- app/ldap_protocol/rid_manager/gateways.py | 496 ------------------ .../rid_manager/object_sid_gateway.py | 60 +++ .../rid_manager/object_sid_use_case.py | 63 +++ .../rid_manager/rid_manager_gateway.py | 69 +++ .../rid_manager/rid_manager_use_case.py | 48 ++ .../rid_manager/rid_set_gateway.py | 204 +++++++ .../rid_manager/rid_set_use_case.py | 107 ++++ .../rid_manager/setup_gateway.py | 184 +++++++ .../rid_manager/setup_use_case.py | 110 ++++ app/ldap_protocol/rid_manager/use_cases.py | 177 ------- app/ldap_protocol/rid_manager/utils.py | 12 +- app/ldap_protocol/rootdse/reader.py | 8 +- tests/conftest.py | 35 +- .../test_main/test_router/conftest.py | 6 +- tests/test_ldap/test_rid_manager.py | 108 ++-- tests/test_shedule.py | 6 +- 26 files changed, 1167 insertions(+), 804 deletions(-) create mode 100644 app/ldap_protocol/rid_manager/dtos.py delete mode 100644 app/ldap_protocol/rid_manager/gateways.py create mode 100644 app/ldap_protocol/rid_manager/object_sid_gateway.py create mode 100644 app/ldap_protocol/rid_manager/object_sid_use_case.py create mode 100644 app/ldap_protocol/rid_manager/rid_manager_gateway.py create mode 100644 app/ldap_protocol/rid_manager/rid_manager_use_case.py create mode 100644 app/ldap_protocol/rid_manager/rid_set_gateway.py create mode 100644 app/ldap_protocol/rid_manager/rid_set_use_case.py create mode 100644 app/ldap_protocol/rid_manager/setup_gateway.py create mode 100644 app/ldap_protocol/rid_manager/setup_use_case.py delete mode 100644 app/ldap_protocol/rid_manager/use_cases.py diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index e059ad3c1..176d9dcce 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -6,6 +6,8 @@ """ +import secrets + import sqlalchemy as sa from alembic import op from dishka import AsyncContainer, Scope @@ -16,19 +18,26 @@ from enums import EntityTypeNames from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase -from ldap_protocol.rid_manager.exceptions import RIDManagerNotFoundError -from ldap_protocol.rid_manager.gateways import RIDManagerGateway -from ldap_protocol.rid_manager.use_cases import ( - RID_AVAILABLE_MAX, +from ldap_protocol.rid_manager import ( + RIDManagerGateway, + RIDManagerSetupGateway, RIDManagerSetupUseCase, + RIDManagerUseCase, + RIDSetUseCase, +) +from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerNotFoundError, + RIDManagerRidSetNotFoundError, ) -from ldap_protocol.rid_manager.utils import create_qword +from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway +from ldap_protocol.rid_manager.utils import from_qword, to_qword from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. revision: None | str = "552b4eafb1aa" -down_revision: None | str = "19d86e660cf2" +down_revision: None | str = "2dadf40c026a" branch_labels: None | list[str] = None depends_on: None | list[str] = None @@ -78,8 +87,8 @@ async def _migrate_object_sids( ) -> None: """Move Directory.objectSid values into Attributes table. - Additionally, for domain directories move the domain SID prefix part - into the ``DomainIdentifier`` attribute. + Additionally, for domain directories create the ``DomainIdentifier`` + attribute if it does not exist. """ async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) @@ -106,16 +115,46 @@ async def _migrate_object_sids( ), ) - if directory.name == "domain": - identifier = directory.object_sid.split("-")[ - -1 - ] # remove sid prefix + base_dn_list = await get_base_directories(session) + if base_dn_list: + domain = base_dn_list[0] + + existing_identifier = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "DomainIdentifier", + ), + ) + + if not (existing_identifier and existing_identifier.value): + domain_object_sid = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "objectSid", + ), + ) + + identifier: str | None = None + if domain_object_sid and domain_object_sid.value: + parts = domain_object_sid.value.split("-") + # "S-1-5-21-AAA-BBB-CCC" -> "AAA-BBB-CCC" + if len(parts) >= 7 and domain_object_sid.value.startswith( + "S-1-5-21-", + ): + identifier = "-".join(parts[4:7]) + + if identifier is None: + identifier = ( + f"{secrets.randbits(32)}-" + f"{secrets.randbits(32)}-" + f"{secrets.randbits(32)}" + ) session.add( Attribute( name="DomainIdentifier", value=identifier, - directory_id=directory.id, + directory_id=domain.id, ), ) @@ -129,27 +168,35 @@ async def _init_rid_manager( """Initialize RID Manager and RID Set for existing data.""" async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) - rid_setup_use_case = await cnt.get(RIDManagerSetupUseCase) - rid_gateway = await cnt.get(RIDManagerGateway) + rid_setup_gateway = await cnt.get(RIDManagerSetupGateway) + rid_gateway = await cnt.get(RIDManagerGateway) + rid_manager_use_case = await cnt.get(RIDManagerUseCase) + rid_set_gateway = await cnt.get(RIDSetGateway) + rid_set_use_case = await cnt.get(RIDSetUseCase) if not await get_base_directories(session): return try: - await rid_gateway.get_rid_manager() + rid_manager_dir = await rid_gateway.get_rid_manager() except RIDManagerNotFoundError: - await rid_setup_use_case.setup() - await rid_gateway.get_rid_manager() + rid_manager_dir = await rid_setup_gateway.set_rid_manager() - rid_set_dir = await rid_gateway.get_rid_set() - if not rid_set_dir: + base_dn_list = await get_base_directories(session) + if not base_dn_list: return + domain = base_dn_list[0] - base_domain = await rid_gateway.get_base_domain() - domain_identifier = await rid_gateway.get_domain_identifier( - base_domain, + domain_identifier = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "DomainIdentifier", + ), ) - sid_prefix = f"S-1-5-21-{domain_identifier}-" + if not (domain_identifier and domain_identifier.value): + return + + sid_prefix = f"S-1-5-21-{domain_identifier.value}-" sid_values = await session.scalars( select(Attribute).where( @@ -172,25 +219,89 @@ async def _init_rid_manager( start_rid = max(max_rid, RIDManagerSetupUseCase.RID_USER_MIN) - qword = create_qword(start_rid, RID_AVAILABLE_MAX) - await rid_gateway.update_available_pool(qword) + qword = to_qword(start_rid, RIDManagerSetupUseCase.RID_AVAILABLE_MAX) + await rid_setup_gateway.set_rid_available_pool(rid_manager_dir, qword) - result = await session.execute( - update(Attribute) - .where( + domain_controller = await rid_gateway.get_domain_controller() + rid_set_dir: Directory | None = None + try: + rid_set_dir = await rid_set_gateway.get(domain_controller) + except RIDManagerRidSetNotFoundError: + rid_set_dir = None + + if rid_set_dir is None: + previous_allocation_pool = ( + await rid_manager_use_case.allocate_pool() + ) + allocation_pool = await rid_manager_use_case.allocate_pool() + lower, _ = from_qword(previous_allocation_pool) + + await rid_set_use_case.add( + domain_controller, + RIDSetAllocationParamsDTO( + next_rid=lower, + allocation_pool=allocation_pool, + previous_allocation_pool=previous_allocation_pool, + ), + ) + await session.commit() + return + + existing_next_rid = await session.scalar( + select(Attribute).where( qa(Attribute.directory_id) == rid_set_dir.id, qa(Attribute.name) == "rIDNextRID", - ) - .values(value=str(start_rid)), + ), ) - if result.rowcount == 0: - session.add( - Attribute( - directory_id=rid_set_dir.id, - name="rIDNextRID", - value=str(start_rid), - ), + existing_prev_pool = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == rid_set_dir.id, + qa(Attribute.name) == "rIDPreviousAllocationPool", + ), + ) + existing_pool = await session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == rid_set_dir.id, + qa(Attribute.name) == "rIDAllocationPool", + ), + ) + + if ( + existing_next_rid + and existing_next_rid.value + and existing_prev_pool + and existing_prev_pool.value + and existing_pool + and existing_pool.value + ): + await session.commit() + return + + previous_allocation_pool = await rid_manager_use_case.allocate_pool() + allocation_pool = await rid_manager_use_case.allocate_pool() + lower, _ = from_qword(previous_allocation_pool) + + for name, value in ( + ("rIDNextRID", str(lower)), + ("rIDPreviousAllocationPool", str(previous_allocation_pool)), + ("rIDAllocationPool", str(allocation_pool)), + ): + result = await session.execute( + update(Attribute) + .where( + qa(Attribute.directory_id) == rid_set_dir.id, + qa(Attribute.name) == name, + ) + .values(value=value), ) + if result.rowcount == 0: + session.add( + Attribute( + directory_id=rid_set_dir.id, + name=name, + value=value, + ), + ) await session.commit() diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py index 36eb9a00b..1b771d1fc 100644 --- a/app/extra/scripts/add_domain_controller.py +++ b/app/extra/scripts/add_domain_controller.py @@ -16,7 +16,7 @@ EntityTypeUseCase, ) from ldap_protocol.objects import UserAccountControlFlag -from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.roles.role_use_case import RoleUseCase from repo.pg.tables import queryable_attr as qa @@ -27,7 +27,7 @@ async def _add_domain_controller( entity_type_use_case: EntityTypeUseCase, settings: Settings, dc_ou_dir: Directory, - rid_manager_use_case: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: dc_directory = Directory( object_class="", @@ -39,7 +39,7 @@ async def _add_domain_controller( await session.flush() dc_directory.parent_id = dc_ou_dir.id - await rid_manager_use_case.set_object_sid( + await object_sid_use_case.add( directory=dc_directory, rid=SecurityPrincipalRid.DOMAIN_CONTROLLERS, ) @@ -105,7 +105,7 @@ async def add_domain_controller( settings: Settings, role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, - rid_manager_use_case: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: logger.info("Adding domain controller.") @@ -139,7 +139,7 @@ async def add_domain_controller( entity_type_use_case=entity_type_use_case, settings=settings, dc_ou_dir=domain_controllers_ou, - rid_manager_use_case=rid_manager_use_case, + object_sid_use_case=object_sid_use_case, ) logger.debug("Domain controller added.") diff --git a/app/ioc.py b/app/ioc.py index 06cc7c617..aec45105d 100644 --- a/app/ioc.py +++ b/app/ioc.py @@ -179,10 +179,14 @@ UserPasswordHistoryUseCases, ) from ldap_protocol.rid_manager import ( + ObjectSIDGateway, + ObjectSIDUseCase, RIDManagerGateway, RIDManagerSetupGateway, RIDManagerSetupUseCase, RIDManagerUseCase, + RIDSetGateway, + RIDSetUseCase, ) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -657,6 +661,10 @@ def get_object_class_use_case_legacy( RIDManagerSetupUseCase, scope=Scope.REQUEST, ) + object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) + rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) + rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST) class LDAPContextProvider(Provider): diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index 96b293dde..9152920f3 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -20,7 +20,7 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) -from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.utils.async_cache import base_directories_cache from ldap_protocol.utils.queries import get_domain_object_class from password_utils import PasswordUtils @@ -37,7 +37,7 @@ def __init__( entity_type_use_case: EntityTypeUseCase, attribute_value_validator: AttributeValueValidator, directory_dao: DirectoryDAO, - rid_manager_use_case: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Initialize Setup use case. @@ -50,7 +50,7 @@ def __init__( self._entity_type_use_case = entity_type_use_case self._attribute_value_validator = attribute_value_validator self._directory_dao = directory_dao - self._rid_manager_use_case = rid_manager_use_case + self._object_sid_use_case = object_sid_use_case async def is_setup(self) -> bool: """Check if setup is performed. @@ -174,7 +174,7 @@ async def create_dir( ) if "objectSid" in data: - await self._rid_manager_use_case.set_object_sid( + await self._object_sid_use_case.add( directory=dir_, rid=int(data["objectSid"]), sid_prefix=SidPrefix.BUILT_IN_DOMAIN, diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index b323d08d8..4fb3d666c 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -44,7 +44,7 @@ from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases -from ldap_protocol.rid_manager.use_cases import RIDManagerSetupUseCase +from ldap_protocol.rid_manager import RIDManagerSetupUseCase from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.helpers import create_integer_hash, ft_now diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index 20643ac2f..de5e4106a 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -213,7 +213,7 @@ async def handle( # noqa: C901 ctx.session.add(new_dir) await ctx.session.flush() - await ctx.rid_manager_use_case.set_object_sid( + await ctx.object_sid_use_case.add( directory=new_dir, ) await ctx.session.flush() diff --git a/app/ldap_protocol/ldap_requests/contexts.py b/app/ldap_protocol/ldap_requests/contexts.py index d53d73b51..f81b3f113 100644 --- a/app/ldap_protocol/ldap_requests/contexts.py +++ b/app/ldap_protocol/ldap_requests/contexts.py @@ -27,7 +27,7 @@ from ldap_protocol.multifactor import LDAPMultiFactorAPI from ldap_protocol.policies.network import NetworkPolicyValidatorUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases -from ldap_protocol.rid_manager import RIDManagerUseCase +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.rootdse.reader import RootDSEReader @@ -48,7 +48,7 @@ class LDAPAddRequestContext: access_manager: AccessManager role_use_case: RoleUseCase attribute_value_validator: AttributeValueValidator - rid_manager_use_case: RIDManagerUseCase + object_sid_use_case: ObjectSIDUseCase @dataclass @@ -65,7 +65,7 @@ class LDAPModifyRequestContext: password_use_cases: PasswordPolicyUseCases password_utils: PasswordUtils attribute_value_validator: AttributeValueValidator - rid_manager_use_case: RIDManagerUseCase + object_sid_use_case: ObjectSIDUseCase @dataclass diff --git a/app/ldap_protocol/rid_manager/__init__.py b/app/ldap_protocol/rid_manager/__init__.py index a32cedc94..204bbef53 100644 --- a/app/ldap_protocol/rid_manager/__init__.py +++ b/app/ldap_protocol/rid_manager/__init__.py @@ -4,12 +4,22 @@ License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE """ -from .gateways import RIDManagerGateway, RIDManagerSetupGateway -from .use_cases import RIDManagerSetupUseCase, RIDManagerUseCase +from .object_sid_gateway import ObjectSIDGateway +from .object_sid_use_case import ObjectSIDUseCase +from .rid_manager_gateway import RIDManagerGateway +from .rid_manager_use_case import RIDManagerUseCase +from .rid_set_gateway import RIDSetGateway +from .rid_set_use_case import RIDSetUseCase +from .setup_gateway import RIDManagerSetupGateway +from .setup_use_case import RIDManagerSetupUseCase __all__ = [ + "ObjectSIDGateway", + "ObjectSIDUseCase", "RIDManagerGateway", "RIDManagerSetupGateway", - "RIDManagerUseCase", "RIDManagerSetupUseCase", + "RIDManagerUseCase", + "RIDSetGateway", + "RIDSetUseCase", ] diff --git a/app/ldap_protocol/rid_manager/dtos.py b/app/ldap_protocol/rid_manager/dtos.py new file mode 100644 index 000000000..12e324cd0 --- /dev/null +++ b/app/ldap_protocol/rid_manager/dtos.py @@ -0,0 +1,16 @@ +"""RID Manager DTOs. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from dataclasses import dataclass + + +@dataclass +class RIDSetAllocationParamsDTO: + """RID Set DTO.""" + + next_rid: int + previous_allocation_pool: int + allocation_pool: int diff --git a/app/ldap_protocol/rid_manager/exceptions.py b/app/ldap_protocol/rid_manager/exceptions.py index cefa0c3e7..9964f5f77 100644 --- a/app/ldap_protocol/rid_manager/exceptions.py +++ b/app/ldap_protocol/rid_manager/exceptions.py @@ -23,6 +23,9 @@ class ErrorCodes(IntEnum): RID_OBJECT_SID_NOT_FOUND_ERROR = 12 RID_BASE_DOMAIN_NOT_FOUND_ERROR = 13 RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR = 14 + RID_ALLOCATION_POOL_NOT_FOUND_ERROR = 15 + RID_PREVIOUS_ALLOCATION_POOL_NOT_FOUND_ERROR = 16 + RID_POOL_EXCEEDED_ERROR = 17 class RIDManagerError(BaseDomainException): @@ -49,7 +52,7 @@ class RIDManagerAvailablePoolNotFoundError(RIDManagerError): code = ErrorCodes.RID_AVAILABLE_POOL_NOT_FOUND_ERROR -class RIDManagerNextRIDNotFoundError(RIDManagerError): +class RIDManagerRidNextRIDNotFoundError(RIDManagerError): """RID Manager next RID not found error.""" code = ErrorCodes.RID_NEXT_RID_NOT_FOUND_ERROR @@ -101,3 +104,21 @@ class RIDManagerSystemContainerNotFoundError(RIDManagerError): """RID Manager system container not found error.""" code = ErrorCodes.RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR + + +class RIDManagerRidAllocationPoolNotFoundError(RIDManagerError): + """RID Manager RID allocation pool not found error.""" + + code = ErrorCodes.RID_ALLOCATION_POOL_NOT_FOUND_ERROR + + +class RIDManagerRidPreviousAllocationPoolNotFoundError(RIDManagerError): + """RID Manager RID previous allocation pool not found error.""" + + code = ErrorCodes.RID_PREVIOUS_ALLOCATION_POOL_NOT_FOUND_ERROR + + +class RIDManagerPoolExceededError(RIDManagerError): + """RID Manager pool exceeded error.""" + + code = ErrorCodes.RID_POOL_EXCEEDED_ERROR diff --git a/app/ldap_protocol/rid_manager/gateways.py b/app/ldap_protocol/rid_manager/gateways.py deleted file mode 100644 index 1ede7df80..000000000 --- a/app/ldap_protocol/rid_manager/gateways.py +++ /dev/null @@ -1,496 +0,0 @@ -"""RID Manager Gateway. - -Copyright (c) 2025 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE -""" - -import secrets - -from sqlalchemy import select, update -from sqlalchemy.ext.asyncio import AsyncSession - -from entities import Attribute, Directory -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO -from ldap_protocol.rid_manager.exceptions import ( - RIDManagerAvailablePoolNotFoundError, - RIDManagerDomainControllerNotFoundError, - RIDManagerDomainIdentifierNotFoundError, - RIDManagerDomainNotFoundError, - RIDManagerNextRIDNotFoundError, - RIDManagerNotFoundError, - RIDManagerObjectSidNotFoundError, - RIDManagerRidSetNotFoundError, - RIDManagerSystemContainerNotFoundError, -) -from ldap_protocol.utils.queries import get_base_directories -from repo.pg.tables import queryable_attr as qa - - -class RIDManagerGateway: - """Gateway for RID Manager database operations. - - Handles all database operations for RID Manager: - - Reading/writing rIDAvailablePool (global pool in CN=RID Manager$) - - Reading/writing rIDNextRID (local counter, non-replicated) - """ - - def __init__(self, session: AsyncSession) -> None: - """Initialize RID Manager Gateway. - - :param session: SQLAlchemy async session - """ - self._session = session - - async def get_rid_available_pool(self, domain: Directory) -> int: - """Get rIDAvailablePool attribute from domain. - - This is a QWORD (64-bit) value where: - - Lower 32 bits: next available RID - - Upper 32 bits: maximum RID in pool - - :param domain: Domain directory object - :return: QWORD value of rIDAvailablePool - """ - query = await self._session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == domain.id, - qa(Attribute.name) == "rIDAvailablePool", - ), - ) - - if not query or not query.value: - raise RIDManagerAvailablePoolNotFoundError( - "rIDAvailablePool attribute not found", - ) - - return int(query.value) - - async def get_next_rid(self, domain: Directory) -> int: - """Get next RID attribute from domain. - - This is the last issued RID (not the next one, despite the name). - This attribute is NOT replicated. - - :param domain: Domain directory object - :return: Last issued RID or None if not set - """ - query = await self._session.scalar( - select(Attribute) - .where( - qa(Attribute.directory_id) == domain.id, - qa(Attribute.name) == "rIDNextRID", - ) - .with_for_update(), - ) - - if not query or not query.value: - raise RIDManagerNextRIDNotFoundError( - "next RID attribute not found", - ) - return int(query.value) - - async def get_domain_identifier(self, domain: Directory) -> str: - """Get domain identifier. - - :return: Domain identifier - """ - query = await self._session.scalar( - select(Attribute).where( - qa(Attribute.name) == "DomainIdentifier", - qa(Attribute.directory_id) == domain.id, - ), - ) - - if not query or not query.value: - raise RIDManagerDomainIdentifierNotFoundError( - "domain identifier not found", - ) - - return query.value - - async def get_rid_set(self) -> Directory | None: - """Get RID Set directory. - - :return: RID Set directory - """ - return await self._session.scalar( - select(Directory).where(qa(Directory.name) == "RID Set"), - ) - - async def update_next_rid(self, rid_set: Directory, next_rid: int) -> None: - """Update next RID attribute in RID Set directory. - - :param rid_set: RID Set directory - :param next_rid: Next RID - """ - await self._session.execute( - update(Attribute) - .where( - qa(Attribute.directory_id) == rid_set.id, - qa(Attribute.name) == "rIDNextRID", - ) - .values(value=str(next_rid)), - ) - - async def get_rid_manager(self) -> Directory: - """Get RID Manager directory. - - :return: RID Manager directory - """ - rid_manager = await self._session.scalar( - select(Directory).where(qa(Directory.name) == "RID Manager$"), - ) - if not rid_manager: - raise RIDManagerNotFoundError("RID Manager directory not found") - - return rid_manager - - async def update_available_pool( - self, - qword_value: int, - ) -> None: - """Update available pool attribute in RID Manager directory. - - :param rid_manager: RID Manager directory - :param qword_value: QWORD value - """ - rid_manager = await self.get_rid_manager() - await self._session.execute( - update(Attribute) - .where( - qa(Attribute.directory_id) == rid_manager.id, - qa(Attribute.name) == "rIDAvailablePool", - ) - .values(value=str(qword_value)), - ) - - async def add_object_sid( - self, - directory: Directory, - object_sid: str, - ) -> None: - """Add object SID to directory. - - :param directory: Directory - :param object_sid: Object SID - """ - self._session.add( - Attribute( - name="objectSid", - value=object_sid, - directory_id=directory.id, - ), - ) - - async def get_object_sid( - self, - rid_set: Directory, - ) -> str: - """Get object SID from directory. - - :param rid_set: RID Set directory - :return: Object SID - """ - query = await self._session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == rid_set.id, - qa(Attribute.name) == "objectSid", - ), - ) - if not query or not query.value: - raise RIDManagerObjectSidNotFoundError("object SID not found") - return query.value - - async def get_base_domain(self) -> Directory: - """Get base domain directory. - - :return: Base domain directory - """ - base_domain = await self._session.scalar( - select(Directory).where(qa(Directory.object_class) == "domain"), - ) - if not base_domain: - raise RIDManagerDomainNotFoundError("base domain not found") - return base_domain - - -class RIDManagerSetupGateway: - """Gateway for RID Manager setup database operations.""" - - def __init__( - self, - session: AsyncSession, - entity_type_dao: EntityTypeDAO, - ) -> None: - """Initialize RID Manager setup gateway.""" - self._session = session - self._entity_type_dao = entity_type_dao - - async def get_domain_controller(self, host_machine_name: str) -> Directory: - """Get domain controller directory. - - :return: Domain controller directory - """ - dc = await self._session.scalar( - select(Directory).where( - qa(Directory.name) == host_machine_name, - ), - ) - - if not dc: - raise RIDManagerDomainControllerNotFoundError( - "Domain controller not found", - ) - - return dc - - async def get_system_container(self) -> Directory: - """Get System container directory. - - :return: System container directory - """ - base_dn_list = await get_base_directories(self._session) - - domain = base_dn_list[0] - - query = select(Directory).where( - qa(Directory.name) == "System", - qa(Directory.parent_id) == domain.id, - ) - - system_container = await self._session.scalar(query) - - if not system_container: - raise RIDManagerSystemContainerNotFoundError( - "System container not found", - ) - - return system_container - - async def set_rid_manager(self) -> Directory: - """Create RID Manager directory.""" - system_container = await self.get_system_container() - - rid_manager_dir = Directory( - is_system=True, - name="RID Manager$", - ) - rid_manager_dir.create_path(system_container, "cn") - - self._session.add(rid_manager_dir) - await self._session.flush() - - rid_manager_dir.parent_id = system_container.id - await self._session.refresh(rid_manager_dir, ["id"]) - - self._session.add( - Attribute( - name="cn", - value="RID Manager$", - directory_id=rid_manager_dir.id, - ), - ) - - self._session.add( - Attribute( - name="objectClass", - value="top", - directory_id=rid_manager_dir.id, - ), - ) - - self._session.add( - Attribute( - name="objectClass", - value="rIDManager", - directory_id=rid_manager_dir.id, - ), - ) - - await self._session.flush() - - await self._session.refresh( - instance=rid_manager_dir, - attribute_names=["attributes"], - with_for_update=None, - ) - - await self._entity_type_dao.attach_entity_type_to_directory( - directory=rid_manager_dir, - is_system_entity_type=True, - ) - - await self._session.flush() - - return rid_manager_dir - - async def create_rid_set( - self, - domain_controller: Directory, - ) -> Directory: - """Create CN=RID Set directory under Domain Controller. - - :param domain_controller: Domain Controller directory object - :return: Created RID Set directory - """ - rid_set_dir = Directory( - is_system=True, - name="RID Set", - ) - rid_set_dir.create_path(domain_controller, "cn") - - self._session.add(rid_set_dir) - await self._session.flush() - - rid_set_dir.parent_id = domain_controller.id - await self._session.refresh(rid_set_dir, ["id"]) - - self._session.add( - Attribute( - name="cn", - value="RID Set", - directory_id=rid_set_dir.id, - ), - ) - - self._session.add( - Attribute( - name="objectClass", - value="top", - directory_id=rid_set_dir.id, - ), - ) - - self._session.add( - Attribute( - name="objectClass", - value="rIDSet", - directory_id=rid_set_dir.id, - ), - ) - - await self._session.flush() - - await self._session.refresh( - instance=rid_set_dir, - attribute_names=["attributes"], - with_for_update=None, - ) - - await self._entity_type_dao.attach_entity_type_to_directory( - directory=rid_set_dir, - is_system_entity_type=True, - ) - - await self._session.flush() - - return rid_set_dir - - async def set_rid_available_pool( - self, - domain: Directory, - qword_value: int, - ) -> None: - """Set rIDAvailablePool attribute in domain. - - Updates the global RID pool counter. - - :param domain: Domain directory object - :param qword_value: New QWORD value (64-bit) - """ - query = ( - update(Attribute) - .where( - qa(Attribute.directory_id) == domain.id, - qa(Attribute.name) == "rIDAvailablePool", - ) - .values(value=str(qword_value)) - ) - - result = await self._session.execute(query) - - if result.rowcount == 0: - self._session.add( - Attribute( - directory_id=domain.id, - name="rIDAvailablePool", - value=str(qword_value), - ), - ) - - await self._session.flush() - - async def set_next_rid( - self, - domain: Directory, - rid: int, - ) -> None: - """Set next RID attribute in domain. - - Updates the last issued RID counter. - - :param domain: Domain directory object - :param rid: Last issued RID value - """ - self._session.add( - Attribute( - directory_id=domain.id, - name="rIDNextRID", - value=str(rid), - ), - ) - - await self._session.flush() - - def _generate_domain_sid_identifier(self) -> str: - """Generate Domain Identifier for Active Directory domain.""" - return ( - f"{secrets.randbits(32)}" - f"-{secrets.randbits(32)}-{secrets.randbits(32)}" - ) - - async def create_domain_identifier(self) -> None: - """Add domain identifier to domain.""" - domain = await self._session.scalar( - select(Directory).where( - qa(Directory.object_class) == "domain", - ), - ) - if not domain: - raise RIDManagerDomainNotFoundError("Domain not found") - - self._session.add( - Attribute( - name="DomainIdentifier", - value=f"{self._generate_domain_sid_identifier()}", - directory_id=domain.id, - ), - ) - await self._session.flush() - - async def get_domain_identifier(self) -> str: - """Get domain identifier.""" - domain = await self._session.scalar( - select(Attribute).where( - qa(Attribute.name) == "DomainIdentifier", - ), - ) - if not domain or not domain.value: - raise RIDManagerDomainIdentifierNotFoundError("Domain not found") - return domain.value - - async def get_rid_set(self, domain_controller: Directory) -> Directory: - """Get RID Set directory. - - :param domain_controller: Domain controller directory - :return: RID Set directory - """ - rid_set = await self._session.scalar( - select(Directory).where( - qa(Directory.name) == "RID Set", - qa(Directory.parent_id) == domain_controller.id, - ), - ) - if not rid_set: - raise RIDManagerRidSetNotFoundError("RID Set directory not found") - return rid_set diff --git a/app/ldap_protocol/rid_manager/object_sid_gateway.py b/app/ldap_protocol/rid_manager/object_sid_gateway.py new file mode 100644 index 000000000..3f7d25683 --- /dev/null +++ b/app/ldap_protocol/rid_manager/object_sid_gateway.py @@ -0,0 +1,60 @@ +"""Object SID gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Attribute, Directory +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerDomainIdentifierNotFoundError, +) +from repo.pg.tables import queryable_attr as qa + + +class ObjectSIDGateway: + """Object SID gateway.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize Object SID gateway.""" + self._session = session + + async def get(self, directory: Directory) -> str: + """Get object SID.""" + return await self._session.scalar( + select(Attribute).where( + qa(Attribute.directory_id) == directory.id, + qa(Attribute.name) == "objectSid", + ), + ) + + async def add(self, directory: Directory, object_sid: str) -> None: + """Add object SID.""" + self._session.add( + Attribute( + name="objectSid", + value=object_sid, + directory_id=directory.id, + ), + ) + + async def get_domain_identifier(self, domain: Directory) -> str: + """Get domain identifier. + + :return: Domain identifier + """ + query = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "DomainIdentifier", + qa(Attribute.directory_id) == domain.id, + ), + ) + + if not query or not query.value: + raise RIDManagerDomainIdentifierNotFoundError( + "domain identifier not found", + ) + + return query.value diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py new file mode 100644 index 000000000..9ae878ace --- /dev/null +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -0,0 +1,63 @@ +"""Object SID use case. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Directory +from enums import SidPrefix +from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase +from ldap_protocol.utils.queries import get_base_directories + + +class ObjectSIDUseCase: + """Object SID use case.""" + + def __init__( + self, + gateway: ObjectSIDGateway, + rid_set_use_case: RIDSetUseCase, + session: AsyncSession, + rid_manager_use_case: RIDManagerUseCase, + ) -> None: + """Initialize Object SID use case.""" + self._gateway = gateway + self._rid_set_use_case = rid_set_use_case + self._session = session + self._rid_manager_use_case = rid_manager_use_case + + async def get(self, directory: Directory) -> str: + """Get object SID.""" + return await self._gateway.get(directory) + + async def add( + self, + directory: Directory, + rid: int | None = None, + sid_prefix: SidPrefix = SidPrefix.DOMAIN_IDENTIFIER, + ) -> None: + """Add object SID.""" + if rid is None: + domain_controller = await self._rid_manager_use_case.choose_nearest_domain_controller() # noqa + rid_set = await self._rid_set_use_case.get(domain_controller) + rid = await self._rid_set_use_case.allocate_next_rid( + rid_set, + ) + + if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: + object_sid = f"{sid_prefix}-{rid}" + elif sid_prefix == SidPrefix.DOMAIN_IDENTIFIER: + domain_identifier = await self.get_domain_identifier() + object_sid = f"{sid_prefix}-{domain_identifier}-{rid}" + + await self._gateway.add(directory, object_sid) + + async def get_domain_identifier(self) -> str: + """Get domain identifier.""" + domain = (await get_base_directories(self._session))[0] + + return await self._gateway.get_domain_identifier(domain) diff --git a/app/ldap_protocol/rid_manager/rid_manager_gateway.py b/app/ldap_protocol/rid_manager/rid_manager_gateway.py new file mode 100644 index 000000000..69cf46c92 --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_manager_gateway.py @@ -0,0 +1,69 @@ +"""RID Manager gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from constants import DOMAIN_CONTROLLERS_OU_NAME +from entities import Attribute, Directory +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerAvailablePoolNotFoundError, + RIDManagerDomainControllerNotFoundError, + RIDManagerNotFoundError, +) +from repo.pg.tables import queryable_attr as qa + + +class RIDManagerGateway: + """RID Manager gateway.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize RID Manager gateway.""" + self._session = session + + async def get_rid_manager(self) -> Directory: + """Get RID Manager directory.""" + rid_manager = await self._session.scalar( + select(Directory).where(qa(Directory.name) == "RID Manager$"), + ) + if not rid_manager: + raise RIDManagerNotFoundError("RID Manager directory not found") + return rid_manager + + async def get_rid_available_pool(self) -> int: + """Get RID available pool.""" + rid_available_pool = await self._session.scalar( + select(Attribute).where(qa(Attribute.name) == "rIDAvailablePool"), + ) + if not (rid_available_pool and rid_available_pool.value): + raise RIDManagerAvailablePoolNotFoundError( + "RID available pool not found", + ) + return int(rid_available_pool.value) + + async def update_rid_available_pool(self, available_pool: int) -> None: + """Update RID available pool.""" + await self._session.execute( + update(Attribute) + .where(qa(Attribute.name) == "rIDAvailablePool") + .values(value=str(available_pool)), + ) + + async def get_domain_controller( + self, + name: str = DOMAIN_CONTROLLERS_OU_NAME, + ) -> Directory: + """Get domain controller.""" + domain_controllers_ou = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == name, + ), + ) + if not domain_controllers_ou: + raise RIDManagerDomainControllerNotFoundError( + "Domain controller not found", + ) + return domain_controllers_ou diff --git a/app/ldap_protocol/rid_manager/rid_manager_use_case.py b/app/ldap_protocol/rid_manager/rid_manager_use_case.py new file mode 100644 index 000000000..5ce06dcbb --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_manager_use_case.py @@ -0,0 +1,48 @@ +"""RID Manager use case. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Directory +from ldap_protocol.rid_manager.exceptions import RIDManagerPoolExceededError +from ldap_protocol.rid_manager.rid_manager_gateway import RIDManagerGateway +from ldap_protocol.rid_manager.utils import from_qword, to_qword + + +class RIDManagerUseCase: + """RID Manager use case.""" + + RID_BLOCK_SIZE = 500 + # NOTE Domain Controller(with role Rid Master) attr + # replace and change logic, when super DC is introduced + + def __init__( + self, + gateway: RIDManagerGateway, + session: AsyncSession, + ) -> None: + """Initialize RID Manager use case.""" + self._gateway = gateway + self._session = session + + async def allocate_pool(self) -> int: + """Allocate pool.""" + available_pool = await self._gateway.get_rid_available_pool() + lower, upper = from_qword(available_pool) + + if lower + self.RID_BLOCK_SIZE > upper: + raise RIDManagerPoolExceededError("Available pool exceeded") + + new_available_pool = to_qword(lower + self.RID_BLOCK_SIZE, upper) + await self._gateway.update_rid_available_pool(new_available_pool) + + return to_qword(lower, lower + self.RID_BLOCK_SIZE) + + async def choose_nearest_domain_controller(self) -> Directory: + """Locate best Domain Controller via DNS SRV records.""" + # TODO: нужно через DNS определять ближайший DC # noqa + # и использовать его для выдачи RID + return await self._gateway.get_domain_controller() diff --git a/app/ldap_protocol/rid_manager/rid_set_gateway.py b/app/ldap_protocol/rid_manager/rid_set_gateway.py new file mode 100644 index 000000000..8f3591bf5 --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_set_gateway.py @@ -0,0 +1,204 @@ +"""RID Set gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Attribute, Directory +from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerRidAllocationPoolNotFoundError, + RIDManagerRidNextRIDNotFoundError, + RIDManagerRidPreviousAllocationPoolNotFoundError, + RIDManagerRidSetNotFoundError, +) +from repo.pg.tables import queryable_attr as qa + + +class RIDSetGateway: + """RID Set gateway.""" + + def __init__(self, session: AsyncSession) -> None: + """Initialize RID Set gateway.""" + self._session = session + + async def get(self, domain_controller: Directory) -> Directory: + """Get RID Set directory.""" + rid_set = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == "RID Set", + qa(Directory.parent_id) == domain_controller.id, + ), + ) + if not rid_set: + raise RIDManagerRidSetNotFoundError("RID Set directory not found") + + return rid_set + + async def add(self, domain_controller: Directory) -> Directory: + """Add RID Set directory.""" + rid_set_dir = Directory( + is_system=True, + name="RID Set", + ) + rid_set_dir.create_path(domain_controller, "cn") + + self._session.add(rid_set_dir) + await self._session.flush() + + rid_set_dir.parent_id = domain_controller.id + await self._session.refresh(rid_set_dir, ["id"]) + + self._session.add( + Attribute( + name="cn", + value="RID Set", + directory_id=rid_set_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="top", + directory_id=rid_set_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="rIDSet", + directory_id=rid_set_dir.id, + ), + ) + + await self._session.flush() + + await self._session.refresh( + instance=rid_set_dir, + attribute_names=["attributes"], + with_for_update=None, + ) + + return rid_set_dir + + async def set_allocation_attrs( + self, + rid_set: Directory, + allocation_params: RIDSetAllocationParamsDTO, + ) -> None: + """Set next RID attribute in RID Set directory.""" + self._session.add( + Attribute( + name="rIDNextRID", + value=str(allocation_params.next_rid), + directory_id=rid_set.id, + ), + ) + self._session.add( + Attribute( + name="rIDPreviousAllocationPool", + value=str(allocation_params.previous_allocation_pool), + directory_id=rid_set.id, + ), + ) + self._session.add( + Attribute( + name="rIDAllocationPool", + value=str(allocation_params.allocation_pool), + directory_id=rid_set.id, + ), + ) + + async def get_rid_allocation_pool(self, rid_set: Directory) -> int: + """Get RID allocation pool from RID Set directory.""" + allocation_pool = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "rIDAllocationPool", + qa(Attribute.directory_id) == rid_set.id, + ), + ) + if not (allocation_pool and allocation_pool.value): + raise RIDManagerRidAllocationPoolNotFoundError( + "RID allocation pool not found", + ) + return int(allocation_pool.value) + + async def get_rid_previous_allocation_pool( + self, + rid_set: Directory, + ) -> int: + """Get previous RID allocation pool from RID Set directory.""" + previous_allocation_pool = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "rIDPreviousAllocationPool", + qa(Attribute.directory_id) == rid_set.id, + ), + ) + if not (previous_allocation_pool and previous_allocation_pool.value): + raise RIDManagerRidPreviousAllocationPoolNotFoundError( + "previous RID allocation pool not found", + ) + return int(previous_allocation_pool.value) + + async def get_rid_next_rid(self, rid_set: Directory) -> int: + """Get next RID from RID Set directory.""" + next_rid = await self._session.scalar( + select(Attribute).where( + qa(Attribute.name) == "rIDNextRID", + qa(Attribute.directory_id) == rid_set.id, + ), + ) + if not (next_rid and next_rid.value): + raise RIDManagerRidNextRIDNotFoundError("next RID not found") + return int(next_rid.value) + + async def update_next_rid_and_pool( + self, + rid_set: Directory, + next_rid: int, + previous_allocation_pool: int, + ) -> None: + """Update next RID and pool.""" + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "rIDNextRID", + qa(Attribute.directory_id) == rid_set.id, + ) + .values(value=str(next_rid)), + ) + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "rIDPreviousAllocationPool", + qa(Attribute.directory_id) == rid_set.id, + ) + .values(value=str(previous_allocation_pool)), + ) + + async def reset_attrs_when_pool_exceeded( + self, + rid_set: Directory, + allocation_pool: int, + previous_allocation_pool: int, + next_rid: int, + ) -> None: + """Reset RID pools when pool exceeded.""" + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "rIDAllocationPool", + qa(Attribute.directory_id) == rid_set.id, + ) + .values(value=str(allocation_pool)), + ) + await self.update_next_rid_and_pool( + rid_set, + next_rid, + previous_allocation_pool, + ) diff --git a/app/ldap_protocol/rid_manager/rid_set_use_case.py b/app/ldap_protocol/rid_manager/rid_set_use_case.py new file mode 100644 index 000000000..8c4c7ceed --- /dev/null +++ b/app/ldap_protocol/rid_manager/rid_set_use_case.py @@ -0,0 +1,107 @@ +"""RID Set use case. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Directory +from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway +from ldap_protocol.rid_manager.utils import from_qword, to_qword + + +class RIDSetUseCase: + """RID Set use case.""" + + def __init__( + self, + gateway: RIDSetGateway, + entity_type_dao: EntityTypeDAO, + session: AsyncSession, + rid_manager_use_case: RIDManagerUseCase, + ) -> None: + """Initialize RID Set use case.""" + self._gateway = gateway + self._entity_type_dao = entity_type_dao + self._session = session + self._rid_manager_use_case = rid_manager_use_case + + async def get(self, domain_controller: Directory) -> Directory: + """Get RID Set directory.""" + return await self._gateway.get(domain_controller) + + async def add( + self, + domain_controller: Directory, + allocation_params: RIDSetAllocationParamsDTO, + ) -> Directory: + """Create RID Set directory.""" + rid_set = await self._gateway.add(domain_controller) + await self._entity_type_dao.attach_entity_type_to_directory( + directory=rid_set, + is_system_entity_type=True, + ) + + await self._gateway.set_allocation_attrs( + rid_set, + allocation_params, + ) + await self._session.flush() + return rid_set + + async def is_pool_exceeded(self, rid_set: Directory) -> bool: + """Check if RID pool is exceeded.""" + previous_allocation_pool = ( + await self._gateway.get_rid_previous_allocation_pool(rid_set) + ) + _, upper = from_qword(previous_allocation_pool) + next_rid = await self._gateway.get_rid_next_rid(rid_set) + + return next_rid + 1 >= upper + + async def allocate_next_rid(self, rid_set: Directory) -> int: + """Allocate next RID.""" + if await self.is_pool_exceeded(rid_set): + previous_allocation_pool = ( + await self._rid_manager_use_case.allocate_pool() + ) + await self.reset_attrs_when_pool_exceeded( + rid_set, + previous_allocation_pool, + ) + current_rid = await self._gateway.get_rid_next_rid(rid_set) + previous_allocation_pool = ( + await self._gateway.get_rid_previous_allocation_pool(rid_set) + ) + _, upper = from_qword(previous_allocation_pool) + new_rid = current_rid + 1 + new_allocation_pool = to_qword(new_rid, upper) + await self._gateway.update_next_rid_and_pool( + rid_set, + new_rid, + new_allocation_pool, + ) + return new_rid + + async def reset_attrs_when_pool_exceeded( + self, + rid_set: Directory, + previous_allocation_pool: int, + ) -> None: + """Reset RID pools when pool exceeded.""" + current_previous_allocation_pool = ( + await self._gateway.get_rid_previous_allocation_pool( + rid_set, + ) + ) + lower, _ = from_qword(previous_allocation_pool) + await self._gateway.reset_attrs_when_pool_exceeded( + rid_set=rid_set, + next_rid=lower, + allocation_pool=current_previous_allocation_pool, + previous_allocation_pool=previous_allocation_pool, + ) diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py new file mode 100644 index 000000000..511483fb6 --- /dev/null +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -0,0 +1,184 @@ +"""RID Manager Gateway. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE +""" + +import secrets + +from sqlalchemy import select, update +from sqlalchemy.ext.asyncio import AsyncSession + +from entities import Attribute, Directory +from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.rid_manager.exceptions import ( + RIDManagerDomainControllerNotFoundError, + RIDManagerSystemContainerNotFoundError, +) +from ldap_protocol.utils.queries import get_base_directories +from repo.pg.tables import queryable_attr as qa + + +class RIDManagerSetupGateway: + """Gateway for RID Manager setup database operations.""" + + def __init__( + self, + session: AsyncSession, + entity_type_dao: EntityTypeDAO, + ) -> None: + """Initialize RID Manager setup gateway.""" + self._session = session + self._entity_type_dao = entity_type_dao + + async def get_domain_controller(self, host_machine_name: str) -> Directory: + """Get domain controller directory. + + :return: Domain controller directory + """ + dc = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == host_machine_name, + ), + ) + + if not dc: + raise RIDManagerDomainControllerNotFoundError( + "Domain controller not found", + ) + + return dc + + async def get_system_container(self) -> Directory: + """Get System container directory. + + :return: System container directory + """ + base_dn_list = await get_base_directories(self._session) + + domain = base_dn_list[0] + + query = select(Directory).where( + qa(Directory.name) == "System", + qa(Directory.parent_id) == domain.id, + ) + + system_container = await self._session.scalar(query) + + if not system_container: + raise RIDManagerSystemContainerNotFoundError( + "System container not found", + ) + + return system_container + + async def set_rid_manager(self) -> Directory: + """Create RID Manager directory.""" + system_container = await self.get_system_container() + + rid_manager_dir = Directory( + is_system=True, + name="RID Manager$", + ) + rid_manager_dir.create_path(system_container, "cn") + + self._session.add(rid_manager_dir) + await self._session.flush() + + rid_manager_dir.parent_id = system_container.id + await self._session.refresh(rid_manager_dir, ["id"]) + + self._session.add( + Attribute( + name="cn", + value="RID Manager$", + directory_id=rid_manager_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="top", + directory_id=rid_manager_dir.id, + ), + ) + + self._session.add( + Attribute( + name="objectClass", + value="rIDManager", + directory_id=rid_manager_dir.id, + ), + ) + + await self._session.flush() + + await self._session.refresh( + instance=rid_manager_dir, + attribute_names=["attributes"], + with_for_update=None, + ) + + await self._entity_type_dao.attach_entity_type_to_directory( + directory=rid_manager_dir, + is_system_entity_type=True, + ) + + await self._session.flush() + + return rid_manager_dir + + async def set_rid_available_pool( + self, + domain: Directory, + qword_value: int, + ) -> None: + """Set rIDAvailablePool attribute in domain. + + Updates the global RID pool counter. + + :param domain: Domain directory object + :param qword_value: New QWORD value (64-bit) + """ + query = ( + update(Attribute) + .where( + qa(Attribute.directory_id) == domain.id, + qa(Attribute.name) == "rIDAvailablePool", + ) + .values(value=str(qword_value)) + ) + + result = await self._session.execute(query) + + if result.rowcount == 0: + self._session.add( + Attribute( + directory_id=domain.id, + name="rIDAvailablePool", + value=str(qword_value), + ), + ) + + await self._session.flush() + + def _generate_domain_sid_identifier(self) -> str: + """Generate Domain Identifier for Active Directory domain.""" + return ( + f"{secrets.randbits(32)}" + f"-{secrets.randbits(32)}-{secrets.randbits(32)}" + ) + + async def create_domain_identifier(self) -> None: + """Add domain identifier to domain.""" + domain = (await get_base_directories(self._session))[0] + + self._session.add( + Attribute( + name="DomainIdentifier", + value=f"{self._generate_domain_sid_identifier()}", + directory_id=domain.id, + ), + ) + await self._session.flush() diff --git a/app/ldap_protocol/rid_manager/setup_use_case.py b/app/ldap_protocol/rid_manager/setup_use_case.py new file mode 100644 index 000000000..97e7eced8 --- /dev/null +++ b/app/ldap_protocol/rid_manager/setup_use_case.py @@ -0,0 +1,110 @@ +"""RID Manager for issuing RID from pools. + +Copyright (c) 2025 MultiFactor +License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE + +""" + +from config import Settings +from entities import Directory +from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase +from ldap_protocol.rid_manager.setup_gateway import RIDManagerSetupGateway +from ldap_protocol.rid_manager.utils import from_qword, to_qword +from ldap_protocol.roles.ace_dao import AccessControlEntryDAO +from ldap_protocol.roles.role_use_case import RoleUseCase + + +class RIDManagerSetupUseCase: + """RID Manager setup use case.""" + + RID_BUILTIN_MIN = 500 + RID_BUILTIN_MAX = 1000 + RID_USER_MIN = 1100 + RID_AVAILABLE_MAX = 1073741822 # 30-bit max (2^30 - 2) + + def __init__( + self, + rid_manager_setup_gateway: RIDManagerSetupGateway, + role_use_case: RoleUseCase, + access_control_entry_dao: AccessControlEntryDAO, + rid_set_use_case: RIDSetUseCase, + rid_manager_use_case: RIDManagerUseCase, + settings: Settings, + ) -> None: + """Initialize RID Manager setup use case. + + :param rid_manager_setup_gateway: Gateway for setup operations + :param role_use_case: Role use case + """ + self._gateway = rid_manager_setup_gateway + self._role_use_case = role_use_case + self._access_control_entry_dao = access_control_entry_dao + self._settings = settings + self._rid_set_use_case = rid_set_use_case + self._rid_manager_use_case = rid_manager_use_case + + async def setup(self) -> None: + """Create RID Manager.""" + await self.create_domain_identifier() + rid_manager_dir = await self._gateway.set_rid_manager() + qword = to_qword(self.RID_USER_MIN, self.RID_AVAILABLE_MAX) + await self._gateway.set_rid_available_pool( + rid_manager_dir, + qword, + ) + dc = ( + await self._rid_manager_use_case.choose_nearest_domain_controller() + ) + rid_set = await self._create_rid_set(dc) + + await self.inherit_aces( + rid_manager_dir, + dc, + rid_set, + ) + + async def _create_rid_set(self, domain_controller: Directory) -> Directory: + previous_allocation_pool = ( + await self._rid_manager_use_case.allocate_pool() + ) + allocation_pool = await self._rid_manager_use_case.allocate_pool() + lower, _ = from_qword(previous_allocation_pool) + + return await self._rid_set_use_case.add( + domain_controller, + RIDSetAllocationParamsDTO( + next_rid=lower, + allocation_pool=allocation_pool, + previous_allocation_pool=previous_allocation_pool, + ), + ) + + async def inherit_aces( + self, + rid_manager_dir: Directory, + domain_controller: Directory, + rid_set: Directory, + ) -> None: + """Inherit ACEs from domain root to RID Manager directory. + + Instead of creating a special ACE or role for RID Manager, + we reuse the existing ACL model: all ACEs that apply to the + domain root (including Domain Admins) are inherited by the + `CN=RID Manager$` directory, similar to how it is done in + migration `ebf19750805e_add_domain_controllers_ou`. + """ + await self._role_use_case.inherit_parent_aces( + parent_directory=await self._gateway.get_system_container(), + directory=rid_manager_dir, + ) + + await self._role_use_case.inherit_parent_aces( + parent_directory=domain_controller, + directory=rid_set, + ) + + async def create_domain_identifier(self) -> None: + """Create domain identifier.""" + await self._gateway.create_domain_identifier() diff --git a/app/ldap_protocol/rid_manager/use_cases.py b/app/ldap_protocol/rid_manager/use_cases.py deleted file mode 100644 index 31fddb4ec..000000000 --- a/app/ldap_protocol/rid_manager/use_cases.py +++ /dev/null @@ -1,177 +0,0 @@ -"""RID Manager for issuing RID from pools. - -Copyright (c) 2025 MultiFactor -License: https://github.com/MultiDirectoryLab/MultiDirectory/blob/main/LICENSE - -""" - -import asyncio - -from sqlalchemy.ext.asyncio import AsyncSession - -from config import Settings -from entities import Directory -from enums import SidPrefix -from ldap_protocol.rid_manager.exceptions import RIDManagerRidSetNotFoundError -from ldap_protocol.rid_manager.gateways import ( - RIDManagerGateway, - RIDManagerSetupGateway, -) -from ldap_protocol.rid_manager.utils import create_qword -from ldap_protocol.roles.ace_dao import AccessControlEntryDAO -from ldap_protocol.roles.role_use_case import RoleUseCase - -RID_AVAILABLE_MAX = 1073741822 # 30-bit max (2^30 - 2) - - -class RIDManagerUseCase: - """RID Manager Use Case for issuing RID from pools.""" - - def __init__( - self, - gateway: RIDManagerGateway, - session: AsyncSession, - ) -> None: - """Initialize RID Manager Use Case. - - :param gateway: RID Manager Gateway for database operations - """ - self._gateway = gateway - self._lock = asyncio.Lock() - self._session = session - - async def get_object_sid( - self, - directory: Directory, - ) -> str: - """Get object SID for directory.""" - return await self._gateway.get_object_sid(directory) - - async def get_rid_set(self) -> Directory | None: - """Get RID Set directory.""" - return await self._gateway.get_rid_set() - - async def set_object_sid( - self, - directory: Directory, - rid: int | None = None, - sid_prefix: SidPrefix = SidPrefix.DOMAIN_IDENTIFIER, - ) -> None: - """Create object SID.""" - async with self._lock, await self._session.begin_nested(): - if rid is None: - rid_set = await self._gateway.get_rid_set() - if not rid_set: - raise RIDManagerRidSetNotFoundError( - "RID Set directory not found", - ) - - next_rid = await self._gateway.get_next_rid(rid_set) - rid = next_rid + 1 - await self._gateway.update_next_rid(rid_set, rid) - await self._gateway.update_available_pool( - create_qword(rid, RID_AVAILABLE_MAX), - ) - - if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: - sid = f"{sid_prefix}-{rid}" - elif sid_prefix == SidPrefix.DOMAIN_IDENTIFIER: - base_domain = await self._gateway.get_base_domain() - domain_identifier = await self._gateway.get_domain_identifier( - base_domain, - ) - sid = f"{sid_prefix}-{domain_identifier}-{rid}" - - await self._gateway.add_object_sid(directory, sid) - - await self._session.flush() - - async def parse_object_sid(self, object_sid: str) -> tuple[str, str, int]: - """Parse object SID. - - :param object_sid: Object SID - :return: Tuple containing domain identifier, rid, and reserved flag - """ - parts = object_sid.split("-") - return parts[1], parts[2], int(parts[3]) - - -class RIDManagerSetupUseCase: - """RID Manager setup use case.""" - - RID_SYSTEM_MIN = 1 - RID_SYSTEM_MAX = 499 - RID_BUILTIN_MIN = 500 - RID_BUILTIN_MAX = 1000 - RID_USER_MIN = 1100 - - def __init__( - self, - rid_manager_setup_gateway: RIDManagerSetupGateway, - role_use_case: RoleUseCase, - access_control_entry_dao: AccessControlEntryDAO, - settings: Settings, - ) -> None: - """Initialize RID Manager setup use case. - - :param rid_manager_setup_gateway: Gateway for setup operations - :param role_use_case: Role use case - """ - self._gateway = rid_manager_setup_gateway - self._role_use_case = role_use_case - self._access_control_entry_dao = access_control_entry_dao - self._settings = settings - - async def setup(self) -> None: - """Create RID Manager.""" - rid_manager_dir = await self._gateway.set_rid_manager() - - qword = create_qword(self.RID_USER_MIN, RID_AVAILABLE_MAX) - - await self._gateway.set_rid_available_pool( - rid_manager_dir, - qword, - ) - domain_controller = await self._gateway.get_domain_controller( - self._settings.HOST_MACHINE_NAME, - ) - - rid_set_dir = await self._gateway.create_rid_set( - domain_controller, - ) - await self._gateway.set_next_rid( - rid_set_dir, - self.RID_USER_MIN, - ) - await self.inherit_aces( - rid_manager_dir, - ) - - async def inherit_aces( - self, - rid_manager_dir: Directory, - ) -> None: - """Inherit ACEs from domain root to RID Manager directory. - - Instead of creating a special ACE or role for RID Manager, - we reuse the existing ACL model: all ACEs that apply to the - domain root (including Domain Admins) are inherited by the - `CN=RID Manager$` directory, similar to how it is done in - migration `ebf19750805e_add_domain_controllers_ou`. - """ - await self._role_use_case.inherit_parent_aces( - parent_directory=await self._gateway.get_system_container(), - directory=rid_manager_dir, - ) - - domain_controller = await self._gateway.get_domain_controller( - self._settings.HOST_MACHINE_NAME, - ) - await self._role_use_case.inherit_parent_aces( - parent_directory=domain_controller, - directory=await self._gateway.get_rid_set(domain_controller), - ) - - async def create_domain_identifier(self) -> None: - """Create domain identifier.""" - await self._gateway.create_domain_identifier() diff --git a/app/ldap_protocol/rid_manager/utils.py b/app/ldap_protocol/rid_manager/utils.py index d99df16fc..eb6f3835b 100644 --- a/app/ldap_protocol/rid_manager/utils.py +++ b/app/ldap_protocol/rid_manager/utils.py @@ -1,7 +1,7 @@ """RID Manager utils.""" -def create_qword(lower: int, upper: int) -> int: +def to_qword(lower: int, upper: int) -> int: """Create QWORD (64-bit) from two DWORDs (32-bit each).""" if lower < 0 or lower > 0xFFFFFFFF: raise ValueError(f"Lower boundary out of range: {lower}") @@ -11,3 +11,13 @@ def create_qword(lower: int, upper: int) -> int: qword = (upper << 32) | lower return qword + + +def from_qword(qword: int) -> tuple[int, int]: + """Split QWORD (64-bit) into two DWORDs (lower, upper).""" + if qword < 0 or qword > 0xFFFFFFFFFFFFFFFF: + raise ValueError(f"QWORD out of range: {qword}") + + lower = qword & 0xFFFFFFFF + upper = (qword >> 32) & 0xFFFFFFFF + return lower, upper diff --git a/app/ldap_protocol/rootdse/reader.py b/app/ldap_protocol/rootdse/reader.py index 20503b4d4..6ceacfe7a 100644 --- a/app/ldap_protocol/rootdse/reader.py +++ b/app/ldap_protocol/rootdse/reader.py @@ -8,7 +8,7 @@ from config import Settings from constants import DEFAULT_DC_POSTFIX, UNC_PREFIX -from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.utils.helpers import get_generalized_now from .dto import DomainControllerInfo @@ -92,17 +92,17 @@ def __init__( self, settings: Settings, gw: DomainReadProtocol, - rid_manager: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: self._settings = settings self._gw = gw - self._rid_manager = rid_manager + self._object_sid_use_case = object_sid_use_case async def get(self) -> DomainControllerInfo: domain = await self._gw.get_domain() dns = domain.name.lower() nb_domain = dns.split(".")[0].upper() - object_sid = await self._rid_manager.get_object_sid(domain) + object_sid = await self._object_sid_use_case.get(domain) return DomainControllerInfo( net_bios_domain=nb_domain, diff --git a/tests/conftest.py b/tests/conftest.py index 4e379b809..8263c2c76 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -175,13 +175,15 @@ PasswordBanWordUseCases, UserPasswordHistoryUseCases, ) -from ldap_protocol.rid_manager.gateways import ( +from ldap_protocol.rid_manager import ( + ObjectSIDGateway, + ObjectSIDUseCase, RIDManagerGateway, RIDManagerSetupGateway, -) -from ldap_protocol.rid_manager.use_cases import ( RIDManagerSetupUseCase, RIDManagerUseCase, + RIDSetGateway, + RIDSetUseCase, ) from ldap_protocol.roles.access_manager import AccessManager from ldap_protocol.roles.ace_dao import AccessControlEntryDAO @@ -838,6 +840,10 @@ def authorization_provider_protocol( RIDManagerSetupUseCase, scope=Scope.REQUEST, ) + object_sid_gateway = provide(ObjectSIDGateway, scope=Scope.REQUEST) + object_sid_use_case = provide(ObjectSIDUseCase, scope=Scope.REQUEST) + rid_set_gateway = provide(RIDSetGateway, scope=Scope.REQUEST) + rid_set_use_case = provide(RIDSetUseCase, scope=Scope.REQUEST) @dataclass @@ -1130,11 +1136,32 @@ async def setup_session( role_dao = RoleDAO(session) ace_dao = AccessControlEntryDAO(session) role_use_case = RoleUseCase(role_dao, ace_dao) + rid_manager_use_case = RIDManagerUseCase( + rid_manager_gateway, + session, + ) + rid_set_gateway = RIDSetGateway(session) + + rid_set_use_case = RIDSetUseCase( + rid_set_gateway, + entity_type_dao, + session, + rid_manager_use_case, + ) + object_sid_gateway = ObjectSIDGateway(session) + object_sid_use_case = ObjectSIDUseCase( + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + ) rid_manager_setup_use_case = RIDManagerSetupUseCase( rid_manager_setup_gateway=rid_manager_setup_gateway, role_use_case=role_use_case, + rid_set_use_case=rid_set_use_case, access_control_entry_dao=AccessControlEntryDAO(session), settings=settings, + rid_manager_use_case=rid_manager_use_case, ) setup_gateway = SetupGateway( session, @@ -1142,7 +1169,7 @@ async def setup_session( entity_type_use_case=entity_type_use_case, attribute_value_validator=attribute_value_validator, directory_dao=directory_dao, - rid_manager_use_case=rid_manager_use_case, + object_sid_use_case=object_sid_use_case, ) for entity_type_dto in chain(ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2): await entity_type_use_case.create_not_safe(entity_type_dto) diff --git a/tests/test_api/test_main/test_router/conftest.py b/tests/test_api/test_main/test_router/conftest.py index f11a0c259..c7f2427cd 100644 --- a/tests/test_api/test_main/test_router/conftest.py +++ b/tests/test_api/test_main/test_router/conftest.py @@ -19,7 +19,7 @@ from ldap_protocol.ldap_schema.object_class.object_class_dao import ( ObjectClassDAO, ) -from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase +from ldap_protocol.rid_manager.object_sid_use_case import ObjectSIDUseCase from ldap_protocol.utils.queries import get_base_directories from password_utils import PasswordUtils from tests.constants import TEST_SYSTEM_ADMIN_DATA @@ -30,7 +30,7 @@ async def add_system_administrator( session: AsyncSession, password_utils: PasswordUtils, setup_session: None, # noqa: ARG001 - rid_manager_use_case: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Create system administrator user for tests that require it.""" attribute_value_validator = AttributeValueValidator() @@ -53,7 +53,7 @@ async def add_system_administrator( entity_type_use_case, attribute_value_validator=attribute_value_validator, directory_dao=directory_dao, - rid_manager_use_case=rid_manager_use_case, + object_sid_use_case=object_sid_use_case, ) domain = (await get_base_directories(session))[0] diff --git a/tests/test_ldap/test_rid_manager.py b/tests/test_ldap/test_rid_manager.py index d13dccc88..25ad64e10 100644 --- a/tests/test_ldap/test_rid_manager.py +++ b/tests/test_ldap/test_rid_manager.py @@ -1,63 +1,51 @@ """Tests for RID Manager.""" -import pytest -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import selectinload -from entities import Directory -from enums import SidPrefix -from ldap_protocol.rid_manager.gateways import RIDManagerGateway -from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase -from ldap_protocol.utils.queries import get_filter_from_path -from repo.pg.tables import queryable_attr as qa - - -@pytest.mark.asyncio -@pytest.mark.usefixtures("setup_session") -@pytest.mark.parametrize( - "sid_prefix", - [SidPrefix.DOMAIN_IDENTIFIER, SidPrefix.BUILT_IN_DOMAIN], -) -async def test_set_object_sid( - session: AsyncSession, - rid_manager_gateway: RIDManagerGateway, - rid_manager_use_case: RIDManagerUseCase, - sid_prefix: SidPrefix, -) -> None: - """Test RID Manager use case.""" - directory = ( - await session.scalars( - select(Directory) - .options(selectinload(qa(Directory.attributes))) - .filter(get_filter_from_path("cn=user0,cn=Users,dc=md,dc=test")), - ) - ).one() - - rid_set = await rid_manager_use_case.get_rid_set() - assert rid_set - rid_manager = await rid_manager_gateway.get_rid_manager() - pool_before = await rid_manager_gateway.get_rid_available_pool(rid_manager) - next_before = await rid_manager_gateway.get_next_rid(rid_set) - - await rid_manager_use_case.set_object_sid( - directory, - rid=None, - sid_prefix=sid_prefix, - ) - await session.commit() - - expected_rid = next_before + 1 - pool_after = await rid_manager_gateway.get_rid_available_pool(rid_manager) - assert (pool_after & 0xFFFFFFFF) == expected_rid - assert pool_after != pool_before - - assert await rid_manager_gateway.get_next_rid(rid_set) == expected_rid - - await session.refresh(directory, ["attributes"]) - sid = await rid_manager_use_case.get_object_sid(directory) - if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: - assert sid == f"{sid_prefix}-{expected_rid}" - else: - assert sid.startswith(f"{sid_prefix}-") - assert sid.endswith(f"-{expected_rid}") +# @pytest.mark.asyncio +# @pytest.mark.usefixtures("setup_session") +# @pytest.mark.parametrize( +# "sid_prefix", +# [SidPrefix.DOMAIN_IDENTIFIER, SidPrefix.BUILT_IN_DOMAIN], +# ) +# async def test_set_object_sid( +# session: AsyncSession, +# rid_manager_gateway: RIDManagerGateway, +# rid_manager_use_case: RIDManagerUseCase, +# sid_prefix: SidPrefix, +# ) -> None: +# """Test RID Manager use case.""" +# directory = ( +# await session.scalars( +# select(Directory) +# .options(selectinload(qa(Directory.attributes))) +# .filter(get_filter_from_path("cn=user0,cn=Users,dc=md,dc=test")), +# ) +# ).one() + +# rid_set = await rid_manager_use_case.get_rid_set() +# assert rid_set +# rid_manager = await rid_manager_gateway.get_rid_manager() +# pool_before = await rid_manager_gateway.get_rid_available_pool(rid_manager) +# next_before = await rid_manager_gateway.get_next_rid(rid_set) + +# await rid_manager_use_case.set_object_sid( +# directory, +# rid=None, +# sid_prefix=sid_prefix, +# ) +# await session.commit() + +# expected_rid = next_before + 1 +# pool_after = await rid_manager_gateway.get_rid_available_pool(rid_manager) +# assert (pool_after & 0xFFFFFFFF) == expected_rid +# assert pool_after != pool_before + +# assert await rid_manager_gateway.get_next_rid(rid_set) == expected_rid + +# await session.refresh(directory, ["attributes"]) +# sid = await rid_manager_use_case.get_object_sid(directory) +# if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: +# assert sid == f"{sid_prefix}-{expected_rid}" +# else: +# assert sid.startswith(f"{sid_prefix}-") +# assert sid.endswith(f"-{expected_rid}") diff --git a/tests/test_shedule.py b/tests/test_shedule.py index ea2ef6e3c..dadbd1e5e 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -17,7 +17,7 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) -from ldap_protocol.rid_manager.use_cases import RIDManagerUseCase +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.roles.role_use_case import RoleUseCase @@ -89,7 +89,7 @@ async def test_add_domain_controller( settings: Settings, role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, - rid_manager_use_case: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Test add domain controller.""" await add_domain_controller( @@ -97,5 +97,5 @@ async def test_add_domain_controller( session=session, role_use_case=role_use_case, entity_type_use_case=entity_type_use_case, - rid_manager_use_case=rid_manager_use_case, + object_sid_use_case=object_sid_use_case, ) From e5bc127463a334158277ed2d4e33405f666c9c80 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 6 Mar 2026 16:15:23 +0300 Subject: [PATCH 13/37] fix --- .../552b4eafb1aa_remove_objectsid_vals.py | 20 +++++-- .../rid_manager/setup_gateway.py | 20 ++++++- tests/conftest.py | 58 ++++++++++++++++++- .../test_main/test_router/test_modify_dn.py | 3 +- 4 files changed, 93 insertions(+), 8 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 176d9dcce..48c17c7b2 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -1,7 +1,7 @@ """Add rIDManager and rIDSet objectClasses to LDAP schema. Revision ID: 552b4eafb1aa -Revises: 2dadf40c026a +Revises: 19d86e660cf2 Create Date: 2026-02-17 09:24:57.906080 """ @@ -32,12 +32,13 @@ ) from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway from ldap_protocol.rid_manager.utils import from_qword, to_qword +from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa # revision identifiers, used by Alembic. revision: None | str = "552b4eafb1aa" -down_revision: None | str = "2dadf40c026a" +down_revision: None | str = "19d86e660cf2" branch_labels: None | list[str] = None depends_on: None | list[str] = None @@ -173,6 +174,7 @@ async def _init_rid_manager( rid_manager_use_case = await cnt.get(RIDManagerUseCase) rid_set_gateway = await cnt.get(RIDSetGateway) rid_set_use_case = await cnt.get(RIDSetUseCase) + role_use_case = await cnt.get(RoleUseCase) if not await get_base_directories(session): return @@ -220,7 +222,13 @@ async def _init_rid_manager( start_rid = max(max_rid, RIDManagerSetupUseCase.RID_USER_MIN) qword = to_qword(start_rid, RIDManagerSetupUseCase.RID_AVAILABLE_MAX) - await rid_setup_gateway.set_rid_available_pool(rid_manager_dir, qword) + await rid_setup_gateway.set_rid_available_pool(domain, qword) + + system_container = await rid_setup_gateway.get_system_container() + await role_use_case.inherit_parent_aces( + parent_directory=system_container, + directory=rid_manager_dir, + ) domain_controller = await rid_gateway.get_domain_controller() rid_set_dir: Directory | None = None @@ -236,7 +244,7 @@ async def _init_rid_manager( allocation_pool = await rid_manager_use_case.allocate_pool() lower, _ = from_qword(previous_allocation_pool) - await rid_set_use_case.add( + rid_set_dir = await rid_set_use_case.add( domain_controller, RIDSetAllocationParamsDTO( next_rid=lower, @@ -244,6 +252,10 @@ async def _init_rid_manager( previous_allocation_pool=previous_allocation_pool, ), ) + await role_use_case.inherit_parent_aces( + parent_directory=domain_controller, + directory=rid_set_dir, + ) await session.commit() return diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py index 511483fb6..bc3f9aa9d 100644 --- a/app/ldap_protocol/rid_manager/setup_gateway.py +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -6,7 +6,7 @@ import secrets -from sqlalchemy import select, update +from sqlalchemy import exists, select, update from sqlalchemy.ext.asyncio import AsyncSession from entities import Attribute, Directory @@ -172,8 +172,24 @@ def _generate_domain_sid_identifier(self) -> str: async def create_domain_identifier(self) -> None: """Add domain identifier to domain.""" - domain = (await get_base_directories(self._session))[0] + domain_identifer = await self._session.scalar( + select( + exists(Attribute), + ).where( + qa(Attribute.name) == "DomainIdentifier", + ), + ) + if domain_identifer: + return + domain = await self._session.scalar( + select(Directory).where( + qa(Directory.object_class) == "domain", + qa(Directory.parent_id).is_(None), + ), + ) + if not domain: + raise self._session.add( Attribute( name="DomainIdentifier", diff --git a/tests/conftest.py b/tests/conftest.py index 8263c2c76..d62f97e5e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1184,7 +1184,7 @@ async def setup_session( is_system=False, ) dc_directory = Directory( - name=settings.HOST_MACHINE_NAME, + name=DOMAIN_CONTROLLERS_OU_NAME, object_class="computer", is_system=True, ) @@ -1812,6 +1812,62 @@ async def rid_manager_use_case( yield RIDManagerUseCase(rid_manager_gateway, session) +@pytest_asyncio.fixture(scope="function") +async def rid_set_gateway( + container: AsyncContainer, +) -> AsyncIterator[RIDSetGateway]: + """Provide RIDSetGateway for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDSetGateway(session) + + +@pytest_asyncio.fixture(scope="function") +async def rid_set_use_case( + container: AsyncContainer, + rid_manager_use_case: RIDManagerUseCase, + entity_type_dao: EntityTypeDAO, + rid_set_gateway: RIDSetGateway, +) -> AsyncIterator[RIDSetUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield RIDSetUseCase( + rid_set_gateway, + entity_type_dao, + session, + rid_manager_use_case, + ) + + +@pytest_asyncio.fixture(scope="function") +async def object_sid_gateway( + container: AsyncContainer, +) -> AsyncIterator[ObjectSIDGateway]: + """Provide ObjectSIDGateway for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield ObjectSIDGateway(session) + + +@pytest_asyncio.fixture(scope="function") +async def object_sid_use_case( + container: AsyncContainer, + rid_manager_use_case: RIDManagerUseCase, + rid_set_use_case: RIDSetUseCase, + object_sid_gateway: ObjectSIDGateway, +) -> AsyncIterator[ObjectSIDUseCase]: + """Provide RIDManagerUseCase for tests that request it explicitly.""" + async with container(scope=Scope.SESSION) as container: + session = await container.get(AsyncSession) + yield ObjectSIDUseCase( + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + ) + + def pytest_configure(config: pytest.Config) -> None: """Pytest hook to limit xdist workers based on Dragonfly DBs. diff --git a/tests/test_api/test_main/test_router/test_modify_dn.py b/tests/test_api/test_main/test_router/test_modify_dn.py index efe7dcf0a..af0bb83ef 100644 --- a/tests/test_api/test_main/test_router/test_modify_dn.py +++ b/tests/test_api/test_main/test_router/test_modify_dn.py @@ -16,6 +16,7 @@ @pytest.mark.usefixtures("session") async def test_api_modify_dn_without_level_change( http_client: AsyncClient, + session: AsyncSession, ) -> None: """Test API for updating DN. @@ -41,7 +42,7 @@ async def test_api_modify_dn_without_level_change( data["search_result"][0]["object_name"] == "ou=testModifyDn1,dc=md,dc=test" ) - + session.expire_all() response = await http_client.put( "/entry/update/dn", json={ From 32865a71fa6ba1f2c3deb893b50373d94aeecf1f Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Mon, 30 Mar 2026 12:59:25 +0300 Subject: [PATCH 14/37] Refactor: Simplify object SID handling in Directory class and update related database operations --- .../552b4eafb1aa_remove_objectsid_vals.py | 45 ++++++++++---- .../versions/6f8fe2548893_fix_read_only.py | 20 +++++- app/api/main/schema.py | 5 +- app/entities.py | 19 +++--- app/ldap_protocol/ldap_requests/search.py | 5 +- app/ldap_protocol/rid_manager/exceptions.py | 62 +++++-------------- .../rid_manager/rid_manager_gateway.py | 4 +- .../rid_manager/rid_manager_use_case.py | 13 ++-- .../rid_manager/rid_set_gateway.py | 12 ++-- .../rid_manager/rid_set_use_case.py | 37 ++++++----- .../rid_manager/setup_gateway.py | 5 +- interface | 2 +- 12 files changed, 123 insertions(+), 106 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 48c17c7b2..9988be436 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -94,15 +94,26 @@ async def _migrate_object_sids( async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) - directories = await session.scalars(select(Directory)) + directory_table = sa.table( + "Directory", + sa.column("id", sa.Integer), + sa.column("objectSid", sa.String), + ) + + result = await session.execute( + select( + directory_table.c.id, + directory_table.c.objectSid, + ), + ) - for directory in directories: - if not directory.object_sid: + for directory_id, object_sid in result: + if not object_sid: continue existing_attr = await session.scalar( select(Attribute).where( - qa(Attribute.directory_id) == directory.id, + qa(Attribute.directory_id) == directory_id, qa(Attribute.name) == "objectSid", ), ) @@ -111,8 +122,8 @@ async def _migrate_object_sids( session.add( Attribute( name="objectSid", - value=directory.object_sid, - directory_id=directory.id, + value=object_sid, + directory_id=directory_id, ), ) @@ -389,19 +400,25 @@ async def _rollback_object_sids( async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) - directories = await session.scalars(select(Directory)) + directory_table = sa.table( + "Directory", + sa.column("id", sa.Integer), + sa.column("objectSid", sa.String), + ) + + result = await session.execute(select(directory_table.c.id)) - for directory in directories: + for (directory_id,) in result: await session.execute( delete(Attribute).where( - qa(Attribute.directory_id) == directory.id, + qa(Attribute.directory_id) == directory_id, qa(Attribute.name) == "DomainIdentifier", ), ) attr = await session.scalar( select(Attribute).where( - qa(Attribute.directory_id) == directory.id, + qa(Attribute.directory_id) == directory_id, qa(Attribute.name) == "objectSid", ), ) @@ -409,11 +426,15 @@ async def _rollback_object_sids( if not attr or not attr.value: continue - directory.object_sid = attr.value + await session.execute( + update(directory_table) + .where(directory_table.c.id == directory_id) + .values(objectSid=attr.value), + ) await session.execute( delete(Attribute).where( - qa(Attribute.directory_id) == directory.id, + qa(Attribute.directory_id) == directory_id, qa(Attribute.name) == "objectSid", ), ) diff --git a/app/alembic/versions/6f8fe2548893_fix_read_only.py b/app/alembic/versions/6f8fe2548893_fix_read_only.py index f28264704..2a5f51ef4 100644 --- a/app/alembic/versions/6f8fe2548893_fix_read_only.py +++ b/app/alembic/versions/6f8fe2548893_fix_read_only.py @@ -31,6 +31,12 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 bind = op.get_bind() session = Session(bind=bind) + directory_table = sa.table( + "Directory", + sa.column("id", sa.Integer), + sa.column("objectSid", sa.String), + ) + ro_dir = session.scalar( select(Directory) .filter_by(name="readonly domain controllers"), @@ -82,8 +88,18 @@ def upgrade(container: AsyncContainer) -> None: # noqa: ARG001 ), ) - domain_sid = "-".join(ro_dir.object_sid.split("-")[:-1]) - ro_dir.object_sid = domain_sid + "-521" + ro_object_sid = session.scalar( + select(directory_table.c.objectSid).where( + directory_table.c.id == ro_dir.id, + ), + ) + if ro_object_sid: + domain_sid = "-".join(ro_object_sid.split("-")[:-1]) + session.execute( + update(directory_table) + .where(directory_table.c.id == ro_dir.id) + .values(objectSid=domain_sid + "-521"), + ) session.commit() diff --git a/app/api/main/schema.py b/app/api/main/schema.py index 4140bb716..bc559de4d 100644 --- a/app/api/main/schema.py +++ b/app/api/main/schema.py @@ -39,10 +39,7 @@ def _cast_filter(self) -> UnaryExpression | ColumnElement: @staticmethod def get_directory_sid(directory: Directory) -> str | None: # type: ignore - for attr in getattr(directory, "attributes", []): - if attr.name and attr.name.lower() == "objectsid" and attr.value: - return attr.value - return None + return directory.object_sid @staticmethod def get_directory_guid(directory: Directory) -> str: # type: ignore diff --git a/app/entities.py b/app/entities.py index 28f0730cb..64605c470 100644 --- a/app/entities.py +++ b/app/entities.py @@ -101,7 +101,6 @@ class Directory: id: int = field(init=False) name: str is_system: bool = field(default=False) - object_sid: str = field(default="") object_guid: uuid.UUID = field(default_factory=uuid.uuid4) parent_id: int | None = None entity_type_id: int | None = None @@ -184,20 +183,24 @@ def create_path( self.rdname = dn @property - def relative_id(self) -> str: - """Get RID from objectSid attribute. - - Relative Identifier (RID) is the last sub-authority value of a SID. - """ + def object_sid(self) -> str: + """Get objectSid attribute value.""" attrs = self.__dict__.get("attributes") if not attrs: return "" - for attr in attrs: if attr.name and attr.name.lower() == "objectsid" and attr.value: - return attr.value.split("-")[-1] + return attr.value return "" + @property + def relative_id(self) -> str: + """Get RID from objectSid attribute. + + Relative Identifier (RID) is the last sub-authority value of a SID. + """ + return self.object_sid.split("-")[-1] if self.object_sid else "" + @property def attributes_dict(self) -> defaultdict[str, list[str]]: d: defaultdict[str, list[str]] = defaultdict(list) diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index ccc74ee0a..d47d4a6c8 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -558,10 +558,7 @@ async def _fill_attrs( # noqa: C901 @staticmethod def get_directory_sid(directory: Directory) -> bytes | None: """Get objectSid as bytes from directory attributes.""" - for attr in directory.attributes: - if attr.name and attr.name.lower() == "objectsid" and attr.value: - return string_to_sid(attr.value) - return None + return string_to_sid(directory.object_sid) @staticmethod def get_directory_guid(directory: Directory) -> bytes: diff --git a/app/ldap_protocol/rid_manager/exceptions.py b/app/ldap_protocol/rid_manager/exceptions.py index 9964f5f77..ba38eec7e 100644 --- a/app/ldap_protocol/rid_manager/exceptions.py +++ b/app/ldap_protocol/rid_manager/exceptions.py @@ -10,22 +10,16 @@ class ErrorCodes(IntEnum): BASE_ERROR = 0 RID_MANAGER_NOT_FOUND_ERROR = 1 - RID_MANAGER_ALREADY_EXISTS_ERROR = 2 - RID_MANAGER_CANT_MODIFY_ERROR = 3 - RID_MANAGER_SETUP_ERROR = 4 - RID_AVAILABLE_POOL_NOT_FOUND_ERROR = 5 - RID_NEXT_RID_NOT_FOUND_ERROR = 6 - RID_SET_NOT_FOUND_ERROR = 7 - RID_SET_CANT_MODIFY_ERROR = 8 - RID_SET_ALREADY_EXISTS_ERROR = 9 - RID_DOMAIN_IDENTIFIER_NOT_FOUND_ERROR = 10 - RID_DOMAIN_CONTROLLER_NOT_FOUND_ERROR = 11 - RID_OBJECT_SID_NOT_FOUND_ERROR = 12 - RID_BASE_DOMAIN_NOT_FOUND_ERROR = 13 - RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR = 14 - RID_ALLOCATION_POOL_NOT_FOUND_ERROR = 15 - RID_PREVIOUS_ALLOCATION_POOL_NOT_FOUND_ERROR = 16 - RID_POOL_EXCEEDED_ERROR = 17 + RID_AVAILABLE_POOL_NOT_FOUND_ERROR = 2 + RID_NEXT_RID_NOT_FOUND_ERROR = 3 + RID_SET_NOT_FOUND_ERROR = 4 + RID_DOMAIN_IDENTIFIER_NOT_FOUND_ERROR = 5 + RID_DOMAIN_CONTROLLER_NOT_FOUND_ERROR = 6 + RID_BASE_DOMAIN_NOT_FOUND_ERROR = 7 + RID_SYSTEM_CONTAINER_NOT_FOUND_ERROR = 8 + RID_ALLOCATION_POOL_NOT_FOUND_ERROR = 9 + RID_PREVIOUS_ALLOCATION_POOL_NOT_FOUND_ERROR = 10 + RID_POOL_EXCEEDED_ERROR = 11 class RIDManagerError(BaseDomainException): @@ -40,12 +34,6 @@ class RIDManagerNotFoundError(RIDManagerError): code = ErrorCodes.RID_MANAGER_NOT_FOUND_ERROR -class RIDManagerSetupError(RIDManagerError): - """RID Manager setup error.""" - - code = ErrorCodes.RID_MANAGER_SETUP_ERROR - - class RIDManagerAvailablePoolNotFoundError(RIDManagerError): """RID Manager available pool not found error.""" @@ -64,18 +52,6 @@ class RIDManagerRidSetNotFoundError(RIDManagerError): code = ErrorCodes.RID_SET_NOT_FOUND_ERROR -class RIDManagerSetCantModifyError(RIDManagerError): - """RID Manager set can't modify error.""" - - code = ErrorCodes.RID_SET_CANT_MODIFY_ERROR - - -class RIDManagerSetAlreadyExistsError(RIDManagerError): - """RID Manager set already exists error.""" - - code = ErrorCodes.RID_SET_ALREADY_EXISTS_ERROR - - class RIDManagerDomainIdentifierNotFoundError(RIDManagerError): """RID Manager domain identifier not found error.""" @@ -88,18 +64,6 @@ class RIDManagerDomainControllerNotFoundError(RIDManagerError): code = ErrorCodes.RID_DOMAIN_CONTROLLER_NOT_FOUND_ERROR -class RIDManagerObjectSidNotFoundError(RIDManagerError): - """RID Manager object SID not found error.""" - - code = ErrorCodes.RID_OBJECT_SID_NOT_FOUND_ERROR - - -class RIDManagerDomainNotFoundError(RIDManagerError): - """RID Manager base domain not found error.""" - - code = ErrorCodes.RID_BASE_DOMAIN_NOT_FOUND_ERROR - - class RIDManagerSystemContainerNotFoundError(RIDManagerError): """RID Manager system container not found error.""" @@ -122,3 +86,9 @@ class RIDManagerPoolExceededError(RIDManagerError): """RID Manager pool exceeded error.""" code = ErrorCodes.RID_POOL_EXCEEDED_ERROR + + +class RIDManagerBaseDomainNotFoundError(RIDManagerError): + """RID Manager base domain not found error.""" + + code = ErrorCodes.RID_BASE_DOMAIN_NOT_FOUND_ERROR diff --git a/app/ldap_protocol/rid_manager/rid_manager_gateway.py b/app/ldap_protocol/rid_manager/rid_manager_gateway.py index 69cf46c92..f0e125441 100644 --- a/app/ldap_protocol/rid_manager/rid_manager_gateway.py +++ b/app/ldap_protocol/rid_manager/rid_manager_gateway.py @@ -36,7 +36,9 @@ async def get_rid_manager(self) -> Directory: async def get_rid_available_pool(self) -> int: """Get RID available pool.""" rid_available_pool = await self._session.scalar( - select(Attribute).where(qa(Attribute.name) == "rIDAvailablePool"), + select(Attribute) + .where(qa(Attribute.name) == "rIDAvailablePool") + .with_for_update(), ) if not (rid_available_pool and rid_available_pool.value): raise RIDManagerAvailablePoolNotFoundError( diff --git a/app/ldap_protocol/rid_manager/rid_manager_use_case.py b/app/ldap_protocol/rid_manager/rid_manager_use_case.py index 5ce06dcbb..e13725e54 100644 --- a/app/ldap_protocol/rid_manager/rid_manager_use_case.py +++ b/app/ldap_protocol/rid_manager/rid_manager_use_case.py @@ -30,14 +30,15 @@ def __init__( async def allocate_pool(self) -> int: """Allocate pool.""" - available_pool = await self._gateway.get_rid_available_pool() - lower, upper = from_qword(available_pool) + async with self._session.begin_nested(): + available_pool = await self._gateway.get_rid_available_pool() + lower, upper = from_qword(available_pool) - if lower + self.RID_BLOCK_SIZE > upper: - raise RIDManagerPoolExceededError("Available pool exceeded") + if lower + self.RID_BLOCK_SIZE > upper: + raise RIDManagerPoolExceededError("Available pool exceeded") - new_available_pool = to_qword(lower + self.RID_BLOCK_SIZE, upper) - await self._gateway.update_rid_available_pool(new_available_pool) + new_available_pool = to_qword(lower + self.RID_BLOCK_SIZE, upper) + await self._gateway.update_rid_available_pool(new_available_pool) return to_qword(lower, lower + self.RID_BLOCK_SIZE) diff --git a/app/ldap_protocol/rid_manager/rid_set_gateway.py b/app/ldap_protocol/rid_manager/rid_set_gateway.py index 8f3591bf5..b032a3fe9 100644 --- a/app/ldap_protocol/rid_manager/rid_set_gateway.py +++ b/app/ldap_protocol/rid_manager/rid_set_gateway.py @@ -134,10 +134,12 @@ async def get_rid_previous_allocation_pool( ) -> int: """Get previous RID allocation pool from RID Set directory.""" previous_allocation_pool = await self._session.scalar( - select(Attribute).where( + select(Attribute) + .where( qa(Attribute.name) == "rIDPreviousAllocationPool", qa(Attribute.directory_id) == rid_set.id, - ), + ) + .with_for_update(), ) if not (previous_allocation_pool and previous_allocation_pool.value): raise RIDManagerRidPreviousAllocationPoolNotFoundError( @@ -148,10 +150,12 @@ async def get_rid_previous_allocation_pool( async def get_rid_next_rid(self, rid_set: Directory) -> int: """Get next RID from RID Set directory.""" next_rid = await self._session.scalar( - select(Attribute).where( + select(Attribute) + .where( qa(Attribute.name) == "rIDNextRID", qa(Attribute.directory_id) == rid_set.id, - ), + ) + .with_for_update(), ) if not (next_rid and next_rid.value): raise RIDManagerRidNextRIDNotFoundError("next RID not found") diff --git a/app/ldap_protocol/rid_manager/rid_set_use_case.py b/app/ldap_protocol/rid_manager/rid_set_use_case.py index 8c4c7ceed..dfd60661e 100644 --- a/app/ldap_protocol/rid_manager/rid_set_use_case.py +++ b/app/ldap_protocol/rid_manager/rid_set_use_case.py @@ -55,36 +55,37 @@ async def add( async def is_pool_exceeded(self, rid_set: Directory) -> bool: """Check if RID pool is exceeded.""" + next_rid = await self._gateway.get_rid_next_rid(rid_set) previous_allocation_pool = ( await self._gateway.get_rid_previous_allocation_pool(rid_set) ) _, upper = from_qword(previous_allocation_pool) - next_rid = await self._gateway.get_rid_next_rid(rid_set) return next_rid + 1 >= upper async def allocate_next_rid(self, rid_set: Directory) -> int: """Allocate next RID.""" - if await self.is_pool_exceeded(rid_set): + async with self._session.begin_nested(): + if await self.is_pool_exceeded(rid_set): + previous_allocation_pool = ( + await self._rid_manager_use_case.allocate_pool() + ) + await self.reset_attrs_when_pool_exceeded( + rid_set, + previous_allocation_pool, + ) + current_rid = await self._gateway.get_rid_next_rid(rid_set) previous_allocation_pool = ( - await self._rid_manager_use_case.allocate_pool() + await self._gateway.get_rid_previous_allocation_pool(rid_set) ) - await self.reset_attrs_when_pool_exceeded( + _, upper = from_qword(previous_allocation_pool) + new_rid = current_rid + 1 + new_allocation_pool = to_qword(new_rid, upper) + await self._gateway.update_next_rid_and_pool( rid_set, - previous_allocation_pool, + new_rid, + new_allocation_pool, ) - current_rid = await self._gateway.get_rid_next_rid(rid_set) - previous_allocation_pool = ( - await self._gateway.get_rid_previous_allocation_pool(rid_set) - ) - _, upper = from_qword(previous_allocation_pool) - new_rid = current_rid + 1 - new_allocation_pool = to_qword(new_rid, upper) - await self._gateway.update_next_rid_and_pool( - rid_set, - new_rid, - new_allocation_pool, - ) return new_rid async def reset_attrs_when_pool_exceeded( @@ -93,6 +94,8 @@ async def reset_attrs_when_pool_exceeded( previous_allocation_pool: int, ) -> None: """Reset RID pools when pool exceeded.""" + _ = await self._gateway.get_rid_next_rid(rid_set) # lock next RID + current_previous_allocation_pool = ( await self._gateway.get_rid_previous_allocation_pool( rid_set, diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py index bc3f9aa9d..acac8a496 100644 --- a/app/ldap_protocol/rid_manager/setup_gateway.py +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -12,6 +12,7 @@ from entities import Attribute, Directory from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.rid_manager.exceptions import ( + RIDManagerBaseDomainNotFoundError, RIDManagerDomainControllerNotFoundError, RIDManagerSystemContainerNotFoundError, ) @@ -182,6 +183,7 @@ async def create_domain_identifier(self) -> None: if domain_identifer: return + domain = await self._session.scalar( select(Directory).where( qa(Directory.object_class) == "domain", @@ -189,7 +191,8 @@ async def create_domain_identifier(self) -> None: ), ) if not domain: - raise + raise RIDManagerBaseDomainNotFoundError("Domain not found") + self._session.add( Attribute( name="DomainIdentifier", diff --git a/interface b/interface index 046449cdd..61e15e236 160000 --- a/interface +++ b/interface @@ -1 +1 @@ -Subproject commit 046449cdd568919cca12a7939366dcee7a54fdfa +Subproject commit 61e15e2367182a3e706c94cf8e1895d742840aa7 From 08cbecdb3db47d3069c10ea188caf8cbd42c8af9 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Mon, 30 Mar 2026 13:10:37 +0300 Subject: [PATCH 15/37] Update down_revision in Alembic migration to reflect new dependency --- app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 9988be436..fa87ace60 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -38,7 +38,7 @@ # revision identifiers, used by Alembic. revision: None | str = "552b4eafb1aa" -down_revision: None | str = "19d86e660cf2" +down_revision: None | str = "df4287898910" branch_labels: None | list[str] = None depends_on: None | list[str] = None From d30f7ca673ed15aaa4402b34f19a058e3e218784 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 31 Mar 2026 08:51:34 +0300 Subject: [PATCH 16/37] Refactor: Update domain controller retrieval methods and enhance RID Manager gateway initialization --- app/ldap_protocol/ldap_requests/search.py | 6 +++++- .../rid_manager/object_sid_use_case.py | 4 +++- .../rid_manager/rid_manager_gateway.py | 14 +++++++------- .../rid_manager/rid_manager_use_case.py | 4 +--- app/ldap_protocol/rid_manager/setup_gateway.py | 14 ++++++++++++++ app/ldap_protocol/rid_manager/setup_use_case.py | 5 ++--- tests/conftest.py | 5 +++-- 7 files changed, 35 insertions(+), 17 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index d47d4a6c8..afa7c4614 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -558,7 +558,11 @@ async def _fill_attrs( # noqa: C901 @staticmethod def get_directory_sid(directory: Directory) -> bytes | None: """Get objectSid as bytes from directory attributes.""" - return string_to_sid(directory.object_sid) + return ( + string_to_sid(directory.object_sid) + if directory.object_sid + else None + ) @staticmethod def get_directory_guid(directory: Directory) -> bytes: diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py index 9ae878ace..953f4bb1f 100644 --- a/app/ldap_protocol/rid_manager/object_sid_use_case.py +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -42,7 +42,9 @@ async def add( ) -> None: """Add object SID.""" if rid is None: - domain_controller = await self._rid_manager_use_case.choose_nearest_domain_controller() # noqa + domain_controller = ( + await self._rid_manager_use_case.get_domain_controller() + ) rid_set = await self._rid_set_use_case.get(domain_controller) rid = await self._rid_set_use_case.allocate_next_rid( rid_set, diff --git a/app/ldap_protocol/rid_manager/rid_manager_gateway.py b/app/ldap_protocol/rid_manager/rid_manager_gateway.py index f0e125441..1b6402acc 100644 --- a/app/ldap_protocol/rid_manager/rid_manager_gateway.py +++ b/app/ldap_protocol/rid_manager/rid_manager_gateway.py @@ -7,7 +7,7 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession -from constants import DOMAIN_CONTROLLERS_OU_NAME +from config import Settings from entities import Attribute, Directory from ldap_protocol.rid_manager.exceptions import ( RIDManagerAvailablePoolNotFoundError, @@ -20,9 +20,10 @@ class RIDManagerGateway: """RID Manager gateway.""" - def __init__(self, session: AsyncSession) -> None: + def __init__(self, session: AsyncSession, settings: Settings) -> None: """Initialize RID Manager gateway.""" self._session = session + self._settings = settings async def get_rid_manager(self) -> Directory: """Get RID Manager directory.""" @@ -56,16 +57,15 @@ async def update_rid_available_pool(self, available_pool: int) -> None: async def get_domain_controller( self, - name: str = DOMAIN_CONTROLLERS_OU_NAME, ) -> Directory: """Get domain controller.""" - domain_controllers_ou = await self._session.scalar( + domain_controller = await self._session.scalar( select(Directory).where( - qa(Directory.name) == name, + qa(Directory.name) == self._settings.HOST_MACHINE_SHORT_NAME, ), ) - if not domain_controllers_ou: + if not domain_controller: raise RIDManagerDomainControllerNotFoundError( "Domain controller not found", ) - return domain_controllers_ou + return domain_controller diff --git a/app/ldap_protocol/rid_manager/rid_manager_use_case.py b/app/ldap_protocol/rid_manager/rid_manager_use_case.py index e13725e54..62fc74e58 100644 --- a/app/ldap_protocol/rid_manager/rid_manager_use_case.py +++ b/app/ldap_protocol/rid_manager/rid_manager_use_case.py @@ -42,8 +42,6 @@ async def allocate_pool(self) -> int: return to_qword(lower, lower + self.RID_BLOCK_SIZE) - async def choose_nearest_domain_controller(self) -> Directory: + async def get_domain_controller(self) -> Directory: """Locate best Domain Controller via DNS SRV records.""" - # TODO: нужно через DNS определять ближайший DC # noqa - # и использовать его для выдачи RID return await self._gateway.get_domain_controller() diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py index acac8a496..26575bb6f 100644 --- a/app/ldap_protocol/rid_manager/setup_gateway.py +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -9,6 +9,7 @@ from sqlalchemy import exists, select, update from sqlalchemy.ext.asyncio import AsyncSession +from constants import DOMAIN_CONTROLLERS_OU_NAME from entities import Attribute, Directory from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.rid_manager.exceptions import ( @@ -201,3 +202,16 @@ async def create_domain_identifier(self) -> None: ), ) await self._session.flush() + + async def get_domain_controller_ou(self) -> Directory: + """Get Domain Controller OU directory.""" + domain_controller_ou = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, + ), + ) + if not domain_controller_ou: + raise RIDManagerDomainControllerNotFoundError( + "Domain Controller OU not found", + ) + return domain_controller_ou diff --git a/app/ldap_protocol/rid_manager/setup_use_case.py b/app/ldap_protocol/rid_manager/setup_use_case.py index 97e7eced8..b966e8e87 100644 --- a/app/ldap_protocol/rid_manager/setup_use_case.py +++ b/app/ldap_protocol/rid_manager/setup_use_case.py @@ -54,9 +54,8 @@ async def setup(self) -> None: rid_manager_dir, qword, ) - dc = ( - await self._rid_manager_use_case.choose_nearest_domain_controller() - ) + dc = await self._rid_manager_use_case.get_domain_controller() + dc_ou = await self._gateway.get_domain_controller_ou() rid_set = await self._create_rid_set(dc) await self.inherit_aces( diff --git a/tests/conftest.py b/tests/conftest.py index d62f97e5e..71e648afd 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1124,7 +1124,7 @@ async def setup_session( password_policy_validator, password_ban_word_repository, ) - rid_manager_gateway = RIDManagerGateway(session) + rid_manager_gateway = RIDManagerGateway(session, settings) rid_manager_use_case = RIDManagerUseCase( rid_manager_gateway, session, @@ -1794,11 +1794,12 @@ async def ctx_search( @pytest_asyncio.fixture(scope="function") async def rid_manager_gateway( container: AsyncContainer, + settings: Settings, ) -> AsyncIterator[RIDManagerGateway]: """Get RID Manager gateway.""" async with container(scope=Scope.SESSION) as container: session = await container.get(AsyncSession) - yield RIDManagerGateway(session) + yield RIDManagerGateway(session, settings) @pytest_asyncio.fixture(scope="function") From 72f782e5b47ca618a0857bf4cfb32103af608c51 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 31 Mar 2026 12:15:54 +0300 Subject: [PATCH 17/37] Refactor: Remove unused domain controller OU retrieval and streamline RID Manager setup process --- app/extra/scripts/add_domain_controller.py | 3 +-- app/ldap_protocol/auth/use_cases.py | 15 +++++++++++++-- app/ldap_protocol/rid_manager/setup_gateway.py | 14 -------------- app/ldap_protocol/rid_manager/setup_use_case.py | 1 - 4 files changed, 14 insertions(+), 19 deletions(-) diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py index 1b771d1fc..73b18ae84 100644 --- a/app/extra/scripts/add_domain_controller.py +++ b/app/extra/scripts/add_domain_controller.py @@ -11,7 +11,7 @@ from config import Settings from constants import DOMAIN_CONTROLLERS_OU_NAME from entities import Attribute, Directory -from enums import SamAccountTypeCodes, SecurityPrincipalRid +from enums import SamAccountTypeCodes from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) @@ -41,7 +41,6 @@ async def _add_domain_controller( dc_directory.parent_id = dc_ou_dir.id await object_sid_use_case.add( directory=dc_directory, - rid=SecurityPrincipalRid.DOMAIN_CONTROLLERS, ) await session.flush() diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index 4fb3d666c..c85378cc8 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -44,7 +44,11 @@ from ldap_protocol.objects import UserAccountControlFlag from ldap_protocol.policies.audit.audit_use_case import AuditUseCase from ldap_protocol.policies.password import PasswordPolicyUseCases -from ldap_protocol.rid_manager import RIDManagerSetupUseCase +from ldap_protocol.rid_manager import ( + ObjectSIDUseCase, + RIDManagerSetupUseCase, + RIDManagerUseCase, +) from ldap_protocol.roles.role_use_case import RoleUseCase from ldap_protocol.utils.helpers import create_integer_hash, ft_now @@ -66,6 +70,8 @@ def __init__( session: AsyncSession, settings: Settings, rid_manager_setup_use_case: RIDManagerSetupUseCase, + rid_manager_use_case: RIDManagerUseCase, + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Initialize Setup manager. @@ -85,6 +91,8 @@ def __init__( self._object_class_use_case = object_class_use_case self._settings = settings self._rid_manager_setup_use_case = rid_manager_setup_use_case + self._rid_manager_use_case = rid_manager_use_case + self._object_sid_use_case = object_sid_use_case async def setup(self, dto: SetupDTO) -> None: """Perform the initial setup of structure and policies. @@ -126,7 +134,6 @@ def _create_domain_controller_data(self) -> dict: "name": self._settings.HOST_MACHINE_SHORT_NAME, "entity_type_name": EntityTypeNames.COMPUTER, "object_class": "computer", - "objectSid": SecurityPrincipalRid.DOMAIN_CONTROLLERS, "attributes": { "objectClass": ["top"], "userAccountControl": [ @@ -246,6 +253,10 @@ async def _create(self, dto: SetupDTO, data: list) -> None: await self._role_use_case.create_read_only_role() await self._audit_use_case.create_policies() await self._rid_manager_setup_use_case.setup() + dc = await self._rid_manager_use_case.get_domain_controller() + await self._object_sid_use_case.add( + directory=dc, + ) await self._session.commit() except IntegrityError: diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py index 26575bb6f..acac8a496 100644 --- a/app/ldap_protocol/rid_manager/setup_gateway.py +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -9,7 +9,6 @@ from sqlalchemy import exists, select, update from sqlalchemy.ext.asyncio import AsyncSession -from constants import DOMAIN_CONTROLLERS_OU_NAME from entities import Attribute, Directory from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.rid_manager.exceptions import ( @@ -202,16 +201,3 @@ async def create_domain_identifier(self) -> None: ), ) await self._session.flush() - - async def get_domain_controller_ou(self) -> Directory: - """Get Domain Controller OU directory.""" - domain_controller_ou = await self._session.scalar( - select(Directory).where( - qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, - ), - ) - if not domain_controller_ou: - raise RIDManagerDomainControllerNotFoundError( - "Domain Controller OU not found", - ) - return domain_controller_ou diff --git a/app/ldap_protocol/rid_manager/setup_use_case.py b/app/ldap_protocol/rid_manager/setup_use_case.py index b966e8e87..02c41ada1 100644 --- a/app/ldap_protocol/rid_manager/setup_use_case.py +++ b/app/ldap_protocol/rid_manager/setup_use_case.py @@ -55,7 +55,6 @@ async def setup(self) -> None: qword, ) dc = await self._rid_manager_use_case.get_domain_controller() - dc_ou = await self._gateway.get_domain_controller_ou() rid_set = await self._create_rid_set(dc) await self.inherit_aces( From 6b444a0e89324187384fb5ae22376e173c048ba5 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 31 Mar 2026 13:34:39 +0300 Subject: [PATCH 18/37] Add: Introduce RID_OBJECT_SID_NOT_FOUND_ERROR exception and update ObjectSIDGateway to raise it when SID is not found --- app/ldap_protocol/rid_manager/exceptions.py | 7 + .../rid_manager/object_sid_gateway.py | 7 +- tests/conftest.py | 16 +- tests/test_ldap/test_object_sid.py | 137 ++++++++++++++++++ tests/test_ldap/test_rid_manager.py | 51 ------- 5 files changed, 165 insertions(+), 53 deletions(-) create mode 100644 tests/test_ldap/test_object_sid.py delete mode 100644 tests/test_ldap/test_rid_manager.py diff --git a/app/ldap_protocol/rid_manager/exceptions.py b/app/ldap_protocol/rid_manager/exceptions.py index ba38eec7e..40b75bbb9 100644 --- a/app/ldap_protocol/rid_manager/exceptions.py +++ b/app/ldap_protocol/rid_manager/exceptions.py @@ -20,6 +20,7 @@ class ErrorCodes(IntEnum): RID_ALLOCATION_POOL_NOT_FOUND_ERROR = 9 RID_PREVIOUS_ALLOCATION_POOL_NOT_FOUND_ERROR = 10 RID_POOL_EXCEEDED_ERROR = 11 + RID_OBJECT_SID_NOT_FOUND_ERROR = 12 class RIDManagerError(BaseDomainException): @@ -92,3 +93,9 @@ class RIDManagerBaseDomainNotFoundError(RIDManagerError): """RID Manager base domain not found error.""" code = ErrorCodes.RID_BASE_DOMAIN_NOT_FOUND_ERROR + + +class RIDManagerObjectSIDNotFoundError(RIDManagerError): + """RID Manager object SID not found error.""" + + code = ErrorCodes.RID_OBJECT_SID_NOT_FOUND_ERROR diff --git a/app/ldap_protocol/rid_manager/object_sid_gateway.py b/app/ldap_protocol/rid_manager/object_sid_gateway.py index 3f7d25683..5ba990a39 100644 --- a/app/ldap_protocol/rid_manager/object_sid_gateway.py +++ b/app/ldap_protocol/rid_manager/object_sid_gateway.py @@ -10,6 +10,7 @@ from entities import Attribute, Directory from ldap_protocol.rid_manager.exceptions import ( RIDManagerDomainIdentifierNotFoundError, + RIDManagerObjectSIDNotFoundError, ) from repo.pg.tables import queryable_attr as qa @@ -23,12 +24,16 @@ def __init__(self, session: AsyncSession) -> None: async def get(self, directory: Directory) -> str: """Get object SID.""" - return await self._session.scalar( + query = await self._session.scalar( select(Attribute).where( qa(Attribute.directory_id) == directory.id, qa(Attribute.name) == "objectSid", ), ) + if not (query and query.value): + raise RIDManagerObjectSIDNotFoundError("object SID not found") + + return query.value async def add(self, directory: Directory, object_sid: str) -> None: """Add object SID.""" diff --git a/tests/conftest.py b/tests/conftest.py index 71e648afd..ade487450 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -63,7 +63,11 @@ from api.shadow.adapter import ShadowAdapter from authorization_provider_protocol import AuthorizationProviderProtocol from config import Settings -from constants import ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2 +from constants import ( + DOMAIN_CONTROLLERS_OU_NAME, + ENTITY_TYPE_DTOS_V1, + ENTITY_TYPE_DTOS_V2, +) from entities import Directory from enums import AuthorizationRules from ioc import AuditRedisClient, MFACredsProvider, SessionStorageClient @@ -1194,6 +1198,16 @@ async def setup_session( dc_directory.parent_id = domain.id await session.refresh(dc_directory, ["id"]) await session.flush() + dc = Directory( + name=settings.HOST_MACHINE_SHORT_NAME, + object_class="computer", + is_system=True, + ) + dc.create_path(dc_directory, "cn") + session.add(dc) + await session.flush() + dc.parent_id = dc_directory.id + await session.refresh(dc, ["id"]) for _at_dto in ( AttributeTypeDTO[None]( diff --git a/tests/test_ldap/test_object_sid.py b/tests/test_ldap/test_object_sid.py new file mode 100644 index 000000000..ac0c10eb7 --- /dev/null +++ b/tests/test_ldap/test_object_sid.py @@ -0,0 +1,137 @@ +"""Tests for RID Manager.""" + +import pytest +from sqlalchemy.ext.asyncio import AsyncSession + +from enums import SidPrefix +from ldap_protocol.rid_manager import RIDManagerUseCase +from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway +from ldap_protocol.rid_manager.object_sid_use_case import ObjectSIDUseCase +from ldap_protocol.rid_manager.rid_manager_gateway import RIDManagerGateway +from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase +from ldap_protocol.rid_manager.utils import from_qword, to_qword + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_rid_manager_allocate_pool( + rid_manager_use_case: RIDManagerUseCase, + rid_manager_gateway: RIDManagerGateway, +) -> None: + """Test RID Manager get domain controller.""" + available_pool = await rid_manager_gateway.get_rid_available_pool() + + await rid_manager_use_case.allocate_pool() + new_available_pool = await rid_manager_gateway.get_rid_available_pool() + lower, _ = from_qword(available_pool) + new_lower, _ = from_qword(new_available_pool) + + assert new_lower == lower + RIDManagerUseCase.RID_BLOCK_SIZE + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_next_rid( + rid_set_use_case: RIDSetUseCase, + rid_manager_use_case: RIDManagerUseCase, +) -> None: + """Test RID Manager get domain controller.""" + dc = await rid_manager_use_case.get_domain_controller() + rid_set = await rid_set_use_case.get(dc) + next_rid = await rid_set_use_case.allocate_next_rid(rid_set) + new_next_rid = await rid_set_use_case.allocate_next_rid(rid_set) + assert new_next_rid == next_rid + 1 + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_rid_set_reset_pool( + rid_set_use_case: RIDSetUseCase, + rid_manager_use_case: RIDManagerUseCase, + rid_manager_gateway: RIDManagerGateway, + rid_set_gateway: RIDSetGateway, +) -> None: + """Test RID Set pool reset.""" + dc = await rid_manager_use_case.get_domain_controller() + rid_set = await rid_set_use_case.get(dc) + + available_pool_before = await rid_manager_gateway.get_rid_available_pool() + lower_before, _ = from_qword(available_pool_before) + await rid_set_gateway.get_rid_allocation_pool( + rid_set, + ) + previous_pool_before = ( + await rid_set_gateway.get_rid_previous_allocation_pool(rid_set) + ) + + _, upper = from_qword(previous_pool_before) + await rid_set_gateway.update_next_rid_and_pool( + rid_set, + next_rid=upper - 1, + previous_allocation_pool=previous_pool_before, + ) + + assert await rid_set_use_case.is_pool_exceeded(rid_set) is True + + next_rid = await rid_set_use_case.allocate_next_rid(rid_set) + assert await rid_set_use_case.is_pool_exceeded(rid_set) is False + + available_pool_after = await rid_manager_gateway.get_rid_available_pool() + lower_after, _ = from_qword(available_pool_after) + allocation_pool_after = await rid_set_gateway.get_rid_allocation_pool( + rid_set, + ) + previous_pool_after = ( + await rid_set_gateway.get_rid_previous_allocation_pool( + rid_set, + ) + ) + + assert lower_after == lower_before + RIDManagerUseCase.RID_BLOCK_SIZE + assert previous_pool_after == to_qword( + next_rid, + lower_before + RIDManagerUseCase.RID_BLOCK_SIZE, + ) + assert allocation_pool_after == previous_pool_before + + +@pytest.mark.asyncio +@pytest.mark.usefixtures("session") +@pytest.mark.usefixtures("setup_session") +async def test_object_sid_add_updates_next_rid_and_prefix( + session: AsyncSession, + object_sid_use_case: ObjectSIDUseCase, + object_sid_gateway: ObjectSIDGateway, + rid_set_use_case: RIDSetUseCase, + rid_set_gateway: RIDSetGateway, + rid_manager_use_case: RIDManagerUseCase, +) -> None: + dc = await rid_manager_use_case.get_domain_controller() + rid_set = await rid_set_use_case.get(dc) + + next_before = await rid_set_gateway.get_rid_next_rid(rid_set) + + await object_sid_use_case.add(dc) + await session.flush() + next_after = await rid_set_gateway.get_rid_next_rid(rid_set) + assert next_after == next_before + 1 + + sid_domain_attr = await object_sid_gateway.get(dc) + assert sid_domain_attr.startswith("S-1-5-21-") + + await object_sid_use_case.add( + rid_set, + rid=512, + sid_prefix=SidPrefix.BUILT_IN_DOMAIN, + ) + await session.flush() + next_after_builtin = await rid_set_gateway.get_rid_next_rid(rid_set) + assert next_after_builtin == next_after + + sid_builtin_attr = await object_sid_gateway.get(rid_set) + assert sid_builtin_attr.startswith("S-1-5-32-") + assert sid_builtin_attr != sid_domain_attr diff --git a/tests/test_ldap/test_rid_manager.py b/tests/test_ldap/test_rid_manager.py deleted file mode 100644 index 25ad64e10..000000000 --- a/tests/test_ldap/test_rid_manager.py +++ /dev/null @@ -1,51 +0,0 @@ -"""Tests for RID Manager.""" - - -# @pytest.mark.asyncio -# @pytest.mark.usefixtures("setup_session") -# @pytest.mark.parametrize( -# "sid_prefix", -# [SidPrefix.DOMAIN_IDENTIFIER, SidPrefix.BUILT_IN_DOMAIN], -# ) -# async def test_set_object_sid( -# session: AsyncSession, -# rid_manager_gateway: RIDManagerGateway, -# rid_manager_use_case: RIDManagerUseCase, -# sid_prefix: SidPrefix, -# ) -> None: -# """Test RID Manager use case.""" -# directory = ( -# await session.scalars( -# select(Directory) -# .options(selectinload(qa(Directory.attributes))) -# .filter(get_filter_from_path("cn=user0,cn=Users,dc=md,dc=test")), -# ) -# ).one() - -# rid_set = await rid_manager_use_case.get_rid_set() -# assert rid_set -# rid_manager = await rid_manager_gateway.get_rid_manager() -# pool_before = await rid_manager_gateway.get_rid_available_pool(rid_manager) -# next_before = await rid_manager_gateway.get_next_rid(rid_set) - -# await rid_manager_use_case.set_object_sid( -# directory, -# rid=None, -# sid_prefix=sid_prefix, -# ) -# await session.commit() - -# expected_rid = next_before + 1 -# pool_after = await rid_manager_gateway.get_rid_available_pool(rid_manager) -# assert (pool_after & 0xFFFFFFFF) == expected_rid -# assert pool_after != pool_before - -# assert await rid_manager_gateway.get_next_rid(rid_set) == expected_rid - -# await session.refresh(directory, ["attributes"]) -# sid = await rid_manager_use_case.get_object_sid(directory) -# if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: -# assert sid == f"{sid_prefix}-{expected_rid}" -# else: -# assert sid.startswith(f"{sid_prefix}-") -# assert sid.endswith(f"-{expected_rid}") From 31be1accad3ea52a98d3b405043768fecd8f0f13 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 31 Mar 2026 13:44:05 +0300 Subject: [PATCH 19/37] Update: Modify Directory name in setup_session to include '-test' suffix for testing environment --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ade487450..ef6f33cd8 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1199,7 +1199,7 @@ async def setup_session( await session.refresh(dc_directory, ["id"]) await session.flush() dc = Directory( - name=settings.HOST_MACHINE_SHORT_NAME, + name=f"{settings.HOST_MACHINE_SHORT_NAME}-test", object_class="computer", is_system=True, ) From f88592c1978ee14162fda06123d8136555b22765 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 31 Mar 2026 15:17:51 +0300 Subject: [PATCH 20/37] Update: Change Directory name in setup_session to append '-test' suffix for improved test environment clarity --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index ef6f33cd8..4e62c1c01 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1188,7 +1188,7 @@ async def setup_session( is_system=False, ) dc_directory = Directory( - name=DOMAIN_CONTROLLERS_OU_NAME, + name=f"{DOMAIN_CONTROLLERS_OU_NAME}-test", object_class="computer", is_system=True, ) From 4f56163ac1b2dd8fe0f71b5b8cdac2d192bd3633 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 31 Mar 2026 16:11:29 +0300 Subject: [PATCH 21/37] Update: Remove '-test' suffix from Directory names in setup_session for consistency with production environment --- tests/conftest.py | 4 ++-- tests/test_shedule.py | 6 ++++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4e62c1c01..ade487450 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1188,7 +1188,7 @@ async def setup_session( is_system=False, ) dc_directory = Directory( - name=f"{DOMAIN_CONTROLLERS_OU_NAME}-test", + name=DOMAIN_CONTROLLERS_OU_NAME, object_class="computer", is_system=True, ) @@ -1199,7 +1199,7 @@ async def setup_session( await session.refresh(dc_directory, ["id"]) await session.flush() dc = Directory( - name=f"{settings.HOST_MACHINE_SHORT_NAME}-test", + name=settings.HOST_MACHINE_SHORT_NAME, object_class="computer", is_system=True, ) diff --git a/tests/test_shedule.py b/tests/test_shedule.py index dadbd1e5e..393a50693 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -90,8 +90,14 @@ async def test_add_domain_controller( role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, object_sid_use_case: ObjectSIDUseCase, + monkeypatch: pytest.MonkeyPatch, ) -> None: """Test add domain controller.""" + monkeypatch.setattr( + settings, + "HOST_MACHINE_SHORT_NAME", + f"{settings.HOST_MACHINE_SHORT_NAME}-test", + ) await add_domain_controller( settings=settings, session=session, From af14d05cb8aebf08821499510f82d71a6d519dc1 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 31 Mar 2026 17:19:39 +0300 Subject: [PATCH 22/37] Add: Integrate RIDManagerUseCase into test_add_domain_controller for improved domain controller retrieval --- tests/test_shedule.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/test_shedule.py b/tests/test_shedule.py index 393a50693..52cb2da6b 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -18,6 +18,7 @@ EntityTypeUseCase, ) from ldap_protocol.rid_manager import ObjectSIDUseCase +from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase from ldap_protocol.roles.role_use_case import RoleUseCase @@ -90,14 +91,25 @@ async def test_add_domain_controller( role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, object_sid_use_case: ObjectSIDUseCase, + rid_manager_use_case: RIDManagerUseCase, monkeypatch: pytest.MonkeyPatch, ) -> None: """Test add domain controller.""" + existing_dc = await rid_manager_use_case.get_domain_controller() monkeypatch.setattr( settings, "HOST_MACHINE_SHORT_NAME", f"{settings.HOST_MACHINE_SHORT_NAME}-test", ) + + async def _get_existing_dc() -> object: + return existing_dc + + monkeypatch.setattr( + object_sid_use_case._rid_manager_use_case, # noqa: SLF001 + "get_domain_controller", + _get_existing_dc, + ) await add_domain_controller( settings=settings, session=session, From f35eee27f8e7158abd5cc8695ac831d1da04f59d Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 7 Apr 2026 10:53:39 +0300 Subject: [PATCH 23/37] Refactor: Update ObjectSID and RIDSet handling to use directory IDs instead of objects for improved consistency and performance --- .../552b4eafb1aa_remove_objectsid_vals.py | 7 +- .../a1b2c3d4e5f6_rename_services_to_system.py | 3 +- app/enums.py | 3 - app/extra/scripts/add_domain_controller.py | 11 +- app/ldap_protocol/auth/setup_gateway.py | 11 +- app/ldap_protocol/auth/use_cases.py | 4 +- app/ldap_protocol/ldap_requests/add.py | 6 +- .../rid_manager/object_sid_gateway.py | 22 ++-- .../rid_manager/object_sid_use_case.py | 20 +--- .../rid_manager/rid_manager_gateway.py | 24 ++++ .../rid_manager/rid_set_gateway.py | 54 ++++----- .../rid_manager/rid_set_use_case.py | 111 +++++++++++------- .../rid_manager/setup_gateway.py | 30 +---- .../rid_manager/setup_use_case.py | 31 ++--- app/ldap_protocol/rootdse/reader.py | 2 +- app/ldap_protocol/utils/async_cache.py | 26 ++++ tests/conftest.py | 3 + tests/test_ldap/test_object_sid.py | 60 ++++++---- tests/test_shedule.py | 3 + 19 files changed, 242 insertions(+), 189 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index fa87ace60..29443e7bd 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -1,7 +1,7 @@ """Add rIDManager and rIDSet objectClasses to LDAP schema. Revision ID: 552b4eafb1aa -Revises: 19d86e660cf2 +Revises: df4287898910 Create Date: 2026-02-17 09:24:57.906080 """ @@ -230,10 +230,11 @@ async def _init_rid_manager( if rid > max_rid: max_rid = rid - start_rid = max(max_rid, RIDManagerSetupUseCase.RID_USER_MIN) + start_rid = max(max_rid, RIDManagerSetupUseCase.RID_MIN) qword = to_qword(start_rid, RIDManagerSetupUseCase.RID_AVAILABLE_MAX) - await rid_setup_gateway.set_rid_available_pool(domain, qword) + + await rid_setup_gateway.set_rid_available_pool(rid_manager_dir, qword) system_container = await rid_setup_gateway.get_system_container() await role_use_case.inherit_parent_aces( diff --git a/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py b/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py index e8d480d94..1f9f4d9d4 100644 --- a/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py +++ b/app/alembic/versions/a1b2c3d4e5f6_rename_services_to_system.py @@ -11,6 +11,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from constants import SYSTEM_CONTAINER_NAME from entities import Attribute, Directory from repo.pg.tables import queryable_attr as qa @@ -111,7 +112,7 @@ async def _rename_system_to_services(connection: AsyncConnection) -> None: # no system_dir = await session.scalar( select(Directory).where( - qa(Directory.name) == "System", + qa(Directory.name) == SYSTEM_CONTAINER_NAME, qa(Directory.is_system).is_(True), ), ) diff --git a/app/enums.py b/app/enums.py index 79d530085..cd2b786cb 100644 --- a/app/enums.py +++ b/app/enums.py @@ -293,10 +293,7 @@ class SidPrefix(StrEnum): class SecurityPrincipalRid(IntEnum): ADMINISTRATOR = 500 GUESTS = 501 - KRBTGT = 502 DOMAIN_ADMINS = 512 DOMAIN_USERS = 513 - DOMAIN_GUESTS = 514 DOMAIN_COMPUTERS = 515 - DOMAIN_CONTROLLERS = 516 DOMAIN_READ_ONLY = 521 diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py index 73b18ae84..71b3b96ca 100644 --- a/app/extra/scripts/add_domain_controller.py +++ b/app/extra/scripts/add_domain_controller.py @@ -16,7 +16,7 @@ EntityTypeUseCase, ) from ldap_protocol.objects import UserAccountControlFlag -from ldap_protocol.rid_manager import ObjectSIDUseCase +from ldap_protocol.rid_manager import ObjectSIDUseCase, RIDSetUseCase from ldap_protocol.roles.role_use_case import RoleUseCase from repo.pg.tables import queryable_attr as qa @@ -28,6 +28,7 @@ async def _add_domain_controller( settings: Settings, dc_ou_dir: Directory, object_sid_use_case: ObjectSIDUseCase, + rid_set_use_case: RIDSetUseCase, ) -> None: dc_directory = Directory( object_class="", @@ -40,9 +41,13 @@ async def _add_domain_controller( dc_directory.parent_id = dc_ou_dir.id await object_sid_use_case.add( - directory=dc_directory, + directory_id=dc_directory.id, ) await session.flush() + await rid_set_use_case.add( + domain_controller=dc_directory, + allocation_params=await rid_set_use_case.generate_rid_set_attrs(), + ) attributes = [ Attribute( @@ -105,6 +110,7 @@ async def add_domain_controller( role_use_case: RoleUseCase, entity_type_use_case: EntityTypeUseCase, object_sid_use_case: ObjectSIDUseCase, + rid_set_use_case: RIDSetUseCase, ) -> None: logger.info("Adding domain controller.") @@ -139,6 +145,7 @@ async def add_domain_controller( settings=settings, dc_ou_dir=domain_controllers_ou, object_sid_use_case=object_sid_use_case, + rid_set_use_case=rid_set_use_case, ) logger.debug("Domain controller added.") diff --git a/app/ldap_protocol/auth/setup_gateway.py b/app/ldap_protocol/auth/setup_gateway.py index 9152920f3..8f4dbf009 100644 --- a/app/ldap_protocol/auth/setup_gateway.py +++ b/app/ldap_protocol/auth/setup_gateway.py @@ -73,7 +73,6 @@ async def setup_enviroment( ) -> None: """Create directories and users for enviroment.""" async with self._session.begin_nested(): - self._session.add(domain) self._session.add( NetworkPolicy( name="Default open policy", @@ -122,14 +121,6 @@ async def setup_enviroment( logger.error(traceback.format_exc()) raise - async def is_base_domain_created(self) -> bool: - """Check if base domain is created.""" - cat_result = await self._session.execute(select(Directory)) - if cat_result.scalar_one_or_none(): - logger.warning("dev data already set up") - return True - return False - async def create_base_domain( self, dn: str = "multifactor.dev", @@ -175,7 +166,7 @@ async def create_dir( if "objectSid" in data: await self._object_sid_use_case.add( - directory=dir_, + directory_id=dir_.id, rid=int(data["objectSid"]), sid_prefix=SidPrefix.BUILT_IN_DOMAIN, ) diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index c85378cc8..564ab6bea 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -210,8 +210,6 @@ async def _create(self, dto: SetupDTO, data: list) -> None: :return: None. """ try: - if await self._setup_gateway.is_base_domain_created(): - return domain = await self._setup_gateway.create_base_domain(dto.domain) await self._rid_manager_setup_use_case.create_domain_identifier() await self._setup_gateway.setup_enviroment( @@ -255,7 +253,7 @@ async def _create(self, dto: SetupDTO, data: list) -> None: await self._rid_manager_setup_use_case.setup() dc = await self._rid_manager_use_case.get_domain_controller() await self._object_sid_use_case.add( - directory=dc, + directory_id=dc.id, ) await self._session.commit() diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index de5e4106a..c5e08f943 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -104,9 +104,9 @@ def from_data(cls, data: ASN1Row) -> "AddRequest": type=attr.value[0].value, vals=[val.value for val in attr.value[1].value], ) - for attr in attributes.value # type: ignore + for attr in attributes.value ] - return cls(entry=entry.value, attributes=attributes) # type: ignore + return cls(entry=entry.value, attributes=attributes) async def handle( # noqa: C901 self, @@ -214,7 +214,7 @@ async def handle( # noqa: C901 await ctx.session.flush() await ctx.object_sid_use_case.add( - directory=new_dir, + directory_id=new_dir.id, ) await ctx.session.flush() except IntegrityError: diff --git a/app/ldap_protocol/rid_manager/object_sid_gateway.py b/app/ldap_protocol/rid_manager/object_sid_gateway.py index 5ba990a39..83e655e7f 100644 --- a/app/ldap_protocol/rid_manager/object_sid_gateway.py +++ b/app/ldap_protocol/rid_manager/object_sid_gateway.py @@ -7,11 +7,12 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from entities import Attribute, Directory +from entities import Attribute from ldap_protocol.rid_manager.exceptions import ( RIDManagerDomainIdentifierNotFoundError, RIDManagerObjectSIDNotFoundError, ) +from ldap_protocol.utils.async_cache import domain_identifier_cache from repo.pg.tables import queryable_attr as qa @@ -22,11 +23,11 @@ def __init__(self, session: AsyncSession) -> None: """Initialize Object SID gateway.""" self._session = session - async def get(self, directory: Directory) -> str: + async def get(self, directory_id: int) -> str: """Get object SID.""" query = await self._session.scalar( select(Attribute).where( - qa(Attribute.directory_id) == directory.id, + qa(Attribute.directory_id) == directory_id, qa(Attribute.name) == "objectSid", ), ) @@ -35,25 +36,26 @@ async def get(self, directory: Directory) -> str: return query.value - async def add(self, directory: Directory, object_sid: str) -> None: + async def add(self, directory_id: int, object_sid: str) -> None: """Add object SID.""" self._session.add( Attribute( name="objectSid", value=object_sid, - directory_id=directory.id, + directory_id=directory_id, ), ) - async def get_domain_identifier(self, domain: Directory) -> str: - """Get domain identifier. + async def get_domain_identifier(self) -> str: + """Get domain identifier (cached ``Attribute.value`` string).""" + return await domain_identifier_cache.get_or_load( + self._load_domain_identifier_value, + ) - :return: Domain identifier - """ + async def _load_domain_identifier_value(self) -> str: query = await self._session.scalar( select(Attribute).where( qa(Attribute.name) == "DomainIdentifier", - qa(Attribute.directory_id) == domain.id, ), ) diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py index 953f4bb1f..a67c17ae6 100644 --- a/app/ldap_protocol/rid_manager/object_sid_use_case.py +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -6,12 +6,10 @@ from sqlalchemy.ext.asyncio import AsyncSession -from entities import Directory from enums import SidPrefix from ldap_protocol.rid_manager.object_sid_gateway import ObjectSIDGateway from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase -from ldap_protocol.utils.queries import get_base_directories class ObjectSIDUseCase: @@ -30,13 +28,13 @@ def __init__( self._session = session self._rid_manager_use_case = rid_manager_use_case - async def get(self, directory: Directory) -> str: + async def get(self, directory_id: int) -> str: """Get object SID.""" - return await self._gateway.get(directory) + return await self._gateway.get(directory_id) async def add( self, - directory: Directory, + directory_id: int, rid: int | None = None, sid_prefix: SidPrefix = SidPrefix.DOMAIN_IDENTIFIER, ) -> None: @@ -47,19 +45,13 @@ async def add( ) rid_set = await self._rid_set_use_case.get(domain_controller) rid = await self._rid_set_use_case.allocate_next_rid( - rid_set, + rid_set.id, ) if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: object_sid = f"{sid_prefix}-{rid}" elif sid_prefix == SidPrefix.DOMAIN_IDENTIFIER: - domain_identifier = await self.get_domain_identifier() + domain_identifier = await self._gateway.get_domain_identifier() object_sid = f"{sid_prefix}-{domain_identifier}-{rid}" - await self._gateway.add(directory, object_sid) - - async def get_domain_identifier(self) -> str: - """Get domain identifier.""" - domain = (await get_base_directories(self._session))[0] - - return await self._gateway.get_domain_identifier(domain) + await self._gateway.add(directory_id, object_sid) diff --git a/app/ldap_protocol/rid_manager/rid_manager_gateway.py b/app/ldap_protocol/rid_manager/rid_manager_gateway.py index 1b6402acc..059312069 100644 --- a/app/ldap_protocol/rid_manager/rid_manager_gateway.py +++ b/app/ldap_protocol/rid_manager/rid_manager_gateway.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from config import Settings +from constants import DOMAIN_CONTROLLERS_OU_NAME from entities import Attribute, Directory from ldap_protocol.rid_manager.exceptions import ( RIDManagerAvailablePoolNotFoundError, @@ -59,9 +60,32 @@ async def get_domain_controller( self, ) -> Directory: """Get domain controller.""" + domain = await self._session.scalar( + select(Directory).where( + qa(Directory.object_class) == "domain", + qa(Directory.parent_id).is_(None), + ), + ) + if not domain: + raise RIDManagerDomainControllerNotFoundError( + "Domain controller not found", + ) + + domain_controllers_ou = await self._session.scalar( + select(Directory).where( + qa(Directory.name) == DOMAIN_CONTROLLERS_OU_NAME, + qa(Directory.parent_id) == domain.id, + ), + ) + if not domain_controllers_ou: + raise RIDManagerDomainControllerNotFoundError( + "Domain controllers OU not found", + ) + domain_controller = await self._session.scalar( select(Directory).where( qa(Directory.name) == self._settings.HOST_MACHINE_SHORT_NAME, + qa(Directory.parent_id) == domain_controllers_ou.id, ), ) if not domain_controller: diff --git a/app/ldap_protocol/rid_manager/rid_set_gateway.py b/app/ldap_protocol/rid_manager/rid_set_gateway.py index b032a3fe9..0318e8fc7 100644 --- a/app/ldap_protocol/rid_manager/rid_set_gateway.py +++ b/app/ldap_protocol/rid_manager/rid_set_gateway.py @@ -88,7 +88,7 @@ async def add(self, domain_controller: Directory) -> Directory: async def set_allocation_attrs( self, - rid_set: Directory, + rid_set_id: int, allocation_params: RIDSetAllocationParamsDTO, ) -> None: """Set next RID attribute in RID Set directory.""" @@ -96,30 +96,30 @@ async def set_allocation_attrs( Attribute( name="rIDNextRID", value=str(allocation_params.next_rid), - directory_id=rid_set.id, + directory_id=rid_set_id, ), ) self._session.add( Attribute( name="rIDPreviousAllocationPool", value=str(allocation_params.previous_allocation_pool), - directory_id=rid_set.id, + directory_id=rid_set_id, ), ) self._session.add( Attribute( name="rIDAllocationPool", value=str(allocation_params.allocation_pool), - directory_id=rid_set.id, + directory_id=rid_set_id, ), ) - async def get_rid_allocation_pool(self, rid_set: Directory) -> int: + async def get_rid_allocation_pool(self, rid_set_id: int) -> int: """Get RID allocation pool from RID Set directory.""" allocation_pool = await self._session.scalar( select(Attribute).where( qa(Attribute.name) == "rIDAllocationPool", - qa(Attribute.directory_id) == rid_set.id, + qa(Attribute.directory_id) == rid_set_id, ), ) if not (allocation_pool and allocation_pool.value): @@ -130,14 +130,14 @@ async def get_rid_allocation_pool(self, rid_set: Directory) -> int: async def get_rid_previous_allocation_pool( self, - rid_set: Directory, + rid_set_id: int, ) -> int: """Get previous RID allocation pool from RID Set directory.""" previous_allocation_pool = await self._session.scalar( select(Attribute) .where( qa(Attribute.name) == "rIDPreviousAllocationPool", - qa(Attribute.directory_id) == rid_set.id, + qa(Attribute.directory_id) == rid_set_id, ) .with_for_update(), ) @@ -147,13 +147,13 @@ async def get_rid_previous_allocation_pool( ) return int(previous_allocation_pool.value) - async def get_rid_next_rid(self, rid_set: Directory) -> int: + async def get_rid_next_rid(self, rid_set_id: int) -> int: """Get next RID from RID Set directory.""" next_rid = await self._session.scalar( select(Attribute) .where( qa(Attribute.name) == "rIDNextRID", - qa(Attribute.directory_id) == rid_set.id, + qa(Attribute.directory_id) == rid_set_id, ) .with_for_update(), ) @@ -161,33 +161,24 @@ async def get_rid_next_rid(self, rid_set: Directory) -> int: raise RIDManagerRidNextRIDNotFoundError("next RID not found") return int(next_rid.value) - async def update_next_rid_and_pool( + async def update_next_rid( self, - rid_set: Directory, + rid_set_id: int, next_rid: int, - previous_allocation_pool: int, ) -> None: - """Update next RID and pool.""" + """Update next RID.""" await self._session.execute( update(Attribute) .where( qa(Attribute.name) == "rIDNextRID", - qa(Attribute.directory_id) == rid_set.id, + qa(Attribute.directory_id) == rid_set_id, ) .values(value=str(next_rid)), ) - await self._session.execute( - update(Attribute) - .where( - qa(Attribute.name) == "rIDPreviousAllocationPool", - qa(Attribute.directory_id) == rid_set.id, - ) - .values(value=str(previous_allocation_pool)), - ) async def reset_attrs_when_pool_exceeded( self, - rid_set: Directory, + rid_set_id: int, allocation_pool: int, previous_allocation_pool: int, next_rid: int, @@ -197,12 +188,19 @@ async def reset_attrs_when_pool_exceeded( update(Attribute) .where( qa(Attribute.name) == "rIDAllocationPool", - qa(Attribute.directory_id) == rid_set.id, + qa(Attribute.directory_id) == rid_set_id, ) .values(value=str(allocation_pool)), ) - await self.update_next_rid_and_pool( - rid_set, + await self._session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "rIDPreviousAllocationPool", + qa(Attribute.directory_id) == rid_set_id, + ) + .values(value=str(previous_allocation_pool)), + ) + await self.update_next_rid( + rid_set_id, next_rid, - previous_allocation_pool, ) diff --git a/app/ldap_protocol/rid_manager/rid_set_use_case.py b/app/ldap_protocol/rid_manager/rid_set_use_case.py index dfd60661e..b0ca28f0b 100644 --- a/app/ldap_protocol/rid_manager/rid_set_use_case.py +++ b/app/ldap_protocol/rid_manager/rid_set_use_case.py @@ -11,7 +11,8 @@ from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway -from ldap_protocol.rid_manager.utils import from_qword, to_qword +from ldap_protocol.rid_manager.utils import from_qword +from ldap_protocol.roles.role_use_case import RoleUseCase class RIDSetUseCase: @@ -23,12 +24,14 @@ def __init__( entity_type_dao: EntityTypeDAO, session: AsyncSession, rid_manager_use_case: RIDManagerUseCase, + role_use_case: RoleUseCase, ) -> None: """Initialize RID Set use case.""" self._gateway = gateway self._entity_type_dao = entity_type_dao self._session = session self._rid_manager_use_case = rid_manager_use_case + self._role_use_case = role_use_case async def get(self, domain_controller: Directory) -> Directory: """Get RID Set directory.""" @@ -47,64 +50,92 @@ async def add( ) await self._gateway.set_allocation_attrs( - rid_set, + rid_set.id, allocation_params, ) + await self.inherit_parent_aces( + domain_controller=domain_controller, + rid_set=rid_set, + ) await self._session.flush() return rid_set - async def is_pool_exceeded(self, rid_set: Directory) -> bool: + def is_pool_exceeded( + self, + current_next_rid: int, + previous_allocation_pool: int, + ) -> bool: """Check if RID pool is exceeded.""" - next_rid = await self._gateway.get_rid_next_rid(rid_set) - previous_allocation_pool = ( - await self._gateway.get_rid_previous_allocation_pool(rid_set) - ) _, upper = from_qword(previous_allocation_pool) - return next_rid + 1 >= upper + return current_next_rid + 1 > upper - async def allocate_next_rid(self, rid_set: Directory) -> int: + async def allocate_next_rid(self, rid_set_id: int) -> int: """Allocate next RID.""" async with self._session.begin_nested(): - if await self.is_pool_exceeded(rid_set): - previous_allocation_pool = ( - await self._rid_manager_use_case.allocate_pool() - ) - await self.reset_attrs_when_pool_exceeded( - rid_set, - previous_allocation_pool, - ) - current_rid = await self._gateway.get_rid_next_rid(rid_set) + current_next_rid = await self._gateway.get_rid_next_rid(rid_set_id) previous_allocation_pool = ( - await self._gateway.get_rid_previous_allocation_pool(rid_set) + await self._gateway.get_rid_previous_allocation_pool( + rid_set_id, + ) ) - _, upper = from_qword(previous_allocation_pool) - new_rid = current_rid + 1 - new_allocation_pool = to_qword(new_rid, upper) - await self._gateway.update_next_rid_and_pool( - rid_set, - new_rid, - new_allocation_pool, + + if self.is_pool_exceeded( + current_next_rid, + previous_allocation_pool, + ): + new_next_rid = await self.rebind_next_rid_from_new_pool( + rid_set_id, + ) + else: + new_next_rid = current_next_rid + 1 + + await self._gateway.update_next_rid( + rid_set_id, + new_next_rid, ) - return new_rid + return new_next_rid - async def reset_attrs_when_pool_exceeded( + async def rebind_next_rid_from_new_pool( self, - rid_set: Directory, - previous_allocation_pool: int, - ) -> None: - """Reset RID pools when pool exceeded.""" - _ = await self._gateway.get_rid_next_rid(rid_set) # lock next RID + rid_set_id: int, + ) -> int: + """Rebind next RID from new pool.""" + new_allocation_pool = await self._rid_manager_use_case.allocate_pool() - current_previous_allocation_pool = ( - await self._gateway.get_rid_previous_allocation_pool( - rid_set, - ) + current_allocation_pool = await self._gateway.get_rid_allocation_pool( + rid_set_id, ) - lower, _ = from_qword(previous_allocation_pool) + lower, _ = from_qword(current_allocation_pool) await self._gateway.reset_attrs_when_pool_exceeded( - rid_set=rid_set, + rid_set_id=rid_set_id, + next_rid=lower, + allocation_pool=new_allocation_pool, + previous_allocation_pool=current_allocation_pool, + ) + return lower + + async def generate_rid_set_attrs(self) -> RIDSetAllocationParamsDTO: + """Generate RID Set attributes.""" + previous_allocation_pool = ( + await self._rid_manager_use_case.allocate_pool() + ) + allocation_pool = await self._rid_manager_use_case.allocate_pool() + lower, _ = from_qword(previous_allocation_pool) + + return RIDSetAllocationParamsDTO( next_rid=lower, - allocation_pool=current_previous_allocation_pool, + allocation_pool=allocation_pool, previous_allocation_pool=previous_allocation_pool, ) + + async def inherit_parent_aces( + self, + domain_controller: Directory, + rid_set: Directory, + ) -> None: + """Inherit parent ACEs to RID Set directory.""" + await self._role_use_case.inherit_parent_aces( + parent_directory=domain_controller, + directory=rid_set, + ) diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py index acac8a496..f89c325fd 100644 --- a/app/ldap_protocol/rid_manager/setup_gateway.py +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -9,11 +9,11 @@ from sqlalchemy import exists, select, update from sqlalchemy.ext.asyncio import AsyncSession +from constants import SYSTEM_CONTAINER_NAME from entities import Attribute, Directory from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO from ldap_protocol.rid_manager.exceptions import ( RIDManagerBaseDomainNotFoundError, - RIDManagerDomainControllerNotFoundError, RIDManagerSystemContainerNotFoundError, ) from ldap_protocol.utils.queries import get_base_directories @@ -32,24 +32,6 @@ def __init__( self._session = session self._entity_type_dao = entity_type_dao - async def get_domain_controller(self, host_machine_name: str) -> Directory: - """Get domain controller directory. - - :return: Domain controller directory - """ - dc = await self._session.scalar( - select(Directory).where( - qa(Directory.name) == host_machine_name, - ), - ) - - if not dc: - raise RIDManagerDomainControllerNotFoundError( - "Domain controller not found", - ) - - return dc - async def get_system_container(self) -> Directory: """Get System container directory. @@ -60,7 +42,7 @@ async def get_system_container(self) -> Directory: domain = base_dn_list[0] query = select(Directory).where( - qa(Directory.name) == "System", + qa(Directory.name) == SYSTEM_CONTAINER_NAME, qa(Directory.parent_id) == domain.id, ) @@ -132,20 +114,20 @@ async def set_rid_manager(self) -> Directory: async def set_rid_available_pool( self, - domain: Directory, + rid_manager_dir: Directory, qword_value: int, ) -> None: """Set rIDAvailablePool attribute in domain. Updates the global RID pool counter. - :param domain: Domain directory object + :param rid_manager_dir: RID Manager directory object :param qword_value: New QWORD value (64-bit) """ query = ( update(Attribute) .where( - qa(Attribute.directory_id) == domain.id, + qa(Attribute.directory_id) == rid_manager_dir.id, qa(Attribute.name) == "rIDAvailablePool", ) .values(value=str(qword_value)) @@ -156,7 +138,7 @@ async def set_rid_available_pool( if result.rowcount == 0: self._session.add( Attribute( - directory_id=domain.id, + directory_id=rid_manager_dir.id, name="rIDAvailablePool", value=str(qword_value), ), diff --git a/app/ldap_protocol/rid_manager/setup_use_case.py b/app/ldap_protocol/rid_manager/setup_use_case.py index 02c41ada1..ea8114c02 100644 --- a/app/ldap_protocol/rid_manager/setup_use_case.py +++ b/app/ldap_protocol/rid_manager/setup_use_case.py @@ -7,11 +7,10 @@ from config import Settings from entities import Directory -from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase from ldap_protocol.rid_manager.setup_gateway import RIDManagerSetupGateway -from ldap_protocol.rid_manager.utils import from_qword, to_qword +from ldap_protocol.rid_manager.utils import to_qword from ldap_protocol.roles.ace_dao import AccessControlEntryDAO from ldap_protocol.roles.role_use_case import RoleUseCase @@ -19,9 +18,7 @@ class RIDManagerSetupUseCase: """RID Manager setup use case.""" - RID_BUILTIN_MIN = 500 - RID_BUILTIN_MAX = 1000 - RID_USER_MIN = 1100 + RID_MIN = 1100 RID_AVAILABLE_MAX = 1073741822 # 30-bit max (2^30 - 2) def __init__( @@ -47,15 +44,17 @@ def __init__( async def setup(self) -> None: """Create RID Manager.""" - await self.create_domain_identifier() rid_manager_dir = await self._gateway.set_rid_manager() - qword = to_qword(self.RID_USER_MIN, self.RID_AVAILABLE_MAX) + qword = to_qword(self.RID_MIN, self.RID_AVAILABLE_MAX) await self._gateway.set_rid_available_pool( rid_manager_dir, qword, ) dc = await self._rid_manager_use_case.get_domain_controller() - rid_set = await self._create_rid_set(dc) + rid_set = await self._rid_set_use_case.add( + dc, + await self._rid_set_use_case.generate_rid_set_attrs(), + ) await self.inherit_aces( rid_manager_dir, @@ -63,22 +62,6 @@ async def setup(self) -> None: rid_set, ) - async def _create_rid_set(self, domain_controller: Directory) -> Directory: - previous_allocation_pool = ( - await self._rid_manager_use_case.allocate_pool() - ) - allocation_pool = await self._rid_manager_use_case.allocate_pool() - lower, _ = from_qword(previous_allocation_pool) - - return await self._rid_set_use_case.add( - domain_controller, - RIDSetAllocationParamsDTO( - next_rid=lower, - allocation_pool=allocation_pool, - previous_allocation_pool=previous_allocation_pool, - ), - ) - async def inherit_aces( self, rid_manager_dir: Directory, diff --git a/app/ldap_protocol/rootdse/reader.py b/app/ldap_protocol/rootdse/reader.py index 6ceacfe7a..6e4464416 100644 --- a/app/ldap_protocol/rootdse/reader.py +++ b/app/ldap_protocol/rootdse/reader.py @@ -102,7 +102,7 @@ async def get(self) -> DomainControllerInfo: domain = await self._gw.get_domain() dns = domain.name.lower() nb_domain = dns.split(".")[0].upper() - object_sid = await self._object_sid_use_case.get(domain) + object_sid = await self._object_sid_use_case.get(domain.id) return DomainControllerInfo( net_bios_domain=nb_domain, diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py index f723440a6..ff544251f 100644 --- a/app/ldap_protocol/utils/async_cache.py +++ b/app/ldap_protocol/utils/async_cache.py @@ -43,4 +43,30 @@ async def wrapper(*args: tuple, **kwargs: dict) -> T: return wrapper +class SingleValueTTLCache(Generic[T]): + """Single cached value; refresh via ``loader`` on miss or TTL expiry.""" + + def __init__(self, ttl: int | None = DEFAULT_CACHE_TIME) -> None: + self._ttl = ttl + self._value: T | None = None + self._expires_at: float | None = None + + def clear(self) -> None: + self._value = None + self._expires_at = None + + async def get_or_load(self, loader: Callable[[], Awaitable[T]]) -> T: + """Return cached value or ``await loader()`` and store it.""" + if self._value is not None: + if not self._expires_at or self._expires_at > time.monotonic(): + return self._value + self.clear() + + result = await loader() + self._value = result + self._expires_at = time.monotonic() + self._ttl if self._ttl else None + return result + + base_directories_cache = AsyncTTLCache[list[Directory]]() +domain_identifier_cache = SingleValueTTLCache[str]() diff --git a/tests/conftest.py b/tests/conftest.py index ade487450..08ffd9914 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1151,6 +1151,7 @@ async def setup_session( entity_type_dao, session, rid_manager_use_case, + role_use_case, ) object_sid_gateway = ObjectSIDGateway(session) object_sid_use_case = ObjectSIDUseCase( @@ -1843,6 +1844,7 @@ async def rid_set_use_case( rid_manager_use_case: RIDManagerUseCase, entity_type_dao: EntityTypeDAO, rid_set_gateway: RIDSetGateway, + role_use_case: RoleUseCase, ) -> AsyncIterator[RIDSetUseCase]: """Provide RIDManagerUseCase for tests that request it explicitly.""" async with container(scope=Scope.SESSION) as container: @@ -1852,6 +1854,7 @@ async def rid_set_use_case( entity_type_dao, session, rid_manager_use_case, + role_use_case, ) diff --git a/tests/test_ldap/test_object_sid.py b/tests/test_ldap/test_object_sid.py index ac0c10eb7..c4347189a 100644 --- a/tests/test_ldap/test_object_sid.py +++ b/tests/test_ldap/test_object_sid.py @@ -41,8 +41,9 @@ async def test_next_rid( """Test RID Manager get domain controller.""" dc = await rid_manager_use_case.get_domain_controller() rid_set = await rid_set_use_case.get(dc) - next_rid = await rid_set_use_case.allocate_next_rid(rid_set) - new_next_rid = await rid_set_use_case.allocate_next_rid(rid_set) + rid_set_id = rid_set.id + next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) + new_next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) assert new_next_rid == next_rid + 1 @@ -58,36 +59,47 @@ async def test_rid_set_reset_pool( """Test RID Set pool reset.""" dc = await rid_manager_use_case.get_domain_controller() rid_set = await rid_set_use_case.get(dc) + rid_set_id = rid_set.id available_pool_before = await rid_manager_gateway.get_rid_available_pool() lower_before, _ = from_qword(available_pool_before) - await rid_set_gateway.get_rid_allocation_pool( - rid_set, - ) previous_pool_before = ( - await rid_set_gateway.get_rid_previous_allocation_pool(rid_set) + await rid_set_gateway.get_rid_previous_allocation_pool(rid_set_id) ) _, upper = from_qword(previous_pool_before) - await rid_set_gateway.update_next_rid_and_pool( - rid_set, - next_rid=upper - 1, - previous_allocation_pool=previous_pool_before, - ) + await rid_set_gateway.update_next_rid(rid_set_id, upper - 1) - assert await rid_set_use_case.is_pool_exceeded(rid_set) is True + current_next_rid = await rid_set_gateway.get_rid_next_rid(rid_set_id) + assert ( + rid_set_use_case.is_pool_exceeded( + current_next_rid, + previous_pool_before, + ) + is True + ) - next_rid = await rid_set_use_case.allocate_next_rid(rid_set) - assert await rid_set_use_case.is_pool_exceeded(rid_set) is False + next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) + current_next_rid = await rid_set_gateway.get_rid_next_rid(rid_set_id) + previous_pool_mid = await rid_set_gateway.get_rid_previous_allocation_pool( + rid_set_id, + ) + assert ( + rid_set_use_case.is_pool_exceeded( + current_next_rid, + previous_pool_mid, + ) + is False + ) available_pool_after = await rid_manager_gateway.get_rid_available_pool() lower_after, _ = from_qword(available_pool_after) allocation_pool_after = await rid_set_gateway.get_rid_allocation_pool( - rid_set, + rid_set_id, ) previous_pool_after = ( await rid_set_gateway.get_rid_previous_allocation_pool( - rid_set, + rid_set_id, ) ) @@ -112,26 +124,28 @@ async def test_object_sid_add_updates_next_rid_and_prefix( ) -> None: dc = await rid_manager_use_case.get_domain_controller() rid_set = await rid_set_use_case.get(dc) + rid_set_id = rid_set.id + dc_id = dc.id - next_before = await rid_set_gateway.get_rid_next_rid(rid_set) + next_before = await rid_set_gateway.get_rid_next_rid(rid_set_id) - await object_sid_use_case.add(dc) + await object_sid_use_case.add(directory_id=dc_id) await session.flush() - next_after = await rid_set_gateway.get_rid_next_rid(rid_set) + next_after = await rid_set_gateway.get_rid_next_rid(rid_set_id) assert next_after == next_before + 1 - sid_domain_attr = await object_sid_gateway.get(dc) + sid_domain_attr = await object_sid_gateway.get(dc_id) assert sid_domain_attr.startswith("S-1-5-21-") await object_sid_use_case.add( - rid_set, + directory_id=rid_set_id, rid=512, sid_prefix=SidPrefix.BUILT_IN_DOMAIN, ) await session.flush() - next_after_builtin = await rid_set_gateway.get_rid_next_rid(rid_set) + next_after_builtin = await rid_set_gateway.get_rid_next_rid(rid_set_id) assert next_after_builtin == next_after - sid_builtin_attr = await object_sid_gateway.get(rid_set) + sid_builtin_attr = await object_sid_gateway.get(rid_set_id) assert sid_builtin_attr.startswith("S-1-5-32-") assert sid_builtin_attr != sid_domain_attr diff --git a/tests/test_shedule.py b/tests/test_shedule.py index 52cb2da6b..d6d9fbba3 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -19,6 +19,7 @@ ) from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase +from ldap_protocol.rid_manager.rid_set_use_case import RIDSetUseCase from ldap_protocol.roles.role_use_case import RoleUseCase @@ -92,6 +93,7 @@ async def test_add_domain_controller( entity_type_use_case: EntityTypeUseCase, object_sid_use_case: ObjectSIDUseCase, rid_manager_use_case: RIDManagerUseCase, + rid_set_use_case: RIDSetUseCase, monkeypatch: pytest.MonkeyPatch, ) -> None: """Test add domain controller.""" @@ -116,4 +118,5 @@ async def _get_existing_dc() -> object: role_use_case=role_use_case, entity_type_use_case=entity_type_use_case, object_sid_use_case=object_sid_use_case, + rid_set_use_case=rid_set_use_case, ) From 8962316315073ec9fc1558fa1b6c638edd6f4ded Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 7 Apr 2026 11:43:22 +0300 Subject: [PATCH 24/37] Refactor: Update entity type handling in RIDManager and RIDSet use cases to improve code organization and maintainability --- .../552b4eafb1aa_remove_objectsid_vals.py | 8 +-- app/ldap_protocol/ldap_requests/add.py | 4 +- .../ldap_schema/directory_create_use_case.py | 10 +++- .../rid_manager/rid_set_use_case.py | 11 ++-- .../rid_manager/setup_gateway.py | 11 ++-- tests/conftest.py | 54 +++++++++---------- 6 files changed, 57 insertions(+), 41 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 29443e7bd..a834f1578 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -1,7 +1,7 @@ """Add rIDManager and rIDSet objectClasses to LDAP schema. Revision ID: 552b4eafb1aa -Revises: df4287898910 +Revises: 1b71cafba681 Create Date: 2026-02-17 09:24:57.906080 """ @@ -17,7 +17,9 @@ from entities import Attribute, Directory, EntityType from enums import EntityTypeNames from ldap_protocol.ldap_schema.dto import EntityTypeDTO -from ldap_protocol.ldap_schema.entity_type_use_case import EntityTypeUseCase +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.rid_manager import ( RIDManagerGateway, RIDManagerSetupGateway, @@ -38,7 +40,7 @@ # revision identifiers, used by Alembic. revision: None | str = "552b4eafb1aa" -down_revision: None | str = "df4287898910" +down_revision: None | str = "1b71cafba681" branch_labels: None | list[str] = None depends_on: None | list[str] = None diff --git a/app/ldap_protocol/ldap_requests/add.py b/app/ldap_protocol/ldap_requests/add.py index c5e08f943..ff625d2a5 100644 --- a/app/ldap_protocol/ldap_requests/add.py +++ b/app/ldap_protocol/ldap_requests/add.py @@ -104,9 +104,9 @@ def from_data(cls, data: ASN1Row) -> "AddRequest": type=attr.value[0].value, vals=[val.value for val in attr.value[1].value], ) - for attr in attributes.value + for attr in attributes.value # type: ignore ] - return cls(entry=entry.value, attributes=attributes) + return cls(entry=entry.value, attributes=attributes) # type: ignore async def handle( # noqa: C901 self, diff --git a/app/ldap_protocol/ldap_schema/directory_create_use_case.py b/app/ldap_protocol/ldap_schema/directory_create_use_case.py index 9e2e4a033..8d067fb47 100644 --- a/app/ldap_protocol/ldap_schema/directory_create_use_case.py +++ b/app/ldap_protocol/ldap_schema/directory_create_use_case.py @@ -8,6 +8,7 @@ from sqlalchemy.ext.asyncio import AsyncSession +from enums import SidPrefix from ldap_protocol.ldap_schema.attribute_dao import AttributeDAO from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO from ldap_protocol.ldap_schema.dto import AttributeDTO, DirCreateDTO @@ -15,6 +16,7 @@ EntityTypeUseCase, ) from ldap_protocol.ldap_schema.exceptions import CantCreateDirectoryError +from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.roles.role_use_case import RoleUseCase if TYPE_CHECKING: @@ -46,6 +48,7 @@ def __init__( role_use_case: RoleUseCase, directory_dao: DirectoryDAO, attribute_dao: AttributeDAO, + object_sid_use_case: ObjectSIDUseCase, ) -> None: """Initialize.""" self.__session = session @@ -53,6 +56,7 @@ def __init__( self.__role_use_case = role_use_case self.__directory_dao = directory_dao self.__attribute_dao = attribute_dao + self.__object_sid_use_case = object_sid_use_case async def get_configuration_dir(self) -> "Directory": """Get configuration directory.""" @@ -85,7 +89,11 @@ async def create_dir( else: raise CantCreateDirectoryError("Cannot create a directory.") - dir_.object_sid = _get_object_sid(base_dn_sid, dir_.id) + await self.__object_sid_use_case.add( + directory_id=dir_.id, + rid=int(base_dn_sid), + sid_prefix=SidPrefix.BUILT_IN_DOMAIN, + ) attr_dto = AttributeDTO(name=dir_.rdname, values=[dir_.name]) await self.__attribute_dao.add_directory_name_attribute( diff --git a/app/ldap_protocol/rid_manager/rid_set_use_case.py b/app/ldap_protocol/rid_manager/rid_set_use_case.py index b0ca28f0b..9c714c2ee 100644 --- a/app/ldap_protocol/rid_manager/rid_set_use_case.py +++ b/app/ldap_protocol/rid_manager/rid_set_use_case.py @@ -7,7 +7,9 @@ from sqlalchemy.ext.asyncio import AsyncSession from entities import Directory -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO from ldap_protocol.rid_manager.rid_manager_use_case import RIDManagerUseCase from ldap_protocol.rid_manager.rid_set_gateway import RIDSetGateway @@ -21,14 +23,14 @@ class RIDSetUseCase: def __init__( self, gateway: RIDSetGateway, - entity_type_dao: EntityTypeDAO, + entity_type_use_case: EntityTypeUseCase, session: AsyncSession, rid_manager_use_case: RIDManagerUseCase, role_use_case: RoleUseCase, ) -> None: """Initialize RID Set use case.""" self._gateway = gateway - self._entity_type_dao = entity_type_dao + self._entity_type_use_case = entity_type_use_case self._session = session self._rid_manager_use_case = rid_manager_use_case self._role_use_case = role_use_case @@ -44,9 +46,10 @@ async def add( ) -> Directory: """Create RID Set directory.""" rid_set = await self._gateway.add(domain_controller) - await self._entity_type_dao.attach_entity_type_to_directory( + await self._entity_type_use_case.attach_entity_type_to_directory( directory=rid_set, is_system_entity_type=True, + object_class_names={"top", "rIDSet"}, ) await self._gateway.set_allocation_attrs( diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py index f89c325fd..4ba027635 100644 --- a/app/ldap_protocol/rid_manager/setup_gateway.py +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -11,7 +11,9 @@ from constants import SYSTEM_CONTAINER_NAME from entities import Attribute, Directory -from ldap_protocol.ldap_schema.entity_type_dao import EntityTypeDAO +from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( + EntityTypeUseCase, +) from ldap_protocol.rid_manager.exceptions import ( RIDManagerBaseDomainNotFoundError, RIDManagerSystemContainerNotFoundError, @@ -26,11 +28,11 @@ class RIDManagerSetupGateway: def __init__( self, session: AsyncSession, - entity_type_dao: EntityTypeDAO, + entity_type_use_case: EntityTypeUseCase, ) -> None: """Initialize RID Manager setup gateway.""" self._session = session - self._entity_type_dao = entity_type_dao + self._entity_type_use_case = entity_type_use_case async def get_system_container(self) -> Directory: """Get System container directory. @@ -103,9 +105,10 @@ async def set_rid_manager(self) -> Directory: with_for_update=None, ) - await self._entity_type_dao.attach_entity_type_to_directory( + await self._entity_type_use_case.attach_entity_type_to_directory( directory=rid_manager_dir, is_system_entity_type=True, + object_class_names={"top", "rIDManager"}, ) await self._session.flush() diff --git a/tests/conftest.py b/tests/conftest.py index 08ffd9914..4cbbb935c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1086,12 +1086,35 @@ async def setup_session( object_class_dao=object_class_dao, directory_dao=directory_dao, ) + rid_manager_gateway = RIDManagerGateway(session, settings) + + rid_manager_use_case = RIDManagerUseCase( + rid_manager_gateway, + session, + ) + rid_set_gateway = RIDSetGateway(session) + + rid_set_use_case = RIDSetUseCase( + rid_set_gateway, + entity_type_use_case, + session, + rid_manager_use_case, + role_use_case, + ) + object_sid_gateway = ObjectSIDGateway(session) + object_sid_use_case = ObjectSIDUseCase( + object_sid_gateway, + rid_set_use_case, + session, + rid_manager_use_case, + ) directory_create_use_case = DirectoryCreateUseCase( session=session, entity_type_use_case=entity_type_use_case, role_use_case=role_use_case, directory_dao=directory_dao, attribute_dao=attribute_dao, + object_sid_use_case=object_sid_use_case, ) object_class_use_case = ObjectClassUseCase( attribute_type_dao=attribute_type_dao, @@ -1128,38 +1151,15 @@ async def setup_session( password_policy_validator, password_ban_word_repository, ) - rid_manager_gateway = RIDManagerGateway(session, settings) - rid_manager_use_case = RIDManagerUseCase( - rid_manager_gateway, - session, - ) + rid_manager_setup_gateway = RIDManagerSetupGateway( session=session, - entity_type_dao=entity_type_dao, + entity_type_use_case=entity_type_use_case, ) role_dao = RoleDAO(session) ace_dao = AccessControlEntryDAO(session) role_use_case = RoleUseCase(role_dao, ace_dao) - rid_manager_use_case = RIDManagerUseCase( - rid_manager_gateway, - session, - ) - rid_set_gateway = RIDSetGateway(session) - rid_set_use_case = RIDSetUseCase( - rid_set_gateway, - entity_type_dao, - session, - rid_manager_use_case, - role_use_case, - ) - object_sid_gateway = ObjectSIDGateway(session) - object_sid_use_case = ObjectSIDUseCase( - object_sid_gateway, - rid_set_use_case, - session, - rid_manager_use_case, - ) rid_manager_setup_use_case = RIDManagerSetupUseCase( rid_manager_setup_gateway=rid_manager_setup_gateway, role_use_case=role_use_case, @@ -1842,7 +1842,7 @@ async def rid_set_gateway( async def rid_set_use_case( container: AsyncContainer, rid_manager_use_case: RIDManagerUseCase, - entity_type_dao: EntityTypeDAO, + entity_type_use_case: EntityTypeUseCase, rid_set_gateway: RIDSetGateway, role_use_case: RoleUseCase, ) -> AsyncIterator[RIDSetUseCase]: @@ -1851,7 +1851,7 @@ async def rid_set_use_case( session = await container.get(AsyncSession) yield RIDSetUseCase( rid_set_gateway, - entity_type_dao, + entity_type_use_case, session, rid_manager_use_case, role_use_case, From 1c4516be96fab570a6094cf7740bed70b6e332d9 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 7 Apr 2026 13:22:40 +0300 Subject: [PATCH 25/37] Update: Add entity_type_name for SYSTEM_CONTAINER_NAME in constants and clean up unused methods in directory_create_use_case and directory_dao for improved clarity and maintainability --- app/constants.py | 1 + .../ldap_schema/directory_create_use_case.py | 19 ---- .../ldap_schema/directory_dao.py | 8 -- .../rid_manager/setup_use_case.py | 11 +-- tests/conftest.py | 90 ++++++++++--------- tests/constants.py | 1 + tests/test_ldap/test_object_sid.py | 13 +-- 7 files changed, 58 insertions(+), 85 deletions(-) diff --git a/app/constants.py b/app/constants.py index c33f54711..5bece61a8 100644 --- a/app/constants.py +++ b/app/constants.py @@ -327,6 +327,7 @@ }, { "name": SYSTEM_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, "object_class": "organizationalUnit", "attributes": { "objectClass": ["top", "container"], diff --git a/app/ldap_protocol/ldap_schema/directory_create_use_case.py b/app/ldap_protocol/ldap_schema/directory_create_use_case.py index 8d067fb47..6cf5f3d90 100644 --- a/app/ldap_protocol/ldap_schema/directory_create_use_case.py +++ b/app/ldap_protocol/ldap_schema/directory_create_use_case.py @@ -8,14 +8,12 @@ from sqlalchemy.ext.asyncio import AsyncSession -from enums import SidPrefix from ldap_protocol.ldap_schema.attribute_dao import AttributeDAO from ldap_protocol.ldap_schema.directory_dao import DirectoryDAO from ldap_protocol.ldap_schema.dto import AttributeDTO, DirCreateDTO from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) -from ldap_protocol.ldap_schema.exceptions import CantCreateDirectoryError from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.roles.role_use_case import RoleUseCase @@ -72,29 +70,12 @@ async def create_dir( parent_dir: "Directory", ) -> None: """Create.""" - base_directory_paths_and_sids = ( - await self.__directory_dao.get_base_directory_paths_with_sid() - ) - dir_ = await self.__directory_dao.create_directory( name=dto.name, is_system=dto.is_system, parent_dir=parent_dir, ) - for _path, _sid in base_directory_paths_and_sids: - if _is_dn_in_base_directory(_path, dir_.path_dn): - base_dn_sid = _sid - break - else: - raise CantCreateDirectoryError("Cannot create a directory.") - - await self.__object_sid_use_case.add( - directory_id=dir_.id, - rid=int(base_dn_sid), - sid_prefix=SidPrefix.BUILT_IN_DOMAIN, - ) - attr_dto = AttributeDTO(name=dir_.rdname, values=[dir_.name]) await self.__attribute_dao.add_directory_name_attribute( dir_.id, diff --git a/app/ldap_protocol/ldap_schema/directory_dao.py b/app/ldap_protocol/ldap_schema/directory_dao.py index 095da77ae..ec0f72e13 100644 --- a/app/ldap_protocol/ldap_schema/directory_dao.py +++ b/app/ldap_protocol/ldap_schema/directory_dao.py @@ -10,7 +10,6 @@ from constants import CONFIGURATION_DIR_NAME from entities import Directory, EntityType -from ldap_protocol.utils.queries import get_base_directories from repo.pg.tables import queryable_attr as qa @@ -64,13 +63,6 @@ async def get_all_without_entity_type(self) -> list[Directory]: ) return list(result.all()) - async def get_base_directory_paths_with_sid(self) -> list[tuple[str, str]]: - """Get all base directory paths.""" - base_dirs = await get_base_directories(self.__session) - return [ - (base_dir.path_dn, base_dir.object_sid) for base_dir in base_dirs - ] - async def get_configuration_dir(self) -> Directory: """Get configuration directory.""" result = await self.__session.execute( diff --git a/app/ldap_protocol/rid_manager/setup_use_case.py b/app/ldap_protocol/rid_manager/setup_use_case.py index ea8114c02..203139515 100644 --- a/app/ldap_protocol/rid_manager/setup_use_case.py +++ b/app/ldap_protocol/rid_manager/setup_use_case.py @@ -51,22 +51,18 @@ async def setup(self) -> None: qword, ) dc = await self._rid_manager_use_case.get_domain_controller() - rid_set = await self._rid_set_use_case.add( + await self._rid_set_use_case.add( dc, await self._rid_set_use_case.generate_rid_set_attrs(), ) await self.inherit_aces( rid_manager_dir, - dc, - rid_set, ) async def inherit_aces( self, rid_manager_dir: Directory, - domain_controller: Directory, - rid_set: Directory, ) -> None: """Inherit ACEs from domain root to RID Manager directory. @@ -81,11 +77,6 @@ async def inherit_aces( directory=rid_manager_dir, ) - await self._role_use_case.inherit_parent_aces( - parent_directory=domain_controller, - directory=rid_set, - ) - async def create_domain_identifier(self) -> None: """Create domain identifier.""" await self._gateway.create_domain_identifier() diff --git a/tests/conftest.py b/tests/conftest.py index 4cbbb935c..4ebd7da87 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1178,8 +1178,7 @@ async def setup_session( ) for entity_type_dto in chain(ENTITY_TYPE_DTOS_V1, ENTITY_TYPE_DTOS_V2): await entity_type_use_case.create_not_safe(entity_type_dto) - await session.flush() - await audit_use_case.create_policies() + domain = await setup_gateway.create_base_domain("md.test") await rid_manager_setup_use_case.create_domain_identifier() @@ -1188,6 +1187,52 @@ async def setup_session( data=TEST_DATA, is_system=False, ) + + for attr_type_name in ( + "description", + "posixEmail", + "userPrincipalName", + "userAccountControl", + "cn", + "objectClass", + ): + _at = await attribute_type_use_case_legacy.get( + attr_type_name, + ) + if not _at: + raise ValueError( + f"setup_session:: AttributeType {attr_type_name} not found", + ) + await attribute_type_use_case.create(_at) + + for _obj_class_name in ( + "top", + "person", + "organizationalPerson", + "user", + "domain", + "container", + "organization", + "domainDNS", + "group", + "inetOrgPerson", + "posixAccount", + "rIDManager", + "rIDSet", + ): + _oc_dto = await object_class_use_case_legacy.get(_obj_class_name) + _oc_dto.attribute_types_may = [ + _.name # type: ignore + for _ in _oc_dto.attribute_types_may + ] + _oc_dto.attribute_types_must = [ + _.name # type: ignore + for _ in _oc_dto.attribute_types_must + ] + await object_class_use_case.create(_oc_dto) # type: ignore + + await session.flush() + dc_directory = Directory( name=DOMAIN_CONTROLLERS_OU_NAME, object_class="computer", @@ -1236,47 +1281,6 @@ async def setup_session( ): await attribute_type_use_case.create(_at_dto) - for attr_type_name in ( - "description", - "posixEmail", - "userPrincipalName", - "userAccountControl", - "cn", - "objectClass", - ): - _at = await attribute_type_use_case_legacy.get( - attr_type_name, - ) - if not _at: - raise ValueError( - f"setup_session:: AttributeType {attr_type_name} not found", - ) - await attribute_type_use_case.create(_at) - - for _obj_class_name in ( - "top", - "person", - "organizationalPerson", - "user", - "domain", - "container", - "organization", - "domainDNS", - "group", - "inetOrgPerson", - "posixAccount", - ): - _oc_dto = await object_class_use_case_legacy.get(_obj_class_name) - _oc_dto.attribute_types_may = [ - _.name # type: ignore - for _ in _oc_dto.attribute_types_may - ] - _oc_dto.attribute_types_must = [ - _.name # type: ignore - for _ in _oc_dto.attribute_types_must - ] - await object_class_use_case.create(_oc_dto) # type: ignore - await audit_use_case.create_policies() # NOTE: after setup environment we need base DN to be created diff --git a/tests/constants.py b/tests/constants.py index b570415c6..50345f139 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -469,6 +469,7 @@ }, { "name": SYSTEM_CONTAINER_NAME, + "entity_type_name": EntityTypeNames.ORGANIZATIONAL_UNIT, "object_class": "organizationalUnit", "attributes": { "objectClass": ["top", "container"], diff --git a/tests/test_ldap/test_object_sid.py b/tests/test_ldap/test_object_sid.py index c4347189a..95fef4a26 100644 --- a/tests/test_ldap/test_object_sid.py +++ b/tests/test_ldap/test_object_sid.py @@ -63,12 +63,15 @@ async def test_rid_set_reset_pool( available_pool_before = await rid_manager_gateway.get_rid_available_pool() lower_before, _ = from_qword(available_pool_before) + allocation_pool_before = await rid_set_gateway.get_rid_allocation_pool( + rid_set_id, + ) previous_pool_before = ( await rid_set_gateway.get_rid_previous_allocation_pool(rid_set_id) ) _, upper = from_qword(previous_pool_before) - await rid_set_gateway.update_next_rid(rid_set_id, upper - 1) + await rid_set_gateway.update_next_rid(rid_set_id, upper) current_next_rid = await rid_set_gateway.get_rid_next_rid(rid_set_id) assert ( @@ -79,7 +82,7 @@ async def test_rid_set_reset_pool( is True ) - next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) + await rid_set_use_case.allocate_next_rid(rid_set_id) current_next_rid = await rid_set_gateway.get_rid_next_rid(rid_set_id) previous_pool_mid = await rid_set_gateway.get_rid_previous_allocation_pool( rid_set_id, @@ -104,11 +107,11 @@ async def test_rid_set_reset_pool( ) assert lower_after == lower_before + RIDManagerUseCase.RID_BLOCK_SIZE - assert previous_pool_after == to_qword( - next_rid, + assert allocation_pool_after == to_qword( + lower_before, lower_before + RIDManagerUseCase.RID_BLOCK_SIZE, ) - assert allocation_pool_after == previous_pool_before + assert previous_pool_after == allocation_pool_before @pytest.mark.asyncio From a6ebd62a315bc37b8d609fa37a6ce8994e8defc4 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Tue, 7 Apr 2026 13:39:10 +0300 Subject: [PATCH 26/37] Enhance: Refactor setup_session to streamline domain controller and directory creation, improving clarity and maintainability --- tests/conftest.py | 44 ++++++++++++++++++++++---------------------- 1 file changed, 22 insertions(+), 22 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 4ebd7da87..60221a93b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1187,6 +1187,27 @@ async def setup_session( data=TEST_DATA, is_system=False, ) + dc_directory = Directory( + name=DOMAIN_CONTROLLERS_OU_NAME, + object_class="computer", + is_system=True, + ) + dc_directory.create_path(domain, "cn") + session.add(dc_directory) + await session.flush() + dc_directory.parent_id = domain.id + await session.refresh(dc_directory, ["id"]) + await session.flush() + dc = Directory( + name=settings.HOST_MACHINE_SHORT_NAME, + is_system=True, + ) + + dc.create_path(dc_directory, "cn") + session.add(dc) + await session.flush() + dc.parent_id = dc_directory.id + await session.refresh(dc, ["id"]) for attr_type_name in ( "description", @@ -1211,6 +1232,7 @@ async def setup_session( "organizationalPerson", "user", "domain", + "computer", "container", "organization", "domainDNS", @@ -1233,28 +1255,6 @@ async def setup_session( await session.flush() - dc_directory = Directory( - name=DOMAIN_CONTROLLERS_OU_NAME, - object_class="computer", - is_system=True, - ) - dc_directory.create_path(domain, "cn") - session.add(dc_directory) - await session.flush() - dc_directory.parent_id = domain.id - await session.refresh(dc_directory, ["id"]) - await session.flush() - dc = Directory( - name=settings.HOST_MACHINE_SHORT_NAME, - object_class="computer", - is_system=True, - ) - dc.create_path(dc_directory, "cn") - session.add(dc) - await session.flush() - dc.parent_id = dc_directory.id - await session.refresh(dc, ["id"]) - for _at_dto in ( AttributeTypeDTO[None]( oid="1.2.3.4.5.6.7.8", From f18a27d4ad859a73fbc5c3f7d881a85a8c8af673 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 9 Apr 2026 15:08:20 +0300 Subject: [PATCH 27/37] Refactor: Streamline object SID retrieval and enhance RID Set handling by introducing caching and updating method signatures for improved clarity and performance --- .../01f3f05a5b11_add_primary_group_id.py | 6 +-- app/entities.py | 8 +--- app/ldap_protocol/auth/use_cases.py | 4 +- app/ldap_protocol/ldap_requests/modify.py | 3 ++ app/ldap_protocol/ldap_requests/search.py | 2 - .../ldap_schema/directory_create_use_case.py | 3 -- .../rid_manager/object_sid_gateway.py | 13 +++-- .../rid_manager/object_sid_use_case.py | 9 +--- .../rid_manager/rid_set_gateway.py | 48 ++++++++++++++++++- .../rid_manager/rid_set_use_case.py | 18 ++++--- .../rid_manager/setup_gateway.py | 40 ++++------------ .../rid_manager/setup_use_case.py | 4 +- app/ldap_protocol/utils/async_cache.py | 28 +---------- app/ldap_protocol/utils/queries.py | 9 +++- tests/conftest.py | 1 - tests/test_ldap/test_object_sid.py | 13 ++--- 16 files changed, 103 insertions(+), 106 deletions(-) diff --git a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py index 947bbba5c..a9e38f219 100644 --- a/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py +++ b/app/alembic/versions/01f3f05a5b11_add_primary_group_id.py @@ -148,9 +148,9 @@ async def _add_primary_group_id(connection: AsyncConnection) -> None: # noqa: A query = ( select(Directory) .options( - selectinload(qa(Directory.groups)).selectinload( - qa(Group.directory), - ), + selectinload(qa(Directory.groups)) + .selectinload(qa(Group.directory)) + .selectinload(qa(Directory.attributes)), ) .where( qa(Directory.entity_type_id).in_(entity_type_ids), diff --git a/app/entities.py b/app/entities.py index 64605c470..dee8fea63 100644 --- a/app/entities.py +++ b/app/entities.py @@ -185,13 +185,7 @@ def create_path( @property def object_sid(self) -> str: """Get objectSid attribute value.""" - attrs = self.__dict__.get("attributes") - if not attrs: - return "" - for attr in attrs: - if attr.name and attr.name.lower() == "objectsid" and attr.value: - return attr.value - return "" + return self.attributes_dict.get("objectSid", [""])[0] @property def relative_id(self) -> str: diff --git a/app/ldap_protocol/auth/use_cases.py b/app/ldap_protocol/auth/use_cases.py index 564ab6bea..39abddd9d 100644 --- a/app/ldap_protocol/auth/use_cases.py +++ b/app/ldap_protocol/auth/use_cases.py @@ -211,7 +211,9 @@ async def _create(self, dto: SetupDTO, data: list) -> None: """ try: domain = await self._setup_gateway.create_base_domain(dto.domain) - await self._rid_manager_setup_use_case.create_domain_identifier() + await self._rid_manager_setup_use_case.create_domain_identifier( + domain.id, + ) await self._setup_gateway.setup_enviroment( data=data, is_system=True, diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index f25ab9996..c6f88c327 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -363,6 +363,9 @@ def _get_dir_query(self) -> Select[tuple[Directory]]: selectinload(qa(Directory.groups)).joinedload( qa(Group.directory), ), + selectinload(qa(Directory.groups)) + .joinedload(qa(Group.directory)) + .selectinload(qa(Directory.attributes)), joinedload(qa(Directory.group)).selectinload( qa(Group.members), ), diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index afa7c4614..fa36b9925 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -380,8 +380,6 @@ def _mutate_query_with_attributes_to_load( func.lower(Attribute.name).in_(attrs), func.lower(Attribute.name) == "objectclass", ] - if self.is_sid_requested: - cond_parts.append(func.lower(Attribute.name) == "objectsid") cond = or_(*cond_parts) diff --git a/app/ldap_protocol/ldap_schema/directory_create_use_case.py b/app/ldap_protocol/ldap_schema/directory_create_use_case.py index 6cf5f3d90..5477c530e 100644 --- a/app/ldap_protocol/ldap_schema/directory_create_use_case.py +++ b/app/ldap_protocol/ldap_schema/directory_create_use_case.py @@ -14,7 +14,6 @@ from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, ) -from ldap_protocol.rid_manager import ObjectSIDUseCase from ldap_protocol.roles.role_use_case import RoleUseCase if TYPE_CHECKING: @@ -46,7 +45,6 @@ def __init__( role_use_case: RoleUseCase, directory_dao: DirectoryDAO, attribute_dao: AttributeDAO, - object_sid_use_case: ObjectSIDUseCase, ) -> None: """Initialize.""" self.__session = session @@ -54,7 +52,6 @@ def __init__( self.__role_use_case = role_use_case self.__directory_dao = directory_dao self.__attribute_dao = attribute_dao - self.__object_sid_use_case = object_sid_use_case async def get_configuration_dir(self) -> "Directory": """Get configuration directory.""" diff --git a/app/ldap_protocol/rid_manager/object_sid_gateway.py b/app/ldap_protocol/rid_manager/object_sid_gateway.py index 83e655e7f..f63cd6cf9 100644 --- a/app/ldap_protocol/rid_manager/object_sid_gateway.py +++ b/app/ldap_protocol/rid_manager/object_sid_gateway.py @@ -7,7 +7,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from entities import Attribute +from entities import Attribute, Directory from ldap_protocol.rid_manager.exceptions import ( RIDManagerDomainIdentifierNotFoundError, RIDManagerObjectSIDNotFoundError, @@ -46,16 +46,21 @@ async def add(self, directory_id: int, object_sid: str) -> None: ), ) + @domain_identifier_cache async def get_domain_identifier(self) -> str: """Get domain identifier (cached ``Attribute.value`` string).""" - return await domain_identifier_cache.get_or_load( - self._load_domain_identifier_value, - ) + return await self._load_domain_identifier_value() async def _load_domain_identifier_value(self) -> str: query = await self._session.scalar( select(Attribute).where( qa(Attribute.name) == "DomainIdentifier", + select(Directory) + .where( + qa(Directory.id) == qa(Attribute.directory_id), + qa(Directory.parent_id).is_(None), + ) + .exists(), ), ) diff --git a/app/ldap_protocol/rid_manager/object_sid_use_case.py b/app/ldap_protocol/rid_manager/object_sid_use_case.py index a67c17ae6..d46a4b9cb 100644 --- a/app/ldap_protocol/rid_manager/object_sid_use_case.py +++ b/app/ldap_protocol/rid_manager/object_sid_use_case.py @@ -40,13 +40,8 @@ async def add( ) -> None: """Add object SID.""" if rid is None: - domain_controller = ( - await self._rid_manager_use_case.get_domain_controller() - ) - rid_set = await self._rid_set_use_case.get(domain_controller) - rid = await self._rid_set_use_case.allocate_next_rid( - rid_set.id, - ) + rid_set_id = await self._rid_set_use_case.get_rid_set_id() + rid = await self._rid_set_use_case.allocate_next_rid(rid_set_id) if sid_prefix == SidPrefix.BUILT_IN_DOMAIN: object_sid = f"{sid_prefix}-{rid}" diff --git a/app/ldap_protocol/rid_manager/rid_set_gateway.py b/app/ldap_protocol/rid_manager/rid_set_gateway.py index 0318e8fc7..29b70985c 100644 --- a/app/ldap_protocol/rid_manager/rid_set_gateway.py +++ b/app/ldap_protocol/rid_manager/rid_set_gateway.py @@ -6,7 +6,10 @@ from sqlalchemy import select, update from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import aliased +from config import Settings +from constants import DOMAIN_CONTROLLERS_OU_NAME from entities import Attribute, Directory from ldap_protocol.rid_manager.dtos import RIDSetAllocationParamsDTO from ldap_protocol.rid_manager.exceptions import ( @@ -15,15 +18,58 @@ RIDManagerRidPreviousAllocationPoolNotFoundError, RIDManagerRidSetNotFoundError, ) +from ldap_protocol.utils.async_cache import rid_set_id_cache from repo.pg.tables import queryable_attr as qa class RIDSetGateway: """RID Set gateway.""" - def __init__(self, session: AsyncSession) -> None: + def __init__(self, session: AsyncSession, settings: Settings) -> None: """Initialize RID Set gateway.""" self._session = session + self._settings = settings + + @rid_set_id_cache + async def get_rid_set_id(self) -> int: + """Get RID Set ID.""" + return await self.get_rid_set_id_value() + + async def get_rid_set_id_value(self) -> int: + """Get RID Set ID.""" + domain = aliased(Directory) + domain_controllers_ou = aliased(Directory) + domain_controller = aliased(Directory) + rid_set = aliased(Directory) + + rid_set_id = await self._session.scalar( + select(qa(rid_set.id)) + .select_from(domain) + .join( + domain_controllers_ou, + qa(domain_controllers_ou.parent_id) == qa(domain.id), + ) + .join( + domain_controller, + qa(domain_controller.parent_id) + == qa(domain_controllers_ou.id), + ) + .join( + rid_set, + qa(rid_set.parent_id) == qa(domain_controller.id), + ) + .where( + qa(domain.object_class) == "domain", + qa(domain.parent_id).is_(None), + qa(domain_controllers_ou.name) == DOMAIN_CONTROLLERS_OU_NAME, + qa(domain_controller.name) + == self._settings.HOST_MACHINE_SHORT_NAME, + qa(rid_set.name) == "RID Set", + ), + ) + if rid_set_id is None: + raise RIDManagerRidSetNotFoundError("RID Set directory not found") + return int(rid_set_id) async def get(self, domain_controller: Directory) -> Directory: """Get RID Set directory.""" diff --git a/app/ldap_protocol/rid_manager/rid_set_use_case.py b/app/ldap_protocol/rid_manager/rid_set_use_case.py index 9c714c2ee..26d6767cb 100644 --- a/app/ldap_protocol/rid_manager/rid_set_use_case.py +++ b/app/ldap_protocol/rid_manager/rid_set_use_case.py @@ -39,6 +39,10 @@ async def get(self, domain_controller: Directory) -> Directory: """Get RID Set directory.""" return await self._gateway.get(domain_controller) + async def get_rid_set_id(self) -> int: + """Get RID Set ID.""" + return await self._gateway.get_rid_set_id() + async def add( self, domain_controller: Directory, @@ -83,20 +87,20 @@ async def allocate_next_rid(self, rid_set_id: int) -> int: ) ) - if self.is_pool_exceeded( + if not self.is_pool_exceeded( current_next_rid, previous_allocation_pool, ): - new_next_rid = await self.rebind_next_rid_from_new_pool( + new_next_rid = current_next_rid + 1 + await self._gateway.update_next_rid( rid_set_id, + new_next_rid, ) else: - new_next_rid = current_next_rid + 1 + new_next_rid = await self.rebind_next_rid_from_new_pool( + rid_set_id, + ) - await self._gateway.update_next_rid( - rid_set_id, - new_next_rid, - ) return new_next_rid async def rebind_next_rid_from_new_pool( diff --git a/app/ldap_protocol/rid_manager/setup_gateway.py b/app/ldap_protocol/rid_manager/setup_gateway.py index 4ba027635..e8d4117bd 100644 --- a/app/ldap_protocol/rid_manager/setup_gateway.py +++ b/app/ldap_protocol/rid_manager/setup_gateway.py @@ -6,7 +6,7 @@ import secrets -from sqlalchemy import exists, select, update +from sqlalchemy import exists, select from sqlalchemy.ext.asyncio import AsyncSession from constants import SYSTEM_CONTAINER_NAME @@ -15,7 +15,6 @@ EntityTypeUseCase, ) from ldap_protocol.rid_manager.exceptions import ( - RIDManagerBaseDomainNotFoundError, RIDManagerSystemContainerNotFoundError, ) from ldap_protocol.utils.queries import get_base_directories @@ -127,26 +126,14 @@ async def set_rid_available_pool( :param rid_manager_dir: RID Manager directory object :param qword_value: New QWORD value (64-bit) """ - query = ( - update(Attribute) - .where( - qa(Attribute.directory_id) == rid_manager_dir.id, - qa(Attribute.name) == "rIDAvailablePool", - ) - .values(value=str(qword_value)) + self._session.add( + Attribute( + directory_id=rid_manager_dir.id, + name="rIDAvailablePool", + value=str(qword_value), + ), ) - result = await self._session.execute(query) - - if result.rowcount == 0: - self._session.add( - Attribute( - directory_id=rid_manager_dir.id, - name="rIDAvailablePool", - value=str(qword_value), - ), - ) - await self._session.flush() def _generate_domain_sid_identifier(self) -> str: @@ -156,7 +143,7 @@ def _generate_domain_sid_identifier(self) -> str: f"-{secrets.randbits(32)}-{secrets.randbits(32)}" ) - async def create_domain_identifier(self) -> None: + async def create_domain_identifier(self, domain_id: int) -> None: """Add domain identifier to domain.""" domain_identifer = await self._session.scalar( select( @@ -169,20 +156,11 @@ async def create_domain_identifier(self) -> None: if domain_identifer: return - domain = await self._session.scalar( - select(Directory).where( - qa(Directory.object_class) == "domain", - qa(Directory.parent_id).is_(None), - ), - ) - if not domain: - raise RIDManagerBaseDomainNotFoundError("Domain not found") - self._session.add( Attribute( name="DomainIdentifier", value=f"{self._generate_domain_sid_identifier()}", - directory_id=domain.id, + directory_id=domain_id, ), ) await self._session.flush() diff --git a/app/ldap_protocol/rid_manager/setup_use_case.py b/app/ldap_protocol/rid_manager/setup_use_case.py index 203139515..34bb18074 100644 --- a/app/ldap_protocol/rid_manager/setup_use_case.py +++ b/app/ldap_protocol/rid_manager/setup_use_case.py @@ -77,6 +77,6 @@ async def inherit_aces( directory=rid_manager_dir, ) - async def create_domain_identifier(self) -> None: + async def create_domain_identifier(self, domain_id: int) -> None: """Create domain identifier.""" - await self._gateway.create_domain_identifier() + await self._gateway.create_domain_identifier(domain_id) diff --git a/app/ldap_protocol/utils/async_cache.py b/app/ldap_protocol/utils/async_cache.py index ff544251f..9974effad 100644 --- a/app/ldap_protocol/utils/async_cache.py +++ b/app/ldap_protocol/utils/async_cache.py @@ -43,30 +43,6 @@ async def wrapper(*args: tuple, **kwargs: dict) -> T: return wrapper -class SingleValueTTLCache(Generic[T]): - """Single cached value; refresh via ``loader`` on miss or TTL expiry.""" - - def __init__(self, ttl: int | None = DEFAULT_CACHE_TIME) -> None: - self._ttl = ttl - self._value: T | None = None - self._expires_at: float | None = None - - def clear(self) -> None: - self._value = None - self._expires_at = None - - async def get_or_load(self, loader: Callable[[], Awaitable[T]]) -> T: - """Return cached value or ``await loader()`` and store it.""" - if self._value is not None: - if not self._expires_at or self._expires_at > time.monotonic(): - return self._value - self.clear() - - result = await loader() - self._value = result - self._expires_at = time.monotonic() + self._ttl if self._ttl else None - return result - - base_directories_cache = AsyncTTLCache[list[Directory]]() -domain_identifier_cache = SingleValueTTLCache[str]() +domain_identifier_cache = AsyncTTLCache[str]() +rid_set_id_cache = AsyncTTLCache[int]() diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 64bb827e2..788308539 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -224,6 +224,9 @@ async def get_groups(dn_list: list[str], session: AsyncSession) -> list[Group]: .options(selectinload(qa(Group.members))) .options( joinedload(qa(Group.directory)).selectinload(qa(Directory.groups)), + joinedload(qa(Group.directory)).selectinload( + qa(Directory.attributes), + ), ) ) @@ -246,7 +249,11 @@ async def get_group( query = ( select(Group) .join(qa(Group.directory), isouter=True) - .options(joinedload(qa(Group.directory))) + .options( + joinedload(qa(Group.directory)).selectinload( + qa(Directory.attributes), + ), + ) ) if validate_entry(dn): diff --git a/tests/conftest.py b/tests/conftest.py index 60221a93b..bbe52ad08 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1114,7 +1114,6 @@ async def setup_session( role_use_case=role_use_case, directory_dao=directory_dao, attribute_dao=attribute_dao, - object_sid_use_case=object_sid_use_case, ) object_class_use_case = ObjectClassUseCase( attribute_type_dao=attribute_type_dao, diff --git a/tests/test_ldap/test_object_sid.py b/tests/test_ldap/test_object_sid.py index 95fef4a26..9432933bc 100644 --- a/tests/test_ldap/test_object_sid.py +++ b/tests/test_ldap/test_object_sid.py @@ -36,12 +36,9 @@ async def test_rid_manager_allocate_pool( @pytest.mark.usefixtures("setup_session") async def test_next_rid( rid_set_use_case: RIDSetUseCase, - rid_manager_use_case: RIDManagerUseCase, ) -> None: """Test RID Manager get domain controller.""" - dc = await rid_manager_use_case.get_domain_controller() - rid_set = await rid_set_use_case.get(dc) - rid_set_id = rid_set.id + rid_set_id = await rid_set_use_case.get_rid_set_id() next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) new_next_rid = await rid_set_use_case.allocate_next_rid(rid_set_id) assert new_next_rid == next_rid + 1 @@ -52,14 +49,11 @@ async def test_next_rid( @pytest.mark.usefixtures("setup_session") async def test_rid_set_reset_pool( rid_set_use_case: RIDSetUseCase, - rid_manager_use_case: RIDManagerUseCase, rid_manager_gateway: RIDManagerGateway, rid_set_gateway: RIDSetGateway, ) -> None: """Test RID Set pool reset.""" - dc = await rid_manager_use_case.get_domain_controller() - rid_set = await rid_set_use_case.get(dc) - rid_set_id = rid_set.id + rid_set_id = await rid_set_use_case.get_rid_set_id() available_pool_before = await rid_manager_gateway.get_rid_available_pool() lower_before, _ = from_qword(available_pool_before) @@ -126,8 +120,7 @@ async def test_object_sid_add_updates_next_rid_and_prefix( rid_manager_use_case: RIDManagerUseCase, ) -> None: dc = await rid_manager_use_case.get_domain_controller() - rid_set = await rid_set_use_case.get(dc) - rid_set_id = rid_set.id + rid_set_id = await rid_set_use_case.get_rid_set_id() dc_id = dc.id next_before = await rid_set_gateway.get_rid_next_rid(rid_set_id) From fa2e2438a943af634d65e0a18eaeaa8b42dadb1c Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 9 Apr 2026 15:15:57 +0300 Subject: [PATCH 28/37] Refactor: Update RIDSetGateway initialization to include settings for improved configuration management in tests --- tests/conftest.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index bbe52ad08..61c20fb1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1092,7 +1092,7 @@ async def setup_session( rid_manager_gateway, session, ) - rid_set_gateway = RIDSetGateway(session) + rid_set_gateway = RIDSetGateway(session, settings) rid_set_use_case = RIDSetUseCase( rid_set_gateway, @@ -1834,11 +1834,12 @@ async def rid_manager_use_case( @pytest_asyncio.fixture(scope="function") async def rid_set_gateway( container: AsyncContainer, + settings: Settings, ) -> AsyncIterator[RIDSetGateway]: """Provide RIDSetGateway for tests that request it explicitly.""" async with container(scope=Scope.SESSION) as container: session = await container.get(AsyncSession) - yield RIDSetGateway(session) + yield RIDSetGateway(session, settings) @pytest_asyncio.fixture(scope="function") From cc2720aaaff66e85d88bcea000e41f7102db22c7 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 9 Apr 2026 15:20:23 +0300 Subject: [PATCH 29/37] Fix: Pass domain ID to create_domain_identifier in setup_session for correct domain initialization --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 61c20fb1f..c37a66b85 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1179,7 +1179,7 @@ async def setup_session( await entity_type_use_case.create_not_safe(entity_type_dto) domain = await setup_gateway.create_base_domain("md.test") - await rid_manager_setup_use_case.create_domain_identifier() + await rid_manager_setup_use_case.create_domain_identifier(domain.id) await setup_gateway.setup_enviroment( domain=domain, From 45cec83b3ddfe520fba929f214f5fbaed184a76a Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 9 Apr 2026 16:00:12 +0300 Subject: [PATCH 30/37] Enhance: Clear domain identifier and RID set caches in setup_session to ensure fresh state for each test run --- tests/conftest.py | 6 ++++++ tests/test_ldap/test_util/test_modify.py | 4 +++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index c37a66b85..0e9afe287 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -204,6 +204,10 @@ from ldap_protocol.server import PoolClientHandler from ldap_protocol.session_storage import RedisSessionStorage, SessionStorage from ldap_protocol.session_storage.repository import SessionRepository +from ldap_protocol.utils.async_cache import ( + domain_identifier_cache, + rid_set_id_cache, +) from ldap_protocol.utils.queries import get_user from password_utils import PasswordUtils from repo.pg.master_gateway import PGMasterGateway @@ -1057,6 +1061,8 @@ async def setup_session( settings: Settings, ) -> None: """Get session and acquire after completion.""" + domain_identifier_cache.clear() + rid_set_id_cache.clear() role_dao = RoleDAO(session) ace_dao = AccessControlEntryDAO(session) role_use_case = RoleUseCase(role_dao, ace_dao) diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index 2b006e529..26f35981c 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -968,7 +968,9 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: query = ( select(Directory) .options( - selectinload(qa(Directory.groups)).joinedload(qa(Group.directory)), + selectinload(qa(Directory.groups)) + .joinedload(qa(Group.directory)) + .selectinload(qa(Directory.attributes)), selectinload(qa(Directory.attributes)), joinedload(qa(Directory.group)), ) From f5defe419c2ac8e26933df667737867e411973b9 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Thu, 9 Apr 2026 18:04:49 +0300 Subject: [PATCH 31/37] Enhance: Integrate existing RID set ID retrieval into test_add_domain_controller for improved test accuracy --- tests/test_shedule.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/test_shedule.py b/tests/test_shedule.py index d6d9fbba3..3a7cc5dcb 100644 --- a/tests/test_shedule.py +++ b/tests/test_shedule.py @@ -98,6 +98,7 @@ async def test_add_domain_controller( ) -> None: """Test add domain controller.""" existing_dc = await rid_manager_use_case.get_domain_controller() + existing_rid_set_id = await rid_set_use_case.get_rid_set_id() monkeypatch.setattr( settings, "HOST_MACHINE_SHORT_NAME", @@ -112,6 +113,15 @@ async def _get_existing_dc() -> object: "get_domain_controller", _get_existing_dc, ) + + async def _get_existing_rid_set_id() -> int: + return existing_rid_set_id + + monkeypatch.setattr( + rid_set_use_case, + "get_rid_set_id", + _get_existing_rid_set_id, + ) await add_domain_controller( settings=settings, session=session, From b37b1ad03e3ed14971b650951a4b0e08bff179e4 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 10 Apr 2026 09:21:09 +0300 Subject: [PATCH 32/37] Refactor: Update add_domain_controller to utilize RID Set allocation before object SID assignment for improved domain controller setup --- app/extra/scripts/add_domain_controller.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/app/extra/scripts/add_domain_controller.py b/app/extra/scripts/add_domain_controller.py index 71b3b96ca..a04c9be9c 100644 --- a/app/extra/scripts/add_domain_controller.py +++ b/app/extra/scripts/add_domain_controller.py @@ -40,14 +40,16 @@ async def _add_domain_controller( await session.flush() dc_directory.parent_id = dc_ou_dir.id - await object_sid_use_case.add( - directory_id=dc_directory.id, - ) await session.flush() + await rid_set_use_case.add( domain_controller=dc_directory, allocation_params=await rid_set_use_case.generate_rid_set_attrs(), ) + await session.flush() + await object_sid_use_case.add( + directory_id=dc_directory.id, + ) attributes = [ Attribute( From f447d4e5b86bb9b7e24f6e0e357c3508c4f326e6 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 10 Apr 2026 09:34:02 +0300 Subject: [PATCH 33/37] Refactor: Simplify condition construction in SearchRequest by directly using or_ for improved readability and maintainability --- app/ldap_protocol/ldap_requests/search.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/search.py b/app/ldap_protocol/ldap_requests/search.py index fa36b9925..7c46f6948 100644 --- a/app/ldap_protocol/ldap_requests/search.py +++ b/app/ldap_protocol/ldap_requests/search.py @@ -376,12 +376,10 @@ def _mutate_query_with_attributes_to_load( if attr not in _ATTRS_TO_CLEAN } - cond_parts = [ + cond = or_( func.lower(Attribute.name).in_(attrs), func.lower(Attribute.name) == "objectclass", - ] - - cond = or_(*cond_parts) + ) return query.options( selectinload(qa(Directory.attributes)), From 38768dcadf1c9476294b850023458b9f4c956084 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 10 Apr 2026 13:06:16 +0300 Subject: [PATCH 34/37] Enhance: Update object SID handling in upgrade function to correctly format existing identifiers and assign built-in SIDs for domain groups, improving data integrity and consistency --- .../552b4eafb1aa_remove_objectsid_vals.py | 118 +++++++++--------- 1 file changed, 58 insertions(+), 60 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index a834f1578..5933eb381 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -15,7 +15,7 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from entities import Attribute, Directory, EntityType -from enums import EntityTypeNames +from enums import EntityTypeNames, SecurityPrincipalRid from ldap_protocol.ldap_schema.dto import EntityTypeDTO from ldap_protocol.ldap_schema.entity_type.entity_type_use_case import ( EntityTypeUseCase, @@ -140,6 +140,15 @@ async def _migrate_object_sids( ), ) + if ( + existing_identifier + and existing_identifier.value + and existing_identifier.value.startswith("S-1-5-21-") + ): + parts = existing_identifier.value.split("-") + if len(parts) >= 7: + existing_identifier.value = "-".join(parts[4:7]) + if not (existing_identifier and existing_identifier.value): domain_object_sid = await session.scalar( select(Attribute).where( @@ -171,6 +180,54 @@ async def _migrate_object_sids( directory_id=domain.id, ), ) + else: + identifier = existing_identifier.value + + built_in_sid_prefix = "S-1-5-32" + for dir_name, rid in ( + ("domain admins", SecurityPrincipalRid.DOMAIN_ADMINS), + ("domain users", SecurityPrincipalRid.DOMAIN_USERS), + ("domain computers", SecurityPrincipalRid.DOMAIN_COMPUTERS), + ( + "read only domain controllers", + SecurityPrincipalRid.DOMAIN_READ_ONLY, + ), + ): + await session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "objectSid", + qa(Attribute.directory_id).in_( + select(qa(Directory.id)).where( + qa(Directory.name) == dir_name, + ), + ), + ) + .values( + value=f"{built_in_sid_prefix}-{int(rid)}", + ), + ) + + await session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "objectSid", + qa(Attribute.directory_id).in_( + select(qa(Directory.id)) + .join(Attribute) + .where( + qa(Attribute.name) == "sAMAccountName", + qa(Attribute.value).ilike("administrator"), + ), + ), + ) + .values( + value=( + f"{built_in_sid_prefix}" + f"-{int(SecurityPrincipalRid.ADMINISTRATOR)}" + ), + ), + ) await session.commit() @@ -266,69 +323,10 @@ async def _init_rid_manager( previous_allocation_pool=previous_allocation_pool, ), ) - await role_use_case.inherit_parent_aces( - parent_directory=domain_controller, - directory=rid_set_dir, - ) - await session.commit() - return - existing_next_rid = await session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == rid_set_dir.id, - qa(Attribute.name) == "rIDNextRID", - ), - ) - existing_prev_pool = await session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == rid_set_dir.id, - qa(Attribute.name) == "rIDPreviousAllocationPool", - ), - ) - existing_pool = await session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == rid_set_dir.id, - qa(Attribute.name) == "rIDAllocationPool", - ), - ) - - if ( - existing_next_rid - and existing_next_rid.value - and existing_prev_pool - and existing_prev_pool.value - and existing_pool - and existing_pool.value - ): await session.commit() return - previous_allocation_pool = await rid_manager_use_case.allocate_pool() - allocation_pool = await rid_manager_use_case.allocate_pool() - lower, _ = from_qword(previous_allocation_pool) - - for name, value in ( - ("rIDNextRID", str(lower)), - ("rIDPreviousAllocationPool", str(previous_allocation_pool)), - ("rIDAllocationPool", str(allocation_pool)), - ): - result = await session.execute( - update(Attribute) - .where( - qa(Attribute.directory_id) == rid_set_dir.id, - qa(Attribute.name) == name, - ) - .values(value=value), - ) - if result.rowcount == 0: - session.add( - Attribute( - directory_id=rid_set_dir.id, - name=name, - value=value, - ), - ) - await session.commit() op.run_async(_init_rid_manager) From b81ee10643a256a5da825f2e64d841ab86158386 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 10 Apr 2026 13:57:42 +0300 Subject: [PATCH 35/37] Enhance: Refactor upgrade function to include parentId in directory selection and improve object SID handling by skipping entries with null parent IDs, ensuring better data integrity for domain groups. --- .../552b4eafb1aa_remove_objectsid_vals.py | 39 +++++++++++-------- 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 5933eb381..06f727a02 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -14,6 +14,12 @@ from sqlalchemy import delete, select, update from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession +from constants import ( + DOMAIN_ADMIN_GROUP_NAME, + DOMAIN_COMPUTERS_GROUP_NAME, + DOMAIN_USERS_GROUP_NAME, + READ_ONLY_GROUP_NAME, +) from entities import Attribute, Directory, EntityType from enums import EntityTypeNames, SecurityPrincipalRid from ldap_protocol.ldap_schema.dto import EntityTypeDTO @@ -95,23 +101,31 @@ async def _migrate_object_sids( """ async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) + base_dn_list = await get_base_directories(session) + if not base_dn_list: + return + domain = base_dn_list[0] directory_table = sa.table( "Directory", sa.column("id", sa.Integer), + sa.column("parentId", sa.Integer), sa.column("objectSid", sa.String), ) result = await session.execute( select( directory_table.c.id, + directory_table.c.parentId, directory_table.c.objectSid, ), ) - for directory_id, object_sid in result: + for directory_id, parent_id, object_sid in result: if not object_sid: continue + if parent_id is None: + continue existing_attr = await session.scalar( select(Attribute).where( @@ -129,10 +143,6 @@ async def _migrate_object_sids( ), ) - base_dn_list = await get_base_directories(session) - if base_dn_list: - domain = base_dn_list[0] - existing_identifier = await session.scalar( select(Attribute).where( qa(Attribute.directory_id) == domain.id, @@ -185,13 +195,13 @@ async def _migrate_object_sids( built_in_sid_prefix = "S-1-5-32" for dir_name, rid in ( - ("domain admins", SecurityPrincipalRid.DOMAIN_ADMINS), - ("domain users", SecurityPrincipalRid.DOMAIN_USERS), - ("domain computers", SecurityPrincipalRid.DOMAIN_COMPUTERS), + (DOMAIN_ADMIN_GROUP_NAME, SecurityPrincipalRid.DOMAIN_ADMINS), + (DOMAIN_USERS_GROUP_NAME, SecurityPrincipalRid.DOMAIN_USERS), ( - "read only domain controllers", - SecurityPrincipalRid.DOMAIN_READ_ONLY, + DOMAIN_COMPUTERS_GROUP_NAME, + SecurityPrincipalRid.DOMAIN_COMPUTERS, ), + (READ_ONLY_GROUP_NAME, SecurityPrincipalRid.DOMAIN_READ_ONLY), ): await session.execute( update(Attribute) @@ -212,13 +222,8 @@ async def _migrate_object_sids( update(Attribute) .where( qa(Attribute.name) == "objectSid", - qa(Attribute.directory_id).in_( - select(qa(Directory.id)) - .join(Attribute) - .where( - qa(Attribute.name) == "sAMAccountName", - qa(Attribute.value).ilike("administrator"), - ), + qa(Attribute.value).like( + f"S-1-5-21-%-{int(SecurityPrincipalRid.ADMINISTRATOR)}", ), ) .values( From d11ed27357797f416aa18f1e462bed7dc499bf5e Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 10 Apr 2026 14:54:47 +0300 Subject: [PATCH 36/37] Enhance: Implement async checks for primary group membership in ModifyRequest and add utility function to verify group inclusion by primary RID, improving group management logic. --- app/ldap_protocol/ldap_requests/modify.py | 40 ++++++++++++++++------- app/ldap_protocol/utils/queries.py | 32 ++++++++++++++++-- tests/test_ldap/test_util/test_modify.py | 22 +++++++++++++ 3 files changed, 80 insertions(+), 14 deletions(-) diff --git a/app/ldap_protocol/ldap_requests/modify.py b/app/ldap_protocol/ldap_requests/modify.py index c6f88c327..2c74e680a 100644 --- a/app/ldap_protocol/ldap_requests/modify.py +++ b/app/ldap_protocol/ldap_requests/modify.py @@ -53,6 +53,7 @@ get_directory_by_rid, get_filter_from_path, get_groups, + groups_include_primary_rid, remove_disallowed_group_members, remove_from_group_membership, ) @@ -363,9 +364,6 @@ def _get_dir_query(self) -> Select[tuple[Directory]]: selectinload(qa(Directory.groups)).joinedload( qa(Group.directory), ), - selectinload(qa(Directory.groups)) - .joinedload(qa(Group.directory)) - .selectinload(qa(Directory.attributes)), joinedload(qa(Directory.group)).selectinload( qa(Group.members), ), @@ -389,13 +387,17 @@ def _get_primary_group_id(self, directory: Directory) -> str | None: None, ) - def _contain_primary_group( + async def _contain_primary_group( self, groups: list[Group], primary_group_id: str, + session: AsyncSession, ) -> bool: - return any( - group.directory.relative_id == primary_group_id for group in groups + """Check whether membership includes the group for this RID.""" + return await groups_include_primary_rid( + session, + groups, + primary_group_id, ) async def _get_directories_with_primary_group_id( @@ -436,16 +438,25 @@ async def _get_members_with_primary_group_id( ) return list(await session.scalars(query)) - def _is_primary_group_deleted( + async def _is_primary_group_deleted( self, groups: list[Group], primary_group_id: str, operation: Operation, + session: AsyncSession, ) -> bool: if operation == Operation.REPLACE: - return not self._contain_primary_group(groups, primary_group_id) + return not await self._contain_primary_group( + groups, + primary_group_id, + session, + ) elif operation == Operation.DELETE: - return self._contain_primary_group(groups, primary_group_id) + return await self._contain_primary_group( + groups, + primary_group_id, + session, + ) return False async def _can_delete_group_from_directory( @@ -454,6 +465,7 @@ async def _can_delete_group_from_directory( user: UserSchema, groups: list[Group], operation: Operation, + session: AsyncSession, ) -> None: """Check if the request can delete group from directory.""" if operation == Operation.REPLACE: @@ -481,7 +493,12 @@ async def _can_delete_group_from_directory( if not primary_group_id: return - if self._is_primary_group_deleted(groups, primary_group_id, operation): + if await self._is_primary_group_deleted( + groups, + primary_group_id, + operation, + session, + ): raise ModifyForbiddenError( "Can't delete primary group from user.", ) @@ -556,6 +573,7 @@ async def _delete_memberof( user=user, groups=groups, operation=change.operation, + session=session, ) if not change.modification.vals: @@ -719,7 +737,7 @@ async def _add_primary_group_attribute( rid = str(change.modification.vals[0]) - if self._contain_primary_group(directory.groups, rid): + if await self._contain_primary_group(directory.groups, rid, session): session.add( Attribute( name="primaryGroupID", diff --git a/app/ldap_protocol/utils/queries.py b/app/ldap_protocol/utils/queries.py index 788308539..d3380776e 100644 --- a/app/ldap_protocol/utils/queries.py +++ b/app/ldap_protocol/utils/queries.py @@ -191,11 +191,15 @@ async def get_directory_by_rid( ) -> Directory | None: query = ( select(Directory) + .join( + Attribute, + qa(Attribute.directory_id) == qa(Directory.id), + ) .options( selectinload(qa(Directory.attributes)), joinedload(qa(Directory.group)), ) - .filter( + .where( qa(Attribute.name) == "objectSid", qa(Attribute.value).endswith(f"-{rid}"), ) @@ -203,6 +207,25 @@ async def get_directory_by_rid( return await session.scalar(query) +async def groups_include_primary_rid( + session: AsyncSession, + groups: list[Group], + primary_group_id: str, +) -> bool: + directory_ids = {g.directory_id for g in groups} + + stmt = ( + select(qa(Attribute.id)) + .where( + qa(Attribute.directory_id).in_(directory_ids), + qa(Attribute.name) == "objectSid", + qa(Attribute.value).endswith(f"-{primary_group_id}"), + ) + .limit(1) + ) + return await session.scalar(stmt) is not None + + async def get_groups(dn_list: list[str], session: AsyncSession) -> list[Group]: """Get dirs with groups by dn list.""" paths = [] @@ -567,10 +590,13 @@ async def get_group_path_dn_by_primary_group_id( """ query = ( select(Directory) - .join(Attribute) + .join( + Attribute, + qa(Attribute.directory_id) == qa(Directory.id), + ) .join(qa(Directory.group)) .options(contains_eager(qa(Directory.group))) - .filter( + .where( qa(Attribute.name) == "objectSid", qa(Attribute.value).endswith(f"-{primary_group_id}"), ) diff --git a/tests/test_ldap/test_util/test_modify.py b/tests/test_ldap/test_util/test_modify.py index 26f35981c..d3db50388 100644 --- a/tests/test_ldap/test_util/test_modify.py +++ b/tests/test_ldap/test_util/test_modify.py @@ -984,6 +984,12 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: @pytest.mark.parametrize( ("operation", "group_dn", "expected_groups", "expected_primary_group"), [ + ( + "add", + "cn=developers,cn=Groups,dc=md,dc=test", + {"domain admins", "developers"}, + True, + ), ( "add", "cn=domain admins,cn=Groups,dc=md,dc=test", @@ -996,6 +1002,12 @@ async def fetch_directory_by_dn(session: AsyncSession, dn: str) -> Directory: {"domain admins", "developers"}, False, ), + ( + "replace", + "cn=developers,cn=Groups,dc=md,dc=test", + {"domain admins", "developers"}, + True, + ), ], ) async def test_ldap_modify_primary_group_id_scenarios( @@ -1074,6 +1086,16 @@ async def test_ldap_modify_primary_group_id_scenarios( 0, {"domain admins"}, ), + ( + [ + "cn=domain admins,cn=Groups,dc=md,dc=test", + "cn=developers,cn=Groups,dc=md,dc=test", + "cn=domain computers,cn=Groups,dc=md,dc=test", + ], + True, + 0, + {"domain admins", "developers", "domain computers"}, + ), ], ) async def test_ldap_modify_replace_memberof_primary_group_various( From 5d3a4c674bf71e345a9fadb41db0cf347be507c6 Mon Sep 17 00:00:00 2001 From: "m.shvets" Date: Fri, 10 Apr 2026 15:40:03 +0300 Subject: [PATCH 37/37] Enhance: Introduce a new utility function to skip specific directory IDs during object SID migration, ensuring proper handling of top-level containers and configuration subtrees. Update upgrade function to add DomainIdentifier attribute and normalize object SID handling for domain directories. --- .../552b4eafb1aa_remove_objectsid_vals.py | 237 ++++++++++-------- 1 file changed, 136 insertions(+), 101 deletions(-) diff --git a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py index 06f727a02..db6546bcf 100644 --- a/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py +++ b/app/alembic/versions/552b4eafb1aa_remove_objectsid_vals.py @@ -15,10 +15,16 @@ from sqlalchemy.ext.asyncio import AsyncConnection, AsyncSession from constants import ( + COMPUTERS_CONTAINER_NAME, + CONFIGURATION_DIR_NAME, DOMAIN_ADMIN_GROUP_NAME, DOMAIN_COMPUTERS_GROUP_NAME, + DOMAIN_CONTROLLERS_OU_NAME, DOMAIN_USERS_GROUP_NAME, + GROUPS_CONTAINER_NAME, READ_ONLY_GROUP_NAME, + SYSTEM_CONTAINER_NAME, + USERS_CONTAINER_NAME, ) from entities import Attribute, Directory, EntityType from enums import EntityTypeNames, SecurityPrincipalRid @@ -51,6 +57,53 @@ depends_on: None | list[str] = None +async def _directory_ids_skipped_for_object_sid_migration( + session: AsyncSession, + domain: Directory, +) -> set[int]: + """Directory ids for which objectSid is not copied into Attributes. + + Top-level peer containers (System, OU DC, Users, Computers, Groups) and + the full subtree under ``Configuration``. + """ + peer_container_names = ( + SYSTEM_CONTAINER_NAME, + DOMAIN_CONTROLLERS_OU_NAME, + USERS_CONTAINER_NAME, + COMPUTERS_CONTAINER_NAME, + GROUPS_CONTAINER_NAME, + ) + peer_rows = await session.scalars( + select(qa(Directory.id)).where( + qa(Directory.parent_id) == domain.id, + qa(Directory.name).in_(peer_container_names), + ), + ) + skip_ids: set[int] = set(peer_rows.all()) + configuration_id = await session.scalar( + select(qa(Directory.id)).where( + qa(Directory.parent_id) == domain.id, + qa(Directory.name) == CONFIGURATION_DIR_NAME, + ), + ) + if configuration_id is None: + return skip_ids + + subtree = ( + select(qa(Directory.id)) + .where(qa(Directory.id) == configuration_id) + .cte(name="subtree", recursive=True) + ) + subtree = subtree.union_all( + select(qa(Directory.id)).where( + qa(Directory.parent_id) == subtree.c.id, + ), + ) + cfg_rows = await session.execute(select(subtree.c.id)) + skip_ids |= {row[0] for row in cfg_rows.all()} + return skip_ids + + def upgrade(container: AsyncContainer) -> None: # noqa: C901 """Add rIDManager and rIDSet objectClasses to LDAP schema.""" @@ -96,8 +149,9 @@ async def _migrate_object_sids( ) -> None: """Move Directory.objectSid values into Attributes table. - Additionally, for domain directories create the ``DomainIdentifier`` - attribute if it does not exist. + Add ``DomainIdentifier`` on the domain (from ``Directory.objectSid`` + column when present). Do not store domain ``objectSid`` in Attributes. + Normalize built-in group / administrator SIDs once. """ async with container(scope=Scope.REQUEST) as cnt: session = await cnt.get(AsyncSession) @@ -106,6 +160,13 @@ async def _migrate_object_sids( return domain = base_dn_list[0] + skip_object_sid_ids = ( + await _directory_ids_skipped_for_object_sid_migration( + session, + domain, + ) + ) + directory_table = sa.table( "Directory", sa.column("id", sa.Integer), @@ -113,127 +174,101 @@ async def _migrate_object_sids( sa.column("objectSid", sa.String), ) - result = await session.execute( - select( - directory_table.c.id, - directory_table.c.parentId, - directory_table.c.objectSid, + domain_sid_from_column = await session.scalar( + select(directory_table.c.objectSid).where( + directory_table.c.id == domain.id, ), ) + identifier: str | None = None + if domain_sid_from_column: + parts = domain_sid_from_column.split("-") + # "S-1-5-21-AAA-BBB-CCC" -> "AAA-BBB-CCC" + if len(parts) >= 7 and domain_sid_from_column.startswith( + "S-1-5-21-", + ): + identifier = "-".join(parts[4:7]) + + if identifier is None: + identifier = ( + f"{secrets.randbits(32)}-" + f"{secrets.randbits(32)}-" + f"{secrets.randbits(32)}" + ) + + session.add( + Attribute( + name="DomainIdentifier", + value=identifier, + directory_id=domain.id, + ), + ) + result = ( + await session.execute( + select( + directory_table.c.id, + directory_table.c.parentId, + directory_table.c.objectSid, + ), + ) + ).all() for directory_id, parent_id, object_sid in result: if not object_sid: continue if parent_id is None: continue + if directory_id in skip_object_sid_ids: + continue - existing_attr = await session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == directory_id, - qa(Attribute.name) == "objectSid", - ), - ) - - if not existing_attr: - session.add( - Attribute( - name="objectSid", - value=object_sid, - directory_id=directory_id, - ), - ) - - existing_identifier = await session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == domain.id, - qa(Attribute.name) == "DomainIdentifier", + session.add( + Attribute( + name="objectSid", + value=object_sid, + directory_id=directory_id, ), ) - if ( - existing_identifier - and existing_identifier.value - and existing_identifier.value.startswith("S-1-5-21-") - ): - parts = existing_identifier.value.split("-") - if len(parts) >= 7: - existing_identifier.value = "-".join(parts[4:7]) - - if not (existing_identifier and existing_identifier.value): - domain_object_sid = await session.scalar( - select(Attribute).where( - qa(Attribute.directory_id) == domain.id, - qa(Attribute.name) == "objectSid", - ), - ) - - identifier: str | None = None - if domain_object_sid and domain_object_sid.value: - parts = domain_object_sid.value.split("-") - # "S-1-5-21-AAA-BBB-CCC" -> "AAA-BBB-CCC" - if len(parts) >= 7 and domain_object_sid.value.startswith( - "S-1-5-21-", - ): - identifier = "-".join(parts[4:7]) - - if identifier is None: - identifier = ( - f"{secrets.randbits(32)}-" - f"{secrets.randbits(32)}-" - f"{secrets.randbits(32)}" - ) - - session.add( - Attribute( - name="DomainIdentifier", - value=identifier, - directory_id=domain.id, - ), - ) - else: - identifier = existing_identifier.value - - built_in_sid_prefix = "S-1-5-32" - for dir_name, rid in ( - (DOMAIN_ADMIN_GROUP_NAME, SecurityPrincipalRid.DOMAIN_ADMINS), - (DOMAIN_USERS_GROUP_NAME, SecurityPrincipalRid.DOMAIN_USERS), - ( - DOMAIN_COMPUTERS_GROUP_NAME, - SecurityPrincipalRid.DOMAIN_COMPUTERS, - ), - (READ_ONLY_GROUP_NAME, SecurityPrincipalRid.DOMAIN_READ_ONLY), - ): - await session.execute( - update(Attribute) - .where( - qa(Attribute.name) == "objectSid", - qa(Attribute.directory_id).in_( - select(qa(Directory.id)).where( - qa(Directory.name) == dir_name, - ), - ), - ) - .values( - value=f"{built_in_sid_prefix}-{int(rid)}", - ), - ) - + built_in_sid_prefix = "S-1-5-32" + for dir_name, rid in ( + (DOMAIN_ADMIN_GROUP_NAME, SecurityPrincipalRid.DOMAIN_ADMINS), + (DOMAIN_USERS_GROUP_NAME, SecurityPrincipalRid.DOMAIN_USERS), + ( + DOMAIN_COMPUTERS_GROUP_NAME, + SecurityPrincipalRid.DOMAIN_COMPUTERS, + ), + (READ_ONLY_GROUP_NAME, SecurityPrincipalRid.DOMAIN_READ_ONLY), + ): await session.execute( update(Attribute) .where( qa(Attribute.name) == "objectSid", - qa(Attribute.value).like( - f"S-1-5-21-%-{int(SecurityPrincipalRid.ADMINISTRATOR)}", + qa(Attribute.directory_id).in_( + select(qa(Directory.id)).where( + qa(Directory.name) == dir_name, + ), ), ) .values( - value=( - f"{built_in_sid_prefix}" - f"-{int(SecurityPrincipalRid.ADMINISTRATOR)}" - ), + value=f"{built_in_sid_prefix}-{int(rid)}", ), ) + await session.execute( + update(Attribute) + .where( + qa(Attribute.name) == "objectSid", + qa(Attribute.value).like( + f"S-1-5-21-%-{int(SecurityPrincipalRid.ADMINISTRATOR)}", + ), + ) + .values( + value=( + f"{built_in_sid_prefix}" + f"-{int(SecurityPrincipalRid.ADMINISTRATOR)}" + ), + ), + ) + await session.commit() op.run_async(_migrate_object_sids)