From 9957da6fcbf82779aa2977b405f81c528fc59669 Mon Sep 17 00:00:00 2001 From: Jonathan Romano Date: Fri, 28 Feb 2025 18:04:34 -0500 Subject: [PATCH] Implement hungarian algorithm update with improved conflict resolution handling --- src/eterna/folding/FoldUtil.ts | 192 +++++++++++++++++++++++++-------- 1 file changed, 147 insertions(+), 45 deletions(-) diff --git a/src/eterna/folding/FoldUtil.ts b/src/eterna/folding/FoldUtil.ts index 30e1f594..db171558 100644 --- a/src/eterna/folding/FoldUtil.ts +++ b/src/eterna/folding/FoldUtil.ts @@ -91,6 +91,28 @@ export default class FoldUtil { return probUnpaired; } + public static checkBpList(bpList: number[][]) { + // Ensure that the bpList is sorted, then check for duplicated nucleotides + const newBpList = bpList.slice(); + newBpList.sort((bpA, bpB) => bpA[0] - bpB[0]); + const nts = newBpList.flat(); + if (nts.length > new Set(nts).size) { + log.warn('Some nucletotides found in more than 1 base pair'); + newBpList.forEach((bpA, i) => { + const bpB = newBpList[i + 1] || []; + if (bpA[0] === bpB[0] && bpA[1] === bpB[1]) { + log.warn(`Removing duplicated base pair: ${bpA}`); + newBpList.splice(i, 1); + } else if (bpB.includes(bpA[0])) { + log.warn(`bpA: ${bpA}, bpB: ${bpB}`); + } else if (bpB.includes(bpA[1])) { + log.warn(`bpA: ${bpA}, bpB: ${bpB}`); + } + }); + } + return newBpList; + } + public static postProcessStruct( bpList: number[][], sequenceLength: number, @@ -188,25 +210,7 @@ export default class FoldUtil { } } - // Ensure that the bpList is sorted, then check for duplicated nucleotides - bpList.sort((bpA, bpB) => bpA[0] - bpB[0]); - const nts = bpList.flat(); - if (nts.length > new Set(nts).size) { - log.warn('Some nucletotides found in more than 1 base pair'); - bpList.forEach((bpA, i) => { - const bpB = bpList[i + 1] || []; - if (bpA[0] === bpB[0] && bpA[1] === bpB[1]) { - log.warn(`Removing duplicated base pair: ${bpA}`); - bpList.splice(i, 1); - } else if (bpB.includes(bpA[0])) { - log.warn(`bpA: ${bpA}, bpB: ${bpB}`); - } else if (bpB.includes(bpA[1])) { - log.warn(`bpA: ${bpA}, bpB: ${bpB}`); - } - }); - } - - return FoldUtil.postProcessStruct(bpList, seq.length, minLenHelix); + return FoldUtil.postProcessStruct(FoldUtil.checkBpList(bpList), seq.length, minLenHelix); } public static hungarian( @@ -225,15 +229,15 @@ export default class FoldUtil { // ln = false, // allowedBulgeLen = 0, }: { - exp?: number, - sigmoidSlopeFactor?: null, probTo0ThresholdPrior?: number, probTo1ThresholdPrior?: number, theta?: number, - ln?: boolean, addPUnpaired?: boolean, - allowedBulgeLen?: number, minLenHelix?: number, + // exp?: number, + // sigmoidSlopeFactor?: null, + // ln?: boolean, + // allowedBulgeLen?: number, } = {} ): SecStruct { // Ported from arnie https://github.com/DasLab/arnie/blob/04ae74d592c240ceb8a01dfbceff55c6342f7d42/src/arnie/pk_predictors.py#L108 @@ -252,6 +256,7 @@ export default class FoldUtil { denseBpps[row * seq.length + col] = prob; denseBpps[col * seq.length + row] = prob; } + const origDenseBpps = denseBpps.slice(); if (addPUnpaired) { const allPunp = FoldUtil.pUnpaired(dotArray, seq, bppStatisticBehavior); @@ -269,24 +274,121 @@ export default class FoldUtil { } const {colInd} = FoldUtil.linearSumAssignment(denseBpps, seq.length); - const bpList = []; + // Hungarian/linear sum assignment operates on a bipartite graph such that each row is assigned to + // exactly one column and each column is assigned to exactly one row, however our case is not + // bipartite. That means some chosen assignments could conflict with others, either creating + // a "chain" (eg [(0,5), (5,10)]) or cycle (eg [(0,5), (5,10), (10, 0)]). We resolve these + // conflicts by solving for the maximum weight independent set. (Note that if we have + // two assignments like [(0,5) and (5,0)] we only need to deduplicate, hence the usage of set). + const bpAssignmentsSet = new Set(); for (const [col, row] of colInd.entries()) { - for (let ii = 0; ii < dotArray.data.length; ii += 3) { - if ( - dotArray.data[ii] === col + 1 - && dotArray.data[ii + 1] === row + 1 - && dotArray.data[ii + 2] > theta - && col < row - // This is a condition that is NOT used in arnie - but we want to ensure that - // we never return invalid pairs - && EPars.pairType(seq.nt(col), seq.nt(row)) - ) { - bpList.push([col + 1, row + 1]); + if ( + origDenseBpps[col * seq.length + row] > theta + && col !== row + // This is a condition that is NOT used in arnie - but we want to ensure that + // we never return invalid pairs + && EPars.pairType(seq.nt(col), seq.nt(row)) + ) { + bpAssignmentsSet.add([col, row].sort((a, b) => a - b).join(',')); + } + } + const bpAssignments = Array.from(bpAssignmentsSet).map( + (str) => str.split(',').map((s) => +s) as [number, number] + ); + const bpList = []; + while (true) { + const newBp = bpAssignments.pop(); + if (!newBp) break; + const bps = [newBp]; + // Start building a chain to the "left" + let checkNt = bps[0][0]; + while (true) { + // eslint-disable-next-line no-loop-func + const conflict = bpAssignments.find((bp) => bp.includes(checkNt)); + if (!conflict) break; + bps.unshift(conflict); + bpAssignments.filter((bp) => bp !== conflict); + // eslint-disable-next-line no-loop-func + checkNt = conflict.find((nt) => nt !== checkNt) as number; + } + checkNt = bps[bps.length - 1][0]; + while (true) { + // eslint-disable-next-line no-loop-func + const conflict = bpAssignments.find((bp) => bp.includes(checkNt)); + if (!conflict) break; + bps.push(conflict); + bpAssignments.filter((bp) => bp !== conflict); + // eslint-disable-next-line no-loop-func + checkNt = conflict.find((nt) => nt !== checkNt) as number; + } + + if (bps.length === 1) { + bpList.push(...bps); + } else if (bps.length > 2 && (bps[bps.length].includes(bps[0][0]) || bps[bps.length].includes(bps[0][1]))) { + // We have a cycle. We need to try both excluding the first element and excluding + // the last element (only one or the other, or neither, can be present since they conflict) + const {bps: bpListA, prob: probA} = FoldUtil.maxWeightIndependentSet( + bps.slice(1), origDenseBpps, seq.length + ); + const {bps: bpListB, prob: probB} = FoldUtil.maxWeightIndependentSet( + bps.slice(0, bps.length - 1), origDenseBpps, seq.length + ); + if (probA > probB) { + bpList.push(...bpListA); + } else { + bpList.push(...bpListB); + } + } else { + const {bps: maxBpList} = FoldUtil.maxWeightIndependentSet(bps, origDenseBpps, seq.length); + bpList.push(...maxBpList); + } + } + + return FoldUtil.postProcessStruct( + FoldUtil.checkBpList(bpList.map((bp) => bp.map((nt) => nt + 1))), + seq.length, + minLenHelix + ); + } + + public static maxWeightIndependentSet(pairs: number[][], probs: number[], seqlen: number) { + const maxSets: {prob: number, bps: number[][]}[] = []; + for (const bp of pairs) { + const bpProb = probs[bp[0] * seqlen + bp[1]]; + + if (maxSets.length === 0) { + maxSets.push({prob: bpProb, bps: [bp]}); + } else if (maxSets.length === 1) { + if (maxSets[0]['prob'] > bpProb) { + maxSets.push(maxSets[0]); + } else if (bpProb > maxSets[0]['prob']) { + maxSets.push({prob: bpProb, bps: [bp]}); + } else if (Math.abs(maxSets[0]['bps'][0][0] - maxSets[0]['bps'][0][1]) <= Math.abs(bp[0] - bp[1])) { + maxSets.push(maxSets[0]); + } else { + maxSets.push({prob: bpProb, bps: [bp]}); } + } else if (maxSets[-1]['prob'] > maxSets[-2]['prob'] + bpProb) { + maxSets.push(maxSets[maxSets.length - 1]); + } else if (maxSets[maxSets.length - 2]['prob'] + bpProb > maxSets[maxSets.length - 1]['prob']) { + maxSets.push({ + prob: maxSets[maxSets.length - 2]['prob'] + bpProb, + bps: [...maxSets[maxSets.length - 2]['bps'], bp] + }); + } else if ( + Math.abs(maxSets[maxSets.length - 1]['bps'][0][0] - maxSets[maxSets.length - 1]['bps'][0][1]) + <= Math.abs(bp[0] - bp[1]) + ) { + maxSets.push(maxSets[maxSets.length - 1]); + } else { + maxSets.push({ + prob: maxSets[maxSets.length - 2]['prob'] + bpProb, + bps: [...maxSets[maxSets.length - 2]['bps'], bp] + }); } } - return FoldUtil.postProcessStruct(bpList, seq.length, minLenHelix); + return maxSets[maxSets.length - 1]; } /** @@ -297,15 +399,15 @@ export default class FoldUtil { // Ported from SciPy https://github.com/scipy/scipy/blob/b421cd64d98c16811f84efbfab7701b335f811be/scipy/optimize/_lsap.c#L37 // We can take some shortcuts since we know this is a square matrix and never call it with maximize - const u = Array(size).fill(0); - const v = Array(size).fill(0); - const shortestPathCosts = Array(size).fill(0); - const path = Array(size).fill(-1); - const col4row = Array(size).fill(-1); - const row4col = Array(size).fill(-1); - const SR = Array(size).fill(false); - const SC = Array(size).fill(false); - const remaining = Array(size).fill(0); + const u = Array(size).fill(0); + const v = Array(size).fill(0); + const shortestPathCosts = Array(size).fill(0); + const path = Array(size).fill(-1); + const col4row = Array(size).fill(-1); + const row4col = Array(size).fill(-1); + const SR = Array(size).fill(false); + const SC = Array(size).fill(false); + const remaining = Array(size).fill(0); for (let curRow = 0; curRow < size; curRow++) { // ----- Start augmenting_path -----