diff --git a/settings/.env.dev b/settings/.env.dev index 359fc2a..3fd7bc1 100644 --- a/settings/.env.dev +++ b/settings/.env.dev @@ -31,3 +31,9 @@ MAVEDB_API_KEY= #################################################################################################### SEQREPO_ROOT_DIR=/usr/local/share/seqrepo/2024-12-20 + +#################################################################################################### +# Environment variables for ensembl +#################################################################################################### + +ENSEMBL_API_URL=https://rest.ensembl.org diff --git a/src/api/routers/map.py b/src/api/routers/map.py index d53cc24..01eea6a 100644 --- a/src/api/routers/map.py +++ b/src/api/routers/map.py @@ -12,6 +12,7 @@ _get_mapped_reference_sequence, _set_scoreset_layer, annotate, + compute_target_gene_info, ) from dcd_mapping.exceptions import ( AlignmentError, @@ -34,6 +35,7 @@ from dcd_mapping.schemas import ( ScoreAnnotation, ScoresetMapping, + TargetAnnotation, TargetType, TxSelectResult, VrsVersion, @@ -196,29 +198,41 @@ async def map_scoreset(urn: str, store_path: Path | None = None) -> JSONResponse try: raw_metadata = get_raw_scoreset_metadata(urn, store_path) - reference_sequences: dict[str, dict] = {} + reference_sequences: dict[str, TargetAnnotation] = {} mapped_scores: list[ScoreAnnotation] = [] for target_gene in annotated_vrs_results: preferred_layers = { _set_scoreset_layer(urn, annotated_vrs_results[target_gene]), } target_gene_name = metadata.target_genes[target_gene].target_gene_name - reference_sequences[target_gene_name] = { + reference_sequences[target_gene_name] = TargetAnnotation() + reference_sequences[target_gene_name].layers = { layer: { "computed_reference_sequence": None, "mapped_reference_sequence": None, } for layer in preferred_layers } + # sometimes Nonetype layers show up in preferred layers dict; remove these preferred_layers.discard(None) + + # Determine one gene symbol per target and its selection method + gene_info = await compute_target_gene_info( + target_key=target_gene, + transcripts=transcripts, + alignment_results=alignment_results, + metadata=metadata, + mapped_scores=annotated_vrs_results[target_gene], + ) + for layer in preferred_layers: - reference_sequences[target_gene_name][layer][ + reference_sequences[target_gene_name].layers[layer][ "computed_reference_sequence" ] = _get_computed_reference_sequence( metadata.target_genes[target_gene], layer, transcripts[target_gene] ) - reference_sequences[target_gene_name][layer][ + reference_sequences[target_gene_name].layers[layer][ "mapped_reference_sequence" ] = _get_mapped_reference_sequence( metadata.target_genes[target_gene], @@ -227,6 +241,9 @@ async def map_scoreset(urn: str, store_path: Path | None = None) -> JSONResponse alignment_results[target_gene], ) + if gene_info is not None: + reference_sequences[target_gene_name].gene_info = gene_info + for m in annotated_vrs_results[target_gene]: if m.pre_mapped is None: mapped_scores.append(ScoreAnnotation(**m.model_dump())) @@ -236,7 +253,7 @@ async def map_scoreset(urn: str, store_path: Path | None = None) -> JSONResponse # if genomic layer, not accession-based, and target gene type is coding, add cdna entry (just the sequence accession) to reference_sequences dict if ( - AnnotationLayer.GENOMIC in reference_sequences[target_gene_name] + AnnotationLayer.GENOMIC in reference_sequences[target_gene_name].layers and metadata.target_genes[target_gene].target_gene_category == TargetType.PROTEIN_CODING and metadata.target_genes[target_gene].target_accession_id is None @@ -244,7 +261,7 @@ async def map_scoreset(urn: str, store_path: Path | None = None) -> JSONResponse and isinstance(transcripts[target_gene], TxSelectResult) and transcripts[target_gene].nm is not None ): - reference_sequences[target_gene_name][AnnotationLayer.CDNA] = { + reference_sequences[target_gene_name].layers[AnnotationLayer.CDNA] = { "computed_reference_sequence": None, "mapped_reference_sequence": { "sequence_accessions": [transcripts[target_gene].nm] @@ -253,18 +270,18 @@ async def map_scoreset(urn: str, store_path: Path | None = None) -> JSONResponse # drop Nonetype reference sequences for target_gene in reference_sequences: - for layer in list(reference_sequences[target_gene].keys()): + for layer in list(reference_sequences[target_gene].layers.keys()): if ( - reference_sequences[target_gene][layer][ + reference_sequences[target_gene].layers[layer][ "mapped_reference_sequence" ] is None - and reference_sequences[target_gene][layer][ + and reference_sequences[target_gene].layers[layer][ "computed_reference_sequence" ] is None ) or layer is None: - del reference_sequences[target_gene][layer] + del reference_sequences[target_gene].layers[layer] except Exception as e: return JSONResponse( diff --git a/src/dcd_mapping/annotate.py b/src/dcd_mapping/annotate.py index 145dafb..0a56d0b 100644 --- a/src/dcd_mapping/annotate.py +++ b/src/dcd_mapping/annotate.py @@ -24,15 +24,20 @@ from dcd_mapping import vrs_v1_schemas from dcd_mapping.lookup import ( + _get_hgnc_symbol, get_chromosome_identifier, get_chromosome_identifier_from_vrs_id, + get_gene_symbol, + get_overlapping_features_for_region, get_seqrepo, + get_ucsc_chromosome_name, get_vrs_id_from_identifier, ) from dcd_mapping.resource_utils import LOCAL_STORE_PATH from dcd_mapping.schemas import ( AlignmentResult, ComputedReferenceSequence, + GeneInfo, MappedReferenceSequence, MappedScore, ScoreAnnotation, @@ -50,6 +55,328 @@ _logger = logging.getLogger(__name__) +async def compute_target_gene_info( + target_key: str, + transcripts: dict[str, TxSelectResult | TxSelectError | None], + alignment_results: dict[str, AlignmentResult | None], + metadata: ScoresetMetadata, + mapped_scores: list[MappedScore] | None = None, +) -> GeneInfo | None: + """Determine a single gene symbol per target with provenance. + + Priority: + 1. HGNC symbol from selected transcript when present + 2. Overlap-based inference from alignment hit subranges + 3. Overlap-based inference from variant spans + 4. Fallback to normalized target metadata symbol + """ + # If a target is not coding, we can't select it's gene. + # TODO#66: Handle regulatory/non-coding targets more intelligently. Our current method of querying + # Ensembl for overlap isn't robust for these cases. + if ( + metadata.target_genes[target_key].target_gene_category + != TargetType.PROTEIN_CODING + ): + _logger.info( + "Target %s is not protein coding. Skipped computing target gene.", + target_key, + ) + return GeneInfo(hgnc_symbol=None, selection_method="target_category") + + # Prefer returning the HGNC symbol directly from a selected transcript. + try: + tx = transcripts.get(target_key) + if tx and isinstance(tx, TxSelectResult): + _logger.info( + "Using selected transcript for gene info for target %s. Computed HGNC: %s", + target_key, + tx.hgnc_symbol, + ) + return GeneInfo(hgnc_symbol=tx.hgnc_symbol, selection_method="tx_selection") + + except Exception: + _logger.exception("Error computing target gene info for target %s", target_key) + return None + + # If we cannot compute gene info via a selected transcript, try to infer a gene symbol from alignment results. + try: + gene_info = _compute_target_gene_info_from_alignment( + target_key, alignment_results.get(target_key) + ) + + if gene_info is not None: + return gene_info + + except Exception: + _logger.exception("Error computing target gene info for target %s", target_key) + return None + + # If we cannot infer a transcript from transcript selection or alignment results, try to compute gene info + # from mapped variant spans. + try: + gene_info = _compute_target_gene_info_from_mapped_variant_spans( + target_key, mapped_scores + ) + + if gene_info is not None: + return gene_info + + except Exception: + _logger.exception( + "Error computing gene info from mapped variant spans for target %s", + target_key, + ) + return None + + # Fallback to target metadata normalization in cases where inference fails. + try: + symbol = get_gene_symbol(metadata.target_genes[target_key]) + if symbol: + _logger.warning( + "Using target metadata for gene info for target %s. Computed HGNC: %s", + target_key, + symbol, + ) + return GeneInfo(hgnc_symbol=symbol, selection_method="target_metadata") + + except Exception: + _logger.exception( + "Error computing gene info from target metadata for target %s", + target_key, + ) + return None + + _logger.warning( + "No gene info could be determined for target %s", + target_key, + ) + return None + + +def _compute_target_gene_info_from_mapped_variant_spans( + target_key: str, mapped_scores: list[MappedScore] | None +) -> GeneInfo | None: + if mapped_scores is None: + return None + + spans = _iter_genomic_spans_from_mapped_scores(mapped_scores) + + # Although multiple chromosomes being present in the same target seems exceedingly unlikely, we handle it just in case + # given it isn't much additional complexity or extra work. + by_chrom: dict[str, list[tuple[int, int]]] = {} + for chrom, start, end in spans: + by_chrom.setdefault(chrom, []).append((start, end)) + + # interval merging helper. Sorts intervals, then merges any which overlap. + # Merging intervals will reduce the number of overlap queries we need to perform. + def merge(intervals: list[tuple[int, int]]) -> list[tuple[int, int]]: + if not intervals: + return [] + + intervals.sort() + m = [intervals[0]] + for s, e in intervals[1:]: + ls, le = m[-1] + if s <= le: + m[-1] = (ls, max(le, e)) + else: + m.append((s, e)) + + return m + + covered_bases: dict[str, int] = {} + for chrom, intervals in by_chrom.items(): + chromosomal_coverage = ( + _covered_bases_from_overlapping_genes_of_chromosomal_intervals( + chrom, merge(intervals) + ) + ) + + # Since HGNC terms are unambiguous and will not exist on multiple chromosomes, + # we can safely merge these dictionaries. + covered_bases = {**covered_bases, **chromosomal_coverage} + + if not covered_bases: + return None + + # Find the gene with the maximum coverage. + max_cov = max(covered_bases.values()) + candidates = sorted([g for g, cov in covered_bases.items() if cov == max_cov]) + + # It doesn't seem guaranteed that the first label in the candidates list will be stable, nor + # that it will be the correct choice if there are multiple candidates. + if len(candidates) > 1: + _logger.warning( + "Multiple genes with maximum coverage for target %s: %s. No gene info will be computed for this target.", + target_key, + candidates, + ) + return None + + _logger.info( + "Using mapped variant spans for gene info for target %s. Computed HGNC: %s", + target_key, + candidates[0], + ) + return GeneInfo( + hgnc_symbol=_get_hgnc_symbol(candidates[0]), + selection_method="variants_max_covered_bases", + ) + + +def _compute_target_gene_info_from_alignment( + target_key: str, alignment_result: AlignmentResult | None +) -> GeneInfo | None: + if alignment_result is None: + return None + + chrom = get_chromosome_identifier(alignment_result.chrom) + + # Find all 'gene' features from Ensembl which at least partially overlap each hit range. + aligned_intervals = [(sub.start, sub.end) for sub in alignment_result.hit_subranges] + covered_bases = _covered_bases_from_overlapping_genes_of_chromosomal_intervals( + chrom, aligned_intervals + ) + + if not covered_bases: + return None + + # Find the gene with the maximum coverage. + max_cov = max(covered_bases.values()) + candidates = sorted([g for g, cov in covered_bases.items() if cov == max_cov]) + + # It doesn't seem guaranteed that the first label in the candidates list will be stable, nor + # that it will be the correct choice if there are multiple candidates. + if len(candidates) > 1: + _logger.warning( + "Multiple genes with maximum coverage for target %s: %s. No gene info will be computed for this target.", + target_key, + candidates, + ) + return None + + _logger.info( + "Using alignment results for gene info for target %s. Computed HGNC: %s", + target_key, + candidates[0], + ) + return GeneInfo( + hgnc_symbol=_get_hgnc_symbol(candidates[0]), + selection_method="alignment_max_covered_bases", + ) + + +def _covered_bases_from_overlapping_genes_of_chromosomal_intervals( + chromosome: str, + intervals: list[tuple[int, int]], +) -> dict[str, int]: + """Compute the number of bases from gene features that overlap given chromosomal intervals. + + This function iterates over a list of genomic intervals, queries overlapping + gene features, and sums the number of overlapping bases per + HGNC symbol across all intervals. Intervals that cause errors during feature + lookup are skipped, and features missing start or end coordinates are ignored. + + Parameters + ---------- + chromosome : str + Chromosome identifier (e.g., "1", "chr1", "X"). + intervals : list[tuple[int, int]] + A list of (start, end) tuples describing 0-based chromosomal intervals, + where start is inclusive and end is exclusive. + + Returns + ------- + dict[str, int] + A mapping from HGNC gene symbol to the total number of bases that overlap + the provided intervals. + + Notes + ----- + - Overlap for each feature is computed as: max(0, min(interval_end, feature_end) - max(interval_start, feature_start)). + - If feature lookup raises an exception, that interval contributes nothing. + - Features without HGNC symbol or valid start/end are skipped. + + """ + covered_bases: dict[str, int] = {} + for s, e in intervals: + try: + matches = get_overlapping_features_for_region( + chromosome, s, e, features=["gene"] + ) + except Exception: + _logger.exception( + "Error fetching overlapping gene features for region %s:%d-%d", + chromosome, + s, + e, + ) + matches = [] + + # The overlapping bases of each feature within our aligned region, 8 in the example below: + # feature: ---------------- + # interval: -------------- + # ^^^^^^^^ + for feature in matches: + hgnc = feature.get("external_name") + feature_start = feature.get("start") + feature_end = feature.get("end") + + if hgnc and feature_start is not None and feature_end is not None: + overlapping_bases = max(0, min(e, feature_end) - max(s, feature_start)) + covered_bases[hgnc] = covered_bases.get(hgnc, 0) + overlapping_bases + _logger.debug( + "Feature %s overlaps region %s:%d-%d by %d bases", + hgnc, + chromosome, + s, + e, + overlapping_bases, + ) + else: + _logger.warning( + "Skipping feature with missing HGNC symbol or invalid coordinates: %s", + feature, + ) + + return covered_bases + + +def _iter_genomic_spans_from_mapped_scores( + mapped_scores: list[MappedScore], +) -> list[tuple[str, int, int]]: + """Extract (chrom, start, end) spans from post-mapped VRS structures. + + Only considers genomic annotations. Returns an empty list if none. + """ + spans: list[tuple[str, int, int]] = [] + for ms in mapped_scores: + if ms.annotation_layer != AnnotationLayer.GENOMIC or ms.post_mapped is None: + continue + if isinstance(ms.post_mapped, Allele): + loc = ms.post_mapped.location + refget_chrom = get_chromosome_identifier_from_vrs_id( + f"ga4gh:{loc.sequenceReference.refgetAccession}" + ) + if refget_chrom: + spans.append( + (get_ucsc_chromosome_name(refget_chrom), loc.start, loc.end) + ) + + elif isinstance(ms.post_mapped, Haplotype): + for allele in ms.post_mapped.members: + loc = allele.location + refget_chrom = get_chromosome_identifier_from_vrs_id( + f"ga4gh:{loc.sequenceReference.refgetAccession}" + ) + if refget_chrom: + spans.append( + (get_ucsc_chromosome_name(refget_chrom), loc.start, loc.end) + ) + + return spans + + def _allele_to_v1_allele(allele: Allele) -> vrs_v1_schemas.Allele: """Convert VRS 2.0 allele to VRS 1.3 allele. @@ -303,7 +630,7 @@ def _annotate_allele_mapping( post_mapped=post_mapped, vrs_version=vrs_version, mavedb_id=mapped_score.accession_id, - score=float(mapped_score.score) if mapped_score.score else None, + score=float(mapped_score.score) if mapped_score.score is not None else None, annotation_layer=mapped_score.annotation_layer, error_message=mapped_score.error_message, ) @@ -403,7 +730,9 @@ def annotate( score_annotations.append( ScoreAnnotationWithLayer( mavedb_id=mapped_score.accession_id, - score=float(mapped_score.score) if mapped_score.score else None, + score=float(mapped_score.score) + if mapped_score.score is not None + else None, vrs_version=vrs_version, error_message=mapped_score.error_message, ) @@ -433,7 +762,9 @@ def annotate( post_mapped=mapped_score.post_mapped, vrs_version=vrs_version, mavedb_id=mapped_score.accession_id, - score=float(mapped_score.score) if mapped_score.score else None, + score=float(mapped_score.score) + if mapped_score.score is not None + else None, error_message=f"Multiple issues with annotation: Inconsistent variant structure (Allele and Haplotype mix).{' ' + mapped_score.error_message if mapped_score.error_message else ''}", ) ) diff --git a/src/dcd_mapping/lookup.py b/src/dcd_mapping/lookup.py index 34de2f0..318ab69 100644 --- a/src/dcd_mapping/lookup.py +++ b/src/dcd_mapping/lookup.py @@ -11,6 +11,7 @@ import logging import os from pathlib import Path +from typing import Any import hgvs import polars as pl @@ -50,6 +51,7 @@ from gene.schemas import MatchType, SourceName from dcd_mapping.exceptions import DataLookupError +from dcd_mapping.resource_utils import ENSEMBL_API_URL, request_with_backoff from dcd_mapping.schemas import ( GeneLocation, ManeDescription, @@ -645,6 +647,68 @@ def _sort_mane_result(description: ManeDescription) -> int: return mane_data +# --------------------------------- Ensembl --------------------------------- # + + +def get_overlapping_features_for_region( + chromosome: str, start: int, end: int, features: list[str] | None = None +) -> list[dict[str, Any]]: + """Get genes overlapping a specific genomic region. + + :param chromosome: Chromosome identifier + :param start: Start position of the region + :param end: End position of the region + :param features: List of features to retrieve (default is ["gene"]) + :return: List of overlapping gene symbols + """ + if not features: + features = ["gene"] + _logger.debug("No features specified, defaulting to %s", features) + + chrom = get_chromosome_identifier(chromosome) + + query = f"/{chrom}:{start}-{end}" + if features: + query += "?" + for feature in features: + query += f"feature={feature};" + + try: + _logger.debug( + "Fetching overlapping features for region %s:%d-%d with features %s", + chromosome, + start, + end, + features, + ) + + url = f"{ENSEMBL_API_URL}/overlap/region/human{query}" + response = request_with_backoff( + url, headers={"Content-Type": "application/json"} + ) + response.raise_for_status() + except requests.RequestException as e: + _logger.error( + "Failed to fetch overlapping features for region %s-%s on chromosome %s: %s", + start, + end, + chromosome, + e, + ) + return [] + + overlapping_features = response.json() + _logger.debug( + "Successfully fetched %d overlapping features for region %s:%d-%d with features %s", + len(overlapping_features), + chromosome, + start, + end, + features, + ) + return overlapping_features + + # ---------------------------------- Misc. ---------------------------------- # diff --git a/src/dcd_mapping/resource_utils.py b/src/dcd_mapping/resource_utils.py index f84caf1..01a67cd 100644 --- a/src/dcd_mapping/resource_utils.py +++ b/src/dcd_mapping/resource_utils.py @@ -1,5 +1,6 @@ """Provide basic utilities for fetching and storing external data.""" import os +import time from pathlib import Path import click @@ -8,6 +9,7 @@ MAVEDB_API_KEY = os.environ.get("MAVEDB_API_KEY") MAVEDB_BASE_URL = os.environ.get("MAVEDB_BASE_URL") +ENSEMBL_API_URL = os.environ.get("ENSEMBL_API_URL", "https://rest.ensembl.org") # TODO LOCAL_STORE_PATH = Path( os.environ.get( @@ -57,3 +59,67 @@ def http_download(url: str, out_path: Path, silent: bool = True) -> Path: if chunk: h.write(chunk) return out_path + + +def request_with_backoff( + url: str, max_retries: int = 5, backoff_factor: float = 0.3, **kwargs +) -> requests.Response: + """HTTP GET with exponential backoff only for retryable errors. + + Retries on: + - Connection timeout or connection errors + - HTTP 5xx server errors + - HTTP 429 rate limiting (respecting Retry-After when present) + + Immediately raises on other HTTP errors (e.g., 4xx client errors). + """ + attempt = 0 + while attempt < max_retries: + try: + response = requests.get(url, timeout=60, **kwargs) + except (requests.Timeout, requests.ConnectionError): + # Retry on transient network failures + if attempt == max_retries - 1: + raise + sleep_time = backoff_factor * (2**attempt) + time.sleep(sleep_time) + attempt += 1 + continue + + # If we have a response, decide retry based on status code + status = response.status_code + if 200 <= status < 300: + return response + + # 429: Too Many Requests — optionally use Retry-After + if status == 429: + if attempt == max_retries - 1: + response.raise_for_status() + retry_after = response.headers.get("Retry-After") + try: + sleep_time = ( + float(retry_after) + if retry_after is not None + else backoff_factor * (2**attempt) + ) + except ValueError: + sleep_time = backoff_factor * (2**attempt) + time.sleep(sleep_time) + attempt += 1 + continue + + # 5xx: server errors — retry + if 500 <= status < 600: + if attempt == max_retries - 1: + response.raise_for_status() + sleep_time = backoff_factor * (2**attempt) + time.sleep(sleep_time) + attempt += 1 + continue + + # Non-retryable (e.g., 4xx other than 429): raise immediately + response.raise_for_status() + + # Exhausted retries without success + msg = f"Failed to fetch {url} after {max_retries} attempts" + raise Exception(msg) diff --git a/src/dcd_mapping/schemas.py b/src/dcd_mapping/schemas.py index 072e3b8..7d213df 100644 --- a/src/dcd_mapping/schemas.py +++ b/src/dcd_mapping/schemas.py @@ -152,6 +152,7 @@ class TxSelectResult(BaseModel): is_full_match: StrictBool transcript_mode: TranscriptPriority | None = None sequence: str + hgnc_symbol: str | None = None class MappedScore(BaseModel): @@ -187,6 +188,49 @@ class ScoreAnnotation(BaseModel): error_message: str | None = None +class GeneInfo(BaseModel): + """Basic gene metadata for a target, including symbol and selection method.""" + + hgnc_symbol: str | None = None + selection_method: str | None = None + + +class TargetAnnotation(BaseModel): + """Represents annotations associated with a biological target, including optional gene metadata + and structured annotation layers. + + Attributes + ---------- + gene_info : GeneInfo | None + Optional metadata describing the gene associated with the target, + including identifiers and descriptive information where available. + + layers : dict[AnnotationLayer, dict[str, ComputedReferenceSequence | MappedReferenceSequence | dict | None]] + A mapping of annotation layers to keyed layer data. Each layer is identified by an + AnnotationLayer key and contains a dictionary where: + - keys are string identifiers for items within the layer (e.g., feature names), + - values are one of: + - ComputedReferenceSequence: a computed sequence representation for the item, + - MappedReferenceSequence: a sequence mapped to a reference coordinate system, + - dict: a generic dictionary for custom layer-specific payloads, + - None: indicating missing or intentionally omitted data. + + Notes + ----- + - The default value for 'layers' is an empty dictionary. + - This model is intended to standardize layer-based annotations for downstream processing + and validation, allowing both computed and mapped sequence data to coexist within the same + structure. + + """ + + gene_info: GeneInfo | None = None + layers: dict[ + AnnotationLayer, + dict[str, ComputedReferenceSequence | MappedReferenceSequence | dict | None], + ] = {} + + class ScoreAnnotationWithLayer(ScoreAnnotation): """Couple annotations with an easily-computable definition of the annotation layer from which they originate. @@ -205,14 +249,6 @@ class ScoresetMapping(BaseModel): mapped_date_utc: str = Field( default=datetime.datetime.now(tz=datetime.UTC).isoformat() ) - reference_sequences: dict[ - str, - dict[ - AnnotationLayer, - dict[ - str, ComputedReferenceSequence | MappedReferenceSequence | dict | None - ], - ], - ] | None = None + reference_sequences: dict[str, TargetAnnotation] | None = None mapped_scores: list[ScoreAnnotation] | None = None error_message: str | None = None diff --git a/src/dcd_mapping/transcripts.py b/src/dcd_mapping/transcripts.py index 00080d3..bd4e97a 100644 --- a/src/dcd_mapping/transcripts.py +++ b/src/dcd_mapping/transcripts.py @@ -173,6 +173,7 @@ async def _select_protein_reference( raise TxSelectError(msg) nm_accession = None tx_mode = None + hgnc_symbol = None else: mane_transcripts = get_mane_transcripts(common_transcripts) best_tx = _choose_best_mane_transcript(mane_transcripts) @@ -185,6 +186,7 @@ async def _select_protein_reference( nm_accession = best_tx.refseq_nuc np_accession = best_tx.refseq_prot tx_mode = best_tx.transcript_priority + hgnc_symbol = best_tx.symbol protein_sequence = _get_protein_sequence(target_gene.target_sequence) is_full_match = ref_sequence.find(protein_sequence) != -1 @@ -197,6 +199,7 @@ async def _select_protein_reference( is_full_match=is_full_match, sequence=protein_sequence, transcript_mode=tx_mode, + hgnc_symbol=hgnc_symbol, ) diff --git a/tests/test_annotate.py b/tests/test_annotate.py new file mode 100644 index 0000000..6a7e18e --- /dev/null +++ b/tests/test_annotate.py @@ -0,0 +1,283 @@ +"""Tests for dcd_mapping.annotate""" +from unittest import mock + +import pytest +from ga4gh.vrs._internal.models import ( + Allele, + LiteralSequenceExpression, + SequenceLocation, + SequenceReference, +) + +from dcd_mapping.annotate import ( + _compute_target_gene_info_from_alignment, + _compute_target_gene_info_from_mapped_variant_spans, + _covered_bases_from_overlapping_genes_of_chromosomal_intervals, + compute_target_gene_info, +) +from dcd_mapping.schemas import ( + AlignmentResult, + AnnotationLayer, + GeneInfo, + MappedScore, + ScoresetMetadata, + SequenceRange, + TargetGene, + TargetSequenceType, + TargetType, + TxSelectResult, +) + + +@pytest.fixture() +def target_dna_pc(): + return TargetGene( + target_gene_name="BRAF", + target_gene_category=TargetType.PROTEIN_CODING, + target_sequence="ATGGCG...", + target_sequence_type=TargetSequenceType.DNA, + target_accession_id=None, + target_uniprot_ref=None, + ) + + +@pytest.fixture() +def scoreset_metadata(target_dna_pc): + return ScoresetMetadata( + urn="urn:mavedb:TEST", + score_count=1, + target_genes={"label": target_dna_pc}, + mapped=False, + ) + + +def make_align(hit_intervals): + return AlignmentResult( + chrom="NC_000001.11", + strand=1, + coverage=None, + ident_pct=None, + query_range=SequenceRange(start=1, end=10), + query_subranges=[SequenceRange(start=1, end=10)], + hit_range=SequenceRange(start=1, end=10), + hit_subranges=[SequenceRange(start=s, end=e) for s, e in hit_intervals], + ) + + +@pytest.mark.asyncio() +async def test_compute_target_gene_info_non_coding_category(): + meta = ScoresetMetadata( + urn="urn:mavedb:TEST", + score_count=0, + target_genes={ + "t": TargetGene( + target_gene_name="REG", + target_gene_category=TargetType.OTHER_NC, + target_sequence="ACGT", + target_sequence_type=TargetSequenceType.DNA, + ) + }, + mapped=False, + ) + + res = await compute_target_gene_info("t", {}, {}, meta, None) + assert isinstance(res, GeneInfo) + assert res.hgnc_symbol is None + assert res.selection_method == "target_category" + + +@pytest.mark.asyncio() +async def test_compute_target_gene_info_tx_selection(scoreset_metadata): + tx = TxSelectResult( + nm="NM_000001.1", + np="NP_000001.1", + start=0, + is_full_match=True, + sequence="MSEQUENCE", + transcript_mode=None, + hgnc_symbol="BRAF", + ) + res = await compute_target_gene_info( + "label", {"label": tx}, {"label": None}, scoreset_metadata, None + ) + assert isinstance(res, GeneInfo) + assert res.hgnc_symbol == "BRAF" + assert res.selection_method == "tx_selection" + + +def test_compute_target_gene_info_alignment_path(scoreset_metadata): + align = make_align([(100, 120)]) + with ( + mock.patch( + "dcd_mapping.annotate.get_chromosome_identifier", + side_effect=lambda c: c, + ), + mock.patch( + "dcd_mapping.annotate.get_overlapping_features_for_region", + return_value=[{"external_name": "GENE1", "start": 95, "end": 130}], + ), + mock.patch("dcd_mapping.annotate._get_hgnc_symbol", side_effect=lambda s: s), + ): + res = _compute_target_gene_info_from_alignment("label", align) + assert isinstance(res, GeneInfo) + assert res.hgnc_symbol == "GENE1" + assert res.selection_method == "alignment_max_covered_bases" + + +def test_compute_target_gene_info_alignment_tie_returns_none(scoreset_metadata): + align = make_align([(100, 120)]) + with ( + mock.patch( + "dcd_mapping.annotate.get_chromosome_identifier", + side_effect=lambda c: c, + ), + mock.patch( + "dcd_mapping.annotate.get_overlapping_features_for_region", + return_value=[ + {"external_name": "GENE1", "start": 100, "end": 120}, + {"external_name": "GENE2", "start": 100, "end": 120}, + ], + ), + ): + res = _compute_target_gene_info_from_alignment("label", align) + assert res is None + + +def test_compute_target_gene_info_mapped_variants_path(scoreset_metadata): + # Build mapped scores to yield spans that overlap single gene best + allele = Allele( + location=SequenceLocation( + sequenceReference=SequenceReference( + refgetAccession="SQ.1234567890abcdef1234567890abcdef" + ), + start=100, + end=110, + ), + state=LiteralSequenceExpression(sequence="A"), + expressions=[], + ) + ms = MappedScore( + mavedb_id="id", + accession_id="id", + pre_mapped=None, + post_mapped=allele, + annotation_layer=AnnotationLayer.GENOMIC, + score=None, + error_message=None, + ) + with ( + mock.patch( + "dcd_mapping.annotate.get_chromosome_identifier_from_vrs_id", + return_value="refseq:NC_000001.11", + ), + mock.patch( + "dcd_mapping.annotate.get_ucsc_chromosome_name", + return_value="NC_000001.11", + ), + mock.patch("dcd_mapping.annotate._get_hgnc_symbol", side_effect=lambda s: s), + mock.patch( + "dcd_mapping.annotate.get_overlapping_features_for_region", + return_value=[{"external_name": "GENE3", "start": 50, "end": 200}], + ), + ): + res = _compute_target_gene_info_from_mapped_variant_spans("label", [ms]) + assert isinstance(res, GeneInfo) + assert res.hgnc_symbol == "GENE3" + assert res.selection_method == "variants_max_covered_bases" + + +@pytest.mark.asyncio() +async def test_compute_target_gene_info_fallback_metadata(scoreset_metadata): + # No tx, no alignment, no mapped scores -> fallback + with mock.patch("dcd_mapping.annotate.get_gene_symbol", return_value="META"): + res = await compute_target_gene_info( + "label", {"label": None}, {"label": None}, scoreset_metadata, None + ) + assert isinstance(res, GeneInfo) + assert res.hgnc_symbol == "META" + assert res.selection_method == "target_metadata" + + +@pytest.mark.asyncio() +async def test_compute_target_gene_info_fallback_unavailable(scoreset_metadata): + # No tx, no alignment, no mapped scores -> fallback + with mock.patch("dcd_mapping.annotate.get_gene_symbol", return_value=None): + res = await compute_target_gene_info( + "label", {"label": None}, {"label": None}, scoreset_metadata, None + ) + assert res is None + + +def test_covered_bases_sums_and_skips_invalid(): + with ( + mock.patch( + "dcd_mapping.annotate.get_overlapping_features_for_region", + return_value=[ + {"external_name": "A", "start": 10, "end": 30}, + {"external_name": None, "start": 10, "end": 30}, + {"external_name": "B", "start": None, "end": 100}, + ], + ), + ): + cov = _covered_bases_from_overlapping_genes_of_chromosomal_intervals( + "NC_000001.11", [(15, 25), (20, 35)] + ) + # Overlaps: A: (15-25)=10 and (20-30)=10 => total 20 + assert cov == {"A": 20} + + +def test_interval_merging_overlapping_and_adjacent_merge(scoreset_metadata): + # Two intervals on same chrom: overlapping and adjacent; both should merge before coverage + + def make_ms(start, end): + allele = Allele( + location=SequenceLocation( + sequenceReference=SequenceReference( + refgetAccession="SQ.1234567890abcdef1234567890abcdef" + ), + start=start, + end=end, + ), + state=LiteralSequenceExpression(sequence="A"), + expressions=[], + ) + return MappedScore( + mavedb_id=f"id_{start}", + accession_id=f"id_{start}", + pre_mapped=None, + post_mapped=allele, + annotation_layer=AnnotationLayer.GENOMIC, + score=None, + error_message=None, + ) + + ms_list = [ + make_ms(100, 110), + make_ms(108, 120), + make_ms(120, 130), + ] # overlap then adjacent + + # After merging, intervals become [(100, 130)], expect single coverage call and length 30 contributing to GENEZ + def fake_overlap(chrom, s, e, features=None): + assert (s, e) == (100, 130) + return [{"external_name": "GENEZ", "start": 90, "end": 200}] + + with ( + mock.patch( + "dcd_mapping.annotate.get_chromosome_identifier_from_vrs_id", + return_value="refseq:NC_000001.11", + ), + mock.patch( + "dcd_mapping.annotate.get_ucsc_chromosome_name", + return_value="NC_000001.11", + ), + mock.patch("dcd_mapping.annotate._get_hgnc_symbol", side_effect=lambda s: s), + mock.patch( + "dcd_mapping.annotate.get_overlapping_features_for_region", + side_effect=fake_overlap, + ), + ): + res = _compute_target_gene_info_from_mapped_variant_spans("label", ms_list) + assert isinstance(res, GeneInfo) + assert res.hgnc_symbol == "GENEZ" + assert res.selection_method == "variants_max_covered_bases" diff --git a/tests/test_lookup.py b/tests/test_lookup.py new file mode 100644 index 0000000..52ef287 --- /dev/null +++ b/tests/test_lookup.py @@ -0,0 +1,107 @@ +"""Tests for dcd_mapping.lookup""" + +from unittest.mock import patch + +import requests + +from dcd_mapping.lookup import get_overlapping_features_for_region + +RAW_OVERLAP_RESPONSE = [ + { + "seq_region_name": "22", + "version": 1, + "biotype": "protein_coding", + "feature_type": "gene", + "description": "novel transcript", + "logic_name": "havana_homo_sapiens", + "start": 19717220, + "id": "ENSG00000284874", + "source": "havana", + "canonical_transcript": "ENST00000455843.5", + "assembly_name": "GRCh38", + "end": 19724772, + "gene_id": "ENSG00000284874", + "strand": 1, + }, + { + "source": "ensembl_havana", + "canonical_transcript": "ENST00000366425.4", + "assembly_name": "GRCh38", + "end": 19724776, + "gene_id": "ENSG00000203618", + "strand": 1, + "external_name": "GP1BB", + "seq_region_name": "22", + "version": 7, + "biotype": "protein_coding", + "logic_name": "ensembl_havana_gene_homo_sapiens", + "feature_type": "gene", + "start": 19723539, + "description": "glycoprotein Ib platelet subunit beta [Source:HGNC Symbol;Acc:HGNC:4440]", + "id": "ENSG00000203618", + }, + { + "end": 19724224, + "gene_id": "ENSG00000184702", + "strand": 1, + "canonical_transcript": "ENST00000455784.7", + "source": "ensembl_havana", + "assembly_name": "GRCh38", + "seq_region_name": "22", + "version": 21, + "biotype": "protein_coding", + "description": "septin 5 [Source:HGNC Symbol;Acc:HGNC:9164]", + "feature_type": "gene", + "logic_name": "ensembl_havana_gene_homo_sapiens", + "start": 19714467, + "id": "ENSG00000184702", + "external_name": "SEPTIN5", + }, +] + + +class _FakeResponse: + def __init__(self, data): + self._data = data + self.status_code = 200 + + def json(self): + return self._data + + def raise_for_status(self): + return None + + +def test_get_overlapping_features_for_region_success(): + with ( + patch( + "dcd_mapping.lookup.request_with_backoff", + return_value=_FakeResponse(RAW_OVERLAP_RESPONSE), + ), + patch("dcd_mapping.lookup.get_chromosome_identifier", side_effect=lambda c: c), + ): + result = get_overlapping_features_for_region( + "NC_000022.11", 19714000, 19725000, features=["gene"] + ) + assert isinstance(result, list) + assert result == RAW_OVERLAP_RESPONSE + + +def test_get_overlapping_features_for_region_error(): + class ErrorResponse(_FakeResponse): + def __init__(self): + super().__init__(None) + self.status_code = 500 + + def raise_for_status(self): + msg = f"HTTP {self.status_code} Error" + raise requests.RequestException(msg) + + with ( + patch("dcd_mapping.lookup.request_with_backoff", return_value=ErrorResponse()), + patch("dcd_mapping.lookup.get_chromosome_identifier", side_effect=lambda c: c), + ): + result = get_overlapping_features_for_region( + "NC_000022.11", 19714000, 19725000, features=["gene"] + ) + assert result == [] diff --git a/tests/test_resource_utils.py b/tests/test_resource_utils.py new file mode 100644 index 0000000..6c6055a --- /dev/null +++ b/tests/test_resource_utils.py @@ -0,0 +1,201 @@ +"""Tests for dcd_mapping.resource_utils""" + +from contextlib import ExitStack +from unittest import mock + +import pytest +import requests + +from dcd_mapping.resource_utils import request_with_backoff + + +class _DummyResponse: + def __init__(self, status_code=200, headers=None): + self.status_code = status_code + self.headers = headers or {} + + def raise_for_status(self): + if not (200 <= self.status_code < 300): + msg = f"HTTP {self.status_code}" + raise requests.HTTPError(msg) + + +def _sequence_side_effect(values): + """Turn a list of values/exceptions into a side_effect callable.""" + it = iter(values) + + def _next(*args, **kwargs): # noqa: ANN002 + v = next(it) + if isinstance(v, BaseException): + raise v + return v + + return _next + + +def test_success_200_returns_response(): + dummy = _DummyResponse(200) + with mock.patch( + "dcd_mapping.resource_utils.requests.get", return_value=dummy + ) as get_mock: + resp = request_with_backoff("http://example.com/resource") + assert resp is dummy + get_mock.assert_called_once() + + +def test_timeout_then_success_retries_with_backoff(): + first_exc = requests.Timeout("timeout") + second_resp = _DummyResponse(200) + with ExitStack() as stack: + get_mock = stack.enter_context( + mock.patch( + "dcd_mapping.resource_utils.requests.get", + side_effect=_sequence_side_effect([first_exc, second_resp]), + ) + ) + sleep_mock = stack.enter_context( + mock.patch("dcd_mapping.resource_utils.time.sleep") + ) + resp = request_with_backoff("http://example.com/resource", backoff_factor=0.5) + assert resp is second_resp + assert get_mock.call_count == 2 + # first backoff attempt uses factor * (2**0) == 0.5 + sleep_mock.assert_called_once_with(0.5) + + +def test_connection_error_until_max_raises(): + seq = [requests.ConnectionError("conn err")] * 5 + with ExitStack() as stack: + stack.enter_context( + mock.patch( + "dcd_mapping.resource_utils.requests.get", + side_effect=_sequence_side_effect(seq), + ) + ) + sleep_mock = stack.enter_context( + mock.patch("dcd_mapping.resource_utils.time.sleep") + ) + with pytest.raises(requests.ConnectionError): + request_with_backoff( + "http://example.com/resource", max_retries=5, backoff_factor=0.1 + ) + # should have slept 4 times (no sleep on final raise) + assert sleep_mock.call_count == 4 + # verify exponential values: 0.1, 0.2, 0.4, 0.8 + assert [c.args[0] for c in sleep_mock.mock_calls] == [0.1, 0.2, 0.4, 0.8] + + +def test_5xx_then_success_retries(): + seq = [_DummyResponse(503), _DummyResponse(200)] + with ExitStack() as stack: + get_mock = stack.enter_context( + mock.patch( + "dcd_mapping.resource_utils.requests.get", + side_effect=_sequence_side_effect(seq), + ) + ) + sleep_mock = stack.enter_context( + mock.patch("dcd_mapping.resource_utils.time.sleep") + ) + resp = request_with_backoff("http://example.com/resource", backoff_factor=0.25) + assert resp.status_code == 200 + assert get_mock.call_count == 2 + sleep_mock.assert_called_once_with(0.25) # first attempt sleep + + +def test_5xx_until_max_raises_http_error(): + seq = [_DummyResponse(500), _DummyResponse(500), _DummyResponse(500)] + with ExitStack() as stack: + stack.enter_context( + mock.patch( + "dcd_mapping.resource_utils.requests.get", + side_effect=_sequence_side_effect(seq), + ) + ) + sleep_mock = stack.enter_context( + mock.patch("dcd_mapping.resource_utils.time.sleep") + ) + with pytest.raises(requests.HTTPError): + request_with_backoff( + "http://example.com/resource", max_retries=3, backoff_factor=0.1 + ) + # slept for first two attempts, then raised on third + assert [c.args[0] for c in sleep_mock.mock_calls] == [0.1, 0.2] + + +def test_429_with_retry_after_header_respected(): + resp1 = _DummyResponse(429, headers={"Retry-After": "1.7"}) + resp2 = _DummyResponse(200) + with ExitStack() as stack: + get_mock = stack.enter_context( + mock.patch( + "dcd_mapping.resource_utils.requests.get", + side_effect=_sequence_side_effect([resp1, resp2]), + ) + ) + sleep_mock = stack.enter_context( + mock.patch("dcd_mapping.resource_utils.time.sleep") + ) + resp = request_with_backoff("http://example.com/resource", backoff_factor=0.9) + assert resp.status_code == 200 + assert get_mock.call_count == 2 + sleep_mock.assert_called_once_with(1.7) + + +def test_429_with_bad_retry_after_falls_back_to_backoff(): + resp1 = _DummyResponse(429, headers={"Retry-After": "not-a-number"}) + resp2 = _DummyResponse(200) + with ExitStack() as stack: + stack.enter_context( + mock.patch( + "dcd_mapping.resource_utils.requests.get", + side_effect=_sequence_side_effect([resp1, resp2]), + ) + ) + sleep_mock = stack.enter_context( + mock.patch("dcd_mapping.resource_utils.time.sleep") + ) + resp = request_with_backoff("http://example.com/resource", backoff_factor=0.3) + assert resp.status_code == 200 + sleep_mock.assert_called_once_with(0.3) # backoff_factor * (2**0) + + +def test_non_retryable_4xx_raises_immediately(): + resp = _DummyResponse(404) + with ExitStack() as stack: + stack.enter_context( + mock.patch("dcd_mapping.resource_utils.requests.get", return_value=resp) + ) + sleep_mock = stack.enter_context( + mock.patch("dcd_mapping.resource_utils.time.sleep") + ) + with pytest.raises(requests.HTTPError): + request_with_backoff("http://example.com/resource") + # no sleeps for non-retryable errors + sleep_mock.assert_not_called() + + +def test_exhausted_retries_without_response_raises_request_exception(): + # The only way to trigger the terminal state in the function is to not even + # attempt a request (max_retries=0) + with mock.patch( + "dcd_mapping.resource_utils.requests.get", return_value=_DummyResponse(500) + ), mock.patch("dcd_mapping.resource_utils.time.sleep"), pytest.raises( + Exception # noqa: PT011 + ) as exc: + request_with_backoff("http://example.com/resource", max_retries=0) + assert "Failed to fetch" in str(exc.value) + + +def test_kwargs_are_passed_through_to_requests_get(): + dummy = _DummyResponse(200) + with mock.patch( + "dcd_mapping.resource_utils.requests.get", return_value=dummy + ) as get_mock: + request_with_backoff( + "http://example.com/resource", headers={"X-Test": "1"}, params={"q": "x"} + ) + called_kwargs = get_mock.call_args.kwargs + assert called_kwargs["headers"] == {"X-Test": "1"} + assert called_kwargs["params"] == {"q": "x"} + assert called_kwargs["timeout"] == 60