Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ This changelog is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.
- Fixed the topic_message_query integarion test
- good first issue template yaml rendering
- Fixed solo workflow defaulting to zero
- TLS Hostname Mismatch & Certificate Verification Failure for Nodes

### Breaking Change

Expand Down
52 changes: 36 additions & 16 deletions src/hiero_sdk_python/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import socket
import ssl # Python's ssl module implements TLS (despite the name)
import grpc
from typing import Optional, Callable
from typing import Optional
from hiero_sdk_python.account.account_id import AccountId
from hiero_sdk_python.channels import _Channel
from hiero_sdk_python.address_book.node_address import NodeAddress
Expand Down Expand Up @@ -92,6 +92,7 @@ def __init__(self, account_id: AccountId, address: str, address_book: NodeAddres
self._verify_certificates: bool = True
self._root_certificates: Optional[bytes] = None
self._authority_override: Optional[str] = self._determine_authority_override()
self._node_pem_cert: Optional[bytes] = None

def _close(self):
"""
Expand All @@ -115,13 +116,23 @@ def _get_channel(self):
return self._channel

if self._address._is_transport_security():
if self._root_certificates:
# Use the certificate that provider
self._node_pem_cert = self._root_certificates
else:
# Fetch pem_cert for the node
self._node_pem_cert = self._fetch_server_certificate_pem()

# Validate certificate if verification is enabled
if self._verify_certificates:
self._validate_tls_certificate_with_trust_manager()

if not self._node_pem_cert:
raise ValueError("No certificate available.")

options = self._build_channel_options()
credentials = grpc.ssl_channel_credentials(
root_certificates=self._root_certificates,
root_certificates=self._node_pem_cert,
private_key=None,
certificate_chain=None,
)
Expand All @@ -141,7 +152,9 @@ def _apply_transport_security(self, enabled: bool):
return
if not enabled and not self._address._is_transport_security():
return

self._close()

if enabled:
self._address = self._address._to_secure()
else:
Expand All @@ -154,13 +167,16 @@ def _set_root_certificates(self, root_certificates: Optional[bytes]):
self._root_certificates = root_certificates
if self._channel and self._address._is_transport_security():
self._close()

def _set_verify_certificates(self, verify: bool):
"""
Set whether TLS certificates should be verified.
"""
if self._verify_certificates == verify:
return

self._verify_certificates = verify

if verify and self._channel and self._address._is_transport_security():
# Force channel recreation to ensure certificates are revalidated.
self._close()
Expand All @@ -173,20 +189,25 @@ def _determine_authority_override(self) -> Optional[str]:
return None
for endpoint in self._address_book._addresses: # pylint: disable=protected-access
domain = endpoint.get_domain_name()

if domain:
return domain

return None

def _build_channel_options(self):
"""
Build gRPC channel options for TLS connections.
"""
if not self._authority_override:
return None
host = self._address._get_host()
if host == self._authority_override:
return None
return [('grpc.ssl_target_name_override', self._authority_override)]
options = [
("grpc.default_authority", "127.0.0.1"),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it necessary to set this to loca host ? Can we set a=a to something non existent ?

("grpc.ssl_target_name_override", "127.0.0.1"),
("grpc.keepalive_time_ms", 100000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1)
]

return options

def _validate_tls_certificate_with_trust_manager(self):
"""
Expand All @@ -197,9 +218,7 @@ def _validate_tls_certificate_with_trust_manager(self):
Note: If verification is enabled but no cert hash is available (e.g., in unit tests
without address books), validation is skipped rather than raising an error.
"""
if not self._address._is_transport_security():
return
if not self._verify_certificates:
if not self._address._is_transport_security() or not self._verify_certificates:
return

cert_hash = None
Expand All @@ -214,10 +233,7 @@ def _validate_tls_certificate_with_trust_manager(self):

# Create trust manager and validate certificate
trust_manager = _HederaTrustManager(cert_hash, self._verify_certificates)

# Fetch server certificate and validate
pem_cert = self._fetch_server_certificate_pem()
trust_manager.check_server_trusted(pem_cert)
trust_manager.check_server_trusted(self._node_pem_cert)

@staticmethod
def _normalize_cert_hash(cert_hash: bytes) -> str:
Expand All @@ -228,6 +244,7 @@ def _normalize_cert_hash(cert_hash: bytes) -> str:
decoded = cert_hash.decode('utf-8').strip().lower()
if decoded.startswith("0x"):
decoded = decoded[2:]

return decoded
except UnicodeDecodeError:
return cert_hash.hex()
Expand All @@ -239,6 +256,9 @@ def _fetch_server_certificate_pem(self) -> bytes:
Returns:
bytes: PEM-encoded certificate bytes
"""
if not self._address_book:
return None

host = self._address._get_host()
port = self._address._get_port()
server_hostname = self._authority_override or host
Expand All @@ -254,4 +274,4 @@ def _fetch_server_certificate_pem(self) -> bytes:

# Convert DER to PEM format (matching Java's PEM encoding)
pem_cert = ssl.DER_cert_to_PEM_cert(der_cert).encode('utf-8')
return pem_cert
return pem_cert
20 changes: 18 additions & 2 deletions tests/unit/conftest.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import hashlib
import time

import pytest

from hiero_sdk_python.account.account_id import AccountId
from hiero_sdk_python.address_book.node_address import NodeAddress
from hiero_sdk_python.client.client import Client
from hiero_sdk_python.client.network import Network
from hiero_sdk_python.consensus.topic_id import TopicId
from hiero_sdk_python.contract.contract_id import ContractId
from hiero_sdk_python.crypto.private_key import PrivateKey
from hiero_sdk_python.file.file_id import FileId
from hiero_sdk_python.hapi.services import timestamp_pb2
from hiero_sdk_python.logger.log_level import LogLevel
from hiero_sdk_python.node import _Node
from hiero_sdk_python.tokens.token_id import TokenId
Expand All @@ -18,6 +19,12 @@
from hiero_sdk_python.transaction.transaction_id import TransactionId


FAKE_CERT_PEM = b"""-----BEGIN CERTIFICATE-----
MIIBszCCAVmgAwIBAgIUQFakeFakeFakeFakeFakeFakeFakewCgYIKoZIzj0EAwIw
-----END CERTIFICATE-----"""

FAKE_CERT_HASH = hashlib.sha384(FAKE_CERT_PEM).hexdigest().encode("utf-8")

@pytest.fixture
def mock_account_ids():
"""Fixture to provide mock account IDs and token IDs."""
Expand Down Expand Up @@ -78,7 +85,16 @@ def contract_id():
@pytest.fixture
def mock_client():
"""Fixture to provide a mock client with hardcoded nodes for testing purposes."""
nodes = [_Node(AccountId(0, 0, 3), "node1.example.com:50211", None)]
# Mock Node
node = _Node(
AccountId(0, 0, 3),
"node1.example.com:50211",
address_book=NodeAddress(cert_hash=FAKE_CERT_HASH, addresses=[])
)
node._fetch_server_certificate_pem = lambda: FAKE_CERT_PEM

nodes = [node]

network = Network(nodes=nodes)
client = Client(network)
client.logger.set_level(LogLevel.DISABLED)
Expand Down
82 changes: 51 additions & 31 deletions tests/unit/test_node_tls.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
"""Unit tests for TLS functionality in _Node."""
import hashlib
import socket
import ssl
from unittest.mock import Mock, patch, MagicMock
import pytest
import grpc
from src.hiero_sdk_python.node import _Node, _HederaTrustManager
from src.hiero_sdk_python.node import _Node
from src.hiero_sdk_python.account.account_id import AccountId
from src.hiero_sdk_python.address_book.node_address import NodeAddress
from src.hiero_sdk_python.address_book.endpoint import Endpoint
Expand Down Expand Up @@ -89,7 +86,7 @@ def test_node_apply_transport_security_closes_channel(mock_node_with_address_boo
node._verify_certificates = False

# Create a channel first
with patch('grpc.secure_channel') as mock_secure:
with patch('grpc.secure_channel') as mock_secure, patch.object(node, "_fetch_server_certificate_pem", return_value=b"dummy-cert"):
mock_channel = Mock()
mock_secure.return_value = mock_channel
node._get_channel()
Expand Down Expand Up @@ -121,7 +118,7 @@ def test_node_set_verify_certificates_idempotent(mock_node_with_address_book):
assert node._verify_certificates == initial_state


def test_node_build_channel_options_with_hostname_override(mock_address_book):
def test_node_build_channel_options_with_hostname_not_override(mock_address_book):
"""Test channel options include hostname override when domain differs from address."""
endpoint = Endpoint(address=b"127.0.0.1", port=50212, domain_name="node.example.com")
address_book = NodeAddress(
Expand All @@ -133,7 +130,7 @@ def test_node_build_channel_options_with_hostname_override(mock_address_book):

options = node._build_channel_options()
assert options is not None
assert ('grpc.ssl_target_name_override', 'node.example.com') in options
assert ('grpc.ssl_target_name_override', 'node.example.com') not in options


def test_node_build_channel_options_no_override_when_same(mock_address_book):
Expand All @@ -147,14 +144,26 @@ def test_node_build_channel_options_no_override_when_same(mock_address_book):
node = _Node(AccountId(0, 0, 3), "node.example.com:50212", address_book)

options = node._build_channel_options()
assert options is None
assert options == [
("grpc.default_authority", "127.0.0.1"),
("grpc.ssl_target_name_override", "127.0.0.1"),
("grpc.keepalive_time_ms", 100000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1)
]


def test_node_build_channel_options_no_override_without_address_book(mock_node_without_address_book):
def test_node_build_channel_options_override_localhost_without_address_book(mock_node_without_address_book):
"""Test channel options don't include override without address book."""
node = mock_node_without_address_book
options = node._build_channel_options()
assert options is None
assert options == [
("grpc.default_authority", "127.0.0.1"),
("grpc.ssl_target_name_override", "127.0.0.1"),
("grpc.keepalive_time_ms", 100000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1)
]


@patch('socket.create_connection')
Expand Down Expand Up @@ -189,6 +198,7 @@ def test_node_validate_tls_certificate_with_trust_manager(mock_node_with_address

# Update address book with matching hash
node._address_book._cert_hash = cert_hash.encode('utf-8')
node._node_pem_cert = pem_cert

with patch.object(node, '_fetch_server_certificate_pem', return_value=pem_cert):
# Should not raise
Expand All @@ -203,10 +213,10 @@ def test_node_validate_tls_certificate_hash_mismatch(mock_node_with_address_book
pem_cert = b"-----BEGIN CERTIFICATE-----\nTEST\n-----END CERTIFICATE-----\n"
wrong_hash = b"wrong_hash"
node._address_book._cert_hash = wrong_hash
node._node_pem_cert = pem_cert

with patch.object(node, '_fetch_server_certificate_pem', return_value=pem_cert):
with pytest.raises(ValueError, match="Failed to confirm the server's certificate"):
node._validate_tls_certificate_with_trust_manager()
with pytest.raises(ValueError, match="Failed to confirm the server's certificate"):
node._validate_tls_certificate_with_trust_manager()


def test_node_validate_tls_certificate_no_verification(mock_node_with_address_book):
Expand Down Expand Up @@ -236,17 +246,18 @@ def test_node_get_channel_secure(mock_insecure, mock_secure, mock_node_with_addr
node = mock_node_with_address_book
node._address = node._address._to_secure() # Ensure TLS is enabled

mock_channel = Mock()
mock_secure.return_value = mock_channel
with patch.object(node, "_fetch_server_certificate_pem", return_value=b"dummy-cert"):
mock_channel = Mock()
mock_secure.return_value = mock_channel

# Skip certificate validation for this test
node._verify_certificates = False
# Skip certificate validation for this test
node._verify_certificates = False

channel = node._get_channel()
channel = node._get_channel()

mock_secure.assert_called_once()
mock_insecure.assert_not_called()
assert channel is not None
mock_secure.assert_called_once()
mock_insecure.assert_not_called()
assert channel is not None


@patch('grpc.secure_channel')
Expand All @@ -272,15 +283,16 @@ def test_node_get_channel_reuses_existing(mock_insecure, mock_secure, mock_node_
node = mock_node_with_address_book
node._verify_certificates = False

mock_channel = Mock()
mock_secure.return_value = mock_channel

channel1 = node._get_channel()
channel2 = node._get_channel()

# Should only create channel once
assert mock_secure.call_count == 1
assert channel1 is channel2
with patch.object(node, "_fetch_server_certificate_pem", return_value=b"dummy-cert"):
mock_channel = Mock()
mock_secure.return_value = mock_channel

channel1 = node._get_channel()
channel2 = node._get_channel()

# Should only create channel once
assert mock_secure.call_count == 1
assert channel1 is channel2


def test_node_set_root_certificates(mock_node_with_address_book):
Expand All @@ -297,7 +309,8 @@ def test_node_set_root_certificates_closes_channel(mock_node_with_address_book):
node = mock_node_with_address_book
node._verify_certificates = False

with patch('grpc.secure_channel') as mock_secure:
with patch('grpc.secure_channel') as mock_secure, patch.object(node, "_fetch_server_certificate_pem", return_value=b"dummy-cert"):

mock_channel = Mock()
mock_secure.return_value = mock_channel
node._get_channel()
Expand All @@ -307,3 +320,10 @@ def test_node_set_root_certificates_closes_channel(mock_node_with_address_book):
# Channel should be closed to force recreation
assert node._channel is None

def test_secure_coonect_raise_error_if_no_certificate_is_available(mock_node_without_address_book):
"""Test get channel raise error if no certificate available if transport security true."""
node = mock_node_without_address_book
node._apply_transport_security(True)

with pytest.raises(ValueError, match="No certificate available."):
node._get_channel()
Loading