diff --git a/pori_python/graphkb/constants.py b/pori_python/graphkb/constants.py index fe22f4a0..07f686b9 100644 --- a/pori_python/graphkb/constants.py +++ b/pori_python/graphkb/constants.py @@ -59,7 +59,10 @@ TSO500_SOURCE_NAME = 'tso500' ONCOGENE = 'oncogenic' TUMOUR_SUPPRESSIVE = 'tumour suppressive' -CANCER_GENE = 'cancer gene' +CANCER_GENE = [ + 'cancer gene', + 'tumourigenesis', +] # KBDEV-1532. tumourigenesis for backward compatibility FUSION_NAMES = ['structural variant', 'fusion'] GSC_PHARMACOGENOMIC_SOURCE_EXCLUDE_LIST = ['cancer genome interpreter', 'civic'] diff --git a/pori_python/graphkb/genes.py b/pori_python/graphkb/genes.py index 09da3ed7..376693af 100644 --- a/pori_python/graphkb/genes.py +++ b/pori_python/graphkb/genes.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any, Dict, List, Sequence, Set, Tuple, cast +from typing import Any, Dict, List, Sequence, Set, Tuple, cast, Union from typing_extensions import deprecated from pori_python.types import IprGene, Ontology, Record, Statement, Variant @@ -27,8 +27,117 @@ from .vocab import get_terms_set +def get_cancer_gene_flags( + conn: GraphKBConnection, + flags: bool = False, + ignore_cache: bool = False, +) -> Union[List[Record], Dict[str, List[Record]]]: + """ + Return all cancer genes, optionally sorted by flags. + + Flag definitions: + oncogenic: relevance 'oncogenic' from OncoKB + tumourSuppressive: relevance 'tumour suppressive' from OncoKB + cancerGene: relevance 'cancer gene' AND child terms ('oncogenic', 'tumour suppressive', 'other cancer gene'), from OncoKB AND TSO500 + + Args: + conn: the graphkb connection object + flags: if the results should be sorted by flags + ignore_cache: if cache should be ignored when querying GraphKB API + + Returns (if flags=False; default): list of unique gene records + [ , , ... ] + + Returns (if flags=True): dict of flags as keys, and list of gene records as value + { + 'oncogenic': [ , , ... ], + 'tumourSuppressive': [ , , ... ], + 'cancerGene': [ , , ... ], + } + """ + # all cancer gene statements + cancer_genes = conn.get_related_terms( + terms=CANCER_GENE, + subgraphType='children', + ) + statements = cast( + List[Statement], + conn.query( + { + 'target': 'Statement', + 'filters': { + 'relevance': {'target': 'Vocabulary', 'filters': {'name': cancer_genes}} + }, + 'returnProperties': [ + 'source.name', + 'relevance.name', + *[f'subject.{prop}' for prop in GENE_RETURN_PROPERTIES], + ], + }, + ignore_cache=ignore_cache, + ), + ) + + # post-query filtering (faster) + cancerGeneStms = list( + filter( + lambda r: ( + r['subject']['@class'] == 'Feature' + and r['subject']['biotype'] == 'gene' + and r['source']['name'] in [ONCOKB_SOURCE_NAME, TSO500_SOURCE_NAME] + ), + statements, + ) + ) + oncogenicStms = list( + filter( + lambda r: ( + r['relevance']['name'] == ONCOGENE and r['source']['name'] == ONCOKB_SOURCE_NAME + ), + cancerGeneStms, + ) + ) + tumourSuppressiveStms = list( + filter( + lambda r: ( + r['relevance']['name'] == TUMOUR_SUPPRESSIVE + and r['source']['name'] == ONCOKB_SOURCE_NAME + ), + cancerGeneStms, + ) + ) + + # Returning a sorted list of unique gene records, based on iProbe requirements + # Unique by name, sorted by displayName + if not flags: + seen: set = set() + unique_genes: List[Record] = [] + for r in cancerGeneStms: + name = r['subject']['name'] + if name not in seen: + seen.add(name) + unique_genes.append(r['subject']) + + return cast( + List[Record], + sorted(unique_genes, key=lambda gene: gene['displayName']), + ) + + # Returning a Dict of flags, with list of associated gene records + # Duplicates are ok + return { + 'cancerGene': [r['subject'] for r in cancerGeneStms], + 'oncogenic': [r['subject'] for r in oncogenicStms], + 'tumourSuppressive': [r['subject'] for r in tumourSuppressiveStms], + } + + +@deprecated('functionality replaced by get_cancer_gene_flags') def _get_tumourigenesis_genes_list( - conn: GraphKBConnection, relevance: str, sources: List[str], ignore_cache: bool = False + conn: GraphKBConnection, + relevance: Union[str, List[str]], + sources: Union[str, List[str]], + ignore_cache: bool = False, ) -> List[Ontology]: statements = cast( List[Statement], @@ -57,6 +166,7 @@ def _get_tumourigenesis_genes_list( return [gene for gene in genes.values()] +@deprecated('functionality replaced by get_cancer_gene_flags') def get_oncokb_oncogenes(conn: GraphKBConnection) -> List[Ontology]: """Get the list of oncogenes stored in GraphKB derived from OncoKB. @@ -66,9 +176,10 @@ def get_oncokb_oncogenes(conn: GraphKBConnection) -> List[Ontology]: Returns: gene (Feature) records """ - return _get_tumourigenesis_genes_list(conn, ONCOGENE, [ONCOKB_SOURCE_NAME]) + return _get_tumourigenesis_genes_list(conn, ONCOGENE, ONCOKB_SOURCE_NAME) +@deprecated('functionality replaced by get_cancer_gene_flags') def get_oncokb_tumour_supressors(conn: GraphKBConnection) -> List[Ontology]: """Get the list of tumour supressor genes stored in GraphKB derived from OncoKB. @@ -78,11 +189,14 @@ def get_oncokb_tumour_supressors(conn: GraphKBConnection) -> List[Ontology]: Returns: gene (Feature) records """ - return _get_tumourigenesis_genes_list(conn, TUMOUR_SUPPRESSIVE, [ONCOKB_SOURCE_NAME]) + return _get_tumourigenesis_genes_list(conn, TUMOUR_SUPPRESSIVE, ONCOKB_SOURCE_NAME) +@deprecated('functionality replaced by get_cancer_gene_flags') def get_cancer_genes(conn: GraphKBConnection) -> List[Ontology]: - """Get the list of cancer genes stored in GraphKB derived from OncoKB & TSO500. + """ + Get the list of cancer genes stored in GraphKB derived from OncoKB & TSO500. + Cancer genes include oncogenes, tumour supressor genes and other cancer genes. Args: conn: the graphkb connection object @@ -90,8 +204,12 @@ def get_cancer_genes(conn: GraphKBConnection) -> List[Ontology]: Returns: gene (Feature) records """ + cancer_gene_terms = conn.get_related_terms( + terms=CANCER_GENE, + subgraphType='children', + ) return _get_tumourigenesis_genes_list( - conn, CANCER_GENE, [ONCOKB_SOURCE_NAME, TSO500_SOURCE_NAME] + conn, cancer_gene_terms, [ONCOKB_SOURCE_NAME, TSO500_SOURCE_NAME] ) @@ -513,12 +631,12 @@ def get_gene_information( # PositionalVariant without a reference2 implies a smallMutation type gene_flags['knownSmallMutation'].add(condition['reference1']) # type: ignore - logger.info('fetching oncogenes list') - gene_flags['oncogene'] = convert_to_rid_set(get_oncokb_oncogenes(graphkb_conn)) - logger.info('fetching tumour supressors list') - gene_flags['tumourSuppressor'] = convert_to_rid_set(get_oncokb_tumour_supressors(graphkb_conn)) - logger.info('fetching cancerGeneListMatch list') - gene_flags['cancerGeneListMatch'] = convert_to_rid_set(get_cancer_genes(graphkb_conn)) + # cancer gene flags + logger.info('fetching cancer genes') + cancer_gene_flags = get_cancer_gene_flags(graphkb_conn, flags=True) + gene_flags['oncogene'] = convert_to_rid_set(cancer_gene_flags['oncogenic']) + gene_flags['tumourSuppressor'] = convert_to_rid_set(cancer_gene_flags['tumourSuppressive']) + gene_flags['cancerGeneListMatch'] = convert_to_rid_set(cancer_gene_flags['cancerGene']) logger.info('fetching therapeutic associated genes lists') gene_flags['therapeuticAssociated'] = convert_to_rid_set( diff --git a/pori_python/graphkb/util.py b/pori_python/graphkb/util.py index 23c28963..a2a9bb14 100644 --- a/pori_python/graphkb/util.py +++ b/pori_python/graphkb/util.py @@ -354,6 +354,74 @@ def get_source(self, name: str) -> Record: raise AssertionError(f'Unable to unqiuely identify source with name {name}') return source[0] + @property + def version(self) -> Dict[str, str]: + """ + Retrieve GraphKB components version + + Returns: + Dict[str, str]: component keys with version values + + e.g. > {"api":"3.17.3","db":"production","parser":"2.1.0","schema":"4.1.1"} + """ + return self.request('version') + + def get_related_records( + self, + base: Union[str, List[str]], + ontology: str, + subgraphType: str, + returnProperties: Optional[List[str]] = None, + ) -> List[Record]: + """ + Given some base node RIDs, an ontology class and a subgraph type, + leverage the subgraphs route to return the list of related nodes. + + Args: + base: the base node RIDs to start the graph traversal from + ontology: the ontology class to traverse + subgraphType: the type of traversal. See options in API specs + returnProperties: additional record properties to return + + Returns: + list of related node record(s) traversed + """ + related = self.post( + uri=f'/subgraphs/{ontology}', + data={ + 'base': base if isinstance(base, list) else [base], + 'subgraphType': subgraphType, + 'returnProperties': returnProperties or [], + }, + ) + return related['result']['g']['nodes'] + + def get_related_terms( + self, + terms: Union[str, List[str]], + ontology: str = 'Vocabulary', + subgraphType: str = 'similar', + ) -> List[str]: + """ + Given some base term name(s), an ontology class and a subgraph type, + leverage the subgraphs route to return the list of related term name(s) + + Args: + terms: the base term name(s) to start the graph traversal from + ontology: the ontology class to traverse + subgraphType: the type of traversal + + Returns: + list of related term name(s) + """ + rids = convert_to_rid_list(self.query({'target': ontology, 'filters': {'name': terms}})) + nodes = self.get_related_records( + base=rids, + ontology=ontology, + subgraphType=subgraphType, + ) + return [x['name'] for x in nodes.values()] + def get_rid(conn: GraphKBConnection, target: str, name: str) -> str: """ diff --git a/pori_python/graphkb/vocab.py b/pori_python/graphkb/vocab.py index e9242a7a..bb96e5f5 100644 --- a/pori_python/graphkb/vocab.py +++ b/pori_python/graphkb/vocab.py @@ -1,4 +1,4 @@ -from typing import Callable, Dict, Iterable, List, Set, cast +from typing import Callable, Dict, Iterable, List, Set, cast, Union from pori_python.types import Ontology @@ -6,7 +6,7 @@ from .util import convert_to_rid_list -def query_by_name(ontology_class: str, base_term_name: str) -> Dict: +def query_by_name(ontology_class: str, base_term_name: Union[str, list[str]]) -> Dict: return {'target': ontology_class, 'filters': {'name': base_term_name}} diff --git a/pori_python/ipr/connection.py b/pori_python/ipr/connection.py index 70eaf26c..75458d41 100644 --- a/pori_python/ipr/connection.py +++ b/pori_python/ipr/connection.py @@ -93,10 +93,54 @@ def delete(self, uri: str, data: Dict = {}, **kwargs) -> Dict: **kwargs, ) + def check_upload_permission(self, project_name: str) -> None: + """Check that the current user has permission to upload to the given project. + + Fetches all projects and the current user info (including groups and + projects) up front. Checks for admin, manager, create report access, + all projects access, and project membership. + """ + projects = self.get('project') + project_exists = any(p['name'] == project_name for p in projects) + if not project_exists: + raise Exception( + f'Project {project_name} does not exist or user does not have permission to view it' + ) + + user = self.get('user/me') + user_groups = user.get('groups', []) if isinstance(user, dict) else [] + group_names = { + group.get('name', '').strip().lower() + if isinstance(group, dict) + else group.strip().lower() + for group in user_groups + } + + is_admin = 'admin' in group_names + is_manager = 'manager' in group_names + has_create_report_access = 'create report access' in group_names + has_all_projects_access = 'all projects access' in group_names + + # admins and managers can always create reports + can_create_report = is_admin or is_manager or has_create_report_access + + user_projects = user.get('projects', []) if isinstance(user, dict) else [] + has_project_access = ( + is_admin + or has_all_projects_access + or any(isinstance(p, dict) and p.get('name') == project_name for p in user_projects) + ) + + if not can_create_report: + raise Exception('User does not have report creation permission') + + if not has_project_access: + raise Exception(f'User has no permission to create report in project {project_name}') + def upload_report( self, content: Dict, - mins_to_wait: int = 5, + mins_to_wait: int = 10, async_upload: bool = False, ignore_extra_fields: bool = False, ) -> Dict: @@ -105,19 +149,6 @@ def upload_report( # or 'report'. jobStatus is no longer available once the report is successfully # uploaded. - projects = self.get('project') - project_names = [item['name'] for item in projects] - - # if project is not exist, create one - if content['project'] not in project_names: - logger.info( - f'Project not found - attempting to create project {content["project"]}' - ) - try: - self.post('project', {'name': content['project']}) - except Exception as err: - raise Exception(f'Project creation failed due to {err}') - if ignore_extra_fields: initial_result = self.post('reports-async?ignore_extra_fields=true', content) else: diff --git a/pori_python/ipr/content.spec.json b/pori_python/ipr/content.spec.json index 5a1793a2..c994ba6f 100644 --- a/pori_python/ipr/content.spec.json +++ b/pori_python/ipr/content.spec.json @@ -202,6 +202,16 @@ "number", "null" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ @@ -475,6 +485,16 @@ "null", "string" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ @@ -892,6 +912,109 @@ "example": "POG", "type": "string" }, + "seqQC": { + "default": [], + "type": "array", + "items": { + "type": "object", + "properties": { + "reads": { + "description": "Number of reads", + "example": "2534M", + "type": [ + "string", + "null" + ] + }, + "bioQC": { + "description": "Biological QC status", + "example": "passed", + "type": [ + "string", + "null" + ] + }, + "labQC": { + "description": "Lab QC status", + "example": "passed", + "type": [ + "string", + "null" + ] + }, + "sample": { + "description": "Sample identifier, e.g. Tumour DNA, Constitutional DNA", + "example": "Tumour DNA", + "type": [ + "string", + "null" + ] + }, + "library": { + "description": "Library identifier", + "example": "LIB0001", + "type": [ + "string", + "null" + ] + }, + "coverage": { + "description": "Sequencing coverage", + "example": "80x", + "type": [ + "string", + "null" + ] + }, + "inputNg": { + "description": "Input amount in nanograms", + "example": "500", + "type": [ + "string", + "number", + "integer", + "null" + ] + }, + "inputUg": { + "description": "Input amount in micrograms", + "example": "0.5", + "type": [ + "string", + "number", + "integer", + "null" + ] + }, + "protocol": { + "description": "Sequencing protocol", + "example": "WGS", + "type": [ + "string", + "null" + ] + }, + "sampleName": { + "description": "Full sample name", + "example": "SAMPLE1-FF-1", + "type": [ + "string", + "null" + ] + }, + "duplicateReadsPerc": { + "description": "Percentage of duplicate reads", + "example": "12.3", + "type": [ + "string", + "number", + "integer", + "null" + ] + } + } + } + }, "smallMutations": { "default": [], "items": { @@ -1106,6 +1229,16 @@ "string", "null" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ @@ -1290,6 +1423,16 @@ "integer", "null" ] + }, + "flags": { + "description": "variant flags", + "items": { + "type": "string" + }, + "type": [ + "array", + "null" + ] } }, "required": [ diff --git a/pori_python/ipr/inputs.py b/pori_python/ipr/inputs.py index f14fc696..beec4f3f 100644 --- a/pori_python/ipr/inputs.py +++ b/pori_python/ipr/inputs.py @@ -60,6 +60,7 @@ 'comments', 'library', 'germline', + 'flags', ] SMALL_MUT_REQ = ['gene', 'proteinChange'] @@ -98,6 +99,7 @@ 'tumourRefCount', 'tumourRefCopies', 'zygosity', + 'flags', ] EXP_REQ = ['gene', 'kbCategory'] @@ -130,6 +132,7 @@ 'rnaReads', 'rpkm', 'tpm', + 'flags', ] SV_REQ = [ @@ -162,6 +165,7 @@ 'tumourDepth', 'germline', 'mavis_product_id', + 'flags', ] SIGV_REQ = ['signatureName', 'variantTypeName'] @@ -278,6 +282,7 @@ def row_key(row: IprSmallMutationVariant) -> Tuple[str, ...]: return tuple(['small mutation'] + key_vals) result = validate_variant_rows(rows, SMALL_MUT_REQ, SMALL_MUT_OPTIONAL, row_key) + if not result: return [] @@ -336,6 +341,7 @@ def row_key(row: Dict) -> Tuple[str, ...]: return tuple(['expression'] + [row[key] for key in EXP_KEY]) variants = validate_variant_rows(rows, EXP_REQ, EXP_OPTIONAL, row_key) + result = [cast(IprExprVariant, var) for var in variants] float_columns = [ col @@ -371,7 +377,6 @@ def row_key(row: Dict) -> Tuple[str, ...]: if errors: raise ValueError(f'{len(errors)} Invalid expression variants in file') - return result @@ -796,6 +801,58 @@ def check_null(checker, instance): DefaultValidatingDraft7Validator = extend_with_default(jsonschema.Draft7Validator) +def normalize_seqqc(content: Dict) -> Dict: + """ + Normalize seqQC field names from production report format to schema format. + + Maps inconsistent casing and underscores in field names to match content.spec.json requirements. + For example: 'Reads' -> 'reads', 'Sample Name' -> 'sampleName', etc. + + Args: + content: Report content dictionary that may contain seqQC array + + Returns: + A new content dictionary with seqQC fields normalized + """ + content = {**content} + # Field name mapping from production/legacy format to schema format + field_mapping = { + 'Reads': 'reads', + 'Sample': 'sample', + 'Library': 'library', + 'Coverage': 'coverage', + 'Input_ng': 'inputNg', + 'Input_ug': 'inputUg', + 'Protocol': 'protocol', + 'Sample Name': 'sampleName', + 'Duplicate_Reads_Perc': 'duplicateReadsPerc', + } + normalized_keys = set(field_mapping.values()) + + if 'seqQC' in content and isinstance(content['seqQC'], list): + content['seqQC'] = list(content['seqQC']) + for i, item in enumerate(content['seqQC']): + if not isinstance(item, dict): + continue + # Preserve already-normalized keys (and unrelated keys) first so + # legacy aliases cannot overwrite them based on insertion order. + normalized_item = {} + for key, value in item.items(): + if key in normalized_keys or key not in field_mapping: + normalized_item[key] = value + + # Add legacy aliases only when the normalized key is not already + # present. This makes collision handling explicit and stable. + for old_key, new_key in field_mapping.items(): + if old_key in item and new_key not in normalized_item: + normalized_item[new_key] = item[old_key] + + # Replace the item with normalized version + content['seqQC'][i] = normalized_item + + return content + + def validate_report_content(content: Dict, schema_file: str = SPECIFICATION) -> None: """ Validate a report content input JSON object against the schema specification diff --git a/pori_python/ipr/ipr.py b/pori_python/ipr/ipr.py index f5cd6873..9a9afb38 100644 --- a/pori_python/ipr/ipr.py +++ b/pori_python/ipr/ipr.py @@ -160,7 +160,6 @@ def convert_statements_to_alterations( ) if query_result: recruitment_statuses[rid] = query_result[0]['recruitmentStatus'] # type: ignore - for statement in statements: variants = [ cast(Variant, c) for c in statement['conditions'] if c['@class'] in VARIANT_CLASSES @@ -229,6 +228,7 @@ def convert_statements_to_alterations( row['kbContextId'], 'not found' ) rows.append(row) + return rows @@ -727,3 +727,82 @@ def get_kb_disease_matches( raise ValueError(msg) return disease_matches + + +def ensure_str_list(val): + if isinstance(val, str): + return [f.strip() for f in val.split(',') if f.strip()] + if isinstance(val, list): + if not all(isinstance(item, str) for item in val): + raise TypeError('All items in flags must be strings') + return val + raise TypeError(f'Unexpected type in flags field: {type(val).__name__}') + + +def add_transcript_flags(variant_sources, transcript_flags_df): + """ + Add flags from the input transcript_flags_df to the variant_sources + records based on matching transcript keys. + - For non-fusion records with a 'transcript' field, add flags directly based on that field. + - For fusion records without a 'transcript' field but with 'ctermTranscript' and + 'ntermTranscript' fields, add flags based on both transcripts with appropriate labeling + """ + lookup = {} + + # create transcript:flags dict from input df + for _, row in ( + transcript_flags_df[['transcript', 'flags']].dropna(subset=['transcript']).iterrows() + ): + transcript = row['transcript'] + flags = lookup.setdefault(transcript, []) + for flag in ensure_str_list(str(row['flags'])): + if flag not in flags: + flags.append(flag) + + # for fusions: check both transcripts for flags and add to the same record + label_map = {'ctermTranscript': 'cterm', 'ntermTranscript': 'nterm'} + + # single pass: add plain transcript flags and labeled fusion transcript flags + for record in variant_sources: + flags = ensure_str_list(record.setdefault('flags', [])) + + if record.get('transcript'): + # non-fusion: plain transcript only, no cterm/nterm + transcript_flags = lookup.get(record['transcript']) + if transcript_flags: + for new_flag in transcript_flags: + if new_flag not in flags: + flags.append(new_flag) + else: + # fusion: check cterm/nterm transcripts with labels + for key, label in label_map.items(): + transcript = record.get(key) + transcript_flags = lookup.get(transcript) + if not transcript_flags: + continue + for flag in transcript_flags: + new_flag = f'{flag} ({label})' + if new_flag not in flags: + flags.append(new_flag) + + record['flags'] = flags + return variant_sources + + +def get_variant_flags(variant_sources): + flags = [] + for item in variant_sources: + raw_flags = item.get('flags') + if not raw_flags: # skips None and '' + continue + unique_flags = list(dict.fromkeys(f for f in ensure_str_list(raw_flags) if f)) + # create record, removing dupes from flags list + flags.append( + { + 'variant': item['key'], + 'variantType': item['variantType'], + 'flags': unique_flags, + } + ) + item.pop('flags', None) # remove after extraction + return flags diff --git a/pori_python/ipr/main.py b/pori_python/ipr/main.py index cbb7c128..d261cbd5 100644 --- a/pori_python/ipr/main.py +++ b/pori_python/ipr/main.py @@ -6,6 +6,7 @@ import jsonschema.exceptions import logging import os +import pandas as pd from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser from typing import Callable, Dict, List, Optional, Sequence, Set @@ -27,6 +28,7 @@ from .inputs import ( check_comparators, check_variant_links, + normalize_seqqc, preprocess_copy_variants, preprocess_cosmic, preprocess_expression_variants, @@ -46,6 +48,8 @@ get_kb_disease_matches, get_kb_matches_sections, select_expression_plots, + get_variant_flags, + add_transcript_flags, ) from .summary import auto_analyst_comments, get_ipr_analyst_comments from .therapeutic_options import create_therapeutic_options @@ -69,6 +73,26 @@ def timestamp() -> str: return datetime.datetime.now().strftime('%Y-%m-%dT%H:%M:%S') +def load_transcript_flags(path: str) -> pd.DataFrame: + transcript_flags_df = pd.read_csv( + path, + sep='\t', + names=['transcript', 'flags'], + dtype=str, + keep_default_na=False, + ) + if transcript_flags_df.empty: + return transcript_flags_df + + first_row = transcript_flags_df.iloc[0] + if [str(first_row[col]).strip().lower() for col in ['transcript', 'flags']] == [ + 'transcript', + 'flags', + ]: + transcript_flags_df = transcript_flags_df.iloc[1:].reset_index(drop=True) + return transcript_flags_df + + def command_interface() -> None: """Parse the ipr command from user input based on usage pattern. Parsed arguments are used to call the ipr_report() function. @@ -129,7 +153,7 @@ def command_interface() -> None: ) parser.add_argument( '--mins_to_wait', - default=5, + default=10, action='store', help='is using reports-async, number of minutes to wait before throwing error', ) @@ -157,6 +181,12 @@ def command_interface() -> None: action='store_true', help='True if ignore extra fields in json', ) + parser.add_argument( + '--transcript_flags', + required=False, + type=file_path, + help='TSV without header, with two columns: transcripts and flags (comma-separated list of flags eg "MANE"). If header is included, it will be skipped. Matching uses direct string comparison, so transcript identifiers must match exactly, including version numbers (e.g., if input variants use ENST00000390477.1, this file must also use ENST00000390477.1, not ENST00000390477).', + ) args = parser.parse_args() with open(args.content, 'r') as fh: @@ -181,6 +211,7 @@ def command_interface() -> None: upload_json=args.upload_json, validate_json=args.validate_json, ignore_extra_fields=args.ignore_extra_fields, + transcript_flags=args.transcript_flags, ) @@ -234,7 +265,7 @@ def clean_unsupported_content(upload_content: Dict, ipr_spec: Dict = {}) -> Dict for key, count in removed_keys.items(): logger.warning(f"IPR unsupported property '{key}' removed from {count} genes.") - drop_columns = ['variant', 'variantType', 'histogramImage'] + drop_columns = ['variant', 'variantType', 'histogramImage', 'flags'] # DEVSU-2034 - use a 'displayName' VARIANT_LIST_KEYS = [ 'expressionVariants', @@ -281,7 +312,6 @@ def clean_unsupported_content(upload_content: Dict, ipr_spec: Dict = {}) -> Dict # Removing cosmicSignatures. Temporary upload_content.pop('cosmicSignatures', None) - return upload_content @@ -308,7 +338,7 @@ def ipr_report( match_germline: bool = False, custom_kb_match_filter: Optional[Callable] = None, async_upload: bool = False, - mins_to_wait: int = 5, + mins_to_wait: int = 10, include_ipr_variant_text: bool = True, include_nonspecific_disease: bool = False, include_nonspecific_project: bool = False, @@ -318,6 +348,7 @@ def ipr_report( validate_json: bool = False, ignore_extra_fields: bool = False, tmb_high: float = TMB_SIGNATURE_HIGH_THRESHOLD, + transcript_flags: str = '', ) -> Dict: """Run the matching and create the report JSON for upload to IPR. @@ -347,6 +378,7 @@ def ipr_report( include_nonspecific_template: if include_ipr_variant_text is True, if no template match is found use template-nonspecific variant comment allow_partial_matches: allow matches to statements where not all conditions are satisfied tmb_high: mutation burden threshold/cutoff to qualify as 'high' + transcript_flags: path to a tsv file with two columns (no header) of transcript identifiers and flags to be added to any observed variants with matching transcript in the report upload. If header is included, it will be skipped. Matching uses direct string comparison, so transcript identifiers must match exactly, including version numbers (e.g., if input variants use ENST00000390477.1, this file must also use ENST00000390477.1, not ENST00000390477). Returns: ipr_conn.upload_report return dictionary """ @@ -365,12 +397,20 @@ def ipr_report( else: logger.warning('No ipr_url given') + # Verify upload permission before doing any expensive processing + if ipr_upload and ipr_conn: + ipr_conn.check_upload_permission(content['project']) + if validate_json: if not ipr_conn: raise ValueError('ipr_url required to validate json') ipr_result = ipr_conn.validate_json(content) return ipr_result + # seqqc normalization is a bridging measure only; + # validate_json should be called on non-normalized json + content = normalize_seqqc(content) + if upload_json: if not ipr_conn: raise ValueError('ipr_url required to upload json') @@ -386,6 +426,10 @@ def ipr_report( logger.error('Failed schema check - report variants may be corrupted or unmatched.') logger.error(f'Failed schema check: {err}') + transcript_flags_df = None + if transcript_flags: + transcript_flags_df = load_transcript_flags(transcript_flags) + # INPUT VARIANTS VALIDATION & PREPROCESSING (OBSERVED BIOMARKERS) signature_variants: List[IprSignatureVariant] = preprocess_signature_variants( [ @@ -410,6 +454,7 @@ def ipr_report( expression_variants: List[IprExprVariant] = preprocess_expression_variants( content.get('expressionVariants', []) ) + # Additional checks if expression_variants: check_comparators(content, expression_variants) @@ -459,6 +504,10 @@ def ipr_report( *structural_variants, ] # type: ignore + # ANNOTATING VARIANTS WITH TRANSCRIPT FLAGS + if transcript_flags_df is not None and not transcript_flags_df.empty: + all_variants = add_transcript_flags(all_variants, transcript_flags_df) + # GKB_MATCHES FILTERING if match_germline: # verify germline kb statements matched germline observed variants, not somatic variants @@ -531,6 +580,7 @@ def ipr_report( # thread safe deep-copy the original content output = json.loads(json.dumps(content)) output.update(kb_matched_sections) + output.update( { 'copyVariants': [ @@ -559,6 +609,25 @@ def ipr_report( 'therapeuticTarget': targets, } ) + + # ADD OBSERVED VARIANT ANNOTATIONS SECTION + annotatable_variant_sources = [ + v + for source in [ + output[section] + for section in [ + 'smallMutations', + 'copyVariants', + 'expressionVariants', + 'structuralVariants', + ] + if section in output + ] + for v in source + ] + + output['observedVariantAnnotations'] = get_variant_flags(annotatable_variant_sources) + output.setdefault('images', []).extend(select_expression_plots(gkb_matches, all_variants)) # if input includes hrdScore field, that is ok to pass to db @@ -577,6 +646,7 @@ def ipr_report( if not ipr_conn: raise ValueError('ipr_url required to upload report') ipr_spec = ipr_conn.get_spec() + output = clean_unsupported_content(output, ipr_spec) try: logger.info(f'Uploading to IPR {ipr_conn.url}') @@ -593,7 +663,7 @@ def ipr_report( if always_write_output_json: logger.info(f'Writing IPR upload json to: {output_json_path}') with open(output_json_path, 'w') as fh: - fh.write(json.dumps(output)) + json.dump(output, fh, indent=4) logger.info(f'made {graphkb_conn.request_count} requests to graphkb') logger.info(f'average load {int(graphkb_conn.load or 0)} req/s') diff --git a/pori_python/ipr/util.py b/pori_python/ipr/util.py index 69ac7024..52d8fe1c 100644 --- a/pori_python/ipr/util.py +++ b/pori_python/ipr/util.py @@ -61,6 +61,11 @@ def create_variant_name_tuple(variant: IprVariant) -> Tuple[str, str]: return (gene, str(variant.get('expressionState', ''))) elif variant_type == 'cnv': return (gene, str(variant.get('cnvState', ''))) + elif variant_type == 'sigv': + return ( + variant.get('signatureName', variant.get('displayName')), + str(variant.get('variantTypeName', '')), + ) variant_split = ( variant['variant'].split(':', 1)[1] if ':' in variant['variant'] else variant['variant'] ) diff --git a/pori_python/types.py b/pori_python/types.py index dd1ab7e5..3840cfc3 100644 --- a/pori_python/types.py +++ b/pori_python/types.py @@ -134,11 +134,12 @@ def __hash__(self): class IprVariantBase(TypedDict): - """Required properties of all variants for IPR.""" + """Required or possible properties of all variants for IPR.""" key: str variantType: str variant: str + flags: Optional[List[str]] class IprGeneVariant(IprVariantBase): diff --git a/setup.cfg b/setup.cfg index e3b7d63c..0cdeb341 100644 --- a/setup.cfg +++ b/setup.cfg @@ -18,7 +18,7 @@ known_standard_library = requests [metadata] name = pori_python -version = 1.4.0 +version = 1.5.0 url = https://github.com/bcgsc/pori_python author_email = dat@bcgsc.ca maintainer_email = dat@bcgsc.ca diff --git a/tests/test_graphkb/test_genes.py b/tests/test_graphkb/test_genes.py index 90efe5d4..c2ba87a7 100644 --- a/tests/test_graphkb/test_genes.py +++ b/tests/test_graphkb/test_genes.py @@ -8,6 +8,7 @@ from pori_python.graphkb import GraphKBConnection from pori_python.graphkb.genes import ( get_cancer_genes, + get_cancer_gene_flags, get_cancer_predisposition_info, get_gene_information, get_gene_linked_cancer_predisposition_info, @@ -27,7 +28,7 @@ CANONICAL_ONCOGENES = ['kras', 'nras', 'alk'] CANONICAL_TS = ['cdkn2a', 'tp53'] -CANONICAL_CG = ['alb'] +CANONICAL_OTHER_CG = ['alb'] CANONICAL_FUSION_GENES = ['alk', 'ewsr1', 'fli1'] CANONICAL_STRUCTURAL_VARIANT_GENES = ['brca1', 'dpyd', 'pten'] CANNONICAL_THERAPY_GENES = ['erbb2', 'brca2', 'egfr'] @@ -111,6 +112,30 @@ def conn(): return conn +@pytest.mark.skipif(EXCLUDE_ONCOKB_TESTS, reason='excluding tests that depend on oncokb data') +def test_cancer_gene_flags(conn): + # wo/ flags + result = get_cancer_gene_flags(conn) + assert [r['displayName'] for r in result] == sorted( + list({r['displayName'] for r in result}), # makes displayName unique and sorted + ) + for gene in [*CANONICAL_OTHER_CG, *CANONICAL_TS, *CANONICAL_ONCOGENES]: + assert gene in {row['name'] for row in result} + # w/ flags + result = get_cancer_gene_flags(conn, flags=True) + for gene in [*CANONICAL_OTHER_CG, *CANONICAL_TS, *CANONICAL_ONCOGENES]: + assert gene in {row['name'] for row in result['cancerGene']} + for gene in CANONICAL_TS: + assert gene in {row['name'] for row in result['tumourSuppressive']} + assert gene not in {row['name'] for row in result['oncogenic']} + for gene in CANONICAL_ONCOGENES: + assert gene in {row['name'] for row in result['oncogenic']} + assert gene not in {row['name'] for row in result['tumourSuppressive']} + for gene in [*CANONICAL_OTHER_CG]: + assert gene not in {row['name'] for row in result['oncogenic']} + assert gene not in {row['name'] for row in result['tumourSuppressive']} + + @pytest.mark.skipif(EXCLUDE_ONCOKB_TESTS, reason='excluding tests that depend on oncokb data') def test_oncogene(conn): result = get_oncokb_oncogenes(conn) @@ -119,7 +144,7 @@ def test_oncogene(conn): assert gene in names for gene in CANONICAL_TS: assert gene not in names - for gene in CANONICAL_CG: + for gene in CANONICAL_OTHER_CG: assert gene not in names @@ -131,7 +156,7 @@ def test_tumour_supressors(conn): assert gene in names for gene in CANONICAL_ONCOGENES: assert gene not in names - for gene in CANONICAL_CG: + for gene in CANONICAL_OTHER_CG: assert gene not in names @@ -142,12 +167,12 @@ def test_tumour_supressors(conn): def test_cancer_genes(conn): result = get_cancer_genes(conn) names = {row['name'] for row in result} - for gene in CANONICAL_CG: + for gene in CANONICAL_OTHER_CG: assert gene in names for gene in CANONICAL_TS: - assert gene not in names + assert gene in names for gene in CANONICAL_ONCOGENES: - assert gene not in names + assert gene in names @pytest.mark.skipif( @@ -254,7 +279,7 @@ def test_get_gene_information(conn): conn, CANONICAL_ONCOGENES + CANONICAL_TS - + CANONICAL_CG + + CANONICAL_OTHER_CG + CANONICAL_FUSION_GENES + CANONICAL_STRUCTURAL_VARIANT_GENES + CANNONICAL_THERAPY_GENES @@ -300,7 +325,7 @@ def test_get_gene_information(conn): f'Missed kbStatementRelated {gene}' ) - for gene in CANONICAL_CG: + for gene in CANONICAL_ONCOGENES + CANONICAL_TS + CANONICAL_OTHER_CG: assert gene in [g['name'] for g in gene_info if g.get('cancerGeneListMatch')], ( f'Missed cancerGeneListMatch {gene}' ) diff --git a/tests/test_graphkb/test_util.py b/tests/test_graphkb/test_util.py index 36760b2a..dbbb2c2b 100644 --- a/tests/test_graphkb/test_util.py +++ b/tests/test_graphkb/test_util.py @@ -1,5 +1,6 @@ import os import pytest +import re from pori_python.graphkb import GraphKBConnection, util @@ -149,3 +150,44 @@ def test_stringifyVariant_positional(self, conn, rid, createdAt, stringifiedVari variant = conn.get_record_by_id(rid) if variant and variant.get('createdAt', None) == createdAt: assert util.stringifyVariant(variant=variant, **opt) == stringifiedVariant + + +class TestGraphKBConnection: + def test_version(self, conn): + version = conn.version + assert version['db'] in [ + 'production', + 'production-sync-dev', + 'production-sync-staging', + ] + SEMANTIC_VERSIONING_REGEX = re.compile(r'^(0|[1-9]\d*)\.(0|[1-9]\d*)\.(0|[1-9]\d*)$') + assert SEMANTIC_VERSIONING_REGEX.match(version['api']) + assert SEMANTIC_VERSIONING_REGEX.match(version['parser']) + assert SEMANTIC_VERSIONING_REGEX.match(version['schema']) + + def test_get_related_records(self, conn): + base = util.convert_to_rid_list( + conn.query({'target': 'Vocabulary', 'filters': {'name': 'missense'}}) + ) + records = conn.get_related_records( + base=base, + ontology='Vocabulary', + subgraphType='similar', + returnProperties=['displayName'], + ) + assert 'missense mutation' in list(map(lambda x: x['displayName'], records.values())) + + def test_get_related_terms(self, conn): + # with defaults + vocab_terms = conn.get_related_terms( + terms='missense', + ) + assert 'missense mutation' in vocab_terms + + # overriding ontology & subgraphType defaults + disease_terms = conn.get_related_terms( + terms='all solid tumors', + ontology='Disease', + subgraphType='parents', + ) + assert 'cancer' in disease_terms diff --git a/tests/test_ipr/test_connection.py b/tests/test_ipr/test_connection.py index d83ac79a..e2f8d4f5 100644 --- a/tests/test_ipr/test_connection.py +++ b/tests/test_ipr/test_connection.py @@ -95,3 +95,125 @@ def request(*args, **kwargs): ) }, ) + + +class TestCheckUploadPermission: + def _user_response(self, groups=None, projects=None): + return { + 'groups': [{'name': g} for g in (groups or [])], + 'projects': [{'name': p} for p in (projects or [])], + } + + def test_rejects_user_without_create_report_access(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock( + side_effect=[[{'name': 'TEST'}], self._user_response(projects=['TEST'])] + ) + conn.post = mock.MagicMock() + + with pytest.raises(Exception, match='User does not have report creation permission'): + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() + + def test_rejects_user_without_project_access(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock( + side_effect=[ + [{'name': 'TEST'}], + self._user_response(groups=['create report access'], projects=['OTHER']), + ] + ) + conn.post = mock.MagicMock() + + with pytest.raises( + Exception, match='User has no permission to create report in project TEST' + ): + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() + + def test_allows_user_with_project_and_create_report_access(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock( + side_effect=[ + [{'name': 'TEST'}], + self._user_response(groups=['create report access'], projects=['TEST']), + ] + ) + conn.post = mock.MagicMock() + + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() + + def test_project_not_found_raises(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock(side_effect=[[{'name': 'OTHER'}]]) + conn.post = mock.MagicMock() + + with pytest.raises(Exception, match='Project TEST does not exist'): + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() + + def test_manager_without_project_membership_raises(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock( + side_effect=[ + [{'name': 'TEST'}], + self._user_response(groups=['manager'], projects=[]), + ] + ) + conn.post = mock.MagicMock() + + with pytest.raises( + Exception, match='User has no permission to create report in project TEST' + ): + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() + + def test_manager_with_project_membership_allowed(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock( + side_effect=[ + [{'name': 'TEST'}], + self._user_response(groups=['manager'], projects=['TEST']), + ] + ) + conn.post = mock.MagicMock() + + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() + + def test_admin_bypasses_all_checks(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock( + side_effect=[ + [{'name': 'TEST'}], + self._user_response(groups=['admin'], projects=[]), + ] + ) + conn.post = mock.MagicMock() + + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() + + def test_all_projects_access_without_project_membership(self): + conn = IprConnection('user', 'pass') + conn.get = mock.MagicMock( + side_effect=[ + [{'name': 'TEST'}], + self._user_response( + groups=['create report access', 'all projects access'], projects=[] + ), + ] + ) + conn.post = mock.MagicMock() + + conn.check_upload_permission('TEST') + + conn.post.assert_not_called() diff --git a/tests/test_ipr/test_inputs.py b/tests/test_ipr/test_inputs.py index 4bdd6b6d..f3cd6f99 100644 --- a/tests/test_ipr/test_inputs.py +++ b/tests/test_ipr/test_inputs.py @@ -17,6 +17,7 @@ check_comparators, check_variant_links, create_graphkb_sv_notation, + normalize_seqqc, preprocess_copy_variants, preprocess_cosmic, preprocess_expression_variants, @@ -558,3 +559,162 @@ def test_valid_json_inputs(example_name: str): with open(os.path.join(DATA_DIR, 'json_examples', f'{example_name}.json'), 'r') as fh: content = json.load(fh) validate_report_content(content) + + +class TestNormalizeSeqQC: + """Test seqQC field name normalization from production format to schema format.""" + + def test_normalize_seqqc_production_format(self): + """Test normalization of production report field names.""" + content = { + 'seqQC': [ + { + 'Reads': '2407M', + 'Sample': 'Tumour DNA', + 'Library': 'LIB0001', + 'Coverage': '96X', + 'Input_ng': 400, + 'Input_ug': '', + 'Protocol': 'Genome Shotgun FFPE 4.2', + 'Sample Name': 'SAMPLE-T-01', + 'bioQC': 'Passed', + 'labQC': 'Approved', + 'Duplicate_Reads_Perc': 18, + } + ] + } + + result = normalize_seqqc(content) + + assert result['seqQC'][0]['reads'] == '2407M' + assert result['seqQC'][0]['sample'] == 'Tumour DNA' + assert result['seqQC'][0]['library'] == 'LIB0001' + assert result['seqQC'][0]['coverage'] == '96X' + assert result['seqQC'][0]['inputNg'] == 400 + assert result['seqQC'][0]['inputUg'] == '' + assert result['seqQC'][0]['protocol'] == 'Genome Shotgun FFPE 4.2' + assert result['seqQC'][0]['sampleName'] == 'SAMPLE-T-01' + assert result['seqQC'][0]['bioQC'] == 'Passed' + assert result['seqQC'][0]['labQC'] == 'Approved' + assert result['seqQC'][0]['duplicateReadsPerc'] == 18 + # Old keys should be gone + assert 'Reads' not in result['seqQC'][0] + assert 'Sample' not in result['seqQC'][0] + + def test_normalize_seqqc_already_normalized(self): + """Test that already-normalized field names are preserved.""" + content = { + 'seqQC': [ + { + 'reads': '1200M', + 'sample': 'Constitutional DNA', + 'library': 'LIB0002', + 'coverage': '40x', + 'inputNg': '300', + 'protocol': 'WGS', + 'sampleName': 'SAMPLE-N-01', + 'bioQC': 'passed', + 'labQC': 'passed', + 'duplicateReadsPerc': '8.1', + } + ] + } + + result = normalize_seqqc(content) + + # All normalized keys should still exist with same values + assert result['seqQC'][0]['reads'] == '1200M' + assert result['seqQC'][0]['sample'] == 'Constitutional DNA' + assert result['seqQC'][0]['inputNg'] == '300' + + def test_normalize_seqqc_no_seqqc_field(self): + """Test that content without seqQC is unchanged.""" + content = { + 'patientId': 'TEST001', + 'project': 'TEST', + } + + result = normalize_seqqc(content) + + assert result == content + assert 'seqQC' not in result + + def test_normalize_seqqc_empty_seqqc(self): + """Test that empty seqQC array is handled.""" + content = {'seqQC': []} + + result = normalize_seqqc(content) + + assert result['seqQC'] == [] + + def test_normalize_seqqc_multiple_items(self): + """Test normalization of multiple seqQC items.""" + content = { + 'seqQC': [ + { + 'Reads': '2534M', + 'Sample': 'Tumour DNA', + 'Duplicate_Reads_Perc': 12.3, + }, + { + 'Reads': '1200M', + 'Sample': 'Constitutional DNA', + 'Duplicate_Reads_Perc': 8.1, + }, + ] + } + + result = normalize_seqqc(content) + + assert len(result['seqQC']) == 2 + assert result['seqQC'][0]['reads'] == '2534M' + assert result['seqQC'][0]['sample'] == 'Tumour DNA' + assert result['seqQC'][0]['duplicateReadsPerc'] == 12.3 + assert result['seqQC'][1]['reads'] == '1200M' + assert result['seqQC'][1]['sample'] == 'Constitutional DNA' + assert result['seqQC'][1]['duplicateReadsPerc'] == 8.1 + + def test_normalize_seqqc_numeric_fields_pass_validation(self): + """Test that integer/float values for inputNg, inputUg, duplicateReadsPerc pass schema validation.""" + content = { + 'patientId': 'PATIENT001', + 'kbDiseaseMatch': 'colorectal cancer', + 'project': 'TEST', + 'template': 'genomic', + 'seqQC': [ + { + 'reads': '2407M', + 'sample': 'Tumour DNA', + 'library': 'LIB0001', + 'inputNg': 400, + 'inputUg': 0.4, + 'duplicateReadsPerc': 18, + } + ], + } + # Should not raise + validate_report_content(content) + + def test_normalize_seqqc_numeric_float_duplicateReadsPerc_passes_validation(self): + """Test that a float duplicateReadsPerc value passes schema validation after normalization.""" + content = { + 'patientId': 'PATIENT001', + 'kbDiseaseMatch': 'colorectal cancer', + 'project': 'TEST', + 'template': 'genomic', + 'seqQC': [ + { + 'Reads': '2534M', + 'Sample': 'Tumour DNA', + 'Duplicate_Reads_Perc': 12.3, + 'Input_ng': 500, + 'Input_ug': 0.5, + } + ], + } + result = normalize_seqqc(content) + assert result['seqQC'][0]['duplicateReadsPerc'] == 12.3 + assert result['seqQC'][0]['inputNg'] == 500 + assert result['seqQC'][0]['inputUg'] == 0.5 + # Should not raise after normalization + validate_report_content(result) diff --git a/tests/test_ipr/test_ipr.py b/tests/test_ipr/test_ipr.py index 3e9b01a3..687c14b4 100644 --- a/tests/test_ipr/test_ipr.py +++ b/tests/test_ipr/test_ipr.py @@ -1,4 +1,5 @@ import pytest +import pandas as pd from unittest.mock import Mock, patch from pori_python.graphkb import statement as gkb_statement @@ -12,7 +13,11 @@ get_kb_variants, get_kb_matches_sections, create_key_alterations, + ensure_str_list, + add_transcript_flags, + get_variant_flags, ) + from pori_python.types import Statement DISEASE_RIDS = ['#138:12', '#138:13'] @@ -415,6 +420,179 @@ def test_approved_therapeutic(self, mock_get_evidencelevel_mapping, graphkb_conn assert row['category'] == 'therapeutic' +class TestFlagUtilities: + def test_ensure_str_list_accepts_string(self): + assert ensure_str_list('abc') == ['abc'] + + def test_ensure_str_list_splits_comma_separated_string(self): + assert ensure_str_list('a, b , c') == ['a', 'b', 'c'] + + def test_ensure_str_list_accepts_list_of_strings(self): + assert ensure_str_list(['a', 'b']) == ['a', 'b'] + + def test_ensure_str_list_rejects_bad_types(self): + with pytest.raises(TypeError): + ensure_str_list([1, 'a']) + with pytest.raises(TypeError): + ensure_str_list(123) + + def test_add_transcript_flags_basic_adds_flags_from_comma_separated_string(self): + variant_sources = [ + {'transcript': 'T1', 'key': 'k1', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T1'], 'flags': ['flag_a,flag_b']}) + result = add_transcript_flags(variant_sources, df) + assert set(result[0]['flags']) == {'flag_a', 'flag_b'} + + def test_add_transcript_flags_basic_converts_string_flag_to_list_avoiding_duplicates(self): + variant_sources = [ + {'transcript': 'T2', 'flags': 'existing', 'key': 'k2', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T2'], 'flags': ['existing']}) + result = add_transcript_flags(variant_sources, df) + assert result[0]['flags'] == ['existing'] + + def test_add_transcript_flags_basic_leaves_unmatched_transcripts_unaffected(self): + variant_sources = [ + {'transcript': 'T3', 'flags': ['present'], 'key': 'k3', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T1', 'T2'], 'flags': ['flag_a,flag_b', 'existing']}) + result = add_transcript_flags(variant_sources, df) + assert result[0]['flags'] == ['present'] + + def test_add_transcript_flags_basic_strips_whitespace_from_comma_separated_flags(self): + variant_sources = [ + {'transcript': 'T4', 'key': 'k4', 'variantType': 'mut'}, + ] + df = pd.DataFrame({'transcript': ['T4'], 'flags': ['flag_c, flag_d']}) + result = add_transcript_flags(variant_sources, df) + assert set(result[0]['flags']) == {'flag_c', 'flag_d'} + + def test_add_transcript_flags_basic_accumulates_duplicate_transcript_rows(self): + variant_sources = [ + {'transcript': 'T5', 'key': 'k5', 'variantType': 'mut'}, + ] + df = pd.DataFrame( + { + 'gene': ['ENSG1', 'ENSG2'], + 'transcript': ['T5', 'T5'], + 'flags': ['flag_a', 'flag_b, flag_c'], + } + ) + result = add_transcript_flags(variant_sources, df) + assert result[0]['flags'] == ['flag_a', 'flag_b', 'flag_c'] + + def test_add_transcript_flags_fusions_tags_cterm_flags(self): + variant_sources = [ + { + 'key': 'f1', + 'variantType': 'fusion', + 'ctermTranscript': 'CT1', + 'ntermTranscript': 'NT1', + } + ] + df = pd.DataFrame( + { + 'transcript': ['CT1'], + 'flags': ['cterm_flag'], + } + ) + result = add_transcript_flags(variant_sources, df) + flags = result[0]['flags'] + assert 'cterm_flag (cterm)' in flags + + def test_add_transcript_flags_fusions_tags_nterm_flags(self): + variant_sources = [ + { + 'key': 'f1', + 'variantType': 'fusion', + 'ctermTranscript': 'CT1', + 'ntermTranscript': 'NT1', + } + ] + df = pd.DataFrame( + { + 'transcript': ['NT1'], + 'flags': ['nterm_flag'], + } + ) + result = add_transcript_flags(variant_sources, df) + flags = result[0]['flags'] + assert 'nterm_flag (nterm)' in flags + + def test_add_transcript_flags_fusions_accumulates_duplicate_transcript_rows(self): + variant_sources = [ + { + 'key': 'f2', + 'variantType': 'fusion', + 'ctermTranscript': 'CT2', + 'ntermTranscript': 'NT2', + } + ] + df = pd.DataFrame( + { + 'gene': ['ENSG3', 'ENSG4'], + 'transcript': ['CT2', 'CT2'], + 'flags': ['cterm_flag_a', 'cterm_flag_b'], + } + ) + result = add_transcript_flags(variant_sources, df) + assert result[0]['flags'] == ['cterm_flag_a (cterm)', 'cterm_flag_b (cterm)'] + + def test_get_variant_flags_converts_string_flags_to_records(self): + variants = [ + {'key': 'k1', 'variantType': 'mut', 'flags': 'foo'}, + ] + out = get_variant_flags(variants) + assert any(item['variant'] == 'k1' and item['flags'] == ['foo'] for item in out) + assert len(out) == 1 + + def test_get_variant_flags_deduplicates_and_removes_empty_strings(self): + variants = [ + {'key': 'k2', 'variantType': 'mut', 'flags': ['bar', 'bar', '']}, + ] + out = get_variant_flags(variants) + assert any(item['variant'] == 'k2' and set(item['flags']) == {'bar'} for item in out) + + def test_get_variant_flags_preserves_input_flag_order_when_deduplicating(self): + variants = [ + {'key': 'k5', 'variantType': 'mut', 'flags': ['flag_b', 'flag_a', 'flag_b', 'flag_c']}, + ] + out = get_variant_flags(variants) + assert out == [ + { + 'variant': 'k5', + 'variantType': 'mut', + 'flags': ['flag_b', 'flag_a', 'flag_c'], + } + ] + + def test_get_variant_flags_skips_null_flags(self): + variants = [ + {'key': 'k3', 'variantType': 'mut', 'flags': None}, + ] + out = get_variant_flags(variants) + assert not any(item['variant'] == 'k3' for item in out) + assert len(out) == 0 + + def test_get_variant_flags_skips_empty_list_flags(self): + variants = [ + {'key': 'k4', 'variantType': 'mut', 'flags': []}, + ] + out = get_variant_flags(variants) + assert not any(item['variant'] == 'k4' for item in out) + assert len(out) == 0 + + def test_get_variant_flags_removes_flags_key_from_processed_records(self): + variants = [ + {'key': 'k1', 'variantType': 'mut', 'flags': 'foo'}, + {'key': 'k2', 'variantType': 'mut', 'flags': ['bar', 'bar', '']}, + ] + get_variant_flags(variants) + assert 'flags' not in variants[0] + assert 'flags' not in variants[1] + + class TestKbmatchFilters: def test_germline_kb_matches(self): assert len(germline_kb_matches(GERMLINE_KB_MATCHES, GERMLINE_VARIANTS)) == len( diff --git a/tests/test_ipr/test_main.py b/tests/test_ipr/test_main.py index 8fe585cd..3dbd0aa5 100644 --- a/tests/test_ipr/test_main.py +++ b/tests/test_ipr/test_main.py @@ -7,7 +7,7 @@ from unittest.mock import MagicMock, patch from pori_python.ipr.connection import IprConnection -from pori_python.ipr.main import command_interface +from pori_python.ipr.main import command_interface, load_transcript_flags from pori_python.types import IprGene from .constants import EXCLUDE_INTEGRATION_TESTS @@ -28,6 +28,30 @@ def get_test_file(name: str) -> str: return os.path.join(os.path.dirname(__file__), 'test_data', name) +class TestLoadTranscriptFlags: + def test_accepts_file_without_header(self, tmp_path) -> None: + transcript_flags_file = tmp_path / 'transcript_flags.tsv' + transcript_flags_file.write_text('ENST1\tflag_a\nENST2\tflag_b, flag_c\n') + + result = load_transcript_flags(str(transcript_flags_file)) + + assert result.to_dict(orient='records') == [ + {'transcript': 'ENST1', 'flags': 'flag_a'}, + {'transcript': 'ENST2', 'flags': 'flag_b, flag_c'}, + ] + + def test_accepts_file_with_header(self, tmp_path) -> None: + transcript_flags_file = tmp_path / 'transcript_flags.tsv' + transcript_flags_file.write_text('transcript\tflags\nENST1\tflag_a\nENST2\tflag_b\n') + + result = load_transcript_flags(str(transcript_flags_file)) + + assert result.to_dict(orient='records') == [ + {'transcript': 'ENST1', 'flags': 'flag_a'}, + {'transcript': 'ENST2', 'flags': 'flag_b'}, + ] + + @pytest.fixture(scope='module') def report_upload_content(tmp_path_factory) -> Dict: mock = MagicMock() @@ -86,6 +110,11 @@ def side_effect_function(*args, **kwargs): return [{'name': 'genomic', 'ident': '001'}] elif args[0] == 'project': return [{'name': 'TEST', 'ident': '001'}] + elif args[0] == 'user/me': + return { + 'groups': [{'name': 'admin'}], + 'projects': [{'name': 'TEST'}], + } else: return [] diff --git a/tests/test_ipr/test_probe.py b/tests/test_ipr/test_probe.py index 43ead9f1..ec93599c 100644 --- a/tests/test_ipr/test_probe.py +++ b/tests/test_ipr/test_probe.py @@ -25,6 +25,11 @@ def side_effect_function(*args, **kwargs): return [{'name': 'genomic', 'ident': '001'}] elif args[0] == 'project': return [{'name': 'TEST', 'ident': '001'}] + elif args[0] == 'user/me': + return { + 'groups': [{'name': 'admin'}], + 'projects': [{'name': 'TEST'}], + } else: return [] diff --git a/tests/test_ipr/test_upload.py b/tests/test_ipr/test_upload.py index 2c6fb73c..6e7f9b05 100644 --- a/tests/test_ipr/test_upload.py +++ b/tests/test_ipr/test_upload.py @@ -19,7 +19,7 @@ DELETE_UPLOAD_TEST_REPORTS = os.environ.get('DELETE_UPLOAD_TEST_REPORTS', '1') == '1' -def get_test_spec(): +def get_test_spec() -> dict: ipr_spec = {'components': {'schemas': {'genesCreate': {'properties': {}}}}} ipr_gene_keys = IprGene.__required_keys__ | IprGene.__optional_keys__ for key in ipr_gene_keys: @@ -31,12 +31,35 @@ def get_test_file(name: str) -> str: return os.path.join(os.path.dirname(__file__), 'test_data', name) +def get_test_transcript_flags(json_contents) -> pd.DataFrame: + """creates a dataframe of transcript flags for test purposes, based on the input json contents""" + transcript_flags = [] + for item in json_contents['structuralVariants']: + transcript_flags.append((item['ntermTranscript'], 'TRANSCRIPT FLAG')) + transcript_flags.append((item['ctermTranscript'], 'TRANSCRIPT FLAG')) + for item in json_contents['smallMutations']: + transcript_flags.append((item['transcript'], 'TRANSCRIPT FLAG')) + df = pd.DataFrame(transcript_flags, columns=['transcript', 'flags']) + df = df.drop_duplicates() + return df + + +def add_test_variant_flags_to_input_data(json_contents) -> dict: + """adds flags to the input variants for test purposes""" + for vtype in ['structuralVariants', 'smallMutations', 'copyVariants', 'expressionVariants']: + for item in json_contents[vtype]: + item['flags'] = ['TEST FLAG'] + return json_contents + + @pytest.fixture(scope='module') def loaded_reports(tmp_path_factory) -> Generator: json_file = tmp_path_factory.mktemp('inputs') / 'content.json' async_json_file = tmp_path_factory.mktemp('inputs') / 'async_content.json' + transcript_flags_file = tmp_path_factory.mktemp('inputs') / 'transcript_flags.tsv' patient_id = f'TEST_{str(uuid.uuid4())}' async_patient_id = f'TEST_ASYNC_{str(uuid.uuid4())}' + json_contents = { 'comparators': [ {'analysisRole': 'expression (disease)', 'name': '1'}, @@ -106,9 +129,40 @@ def loaded_reports(tmp_path_factory) -> Generator: 'caption': 'Test adding a caption to an image', } ], + 'seqQC': [ + { + 'sample': 'Tumour DNA', + 'reads': '2534M', + 'library': 'LIB0001', + 'coverage': '80x', + 'inputNg': '500', + 'protocol': 'WGS', + 'sampleName': 'SAMPLE2-FF-1', + 'bioQC': 'passed', + 'labQC': 'passed', + 'duplicateReadsPerc': '12.3', + }, + { + 'sample': 'Constitutional DNA', + 'reads': '1200M', + 'library': 'LIB0002', + 'coverage': '40x', + 'inputNg': '300', + 'protocol': 'WGS', + 'sampleName': 'SAMPLE1-PB', + 'bioQC': 'passed', + 'labQC': 'passed', + 'duplicateReadsPerc': '8.1', + }, + ], 'config': 'test config', } + json_contents = add_test_variant_flags_to_input_data(json_contents) + + transcript_flags_df = get_test_transcript_flags(json_contents) + transcript_flags_df.to_csv(transcript_flags_file, sep='\t', index=False) + json_file.write_text( json.dumps( json_contents, @@ -140,6 +194,8 @@ def loaded_reports(tmp_path_factory) -> Generator: os.environ.get('GRAPHKB_URL', False), '--therapeutics', '--allow_partial_matches', + '--transcript_flags', + str(transcript_flags_file), ] sync_argslist = argslist.copy() @@ -192,7 +248,7 @@ def stringify_sorted(obj): obj.sort() return str(obj) elif isinstance(obj, dict): - for key in ('ident', 'updatedAt', 'createdAt', 'deletedAt'): + for key in ('ident', 'updatedAt', 'createdAt', 'deletedAt', 'variantId', 'id', 'reportId'): obj.pop(key, None) keys = obj.keys() for key in keys: @@ -330,6 +386,18 @@ def test_analyst_comments_loaded(self, loaded_reports) -> None: assert async_section['comments'] assert sync_section['comments'] == async_section['comments'] + def test_seqqc_loaded(self, loaded_reports) -> None: + """Test that seqQC data is present in the loaded report.""" + sync_report = loaded_reports['sync'][1]['reports'][0] + assert 'seqQC' in sync_report + assert len(sync_report['seqQC']) == 2 + samples = [item['sample'] for item in sync_report['seqQC']] + assert 'Tumour DNA' in samples + assert 'Constitutional DNA' in samples + async_report = loaded_reports['async'][1]['reports'][0] + assert 'seqQC' in async_report + assert len(async_report['seqQC']) == 2 + def test_sample_info_loaded(self, loaded_reports) -> None: sync_section = get_section(loaded_reports['sync'], 'sample-info') async_section = get_section(loaded_reports['async'], 'sample-info') diff --git a/tests/test_ipr/test_util.py b/tests/test_ipr/test_util.py index bbae6d98..9208f51e 100644 --- a/tests/test_ipr/test_util.py +++ b/tests/test_ipr/test_util.py @@ -17,13 +17,32 @@ def test_trim_empty_values(input, output_keys): [ [ {'variantType': 'exp', 'gene': 'GENE', 'expressionState': 'increased expression'}, - 'increased expression', + ('GENE', 'increased expression'), + ], + [ + {'variantType': 'cnv', 'gene': 'GENE', 'cnvState': 'amplification'}, + ('GENE', 'amplification'), + ], + [ + {'variantType': 'other', 'gene2': 'GENE', 'variant': 'GENE:anything'}, + ('GENE', 'anything'), + ], + [ + {'variantType': 'sigv', 'displayName': 'test signature signature present'}, + ('test signature signature present', ''), + ], + [ + { + 'variantType': 'sigv', + 'displayName': 'test signature signature present', + 'signatureName': 'test signature', + 'variantTypeName': 'signature present', + }, + ('test signature', 'signature present'), ], - [{'variantType': 'cnv', 'gene': 'GENE', 'cnvState': 'amplification'}, 'amplification'], - [{'variantType': 'other', 'gene2': 'GENE', 'variant': 'GENE:anything'}, 'anything'], ], ) def test_create_variant_name_tuple(variant, result): gene, name = create_variant_name_tuple(variant) - assert name == result - assert gene == 'GENE' + assert gene == result[0] + assert name == result[1]