-
Notifications
You must be signed in to change notification settings - Fork 8
fix(keras): prevent spoofed built-in registered_name from hiding non-allowlisted modules #736
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
666358a
cdb440e
82abd75
37d08a6
903f518
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -135,47 +135,60 @@ def _is_allowlisted_keras_module(module_value: Any) -> bool: | |
| return False | ||
| return module_value.strip().split(".")[0] in _SAFE_KERAS_MODULE_ROOTS | ||
|
|
||
| def _layer_uses_allowlisted_module(self, layer: dict[str, Any]) -> bool: | ||
| def _iter_layer_module_references(self, layer: dict[str, Any]) -> list[str]: | ||
| layer_config = layer.get("config", {}) | ||
| if not isinstance(layer_config, dict): | ||
| layer_config = {} | ||
|
|
||
| module_references: list[str] = [] | ||
| for key in ("module", "fn_module"): | ||
| if self._is_allowlisted_keras_module(layer.get(key)): | ||
| return True | ||
| if self._is_allowlisted_keras_module(layer_config.get(key)): | ||
| return True | ||
| return False | ||
| for value in (layer.get(key), layer_config.get(key)): | ||
| if isinstance(value, str) and value.strip(): | ||
| module_references.append(value.strip()) | ||
| return module_references | ||
|
|
||
| def _layer_uses_allowlisted_module(self, layer: dict[str, Any]) -> bool: | ||
| return any( | ||
| self._is_allowlisted_keras_module(module_value) | ||
| for module_value in self._iter_layer_module_references(layer) | ||
| ) | ||
|
|
||
| def _layer_uses_non_allowlisted_module(self, layer: dict[str, Any]) -> bool: | ||
| return any( | ||
| not self._is_allowlisted_keras_module(module_value) | ||
| for module_value in self._iter_layer_module_references(layer) | ||
| ) | ||
|
|
||
| @staticmethod | ||
| def _is_known_safe_allowlisted_registered_object(identifier: Any) -> bool: | ||
| return isinstance(identifier, str) and identifier.strip().lower() in _SAFE_ALLOWLISTED_REGISTERED_OBJECTS | ||
|
|
||
| def _is_known_safe_serialized_layer(self, layer: dict[str, Any]) -> bool: | ||
| layer_class = layer.get("class_name") | ||
| if is_known_safe_keras_layer_class(layer_class): | ||
| return True | ||
|
|
||
| return self._layer_uses_allowlisted_module(layer) and self._is_known_safe_allowlisted_registered_object( | ||
| if is_known_safe_keras_layer_class(layer_class) or self._is_known_safe_allowlisted_registered_object( | ||
| layer_class | ||
| ) | ||
| ): | ||
| return not self._layer_uses_non_allowlisted_module(layer) | ||
|
|
||
|
Comment on lines
+168
to
+172
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Match allowlisted registered objects against Line 151 and Line 167 still feed 🔧 Suggested direction def _is_known_safe_serialized_layer(self, layer: dict[str, Any]) -> bool:
layer_class = layer.get("class_name")
+ registered_name = layer.get("registered_name")
if is_known_safe_keras_layer_class(layer_class) or self._is_known_safe_allowlisted_registered_object(
- layer_class
+ registered_name
+ ) or self._is_known_safe_allowlisted_registered_object(
+ layer_class
):
return not self._layer_uses_non_allowlisted_module(layer)
return False
...
layer_class = layer.get("class_name")
if isinstance(layer_class, str) and normalized_registered_name == layer_class.strip():
- if self._is_known_safe_serialized_layer(layer) or self._is_known_safe_allowlisted_registered_object(
- layer_class
- ):
+ if self._is_known_safe_serialized_layer(layer) or self._is_known_safe_allowlisted_registered_object(
+ normalized_registered_name
+ ):
return has_non_allowlisted_module
return TrueBased on learnings: Preserve or strengthen security detections; test both benign and malicious samples when adding scanner/feature changes. Also applies to: 167-170 🤖 Prompt for AI Agents |
||
| return False | ||
|
|
||
| def _should_flag_registered_object(self, layer: dict[str, Any]) -> bool: | ||
| registered_name = layer.get("registered_name") | ||
| if not isinstance(registered_name, str) or not registered_name.strip(): | ||
| return False | ||
|
|
||
| if is_known_safe_keras_layer_class(registered_name): | ||
| return False | ||
|
|
||
| normalized_registered_name = registered_name.strip() | ||
| has_non_allowlisted_module = self._layer_uses_non_allowlisted_module(layer) | ||
| layer_class = layer.get("class_name") | ||
| if isinstance(layer_class, str) and registered_name.strip() == layer_class.strip(): | ||
| if self._is_known_safe_serialized_layer(layer): | ||
| return False | ||
| return not ( | ||
| self._layer_uses_allowlisted_module(layer) | ||
| and self._is_known_safe_allowlisted_registered_object(layer_class) | ||
| ) | ||
| if isinstance(layer_class, str) and normalized_registered_name == layer_class.strip(): | ||
| if self._is_known_safe_serialized_layer(layer) or self._is_known_safe_allowlisted_registered_object( | ||
| layer_class | ||
| ): | ||
| return has_non_allowlisted_module | ||
| return True | ||
|
|
||
| if is_known_safe_keras_layer_class(normalized_registered_name): | ||
| return has_non_allowlisted_module | ||
|
|
||
| return True | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.