Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 14 additions & 1 deletion medcat/datasets/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def _info(self):

def _split_generators(self, dl_manager): # noqa
"""Returns SplitGenerators.""" # noqa
return [
splits = [
datasets.SplitGenerator(
name=datasets.Split.TRAIN,
gen_kwargs={
Expand All @@ -73,6 +73,19 @@ def _split_generators(self, dl_manager): # noqa
),
]

# Only add test split if test data files are provided
if 'test' in self.config.data_files:
splits.append(
datasets.SplitGenerator(
name=datasets.Split.TEST,
gen_kwargs={
"filepaths": self.config.data_files['test'],
},
)
)

return splits

def _generate_examples(self, filepaths): # noqa
cnt = 0
for filepath in filepaths:
Expand Down
44 changes: 30 additions & 14 deletions medcat/ner/transformers_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,8 @@ def train(self,
ignore_extra_labels=False,
dataset=None,
meta_requirements=None,
train_json_path: Union[str, list, None]=None,
test_json_path: Union[str, list, None]=None,
trainer_callbacks: Optional[List[Callable[[Trainer], TrainerCallback]]] = None) -> Tuple:
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
continue training if an existing model is loaded or start new training if the model is blank/new.
Expand All @@ -187,8 +189,10 @@ def train(self,
ignore_extra_labels:
Makes only sense when an existing deid model was loaded and from the new data we want to ignore
labels that did not exist in the old model.
dataset: Defaults to None.
dataset: Defaults to None. Will be split by self.config.general['test_size'] into train and test datasets.
meta_requirements: Defaults to None
train_json_path (str): Defaults to None. If provided, will be used as the training dataset json_path to load from
test_json_path (str): Defaults to None. If provided, will be used as the test dataset json_path to load from
trainer_callbacks (List[Callable[[Trainer], TrainerCallback]]]):
A list of trainer callbacks for collecting metrics during the training at the client side. The
transformers Trainer object will be passed in when each callback is called.
Expand All @@ -200,11 +204,16 @@ def train(self,
Tuple: The dataframe, examples, and the dataset
"""

if dataset is None and json_path is not None:
if dataset is None:
# Load the medcattrainer export
json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
if json_path is not None:
json_path = self._prepare_dataset(json_path, ignore_extra_labels=ignore_extra_labels,
meta_requirements=meta_requirements, file_name='data_eval.json')
# Load dataset
elif test_json_path is not None and train_json_path is not None:
train_json_path = self._prepare_dataset(train_json_path, ignore_extra_labels=ignore_extra_labels,
meta_requirements=meta_requirements, file_name='data_train.json')
test_json_path = self._prepare_dataset(test_json_path, ignore_extra_labels=ignore_extra_labels,
meta_requirements=meta_requirements, file_name='data_test.json')

# NOTE: The following is for backwards comppatibility
# in datasets==2.20.0 `trust_remote_code=True` must be explicitly
Expand All @@ -216,13 +225,21 @@ def train(self,
ds_load_dataset = partial(datasets.load_dataset, trust_remote_code=True)
else:
ds_load_dataset = datasets.load_dataset
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
data_files={'train': json_path}, # type: ignore
split='train',
cache_dir='/tmp/')
# We split before encoding so the split is document level, as encoding
#does the document splitting into max_seq_len
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore

if json_path:
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
data_files={'train': json_path}, # type: ignore
split='train',
cache_dir='/tmp/')
# We split before encoding so the split is document level, as encoding
# does the document splitting into max_seq_len
dataset = dataset.train_test_split(test_size=self.config.general['test_size']) # type: ignore
elif train_json_path and test_json_path:
dataset = ds_load_dataset(os.path.abspath(transformers_ner.__file__),
data_files={'train': train_json_path, 'test': test_json_path}, # type: ignore
cache_dir='/tmp/')
else:
raise ValueError("Either json_path or train_json_path and test_json_path must be provided when no dataset is provided")

# Update labelmap in case the current dataset has more labels than what we had before
self.tokenizer.calculate_label_map(dataset['train'])
Expand All @@ -231,8 +248,8 @@ def train(self,
if self.model.num_labels != len(self.tokenizer.label_map):
logger.warning("The dataset contains labels we've not seen before, model is being reinitialized")
logger.warning("Model: {} vs Dataset: {}".format(self.model.num_labels, len(self.tokenizer.label_map)))
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
num_labels=len(self.tokenizer.label_map),
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
num_labels=len(self.tokenizer.label_map),
ignore_mismatched_sizes=True)
self.tokenizer.cui2name = {k:self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()}

Expand Down Expand Up @@ -273,7 +290,6 @@ def train(self,
# NOTE: this shouldn't really happen, but we'll do this for type safety
raise ValueError("Output path should not be None!")
self.save(save_dir_path=os.path.join(output_dir, 'final_model'))

# Run an eval step and return metrics
p = trainer.predict(encoded_dataset['test']) # type: ignore
df, examples = metrics(p, return_df=True, tokenizer=self.tokenizer, dataset=encoded_dataset['test'])
Expand Down
142 changes: 137 additions & 5 deletions medcat/utils/ner/deid.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- config
- cdb
"""
import re
from typing import Union, Tuple, Any, List, Iterable, Optional, Dict
import logging

Expand Down Expand Up @@ -62,9 +63,11 @@ class DeIdModel(NerModel):
def __init__(self, cat: CAT) -> None:
self.cat = cat

def train(self, json_path: Union[str, list, None],
def train(self, json_path: Union[str, list, None] = None,
*args, **kwargs) -> Tuple[Any, Any, Any]:
return super().train(json_path, *args, train_nr=0, **kwargs) # type: ignore
assert not all([json_path, kwargs.get('train_json_path'), kwargs.get('test_json_path')]), \
"Either json_path or train_json_path and test_json_path must be provided when no dataset is provided"
return super().train(json_path=json_path, *args, **kwargs) # type: ignore

def eval(self, json_path: Union[str, list, None],
*args, **kwargs) -> Tuple[Any, Any, Any]:
Expand Down Expand Up @@ -146,7 +149,8 @@ def deid_multi_texts(self,
return out

@classmethod
def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) -> 'DeIdModel':
def load_model_pack(cls, model_pack_path: str,
config: Optional[Dict] = None) -> 'DeIdModel':
"""Load DeId model from model pack.

The method first loads the CAT instance.
Expand All @@ -164,7 +168,7 @@ def load_model_pack(cls, model_pack_path: str, config: Optional[Dict] = None) ->
Returns:
DeIdModel: The resulting DeI model.
"""
ner_model = NerModel.load_model_pack(model_pack_path,config=config)
ner_model = NerModel.load_model_pack(model_pack_path, config=config)
cat = ner_model.cat
if not cls._is_deid_model(cat):
raise ValueError(
Expand All @@ -180,7 +184,135 @@ def _is_deid_model(cls, cat: CAT) -> bool:
@classmethod
def _get_reason_not_deid(cls, cat: CAT) -> str:
if cat.vocab is not None:
return "Has vocab"
return "Has voc§ab"
if len(cat._addl_ner) != 1:
return f"Incorrect number of addl_ner: {len(cat._addl_ner)}"
return ""


def match_rules(rules: List[Tuple[str, str]], texts: List[str], cui2preferred_name: Dict[str, str]) -> List[List[Dict]]:
"""Match a set of rules - pat / cui combos as post processing labels.

Uses a cat DeID model for pretty name mapping.

Args:
rules (List[Tuple[str, str]]): List of tuples of pattern and cui
texts (List[str]): List of texts to match rules on
cui2preferred_name (Dict[str, str]): Dictionary of CUI to preferred name, likely to be cat.cdb.cui2preferred_name.

Examples:
>>> cat = CAT.load_model_pack(model_pack_path)
...
>>> rules = [
('(123) 456-7890', '134'),
('1234567890', '134'),
('123.456.7890', '134'),
('1234567890', '134'),
('1234567890', '134'),
]
>>> texts = [
'My phone number is (123) 456-7890',
'My phone number is 1234567890',
'My phone number is 123.456.7890',
'My phone number is 1234567890',
]
>>> matches = match_rules(rules, texts, cat.cdb.cui2preferred_name)

Returns:
List[List[Dict]]: List of lists of predictions from `match_rules`
"""
# Iterate through each text and pattern combination
rule_matches_per_text = []
for i, text in enumerate(texts):
matches_in_text = []
for pattern, concept in rules:
# Find all matches of current pattern in current text
text_matches = re.finditer(pattern, text, flags=re.M)
# Add each match with its pattern and text info
for match in text_matches:
matches_in_text.append({
'source_value': match.group(),
'pretty_name': cui2preferred_name[concept],
'start': match.start(),
'end': match.end(),
'cui': concept,
'acc': 1.0
})
rule_matches_per_text.append(matches_in_text)
return rule_matches_per_text


def merge_all_preds(model_preds_by_text: List[List[Dict]],
rule_matches_per_text: List[List[Dict]],
accept_preds: bool = True) -> List[List[Dict]]:
"""Conveniance method to merge predictions from rule based and deID model predictions.

Args:
model_preds_by_text (List[Dict]): list of predictions from
`cat.get_entities()`, then `[list(m['entities'].values()) for m in model_preds]`
rule_matches_per_text (List[Dict]): list of predictions from output of
running `match_rules`
accept_preds (bool): uses the predicted label from the model,
model_preds_by_text, over the rule matches if they overlap.
Defaults to using model preds over rules.

Returns:
List[List[Dict]]: List of lists of predictions from `merge_all_preds`
"""
assert len(model_preds_by_text) == len(rule_matches_per_text), \
"model_preds_by_text and rule_matches_per_text must have the same length as they should be CAT.get_entities and match_rules outputs of the same text"
return [merge_preds(model_preds_by_text[i], rule_matches_per_text[i], accept_preds) for i in range(len(model_preds_by_text))]


def merge_preds(model_preds: List[Dict],
rule_matches: List[Dict],
accept_preds: bool = True) -> List[Dict]:
"""Merge predictions from rule based and deID model predictions.

Args:
model_preds (List[Dict]): predictions from `cat.get_entities()`
rule_matches (List[Dict]): predictions from output of running `match_rules` on a text
accept_preds (bool): uses the predicted label from the model,
model_preds, over the rule matches if they overlap.
Defaults to using model preds over rules.

Examples:
>>> # a list of predictions from `cat.get_entities()`
>>> model_preds = [
[
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
'pretty_name': 'Phone Number'},
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
'pretty_name': 'Phone Number'}
]
]
>>> # a list of predictions from `match_rules`
>>> rule_matches = [
[
{'cui': '134', 'start': 10, 'end': 20, 'acc': 1.0,
'pretty_name': 'Phone Number'},
{'cui': '134', 'start': 25, 'end': 35, 'acc': 1.0,
'pretty_name': 'Phone Number'}
]
]
>>> merged_preds = merge_preds(model_preds, rule_matches)

Returns:
List[Dict]: List of predictions from `merge_preds`
"""
if accept_preds:
labels1 = model_preds
labels2 = rule_matches
else:
labels1 = rule_matches
labels2 = model_preds

# Keep only non-overlapping model predictions
labels2 = [span2 for span2 in labels2
if not any(not (span2['end'] <= span1['start'] or span1['end'] <= span2['start'])
for span1 in labels1)]
# merge preds and sort on start
merged_preds = labels1 + labels2
merged_preds.sort(key=lambda x: x['start'])
merged_preds
return merged_preds
Loading