88import htsjdk .samtools .SAMSequenceRecord ;
99import htsjdk .samtools .util .Interval ;
1010import htsjdk .variant .utils .SAMSequenceDictionaryExtractor ;
11- import org .apache .logging .log4j .Logger ;
12- import org .apache .logging .log4j .LogManager ;
1311import org .apache .commons .lang3 .StringUtils ;
12+ import org .apache .logging .log4j .LogManager ;
13+ import org .apache .logging .log4j .Logger ;
1414import org .jetbrains .annotations .Nullable ;
1515import org .json .JSONObject ;
1616import org .junit .Assert ;
2929import org .labkey .api .sequenceanalysis .SequenceOutputFile ;
3030import org .labkey .api .sequenceanalysis .pipeline .ReferenceGenome ;
3131import org .labkey .api .sequenceanalysis .pipeline .SequenceOutputHandler ;
32+ import org .labkey .api .sequenceanalysis .pipeline .VariantProcessingStep ;
3233import org .labkey .api .writer .PrintWriters ;
3334import org .labkey .sequenceanalysis .util .ScatterGatherUtils ;
3435
4647
4748public class VariantProcessingJob extends SequenceOutputHandlerJob
4849{
49- private ScatterGatherUtils .ScatterGatherMethod _scatterGatherMethod = ScatterGatherUtils .ScatterGatherMethod .none ;
50+ private VariantProcessingStep .ScatterGatherMethod _scatterGatherMethod = VariantProcessingStep .ScatterGatherMethod .none ;
5051 File _dictFile = null ;
5152 Map <String , File > _scatterOutputs = new HashMap <>();
5253 private transient LinkedHashMap <String , List <Interval >> _jobToIntervalMap ;
@@ -68,13 +69,14 @@ protected VariantProcessingJob(VariantProcessingJob parentJob, String intervalSe
6869 _intervalSetName = intervalSetName ;
6970 }
7071
71- public VariantProcessingJob (Container c , User user , @ Nullable String jobName , PipeRoot pipeRoot , SequenceOutputHandler handler , List <SequenceOutputFile > files , JSONObject jsonParams , ScatterGatherUtils .ScatterGatherMethod scatterGatherMethod ) throws IOException , PipelineJobException
72+ public VariantProcessingJob (Container c , User user , @ Nullable String jobName , PipeRoot pipeRoot , SequenceOutputHandler handler , List <SequenceOutputFile > files , JSONObject jsonParams , VariantProcessingStep .ScatterGatherMethod scatterGatherMethod ) throws IOException , PipelineJobException
7273 {
7374 super (c , user , jobName , pipeRoot , handler , files , jsonParams );
7475 _scatterGatherMethod = scatterGatherMethod ;
7576
7677 if (isScatterJob ())
7778 {
79+ validateScatterForTask ();
7880 Set <Integer > genomeIds = new HashSet <>();
7981 for (SequenceOutputFile so : files )
8082 {
@@ -94,30 +96,47 @@ public VariantProcessingJob(Container c, User user, @Nullable String jobName, Pi
9496 }
9597 }
9698
99+ private void validateScatterForTask ()
100+ {
101+ if (!isScatterJob ())
102+ {
103+ return ;
104+ }
105+
106+ if (!(getHandler () instanceof VariantProcessingStep .SupportsScatterGather ))
107+ {
108+ throw new IllegalArgumentException ("Task doe not support Scatter/Gather: " + getHandler ().getName ());
109+ }
110+
111+
112+ VariantProcessingStep .SupportsScatterGather sg = (VariantProcessingStep .SupportsScatterGather )getHandler ();
113+ sg .validateScatter (getScatterGatherMethod (), this );
114+ }
115+
97116 private LinkedHashMap <String , List <Interval >> establishIntervals ()
98117 {
99118 LinkedHashMap <String , List <Interval >> ret ;
100119 SAMSequenceDictionary dict = SAMSequenceDictionaryExtractor .extractDictionary (_dictFile .toPath ());
101- if (_scatterGatherMethod == ScatterGatherUtils .ScatterGatherMethod .contig )
120+ if (_scatterGatherMethod == VariantProcessingStep .ScatterGatherMethod .contig )
102121 {
103122 ret = new LinkedHashMap <>();
104123 for (SAMSequenceRecord rec : dict .getSequences ())
105124 {
106125 ret .put (rec .getSequenceName (), Collections .singletonList (new Interval (rec .getSequenceName (), 1 , rec .getSequenceLength ())));
107126 }
108127 }
109- else if (_scatterGatherMethod == ScatterGatherUtils .ScatterGatherMethod .chunked )
128+ else if (_scatterGatherMethod == VariantProcessingStep .ScatterGatherMethod .chunked )
110129 {
111130 int basesPerJob = getParameterJson ().getInt ("scatterGather.basesPerJob" );
112- boolean allowSplitChromosomes = getParameterJson (). optBoolean ( "scatterGather.allowSplitChromosomes" , true );
131+ boolean allowSplitChromosomes = doAllowSplitContigs ( );
113132 int maxContigsPerJob = getParameterJson ().optInt ("scatterGather.maxContigsPerJob" , -1 );
114133 getLogger ().info ("Creating jobs with target bp size: " + basesPerJob + " mbp. allow splitting configs: " + allowSplitChromosomes + ", max contigs per job: " + maxContigsPerJob );
115134
116135 basesPerJob = basesPerJob * 1000000 ;
117136 ret = ScatterGatherUtils .divideGenome (dict , basesPerJob , allowSplitChromosomes , maxContigsPerJob );
118137
119138 }
120- else if (_scatterGatherMethod == ScatterGatherUtils .ScatterGatherMethod .fixedJobs )
139+ else if (_scatterGatherMethod == VariantProcessingStep .ScatterGatherMethod .fixedJobs )
121140 {
122141 long totalSize = dict .getReferenceLength ();
123142 int numJobs = getParameterJson ().getInt ("scatterGather.totalJobs" );
@@ -133,9 +152,14 @@ else if (_scatterGatherMethod == ScatterGatherUtils.ScatterGatherMethod.fixedJob
133152 return ret ;
134153 }
135154
155+ public boolean doAllowSplitContigs ()
156+ {
157+ return getParameterJson ().optBoolean ("scatterGather.allowSplitChromosomes" , true );
158+ }
159+
136160 public boolean isScatterJob ()
137161 {
138- return _scatterGatherMethod != ScatterGatherUtils .ScatterGatherMethod .none ;
162+ return _scatterGatherMethod != VariantProcessingStep .ScatterGatherMethod .none ;
139163 }
140164
141165 @ JsonIgnore
@@ -296,12 +320,12 @@ public TaskPipeline getTaskPipeline()
296320 return PipelineJobService .get ().getTaskPipeline (new TaskId (VariantProcessingJob .class ));
297321 }
298322
299- public ScatterGatherUtils .ScatterGatherMethod getScatterGatherMethod ()
323+ public VariantProcessingStep .ScatterGatherMethod getScatterGatherMethod ()
300324 {
301325 return _scatterGatherMethod ;
302326 }
303327
304- public void setScatterGatherMethod (ScatterGatherUtils .ScatterGatherMethod scatterGatherMethod )
328+ public void setScatterGatherMethod (VariantProcessingStep .ScatterGatherMethod scatterGatherMethod )
305329 {
306330 _scatterGatherMethod = scatterGatherMethod ;
307331 }
@@ -315,7 +339,7 @@ public void serializeTest() throws Exception
315339 {
316340 VariantProcessingJob job1 = new VariantProcessingJob ();
317341 job1 ._intervalSetName = "chr1" ;
318- job1 ._scatterGatherMethod = ScatterGatherUtils .ScatterGatherMethod .chunked ;
342+ job1 ._scatterGatherMethod = VariantProcessingStep .ScatterGatherMethod .chunked ;
319343
320344 File tmp = new File (System .getProperty ("java.io.tmpdir" ));
321345 File xml = new File (tmp , "variantProcessingJob.txt" );
@@ -342,7 +366,7 @@ public File getDataDirectory()
342366 };
343367
344368 job1 ._intervalSetName = "chr1" ;
345- job1 ._scatterGatherMethod = ScatterGatherUtils .ScatterGatherMethod .chunked ;
369+ job1 ._scatterGatherMethod = VariantProcessingStep .ScatterGatherMethod .chunked ;
346370
347371 Map <String , List <Interval >> intervalMap = new LinkedHashMap <>();
348372 intervalMap .put ("1" , Arrays .asList (new Interval ("chr1" , 1 , 10 )));
0 commit comments