Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function buildMockTokenPricesService(
): AbstractTokenPricesService {
return {
async fetchTokenPrices() {
return {};
return [];
},
async fetchExchangeRates() {
return {};
Expand Down
242 changes: 173 additions & 69 deletions packages/assets-controllers/src/TokenRatesController.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,10 @@ import { isEqual } from 'lodash';

import { reduceInBatchesSerially, TOKEN_PRICES_BATCH_SIZE } from './assetsUtil';
import { fetchExchangeRate as fetchNativeCurrencyExchangeRate } from './crypto-compare-service';
import type { AbstractTokenPricesService } from './token-prices-service/abstract-token-prices-service';
import type {
AbstractTokenPricesService,
EvmAssetWithMarketData,
} from './token-prices-service/abstract-token-prices-service';
import { getNativeTokenAddress } from './token-prices-service/codefi-v2';
import type {
TokensControllerGetStateAction,
Expand Down Expand Up @@ -92,6 +95,11 @@ export type MarketDataDetails = {
*/
export type ContractMarketData = Record<Hex, MarketDataDetails>;

type ChainIdAndNativeCurrency = {
chainId: Hex;
nativeCurrency: string;
};

enum PollState {
Active = 'Active',
Inactive = 'Inactive',
Expand Down Expand Up @@ -250,6 +258,8 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR

readonly #interval: number;

readonly #getSelectedCurrency: () => string;

#allTokens: TokensControllerState['allTokens'];

#allDetectedTokens: TokensControllerState['allDetectedTokens'];
Expand All @@ -263,19 +273,22 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
* @param options.tokenPricesService - An object in charge of retrieving token price
* @param options.messenger - The messenger instance for communication
* @param options.state - Initial state to set on this controller
* @param options.getSelectedCurrency - A function to fetch the selected currency
*/
constructor({
interval = DEFAULT_INTERVAL,
disabled = false,
tokenPricesService,
messenger,
state,
getSelectedCurrency,
}: {
interval?: number;
disabled?: boolean;
tokenPricesService: AbstractTokenPricesService;
messenger: TokenRatesControllerMessenger;
state?: Partial<TokenRatesControllerState>;
getSelectedCurrency: () => string;
}) {
super({
name: controllerName,
Expand All @@ -288,6 +301,7 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
this.#tokenPricesService = tokenPricesService;
this.#disabled = disabled;
this.#interval = interval;
this.#getSelectedCurrency = getSelectedCurrency;

const { allTokens, allDetectedTokens } = this.#getTokensControllerState();
this.#allTokens = allTokens;
Expand Down Expand Up @@ -409,7 +423,13 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
const tokenAddresses = getTokens(this.#allTokens[chainId]);
const detectedTokenAddresses = getTokens(this.#allDetectedTokens[chainId]);

return [...new Set([...tokenAddresses, ...detectedTokenAddresses])].sort();
return [
...new Set([
...tokenAddresses,
...detectedTokenAddresses,
getNativeTokenAddress(chainId),
]),
].sort();
}

/**
Expand Down Expand Up @@ -495,14 +515,143 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
* @param chainIdAndNativeCurrency - The chain ID and native currency.
*/
async updateExchangeRates(
chainIdAndNativeCurrency: {
chainId: Hex;
nativeCurrency: string;
}[],
chainIdAndNativeCurrency: ChainIdAndNativeCurrency[],
) {
await this.updateExchangeRatesByChainId(chainIdAndNativeCurrency);
}

async updateExchangeRatesToCurrency(chainIds: Hex[]): Promise<void> {
if (this.#disabled) {
return;
}

const currency = this.#getSelectedCurrency();

const marketData: Record<Hex, Record<Hex, MarketDataDetails>> = {};
const assets: {
chainId: Hex;
tokenAddress: Hex;
}[] = [];
for (const chainId of chainIds) {
if (this.#tokenPricesService.validateChainIdSupported(chainId)) {
this.#getTokenAddresses(chainId).forEach((tokenAddress) => {
assets.push({
chainId,
tokenAddress,
});
});
} else {
marketData[chainId] = {};
}
}

await reduceInBatchesSerially<
{ chainId: Hex; tokenAddress: Hex },
Record<Hex, Record<Hex, MarketDataDetails>>
>({
values: assets,
batchSize: TOKEN_PRICES_BATCH_SIZE,
eachBatch: async (partialMarketData, assetsBatch) => {
const batchMarketData = await this.#tokenPricesService.fetchTokenPrices(
{
assets: assetsBatch,
currency,
},
);

for (const tokenPrice of batchMarketData) {
(partialMarketData[tokenPrice.chainId] ??= {})[
tokenPrice.tokenAddress
] = tokenPrice;
}

return partialMarketData;
},
initialResult: marketData,
});

if (Object.keys(marketData).length > 0) {
this.update((state) => {
state.marketData = {
...state.marketData,
...marketData,
};
});
}
}

async updateExchangeRatesToNative(chainIds: Hex[]): Promise<void> {
if (this.#disabled) {
return;
}

const { networkConfigurationsByChainId } = this.messenger.call(
'NetworkController:getState',
);

const marketData: Record<Hex, Record<Hex, MarketDataDetails>> = {};
const assetsByNativeCurrency: Record<
string,
{
chainId: Hex;
tokenAddress: Hex;
}[]
> = {};
for (const chainId of chainIds) {
if (this.#tokenPricesService.validateChainIdSupported(chainId)) {
const { nativeCurrency } = networkConfigurationsByChainId[chainId];

this.#getTokenAddresses(chainId).forEach((tokenAddress) => {
(assetsByNativeCurrency[nativeCurrency] ??= []).push({
chainId,
tokenAddress,
});
});
} else {
marketData[chainId] = {};
}
}

await Promise.allSettled(
Object.entries(assetsByNativeCurrency).map(
async ([nativeCurrency, assets]) => {
return await reduceInBatchesSerially<
{ chainId: Hex; tokenAddress: Hex },
Record<Hex, Record<Hex, MarketDataDetails>>
>({
values: assets,
batchSize: TOKEN_PRICES_BATCH_SIZE,
eachBatch: async (partialMarketData, assetsBatch) => {
const batchMarketData =
await this.#tokenPricesService.fetchTokenPrices({
assets: assetsBatch,
currency: nativeCurrency,
});

for (const tokenPrice of batchMarketData) {
(partialMarketData[tokenPrice.chainId] ??= {})[
tokenPrice.tokenAddress
] = tokenPrice;
}

return partialMarketData;
},
initialResult: marketData,
});
},
),
);

if (Object.keys(marketData).length > 0) {
this.update((state) => {
state.marketData = {
...state.marketData,
...marketData,
};
});
}
}

/**
* Updates exchange rates for all tokens.
*
Expand All @@ -515,10 +664,7 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
* @param chainIdAndNativeCurrency - The chain ID and native currency.
*/
async updateExchangeRatesByChainId(
chainIdAndNativeCurrency: {
chainId: Hex;
nativeCurrency: string;
}[],
chainIdAndNativeCurrency: ChainIdAndNativeCurrency[],
): Promise<void> {
if (this.#disabled) {
return;
Expand Down Expand Up @@ -655,28 +801,7 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
* @param input.chainIds - The chain ids to poll token rates on.
*/
async _executePoll({ chainIds }: TokenRatesPollingInput): Promise<void> {
const { networkConfigurationsByChainId } = this.messenger.call(
'NetworkController:getState',
);

const chainIdAndNativeCurrency = chainIds.reduce<
{ chainId: Hex; nativeCurrency: string }[]
>((acc, chainId) => {
const networkConfiguration = networkConfigurationsByChainId[chainId];
if (!networkConfiguration) {
console.error(
`TokenRatesController: No network configuration found for chainId ${chainId}`,
);
return acc;
}
acc.push({
chainId,
nativeCurrency: networkConfiguration.nativeCurrency,
});
return acc;
}, []);

await this.updateExchangeRatesByChainId(chainIdAndNativeCurrency);
await this.updateExchangeRatesToNative(chainIds);
}

/**
Expand All @@ -700,20 +825,28 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
chainId: Hex;
nativeCurrency: string;
}): Promise<ContractMarketData> {
let contractNativeInformations;
const tokenPricesByTokenAddress = await reduceInBatchesSerially<
return await reduceInBatchesSerially<
Hex,
Awaited<ReturnType<AbstractTokenPricesService['fetchTokenPrices']>>
Record<Hex, EvmAssetWithMarketData>
>({
values: [...tokenAddresses].sort(),
values: [...tokenAddresses, getNativeTokenAddress(chainId)].sort(),
batchSize: TOKEN_PRICES_BATCH_SIZE,
eachBatch: async (allTokenPricesByTokenAddress, batch) => {
const tokenPricesByTokenAddressForBatch =
const tokenPricesByTokenAddressForBatch = (
await this.#tokenPricesService.fetchTokenPrices({
tokenAddresses: batch,
chainId,
assets: batch.map((tokenAddress) => ({
chainId,
tokenAddress,
})),
currency: nativeCurrency,
});
})
).reduce(
(acc, tokenPrice) => {
acc[tokenPrice.tokenAddress] = tokenPrice;
return acc;
},
{} as Record<Hex, EvmAssetWithMarketData>,
);

return {
...allTokenPricesByTokenAddress,
Expand All @@ -722,35 +855,6 @@ export class TokenRatesController extends StaticIntervalPollingController<TokenR
},
initialResult: {},
});
contractNativeInformations = tokenPricesByTokenAddress;

// fetch for native token
if (tokenAddresses.length === 0) {
const contractNativeInformationsNative =
await this.#tokenPricesService.fetchTokenPrices({
tokenAddresses: [],
chainId,
currency: nativeCurrency,
});

contractNativeInformations = {
[getNativeTokenAddress(chainId)]: {
currency: nativeCurrency,
...contractNativeInformationsNative[getNativeTokenAddress(chainId)],
},
};
}
return Object.entries(contractNativeInformations).reduce(
(obj, [tokenAddress, token]) => {
obj = {
...obj,
[tokenAddress]: { ...token },
};

return obj;
},
{},
);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ import type { TokenDisplayData } from './types';
import { formatIconUrlWithProxy } from '../assetsUtil';
import type { GetCurrencyRateState } from '../CurrencyRateController';
import type { AbstractTokenPricesService } from '../token-prices-service';
import type { TokenPrice } from '../token-prices-service/abstract-token-prices-service';
import {
fetchTokenMetadata,
TOKEN_METADATA_NO_SUPPORT_ERROR,
Expand Down Expand Up @@ -172,22 +171,18 @@ export class TokenSearchDiscoveryDataController extends BaseController<
this.#fetchSwapsTokensThresholdMs = fetchSwapsTokensThresholdMs;
}

async #fetchPriceData(
chainId: Hex,
address: string,
): Promise<TokenPrice<Hex, string> | null> {
async #fetchPriceData(chainId: Hex, address: string) {
const { currentCurrency } = this.messenger.call(
'CurrencyRateController:getState',
);

try {
const pricesData = await this.#tokenPricesService.fetchTokenPrices({
chainId,
tokenAddresses: [address as Hex],
assets: [{ chainId, tokenAddress: address as Hex }],
currency: currentCurrency,
});

return pricesData[address as Hex] ?? null;
return pricesData[0] ?? null;
} catch (error) {
console.error(error);
return null;
Expand Down
Loading
Loading