Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions python/coinbase-agentkit/changelog.d/968.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed asyncio event loop conflicts in CdpEvmWalletProvider, CdpSmartWalletProvider, and CdpSolanaWalletProvider when used within existing async contexts (e.g., Jupyter notebooks, async frameworks)
Original file line number Diff line number Diff line change
Expand Up @@ -66,9 +66,9 @@ def __init__(self, config: CdpEvmWalletProviderConfig):

client = self.get_client()
if config.address:
account = asyncio.run(self._get_account(client, config.address))
account = self._run_async(self._get_account(client, config.address))
else:
account = asyncio.run(self._create_account(client))
account = self._run_async(self._create_account(client))

self._account = account

Expand Down Expand Up @@ -366,8 +366,29 @@ def _run_async(self, coroutine):

"""
try:
loop = asyncio.get_event_loop()
# Check if we're in an existing event loop
loop = asyncio.get_running_loop()
# If we reach this point, there's already a running event loop
# We need to run the coroutine in a new thread with a new event loop
import concurrent.futures

def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(coroutine)
finally:
new_loop.close()

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()

except RuntimeError:
# No running event loop, safe to create and use a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coroutine)
try:
return loop.run_until_complete(coroutine)
finally:
loop.close()
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,12 @@ async def initialize_accounts():
smart_account = await cdp.evm.create_smart_account(owner=owner)
return owner, smart_account

owner, smart_account = asyncio.run(initialize_accounts())
owner, smart_account = self._run_async(initialize_accounts())
self._address = smart_account.address
self._owner = owner

finally:
asyncio.run(client.close())
self._run_async(client.close())

self._gas_limit_multiplier = (
max(config.gas.gas_limit_multiplier, 1)
Expand Down Expand Up @@ -171,11 +171,32 @@ def _run_async(self, coroutine):

"""
try:
loop = asyncio.get_event_loop()
# Check if we're in an existing event loop
loop = asyncio.get_running_loop()
# If we reach this point, there's already a running event loop
# We need to run the coroutine in a new thread with a new event loop
import concurrent.futures

def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(coroutine)
finally:
new_loop.close()

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()

except RuntimeError:
# No running event loop, safe to create and use a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coroutine)
try:
return loop.run_until_complete(coroutine)
finally:
loop.close()

async def _get_smart_account(self, cdp):
"""Get the smart account, handling server wallet owners differently.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ async def initialize_wallet():
wallet = await cdp.solana.create_account()
return wallet

wallet = asyncio.run(initialize_wallet())
wallet = self._run_async(initialize_wallet())
self._address = wallet.address

finally:
asyncio.run(client.close())
self._run_async(client.close())

except ImportError as e:
raise ImportError(
Expand Down Expand Up @@ -113,11 +113,32 @@ def _run_async(self, coroutine):

"""
try:
loop = asyncio.get_event_loop()
# Check if we're in an existing event loop
loop = asyncio.get_running_loop()
# If we reach this point, there's already a running event loop
# We need to run the coroutine in a new thread with a new event loop
import concurrent.futures

def run_in_thread():
new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(new_loop)
try:
return new_loop.run_until_complete(coroutine)
finally:
new_loop.close()

with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(run_in_thread)
return future.result()

except RuntimeError:
# No running event loop, safe to create and use a new one
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
return loop.run_until_complete(coroutine)
try:
return loop.run_until_complete(coroutine)
finally:
loop.close()

def get_address(self) -> str:
"""Get the wallet address.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,16 @@ def mocked_wallet_provider(mock_cdp_client, mock_account, mock_web3, mock_wallet
network_id=MOCK_NETWORK_ID,
)

# Patch the async run to return the mock account directly
with patch("asyncio.run", return_value=mock_account):
# Patch _run_async only during init to return the mock account
original_run_async = CdpEvmWalletProvider._run_async
CdpEvmWalletProvider._run_async = lambda self, coro: mock_account
try:
provider = CdpEvmWalletProvider(config)
finally:
CdpEvmWalletProvider._run_async = original_run_async

# Manually set account and wallet attributes
provider._account = mock_account
provider._wallet = mock_wallet
# Manually set account and wallet attributes
provider._account = mock_account
provider._wallet = mock_wallet

yield provider
yield provider
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

def test_init_with_config(mock_cdp_client, mock_account):
"""Test initialization with config."""
with patch("asyncio.run") as mock_run:
with patch.object(CdpEvmWalletProvider, "_run_async") as mock_run:
mock_run.return_value = mock_account

config = CdpEvmWalletProviderConfig(
Expand All @@ -44,7 +44,7 @@ def test_init_with_config(mock_cdp_client, mock_account):
def test_init_with_env_vars(mock_cdp_client, mock_account):
"""Test initialization with environment variables."""
with (
patch("asyncio.run") as mock_run,
patch.object(CdpEvmWalletProvider, "_run_async") as mock_run,
patch.dict(
os.environ,
{
Expand All @@ -66,7 +66,7 @@ def test_init_with_env_vars(mock_cdp_client, mock_account):
def test_init_with_default_network(mock_cdp_client, mock_account):
"""Test initialization with default network when no network ID is provided."""
with (
patch("asyncio.run") as mock_run,
patch.object(CdpEvmWalletProvider, "_run_async") as mock_run,
patch(
"os.getenv",
side_effect=lambda key, default=None: "base-sepolia" if key == "NETWORK_ID" else None,
Expand Down Expand Up @@ -99,7 +99,9 @@ def test_init_with_missing_credentials():
def test_init_with_invalid_network(mock_cdp_client):
"""Test initialization with invalid network."""
# Use a known invalid network ID
with patch("asyncio.run", side_effect=ValueError("Invalid network ID")):
with patch.object(
CdpEvmWalletProvider, "_run_async", side_effect=ValueError("Invalid network ID")
):
config = CdpEvmWalletProviderConfig(
api_key_id=MOCK_API_KEY_ID,
api_key_secret=MOCK_API_KEY_SECRET,
Expand All @@ -113,7 +115,9 @@ def test_init_with_invalid_network(mock_cdp_client):

def test_init_with_account_creation_error(mock_cdp_client):
"""Test initialization when account creation fails."""
with patch("asyncio.run", side_effect=Exception("Failed to create account")):
with patch.object(
CdpEvmWalletProvider, "_run_async", side_effect=Exception("Failed to create account")
):
config = CdpEvmWalletProviderConfig(
api_key_id=MOCK_API_KEY_ID,
api_key_secret=MOCK_API_KEY_SECRET,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def mock_asyncio():
# Create a side effect for asyncio.run that handles the initialization coroutine
def run_side_effect(coro):
# Special handling for the initialization_accounts coroutine
if coro.__name__ == "initialize_accounts":
if hasattr(coro, "__name__") and coro.__name__ == "initialize_accounts":
# Return a tuple of mock owner and smart account
mock_owner = Mock(spec=Account)
mock_owner.address = "0x123456789012345678901234567890123456789012"
Expand All @@ -127,7 +127,12 @@ def run_side_effect(coro):
return None

mock_asyncio.run = Mock(side_effect=run_side_effect)

# Also patch _run_async on the class so __init__ works with our mocks
original_run_async = CdpSmartWalletProvider._run_async
CdpSmartWalletProvider._run_async = lambda self, coro: run_side_effect(coro)
yield mock_asyncio
CdpSmartWalletProvider._run_async = original_run_async


@pytest.fixture
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,13 +124,14 @@ def test_init_with_private_key_owner(mock_cdp_client, mock_asyncio, mock_network
) as mock_account_class:
mock_account_class.from_key.return_value = mock_owner

# Update mock_asyncio.run to return a tuple with the owner and smart_account for initialization
def run_side_effect(coro):
if coro.__name__ == "initialize_accounts":
# Update _run_async to return the right values for initialization
def run_async_side_effect(self, coro):
if hasattr(coro, "__name__") and coro.__name__ == "initialize_accounts":
return (mock_owner, mock_smart_account)
return None

mock_asyncio.run.side_effect = run_side_effect
# Patch _run_async on the class for this test
CdpSmartWalletProvider._run_async = run_async_side_effect

async def create_smart_account_mock(*args, **kwargs):
return mock_smart_account
Expand All @@ -147,7 +148,7 @@ async def create_smart_account_mock(*args, **kwargs):

provider = CdpSmartWalletProvider(config)

# Don't check if from_key was called, since we're mocking asyncio.run
# Don't check if from_key was called, since we're mocking _run_async
# We only care that the provider was initialized correctly
assert provider.get_address() == MOCK_ADDRESS
assert provider._owner == mock_owner
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,13 @@ def mocked_wallet_provider(mock_cdp_client, mock_solana_client, mock_public_key)
network_id=MOCK_NETWORK_ID,
)

# Patch asyncio.run to return the mock account
# Patch _run_async only during init, then restore it
mock_account = Mock(address=MOCK_ADDRESS)
with patch("asyncio.run", return_value=mock_account):
original_run_async = CdpSolanaWalletProvider._run_async
CdpSolanaWalletProvider._run_async = lambda self, coro: mock_account
try:
provider = CdpSolanaWalletProvider(config)
finally:
CdpSolanaWalletProvider._run_async = original_run_async

yield provider
yield provider
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,9 @@ def test_get_network_different_networks():
with (
patch("coinbase_agentkit.wallet_providers.cdp_solana_wallet_provider.CdpClient"),
patch("coinbase_agentkit.wallet_providers.cdp_solana_wallet_provider.SolanaClient"),
patch("asyncio.run", return_value=Mock(address=MOCK_ADDRESS)),
patch.object(
CdpSolanaWalletProvider, "_run_async", return_value=Mock(address=MOCK_ADDRESS)
),
):
config = CdpSolanaWalletProviderConfig(
api_key_id="test_key",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

import pytest

from coinbase_agentkit.wallet_providers.cdp_solana_wallet_provider import CdpSolanaWalletProvider

# =========================================================
# error handling tests
# =========================================================
Expand All @@ -23,7 +25,7 @@ async def raise_connection_error(*args, **kwargs):
mock_cdp_client.solana.get_account = AsyncMock(return_value=mock_wallet)

with (
patch("asyncio.run", side_effect=ConnectionError(error_msg)),
patch.object(CdpSolanaWalletProvider, "_run_async", side_effect=ConnectionError(error_msg)),
pytest.raises(ConnectionError, match=error_msg),
):
mocked_wallet_provider.native_transfer("SomeAddress", Decimal("0.5"))
Expand All @@ -32,7 +34,7 @@ async def raise_connection_error(*args, **kwargs):
mock_cdp_client.solana.sign_message = raise_connection_error

with (
patch("asyncio.run", side_effect=ConnectionError(error_msg)),
patch.object(CdpSolanaWalletProvider, "_run_async", side_effect=ConnectionError(error_msg)),
pytest.raises(ConnectionError, match=error_msg),
):
mocked_wallet_provider.sign_message("test message")
Expand All @@ -51,7 +53,7 @@ async def raise_value_error(*args, **kwargs):
mock_cdp_client.solana.get_account = AsyncMock(return_value=mock_wallet)

with (
patch("asyncio.run", side_effect=ValueError(address_error)),
patch.object(CdpSolanaWalletProvider, "_run_async", side_effect=ValueError(address_error)),
pytest.raises(ValueError, match=address_error),
):
mocked_wallet_provider.native_transfer("invalid_address", Decimal("1.0"))
Expand All @@ -75,7 +77,7 @@ async def raise_auth_error(*args, **kwargs):
mock_cdp_client.solana.get_account = raise_auth_error

with (
patch("asyncio.run", side_effect=Exception(auth_error)),
patch.object(CdpSolanaWalletProvider, "_run_async", side_effect=Exception(auth_error)),
pytest.raises(Exception, match=auth_error),
):
mocked_wallet_provider.native_transfer("SomeAddress", Decimal("1.0"))
Expand All @@ -85,7 +87,9 @@ async def raise_auth_error(*args, **kwargs):
mock_cdp_client.solana.sign_message = AsyncMock(side_effect=Exception(rate_limit_error))

with (
patch("asyncio.run", side_effect=Exception(rate_limit_error)),
patch.object(
CdpSolanaWalletProvider, "_run_async", side_effect=Exception(rate_limit_error)
),
pytest.raises(Exception, match=rate_limit_error),
):
mocked_wallet_provider.sign_message("test")
Expand All @@ -104,7 +108,7 @@ async def raise_balance_error(*args, **kwargs):
mock_cdp_client.solana.get_account = AsyncMock(return_value=mock_wallet)

with (
patch("asyncio.run", side_effect=Exception(balance_error)),
patch.object(CdpSolanaWalletProvider, "_run_async", side_effect=Exception(balance_error)),
pytest.raises(Exception, match=balance_error),
):
mocked_wallet_provider.native_transfer("SomeAddress", Decimal("1000000"))
Expand All @@ -118,7 +122,9 @@ async def raise_timeout(*args, **kwargs):
mock_wallet.transfer = raise_timeout

with (
patch("asyncio.run", side_effect=TimeoutError(timeout_error)),
patch.object(
CdpSolanaWalletProvider, "_run_async", side_effect=TimeoutError(timeout_error)
),
pytest.raises(TimeoutError, match=timeout_error),
):
mocked_wallet_provider.native_transfer("SomeAddress", Decimal("1.0"))
Expand Down
Loading