diff --git a/medcat/datasets/transformers_ner.py b/medcat/datasets/transformers_ner.py index 7d8362d45..c0db635ba 100644 --- a/medcat/datasets/transformers_ner.py +++ b/medcat/datasets/transformers_ner.py @@ -64,7 +64,7 @@ def _info(self): def _split_generators(self, dl_manager): # noqa """Returns SplitGenerators.""" # noqa - return [ + splits = [ datasets.SplitGenerator( name=datasets.Split.TRAIN, gen_kwargs={ @@ -73,6 +73,19 @@ def _split_generators(self, dl_manager): # noqa ), ] + # Only add test split if test data files are provided + if 'test' in self.config.data_files: + splits.append( + datasets.SplitGenerator( + name=datasets.Split.TEST, + gen_kwargs={ + "filepaths": self.config.data_files['test'], + }, + ) + ) + + return splits + def _generate_examples(self, filepaths): # noqa cnt = 0 for filepath in filepaths: diff --git a/medcat/ner/transformers_ner.py b/medcat/ner/transformers_ner.py index 715771027..a6a5c57a0 100644 --- a/medcat/ner/transformers_ner.py +++ b/medcat/ner/transformers_ner.py @@ -177,6 +177,8 @@ def train(self, ignore_extra_labels=False, dataset=None, meta_requirements=None, + train_json_path: Union[str, list, None]=None, + test_json_path: Union[str, list, None]=None, trainer_callbacks: Optional[List[Callable[[Trainer], TrainerCallback]]] = None) -> Tuple: """Train or continue training a model give a json_path containing a MedCATtrainer export. It will continue training if an existing model is loaded or start new training if the model is blank/new. @@ -187,8 +189,10 @@ def train(self, ignore_extra_labels: Makes only sense when an existing deid model was loaded and from the new data we want to ignore labels that did not exist in the old model. - dataset: Defaults to None. + dataset: Defaults to None. Will be split by self.config.general['test_size'] into train and test datasets. meta_requirements: Defaults to None + train_json_path (str): Defaults to None. If provided, will be used as the training dataset json_path to load from + test_json_path (str): Defaults to None. If provided, will be used as the test dataset json_path to load from trainer_callbacks (List[Callable[[Trainer], TrainerCallback]]]): A list of trainer callbacks for collecting metrics during the training at the client side. The transformers Trainer object will be passed in when each callback is called. @@ -200,11 +204,16 @@ def train(self, Tuple: The dataframe, examples, and the dataset """ - if dataset is None and json_path is not None: + if dataset is None: # Load the medcattrainer export - json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels, + if json_path is not None: + json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels, meta_requirements=meta_requirements, file_name='data_eval.json') - # Load dataset + elif test_json_path is not None and train_json_path is not None: + train_json_path = self._prepare_dataset(train_json_path, ignore_extra_labels=ignore_extra_labels, + meta_requirements=meta_requirements, file_name='data_train.json') + test_json_path = self._prepare_dataset(test_json_path, ignore_extra_labels=ignore_extra_labels, + meta_requirements=meta_requirements, file_name='data_test.json') # NOTE: The following is for backwards comppatibility # in datasets==2.20.0 `trust_remote_code=True` must be explicitly @@ -216,13 +225,21 @@ def train(self, ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True) else: ds_load_dataset = datasets.load_dataset - dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__), - data_files={'train': json_path}, # type: ignore - split='train', - cache_dir='/tmp/') - # We split before encoding so the split is document level, as encoding - #does the document splitting into max_seq_len - dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore + + if json_path: + dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__), + data_files={'train': json_path}, # type: ignore + split='train', + cache_dir='/tmp/') + # We split before encoding so the split is document level, as encoding + # does the document splitting into max_seq_len + dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore + elif train_json_path and test_json_path: + dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__), + data_files={'train': train_json_path, 'test': test_json_path}, # type: ignore + cache_dir='/tmp/') + else: + raise ValueError("Either json_path or train_json_path and test_json_path must be provided when no dataset is provided") # Update labelmap in case the current dataset has more labels than what we had before self.tokenizer.calculate_label_map(dataset['train']) @@ -231,8 +248,8 @@ def train(self, if self.model.num_labels != len(self.tokenizer.label_map): logger.warning("The dataset contains labels we've not seen before, model is being reinitialized") logger.warning("Model: {} vs Dataset: {}".format(self.model.num_labels, len(self.tokenizer.label_map))) - self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'], - num_labels=len(self.tokenizer.label_map), + self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'], + num_labels=len(self.tokenizer.label_map), ignore_mismatched_sizes=True) self.tokenizer.cui2name = {k:self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()} @@ -273,7 +290,6 @@ def train(self, # NOTE: this shouldn't really happen, but we'll do this for type safety raise ValueError("Output path should not be None!") self.save(save_dir_path=os.path.join(output_dir, 'final_model')) - # Run an eval step and return metrics p = trainer.predict(encoded_dataset['test']) # type: ignore df, examples = metrics(p, return_df=True, tokenizer=self.tokenizer, dataset=encoded_dataset['test']) diff --git a/medcat/utils/ner/deid.py b/medcat/utils/ner/deid.py index 522924503..814632eef 100644 --- a/medcat/utils/ner/deid.py +++ b/medcat/utils/ner/deid.py @@ -34,6 +34,7 @@ - config - cdb """ +import re from typing import Union, Tuple, Any, List, Iterable, Optional, Dict import logging @@ -62,9 +63,11 @@ class DeIdModel(NerModel): def __init__(self, cat: CAT) -> None: self.cat = cat - def train(self, json_path: Union[str, list, None], + def train(self, json_path: Union[str, list, None] = None, *args, **kwargs) -> Tuple[Any, Any, Any]: - return super().train(json_path, *args, train_nr=0, **kwargs) # type: ignore + assert not all([json_path, kwargs.get('train_json_path'), kwargs.get('test_json_path')]), \ + "Either json_path or train_json_path and test_json_path must be provided when no dataset is provided" + return super().train(json_path=json_path, *args, **kwargs) # type: ignore def eval(self, json_path: Union[str, list, None], *args, **kwargs) -> Tuple[Any, Any, Any]: @@ -146,7 +149,8 @@ def deid_multi_texts(self, return out @classmethod - def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel': + def load_model_pack(cls, model_pack_path: str, + config: Optional[Dict] = None) -> 'DeIdModel': """Load DeId model from model pack. The method first loads the CAT instance. @@ -164,7 +168,7 @@ def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> Returns: DeIdModel: The resulting DeI model. """ - ner_model = NerModel.load_model_pack(model_pack_path,config=config) + ner_model = NerModel.load_model_pack(model_pack_path, config=config) cat = ner_model.cat if not cls._is_deid_model(cat): raise ValueError( @@ -180,7 +184,135 @@ def _is_deid_model(cls, cat: CAT) -> bool: @classmethod def _get_reason_not_deid(cls, cat: CAT) -> str: if cat.vocab is not None: - return "Has vocab" + return "Has voc§ab" if len(cat._addl_ner) != 1: return f"Incorrect number of addl_ner: {len(cat._addl_ner)}" return "" + + +def match_rules(rules: List[Tuple[str, str]], texts: List[str], cui2preferred_name: Dict[str, str]) -> List[List[Dict]]: + """Match a set of rules - pat / cui combos as post processing labels. + + Uses a cat DeID model for pretty name mapping. + + Args: + rules (List[Tuple[str, str]]): List of tuples of pattern and cui + texts (List[str]): List of texts to match rules on + cui2preferred_name (Dict[str, str]): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name. + + Examples: + >>> cat = CAT.load_model_pack(model_pack_path) + ... + >>> rules = [ + ('(123) 456-7890', '134'), + ('1234567890', '134'), + ('123.456.7890', '134'), + ('1234567890', '134'), + ('1234567890', '134'), + ] + >>> texts = [ + 'My phone number is (123) 456-7890', + 'My phone number is 1234567890', + 'My phone number is 123.456.7890', + 'My phone number is 1234567890', + ] + >>> matches = match_rules(rules, texts, cat.cdb.cui2preferred_name) + + Returns: + List[List[Dict]]: List of lists of predictions from `match_rules` + """ + # Iterate through each text and pattern combination + rule_matches_per_text = [] + for i, text in enumerate(texts): + matches_in_text = [] + for pattern, concept in rules: + # Find all matches of current pattern in current text + text_matches = re.finditer(pattern, text, flags=re.M) + # Add each match with its pattern and text info + for match in text_matches: + matches_in_text.append({ + 'source_value': match.group(), + 'pretty_name': cui2preferred_name[concept], + 'start': match.start(), + 'end': match.end(), + 'cui': concept, + 'acc': 1.0 + }) + rule_matches_per_text.append(matches_in_text) + return rule_matches_per_text + + +def merge_all_preds(model_preds_by_text: List[List[Dict]], + rule_matches_per_text: List[List[Dict]], + accept_preds: bool = True) -> List[List[Dict]]: + """Conveniance method to merge predictions from rule based and deID model predictions. + + Args: + model_preds_by_text (List[Dict]): list of predictions from + `cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]` + rule_matches_per_text (List[Dict]): list of predictions from output of + running `match_rules` + accept_preds (bool): uses the predicted label from the model, + model_preds_by_text, over the rule matches if they overlap. + Defaults to using model preds over rules. + + Returns: + List[List[Dict]]: List of lists of predictions from `merge_all_preds` + """ + assert len(model_preds_by_text) == len(rule_matches_per_text), \ + "model_preds_by_text and rule_matches_per_text must have the same length as they should be CAT.get_entities and match_rules outputs of the same text" + return [merge_preds(model_preds_by_text[i], rule_matches_per_text[i], accept_preds) for i in range(len(model_preds_by_text))] + + +def merge_preds(model_preds: List[Dict], + rule_matches: List[Dict], + accept_preds: bool = True) -> List[Dict]: + """Merge predictions from rule based and deID model predictions. + + Args: + model_preds (List[Dict]): predictions from `cat.get_entities()` + rule_matches (List[Dict]): predictions from output of running `match_rules` on a text + accept_preds (bool): uses the predicted label from the model, + model_preds, over the rule matches if they overlap. + Defaults to using model preds over rules. + + Examples: + >>> # a list of predictions from `cat.get_entities()` + >>> model_preds = [ + [ + {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, + 'pretty_name': 'Phone Number'} + ] + ] + >>> # a list of predictions from `match_rules` + >>> rule_matches = [ + [ + {'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0, + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, + 'pretty_name': 'Phone Number'} + ] + ] + >>> merged_preds = merge_preds(model_preds, rule_matches) + + Returns: + List[Dict]: List of predictions from `merge_preds` + """ + if accept_preds: + labels1 = model_preds + labels2 = rule_matches + else: + labels1 = rule_matches + labels2 = model_preds + + # Keep only non-overlapping model predictions + labels2 = [span2 for span2 in labels2 + if not any(not (span2['end'] <= span1['start'] or span1['end'] <= span2['start']) + for span1 in labels1)] + # merge preds and sort on start + merged_preds = labels1 + labels2 + merged_preds.sort(key=lambda x: x['start']) + merged_preds + return merged_preds diff --git a/medcat/utils/ner/metrics.py b/medcat/utils/ner/metrics.py index 78175cec8..862748380 100644 --- a/medcat/utils/ner/metrics.py +++ b/medcat/utils/ner/metrics.py @@ -1,3 +1,4 @@ +from typing import Dict, List from sklearn.metrics import classification_report import numpy as np import pandas as pd @@ -11,7 +12,25 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, merged_negative={0, 1, -100}, padding_label=-100, csize=15, subword_label=1, verbose=False): - """TODO: This could be done better, for sure. But it works.""" # noqa + """ + Calculate metrics for a model's predictions, based off the tokenized output of a MedCATTrainer project. + + Args: + p: The model's predictions. + return_df: Whether to return a DataFrame of metrics. + plus_recall: The recall to add to the model's predictions. + tokenizer: The tokenizer used to tokenize the texts. + dataset: The dataset used to train the model. + merged_negative: The negative labels to merge. + padding_label: The padding label. + csize: The size of the context window. + subword_label: The subword label. + verbose: Whether to print the metrics. + + Returns: + Dict: A dictionary of metrics. + """ + predictions = np.array(p.predictions) predictions = softmax(predictions, axis=2) examples = None @@ -104,7 +123,7 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, mer for key in _cr: cui = ilabel_map[key] p_merged = tp_all / (tp_all + fp_all) if (tp_all + fp_all) > 0 else 0 - data.append([cui, tokenizer.cui2name.get(cui, cui), _cr[key]['precision'], + data.append([cui, tokenizer.cui2name.get(cui, cui), _cr[key]['precision'], _cr[key]['recall'], _cr[key]['f1-score'], _cr[key]['support'], _cr[key]['r_merged'], p_merged]) df = pd.DataFrame(data[1:], columns=data[0]) @@ -117,3 +136,87 @@ def metrics(p, return_df=False, plus_recall=0, tokenizer=None, dataset=None, mer 'precison_merged': np.average([x for x in df.p_merged.values if pd.notna(x)])} else: return df, examples + + +def _anno_within_pred_list(label: Dict, preds: List[Dict]) -> bool: + """ + Check if a label is within a list of predictions, + + Args: + label (Dict): an annotation likely from a MedCATTrainer project + preds (List[Dict]): a list of predictions likely from a cat.__call__ + + Returns: + bool: True if the label is within the list of predictions, False otherwise + """ + return any(label['start'] >= p['start'] and label['end'] <= p['end'] for p in preds) + + +def evaluate_predictions(true_annotations: List[List[Dict]], all_preds: List[List[Dict]], texts: List[str], cui2preferred_name: Dict[str, str]): + """ + Evaluate predictions against sets of collected labels as collected and output from a MedCATTrainer project. + Counts predictions as correct if the prediction fully encloses the label. + + Args: + true_annotations (List[List[Dict]]): Ground truth predictions by text + all_preds (List[List[Dict]]): Model predictions by text + texts (List[str]): Original list of texts + cui2preferred_name (Dict[str, str]): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name. + + Returns: + Tuple[pd.DataFrame, Dict]: A tuple containing a DataFrame of evaluation metrics and a dictionary of missed annotations per CUI. + """ + per_cui_recall = {} + per_cui_prec = {} + per_cui_recall_merged = {} + per_cui_anno_counts = {} + per_cui_annos_missed = defaultdict(list) + uniq_labels = set([p['cui'] for ap in true_annotations for p in ap]) + + for cui in uniq_labels: + # annos in test set + anno_count = sum([len([p for p in cui_annos if p['cui'] == cui]) for cui_annos in true_annotations]) + pred_counts = sum([len([p for p in d if p['cui'] == cui]) for d in all_preds]) + + # print(anno_count) + # print(pred_counts) + + # print(f'pred_count: {pred_counts}, anno_count:{anno_count}') + per_cui_anno_counts[cui] = anno_count + + doc_annos_left, preds_left, doc_annos_left_any_cui = [], [], [] + + for doc_preds, doc_labels, text in zip(all_preds, true_annotations, texts): + # num of annos that are not found - recall + cui_labels = [label for label in doc_labels if label['cui'] == cui] + cui_doc_preds = [pred for pred in doc_preds if pred['cui'] == cui] + + labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, cui_doc_preds)] + doc_annos_left.append(len(labels_not_found)) + + # num of annos that are not found across any cui prediction - recall_merged + any_labels_not_found = [label for label in cui_labels if not _anno_within_pred_list(label, doc_preds)] + doc_annos_left_any_cui.append(len(any_labels_not_found)) + + per_cui_annos_missed[cui].append(any_labels_not_found) + + # num of preds that are incorrect - precision + preds_left.append(len([label for label in cui_doc_preds if not _anno_within_pred_list(label, cui_labels)])) + + if anno_count != 0 and pred_counts != 0: + per_cui_recall[cui] = (anno_count - sum(doc_annos_left)) / anno_count + per_cui_recall_merged[cui] = (anno_count - sum(doc_annos_left_any_cui)) / anno_count + per_cui_prec[cui] = (pred_counts - sum(preds_left)) / pred_counts + else: + per_cui_recall[cui] = 0 + per_cui_recall_merged[cui] = 0 + per_cui_prec[cui] = 0 + + res_df = pd.DataFrame({ + 'cui': per_cui_recall_merged.keys(), + 'recall_merged': per_cui_recall_merged.values(), + 'recall': per_cui_recall.values(), + 'precision': per_cui_prec.values(), + 'label_count': per_cui_anno_counts.values()}, index=[cui2preferred_name[k] for k in per_cui_recall_merged]) + + return res_df, per_cui_annos_missed diff --git a/tests/ner/test_transformers_ner.py b/tests/ner/test_transformers_ner.py index 14579711c..9e50ccd8c 100644 --- a/tests/ner/test_transformers_ner.py +++ b/tests/ner/test_transformers_ner.py @@ -34,9 +34,11 @@ def test_pipe(self): def test_train(self): tracker = unittest.mock.Mock() + class _DummyCallback(TrainerCallback): def __init__(self, trainer) -> None: self._trainer = trainer + def on_epoch_end(self, *args, **kwargs) -> None: tracker.call() @@ -49,13 +51,32 @@ def on_epoch_end(self, *args, **kwargs) -> None: assert dataset["test"].num_rows == 12 self.assertEqual(tracker.call.call_count, 2) + def test_train_with_test_file(self): + tracker = unittest.mock.Mock() + + class _DummyCallback(TrainerCallback): + def __init__(self, trainer) -> None: + self._trainer = trainer + + def on_epoch_end(self, *args, **kwargs) -> None: + tracker.call() + + train_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_train_data.json") + test_data = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "resources", "deid_test_data.json") + self.undertest.training_arguments.num_train_epochs = 1 + df, examples, dataset = self.undertest.train(train_json_path=train_data, test_json_path=test_data, trainer_callbacks=[_DummyCallback]) + assert "fp" in examples + assert "fn" in examples + assert dataset["train"].num_rows == 60 + self.assertEqual(tracker.call.call_count, 1) + def test_expand_model_with_concepts(self): original_num_labels = self.undertest.model.num_labels - original_out_features = self.undertest.model.classifier.out_features + original_out_features = self.undertest.model.classifier.out_features original_label_map_size = len(self.undertest.tokenizer.label_map) cui2preferred_name = { - "concept_1" : "Preferred Name 1", - "concept_2" : "Preferred Name 2", + "concept_1": "Preferred Name 1", + "concept_2": "Preferred Name 2", } self.undertest.expand_model_with_concepts(cui2preferred_name) diff --git a/tests/utils/ner/test_deid.py b/tests/utils/ner/test_deid.py index 9eda6c973..4aaf2bc08 100644 --- a/tests/utils/ner/test_deid.py +++ b/tests/utils/ner/test_deid.py @@ -41,6 +41,7 @@ def test_can_create_model(self): deid_model = deid.DeIdModel.create(ner) self.assertIsNotNone(deid_model) + def _add_model(cls): cdb = make_or_update_cdb(TRAIN_DATA) config = transformers_ner.ConfigTransformersNER() @@ -98,13 +99,14 @@ def test_add_new_concepts(self): self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.label_map) self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.cui2name) + input_text = ''' -James Joyce -7 Eccles Street, +James Joyce +7 Eccles Street, Dublin CC: Memory difficulty. -HX: Mr James is a 64 y/o RHM, had difficulty remembering names, phone numbers and events for 12 months prior to presentation, on 2/28/95. He had visited London recently and had had no professional or social faux pas or mishaps due to his memory. J.J. could not tell whether his problem was becoming worse, so he brought himself to the Neurology clinic on his own referral. +HX: Mr James is a 64 y/o RHM, had difficulty remembering names, phone numbers and events for 12 months prior to presentation, on 2/28/95. He had visited London recently and had had no professional or social faux pas or mishaps due to his memory. J.J. could not tell whether his problem was becoming worse, so he brought himself to the Neurology clinic on his own referral. FHX: Both parents (Mary and John) experienced memory problems in their ninth decades, but not earlier. 5 siblings have had no memory trouble. There are no neurological illnesses in his family. @@ -147,6 +149,7 @@ def test_model_works_deid_text_redact(self): # self.assertNotIn("Dublin", anon_text) self.assertNotIn("7 Eccles Street", anon_text) + class DeIDModelMultiprocessingWorks(unittest.TestCase): processes = 2 @@ -198,3 +201,170 @@ def test_model_can_multiprocess_redact(self): for tid, new_text in enumerate(processed): with self.subTest(str(tid)): self.assertTextHasBeenDeIded(new_text, redacted=True) + + +class MatchRulesTests(unittest.TestCase): + def test_match_rules(self): + # Test data from the docstring example + rules = [ + (r'\(\d{3}\)\s*\d{3}-\d{4}', '134'), # (123) 456-7890 + (r'\d{3}\.\d{3}\.\d{4}', '134'), # 123.456.7890 + (r'\d{10}', '134'), # 1234567890 + ] + texts = [ + 'My phone number is (123) 456-7890', + 'My phone number is 1234567890', + 'My phone number is 123.456.7890', + ] + cui2preferred_name = {'134': 'Phone Number'} + + # Get matches + matches = deid.match_rules(rules, texts, cui2preferred_name) + + # Verify results + self.assertEqual(len(matches), len(texts)) # One list of matches per text + + # Check first text matches + self.assertEqual(len(matches[0]), 1) # One match in first text + self.assertEqual(matches[0][0]['source_value'], '(123) 456-7890') + self.assertEqual(matches[0][0]['pretty_name'], 'Phone Number') + self.assertEqual(matches[0][0]['cui'], '134') + self.assertEqual(matches[0][0]['acc'], 1.0) + self.assertEqual(matches[0][0]['start'], 19) # Position of phone number in text + self.assertEqual(matches[0][0]['end'], 33) # End position of phone number + + # Check second text matches + self.assertEqual(len(matches[1]), 1) # One match in second text + self.assertEqual(matches[1][0]['source_value'], '1234567890') + self.assertEqual(matches[1][0]['pretty_name'], 'Phone Number') + self.assertEqual(matches[1][0]['cui'], '134') + self.assertEqual(matches[1][0]['acc'], 1.0) + self.assertEqual(matches[1][0]['start'], 19) # Position of phone number in text + self.assertEqual(matches[1][0]['end'], 29) # End position of phone number + + # Check third text matches + self.assertEqual(len(matches[2]), 1) # One match in third text + self.assertEqual(matches[2][0]['source_value'], '123.456.7890') + self.assertEqual(matches[2][0]['pretty_name'], 'Phone Number') + self.assertEqual(matches[2][0]['cui'], '134') + self.assertEqual(matches[2][0]['acc'], 1.0) + self.assertEqual(matches[2][0]['start'], 19) # Position of phone number in text + self.assertEqual(matches[2][0]['end'], 31) # End position of phone number + + def test_merge_preds(self): + # Test data with overlapping predictions + model_preds = [ + {'cui': '134', 'start': 10, 'end': 20, 'acc': 0.9, + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 25, 'end': 35, 'acc': 0.8, + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 50, 'end': 60, 'acc': 0.9, # Non-overlapping model pred + 'pretty_name': 'Phone Number'} + ] + rule_matches = [ + {'cui': '134', 'start': 15, 'end': 25, 'acc': 1.0, # Overlaps with first model pred + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 30, 'end': 40, 'acc': 1.0, # Overlaps with second model pred + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 70, 'end': 80, 'acc': 1.0, # Non-overlapping rule match + 'pretty_name': 'Phone Number'} + ] + + # Test with accept_preds=True (default) + merged_preds = deid.merge_preds(model_preds, rule_matches) + self.assertEqual(len(merged_preds), 4) # Should return a list with 4 elements + self.assertEqual(merged_preds[0]['start'], 10) # First model pred + self.assertEqual(merged_preds[1]['start'], 25) # Second model pred + self.assertEqual(merged_preds[2]['start'], 50) # Third model pred + self.assertEqual(merged_preds[3]['start'], 70) # Fourth rule match + + # Test with accept_preds=False + merged_preds = deid.merge_preds(model_preds, rule_matches, accept_preds=False) + self.assertEqual(len(merged_preds), 4) # Should return a list with 4 elements + self.assertEqual(merged_preds[0]['start'], 15) # First rule match + self.assertEqual(merged_preds[1]['start'], 30) # Second rule match + self.assertEqual(merged_preds[2]['start'], 50) # Third model pred + self.assertEqual(merged_preds[3]['start'], 70) # Fourth rule match + + # Test with non-overlapping predictions + model_preds = [ + {'cui': '134', 'start': 10, 'end': 20, 'acc': 0.9, + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 50, 'end': 60, 'acc': 0.9, # Additional non-overlapping model pred + 'pretty_name': 'Phone Number'} + ] + rule_matches = [ + {'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0, + 'pretty_name': 'Phone Number'}, + {'cui': '134', 'start': 70, 'end': 80, 'acc': 1.0, # Additional non-overlapping rule match + 'pretty_name': 'Phone Number'} + ] + + # Test with accept_preds=True (default) + merged_preds = deid.merge_preds(model_preds, rule_matches) + self.assertEqual(len(merged_preds), 4) # Should keep all predictions + self.assertEqual(merged_preds[0]['start'], 10) # First model pred + self.assertEqual(merged_preds[1]['start'], 25) # First rule match + self.assertEqual(merged_preds[2]['start'], 50) # Second model pred + self.assertEqual(merged_preds[3]['start'], 70) # Second rule match + + def test_merge_all_preds(self): + # Test with lists of different lengths + model_preds_by_text = [ + [{'cui': '134', 'start': 10, 'end': 20, 'acc': 0.9, 'pretty_name': 'Phone Number'}], + [{'cui': '134', 'start': 25, 'end': 35, 'acc': 0.8, 'pretty_name': 'Phone Number'}] + ] + rule_matches_per_text = [ + [{'cui': '134', 'start': 15, 'end': 25, 'acc': 1.0, 'pretty_name': 'Phone Number'}] + ] + + # Test that it raises ValueError for different lengths + with self.assertRaises(AssertionError) as context: + deid.merge_all_preds(model_preds_by_text, rule_matches_per_text) + self.assertIn("must have the same length", str(context.exception)) + + # Test with consistent lengths + model_preds_by_text = [ + [{'cui': '134', 'start': 10, 'end': 20, 'acc': 0.9, 'pretty_name': 'Phone Number'}], + [{'cui': '134', 'start': 25, 'end': 35, 'acc': 0.8, 'pretty_name': 'Phone Number'}] + ] + rule_matches_per_text = [ + [{'cui': '134', 'start': 15, 'end': 25, 'acc': 1.0, 'pretty_name': 'Phone Number'}], + [{'cui': '134', 'start': 30, 'end': 40, 'acc': 1.0, 'pretty_name': 'Phone Number'}] + ] + + # Test with accept_preds=True (default) + merged_preds = deid.merge_all_preds(model_preds_by_text, rule_matches_per_text) + self.assertEqual(len(merged_preds), 2) # Two texts + self.assertEqual(len(merged_preds[0]), 1) # First text has one model pred + self.assertEqual(len(merged_preds[1]), 1) # Second text has one model pred + self.assertEqual(merged_preds[0][0]['start'], 10) # First text model pred + self.assertEqual(merged_preds[1][0]['start'], 25) # Second text model pred + + # Test with accept_preds=False + merged_preds = deid.merge_all_preds(model_preds_by_text, rule_matches_per_text, accept_preds=False) + self.assertEqual(len(merged_preds), 2) # Two texts + self.assertEqual(len(merged_preds[0]), 1) # First text has one rule match + self.assertEqual(len(merged_preds[1]), 1) # Second text has one rule match + self.assertEqual(merged_preds[0][0]['start'], 15) # First text rule match + self.assertEqual(merged_preds[1][0]['start'], 30) # Second text rule match + + # Test with non-overlapping predictions + model_preds_by_text = [ + [{'cui': '134', 'start': 10, 'end': 20, 'acc': 0.9, 'pretty_name': 'Phone Number'}], + [{'cui': '134', 'start': 25, 'end': 35, 'acc': 0.8, 'pretty_name': 'Phone Number'}] + ] + rule_matches_per_text = [ + [{'cui': '134', 'start': 30, 'end': 40, 'acc': 1.0, 'pretty_name': 'Phone Number'}], + [{'cui': '134', 'start': 50, 'end': 60, 'acc': 1.0, 'pretty_name': 'Phone Number'}] + ] + + # Test with accept_preds=True (default) + merged_preds = deid.merge_all_preds(model_preds_by_text, rule_matches_per_text) + self.assertEqual(len(merged_preds), 2) # Two texts + self.assertEqual(len(merged_preds[0]), 2) # First text has both preds + self.assertEqual(len(merged_preds[1]), 2) # Second text has both preds + self.assertEqual(merged_preds[0][0]['start'], 10) # First text model pred + self.assertEqual(merged_preds[0][1]['start'], 30) # First text rule match + self.assertEqual(merged_preds[1][0]['start'], 25) # Second text model pred + self.assertEqual(merged_preds[1][1]['start'], 50) # Second text rule match