Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 147 additions & 45 deletions src/eterna/folding/FoldUtil.ts
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,28 @@ export default class FoldUtil {
return probUnpaired;
}

public static checkBpList(bpList: number[][]) {
// Ensure that the bpList is sorted, then check for duplicated nucleotides
const newBpList = bpList.slice();
newBpList.sort((bpA, bpB) => bpA[0] - bpB[0]);
const nts = newBpList.flat();
if (nts.length > new Set(nts).size) {
log.warn('Some nucletotides found in more than 1 base pair');
newBpList.forEach((bpA, i) => {
const bpB = newBpList[i + 1] || [];
if (bpA[0] === bpB[0] && bpA[1] === bpB[1]) {
log.warn(`Removing duplicated base pair: ${bpA}`);
newBpList.splice(i, 1);
} else if (bpB.includes(bpA[0])) {
log.warn(`bpA: ${bpA}, bpB: ${bpB}`);
} else if (bpB.includes(bpA[1])) {
log.warn(`bpA: ${bpA}, bpB: ${bpB}`);
}
});
}
return newBpList;
}

public static postProcessStruct(
bpList: number[][],
sequenceLength: number,
Expand Down Expand Up @@ -188,25 +210,7 @@ export default class FoldUtil {
}
}

// Ensure that the bpList is sorted, then check for duplicated nucleotides
bpList.sort((bpA, bpB) => bpA[0] - bpB[0]);
const nts = bpList.flat();
if (nts.length > new Set(nts).size) {
log.warn('Some nucletotides found in more than 1 base pair');
bpList.forEach((bpA, i) => {
const bpB = bpList[i + 1] || [];
if (bpA[0] === bpB[0] && bpA[1] === bpB[1]) {
log.warn(`Removing duplicated base pair: ${bpA}`);
bpList.splice(i, 1);
} else if (bpB.includes(bpA[0])) {
log.warn(`bpA: ${bpA}, bpB: ${bpB}`);
} else if (bpB.includes(bpA[1])) {
log.warn(`bpA: ${bpA}, bpB: ${bpB}`);
}
});
}

return FoldUtil.postProcessStruct(bpList, seq.length, minLenHelix);
return FoldUtil.postProcessStruct(FoldUtil.checkBpList(bpList), seq.length, minLenHelix);
}

public static hungarian(
Expand All @@ -225,15 +229,15 @@ export default class FoldUtil {
// ln = false,
// allowedBulgeLen = 0,
}: {
exp?: number,
sigmoidSlopeFactor?: null,
probTo0ThresholdPrior?: number,
probTo1ThresholdPrior?: number,
theta?: number,
ln?: boolean,
addPUnpaired?: boolean,
allowedBulgeLen?: number,
minLenHelix?: number,
// exp?: number,
// sigmoidSlopeFactor?: null,
// ln?: boolean,
// allowedBulgeLen?: number,
} = {}
): SecStruct {
// Ported from arnie https://github.com/DasLab/arnie/blob/04ae74d592c240ceb8a01dfbceff55c6342f7d42/src/arnie/pk_predictors.py#L108
Expand All @@ -252,6 +256,7 @@ export default class FoldUtil {
denseBpps[row * seq.length + col] = prob;
denseBpps[col * seq.length + row] = prob;
}
const origDenseBpps = denseBpps.slice();

if (addPUnpaired) {
const allPunp = FoldUtil.pUnpaired(dotArray, seq, bppStatisticBehavior);
Expand All @@ -269,24 +274,121 @@ export default class FoldUtil {
}

const {colInd} = FoldUtil.linearSumAssignment(denseBpps, seq.length);
const bpList = [];
// Hungarian/linear sum assignment operates on a bipartite graph such that each row is assigned to
// exactly one column and each column is assigned to exactly one row, however our case is not
// bipartite. That means some chosen assignments could conflict with others, either creating
// a "chain" (eg [(0,5), (5,10)]) or cycle (eg [(0,5), (5,10), (10, 0)]). We resolve these
// conflicts by solving for the maximum weight independent set. (Note that if we have
// two assignments like [(0,5) and (5,0)] we only need to deduplicate, hence the usage of set).
const bpAssignmentsSet = new Set<string>();
for (const [col, row] of colInd.entries()) {
for (let ii = 0; ii < dotArray.data.length; ii += 3) {
if (
dotArray.data[ii] === col + 1
&& dotArray.data[ii + 1] === row + 1
&& dotArray.data[ii + 2] > theta
&& col < row
// This is a condition that is NOT used in arnie - but we want to ensure that
// we never return invalid pairs
&& EPars.pairType(seq.nt(col), seq.nt(row))
) {
bpList.push([col + 1, row + 1]);
if (
origDenseBpps[col * seq.length + row] > theta
&& col !== row
// This is a condition that is NOT used in arnie - but we want to ensure that
// we never return invalid pairs
&& EPars.pairType(seq.nt(col), seq.nt(row))
) {
bpAssignmentsSet.add([col, row].sort((a, b) => a - b).join(','));
}
}
const bpAssignments = Array.from(bpAssignmentsSet).map(
(str) => str.split(',').map((s) => +s) as [number, number]
);
const bpList = [];
while (true) {
const newBp = bpAssignments.pop();
if (!newBp) break;
const bps = [newBp];
// Start building a chain to the "left"
let checkNt = bps[0][0];
while (true) {
// eslint-disable-next-line no-loop-func
const conflict = bpAssignments.find((bp) => bp.includes(checkNt));
if (!conflict) break;
bps.unshift(conflict);
bpAssignments.filter((bp) => bp !== conflict);
// eslint-disable-next-line no-loop-func
checkNt = conflict.find((nt) => nt !== checkNt) as number;
}
checkNt = bps[bps.length - 1][0];
while (true) {
// eslint-disable-next-line no-loop-func
const conflict = bpAssignments.find((bp) => bp.includes(checkNt));
if (!conflict) break;
bps.push(conflict);
bpAssignments.filter((bp) => bp !== conflict);
// eslint-disable-next-line no-loop-func
checkNt = conflict.find((nt) => nt !== checkNt) as number;
}

if (bps.length === 1) {
bpList.push(...bps);
} else if (bps.length > 2 && (bps[bps.length].includes(bps[0][0]) || bps[bps.length].includes(bps[0][1]))) {
// We have a cycle. We need to try both excluding the first element and excluding
// the last element (only one or the other, or neither, can be present since they conflict)
const {bps: bpListA, prob: probA} = FoldUtil.maxWeightIndependentSet(
bps.slice(1), origDenseBpps, seq.length
);
const {bps: bpListB, prob: probB} = FoldUtil.maxWeightIndependentSet(
bps.slice(0, bps.length - 1), origDenseBpps, seq.length
);
if (probA > probB) {
bpList.push(...bpListA);
} else {
bpList.push(...bpListB);
}
} else {
const {bps: maxBpList} = FoldUtil.maxWeightIndependentSet(bps, origDenseBpps, seq.length);
bpList.push(...maxBpList);
}
}

return FoldUtil.postProcessStruct(
FoldUtil.checkBpList(bpList.map((bp) => bp.map((nt) => nt + 1))),
seq.length,
minLenHelix
);
}

public static maxWeightIndependentSet(pairs: number[][], probs: number[], seqlen: number) {
const maxSets: {prob: number, bps: number[][]}[] = [];
for (const bp of pairs) {
const bpProb = probs[bp[0] * seqlen + bp[1]];

if (maxSets.length === 0) {
maxSets.push({prob: bpProb, bps: [bp]});
} else if (maxSets.length === 1) {
if (maxSets[0]['prob'] > bpProb) {
maxSets.push(maxSets[0]);
} else if (bpProb > maxSets[0]['prob']) {
maxSets.push({prob: bpProb, bps: [bp]});
} else if (Math.abs(maxSets[0]['bps'][0][0] - maxSets[0]['bps'][0][1]) <= Math.abs(bp[0] - bp[1])) {
maxSets.push(maxSets[0]);
} else {
maxSets.push({prob: bpProb, bps: [bp]});
}
} else if (maxSets[-1]['prob'] > maxSets[-2]['prob'] + bpProb) {
maxSets.push(maxSets[maxSets.length - 1]);
} else if (maxSets[maxSets.length - 2]['prob'] + bpProb > maxSets[maxSets.length - 1]['prob']) {
maxSets.push({
prob: maxSets[maxSets.length - 2]['prob'] + bpProb,
bps: [...maxSets[maxSets.length - 2]['bps'], bp]
});
} else if (
Math.abs(maxSets[maxSets.length - 1]['bps'][0][0] - maxSets[maxSets.length - 1]['bps'][0][1])
<= Math.abs(bp[0] - bp[1])
) {
maxSets.push(maxSets[maxSets.length - 1]);
} else {
maxSets.push({
prob: maxSets[maxSets.length - 2]['prob'] + bpProb,
bps: [...maxSets[maxSets.length - 2]['bps'], bp]
});
}
}

return FoldUtil.postProcessStruct(bpList, seq.length, minLenHelix);
return maxSets[maxSets.length - 1];
}

/**
Expand All @@ -297,15 +399,15 @@ export default class FoldUtil {
// Ported from SciPy https://github.com/scipy/scipy/blob/b421cd64d98c16811f84efbfab7701b335f811be/scipy/optimize/_lsap.c#L37
// We can take some shortcuts since we know this is a square matrix and never call it with maximize

const u = Array(size).fill(0);
const v = Array(size).fill(0);
const shortestPathCosts = Array(size).fill(0);
const path = Array(size).fill(-1);
const col4row = Array(size).fill(-1);
const row4col = Array(size).fill(-1);
const SR = Array(size).fill(false);
const SC = Array(size).fill(false);
const remaining = Array(size).fill(0);
const u = Array<number>(size).fill(0);
const v = Array<number>(size).fill(0);
const shortestPathCosts = Array<number>(size).fill(0);
const path = Array<number>(size).fill(-1);
const col4row = Array<number>(size).fill(-1);
const row4col = Array<number>(size).fill(-1);
const SR = Array<boolean>(size).fill(false);
const SC = Array<boolean>(size).fill(false);
const remaining = Array<number>(size).fill(0);

for (let curRow = 0; curRow < size; curRow++) {
// ----- Start augmenting_path -----
Expand Down