diff --git a/medcat-v2/medcat/cat.py b/medcat-v2/medcat/cat.py index 8b8e11383..2825ee561 100644 --- a/medcat-v2/medcat/cat.py +++ b/medcat-v2/medcat/cat.py @@ -660,7 +660,7 @@ def _doc_to_out(self, def trainer(self): """The trainer object.""" if not self._trainer: - self._trainer = Trainer(self.cdb, self.__call__, self._pipeline) + self._trainer = Trainer(self.cdb, self._pipeline) return self._trainer def save_model_card(self, model_card_path: str) -> None: diff --git a/medcat-v2/medcat/trainer.py b/medcat-v2/medcat/trainer.py index f1c65eeca..c7f62550f 100644 --- a/medcat-v2/medcat/trainer.py +++ b/medcat-v2/medcat/trainer.py @@ -1,4 +1,4 @@ -from typing import Iterable, Callable, Optional, Union, cast +from typing import Iterable, Optional, Union, cast import logging import tempfile from itertools import chain, repeat, islice @@ -28,11 +28,10 @@ class Trainer: strict_train: bool = False - def __init__(self, cdb: CDB, caller: Callable[[str], MutableDocument], + def __init__(self, cdb: CDB, pipeline: Pipeline): self.cdb = cdb self.config = cdb.config - self.caller = caller self._pipeline = pipeline def train_unsupervised(self, @@ -92,7 +91,7 @@ def _train_unsupervised(self, # inference run for the document try: - doc = self.caller(line) + doc = self._pipeline.get_doc(line) except Exception as e: logger.warning("LINE: '%s...' \t WAS SKIPPED", line[0:100]) logger.warning("BECAUSE OF:", exc_info=e) @@ -416,7 +415,19 @@ def _prepare_doc_with_anns( for ann in anns: tkns = doc.get_tokens(ann['start'], ann['end']) try: - ents.append(self._pipeline.entity_from_tokens_in_doc(tkns, doc)) + ent = self._pipeline.entity_from_tokens_in_doc(tkns, doc) + pn_dict = prepare_name(ann['value'], self._pipeline.tokenizer, {}, + self._pn_configs) + processed_names = list(pn_dict.keys()) + if len(processed_names) > 1: + logger.info("Got multiple processed names for %s: %s", + ann['value'], processed_names) + elif not processed_names: + # NOTE: shouldn't really happen + raise ValueError(f"Could not process {ann['value']} into names") + ent.detected_name = processed_names[0] + ent.cui = ann['cui'] + ents.append(ent) except ValueError as err: self._warn_on_error( err, doc.base.text, @@ -471,7 +482,8 @@ def _train_supervised_for_project2(self, doc = docs[idx_doc] with temp_changed_config(self.config.components.linking, 'train', False): - mut_doc = self.caller(doc['text']) + # NOTE: only need tokenization here + mut_doc = self._pipeline.tokenizer_with_tag(doc['text']) self._prepare_doc_with_anns(mut_doc, doc, doc['annotations']) # Compatibility with old output where annotations are a list @@ -646,6 +658,7 @@ def add_and_train_concept(self, mut_entity = self._pipeline.entity_from_tokens(mut_entity) component.train(cui=cui, entity=mut_entity, doc=mut_doc, negative=negative, names=names) + trained_comps += 1 if not negative and devalue_others: # Find all cuis diff --git a/medcat-v2/tests/test_trainer.py b/medcat-v2/tests/test_trainer.py index 9e3808e5d..dd518c521 100644 --- a/medcat-v2/tests/test_trainer.py +++ b/medcat-v2/tests/test_trainer.py @@ -1,5 +1,6 @@ import os import json +from typing import Callable from medcat.tokenizing.tokens import MutableDocument from medcat.trainer import Trainer @@ -100,6 +101,12 @@ def train(self, *args, **kwargs) -> None: class FakePipeline: + def __init__(self, caller: Callable[[str], MutableDocument] = None) -> None: + self._caller = caller or FakeMutDoc + + def get_doc(self, text: str) -> FakeMutDoc: + return self._caller(text) + def tokenizer(self, text: str) -> FakeMutDoc: return FakeMutDoc(text) @@ -118,7 +125,8 @@ def entity_from_tokens_in_doc(self, tkns: list, doc: MutableDocument) -> FakeMut class FakePipelineWithComponents(FakePipeline): - def __init__(self, components: list): + def __init__(self, components: list, caller: Callable[[str], MutableDocument] = None): + super().__init__(caller) self._components = components def iter_all_components(self): @@ -137,8 +145,7 @@ def setUpClass(cls): cls.cnf = Config() cls.cdb = FakeCDB(cls.cnf) cls.vocab = Vocab() - cls.trainer = Trainer(cls.cdb, - cls.caller, FakePipeline()) + cls.trainer = Trainer(cls.cdb, FakePipeline(cls.caller)) def setUp(self): self.cnf = Config() @@ -203,8 +210,7 @@ def test_unsup_training_trains_non_linking_component(self): ner_component = FakeTrainableNERComponent() trainer = Trainer( self.cdb, - self.caller, - FakePipelineWithComponents([ner_component]), + FakePipelineWithComponents([ner_component], self.caller), ) trainer.config = self.cnf @@ -219,8 +225,7 @@ def test_unsup_training_skips_non_trainable_components(self): ner_component = FakeTrainableNERComponent() trainer = Trainer( self.cdb, - self.caller, - FakePipelineWithComponents([FakeComponent(), ner_component, object()]), + FakePipelineWithComponents([FakeComponent(), ner_component, object()], self.caller), ) trainer.config = self.cnf diff --git a/medcat-v2/tests/utils/test_training_utils.py b/medcat-v2/tests/utils/test_training_utils.py index 6e2549897..3bafac66c 100644 --- a/medcat-v2/tests/utils/test_training_utils.py +++ b/medcat-v2/tests/utils/test_training_utils.py @@ -4,6 +4,7 @@ from medcat.config import Config from medcat.components.types import CoreComponentType, AbstractEntityProvidingComponent from medcat.stats.stats import get_stats +from medcat.tokenizing.tokens import MutableDocument from medcat.trainer import Trainer from medcat.utils.training_utils import dataset_aware_component @@ -59,6 +60,9 @@ def entity_from_tokens_in_doc(self, tkns, doc: _FakeDoc): cui = doc.cui_by_start.get(start, "C_WRONG") return _FakeEntity(start, end, text, cui) + def __call__(self, text: str) -> _FakeDoc: + return _FakeDoc(text, {}) + class _EmptyNER(AbstractEntityProvidingComponent): name = "empty_ner" @@ -140,6 +144,9 @@ def get_component(self, comp_type): return comp raise KeyError(comp_type) + def get_doc(self, text: str) -> MutableDocument: + return self(self.tokenizer(text)) + def iter_all_components(self): return self._components @@ -231,7 +238,7 @@ def test_train_unsupervised_can_train_only_linker_when_ner_is_cheating(self): ner = _TrainableNER() linker = _TrainablePassThroughLinker() cat = _FakeCat(self.DATASET, [ner, linker]) - trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + trainer = Trainer(cat.cdb, cat.pipe) with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET): trainer.train_unsupervised(["abc def"], nepochs=1) @@ -243,7 +250,7 @@ def test_train_unsupervised_can_train_only_ner_when_linker_is_cheating(self): ner = _TrainableNER() linker = _TrainablePassThroughLinker() cat = _FakeCat(self.DATASET, [ner, linker]) - trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + trainer = Trainer(cat.cdb, cat.pipe) with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET): trainer.train_unsupervised(["abc def"], nepochs=1) @@ -255,7 +262,7 @@ def test_train_supervised_can_train_only_linker_when_ner_is_cheating(self): ner = _TrainableNER() linker = _TrainablePassThroughLinker() cat = _FakeCat(self.DATASET, [ner, linker]) - trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + trainer = Trainer(cat.cdb, cat.pipe) with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}): with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET): @@ -268,7 +275,7 @@ def test_train_supervised_can_train_only_ner_when_linker_is_cheating(self): ner = _TrainableNER() linker = _TrainablePassThroughLinker() cat = _FakeCat(self.DATASET, [ner, linker]) - trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + trainer = Trainer(cat.cdb, cat.pipe) with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}): with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET):