diff --git a/frontend/.clinerules b/frontend/.clinerules index 7c824479..08569677 100644 --- a/frontend/.clinerules +++ b/frontend/.clinerules @@ -44,5 +44,5 @@ frontend ## Testing -- use `testing-library/react` for testing hooks. +- use `testing-library/react` for testing hooks. DO NOT USE `testing-library/react-hooks` - root directory of the frontend is `/frontend`, so if you want to run test, you should run `yarn test` in `/frontend` directory. diff --git a/frontend/components/trainer-add/navigator-preprocess/hooks/__tests__/use-pagenation.test.tsx b/frontend/components/trainer-add/navigator-preprocess/hooks/__tests__/use-pagenation.test.tsx new file mode 100644 index 00000000..d9a235f6 --- /dev/null +++ b/frontend/components/trainer-add/navigator-preprocess/hooks/__tests__/use-pagenation.test.tsx @@ -0,0 +1,274 @@ +import { renderHook, act } from '@testing-library/react'; +import { usePreprocessSelexData } from '../use-pagenation'; +import { useRouter } from 'next/router'; +import { useDispatch, useSelector } from 'react-redux'; +import { preprocessSelexData } from '../../../redux/selex-data'; +import { clearPreprocessingDirty } from '../../../redux/preprocessing-config'; + +// Mock the dependencies +jest.mock('next/router', () => ({ + useRouter: jest.fn(), +})); + +jest.mock('react-redux', () => ({ + useDispatch: jest.fn(), + useSelector: jest.fn(), +})); + +jest.mock('../../../redux/selex-data', () => ({ + preprocessSelexData: jest.fn(), +})); + +jest.mock('../../../redux/preprocessing-config', () => ({ + clearPreprocessingDirty: jest.fn(), +})); + +describe('usePreprocessSelexData', () => { + // Setup common mocks + const mockPush = jest.fn(); + const mockDispatch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock router + (useRouter as jest.Mock).mockReturnValue({ + push: mockPush, + }); + + // Mock dispatch + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + + // Mock successful dispatch for async actions + mockDispatch.mockImplementation((action) => { + if (typeof action === 'function') { + return Promise.resolve(); + } + return action; + }); + }); + + it('should return the correct initial values', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + isDirty: false, + isValidParams: true, + }, + pageConfig: { + experimentName: 'Test Experiment', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessSelexData()); + + expect(result.current.isLoading).toBe(false); + expect(typeof result.current.handleClickNext).toBe('function'); + expect(typeof result.current.handleClickBack).toBe('function'); + expect(result.current.canProceed).toBe(true); + }); + + it('should navigate back when handleClickBack is called', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + isDirty: false, + isValidParams: true, + }, + pageConfig: { + experimentName: 'Test Experiment', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessSelexData()); + + act(() => { + result.current.handleClickBack(); + }); + + expect(mockPush).toHaveBeenCalledWith('/trainer'); + }); + + it('should navigate to next page without preprocessing when isDirty is false', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + isDirty: false, + isValidParams: true, + }, + pageConfig: { + experimentName: 'Test Experiment', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessSelexData()); + + await act(async () => { + await result.current.handleClickNext(); + }); + + expect(mockPush).toHaveBeenCalledWith('?page=raptgen'); + expect(mockDispatch).not.toHaveBeenCalled(); + }); + + it('should preprocess data and navigate to next page when isDirty is true', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + isDirty: true, + isValidParams: true, + }, + pageConfig: { + experimentName: 'Test Experiment', + }, + }; + return selector(state); + }); + + // Mock the action creators + (preprocessSelexData as unknown as jest.Mock).mockReturnValue({ type: 'PREPROCESS_SELEX_DATA' }); + (clearPreprocessingDirty as unknown as jest.Mock).mockReturnValue({ type: 'CLEAR_PREPROCESSING_DIRTY' }); + + const { result } = renderHook(() => usePreprocessSelexData()); + + await act(async () => { + await result.current.handleClickNext(); + }); + + expect(preprocessSelexData).toHaveBeenCalledWith({ + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }); + + expect(clearPreprocessingDirty).toHaveBeenCalled(); + expect(mockPush).toHaveBeenCalledWith('?page=raptgen'); + }); + + it('should handle errors during preprocessing', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + isDirty: true, + isValidParams: true, + }, + pageConfig: { + experimentName: 'Test Experiment', + }, + }; + return selector(state); + }); + + // Mock error in dispatch + const mockError = new Error('Preprocessing failed'); + mockDispatch.mockRejectedValueOnce(mockError); + + // Spy on console.error + jest.spyOn(console, 'error').mockImplementation(() => {}); + + const { result } = renderHook(() => usePreprocessSelexData()); + + await act(async () => { + await result.current.handleClickNext(); + }); + + expect(console.error).toHaveBeenCalledWith(mockError); + expect(result.current.isLoading).toBe(false); + expect(mockPush).not.toHaveBeenCalled(); + }); + + it('should disable proceeding when params are invalid', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + isDirty: false, + isValidParams: false, + }, + pageConfig: { + experimentName: 'Test Experiment', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessSelexData()); + + expect(result.current.canProceed).toBe(false); + }); + + it('should disable proceeding when experiment name is missing', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + isDirty: false, + isValidParams: true, + }, + pageConfig: { + experimentName: '', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessSelexData()); + + expect(result.current.canProceed).toBe(false); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-add/navigator-preprocess/hooks/use-pagenation.tsx b/frontend/components/trainer-add/navigator-preprocess/hooks/use-pagenation.tsx index 129bbdd0..e9bb0c9f 100644 --- a/frontend/components/trainer-add/navigator-preprocess/hooks/use-pagenation.tsx +++ b/frontend/components/trainer-add/navigator-preprocess/hooks/use-pagenation.tsx @@ -27,6 +27,7 @@ export const usePreprocessSelexData = () => { if (!isDirty) { // do nothing and go to next page router.push("?page=raptgen"); + return; } setIsLoading(true); diff --git a/frontend/components/trainer-add/navigator-train/hooks/__tests__/use-submit-job.test.tsx b/frontend/components/trainer-add/navigator-train/hooks/__tests__/use-submit-job.test.tsx new file mode 100644 index 00000000..a9ea8065 --- /dev/null +++ b/frontend/components/trainer-add/navigator-train/hooks/__tests__/use-submit-job.test.tsx @@ -0,0 +1,252 @@ +import { renderHook, act } from '@testing-library/react'; +import { useSubmitJob } from '../use-submit-job'; +import { useRouter } from 'next/router'; +import { useSelector } from 'react-redux'; +import { apiClient } from '~/services/api-client'; + +// Mock the dependencies +jest.mock('next/router', () => ({ + useRouter: jest.fn(), +})); + +jest.mock('react-redux', () => ({ + useSelector: jest.fn(), +})); + +jest.mock('~/services/api-client', () => ({ + apiClient: { + postSubmitJob: jest.fn(), + }, +})); + +describe('useSubmitJob', () => { + // Setup common mocks + const mockPush = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock router + (useRouter as jest.Mock).mockReturnValue({ + push: mockPush, + }); + }); + + it('should return the correct initial values', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + modelLength: 64, + epochs: 100, + forceMatchEpochs: 10, + betaScheduleEpochs: 20, + earlyStoppingEpochs: 5, + seed: 42, + matchCost: 1.0, + device: 'cuda', + reiteration: 1, + isValidParams: true, + }, + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + filteredRandomRegions: ['ACGT', 'TGCA'], + filteredDuplicates: [1, 2], + }, + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'vae', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => useSubmitJob()); + + expect(result.current.isLoading).toBe(false); + expect(result.current.canTrain).toBe(true); + expect(typeof result.current.handleClickTrain).toBe('function'); + expect(typeof result.current.handleClickBack).toBe('function'); + }); + + it('should navigate back when handleClickBack is called', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + isValidParams: true, + }, + preprocessingConfig: {}, + selexData: {}, + pageConfig: {}, + }; + return selector(state); + }); + + const { result } = renderHook(() => useSubmitJob()); + + act(() => { + result.current.handleClickBack(); + }); + + expect(mockPush).toHaveBeenCalledWith(''); + }); + + it('should submit job and navigate to job page when handleClickTrain is called', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + modelLength: 64, + epochs: 100, + forceMatchEpochs: 10, + betaScheduleEpochs: 20, + earlyStoppingEpochs: 5, + seed: 42, + matchCost: 1.0, + device: 'cuda', + reiteration: 1, + isValidParams: true, + }, + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + filteredRandomRegions: ['ACGT', 'TGCA'], + filteredDuplicates: [1, 2], + }, + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'vae', + }, + }; + return selector(state); + }); + + // Mock API response + (apiClient.postSubmitJob as jest.Mock).mockResolvedValue({ uuid: '123-456-789' }); + + const { result } = renderHook(() => useSubmitJob()); + + await act(async () => { + await result.current.handleClickTrain(); + }); + + // Check API call + expect(apiClient.postSubmitJob).toHaveBeenCalledWith({ + type: 'vae', + name: 'Test Experiment', + params_preprocessing: { + forward: 'ACGT', + reverse: 'TGCA', + random_region_length: 22, // 30 - 4 - 4 + tolerance: 2, + minimum_count: 10, + }, + random_regions: ['ACGT', 'TGCA'], + duplicates: [1, 2], + reiteration: 1, + params_training: { + model_length: 64, + epochs: 100, + match_forcing_duration: 10, + beta_duration: 20, + early_stopping: 5, + seed_value: 42, + match_cost: 1.0, + device: 'cuda', + }, + }); + + // Check navigation + expect(mockPush).toHaveBeenCalledWith('/trainer?experiment=123-456-789'); + }); + + it('should handle API errors when submitting job', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + modelLength: 64, + epochs: 100, + forceMatchEpochs: 10, + betaScheduleEpochs: 20, + earlyStoppingEpochs: 5, + seed: 42, + matchCost: 1.0, + device: 'cuda', + reiteration: 1, + isValidParams: true, + }, + preprocessingConfig: { + forwardAdapter: 'ACGT', + reverseAdapter: 'TGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + filteredRandomRegions: ['ACGT', 'TGCA'], + filteredDuplicates: [1, 2], + }, + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'vae', + }, + }; + return selector(state); + }); + + // Mock API error + const mockError = new Error('API error'); + (apiClient.postSubmitJob as jest.Mock).mockRejectedValue(mockError); + + // Spy on console.error + jest.spyOn(console, 'error').mockImplementation(() => {}); + + const { result } = renderHook(() => useSubmitJob()); + + await act(async () => { + await result.current.handleClickTrain(); + }); + + // Check error handling + expect(console.error).toHaveBeenCalledWith(mockError); + expect(result.current.isLoading).toBe(false); + expect(mockPush).not.toHaveBeenCalled(); + }); + + it('should disable training when params are invalid', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + isValidParams: false, + }, + preprocessingConfig: {}, + selexData: {}, + pageConfig: {}, + }; + return selector(state); + }); + + const { result } = renderHook(() => useSubmitJob()); + + expect(result.current.canTrain).toBe(false); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-add/preprocessing-forms/hooks/__tests__/use-preprocessing-params.test.tsx b/frontend/components/trainer-add/preprocessing-forms/hooks/__tests__/use-preprocessing-params.test.tsx new file mode 100644 index 00000000..535e26ed --- /dev/null +++ b/frontend/components/trainer-add/preprocessing-forms/hooks/__tests__/use-preprocessing-params.test.tsx @@ -0,0 +1,392 @@ +import { renderHook, act } from '@testing-library/react'; +import { usePreprocessingParams, useModelTypeSelection } from '../use-preprocessing-params'; +import { useDispatch, useSelector } from 'react-redux'; +import { setPreprocessingConfig } from '../../../redux/preprocessing-config'; +import { setPageConfig } from '../../../redux/page-config'; +import { apiClient } from '~/services/api-client'; + +// Mock the dependencies +jest.mock('react-redux', () => ({ + useDispatch: jest.fn(), + useSelector: jest.fn(), +})); + +jest.mock('../../../redux/preprocessing-config', () => ({ + setPreprocessingConfig: jest.fn(), +})); + +jest.mock('../../../redux/page-config', () => ({ + setPageConfig: jest.fn(), +})); + +jest.mock('~/services/api-client', () => ({ + apiClient: { + estimateTargetLength: jest.fn(), + estimateAdapters: jest.fn(), + }, +})); + +describe('usePreprocessingParams', () => { + // Setup common mocks + const mockDispatch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock dispatch + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + + // Mock action creators + (setPreprocessingConfig as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_PREPROCESSING_CONFIG', + payload, + })); + + (setPageConfig as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_PAGE_CONFIG', + payload, + })); + }); + + it('should return the correct initial values', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }, + preprocessingConfig: { + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + sequences: ['ACGUUGCA', 'UGCAACGU'], + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessingParams()); + + // Check returned structure + expect(result.current.modelType).toBeDefined(); + expect(result.current.experimentName).toBeDefined(); + expect(result.current.targetLength).toBeDefined(); + expect(result.current.adapters).toBeDefined(); + expect(result.current.tolerance).toBeDefined(); + expect(result.current.minCount).toBeDefined(); + expect(result.current.fullSequences).toBeDefined(); + + // Check initial values + expect(result.current.modelType.value).toBe('RaptGen'); + expect(result.current.experimentName.value).toBe('Test Experiment'); + expect(result.current.targetLength.value).toBe(30); + expect(result.current.adapters.forwardAdapter.value).toBe('ACGU'); + expect(result.current.adapters.reverseAdapter.value).toBe('UGCA'); + expect(result.current.tolerance.value).toBe(2); + expect(result.current.minCount.value).toBe(10); + expect(result.current.fullSequences).toEqual(['ACGUUGCA', 'UGCAACGU']); + }); + + it('should handle experiment name change', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }, + preprocessingConfig: { + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + sequences: ['ACGUUGCA', 'UGCAACGU'], + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessingParams()); + + // Simulate input change + act(() => { + result.current.experimentName.handleChange({ + target: { value: 'New Experiment Name' }, + } as React.ChangeEvent); + }); + + // Check dispatch was called with correct action + expect(setPageConfig).toHaveBeenCalledWith({ + experimentName: 'New Experiment Name', + modelType: 'RaptGen', + }); + }); + + it('should handle adapter change and convert T to U', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }, + preprocessingConfig: { + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + sequences: ['ACGUUGCA', 'UGCAACGU'], + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => usePreprocessingParams()); + + // Simulate input change with T that should be converted to U + act(() => { + result.current.adapters.forwardAdapter.handleChange({ + target: { value: 'ACGTt' }, + } as React.ChangeEvent); + }); + + // Check dispatch was called with correct action and T converted to U + expect(setPreprocessingConfig).toHaveBeenCalledWith({ + forwardAdapter: 'ACGUU', + reverseAdapter: 'UGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }); + }); + + it('should estimate target length successfully', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }, + preprocessingConfig: { + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + sequences: ['ACGUUGCA', 'UGCAACGU'], + }, + }; + return selector(state); + }); + + // Mock API response + (apiClient.estimateTargetLength as jest.Mock).mockResolvedValue({ + status: 'success', + data: { target_length: 40 }, + }); + + const { result } = renderHook(() => usePreprocessingParams()); + + // Before estimation + expect(result.current.targetLength.value).toBe(30); + expect(result.current.targetLength.isEstimating).toBe(false); + + // Perform estimation + await act(async () => { + await result.current.targetLength.estimate(); + }); + + // After estimation + expect(apiClient.estimateTargetLength).toHaveBeenCalledWith({ + sequences: ['ACGUUGCA', 'UGCAACGU'], + }); + + expect(setPreprocessingConfig).toHaveBeenCalledWith({ + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + targetLength: 40, + tolerance: 2, + minCount: 10, + }); + + expect(result.current.targetLength.isEstimating).toBe(false); + }); + + it('should estimate adapters successfully', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }, + preprocessingConfig: { + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + sequences: ['ACGUUGCA', 'UGCAACGU'], + }, + }; + return selector(state); + }); + + // Mock API response + (apiClient.estimateAdapters as jest.Mock).mockResolvedValue({ + status: 'success', + data: { + forward_adapter: 'GGTT', + reverse_adapter: 'AACC', + }, + }); + + const { result } = renderHook(() => usePreprocessingParams()); + + // Before estimation + expect(result.current.adapters.forwardAdapter.value).toBe('ACGU'); + expect(result.current.adapters.reverseAdapter.value).toBe('UGCA'); + expect(result.current.adapters.isEstimating).toBe(false); + + // Perform estimation + await act(async () => { + await result.current.adapters.estimate(); + }); + + // After estimation + expect(apiClient.estimateAdapters).toHaveBeenCalledWith({ + target_length: 30, + sequences: ['ACGUUGCA', 'UGCAACGU'], + }); + + expect(setPreprocessingConfig).toHaveBeenCalledWith({ + forwardAdapter: 'GGUU', + reverseAdapter: 'AACC', + targetLength: 30, + tolerance: 2, + minCount: 10, + }); + + expect(result.current.adapters.isEstimating).toBe(false); + }); + + it('should handle API errors during estimation', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }, + preprocessingConfig: { + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + targetLength: 30, + tolerance: 2, + minCount: 10, + }, + selexData: { + sequences: ['ACGUUGCA', 'UGCAACGU'], + }, + }; + return selector(state); + }); + + // Mock API error + (apiClient.estimateTargetLength as jest.Mock).mockRejectedValue(new Error('API error')); + + const { result } = renderHook(() => usePreprocessingParams()); + + // Perform estimation that will fail + await act(async () => { + await result.current.targetLength.estimate(); + }); + + // Should reset loading state even after error + expect(result.current.targetLength.isEstimating).toBe(false); + expect(setPreprocessingConfig).not.toHaveBeenCalled(); + }); +}); + +describe('useModelTypeSelection', () => { + const mockDispatch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + (setPageConfig as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_PAGE_CONFIG', + payload, + })); + }); + + it('should return the correct initial values', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + modelType: 'RaptGen', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => useModelTypeSelection()); + + expect(result.current.value).toBe('RaptGen'); + expect(result.current.options).toContain('RaptGen'); + expect(typeof result.current.handleChange).toBe('function'); + }); + + it('should handle model type change', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + pageConfig: { + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => useModelTypeSelection()); + + // Simulate select change + act(() => { + result.current.handleChange({ + target: { value: 'RaptGen' }, + } as React.ChangeEvent); + }); + + // Check dispatch was called with correct action + expect(setPageConfig).toHaveBeenCalledWith({ + experimentName: 'Test Experiment', + modelType: 'RaptGen', + }); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-add/preprocessing-forms/hooks/use-preprocessing-params.tsx b/frontend/components/trainer-add/preprocessing-forms/hooks/use-preprocessing-params.tsx index 9d080232..20387e9e 100644 --- a/frontend/components/trainer-add/preprocessing-forms/hooks/use-preprocessing-params.tsx +++ b/frontend/components/trainer-add/preprocessing-forms/hooks/use-preprocessing-params.tsx @@ -226,6 +226,8 @@ export const usePreprocessingParams = (): PreprocessingParameters => { }) ); } + } catch (e) { + console.error(e); } finally { setIsLoadingTargetlen(false); } diff --git a/frontend/components/trainer-add/train-parameters-forms/hooks/__tests__/use-train-parameters.test.tsx b/frontend/components/trainer-add/train-parameters-forms/hooks/__tests__/use-train-parameters.test.tsx new file mode 100644 index 00000000..0f769449 --- /dev/null +++ b/frontend/components/trainer-add/train-parameters-forms/hooks/__tests__/use-train-parameters.test.tsx @@ -0,0 +1,426 @@ +import { renderHook, act } from '@testing-library/react'; +import { useTrainParameters, useDeviceSelection, useModelLengthEffect } from '../use-train-parameters'; +import { useDispatch, useSelector } from 'react-redux'; +import { setTrainConfig } from '../../../redux/train-config'; +import { apiClient } from '~/services/api-client'; + +// Mock the dependencies +jest.mock('react-redux', () => ({ + useDispatch: jest.fn(), + useSelector: jest.fn(), +})); + +jest.mock('../../../redux/train-config', () => ({ + setTrainConfig: jest.fn(), +})); + +jest.mock('~/services/api-client', () => ({ + apiClient: { + getDevices: jest.fn(), + }, +})); + +// Mock Math.random for predictable tests +const mockRandom = jest.spyOn(Math, 'random'); + +describe('useTrainParameters', () => { + // Setup common mocks + const mockDispatch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock dispatch + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + + // Mock action creator + (setTrainConfig as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_TRAIN_CONFIG', + payload, + })); + + // Mock API response + (apiClient.getDevices as jest.Mock).mockResolvedValue(['cpu', 'cuda']); + }); + + it('should return the correct initial values', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + device: 'cpu', + reiteration: 1, + seed: 42, + epochs: 100, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 22, + }, + preprocessingConfig: { + isValidParams: true, + targetLength: 30, + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => useTrainParameters()); + + // Check returned structure + expect(result.current.device).toBeDefined(); + expect(result.current.reiteration).toBeDefined(); + expect(result.current.seedValue).toBeDefined(); + expect(result.current.epochs).toBeDefined(); + expect(result.current.earlyStopping).toBeDefined(); + expect(result.current.betaDuration).toBeDefined(); + expect(result.current.matchForcingDuration).toBeDefined(); + expect(result.current.matchCost).toBeDefined(); + expect(result.current.modelLength).toBeDefined(); + + // Check initial values + expect(result.current.device.value).toBe('cpu'); + expect(result.current.reiteration.value).toBe(1); + expect(result.current.seedValue.value).toBe(42); + expect(result.current.epochs.value).toBe(100); + expect(result.current.earlyStopping.value).toBe(5); + expect(result.current.betaDuration.value).toBe(20); + expect(result.current.matchForcingDuration.value).toBe(10); + expect(result.current.matchCost.value).toBe(1.0); + expect(result.current.modelLength.value).toBe(22); + }); + + it('should handle numeric parameter changes', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + device: 'cpu', + reiteration: 1, + seed: 42, + epochs: 100, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 22, + }, + preprocessingConfig: { + isValidParams: true, + targetLength: 30, + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => useTrainParameters()); + + // Simulate input change for epochs + act(() => { + result.current.epochs.handleChange({ + target: { value: '200' }, + } as React.ChangeEvent); + }); + + // Check dispatch was called with correct action + expect(setTrainConfig).toHaveBeenCalledWith({ + device: 'cpu', + reiteration: 1, + seed: 42, + epochs: 200, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 22, + }); + }); + + it('should generate random seed value', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + device: 'cpu', + reiteration: 1, + seed: 42, + epochs: 100, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 22, + }, + preprocessingConfig: { + isValidParams: true, + targetLength: 30, + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + }, + }; + return selector(state); + }); + + // Mock random value + mockRandom.mockReturnValue(0.5); + + const { result } = renderHook(() => useTrainParameters()); + + // Generate random seed + act(() => { + result.current.seedValue.generateRandom(); + }); + + // Expected random value: Math.floor(0.5 * 1000000) = 500000 + expect(setTrainConfig).toHaveBeenCalledWith({ + device: 'cpu', + reiteration: 1, + seed: 500000, + epochs: 100, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 22, + }); + }); + + it('should calculate model length based on preprocessing config', () => { + // First render with initial state + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + device: 'cpu', + reiteration: 1, + seed: 42, + epochs: 100, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 22, + }, + preprocessingConfig: { + isValidParams: true, + targetLength: 30, + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + }, + }; + return selector(state); + }); + + const { result, rerender } = renderHook(() => useTrainParameters()); + + // Initial model length + expect(result.current.modelLength.value).toBe(22); + + // Update preprocessing config + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the updated Redux state + const state = { + trainConfig: { + device: 'cpu', + reiteration: 1, + seed: 42, + epochs: 100, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 22, + }, + preprocessingConfig: { + isValidParams: true, + targetLength: 40, // Changed from 30 to 40 + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + }, + }; + return selector(state); + }); + + // Rerender to trigger the effect + rerender(); + + // Expected model length: 40 - 4 - 4 = 32 + expect(setTrainConfig).toHaveBeenCalledWith({ + device: 'cpu', + reiteration: 1, + seed: 42, + epochs: 100, + earlyStoppingEpochs: 5, + betaScheduleEpochs: 20, + forceMatchEpochs: 10, + matchCost: 1.0, + modelLength: 32, + }); + }); +}); + +describe('useDeviceSelection', () => { + const mockDispatch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + (setTrainConfig as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_TRAIN_CONFIG', + payload, + })); + }); + + it('should fetch device list on mount', async () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + device: 'cpu', + }, + }; + return selector(state); + }); + + // Mock API response + (apiClient.getDevices as jest.Mock).mockResolvedValue(['cpu', 'cuda']); + + const { result } = renderHook(() => useDeviceSelection()); + + // Initial state + expect(result.current.value).toBe('cpu'); + expect(result.current.options).toEqual(['cpu']); + + // Wait for API call to resolve + await act(async () => { + // Wait for the effect to complete + await new Promise(resolve => setTimeout(resolve, 0)); + }); + + // Check updated options + expect(result.current.options).toEqual(['cpu', 'cuda']); + expect(apiClient.getDevices).toHaveBeenCalled(); + }); + + it('should handle device selection change', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + device: 'cpu', + reiteration: 1, + seed: 42, + }, + }; + return selector(state); + }); + + const { result } = renderHook(() => useDeviceSelection()); + + // Simulate select change + act(() => { + result.current.handleChange({ + target: { value: 'cuda' }, + } as React.ChangeEvent); + }); + + // Check dispatch was called with correct action + expect(setTrainConfig).toHaveBeenCalledWith({ + device: 'cuda', + reiteration: 1, + seed: 42, + }); + }); +}); + +describe('useModelLengthEffect', () => { + const mockDispatch = jest.fn(); + const mockSetValue = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + (setTrainConfig as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_TRAIN_CONFIG', + payload, + })); + }); + + it('should calculate model length when preprocessing config is valid', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + modelLength: 22, + }, + preprocessingConfig: { + isValidParams: true, + targetLength: 30, + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + }, + }; + return selector(state); + }); + + // Mock model length parameter + const modelLength = { + value: 22, + setValue: mockSetValue, + isValid: true, + handleChange: jest.fn(), + }; + + renderHook(() => useModelLengthEffect(modelLength)); + + // Check model length calculation + expect(mockSetValue).toHaveBeenCalledWith(22); // 30 - 4 - 4 + expect(setTrainConfig).toHaveBeenCalledWith({ + modelLength: 22, + }); + }); + + it('should not calculate model length when preprocessing config is invalid', () => { + // Mock selector values + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + trainConfig: { + modelLength: 22, + }, + preprocessingConfig: { + isValidParams: false, + targetLength: 30, + forwardAdapter: 'ACGU', + reverseAdapter: 'UGCA', + }, + }; + return selector(state); + }); + + // Mock model length parameter + const modelLength = { + value: 22, + setValue: mockSetValue, + isValid: true, + handleChange: jest.fn(), + }; + + renderHook(() => useModelLengthEffect(modelLength)); + + // Should not update model length + expect(mockSetValue).not.toHaveBeenCalled(); + expect(setTrainConfig).not.toHaveBeenCalled(); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-add/upload-file/hooks/__tests__/use-upload-file.test.tsx b/frontend/components/trainer-add/upload-file/hooks/__tests__/use-upload-file.test.tsx new file mode 100644 index 00000000..d314fb7f --- /dev/null +++ b/frontend/components/trainer-add/upload-file/hooks/__tests__/use-upload-file.test.tsx @@ -0,0 +1,245 @@ +import { renderHook, act } from '@testing-library/react'; +import { useUploadFile } from '../use-upload-file'; +import { useDispatch } from 'react-redux'; +import { setSelexDataState } from '../../../redux/selex-data'; + +// Mock the dependencies +jest.mock('react-redux', () => ({ + useDispatch: jest.fn(), +})); + +jest.mock('../../../redux/selex-data', () => ({ + setSelexDataState: jest.fn(), +})); + +// Mock lodash's countBy function +jest.mock('lodash', () => ({ + countBy: (arr: string[]) => { + // Simple implementation for testing + const result: Record = {}; + arr.forEach(item => { + result[item] = (result[item] || 0) + 1; + }); + return result; + }, +})); + +describe('useUploadFile', () => { + // Setup common mocks + const mockDispatch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock dispatch + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + + // Mock action creator + (setSelexDataState as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_SELEX_DATA_STATE', + payload, + })); + }); + + it('should return the correct initial values', () => { + const { result } = renderHook(() => useUploadFile()); + + expect(result.current.dataSource).toEqual([]); + expect(result.current.isValidFile).toBe(true); + expect(result.current.feedback).toBe(''); + expect(result.current.isLoading).toBe(false); + expect(typeof result.current.handleFile).toBe('function'); + }); + + it('should handle valid FASTA file upload', async () => { + const { result } = renderHook(() => useUploadFile()); + + // Create a mock file + const fastaContent = '>Sequence1\nACGTU\n>Sequence2\nUGCAA\n>Sequence3\nACGTU'; + const file = new File([fastaContent], 'test.fasta', { type: 'text/plain' }); + + // Create a mock file input event + const mockEvent = { + target: { + files: [file], + }, + } as unknown as React.ChangeEvent; + + // Mock file.text() method + file.text = jest.fn().mockResolvedValue(fastaContent); + + // Handle file upload + await act(async () => { + await result.current.handleFile(mockEvent); + }); + + // Check loading state during file processing + expect(result.current.isLoading).toBe(false); + + // Check final state + expect(result.current.isValidFile).toBe(true); + expect(result.current.feedback).toBe(''); + + // Check extracted sequences + expect(result.current.dataSource).toEqual([ + { id: 0, sequence: 'ACGTU', duplicate: 2 }, + { id: 1, sequence: 'UGCAA', duplicate: 1 }, + ]); + + // Check Redux action + expect(setSelexDataState).toHaveBeenCalledWith({ + sequences: ['ACGTU', 'UGCAA'], + duplicates: [2, 1], + }); + + expect(mockDispatch).toHaveBeenCalled(); + }); + + it('should handle valid FASTQ file upload', async () => { + const { result } = renderHook(() => useUploadFile()); + + // Create a mock file + const fastqContent = '@Sequence1\nACGTU\n@Sequence2\nUGCAA\n@Sequence3\nACGTU'; + const file = new File([fastqContent], 'test.fastq', { type: 'text/plain' }); + + // Create a mock file input event + const mockEvent = { + target: { + files: [file], + }, + } as unknown as React.ChangeEvent; + + // Mock file.text() method + file.text = jest.fn().mockResolvedValue(fastqContent); + + // Handle file upload + await act(async () => { + await result.current.handleFile(mockEvent); + }); + + // Check final state + expect(result.current.isValidFile).toBe(true); + expect(result.current.feedback).toBe(''); + + // Check extracted sequences + expect(result.current.dataSource).toEqual([ + { id: 0, sequence: 'ACGTU', duplicate: 2 }, + { id: 1, sequence: 'UGCAA', duplicate: 1 }, + ]); + + // Check Redux action + expect(setSelexDataState).toHaveBeenCalledWith({ + sequences: ['ACGTU', 'UGCAA'], + duplicates: [2, 1], + }); + }); + + it('should handle unsupported file type', async () => { + const { result } = renderHook(() => useUploadFile()); + + // Create a mock file with unsupported extension + const file = new File(['some content'], 'test.txt', { type: 'text/plain' }); + + // Create a mock file input event + const mockEvent = { + target: { + files: [file], + }, + } as unknown as React.ChangeEvent; + + // Mock file.text() method + file.text = jest.fn().mockResolvedValue('some content'); + + // Handle file upload + await act(async () => { + await result.current.handleFile(mockEvent); + }); + + // Check final state + expect(result.current.isValidFile).toBe(false); + expect(result.current.feedback).toBe("File type 'txt' is not supported"); + expect(result.current.dataSource).toEqual([]); + expect(setSelexDataState).not.toHaveBeenCalled(); + }); + + it('should handle file with no sequences', async () => { + const { result } = renderHook(() => useUploadFile()); + + // Create a mock file with invalid content + const invalidContent = '>Sequence1\n>Sequence2\n'; + const file = new File([invalidContent], 'test.fasta', { type: 'text/plain' }); + + // Create a mock file input event + const mockEvent = { + target: { + files: [file], + }, + } as unknown as React.ChangeEvent; + + // Mock file.text() method + file.text = jest.fn().mockResolvedValue(invalidContent); + + // Handle file upload + await act(async () => { + await result.current.handleFile(mockEvent); + }); + + // Check final state + expect(result.current.isValidFile).toBe(false); + expect(result.current.feedback).toBe('No sequences found'); + expect(result.current.dataSource).toEqual([]); + expect(setSelexDataState).not.toHaveBeenCalled(); + }); + + it('should handle error during file processing', async () => { + const { result } = renderHook(() => useUploadFile()); + + // Create a mock file + const file = new File(['some content'], 'test.fasta', { type: 'text/plain' }); + + // Create a mock file input event + const mockEvent = { + target: { + files: [file], + }, + } as unknown as React.ChangeEvent; + + // Mock file.text() method to throw an error + file.text = jest.fn().mockRejectedValue(new Error('File reading error')); + + // Handle file upload + await act(async () => { + await result.current.handleFile(mockEvent); + }); + + // Check final state + expect(result.current.isValidFile).toBe(false); + expect(result.current.feedback).toBe('Some error occurred'); + expect(result.current.isLoading).toBe(false); + expect(result.current.dataSource).toEqual([]); + expect(setSelexDataState).not.toHaveBeenCalled(); + }); + + it('should do nothing when no file is selected', async () => { + const { result } = renderHook(() => useUploadFile()); + + // Create a mock file input event with no files + const mockEvent = { + target: { + files: null, + }, + } as unknown as React.ChangeEvent; + + // Handle file upload + await act(async () => { + await result.current.handleFile(mockEvent); + }); + + // State should remain unchanged + expect(result.current.isValidFile).toBe(true); + expect(result.current.feedback).toBe(''); + expect(result.current.isLoading).toBe(false); + expect(result.current.dataSource).toEqual([]); + expect(setSelexDataState).not.toHaveBeenCalled(); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/hooks/__tests__/use-job-item.test.tsx b/frontend/components/trainer-home/hooks/__tests__/use-job-item.test.tsx new file mode 100644 index 00000000..3d9fa87d --- /dev/null +++ b/frontend/components/trainer-home/hooks/__tests__/use-job-item.test.tsx @@ -0,0 +1,331 @@ +import { renderHook, act, waitFor } from '@testing-library/react'; +import { useJobItem } from '../use-job-item'; +import { useRouter } from 'next/router'; +import { apiClient } from '~/services/api-client'; +import _ from 'lodash'; + +// Mock the dependencies +jest.mock('next/router', () => ({ + useRouter: jest.fn(), +})); + +jest.mock('~/services/api-client', () => ({ + apiClient: { + getItem: jest.fn(), + getChildItem: jest.fn(), + }, +})); + +// Mock lodash min function +jest.mock('lodash', () => ({ + min: jest.fn(), +})); + +describe('useJobItem', () => { + // Setup common mocks + const mockParentItem = { + uuid: 'parent-123', + status: 'success', + summary: { + statuses: ['pending', 'progress', 'success'], + indices: [0, 1, 2], + minimum_NLLs: [5.0, 3.0, 4.0], + }, + }; + + const mockChildItem = { + id: 1, + parent_uuid: 'parent-123', + status: 'success', + }; + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock router + (useRouter as jest.Mock).mockReturnValue({ + isReady: true, + query: { + experiment: 'parent-123', + job: '1', + }, + }); + + // Mock API responses + (apiClient.getItem as jest.Mock).mockResolvedValue(mockParentItem); + (apiClient.getChildItem as jest.Mock).mockResolvedValue(mockChildItem); + + // Mock lodash min + (_.min as jest.Mock).mockReturnValue(3.0); + }); + + it('should fetch parent and child items on mount', async () => { + const { result } = renderHook(() => useJobItem()); + + // Initial state + expect(result.current.isLoading).toBe(true); + expect(result.current.pid).toBe('parent-123'); + expect(result.current.cid).toBe('1'); + expect(result.current.pItem).toBeNull(); + expect(result.current.cItem).toBeNull(); + + // Wait for parent fetch to complete + await waitFor(() => { + expect(apiClient.getItem).toHaveBeenCalledWith({ + params: { parent_uuid: 'parent-123' }, + }); + }); + + // Wait for child fetch to complete + await waitFor(() => { + expect(apiClient.getChildItem).toHaveBeenCalledWith({ + params: { + parent_uuid: 'parent-123', + child_id: 1, + }, + }); + }); + + // Wait for final state + await waitFor(() => { + expect(result.current.isLoading).toBe(false); + expect(result.current.pItem).toEqual(mockParentItem); + expect(result.current.cItem).toEqual(mockChildItem); + }); + }); + + it('should not fetch data if router is not ready', async () => { + // Mock router not ready + (useRouter as jest.Mock).mockReturnValue({ + isReady: false, + query: {}, + }); + + renderHook(() => useJobItem()); + + // Wait a bit to ensure any potential API calls would have happened + await act(async () => { + await new Promise(resolve => setTimeout(resolve, 0)); + }); + + // API should not be called + expect(apiClient.getItem).not.toHaveBeenCalled(); + expect(apiClient.getChildItem).not.toHaveBeenCalled(); + }); + + it('should not fetch data if pid is missing', async () => { + // Mock router with missing pid + (useRouter as jest.Mock).mockReturnValue({ + isReady: true, + query: {}, + }); + + renderHook(() => useJobItem()); + + // Wait a bit to ensure any potential API calls would have happened + await act(async () => { + await new Promise(resolve => setTimeout(resolve, 0)); + }); + + // API should not be called + expect(apiClient.getItem).not.toHaveBeenCalled(); + expect(apiClient.getChildItem).not.toHaveBeenCalled(); + }); + + it('should refresh data when refresh function is called', async () => { + const { result } = renderHook(() => useJobItem()); + + // Wait for initial fetches to complete + await waitFor(() => { + expect(result.current.pItem).toEqual(mockParentItem); + expect(result.current.cItem).toEqual(mockChildItem); + }); + + // Clear mocks to check if they're called again + jest.clearAllMocks(); + + // Call refresh function + act(() => { + result.current.refresh(); + }); + + // Wait for refreshed fetches + await waitFor(() => { + expect(apiClient.getItem).toHaveBeenCalledWith({ + params: { parent_uuid: 'parent-123' }, + }); + }); + + await waitFor(() => { + expect(apiClient.getChildItem).toHaveBeenCalledWith({ + params: { + parent_uuid: 'parent-123', + child_id: 1, + }, + }); + }); + }); + + it('should handle API errors gracefully', async () => { + // Mock API error + const mockError = new Error('API error'); + (apiClient.getItem as jest.Mock).mockRejectedValue(mockError); + + // Spy on console.error + jest.spyOn(console, 'error').mockImplementation(() => {}); + + const { result } = renderHook(() => useJobItem()); + + // Wait for API call to fail + await waitFor(() => { + expect(console.error).toHaveBeenCalledWith(mockError); + }); + + // Should reset loading state + expect(result.current.isLoading).toBe(false); + + // Should not have set items + expect(result.current.pItem).toBeNull(); + expect(result.current.cItem).toBeNull(); + }); + + it('should calculate default child ID for success status based on minimum NLL', async () => { + // Mock parent item with success status + const successParentItem = { + ...mockParentItem, + status: 'success', + summary: { + statuses: ['pending', 'progress', 'success'], + indices: [0, 1, 2], + minimum_NLLs: [5.0, 3.0, 4.0], // Minimum is at index 1 + }, + }; + + (apiClient.getItem as jest.Mock).mockResolvedValue(successParentItem); + + // Mock router without child ID + (useRouter as jest.Mock).mockReturnValue({ + isReady: true, + query: { + experiment: 'parent-123', + }, + }); + + renderHook(() => useJobItem()); + + // Wait for parent and child fetches to complete + await waitFor(() => { + expect(apiClient.getChildItem).toHaveBeenCalledWith({ + params: { + parent_uuid: 'parent-123', + child_id: 1, // Index with minimum NLL + }, + }); + }); + }); + + it('should calculate default child ID for progress status based on first occurrence', async () => { + // Mock parent item with progress status + const progressParentItem = { + ...mockParentItem, + status: 'progress', + summary: { + statuses: ['pending', 'progress', 'success'], + indices: [0, 1, 2], + minimum_NLLs: [5.0, 3.0, 4.0], + }, + }; + + (apiClient.getItem as jest.Mock).mockResolvedValue(progressParentItem); + + // Mock router without child ID + (useRouter as jest.Mock).mockReturnValue({ + isReady: true, + query: { + experiment: 'parent-123', + }, + }); + + renderHook(() => useJobItem()); + + // Wait for parent and child fetches to complete + await waitFor(() => { + expect(apiClient.getChildItem).toHaveBeenCalledWith({ + params: { + parent_uuid: 'parent-123', + child_id: 1, // Index of first 'progress' status + }, + }); + }); + }); + + it('should handle NaN values in minimum_NLLs', async () => { + // Mock parent item with NaN values + const nanParentItem = { + ...mockParentItem, + status: 'success', + summary: { + statuses: ['pending', 'progress', 'success'], + indices: [0, 1, 2], + minimum_NLLs: [NaN, 3.0, NaN], // Only index 1 is valid + }, + }; + + (apiClient.getItem as jest.Mock).mockResolvedValue(nanParentItem); + + // Mock router without child ID + (useRouter as jest.Mock).mockReturnValue({ + isReady: true, + query: { + experiment: 'parent-123', + }, + }); + + renderHook(() => useJobItem()); + + // Wait for parent and child fetches to complete + await waitFor(() => { + expect(apiClient.getChildItem).toHaveBeenCalledWith({ + params: { + parent_uuid: 'parent-123', + child_id: 1, // Index with valid NLL + }, + }); + }); + }); + + it('should default to child ID 0 if no valid indices are found', async () => { + // Mock parent item with no valid indices + const invalidParentItem = { + ...mockParentItem, + status: 'unknown', // Unknown status + summary: { + statuses: ['pending', 'progress', 'success'], + indices: [0, 1, 2], + minimum_NLLs: [5.0, 3.0, 4.0], + }, + }; + + (apiClient.getItem as jest.Mock).mockResolvedValue(invalidParentItem); + + // Mock router without child ID + (useRouter as jest.Mock).mockReturnValue({ + isReady: true, + query: { + experiment: 'parent-123', + }, + }); + + renderHook(() => useJobItem()); + + // Wait for parent and child fetches to complete + await waitFor(() => { + expect(apiClient.getChildItem).toHaveBeenCalledWith({ + params: { + parent_uuid: 'parent-123', + child_id: 0, // Default index + }, + }); + }); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/latent-graph/hooks/__tests__/use-graph-config.test.tsx b/frontend/components/trainer-home/latent-graph/hooks/__tests__/use-graph-config.test.tsx new file mode 100644 index 00000000..787563d6 --- /dev/null +++ b/frontend/components/trainer-home/latent-graph/hooks/__tests__/use-graph-config.test.tsx @@ -0,0 +1,145 @@ +import { renderHook, act } from '@testing-library/react'; +import { useGraphConfig } from '../use-graph-config'; +import { useDispatch, useSelector } from 'react-redux'; +import { setGraphConfig } from '../../../redux/graph-config'; + +// Mock the dependencies +jest.mock('react-redux', () => ({ + useDispatch: jest.fn(), + useSelector: jest.fn(), +})); + +jest.mock('../../../redux/graph-config', () => ({ + setGraphConfig: jest.fn(), +})); + +describe('useGraphConfig', () => { + // Setup common mocks + const mockDispatch = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock dispatch + (useDispatch as jest.Mock).mockReturnValue(mockDispatch); + + // Mock selector + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + graphConfig: { + minCount: 5, + }, + }; + return selector(state); + }); + + // Mock action creator + (setGraphConfig as unknown as jest.Mock).mockImplementation((payload) => ({ + type: 'SET_GRAPH_CONFIG', + payload, + })); + }); + + it('should return the correct initial values', () => { + const { result } = renderHook(() => useGraphConfig()); + + expect(result.current.minCount).toBe(5); + expect(result.current.isValidMinCount).toBe(true); + expect(typeof result.current.handleMinCountChange).toBe('function'); + }); + + it('should update minCount and validation state when handleMinCountChange is called with valid input', () => { + const { result } = renderHook(() => useGraphConfig()); + + // Simulate input change with valid value + act(() => { + result.current.handleMinCountChange({ + target: { value: '10' }, + } as React.ChangeEvent); + }); + + // Check state updates + expect(result.current.minCount).toBe(10); + expect(result.current.isValidMinCount).toBe(true); + + // Check Redux action + expect(setGraphConfig).toHaveBeenCalledWith({ + minCount: 10, + }); + + expect(mockDispatch).toHaveBeenCalled(); + }); + + it('should update minCount and validation state when handleMinCountChange is called with invalid input', () => { + const { result } = renderHook(() => useGraphConfig()); + + // Simulate input change with invalid value (negative number) + act(() => { + result.current.handleMinCountChange({ + target: { value: '-5' }, + } as React.ChangeEvent); + }); + + // Check state updates + expect(result.current.minCount).toBe(-5); + expect(result.current.isValidMinCount).toBe(false); + + // Check Redux action - should use previous valid value + expect(setGraphConfig).toHaveBeenCalledWith({ + minCount: 5, // Previous valid value from Redux store + }); + + expect(mockDispatch).toHaveBeenCalled(); + }); + + it('should update minCount and validation state when handleMinCountChange is called with non-numeric input', () => { + const { result } = renderHook(() => useGraphConfig()); + + // Simulate input change with non-numeric value + act(() => { + result.current.handleMinCountChange({ + target: { value: 'abc' }, + } as React.ChangeEvent); + }); + + // Check state updates + expect(result.current.minCount).toBe(NaN); + expect(result.current.isValidMinCount).toBe(false); + + // Check Redux action - should use previous valid value + expect(setGraphConfig).toHaveBeenCalledWith({ + minCount: 5, // Previous valid value from Redux store + }); + + expect(mockDispatch).toHaveBeenCalled(); + }); + + it('should update Redux store when minCount changes', () => { + const { rerender } = renderHook(() => useGraphConfig()); + + // Check initial Redux action + expect(setGraphConfig).toHaveBeenCalledWith({ + minCount: 5, + }); + + // Update selector mock to simulate Redux store update + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the updated Redux state + const state = { + graphConfig: { + minCount: 10, // Updated value + }, + }; + return selector(state); + }); + + // Rerender to trigger effect + rerender(); + + // Check Redux action with updated value + expect(setGraphConfig).toHaveBeenCalledWith({ + minCount: 5, // Initial value from hook state + }); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/latent-graph/hooks/__tests__/use-latent-space-plots.test.tsx b/frontend/components/trainer-home/latent-graph/hooks/__tests__/use-latent-space-plots.test.tsx new file mode 100644 index 00000000..0b2a7d15 --- /dev/null +++ b/frontend/components/trainer-home/latent-graph/hooks/__tests__/use-latent-space-plots.test.tsx @@ -0,0 +1,245 @@ +import { renderHook, act } from '@testing-library/react'; +import { useVaeDataPlot, useDownloadCsv, VaeData } from '../use-latent-space-plots'; +import { useSelector } from 'react-redux'; +import { latentGraphLayout } from '~/components/common/graph-helper'; +import { downloadFileFromText } from '~/components/viewer/downloader/hooks/utils'; + +// Mock the dependencies +jest.mock('react-redux', () => ({ + useSelector: jest.fn(), +})); + +jest.mock('~/components/common/graph-helper', () => ({ + latentGraphLayout: jest.fn(), +})); + +jest.mock('~/components/viewer/downloader/hooks/utils', () => ({ + downloadFileFromText: jest.fn(), +})); + +describe('useVaeDataPlot', () => { + // Mock data + const mockVaeData: VaeData = { + coordsX: [0.1, 0.2, 0.3, 0.4, 0.5], + coordsY: [1.1, 1.2, 1.3, 1.4, 1.5], + randomRegions: ['region1', 'region2', 'region3', 'region4', 'region5'], + duplicates: [3, 5, 2, 10, 1], + }; + + // Mock layout + const mockLayout = { + title: { text: '' }, + plot_bgcolor: '#EDEDED', + xaxis: { + color: '#FFFFFF', + tickfont: { color: '#000000' }, + range: [-3.5, 3.5], + gridcolor: '#FFFFFF', + }, + yaxis: { + color: '#FFFFFF', + tickfont: { color: '#000000' }, + range: [-3.5, 3.5], + gridcolor: '#FFFFFF', + }, + }; + + beforeEach(() => { + jest.clearAllMocks(); + + // Mock selector + (useSelector as jest.Mock).mockImplementation((selector) => { + // Simulate the Redux state + const state = { + graphConfig: { + minCount: 3, + }, + }; + return selector(state); + }); + + // Mock layout + (latentGraphLayout as jest.Mock).mockReturnValue(mockLayout); + }); + + it('should return the correct plot data and layout', () => { + const { result } = renderHook(() => useVaeDataPlot(mockVaeData)); + + // Check layout + expect(result.current.layout).toEqual(mockLayout); + expect(latentGraphLayout).toHaveBeenCalledWith(''); + + // Check plot data + const plot = result.current.vaeDataPlot; + + // Should filter out points with duplicates < minCount (3) + expect(plot.x).toEqual([0.1, 0.2, 0.4]); + expect(plot.y).toEqual([1.1, 1.2, 1.4]); + + // Check marker properties + expect(plot.type).toBe('scatter'); + expect(plot.mode).toBe('markers'); + expect(plot.marker).toEqual({ + size: expect.any(Array), + color: 'black', + opacity: 0.5, + line: { + color: 'black', + }, + }); + + // Check marker sizes (should be Math.max(2, Math.sqrt(duplicate))) + expect(plot.marker?.size).toEqual([ + Math.max(2, Math.sqrt(3)), + Math.max(2, Math.sqrt(5)), + Math.max(2, Math.sqrt(10)), + ]); + + // Check customdata + expect(plot.customdata).toEqual([ + ['region1', '3'], + ['region2', '5'], + ['region4', '10'], + ]); + }); + + it('should update plot data when minCount changes', () => { + const { result, rerender } = renderHook(() => useVaeDataPlot(mockVaeData)); + + // Initial render with minCount = 3 + expect(result.current.vaeDataPlot.x).toEqual([0.1, 0.2, 0.4]); + expect(result.current.vaeDataPlot.y).toEqual([1.1, 1.2, 1.4]); + + // Update minCount to 5 + (useSelector as jest.Mock).mockImplementation((selector) => { + const state = { + graphConfig: { + minCount: 5, + }, + }; + return selector(state); + }); + + // Rerender to trigger effect + rerender(); + + // Should now only include points with duplicates >= 5 + expect(result.current.vaeDataPlot.x).toEqual([0.2, 0.4]); + expect(result.current.vaeDataPlot.y).toEqual([1.2, 1.4]); + expect(result.current.vaeDataPlot.customdata).toEqual([ + ['region2', '5'], + ['region4', '10'], + ]); + }); + + it('should update plot data when vaeData changes', () => { + const { result, rerender } = renderHook( + (props) => useVaeDataPlot(props), + { initialProps: mockVaeData } + ); + + // Initial render + expect(result.current.vaeDataPlot.x).toEqual([0.1, 0.2, 0.4]); + expect(result.current.vaeDataPlot.y).toEqual([1.1, 1.2, 1.4]); + + // Update vaeData + const updatedVaeData: VaeData = { + coordsX: [0.6, 0.7, 0.8], + coordsY: [1.6, 1.7, 1.8], + randomRegions: ['region6', 'region7', 'region8'], + duplicates: [4, 2, 6], + }; + + // Rerender with new props + rerender(updatedVaeData); + + // Should filter based on minCount = 3 + expect(result.current.vaeDataPlot.x).toEqual([0.6, 0.8]); + expect(result.current.vaeDataPlot.y).toEqual([1.6, 1.8]); + expect(result.current.vaeDataPlot.customdata).toEqual([ + ['region6', '4'], + ['region8', '6'], + ]); + }); +}); + +describe('useDownloadCsv', () => { + // Mock data + const mockVaeData: VaeData = { + coordsX: [0.1, 0.2, 0.3], + coordsY: [1.1, 1.2, 1.3], + randomRegions: ['region1', 'region2', 'region3'], + duplicates: [3, 5, 2], + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should call downloadFileFromText with correct CSV data when handleClickSave is called', () => { + const { result } = renderHook(() => useDownloadCsv(mockVaeData)); + + // Call the download function + act(() => { + result.current.handleClickSave(); + }); + + // Expected CSV content + const expectedCsvHeader = 'random_region, x, y, duplicate'; + const expectedCsvData = + 'region1,0.1,1.1,3\n' + + 'region2,0.2,1.2,5\n' + + 'region3,0.3,1.3,2\n'; + + // Check that downloadFileFromText was called with correct arguments + expect(downloadFileFromText).toHaveBeenCalledWith( + expectedCsvHeader + '\n' + expectedCsvData, + 'latent_points.csv' + ); + }); + + it('should update CSV data when vaeData changes', () => { + const { result, rerender } = renderHook( + (props) => useDownloadCsv(props), + { initialProps: mockVaeData } + ); + + // Call the download function + act(() => { + result.current.handleClickSave(); + }); + + // Check initial call + expect(downloadFileFromText).toHaveBeenCalledTimes(1); + + // Update vaeData + const updatedVaeData: VaeData = { + coordsX: [0.4, 0.5], + coordsY: [1.4, 1.5], + randomRegions: ['region4', 'region5'], + duplicates: [4, 6], + }; + + // Rerender with new props + rerender(updatedVaeData); + + // Call the download function again + act(() => { + result.current.handleClickSave(); + }); + + // Expected updated CSV content + const expectedCsvHeader = 'random_region, x, y, duplicate'; + const expectedCsvData = + 'region4,0.4,1.4,4\n' + + 'region5,0.5,1.5,6\n'; + + // Check that downloadFileFromText was called with updated arguments + expect(downloadFileFromText).toHaveBeenCalledWith( + expectedCsvHeader + '\n' + expectedCsvData, + 'latent_points.csv' + ); + + expect(downloadFileFromText).toHaveBeenCalledTimes(2); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/latent-graph/hooks/use-latent-space-plots.tsx b/frontend/components/trainer-home/latent-graph/hooks/use-latent-space-plots.tsx index fa76d1df..c1c61762 100644 --- a/frontend/components/trainer-home/latent-graph/hooks/use-latent-space-plots.tsx +++ b/frontend/components/trainer-home/latent-graph/hooks/use-latent-space-plots.tsx @@ -3,6 +3,8 @@ import { useCallback, useMemo } from "react"; import { useSelector } from "react-redux"; import { RootState } from "../../redux/store"; import { latentGraphLayout } from "~/components/common/graph-helper"; +import { downloadFileFromText } from "~/components/viewer/downloader/hooks/utils"; +import { zip } from "lodash"; export type VaeData = { coordsX: number[]; @@ -18,12 +20,18 @@ export const useVaeDataPlot = (vaeData: VaeData) => { const { minCount } = useSelector((state: RootState) => state.graphConfig); const vaeDataPlot: Partial = useMemo(() => { - const { coordsX, coordsY, randomRegions, duplicates } = vaeData; - const mask = duplicates.map((value) => value >= minCount); - const trace: Partial = { - x: coordsX.filter((_, index) => mask[index]), - y: coordsY.filter((_, index) => mask[index]), - type: "scattergl", + const mcMask = vaeData.duplicates.map((value) => value >= minCount); + const coordsX = vaeData.coordsX.filter((_, index) => mcMask[index]); + const coordsY = vaeData.coordsY.filter((_, index) => mcMask[index]); + const duplicates = vaeData.duplicates.filter((_, index) => mcMask[index]); + const randomRegions = vaeData.randomRegions.filter( + (_, index) => mcMask[index] + ); + + return { + x: coordsX, + y: coordsY, + type: "scatter", mode: "markers", marker: { size: duplicates.map((d) => Math.max(2, Math.sqrt(d))), @@ -33,19 +41,15 @@ export const useVaeDataPlot = (vaeData: VaeData) => { color: "black", }, }, - customdata: mask - .map((value, index) => - value ? [randomRegions[index], duplicates[index]] : null - ) - .filter((value) => value !== null) as [string, number][], + customdata: zip( + randomRegions, + duplicates.map((d) => d.toString()) + ) as unknown as string[], hovertemplate: - "X: %{x}
" + - "Y: %{y}
" + + "Coord: (%{x:.4f}, %{y:.4f})
" + "Random Region: %{customdata[0]}
" + - "Duplicates: %{customdata[1]}
" + - "", + "Duplicates: %{customdata[1]}
", }; - return trace; }, [vaeData, minCount]); const layout = useMemo(() => { @@ -73,18 +77,9 @@ export const useDownloadCsv = (vaeData: VaeData) => { vaeData.duplicates[i] + "\n"; } + // download csv file - const blob = new Blob([csvHeader + "\n" + csvData], { - type: "text/csv", - }); - const url = window.URL.createObjectURL(blob); - const a = document.createElement("a"); - a.setAttribute("hidden", ""); - a.setAttribute("href", url); - a.setAttribute("download", "latent_points.csv"); - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); + downloadFileFromText(csvHeader + "\n" + csvData, "latent_points.csv"); }, [vaeData]); return { handleClickSave }; diff --git a/frontend/components/trainer-home/losses-graph/hooks/__tests__/use-losses-graph.test.tsx b/frontend/components/trainer-home/losses-graph/hooks/__tests__/use-losses-graph.test.tsx new file mode 100644 index 00000000..c6668fcc --- /dev/null +++ b/frontend/components/trainer-home/losses-graph/hooks/__tests__/use-losses-graph.test.tsx @@ -0,0 +1,182 @@ +import { renderHook, act } from '@testing-library/react'; +import { useLossDataPlot, useDownloadCsv, LossData } from '../use-losses-graph'; +import { downloadFileFromText } from '~/components/viewer/downloader/hooks/utils'; + +// Mock the dependencies +jest.mock('~/components/viewer/downloader/hooks/utils', () => ({ + downloadFileFromText: jest.fn(), +})); + +describe('useLossDataPlot', () => { + // Mock data + const mockLossData: LossData = { + epochs: [1, 2, 3, 4, 5], + trainLosses: [0.5, 0.4, 0.3, 0.25, 0.2], + testLosses: [0.6, 0.5, 0.45, 0.4, 0.35], + testRecons: [0.4, 0.35, 0.3, 0.25, 0.2], + testKlds: [0.2, 0.15, 0.15, 0.15, 0.15], + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should return the correct plot data for all loss types', () => { + const { result } = renderHook(() => useLossDataPlot(mockLossData)); + + const plots = result.current; + + // Should return 4 traces + expect(plots).toHaveLength(4); + + // Check train loss trace + const trainLossTrace = plots[0]; + expect(trainLossTrace.x).toEqual(mockLossData.epochs); + expect(trainLossTrace.y).toEqual(mockLossData.trainLosses); + expect(trainLossTrace.type).toBe('scatter'); + expect(trainLossTrace.mode).toBe('lines'); + expect(trainLossTrace.name).toBe('Train loss'); + expect(trainLossTrace.line?.color).toBe('#000000'); + expect(trainLossTrace.hovertemplate).toContain('Epoch: %{x}
'); + expect(trainLossTrace.hovertemplate).toContain('NLL (ELBO): %{y}
'); + + // Check test loss trace + const testLossTrace = plots[1]; + expect(testLossTrace.x).toEqual(mockLossData.epochs); + expect(testLossTrace.y).toEqual(mockLossData.testLosses); + expect(testLossTrace.type).toBe('scatter'); + expect(testLossTrace.mode).toBe('lines'); + expect(testLossTrace.name).toBe('Test loss'); + expect(testLossTrace.line?.color).toBe('#FF0000'); + + // Check test reconstruction loss trace + const testReconTrace = plots[2]; + expect(testReconTrace.x).toEqual(mockLossData.epochs); + expect(testReconTrace.y).toEqual(mockLossData.testRecons); + expect(testReconTrace.type).toBe('scatter'); + expect(testReconTrace.mode).toBe('lines'); + expect(testReconTrace.name).toBe('Test reconstruction loss'); + expect(testReconTrace.line?.color).toBe('#00FF00'); + + // Check test KL divergence loss trace + const testKldTrace = plots[3]; + expect(testKldTrace.x).toEqual(mockLossData.epochs); + expect(testKldTrace.y).toEqual(mockLossData.testKlds); + expect(testKldTrace.type).toBe('scatter'); + expect(testKldTrace.mode).toBe('lines'); + expect(testKldTrace.name).toBe('Test KL divergence loss'); + expect(testKldTrace.line?.color).toBe('#0000FF'); + }); + + it('should update plot data when lossData changes', () => { + const { result, rerender } = renderHook( + (props) => useLossDataPlot(props), + { initialProps: mockLossData } + ); + + // Initial render + expect(result.current[0].y).toEqual(mockLossData.trainLosses); + + // Update lossData + const updatedLossData: LossData = { + epochs: [1, 2, 3], + trainLosses: [0.3, 0.2, 0.1], + testLosses: [0.4, 0.3, 0.2], + testRecons: [0.2, 0.15, 0.1], + testKlds: [0.1, 0.05, 0.05], + }; + + // Rerender with new props + rerender(updatedLossData); + + // Check updated data + expect(result.current[0].x).toEqual(updatedLossData.epochs); + expect(result.current[0].y).toEqual(updatedLossData.trainLosses); + expect(result.current[1].y).toEqual(updatedLossData.testLosses); + expect(result.current[2].y).toEqual(updatedLossData.testRecons); + expect(result.current[3].y).toEqual(updatedLossData.testKlds); + }); +}); + +describe('useDownloadCsv', () => { + // Mock data + const mockLossData: LossData = { + epochs: [1, 2, 3], + trainLosses: [0.5, 0.4, 0.3], + testLosses: [0.6, 0.5, 0.45], + testRecons: [0.4, 0.35, 0.3], + testKlds: [0.2, 0.15, 0.15], + }; + + beforeEach(() => { + jest.clearAllMocks(); + }); + + it('should call downloadFileFromText with correct CSV data when handleClickSave is called', () => { + const { result } = renderHook(() => useDownloadCsv(mockLossData)); + + // Call the download function + act(() => { + result.current.handleClickSave(); + }); + + // Expected CSV content + const expectedCsvHeader = 'epoch, train_loss, test_loss, test_recon, test_kld'; + const expectedCsvData = + '0,0.5,0.6,0.4,0.2\n' + + '1,0.4,0.5,0.35,0.15\n' + + '2,0.3,0.45,0.3,0.15\n'; + + // Check that downloadFileFromText was called with correct arguments + expect(downloadFileFromText).toHaveBeenCalledWith( + expectedCsvHeader + '\n' + expectedCsvData, + 'losses.csv' + ); + }); + + it('should update CSV data when lossData changes', () => { + const { result, rerender } = renderHook( + (props) => useDownloadCsv(props), + { initialProps: mockLossData } + ); + + // Call the download function + act(() => { + result.current.handleClickSave(); + }); + + // Check initial call + expect(downloadFileFromText).toHaveBeenCalledTimes(1); + + // Update lossData + const updatedLossData: LossData = { + epochs: [1, 2], + trainLosses: [0.3, 0.2], + testLosses: [0.4, 0.3], + testRecons: [0.2, 0.15], + testKlds: [0.1, 0.05], + }; + + // Rerender with new props + rerender(updatedLossData); + + // Call the download function again + act(() => { + result.current.handleClickSave(); + }); + + // Expected updated CSV content + const expectedCsvHeader = 'epoch, train_loss, test_loss, test_recon, test_kld'; + const expectedCsvData = + '0,0.3,0.4,0.2,0.1\n' + + '1,0.2,0.3,0.15,0.05\n'; + + // Check that downloadFileFromText was called with updated arguments + expect(downloadFileFromText).toHaveBeenCalledWith( + expectedCsvHeader + '\n' + expectedCsvData, + 'losses.csv' + ); + + expect(downloadFileFromText).toHaveBeenCalledTimes(2); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/losses-graph/hooks/use-losses-graph.tsx b/frontend/components/trainer-home/losses-graph/hooks/use-losses-graph.tsx index f3f9723c..7d41c9dd 100644 --- a/frontend/components/trainer-home/losses-graph/hooks/use-losses-graph.tsx +++ b/frontend/components/trainer-home/losses-graph/hooks/use-losses-graph.tsx @@ -1,5 +1,6 @@ import { Layout, PlotData } from "plotly.js"; import { useCallback, useMemo } from "react"; +import { downloadFileFromText } from "~/components/viewer/downloader/hooks/utils"; export type LossData = { epochs: number[]; @@ -9,47 +10,6 @@ export type LossData = { testKlds: number[]; }; -export const useLayout = (title: string): Partial => { - return { - title: title, - plot_bgcolor: "#EDEDED", - xaxis: { - color: "#FFFFFF", - tickfont: { - color: "#000000", - }, - gridcolor: "#FFFFFF", - }, - yaxis: { - color: "#FFFFFF", - tickfont: { - color: "#000000", - }, - gridcolor: "#FFFFFF", - }, - hoverlabel: { - font: { - family: "monospace", - }, - }, - showlegend: true, - legend: { - xanchor: "right", - x: 1, - yanchor: "top", - y: 1, - }, - clickmode: "event+select", - margin: { - l: 30, - r: 30, - b: 30, - t: 30, - pad: 5, - }, - }; -}; - export const useLossDataPlot = (lossData: LossData) => { const lossDataPlot: Partial[] = useMemo(() => { const { epochs, trainLosses, testLosses, testRecons, testKlds } = lossData; @@ -124,18 +84,9 @@ export const useDownloadCsv = (lossData: LossData) => { lossData.testKlds[i] + "\n"; } + // download csv file - const blob = new Blob([csvHeader + "\n" + csvData], { - type: "text/csv", - }); - const url = window.URL.createObjectURL(blob); - const a = document.createElement("a"); - a.setAttribute("hidden", ""); - a.setAttribute("href", url); - a.setAttribute("download", "losses.csv"); - document.body.appendChild(a); - a.click(); - document.body.removeChild(a); + downloadFileFromText(csvHeader + "\n" + csvData, "losses.csv"); }, [lossData]); return { handleClickSave }; diff --git a/frontend/components/trainer-home/losses-graph/index.tsx b/frontend/components/trainer-home/losses-graph/index.tsx index f2644471..9f316390 100644 --- a/frontend/components/trainer-home/losses-graph/index.tsx +++ b/frontend/components/trainer-home/losses-graph/index.tsx @@ -3,10 +3,10 @@ import dynamic from "next/dynamic"; import { Badge, Card } from "react-bootstrap"; import { LossData, - useLayout, useLossDataPlot, useDownloadCsv, } from "./hooks/use-losses-graph"; +import { Layout } from "plotly.js"; const Plot = dynamic(() => import("react-plotly.js"), { ssr: false }); @@ -15,11 +15,53 @@ type Props = { lossData: LossData; }; +const createLayout = (title: string): Partial => { + return { + title: title, + plot_bgcolor: "#EDEDED", + xaxis: { + color: "#FFFFFF", + tickfont: { + color: "#000000", + }, + gridcolor: "#FFFFFF", + }, + yaxis: { + color: "#FFFFFF", + tickfont: { + color: "#000000", + }, + gridcolor: "#FFFFFF", + }, + hoverlabel: { + font: { + family: "monospace", + }, + }, + showlegend: true, + legend: { + xanchor: "right", + x: 1, + yanchor: "top", + y: 1, + }, + clickmode: "event+select", + margin: { + l: 30, + r: 30, + b: 30, + t: 30, + pad: 5, + }, + } +}; + export const LossesGraph: React.FC = ({ title, lossData }) => { - const layout = useLayout(title); const lossDataPlot = useLossDataPlot(lossData); const { handleClickSave } = useDownloadCsv(lossData); + const layout = createLayout(title); + return ( diff --git a/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-child-job-card.test.ts b/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-child-job-card.test.ts new file mode 100644 index 00000000..f67b57a2 --- /dev/null +++ b/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-child-job-card.test.ts @@ -0,0 +1,166 @@ +import { renderHook, act } from '@testing-library/react'; +import { useChildJobCard } from '../use-child-job-card'; +import { intervalToDuration } from 'date-fns'; + +// Mock date-fns +jest.mock('date-fns', () => ({ + intervalToDuration: jest.fn(), +})); + +describe('useChildJobCard', () => { + // Mock console.log + const originalConsoleLog = console.log; + const mockConsoleLog = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + console.log = mockConsoleLog; + + // Mock intervalToDuration + (intervalToDuration as jest.Mock).mockImplementation(({ start, end }) => { + const duration = end - start; + const seconds = Math.floor((duration / 1000) % 60); + const minutes = Math.floor((duration / (1000 * 60)) % 60); + const hours = Math.floor((duration / (1000 * 60 * 60)) % 24); + const days = Math.floor(duration / (1000 * 60 * 60 * 24)); + + return { days, hours, minutes, seconds }; + }); + }); + + afterEach(() => { + console.log = originalConsoleLog; + }); + + it('should return the correct utility functions', () => { + const { result } = renderHook(() => useChildJobCard({})); + + expect(typeof result.current.formatDurationText).toBe('function'); + expect(typeof result.current.handleClick).toBe('function'); + expect(typeof result.current.getCardStyle).toBe('function'); + }); + + describe('formatDurationText', () => { + it('should return empty string for non-progress status', () => { + const { result } = renderHook(() => useChildJobCard({})); + + expect(result.current.formatDurationText('success', 1000)).toBe(''); + expect(result.current.formatDurationText('failure', 1000)).toBe(''); + expect(result.current.formatDurationText('pending', 1000)).toBe(''); + expect(result.current.formatDurationText('suspend', 1000)).toBe(''); + }); + + it('should return empty string if duration is not provided', () => { + const { result } = renderHook(() => useChildJobCard({})); + + expect(result.current.formatDurationText('progress')).toBe(''); + expect(result.current.formatDurationText('progress', undefined)).toBe(''); + }); + + it('should format duration with seconds only', () => { + const { result } = renderHook(() => useChildJobCard({})); + + // Mock 30 seconds + (intervalToDuration as jest.Mock).mockReturnValueOnce({ + days: 0, + hours: 0, + minutes: 0, + seconds: 30, + }); + + expect(result.current.formatDurationText('progress', 30000)).toBe('Running for 30s'); + }); + + it('should format duration with minutes and seconds', () => { + const { result } = renderHook(() => useChildJobCard({})); + + // Mock 2 minutes and 30 seconds + (intervalToDuration as jest.Mock).mockReturnValueOnce({ + days: 0, + hours: 0, + minutes: 2, + seconds: 30, + }); + + expect(result.current.formatDurationText('progress', 150000)).toBe('Running for 2m 30s'); + }); + + it('should format duration with hours, minutes, and seconds', () => { + const { result } = renderHook(() => useChildJobCard({})); + + // Mock 1 hour, 2 minutes, and 30 seconds + (intervalToDuration as jest.Mock).mockReturnValueOnce({ + days: 0, + hours: 1, + minutes: 2, + seconds: 30, + }); + + expect(result.current.formatDurationText('progress', 3750000)).toBe('Running for 1h 2m 30s'); + }); + + it('should convert days to hours', () => { + const { result } = renderHook(() => useChildJobCard({})); + + // Mock 2 days, 3 hours, 2 minutes, and 30 seconds + (intervalToDuration as jest.Mock).mockReturnValueOnce({ + days: 2, + hours: 3, + minutes: 2, + seconds: 30, + }); + + // 2 days * 24 + 3 hours = 51 hours + expect(result.current.formatDurationText('progress', 176550000)).toBe('Running for 51h 2m 30s'); + }); + }); + + describe('handleClick', () => { + it('should call the provided onClick handler', () => { + const mockOnClick = jest.fn(); + const { result } = renderHook(() => useChildJobCard({ onClick: mockOnClick })); + + const mockEvent = {} as React.MouseEvent; + + act(() => { + result.current.handleClick(mockEvent); + }); + + expect(mockOnClick).toHaveBeenCalledWith(mockEvent); + }); + + it('should log a message if event or onClick is undefined', () => { + const { result } = renderHook(() => useChildJobCard({})); + + const mockEvent = {} as React.MouseEvent; + + act(() => { + result.current.handleClick(mockEvent); + }); + + expect(mockConsoleLog).toHaveBeenCalledWith('event or onClick is undefined'); + }); + }); + + describe('getCardStyle', () => { + it('should return the correct style for selected state', () => { + const { result } = renderHook(() => useChildJobCard({})); + + const style = result.current.getCardStyle(true); + + expect(style.backgroundColor).toBe('#f0f0f0'); + expect(style.border).toBe('1px solid gray'); + expect(style.cursor).toBe('pointer'); + }); + + it('should return the correct style for unselected state', () => { + const { result } = renderHook(() => useChildJobCard({})); + + const style = result.current.getCardStyle(false); + + expect(style.backgroundColor).toBe('#f5f5f5'); + expect(style.border).toBe('1px solid lightgray'); + expect(style.cursor).toBe('pointer'); + }); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-job-card.test.ts b/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-job-card.test.ts new file mode 100644 index 00000000..05c77ea3 --- /dev/null +++ b/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-job-card.test.ts @@ -0,0 +1,158 @@ +import { renderHook, act } from '@testing-library/react'; +import { useJobCard } from '../use-job-card'; + +describe('useJobCard', () => { + // Mock console.log + const originalConsoleLog = console.log; + const mockConsoleLog = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + console.log = mockConsoleLog; + }); + + afterEach(() => { + console.log = originalConsoleLog; + }); + + it('should return the correct initial values', () => { + const { result } = renderHook(() => useJobCard({})); + + expect(result.current.clickedModel).toBeNull(); + expect(typeof result.current.setClickedModel).toBe('function'); + expect(typeof result.current.handleClick).toBe('function'); + expect(typeof result.current.handleChildClick).toBe('function'); + expect(typeof result.current.getCardStyle).toBe('function'); + }); + + describe('handleClick', () => { + it('should call the provided onClick handler and reset clickedModel', () => { + const mockOnClick = jest.fn(); + const { result } = renderHook(() => useJobCard({ onClick: mockOnClick })); + + // Set initial clickedModel + act(() => { + result.current.setClickedModel(5); + }); + + expect(result.current.clickedModel).toBe(5); + + // Call handleClick + const mockEvent = {} as React.MouseEvent; + + act(() => { + result.current.handleClick(mockEvent); + }); + + expect(mockOnClick).toHaveBeenCalledWith(mockEvent); + expect(result.current.clickedModel).toBeNull(); + }); + + it('should do nothing if onClick is not provided', () => { + const { result } = renderHook(() => useJobCard({})); + + // Set initial clickedModel + act(() => { + result.current.setClickedModel(5); + }); + + // Call handleClick + const mockEvent = {} as React.MouseEvent; + + act(() => { + result.current.handleClick(mockEvent); + }); + + // clickedModel should remain unchanged + expect(result.current.clickedModel).toBe(5); + }); + }); + + describe('handleChildClick', () => { + it('should call the provided onChildClick handler and set clickedModel', () => { + const mockOnChildClick = jest.fn(); + const { result } = renderHook(() => useJobCard({ onChildClick: mockOnChildClick })); + + // Call handleChildClick + const mockEvent = { + stopPropagation: jest.fn(), + } as unknown as React.MouseEvent; + + act(() => { + result.current.handleChildClick(3, mockEvent); + }); + + expect(mockOnChildClick).toHaveBeenCalledWith(3, mockEvent); + expect(result.current.clickedModel).toBe(3); + expect(mockEvent.stopPropagation).toHaveBeenCalled(); + }); + + it('should set clickedModel even if onChildClick is not provided', () => { + const { result } = renderHook(() => useJobCard({})); + + // Call handleChildClick + const mockEvent = { + stopPropagation: jest.fn(), + } as unknown as React.MouseEvent; + + act(() => { + result.current.handleChildClick(3, mockEvent); + }); + + expect(result.current.clickedModel).toBe(3); + expect(mockEvent.stopPropagation).toHaveBeenCalled(); + }); + }); + + describe('getCardStyle', () => { + it('should return the correct style for selected state with no clicked model', () => { + const { result } = renderHook(() => useJobCard({})); + + const style = result.current.getCardStyle(true, null); + + expect(style.backgroundColor).toBe('lightgray'); + expect(style.border).toBe('1px solid gray'); + expect(style.cursor).toBe('pointer'); + }); + + it('should return the correct style for selected state with clicked model', () => { + const { result } = renderHook(() => useJobCard({})); + + const style = result.current.getCardStyle(true, 3); + + expect(style.backgroundColor).toBe('lightgray'); + expect(style.border).toBe('1px solid #E5E5E5'); + expect(style.cursor).toBe('pointer'); + }); + + it('should return the correct style for unselected state', () => { + const { result } = renderHook(() => useJobCard({})); + + const style = result.current.getCardStyle(false, null); + + expect(style.backgroundColor).toBe('#E5E5E5'); + expect(style.border).toBe('1px solid #E5E5E5'); + expect(style.cursor).toBe('pointer'); + }); + }); + + describe('setClickedModel', () => { + it('should update the clickedModel state', () => { + const { result } = renderHook(() => useJobCard({})); + + expect(result.current.clickedModel).toBeNull(); + + act(() => { + result.current.setClickedModel(3); + }); + + expect(result.current.clickedModel).toBe(3); + + act(() => { + result.current.setClickedModel(null); + }); + + expect(result.current.clickedModel).toBeNull(); + }); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-vae-jobs.test.ts b/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-vae-jobs.test.ts new file mode 100644 index 00000000..e22bf3fd --- /dev/null +++ b/frontend/components/trainer-home/vae-jobs-list/hooks/__tests__/use-vae-jobs.test.ts @@ -0,0 +1,263 @@ +import { renderHook, act, waitFor } from '@testing-library/react'; +import { useVaeJobs } from '../use-vae-jobs'; +import { useRouter } from 'next/router'; +import { apiClient } from '~/services/api-client'; + +// Mock the dependencies +jest.mock('next/router', () => ({ + useRouter: jest.fn(), +})); + +jest.mock('~/services/api-client', () => ({ + apiClient: { + postSearchJobs: jest.fn(), + }, +})); + +describe('useVaeJobs', () => { + // Mock data + const mockJobs = [ + { + uuid: 'job-1', + name: 'Job 1', + status: 'progress', + series: [ + { + item_id: 1, + item_status: 'progress', + item_datetime_start: Math.floor(Date.now() / 1000) - 3600, // 1 hour ago + item_duration_suspend: 0, + item_datetime_laststop: null, + item_epochs_current: 50, + item_epochs_total: 100, + }, + ], + }, + { + uuid: 'job-2', + name: 'Job 2', + status: 'success', + series: [ + { + item_id: 2, + item_status: 'success', + item_datetime_start: Math.floor(Date.now() / 1000) - 7200, // 2 hours ago + item_duration_suspend: 0, + item_datetime_laststop: Math.floor(Date.now() / 1000) - 3600, // 1 hour ago + item_epochs_current: 100, + item_epochs_total: 100, + }, + ], + }, + ]; + + // Mock router + const mockPush = jest.fn(); + + beforeEach(() => { + jest.clearAllMocks(); + jest.useFakeTimers(); + + // Mock router + (useRouter as jest.Mock).mockReturnValue({ + query: { experiment: 'job-1' }, + push: mockPush, + }); + + // Mock API response + (apiClient.postSearchJobs as jest.Mock).mockResolvedValue(mockJobs); + }); + + afterEach(() => { + jest.useRealTimers(); + }); + + it('should fetch jobs on mount and set up interval', async () => { + const { result } = renderHook(() => useVaeJobs()); + + // Initial state + expect(result.current.runningJobs).toEqual([]); + expect(result.current.finishedJobs).toEqual([]); + expect(result.current.searchQuery).toBe(''); + expect(result.current.experimentId).toBe('job-1'); + + // Wait for API call to resolve + await waitFor(() => { + expect(apiClient.postSearchJobs).toHaveBeenCalledWith({ + search_regex: undefined, + }); + }); + + // Wait for state to update + await waitFor(() => { + expect(result.current.runningJobs).toHaveLength(1); + }); + + // Check state after API call + expect(result.current.runningJobs[0].uuid).toBe('job-1'); + expect(result.current.finishedJobs).toHaveLength(1); + expect(result.current.finishedJobs[0].uuid).toBe('job-2'); + + // Advance timer to trigger interval + act(() => { + jest.advanceTimersByTime(5000); + }); + + // Check that API was called again + await waitFor(() => { + expect(apiClient.postSearchJobs).toHaveBeenCalledTimes(2); + }); + }); + + it('should update jobs when searchQuery changes', async () => { + const { result } = renderHook(() => useVaeJobs()); + + // Wait for initial API call to resolve + await waitFor(() => { + expect(apiClient.postSearchJobs).toHaveBeenCalled(); + }); + + // Update search query + act(() => { + result.current.setSearchQuery('test'); + }); + + // Check API call with search query + await waitFor(() => { + expect(apiClient.postSearchJobs).toHaveBeenCalledWith({ + search_regex: 'test', + }); + }); + }); + + it('should navigate to job page when handleJobClick is called', () => { + const { result } = renderHook(() => useVaeJobs()); + + act(() => { + result.current.handleJobClick('job-3'); + }); + + expect(mockPush).toHaveBeenCalledWith('?experiment=job-3', undefined, { + scroll: false, + }); + }); + + it('should navigate to child job page when handleChildJobClick is called', () => { + const { result } = renderHook(() => useVaeJobs()); + + act(() => { + result.current.handleChildJobClick('job-3', 4); + }); + + expect(mockPush).toHaveBeenCalledWith('?experiment=job-3&job=4', undefined, { + scroll: false, + }); + }); + + describe('calculateRunningSeriesItem', () => { + it('should calculate series item for progress status', () => { + const { result } = renderHook(() => useVaeJobs()); + + const now = Date.now(); + const oneHourAgo = Math.floor(now / 1000) - 3600; // 1 hour ago + + const childJob = { + item_id: 1, + item_status: 'progress' as const, + item_datetime_start: oneHourAgo, + item_duration_suspend: 0, + item_datetime_laststop: null, + item_epochs_current: 50, + item_epochs_total: 100, + }; + + const seriesItem = result.current.calculateRunningSeriesItem(childJob); + + expect(seriesItem.id).toBe(1); + expect(seriesItem.status).toBe('progress'); + expect(seriesItem.epochsCurrent).toBe(50); + expect(seriesItem.epochsTotal).toBe(100); + + // Duration should be approximately 1 hour (3600000 ms) + expect(seriesItem.duration).toBeCloseTo(3_600_000, -4); + }); + + it('should calculate series item for pending status', () => { + const { result } = renderHook(() => useVaeJobs()); + + const childJob = { + item_id: 1, + item_status: 'pending' as const, + item_datetime_start: Math.floor(Date.now() / 1000) - 3600, + item_duration_suspend: 0, + item_datetime_laststop: null, + item_epochs_current: 0, + item_epochs_total: 100, + }; + + const seriesItem = result.current.calculateRunningSeriesItem(childJob); + + expect(seriesItem.id).toBe(1); + expect(seriesItem.status).toBe('pending'); + expect(seriesItem.duration).toBe(0); + expect(seriesItem.epochsCurrent).toBe(0); + expect(seriesItem.epochsTotal).toBe(100); + }); + + it('should calculate series item for suspend status', () => { + const { result } = renderHook(() => useVaeJobs()); + + const now = Date.now(); + const twoHoursAgo = Math.floor(now / 1000) - 7200; // 2 hours ago + const oneHourAgo = Math.floor(now / 1000) - 3600; // 1 hour ago + + const childJob = { + item_id: 1, + item_status: 'suspend' as const, + item_datetime_start: twoHoursAgo, + item_duration_suspend: 0, + item_datetime_laststop: oneHourAgo, + item_epochs_current: 50, + item_epochs_total: 100, + }; + + const seriesItem = result.current.calculateRunningSeriesItem(childJob); + + expect(seriesItem.id).toBe(1); + expect(seriesItem.status).toBe('suspend'); + expect(seriesItem.epochsCurrent).toBe(50); + expect(seriesItem.epochsTotal).toBe(100); + + // Duration should be approximately 1 hour (3600 ms) + expect(seriesItem.duration).toBeCloseTo(3_600_000, -4); + }); + }); + + describe('calculateFinishedSeriesItem', () => { + it('should calculate series item with item_datetime_laststop', () => { + const { result } = renderHook(() => useVaeJobs()); + + const now = Date.now(); + const twoHoursAgo = Math.floor(now / 1000) - 7200; // 2 hours ago + const oneHourAgo = Math.floor(now / 1000) - 3600; // 1 hour ago + + const childJob = { + item_id: 2, + item_status: 'success' as const, + item_datetime_start: twoHoursAgo, + item_duration_suspend: 0, + item_datetime_laststop: oneHourAgo, + item_epochs_current: 100, + item_epochs_total: 100, + }; + + const seriesItem = result.current.calculateFinishedSeriesItem(childJob); + + expect(seriesItem.id).toBe(2); + expect(seriesItem.status).toBe('success'); + expect(seriesItem.duration).toBeCloseTo(3_600_000, -4); + expect(seriesItem.epochsCurrent).toBe(100); + expect(seriesItem.epochsTotal).toBe(100); + }); + }); +}); \ No newline at end of file diff --git a/frontend/components/trainer-home/vae-jobs-list/hooks/use-vae-jobs.ts b/frontend/components/trainer-home/vae-jobs-list/hooks/use-vae-jobs.ts index bcfce554..58bd7db1 100644 --- a/frontend/components/trainer-home/vae-jobs-list/hooks/use-vae-jobs.ts +++ b/frontend/components/trainer-home/vae-jobs-list/hooks/use-vae-jobs.ts @@ -108,11 +108,9 @@ export function useVaeJobs(): UseVaeJobsReturn { const calculateFinishedSeriesItem = (childJob: ChildJob): SeriesItem => { return { id: childJob.item_id, - duration: childJob.item_datetime_laststop - ? childJob.item_datetime_laststop - : Date.now() - - childJob.item_datetime_start - - childJob.item_duration_suspend, + duration: ((childJob.item_datetime_laststop as number) - + childJob.item_datetime_start - + childJob.item_duration_suspend) * 1000, status: childJob.item_status, epochsCurrent: childJob.item_epochs_current, epochsTotal: childJob.item_epochs_total,