Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion medcat-v2/medcat/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ def _doc_to_out(self,
def trainer(self):
"""The trainer object."""
if not self._trainer:
self._trainer = Trainer(self.cdb, self.__call__, self._pipeline)
self._trainer = Trainer(self.cdb, self._pipeline)
return self._trainer

def save_model_card(self, model_card_path: str) -> None:
Expand Down
25 changes: 19 additions & 6 deletions medcat-v2/medcat/trainer.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Iterable, Callable, Optional, Union, cast
from typing import Iterable, Optional, Union, cast
import logging
import tempfile
from itertools import chain, repeat, islice
Expand Down Expand Up @@ -28,11 +28,10 @@
class Trainer:
strict_train: bool = False

def __init__(self, cdb: CDB, caller: Callable[[str], MutableDocument],
def __init__(self, cdb: CDB,
pipeline: Pipeline):
self.cdb = cdb
self.config = cdb.config
self.caller = caller
self._pipeline = pipeline

def train_unsupervised(self,
Expand Down Expand Up @@ -92,7 +91,7 @@ def _train_unsupervised(self,

# inference run for the document
try:
doc = self.caller(line)
doc = self._pipeline.get_doc(line)
except Exception as e:
logger.warning("LINE: '%s...' \t WAS SKIPPED", line[0:100])
logger.warning("BECAUSE OF:", exc_info=e)
Expand Down Expand Up @@ -416,7 +415,19 @@ def _prepare_doc_with_anns(
for ann in anns:
tkns = doc.get_tokens(ann['start'], ann['end'])
try:
ents.append(self._pipeline.entity_from_tokens_in_doc(tkns, doc))
ent = self._pipeline.entity_from_tokens_in_doc(tkns, doc)
pn_dict = prepare_name(ann['value'], self._pipeline.tokenizer, {},
self._pn_configs)
processed_names = list(pn_dict.keys())
if len(processed_names) > 1:
logger.info("Got multiple processed names for %s: %s",
ann['value'], processed_names)
elif not processed_names:
# NOTE: shouldn't really happen
raise ValueError(f"Could not process {ann['value']} into names")
ent.detected_name = processed_names[0]
ent.cui = ann['cui']
ents.append(ent)
Comment on lines +418 to +430
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As in the correct CUI isn't already set when performing training?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right - it doesn't matter. Even in the previous case this would have already gone through the pipe. Might as well add that as well.

except ValueError as err:
self._warn_on_error(
err, doc.base.text,
Expand Down Expand Up @@ -471,7 +482,8 @@ def _train_supervised_for_project2(self,
doc = docs[idx_doc]
with temp_changed_config(self.config.components.linking,
'train', False):
mut_doc = self.caller(doc['text'])
# NOTE: only need tokenization here
mut_doc = self._pipeline.tokenizer_with_tag(doc['text'])
self._prepare_doc_with_anns(mut_doc, doc, doc['annotations'])

# Compatibility with old output where annotations are a list
Expand Down Expand Up @@ -646,6 +658,7 @@ def add_and_train_concept(self,
mut_entity = self._pipeline.entity_from_tokens(mut_entity)
component.train(cui=cui, entity=mut_entity, doc=mut_doc,
negative=negative, names=names)
trained_comps += 1

if not negative and devalue_others:
# Find all cuis
Expand Down
19 changes: 12 additions & 7 deletions medcat-v2/tests/test_trainer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import json
from typing import Callable

from medcat.tokenizing.tokens import MutableDocument
from medcat.trainer import Trainer
Expand Down Expand Up @@ -100,6 +101,12 @@ def train(self, *args, **kwargs) -> None:

class FakePipeline:

def __init__(self, caller: Callable[[str], MutableDocument] = None) -> None:
self._caller = caller or FakeMutDoc

def get_doc(self, text: str) -> FakeMutDoc:
return self._caller(text)

def tokenizer(self, text: str) -> FakeMutDoc:
return FakeMutDoc(text)

Expand All @@ -118,7 +125,8 @@ def entity_from_tokens_in_doc(self, tkns: list, doc: MutableDocument) -> FakeMut

class FakePipelineWithComponents(FakePipeline):

def __init__(self, components: list):
def __init__(self, components: list, caller: Callable[[str], MutableDocument] = None):
super().__init__(caller)
self._components = components

def iter_all_components(self):
Expand All @@ -137,8 +145,7 @@ def setUpClass(cls):
cls.cnf = Config()
cls.cdb = FakeCDB(cls.cnf)
cls.vocab = Vocab()
cls.trainer = Trainer(cls.cdb,
cls.caller, FakePipeline())
cls.trainer = Trainer(cls.cdb, FakePipeline(cls.caller))

def setUp(self):
self.cnf = Config()
Expand Down Expand Up @@ -203,8 +210,7 @@ def test_unsup_training_trains_non_linking_component(self):
ner_component = FakeTrainableNERComponent()
trainer = Trainer(
self.cdb,
self.caller,
FakePipelineWithComponents([ner_component]),
FakePipelineWithComponents([ner_component], self.caller),
)
trainer.config = self.cnf

Expand All @@ -219,8 +225,7 @@ def test_unsup_training_skips_non_trainable_components(self):
ner_component = FakeTrainableNERComponent()
trainer = Trainer(
self.cdb,
self.caller,
FakePipelineWithComponents([FakeComponent(), ner_component, object()]),
FakePipelineWithComponents([FakeComponent(), ner_component, object()], self.caller),
)
trainer.config = self.cnf

Expand Down
15 changes: 11 additions & 4 deletions medcat-v2/tests/utils/test_training_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from medcat.config import Config
from medcat.components.types import CoreComponentType, AbstractEntityProvidingComponent
from medcat.stats.stats import get_stats
from medcat.tokenizing.tokens import MutableDocument
from medcat.trainer import Trainer
from medcat.utils.training_utils import dataset_aware_component

Expand Down Expand Up @@ -59,6 +60,9 @@ def entity_from_tokens_in_doc(self, tkns, doc: _FakeDoc):
cui = doc.cui_by_start.get(start, "C_WRONG")
return _FakeEntity(start, end, text, cui)

def __call__(self, text: str) -> _FakeDoc:
return _FakeDoc(text, {})


class _EmptyNER(AbstractEntityProvidingComponent):
name = "empty_ner"
Expand Down Expand Up @@ -140,6 +144,9 @@ def get_component(self, comp_type):
return comp
raise KeyError(comp_type)

def get_doc(self, text: str) -> MutableDocument:
return self(self.tokenizer(text))

def iter_all_components(self):
return self._components

Expand Down Expand Up @@ -231,7 +238,7 @@ def test_train_unsupervised_can_train_only_linker_when_ner_is_cheating(self):
ner = _TrainableNER()
linker = _TrainablePassThroughLinker()
cat = _FakeCat(self.DATASET, [ner, linker])
trainer = Trainer(cat.cdb, cat.__call__, cat.pipe)
trainer = Trainer(cat.cdb, cat.pipe)

with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET):
trainer.train_unsupervised(["abc def"], nepochs=1)
Expand All @@ -243,7 +250,7 @@ def test_train_unsupervised_can_train_only_ner_when_linker_is_cheating(self):
ner = _TrainableNER()
linker = _TrainablePassThroughLinker()
cat = _FakeCat(self.DATASET, [ner, linker])
trainer = Trainer(cat.cdb, cat.__call__, cat.pipe)
trainer = Trainer(cat.cdb, cat.pipe)

with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET):
trainer.train_unsupervised(["abc def"], nepochs=1)
Expand All @@ -255,7 +262,7 @@ def test_train_supervised_can_train_only_linker_when_ner_is_cheating(self):
ner = _TrainableNER()
linker = _TrainablePassThroughLinker()
cat = _FakeCat(self.DATASET, [ner, linker])
trainer = Trainer(cat.cdb, cat.__call__, cat.pipe)
trainer = Trainer(cat.cdb, cat.pipe)

with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}):
with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET):
Expand All @@ -268,7 +275,7 @@ def test_train_supervised_can_train_only_ner_when_linker_is_cheating(self):
ner = _TrainableNER()
linker = _TrainablePassThroughLinker()
cat = _FakeCat(self.DATASET, [ner, linker])
trainer = Trainer(cat.cdb, cat.__call__, cat.pipe)
trainer = Trainer(cat.cdb, cat.pipe)

with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}):
with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET):
Expand Down
Loading