33import functools
44import json
55import re
6- from typing import Any , Callable , Iterable
6+ from typing import Any , Callable
77
8+ # , Iterable
89from aws_lambda_powertools .utilities .data_masking .constants import DATA_MASKING_STRING
910
1011PRESERVE_CHARS = set ("-_. " )
@@ -69,14 +70,14 @@ def decrypt(self, data, provider_options: dict | None = None, **encryption_conte
6970
7071 def erase (
7172 self ,
72- data ,
73+ data : Any ,
7374 dynamic_mask : bool | None = None ,
7475 custom_mask : str | None = None ,
7576 regex_pattern : str | None = None ,
7677 mask_format : str | None = None ,
7778 masking_rules : dict | None = None ,
7879 ** kwargs ,
79- ) -> Iterable [ str ] :
80+ ) -> str | dict | list | tuple | set :
8081 """
8182 This method irreversibly erases data.
8283
@@ -85,47 +86,68 @@ def erase(
8586
8687 If the data to be erased is of an iterable type like `list`, `tuple`,
8788 or `set`, this method will return a new object of the same type as the
88- input data but with each element replaced by the string "*****" or following one of the custom masks .
89+ input data but with each element masked according to the specified rules .
8990 """
90- result = DATA_MASKING_STRING
91-
92- if data :
93- if isinstance (data , str ):
94- if dynamic_mask :
95- result = self ._custom_erase (data , ** kwargs )
96- if custom_mask :
97- result = self ._pattern_mask (data , custom_mask )
98- if regex_pattern and mask_format :
99- result = self ._regex_mask (data , regex_pattern , mask_format )
100- elif isinstance (data , dict ):
101- if masking_rules :
102- result = self ._apply_masking_rules (data , masking_rules )
103- elif isinstance (data , (list , tuple , set )):
104- result = type (data )(
105- self .erase (
106- item ,
107- dynamic_mask = dynamic_mask ,
108- custom_mask = custom_mask ,
109- regex_pattern = regex_pattern ,
110- mask_format = mask_format ,
111- masking_rules = masking_rules ,
112- ** kwargs ,
113- )
114- for item in data
91+ result = None
92+
93+ # Handle empty or None data
94+ if not data :
95+ result = DATA_MASKING_STRING if isinstance (data , (str , bytes )) else data
96+
97+ # Handle string data
98+ elif isinstance (data , str ):
99+ if regex_pattern and mask_format :
100+ result = self ._regex_mask (data , regex_pattern , mask_format )
101+ elif custom_mask :
102+ result = self ._pattern_mask (data , custom_mask )
103+ elif dynamic_mask :
104+ result = self ._custom_erase (data , ** kwargs )
105+ else :
106+ result = DATA_MASKING_STRING
107+
108+ # Handle dictionary data
109+ elif isinstance (data , dict ):
110+ result = (
111+ self ._apply_masking_rules (data , masking_rules )
112+ if masking_rules
113+ else {k : DATA_MASKING_STRING for k in data }
114+ )
115+
116+ # Handle iterable data (list, tuple, set)
117+ elif isinstance (data , (list , tuple , set )):
118+ masked_data = (
119+ self .erase (
120+ item ,
121+ dynamic_mask = dynamic_mask ,
122+ custom_mask = custom_mask ,
123+ regex_pattern = regex_pattern ,
124+ mask_format = mask_format ,
125+ masking_rules = masking_rules ,
126+ ** kwargs ,
115127 )
128+ for item in data
129+ )
130+ result = type (data )(masked_data )
131+
132+ # Default case
133+ else :
134+ result = DATA_MASKING_STRING
116135
117136 return result
118137
119138 def _apply_masking_rules (self , data : dict , masking_rules : dict ) -> dict :
139+ """Apply masking rules to dictionary data."""
120140 return {
121141 key : self .erase (str (value ), ** masking_rules [key ]) if key in masking_rules else str (value )
122142 for key , value in data .items ()
123143 }
124144
125145 def _pattern_mask (self , data : str , pattern : str ) -> str :
146+ """Apply pattern masking to string data."""
126147 return pattern [: len (data )] if len (pattern ) >= len (data ) else pattern
127148
128149 def _regex_mask (self , data : str , regex_pattern : str , mask_format : str ) -> str :
150+ """Apply regex masking to string data."""
129151 try :
130152 if regex_pattern not in _regex_cache :
131153 _regex_cache [regex_pattern ] = re .compile (regex_pattern )
@@ -137,5 +159,4 @@ def _custom_erase(self, data: str, **kwargs) -> str:
137159 if not data :
138160 return ""
139161
140- # Use join with list comprehension instead of building list incrementally
141162 return "" .join ("*" if char not in PRESERVE_CHARS else char for char in data )
0 commit comments