Skip to content

Commit 91dc85b

Browse files
committed
Improve MhcCleanupPipelineJob and include group collapsing code in SequenceBasedTypingAnalysis
1 parent 92bf1b0 commit 91dc85b

File tree

1 file changed

+244
-3
lines changed

1 file changed

+244
-3
lines changed

SequenceAnalysis/src/org/labkey/sequenceanalysis/run/analysis/SequenceBasedTypingAnalysis.java

Lines changed: 244 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,26 @@
22

33
import au.com.bytecode.opencsv.CSVWriter;
44
import htsjdk.samtools.filter.DuplicateReadFilter;
5+
import org.apache.commons.collections.CollectionUtils;
56
import org.apache.logging.log4j.Logger;
67
import org.json.JSONObject;
8+
import org.labkey.api.collections.CaseInsensitiveHashMap;
9+
import org.labkey.api.data.CompareType;
10+
import org.labkey.api.data.Container;
11+
import org.labkey.api.data.DbSchema;
12+
import org.labkey.api.data.DbSchemaType;
713
import org.labkey.api.data.DbScope;
814
import org.labkey.api.data.SQLFragment;
915
import org.labkey.api.data.Selector;
16+
import org.labkey.api.data.SimpleFilter;
1017
import org.labkey.api.data.SqlSelector;
18+
import org.labkey.api.data.Table;
19+
import org.labkey.api.data.TableInfo;
20+
import org.labkey.api.data.TableSelector;
1121
import org.labkey.api.pipeline.PipelineJobException;
22+
import org.labkey.api.query.FieldKey;
23+
import org.labkey.api.query.QueryService;
24+
import org.labkey.api.security.User;
1225
import org.labkey.api.sequenceanalysis.SequenceAnalysisService;
1326
import org.labkey.api.sequenceanalysis.model.AnalysisModel;
1427
import org.labkey.api.sequenceanalysis.model.ReadData;
@@ -24,6 +37,7 @@
2437
import org.labkey.api.sequenceanalysis.pipeline.ToolParameterDescriptor;
2538
import org.labkey.api.util.Compress;
2639
import org.labkey.api.util.FileUtil;
40+
import org.labkey.api.util.PageFlowUtil;
2741
import org.labkey.api.util.Pair;
2842
import org.labkey.api.writer.PrintWriters;
2943
import org.labkey.sequenceanalysis.SequenceAnalysisSchema;
@@ -34,9 +48,15 @@
3448
import java.sql.SQLException;
3549
import java.util.ArrayList;
3650
import java.util.Arrays;
51+
import java.util.Collections;
52+
import java.util.Comparator;
3753
import java.util.HashMap;
3854
import java.util.List;
55+
import java.util.ListIterator;
3956
import java.util.Map;
57+
import java.util.Set;
58+
import java.util.TreeSet;
59+
import java.util.stream.Collectors;
4060

4161
/**
4262
* User: bimber
@@ -45,7 +65,7 @@
4565
*/
4666
public class SequenceBasedTypingAnalysis extends AbstractPipelineStep implements AnalysisStep
4767
{
48-
public SequenceBasedTypingAnalysis(PipelineStepProvider provider, PipelineContext ctx)
68+
public SequenceBasedTypingAnalysis(PipelineStepProvider<?> provider, PipelineContext ctx)
4969
{
5070
super(provider, ctx);
5171
}
@@ -182,8 +202,6 @@ public void exec(ResultSet rs) throws SQLException
182202
@Override
183203
public Output performAnalysisPerSampleLocal(AnalysisModel model, File inputBam, File referenceFasta, File outDir) throws PipelineJobException
184204
{
185-
//TODO: store pct of mapped matching MHC
186-
187205
File expectedTxt = getSBTSummaryFile(outDir, inputBam);
188206
if (expectedTxt.exists())
189207
{
@@ -196,6 +214,9 @@ public Output performAnalysisPerSampleLocal(AnalysisModel model, File inputBam,
196214
{
197215
expectedTxt.delete();
198216
}
217+
218+
// Perform second pass to collapse groups:
219+
new AlignmentGroupCompare(model.getAnalysisId(), getPipelineCtx().getJob().getContainer(), getPipelineCtx().getJob().getUser()).collapseGroups(getPipelineCtx().getLogger(), getPipelineCtx().getJob().getUser());
199220
}
200221
else
201222
{
@@ -353,4 +374,224 @@ protected File getSBTSummaryFile(File outputDir, File bam)
353374
{
354375
return new File(outputDir, FileUtil.getBaseName(bam) + ".sbt_hits.txt");
355376
}
377+
378+
public static class AlignmentGroupCompare
379+
{
380+
private final int analysisId;
381+
private final List<AlignmentGroup> groups = new ArrayList<>();
382+
383+
public AlignmentGroupCompare(final int analysisId, Container c, User u)
384+
{
385+
this.analysisId = analysisId;
386+
387+
new TableSelector(QueryService.get().getUserSchema(u, c, "sequenceanalysis").getTable("alignment_summary_grouped"), PageFlowUtil.set("analysis_id", "alleles", "lineages", "totalLineages", "total_reads", "total_forward", "total_reverse", "valid_pairs", "rowids"), new SimpleFilter(FieldKey.fromString("analysis_id"), analysisId), null).forEachResults(rs -> {
388+
if (rs.getString(FieldKey.fromString("alleles")) == null)
389+
{
390+
return;
391+
}
392+
393+
AlignmentGroup g = new AlignmentGroup();
394+
g.analysisId = analysisId;
395+
g.alleles.addAll(Arrays.stream(rs.getString(FieldKey.fromString("alleles")).split("\n")).toList());
396+
g.lineages = rs.getString(FieldKey.fromString("lineages"));
397+
g.totalLineages = rs.getInt(FieldKey.fromString("totalLineages"));
398+
g.totalReads = rs.getInt(FieldKey.fromString("total_reads"));
399+
g.totalForward = rs.getInt(FieldKey.fromString("total_forward"));
400+
g.totalReverse = rs.getInt(FieldKey.fromString("total_reverse"));
401+
g.validPairs = rs.getInt(FieldKey.fromString("valid_pairs"));
402+
g.rowIds.addAll(Arrays.stream(rs.getString(FieldKey.fromString("rowids")).split(",")).map(Integer::parseInt).toList());
403+
404+
groups.add(g);
405+
});
406+
407+
sortGroups();
408+
}
409+
410+
private void sortGroups()
411+
{
412+
groups.sort(Comparator.comparingInt(o -> o.alleles.size()));
413+
Collections.reverse(groups);
414+
}
415+
416+
public Pair<Integer, Integer> collapseGroups(Logger log, User user)
417+
{
418+
final long initialCounts = groups.stream().map(x -> x.totalReads).mapToInt(Integer::intValue).sum();
419+
420+
if (groups.isEmpty())
421+
{
422+
return null;
423+
}
424+
425+
Pair<Integer, Integer> ret = Pair.of(0, 0);
426+
while (doCollapse(log))
427+
{
428+
//do work. each time we have any groups collapsed, we will restart. once there are no collapsed allele groups, we finish
429+
sortGroups();
430+
}
431+
432+
final int endCounts = groups.stream().map(x -> x.totalReads).mapToInt(Integer::intValue).sum();
433+
if (initialCounts != endCounts)
434+
{
435+
throw new IllegalStateException("Starting/ending counts not equal: " + initialCounts + " / " + endCounts);
436+
}
437+
438+
List<Integer> alignmentIdsToDelete = groups.stream().map(x -> x.rowIdsToDelete).flatMap(List::stream).toList();
439+
List<AlignmentGroup> alignmentGroupsToUpdate = groups.stream().filter(g -> !g.rowIdsToDelete.isEmpty()).toList();
440+
log.info("Alignment IDs to delete: " + alignmentIdsToDelete.size());
441+
log.info("Alignment groups to update counts: " + alignmentGroupsToUpdate.size());
442+
443+
if (!alignmentGroupsToUpdate.isEmpty())
444+
{
445+
log.info("Updating counts in " + alignmentGroupsToUpdate.size() + " groups after collapse");
446+
TableInfo alignmentSummary = DbSchema.get("sequenceanalysis", DbSchemaType.Module).getTable("alignment_summary");
447+
448+
alignmentGroupsToUpdate.forEach(ag -> {
449+
Map<String, Object> toUpdate = new CaseInsensitiveHashMap<>();
450+
toUpdate.put("rowId", ag.rowIds.get(0));
451+
toUpdate.put("total", ag.totalReads);
452+
toUpdate.put("total_forward", ag.totalForward);
453+
toUpdate.put("total_reverse", ag.totalReverse);
454+
toUpdate.put("valid_pairs", ag.validPairs);
455+
Table.update(user, alignmentSummary, toUpdate, ag.rowIds.get(0));
456+
457+
if (ag.rowIds.size() > 1) {
458+
log.info("The following IDs are redundant and will also be removed: " + ag.rowIds.subList(1, ag.rowIds.size()).stream().map(String::valueOf).collect(Collectors.joining(", ")));
459+
alignmentIdsToDelete.addAll(ag.rowIds.subList(1, ag.rowIds.size()));
460+
}
461+
});
462+
}
463+
464+
if (!alignmentIdsToDelete.isEmpty())
465+
{
466+
log.info("Deleting " + alignmentIdsToDelete.size() + " alignment_summary records after collapse");
467+
468+
TableInfo alignmentSummary = DbSchema.get("sequenceanalysis", DbSchemaType.Module).getTable("alignment_summary");
469+
TableInfo alignmentSummaryJunction = DbSchema.get("sequenceanalysis", DbSchemaType.Module).getTable("alignment_summary_junction");
470+
471+
alignmentIdsToDelete.forEach(rowId -> {
472+
Table.delete(alignmentSummary, rowId);
473+
});
474+
ret.first += alignmentIdsToDelete.size();
475+
476+
// also junction records:
477+
SimpleFilter alignmentIdFilter = new SimpleFilter(FieldKey.fromString("analysis_id"), analysisId, CompareType.EQUAL);
478+
alignmentIdFilter.addCondition(FieldKey.fromString("alignment_id"), alignmentIdsToDelete, CompareType.IN);
479+
List<Integer> junctionRecordsToDelete = new TableSelector(alignmentSummaryJunction, PageFlowUtil.set("rowid"), alignmentIdFilter, null).getArrayList(Integer.class);
480+
log.info("Deleting " + junctionRecordsToDelete.size() + " alignment_summary_junction records");
481+
if (!junctionRecordsToDelete.isEmpty())
482+
{
483+
junctionRecordsToDelete.forEach(rowId -> {
484+
Table.delete(alignmentSummaryJunction, rowId);
485+
});
486+
ret.second += junctionRecordsToDelete.size();
487+
}
488+
}
489+
490+
return ret;
491+
}
492+
493+
private boolean doCollapse(Logger log)
494+
{
495+
ListIterator<AlignmentGroup> it = groups.listIterator();
496+
AlignmentGroup g1 = it.next();
497+
while (it.hasNext())
498+
{
499+
if (compareGroupToOthers(g1))
500+
{
501+
log.info("Collapsed: " + g1.lineages + ", with: " + g1.alleles.size());
502+
return true; // abort and restart the process with a new list iterator
503+
}
504+
505+
g1 = it.next();
506+
}
507+
508+
return false;
509+
}
510+
511+
private boolean compareGroupToOthers(AlignmentGroup g1)
512+
{
513+
boolean didCollapse = false;
514+
int idx = groups.indexOf(g1);
515+
if (idx == groups.size() - 1)
516+
{
517+
return false;
518+
}
519+
520+
List<AlignmentGroup> groupsClone = new ArrayList<>(groups.subList(idx + 1, groups.size()));
521+
ListIterator<AlignmentGroup> it = groupsClone.listIterator();
522+
while (it.hasNext())
523+
{
524+
AlignmentGroup g2 = it.next();
525+
if (g2.equals(g1))
526+
{
527+
throw new IllegalStateException("Should not happen");
528+
}
529+
530+
if (g1.canCombine(g2))
531+
{
532+
AlignmentGroup combined = g1.combine(g2);
533+
groups.remove(g1);
534+
groups.add(idx, combined);
535+
g1 = combined;
536+
groups.remove(g2);
537+
538+
didCollapse = true;
539+
}
540+
}
541+
542+
return didCollapse;
543+
}
544+
545+
public static class AlignmentGroup
546+
{
547+
int analysisId;
548+
Set<String> alleles = new TreeSet<>();
549+
String lineages;
550+
int totalLineages;
551+
int totalReads;
552+
int totalForward;
553+
int totalReverse;
554+
int validPairs;
555+
List<Integer> rowIds = new ArrayList<>();
556+
557+
List<Integer> rowIdsToDelete = new ArrayList<>();
558+
559+
public boolean canCombine(AlignmentGroup g2)
560+
{
561+
if (this.totalLineages > 1 || g2.totalLineages > 1 || this.alleles.size() < 4 || g2.alleles.size() < 4)
562+
{
563+
return false;
564+
}
565+
566+
return CollectionUtils.disjunction(this.alleles, g2.alleles).size() == 1;
567+
}
568+
569+
public AlignmentGroup combine(AlignmentGroup g2)
570+
{
571+
// Take the larger allele set:
572+
if (g2.alleles.size() > this.alleles.size())
573+
{
574+
g2.rowIdsToDelete.addAll(this.rowIds);
575+
g2.rowIdsToDelete.addAll(this.rowIdsToDelete);
576+
g2.totalReads = g2.totalReads + totalReads;
577+
g2.totalForward = g2.totalForward + totalForward;
578+
g2.totalReverse = g2.totalReverse + totalReverse;
579+
g2.validPairs = g2.validPairs + validPairs;
580+
581+
return g2;
582+
}
583+
else
584+
{
585+
this.rowIdsToDelete.addAll(g2.rowIds);
586+
this.rowIdsToDelete.addAll(g2.rowIdsToDelete);
587+
this.totalReads = g2.totalReads + totalReads;
588+
this.totalForward = g2.totalForward + totalForward;
589+
this.totalReverse = g2.totalReverse + totalReverse;
590+
this.validPairs = g2.validPairs + validPairs;
591+
592+
return this;
593+
}
594+
}
595+
}
596+
}
356597
}

0 commit comments

Comments
 (0)