diff --git a/packages/transaction-pay-controller/CHANGELOG.md b/packages/transaction-pay-controller/CHANGELOG.md index 3a66a07a7c..735685211f 100644 --- a/packages/transaction-pay-controller/CHANGELOG.md +++ b/packages/transaction-pay-controller/CHANGELOG.md @@ -10,6 +10,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Changed - Bump `@metamask/gas-fee-controller` from `^26.2.1` to `^26.2.2` ([#8834](https://github.com/MetaMask/core/pull/8834)) +- `getLiveTokenBalance` now prefers the Infura RPC endpoint for a chain when querying live token balances, falling back to the chain's default endpoint if no Infura endpoint is configured ([#XXXX](https://github.com/MetaMask/core/pull/XXXX)) ### Fixed diff --git a/packages/transaction-pay-controller/src/strategy/across/across-submit.ts b/packages/transaction-pay-controller/src/strategy/across/across-submit.ts index 3c52f2395d..cc0bb7067a 100644 --- a/packages/transaction-pay-controller/src/strategy/across/across-submit.ts +++ b/packages/transaction-pay-controller/src/strategy/across/across-submit.ts @@ -21,6 +21,7 @@ import type { import { accountSupports7702 } from '../../utils/7702'; import { getPayStrategiesConfig } from '../../utils/feature-flags'; import { getGasBuffer } from '../../utils/feature-flags'; +import { getNetworkClientId } from '../../utils/provider'; import { collectTransactionIds, getTransaction, @@ -141,10 +142,7 @@ async function submitTransactions( const transactionCount = orderedTransactions.length + (shouldPrependOriginalTransaction ? 1 : 0); - const networkClientId = messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - chainId, - ); + const networkClientId = getNetworkClientId(messenger, chainId); const is7702Batch = is7702 && transactionCount > 1; const canUseQuotedBatchGasLimit = diff --git a/packages/transaction-pay-controller/src/strategy/relay/relay-submit.ts b/packages/transaction-pay-controller/src/strategy/relay/relay-submit.ts index 372a016ab5..98c098e081 100644 --- a/packages/transaction-pay-controller/src/strategy/relay/relay-submit.ts +++ b/packages/transaction-pay-controller/src/strategy/relay/relay-submit.ts @@ -20,6 +20,7 @@ import { getRelayPollingInterval, getRelayPollingTimeout, } from '../../utils/feature-flags'; +import { getNetworkClientId } from '../../utils/provider'; import { getLiveTokenBalance, normalizeTokenAddress, @@ -488,10 +489,7 @@ async function submitViaRelayExecute( const { from, sourceChainId } = quote.request; const { requestId } = quote.original.steps[0]; - const networkClientId = messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - sourceChainId, - ); + const networkClientId = getNetworkClientId(messenger, sourceChainId); const sourceCallTransaction = { ...transaction, @@ -580,10 +578,7 @@ async function submitViaTransactionController( const { from, sourceChainId, sourceTokenAddress } = quote.request; const { isPostQuote } = quote.request; - const networkClientId = messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - sourceChainId, - ); + const networkClientId = getNetworkClientId(messenger, sourceChainId); log('Adding transactions', { normalizedParams: allParams, diff --git a/packages/transaction-pay-controller/src/tests/messenger-mock.ts b/packages/transaction-pay-controller/src/tests/messenger-mock.ts index 1931aa202a..1ee06b397a 100644 --- a/packages/transaction-pay-controller/src/tests/messenger-mock.ts +++ b/packages/transaction-pay-controller/src/tests/messenger-mock.ts @@ -15,6 +15,7 @@ import type { import { Messenger, MOCK_ANY_NAMESPACE } from '@metamask/messenger'; import type { NetworkControllerGetNetworkClientByIdAction } from '@metamask/network-controller'; import type { NetworkControllerFindNetworkClientIdByChainIdAction } from '@metamask/network-controller'; +import type { NetworkControllerGetNetworkConfigurationByChainIdAction } from '@metamask/network-controller'; import type { RemoteFeatureFlagControllerGetStateAction } from '@metamask/remote-feature-flag-controller'; import type { TransactionControllerAddTransactionAction, @@ -116,6 +117,10 @@ export function getMessengerMock({ NetworkControllerGetNetworkClientByIdAction['handler'] > = jest.fn(); + const getNetworkConfigurationByChainIdMock: jest.MockedFn< + NetworkControllerGetNetworkConfigurationByChainIdAction['handler'] + > = jest.fn(); + const getDelegationTransactionMock: jest.MockedFn< TransactionPayControllerGetDelegationTransactionAction['handler'] > = jest.fn(); @@ -250,6 +255,11 @@ export function getMessengerMock({ getNetworkClientByIdMock, ); + messenger.registerActionHandler( + 'NetworkController:getNetworkConfigurationByChainId', + getNetworkConfigurationByChainIdMock, + ); + messenger.registerActionHandler( 'TransactionPayController:getDelegationTransaction', getDelegationTransactionMock, @@ -310,6 +320,7 @@ export function getMessengerMock({ getGasFeeTokensMock, getKeyringControllerStateMock, getNetworkClientByIdMock, + getNetworkConfigurationByChainIdMock, getRemoteFeatureFlagControllerStateMock, getStrategyMock, getTokenBalanceControllerStateMock, diff --git a/packages/transaction-pay-controller/src/types.ts b/packages/transaction-pay-controller/src/types.ts index 8844400512..e3a3948c95 100644 --- a/packages/transaction-pay-controller/src/types.ts +++ b/packages/transaction-pay-controller/src/types.ts @@ -31,6 +31,7 @@ import type { import type { Messenger } from '@metamask/messenger'; import type { NetworkControllerFindNetworkClientIdByChainIdAction } from '@metamask/network-controller'; import type { NetworkControllerGetNetworkClientByIdAction } from '@metamask/network-controller'; +import type { NetworkControllerGetNetworkConfigurationByChainIdAction } from '@metamask/network-controller'; import type { Quote as RampsQuote } from '@metamask/ramps-controller'; import type { RampsControllerGetOrderAction, @@ -73,6 +74,7 @@ export type AllowedActions = | KeyringControllerSignTypedMessageAction | NetworkControllerFindNetworkClientIdByChainIdAction | NetworkControllerGetNetworkClientByIdAction + | NetworkControllerGetNetworkConfigurationByChainIdAction | RampsControllerGetOrderAction | RampsControllerGetQuotesAction | RampsControllerGetStateAction diff --git a/packages/transaction-pay-controller/src/utils/gas.ts b/packages/transaction-pay-controller/src/utils/gas.ts index fceb4f5ae0..fc98784d8a 100644 --- a/packages/transaction-pay-controller/src/utils/gas.ts +++ b/packages/transaction-pay-controller/src/utils/gas.ts @@ -11,6 +11,7 @@ import type { TransactionPayControllerMessenger } from '..'; import { createModuleLogger, projectLogger } from '../logger'; import type { Amount } from '../types'; import { getFallbackGas, getGasBuffer } from './feature-flags'; +import { getNetworkClientId } from './provider'; import { getNativeToken, getTokenBalance, getTokenFiatRate } from './token'; const log = createModuleLogger(projectLogger, 'gas'); @@ -227,10 +228,7 @@ export async function estimateGasLimit({ error?: unknown; }> { const gasBuffer = getGasBuffer(messenger, chainId); - const networkClientId = messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - chainId, - ); + const networkClientId = getNetworkClientId(messenger, chainId); let estimateGasError: unknown; let simulationError: Error | undefined; diff --git a/packages/transaction-pay-controller/src/utils/provider.test.ts b/packages/transaction-pay-controller/src/utils/provider.test.ts new file mode 100644 index 0000000000..e178aa0a71 --- /dev/null +++ b/packages/transaction-pay-controller/src/utils/provider.test.ts @@ -0,0 +1,186 @@ +import type { Provider } from '@metamask/network-controller'; +import { RpcEndpointType } from '@metamask/network-controller'; +import type { NetworkConfiguration } from '@metamask/network-controller'; +import type { Hex } from '@metamask/utils'; + +import { getMessengerMock } from '../tests/messenger-mock'; +import { getNetworkClientId, rpcRequest } from './provider'; + +const CHAIN_ID_MOCK = '0x1' as Hex; +const DEFAULT_NETWORK_CLIENT_ID_MOCK = 'default-client-id'; +const INFURA_NETWORK_CLIENT_ID_MOCK = 'mainnet'; +const PROVIDER_MOCK = { request: jest.fn() } as unknown as Provider; + +describe('provider utils', () => { + const { + messenger, + findNetworkClientIdByChainIdMock, + getNetworkClientByIdMock, + getNetworkConfigurationByChainIdMock, + } = getMessengerMock(); + + beforeEach(() => { + jest.resetAllMocks(); + + findNetworkClientIdByChainIdMock.mockReturnValue( + DEFAULT_NETWORK_CLIENT_ID_MOCK, + ); + + getNetworkClientByIdMock.mockReturnValue({ + provider: PROVIDER_MOCK, + } as never); + + getNetworkConfigurationByChainIdMock.mockReturnValue(undefined); + }); + + describe('getNetworkClientId', () => { + it('returns default network client ID when preferInfura is false', () => { + const result = getNetworkClientId(messenger, CHAIN_ID_MOCK); + + expect(result).toBe(DEFAULT_NETWORK_CLIENT_ID_MOCK); + expect(findNetworkClientIdByChainIdMock).toHaveBeenCalledWith( + CHAIN_ID_MOCK, + ); + expect(getNetworkConfigurationByChainIdMock).not.toHaveBeenCalled(); + }); + + it('returns Infura network client ID when preferInfura is true and Infura endpoint exists', () => { + getNetworkConfigurationByChainIdMock.mockReturnValue({ + rpcEndpoints: [ + { + type: RpcEndpointType.Infura, + networkClientId: INFURA_NETWORK_CLIENT_ID_MOCK, + }, + ], + } as NetworkConfiguration); + + const result = getNetworkClientId(messenger, CHAIN_ID_MOCK, { + preferInfura: true, + }); + + expect(result).toBe(INFURA_NETWORK_CLIENT_ID_MOCK); + expect(findNetworkClientIdByChainIdMock).not.toHaveBeenCalled(); + }); + + it('falls back to default network client ID when preferInfura is true but no Infura endpoint exists', () => { + getNetworkConfigurationByChainIdMock.mockReturnValue({ + rpcEndpoints: [ + { + type: RpcEndpointType.Custom, + networkClientId: 'custom-rpc-id', + }, + ], + } as NetworkConfiguration); + + const result = getNetworkClientId(messenger, CHAIN_ID_MOCK, { + preferInfura: true, + }); + + expect(result).toBe(DEFAULT_NETWORK_CLIENT_ID_MOCK); + expect(findNetworkClientIdByChainIdMock).toHaveBeenCalledWith( + CHAIN_ID_MOCK, + ); + }); + + it('falls back to default network client ID when preferInfura is true but getNetworkConfigurationByChainId throws', () => { + getNetworkConfigurationByChainIdMock.mockImplementation(() => { + throw new Error('Configuration not found'); + }); + + const result = getNetworkClientId(messenger, CHAIN_ID_MOCK, { + preferInfura: true, + }); + + expect(result).toBe(DEFAULT_NETWORK_CLIENT_ID_MOCK); + expect(findNetworkClientIdByChainIdMock).toHaveBeenCalledWith( + CHAIN_ID_MOCK, + ); + }); + + it('falls back to default network client ID when preferInfura is true but network configuration is undefined', () => { + getNetworkConfigurationByChainIdMock.mockReturnValue(undefined); + + const result = getNetworkClientId(messenger, CHAIN_ID_MOCK, { + preferInfura: true, + }); + + expect(result).toBe(DEFAULT_NETWORK_CLIENT_ID_MOCK); + expect(findNetworkClientIdByChainIdMock).toHaveBeenCalledWith( + CHAIN_ID_MOCK, + ); + }); + }); + + describe('rpcRequest', () => { + it('calls provider.request with method and params', async () => { + const requestMock = jest.fn().mockResolvedValue('0xabc'); + getNetworkClientByIdMock.mockReturnValue({ + provider: { request: requestMock }, + } as never); + + const result = await rpcRequest(messenger, CHAIN_ID_MOCK, 'eth_chainId', [ + 'latest', + ]); + + expect(result).toBe('0xabc'); + expect(requestMock).toHaveBeenCalledWith({ + method: 'eth_chainId', + params: ['latest'], + }); + }); + + it('calls provider.request without params when omitted', async () => { + const requestMock = jest.fn().mockResolvedValue('0x10'); + getNetworkClientByIdMock.mockReturnValue({ + provider: { request: requestMock }, + } as never); + + await rpcRequest(messenger, CHAIN_ID_MOCK, 'eth_blockNumber'); + + expect(requestMock).toHaveBeenCalledWith({ + method: 'eth_blockNumber', + params: undefined, + }); + }); + + it('propagates provider errors', async () => { + const error = new Error('RPC failed'); + const requestMock = jest.fn().mockRejectedValue(error); + getNetworkClientByIdMock.mockReturnValue({ + provider: { request: requestMock }, + } as never); + + await expect( + rpcRequest(messenger, CHAIN_ID_MOCK, 'eth_blockNumber'), + ).rejects.toBe(error); + }); + + it('uses Infura network client when preferInfura is true', async () => { + getNetworkConfigurationByChainIdMock.mockReturnValue({ + rpcEndpoints: [ + { + type: RpcEndpointType.Infura, + networkClientId: INFURA_NETWORK_CLIENT_ID_MOCK, + }, + ], + } as NetworkConfiguration); + + const requestMock = jest.fn().mockResolvedValue('0x1'); + getNetworkClientByIdMock.mockReturnValue({ + provider: { request: requestMock }, + } as never); + + await rpcRequest( + messenger, + CHAIN_ID_MOCK, + 'eth_chainId', + [], + { preferInfura: true }, + ); + + expect(getNetworkClientByIdMock).toHaveBeenCalledWith( + INFURA_NETWORK_CLIENT_ID_MOCK, + ); + }); + }); +}); diff --git a/packages/transaction-pay-controller/src/utils/provider.ts b/packages/transaction-pay-controller/src/utils/provider.ts new file mode 100644 index 0000000000..9a72c955a2 --- /dev/null +++ b/packages/transaction-pay-controller/src/utils/provider.ts @@ -0,0 +1,98 @@ +import type { NetworkClientId, Provider } from '@metamask/network-controller'; +import { RpcEndpointType } from '@metamask/network-controller'; +import type { Hex } from '@metamask/utils'; +import { createModuleLogger } from '@metamask/utils'; + +import { projectLogger } from '../logger'; +import type { TransactionPayControllerMessenger } from '../types'; + +const log = createModuleLogger(projectLogger, 'provider'); + +type ProviderRequestParams = Parameters[0]['params']; + +/** + * Options for network client resolution. + */ +export type GetNetworkClientIdOptions = { + /** + * When true, attempts to resolve to an Infura endpoint for the chain before + * falling back to the default selected endpoint. Useful for calls that use + * block tags (e.g. `pending`) that may not be supported by custom RPCs. + */ + preferInfura?: boolean; +}; + +/** + * Resolve the network client ID for a chain. + * + * When `preferInfura` is true the method tries to locate an Infura endpoint + * in the chain's network configuration and returns its `networkClientId`. + * If no Infura endpoint is configured, or if the configuration lookup throws, + * it falls back to `findNetworkClientIdByChainId`. + * + * @param messenger - The TransactionPayController messenger. + * @param chainId - The chain ID to resolve. + * @param options - Resolution options. + * @param options.preferInfura - Prefer the Infura endpoint when available. + * @returns The resolved network client ID. + */ +export function getNetworkClientId( + messenger: TransactionPayControllerMessenger, + chainId: Hex, + { preferInfura = false }: GetNetworkClientIdOptions = {}, +): NetworkClientId { + if (preferInfura) { + try { + const networkConfiguration = messenger.call( + 'NetworkController:getNetworkConfigurationByChainId', + chainId, + ); + + const infuraEndpoint = networkConfiguration?.rpcEndpoints.find( + (endpoint) => endpoint.type === RpcEndpointType.Infura, + ); + + if (infuraEndpoint) { + return infuraEndpoint.networkClientId; + } + } catch { + // empty + } + } + + return messenger.call( + 'NetworkController:findNetworkClientIdByChainId', + chainId, + ); +} + +/** + * Send an RPC request to the network for the specified chain. + * + * @param messenger - The TransactionPayController messenger. + * @param chainId - The chain ID to resolve. + * @param method - The JSON-RPC method name. + * @param params - Optional parameters for the RPC call. + * @param options - Resolution options (forwarded to {@link getNetworkClientId}). + * @returns The RPC response. + */ +export async function rpcRequest( + messenger: TransactionPayControllerMessenger, + chainId: Hex, + method: string, + params?: ProviderRequestParams, + options?: GetNetworkClientIdOptions, +): Promise { + const networkClientId = getNetworkClientId(messenger, chainId, options); + + const { provider } = messenger.call( + 'NetworkController:getNetworkClientById', + networkClientId, + ); + + const response = await provider.request({ method, params }); + + log(method, { params, response }); + + return response; +} diff --git a/packages/transaction-pay-controller/src/utils/token.test.ts b/packages/transaction-pay-controller/src/utils/token.test.ts index ffed34cbc1..990ee93593 100644 --- a/packages/transaction-pay-controller/src/utils/token.test.ts +++ b/packages/transaction-pay-controller/src/utils/token.test.ts @@ -1,8 +1,10 @@ -import { Contract } from '@ethersproject/contracts'; -import { Web3Provider } from '@ethersproject/providers'; +import { Interface } from '@ethersproject/abi'; import type { TokensControllerState } from '@metamask/assets-controllers'; import type { AccountTrackerControllerState } from '@metamask/assets-controllers'; import type { TokenRatesControllerState } from '@metamask/assets-controllers'; +import { abiERC20 } from '@metamask/metamask-eth-abis'; +import { RpcEndpointType } from '@metamask/network-controller'; +import type { NetworkConfiguration } from '@metamask/network-controller'; import type { Hex } from '@metamask/utils'; import { getDefaultRemoteFeatureFlagControllerState } from '../../../remote-feature-flag-controller/src/remote-feature-flag-controller'; @@ -26,15 +28,6 @@ import { TokenAddressTarget, } from './token'; -jest.mock('@ethersproject/contracts', () => ({ - ...jest.requireActual('@ethersproject/contracts'), - Contract: jest.fn(), -})); - -jest.mock('@ethersproject/providers', () => ({ - ...jest.requireActual('@ethersproject/providers'), - Web3Provider: jest.fn(), -})); const TOKEN_ADDRESS_MOCK = '0x559B65722aD62AD6DAC4Fa5a1c6B23A2e8ce57Ec' as Hex; const TOKEN_ADDRESS_2_MOCK = '0x123456789abcdef1234567890abcdef12345678' as Hex; @@ -43,6 +36,7 @@ const DECIMALS_MOCK = 6; const BALANCE_MOCK = '0x123' as Hex; const FROM_MOCK = '0x456' as Hex; const NETWORK_CLIENT_ID_MOCK = '123-456'; +const INFURA_NETWORK_CLIENT_ID_MOCK = 'mainnet'; const TICKER_MOCK = 'TST'; const SYMBOL_MOCK = 'TEST'; const ACCOUNT_MOCK = '0x1234567890abcdef1234567890abcdef12345678' as Hex; @@ -56,6 +50,7 @@ describe('Token Utils', () => { getRemoteFeatureFlagControllerStateMock, getTokensControllerStateMock, getNetworkClientByIdMock, + getNetworkConfigurationByChainIdMock, getTokenBalanceControllerStateMock, getAccountTrackerControllerStateMock, getTokenRatesControllerStateMock, @@ -63,33 +58,20 @@ describe('Token Utils', () => { findNetworkClientIdByChainIdMock, } = getMessengerMock(); - let mockBalanceOf: jest.Mock; - let mockGetBalance: jest.Mock; - beforeEach(() => { jest.resetAllMocks(); - mockBalanceOf = jest.fn(); - mockGetBalance = jest.fn(); - getRemoteFeatureFlagControllerStateMock.mockReturnValue({ ...getDefaultRemoteFeatureFlagControllerState(), }); findNetworkClientIdByChainIdMock.mockReturnValue(NETWORK_CLIENT_ID_MOCK); + getNetworkConfigurationByChainIdMock.mockReturnValue(undefined); getNetworkClientByIdMock.mockReturnValue({ configuration: { ticker: TICKER_MOCK }, provider: PROVIDER_MOCK, } as never); - - (Contract as unknown as jest.Mock).mockImplementation(() => ({ - balanceOf: mockBalanceOf, - })); - - (Web3Provider as unknown as jest.Mock).mockImplementation(() => ({ - getBalance: mockGetBalance, - })); }); function enableAssetsUnifyState(): void { @@ -631,8 +613,8 @@ describe('Token Utils', () => { }); describe('getLiveTokenBalance', () => { - it('returns ERC-20 balance via contract balanceOf', async () => { - mockBalanceOf.mockResolvedValue({ toString: () => '5000000' }); + it('returns ERC-20 balance via eth_call', async () => { + (PROVIDER_MOCK.request as jest.Mock).mockResolvedValue('0x4C4B40'); const result = await getLiveTokenBalance( messenger, @@ -648,21 +630,22 @@ describe('Token Utils', () => { expect(getNetworkClientByIdMock).toHaveBeenCalledWith( NETWORK_CLIENT_ID_MOCK, ); - expect(Web3Provider).toHaveBeenCalledWith(PROVIDER_MOCK); - expect(Contract).toHaveBeenCalledWith( - ERC20_ADDRESS_MOCK, - expect.anything(), - expect.anything(), - ); - expect(mockBalanceOf).toHaveBeenCalledWith(ACCOUNT_MOCK, { - blockTag: 'pending', + expect(PROVIDER_MOCK.request).toHaveBeenCalledWith({ + method: 'eth_call', + params: [ + { + to: ERC20_ADDRESS_MOCK, + data: new Interface(abiERC20).encodeFunctionData('balanceOf', [ + ACCOUNT_MOCK, + ]), + }, + 'pending', + ], }); }); - it('returns native balance via ethersProvider.getBalance', async () => { - mockGetBalance.mockResolvedValue({ - toString: () => '1000000000000000000', - }); + it('returns native balance via eth_getBalance', async () => { + (PROVIDER_MOCK.request as jest.Mock).mockResolvedValue('0xde0b6b3a7640000'); const result = await getLiveTokenBalance( messenger, @@ -672,14 +655,14 @@ describe('Token Utils', () => { ); expect(result).toBe('1000000000000000000'); - expect(mockGetBalance).toHaveBeenCalledWith(ACCOUNT_MOCK, 'pending'); - expect(Contract).not.toHaveBeenCalled(); + expect(PROVIDER_MOCK.request).toHaveBeenCalledWith({ + method: 'eth_getBalance', + params: [ACCOUNT_MOCK, 'pending'], + }); }); it('returns native balance for polygon native address', async () => { - mockGetBalance.mockResolvedValue({ - toString: () => '2000000000000000000', - }); + (PROVIDER_MOCK.request as jest.Mock).mockResolvedValue('0x1bc16d674ec80000'); const result = await getLiveTokenBalance( messenger, @@ -689,12 +672,14 @@ describe('Token Utils', () => { ); expect(result).toBe('2000000000000000000'); - expect(mockGetBalance).toHaveBeenCalledWith(ACCOUNT_MOCK, 'pending'); - expect(Contract).not.toHaveBeenCalled(); + expect(PROVIDER_MOCK.request).toHaveBeenCalledWith({ + method: 'eth_getBalance', + params: [ACCOUNT_MOCK, 'pending'], + }); }); it('treats native address comparison as case-insensitive', async () => { - mockGetBalance.mockResolvedValue({ toString: () => '500' }); + (PROVIDER_MOCK.request as jest.Mock).mockResolvedValue('0x1f4'); const result = await getLiveTokenBalance( messenger, @@ -704,8 +689,90 @@ describe('Token Utils', () => { ); expect(result).toBe('500'); - expect(mockGetBalance).toHaveBeenCalledWith(ACCOUNT_MOCK, 'pending'); - expect(Contract).not.toHaveBeenCalled(); + expect(PROVIDER_MOCK.request).toHaveBeenCalledWith({ + method: 'eth_getBalance', + params: [ACCOUNT_MOCK, 'pending'], + }); + }); + + it('uses Infura network client when Infura endpoint is available', async () => { + (PROVIDER_MOCK.request as jest.Mock).mockResolvedValue('0x895440'); + + getNetworkConfigurationByChainIdMock.mockReturnValue({ + rpcEndpoints: [ + { + type: RpcEndpointType.Infura, + networkClientId: INFURA_NETWORK_CLIENT_ID_MOCK, + }, + ], + } as NetworkConfiguration); + + const result = await getLiveTokenBalance( + messenger, + ACCOUNT_MOCK, + CHAIN_ID_MOCK, + ERC20_ADDRESS_MOCK, + ); + + expect(result).toBe('9000000'); + expect(getNetworkConfigurationByChainIdMock).toHaveBeenCalledWith( + CHAIN_ID_MOCK, + ); + expect(getNetworkClientByIdMock).toHaveBeenCalledWith( + INFURA_NETWORK_CLIENT_ID_MOCK, + ); + expect(findNetworkClientIdByChainIdMock).not.toHaveBeenCalled(); + }); + + it('falls back to default network client when no Infura endpoint is configured', async () => { + (PROVIDER_MOCK.request as jest.Mock).mockResolvedValue('0x6ACFC0'); + + getNetworkConfigurationByChainIdMock.mockReturnValue({ + rpcEndpoints: [ + { + type: RpcEndpointType.Custom, + networkClientId: 'custom-rpc-id', + }, + ], + } as NetworkConfiguration); + + const result = await getLiveTokenBalance( + messenger, + ACCOUNT_MOCK, + CHAIN_ID_MOCK, + ERC20_ADDRESS_MOCK, + ); + + expect(result).toBe('7000000'); + expect(findNetworkClientIdByChainIdMock).toHaveBeenCalledWith( + CHAIN_ID_MOCK, + ); + expect(getNetworkClientByIdMock).toHaveBeenCalledWith( + NETWORK_CLIENT_ID_MOCK, + ); + }); + + it('falls back to default network client when getNetworkConfigurationByChainId throws', async () => { + (PROVIDER_MOCK.request as jest.Mock).mockResolvedValue('0x2DC6C0'); + + getNetworkConfigurationByChainIdMock.mockImplementation(() => { + throw new Error('Network configuration not found'); + }); + + const result = await getLiveTokenBalance( + messenger, + ACCOUNT_MOCK, + CHAIN_ID_MOCK, + ERC20_ADDRESS_MOCK, + ); + + expect(result).toBe('3000000'); + expect(findNetworkClientIdByChainIdMock).toHaveBeenCalledWith( + CHAIN_ID_MOCK, + ); + expect(getNetworkClientByIdMock).toHaveBeenCalledWith( + NETWORK_CLIENT_ID_MOCK, + ); }); }); diff --git a/packages/transaction-pay-controller/src/utils/token.ts b/packages/transaction-pay-controller/src/utils/token.ts index 4534a048c9..1eec031c88 100644 --- a/packages/transaction-pay-controller/src/utils/token.ts +++ b/packages/transaction-pay-controller/src/utils/token.ts @@ -1,5 +1,4 @@ -import { Contract } from '@ethersproject/contracts'; -import { Web3Provider } from '@ethersproject/providers'; +import { Interface } from '@ethersproject/abi'; import { TokensControllerState } from '@metamask/assets-controllers'; import { toChecksumHexAddress } from '@metamask/controller-utils'; import { abiERC20 } from '@metamask/metamask-eth-abis'; @@ -15,6 +14,7 @@ import { } from '../constants'; import type { FiatRates, TransactionPayControllerMessenger } from '../types'; import { getAssetsUnifyStateFeature } from './feature-flags'; +import { getNetworkClientId, rpcRequest } from './provider'; /** * Check if two tokens are the same (same address and chain). @@ -306,6 +306,10 @@ export function getNativeToken(chainId: Hex): Hex { * Unlike {@link getTokenBalance}, this bypasses the cached state in * `TokenBalancesController` and reads directly from the chain. * + * Uses the Infura RPC endpoint for the chain when one is configured, falling + * back to the chain's default endpoint. This avoids errors on custom mainnet + * RPC endpoints that may not support pending block queries. + * * @param messenger - Controller messenger. * @param account - Address of the account. * @param chainId - Chain ID. @@ -318,31 +322,35 @@ export async function getLiveTokenBalance( chainId: Hex, tokenAddress: Hex, ): Promise { - const networkClientId = messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - chainId, - ); - - const { provider } = messenger.call( - 'NetworkController:getNetworkClientById', - networkClientId, - ); - - const ethersProvider = new Web3Provider(provider); + const options = { preferInfura: true }; const isNative = tokenAddress.toLowerCase() === getNativeToken(chainId).toLowerCase(); - // Use `pending` blockTag to bypass the RPC block-cache middleware so callers - // always observe the latest balance instead of a value pinned to the last - // polled block. if (isNative) { - const balance = await ethersProvider.getBalance(account, 'pending'); - return balance.toString(); + const result = await rpcRequest( + messenger, + chainId, + 'eth_getBalance', + [account, 'pending'], + options, + ); + + return new BigNumber(result as string, 16).toString(10); } - const contract = new Contract(tokenAddress, abiERC20, ethersProvider); - const balance = await contract.balanceOf(account, { blockTag: 'pending' }); - return balance.toString(); + const calldata = new Interface(abiERC20).encodeFunctionData('balanceOf', [ + account, + ]) as Hex; + + const result = await rpcRequest( + messenger, + chainId, + 'eth_call', + [{ to: tokenAddress, data: calldata }, 'pending'], + options, + ); + + return new BigNumber(result as string, 16).toString(10); } /** @@ -385,17 +393,12 @@ function getTicker( messenger: TransactionPayControllerMessenger, ): string | undefined { try { - const networkClientId = messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - chainId, - ); + const networkClientId = getNetworkClientId(messenger, chainId); - const networkConfiguration = messenger.call( + return messenger.call( 'NetworkController:getNetworkClientById', networkClientId, - ); - - return networkConfiguration.configuration.ticker; + ).configuration.ticker; } catch { return undefined; } @@ -447,3 +450,5 @@ export function normalizeTokenAddress( return tokenAddress; } + + diff --git a/packages/transaction-pay-controller/src/utils/transaction.test.ts b/packages/transaction-pay-controller/src/utils/transaction.test.ts index a2328da53f..6189adb600 100644 --- a/packages/transaction-pay-controller/src/utils/transaction.test.ts +++ b/packages/transaction-pay-controller/src/utils/transaction.test.ts @@ -1,5 +1,4 @@ import { Interface } from '@ethersproject/abi'; -import { Web3Provider } from '@ethersproject/providers'; import { abiERC20 } from '@metamask/metamask-eth-abis'; import { TransactionStatus, @@ -33,10 +32,6 @@ import { jest.mock('./feature-flags'); jest.mock('./required-tokens'); -jest.mock('@ethersproject/providers', () => ({ - ...jest.requireActual('@ethersproject/providers'), - Web3Provider: jest.fn(), -})); const TRANSACTION_ID_MOCK = '123-456'; const ERROR_MESSAGE_MOCK = 'Test error'; @@ -707,32 +702,18 @@ describe('getTransferredAmountFromTxHash', () => { getNetworkClientByIdMock: receiptGetNetworkMock, } = getMessengerMock(); - let mockGetTransactionReceipt: jest.Mock; - let mockSend: jest.Mock; - let mockGetTx: jest.Mock; - beforeEach(() => { jest.resetAllMocks(); - mockGetTransactionReceipt = jest.fn(); - mockSend = jest.fn(); - mockGetTx = jest.fn(); - receiptFindNetworkMock.mockReturnValue(NETWORK_CLIENT_ID_RECEIPT_MOCK); receiptGetNetworkMock.mockReturnValue({ provider: PROVIDER_RECEIPT_MOCK, } as never); - - (Web3Provider as unknown as jest.Mock).mockImplementation(() => ({ - getTransactionReceipt: mockGetTransactionReceipt, - send: mockSend, - getTransaction: mockGetTx, - })); }); describe('native token', () => { it('returns amount from debug_traceTransaction for direct transfer', async () => { - mockSend.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), value: '0xde0b6b3a7640000', calls: [], @@ -747,14 +728,14 @@ describe('getTransferredAmountFromTxHash', () => { }); expect(result).toBe('1000000000000000000'); - expect(mockSend).toHaveBeenCalledWith('debug_traceTransaction', [ - TX_HASH_MOCK, - { tracer: 'callTracer' }, - ]); + expect(PROVIDER_RECEIPT_MOCK.request).toHaveBeenCalledWith({ + method: 'debug_traceTransaction', + params: [TX_HASH_MOCK, { tracer: 'callTracer' }], + }); }); it('sums native value from nested internal calls', async () => { - mockSend.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ to: '0xcontract', value: '0x0', calls: [ @@ -788,11 +769,17 @@ describe('getTransferredAmountFromTxHash', () => { }); it('falls back to tx.value when debug_traceTransaction is unsupported', async () => { - mockSend.mockRejectedValue(new Error('Method not found')); - mockGetTx.mockResolvedValue({ - to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), - value: { toString: () => '1500000000000000000' }, - }); + PROVIDER_RECEIPT_MOCK.request.mockImplementation( + ({ method }: { method: string }) => { + if (method === 'debug_traceTransaction') { + return Promise.reject(new Error('Method not found')); + } + return Promise.resolve({ + to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), + value: '0x14d1120d7b160000', + }); + }, + ); const result = await getTransferredAmountFromTxHash({ messenger: receiptMessenger, @@ -806,14 +793,17 @@ describe('getTransferredAmountFromTxHash', () => { }); it('returns undefined when trace returns zero value and tx.to does not match wallet', async () => { - mockSend.mockResolvedValue({ - to: '0xcontract', - value: '0x0', - }); - mockGetTx.mockResolvedValue({ - to: '0xcontract', - value: { toString: () => '1000000000000000000' }, - }); + PROVIDER_RECEIPT_MOCK.request.mockImplementation( + ({ method }: { method: string }) => { + if (method === 'debug_traceTransaction') { + return Promise.resolve({ to: '0xcontract', value: '0x0' }); + } + return Promise.resolve({ + to: '0xcontract', + value: '0xde0b6b3a7640000', + }); + }, + ); const result = await getTransferredAmountFromTxHash({ messenger: receiptMessenger, @@ -827,8 +817,14 @@ describe('getTransferredAmountFromTxHash', () => { }); it('returns undefined when trace is unsupported and transaction is not found', async () => { - mockSend.mockRejectedValue(new Error('Method not found')); - mockGetTx.mockResolvedValue(null); + PROVIDER_RECEIPT_MOCK.request.mockImplementation( + ({ method }: { method: string }) => { + if (method === 'debug_traceTransaction') { + return Promise.reject(new Error('Method not found')); + } + return Promise.resolve(null); + }, + ); const result = await getTransferredAmountFromTxHash({ messenger: receiptMessenger, @@ -842,11 +838,17 @@ describe('getTransferredAmountFromTxHash', () => { }); it('returns undefined when trace is unsupported and native tx.value is zero', async () => { - mockSend.mockRejectedValue(new Error('Method not found')); - mockGetTx.mockResolvedValue({ - to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), - value: { toString: () => '0' }, - }); + PROVIDER_RECEIPT_MOCK.request.mockImplementation( + ({ method }: { method: string }) => { + if (method === 'debug_traceTransaction') { + return Promise.reject(new Error('Method not found')); + } + return Promise.resolve({ + to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), + value: '0x0', + }); + }, + ); const result = await getTransferredAmountFromTxHash({ messenger: receiptMessenger, @@ -860,14 +862,20 @@ describe('getTransferredAmountFromTxHash', () => { }); it('ignores trace value with 0x0', async () => { - mockSend.mockResolvedValue({ - to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), - value: '0x0', - }); - mockGetTx.mockResolvedValue({ - to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), - value: { toString: () => '500' }, - }); + PROVIDER_RECEIPT_MOCK.request.mockImplementation( + ({ method }: { method: string }) => { + if (method === 'debug_traceTransaction') { + return Promise.resolve({ + to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), + value: '0x0', + }); + } + return Promise.resolve({ + to: WALLET_ADDRESS_RECEIPT_MOCK.toLowerCase(), + value: '0x1f4', + }); + }, + ); const result = await getTransferredAmountFromTxHash({ messenger: receiptMessenger, @@ -883,7 +891,7 @@ describe('getTransferredAmountFromTxHash', () => { describe('ERC-20 token', () => { it('decodes transfer amount from receipt logs', async () => { - mockGetTransactionReceipt.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [encodeTransferLog(WALLET_ADDRESS_RECEIPT_MOCK, '5000000')], }); @@ -899,7 +907,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('sums multiple Transfer events to the same wallet', async () => { - mockGetTransactionReceipt.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [ encodeTransferLog(WALLET_ADDRESS_RECEIPT_MOCK, '3000000'), encodeTransferLog(WALLET_ADDRESS_RECEIPT_MOCK, '2000000'), @@ -919,7 +927,7 @@ describe('getTransferredAmountFromTxHash', () => { it('ignores Transfer events to other addresses', async () => { const otherAddress = '0x3333333333333333333333333333333333333333' as Hex; - mockGetTransactionReceipt.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [ encodeTransferLog(otherAddress, '9000000'), encodeTransferLog(WALLET_ADDRESS_RECEIPT_MOCK, '1000000'), @@ -943,7 +951,7 @@ describe('getTransferredAmountFromTxHash', () => { WALLET_ADDRESS_RECEIPT_MOCK, '5000000', ); - mockGetTransactionReceipt.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [ { ...transferLog, address: otherToken }, encodeTransferLog(WALLET_ADDRESS_RECEIPT_MOCK, '1000000'), @@ -962,7 +970,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('ignores logs with non-Transfer event topics', async () => { - mockGetTransactionReceipt.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [ { address: ERC20_ADDRESS_RECEIPT_MOCK, @@ -989,7 +997,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('returns undefined when receipt is not found', async () => { - mockGetTransactionReceipt.mockResolvedValue(null); + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue(null); const result = await getTransferredAmountFromTxHash({ messenger: receiptMessenger, @@ -1003,9 +1011,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('returns undefined when no matching Transfer logs exist', async () => { - mockGetTransactionReceipt.mockResolvedValue({ - logs: [], - }); + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [] }); const result = await getTransferredAmountFromTxHash({ messenger: receiptMessenger, @@ -1019,7 +1025,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('skips malformed log entries gracefully', async () => { - mockGetTransactionReceipt.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [ { address: ERC20_ADDRESS_RECEIPT_MOCK, @@ -1042,7 +1048,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('returns undefined when all Transfer amounts are zero', async () => { - mockGetTransactionReceipt.mockResolvedValue({ + PROVIDER_RECEIPT_MOCK.request.mockResolvedValue({ logs: [encodeTransferLog(WALLET_ADDRESS_RECEIPT_MOCK, '0')], }); @@ -1059,7 +1065,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('propagates provider errors for ERC-20', async () => { - mockGetTransactionReceipt.mockRejectedValue(new Error('RPC error')); + PROVIDER_RECEIPT_MOCK.request.mockRejectedValue(new Error('RPC error')); await expect( getTransferredAmountFromTxHash({ @@ -1073,8 +1079,7 @@ describe('getTransferredAmountFromTxHash', () => { }); it('propagates provider errors for native when both trace and getTransaction fail', async () => { - mockSend.mockRejectedValue(new Error('Trace failed')); - mockGetTx.mockRejectedValue(new Error('RPC error')); + PROVIDER_RECEIPT_MOCK.request.mockRejectedValue(new Error('RPC error')); await expect( getTransferredAmountFromTxHash({ diff --git a/packages/transaction-pay-controller/src/utils/transaction.ts b/packages/transaction-pay-controller/src/utils/transaction.ts index 97e772e666..8c66b35764 100644 --- a/packages/transaction-pay-controller/src/utils/transaction.ts +++ b/packages/transaction-pay-controller/src/utils/transaction.ts @@ -1,5 +1,4 @@ import { Interface } from '@ethersproject/abi'; -import { Web3Provider } from '@ethersproject/providers'; import { abiERC20 } from '@metamask/metamask-eth-abis'; import { TransactionStatus, @@ -19,6 +18,7 @@ import type { UpdateTransactionDataCallback, } from '../types'; import { getAssetsUnifyStateFeature } from './feature-flags'; +import { getNetworkClientId, rpcRequest } from './provider'; import { parseRequiredTokens } from './required-tokens'; import { getNativeToken } from './token'; @@ -396,17 +396,16 @@ export async function getTransferredAmountFromTxHash({ tokenAddress: Hex; walletAddress: Hex; }): Promise { - const provider = getEthersProvider(messenger, chainId); - const isNative = tokenAddress.toLowerCase() === getNativeToken(chainId).toLowerCase(); if (isNative) { - return await getNativeTransferAmount(provider, txHash, walletAddress); + return await getNativeTransferAmount(messenger, chainId, txHash, walletAddress); } return await getErc20TransferAmount( - provider, + messenger, + chainId, txHash, tokenAddress, walletAddress, @@ -421,23 +420,25 @@ export async function getTransferredAmountFromTxHash({ * 2. Falls back to the top-level `tx.value` when the wallet is the direct * recipient and the trace RPC is unavailable or errors. * - * @param provider - Ethers Web3Provider. + * @param messenger - Controller messenger. + * @param chainId - Chain ID where the transaction was executed. * @param txHash - Transaction hash. * @param walletAddress - Recipient wallet address. * @returns Raw amount as a decimal string, or `undefined`. */ async function getNativeTransferAmount( - provider: Web3Provider, + messenger: TransactionPayControllerMessenger, + chainId: Hex, txHash: string, walletAddress: Hex, ): Promise { try { - const trace = await provider.send('debug_traceTransaction', [ + const trace = await rpcRequest(messenger, chainId, 'debug_traceTransaction', [ txHash, { tracer: 'callTracer' }, ]); - const amount = sumNativeValueFromTrace(trace, walletAddress); + const amount = sumNativeValueFromTrace(trace as CallTrace, walletAddress); if (amount.gt(0)) { return amount.toFixed(0); } @@ -445,7 +446,11 @@ async function getNativeTransferAmount( // debug_traceTransaction not supported — fall through to tx.value } - const tx = await provider.getTransaction(txHash); + const tx = await rpcRequest(messenger, chainId, 'eth_getTransactionByHash', [txHash]) as { + to?: string; + value: string; + } | null; + if (!tx) { return undefined; } @@ -454,26 +459,30 @@ async function getNativeTransferAmount( return undefined; } - return positiveOrUndefined(tx.value.toString()); + return positiveOrUndefined(new BigNumber(tx.value).toFixed(0)); } /** * Resolves the ERC-20 token amount received by a wallet from a transaction * by decoding `Transfer` event logs from the transaction receipt. * - * @param provider - Ethers Web3Provider. + * @param messenger - Controller messenger. + * @param chainId - Chain ID where the transaction was executed. * @param txHash - Transaction hash. * @param tokenAddress - ERC-20 token contract address. * @param walletAddress - Recipient wallet address. * @returns Raw amount as a decimal string, or `undefined`. */ async function getErc20TransferAmount( - provider: Web3Provider, + messenger: TransactionPayControllerMessenger, + chainId: Hex, txHash: string, tokenAddress: Hex, walletAddress: Hex, ): Promise { - const receipt = await provider.getTransactionReceipt(txHash); + const receipt = await rpcRequest(messenger, chainId, 'eth_getTransactionReceipt', [txHash]) as { + logs: { address: string; topics: string[]; data: string }[]; + } | null; if (!receipt) { return undefined; @@ -544,23 +553,6 @@ function sumNativeValueFromTrace( return total; } -function getEthersProvider( - messenger: TransactionPayControllerMessenger, - chainId: Hex, -): Web3Provider { - const networkClientId = messenger.call( - 'NetworkController:findNetworkClientIdByChainId', - chainId, - ); - - const { provider } = messenger.call( - 'NetworkController:getNetworkClientById', - networkClientId, - ); - - return new Web3Provider(provider); -} - function positiveOrUndefined(amount: string): string | undefined { return new BigNumber(amount).gt(0) ? amount : undefined; }