Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions medcat-v2/medcat/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,19 @@ def _train_supervised_for_project(self,
docs, current_document, train_from_false_positives,
devalue_others)

def _get_processed_name(self, raw_name: str) -> str:
pn_dict = prepare_name(
raw_name, self._pipeline.tokenizer_with_tag, {},
self._pn_configs)
processed_names = list(pn_dict.keys())
if len(processed_names) > 1:
logger.info("Got multiple processed names for %s: %s",
raw_name, processed_names)
elif not processed_names:
# NOTE: shouldn't really happen
raise ValueError(f"Could not process {raw_name} into names")
return processed_names[0]

def _prepare_doc_with_anns(
self, doc: MutableDocument, ann_doc: MedCATTrainerExportDocument,
anns: list[MedCATTrainerExportAnnotation]) -> None:
Expand All @@ -416,16 +429,7 @@ def _prepare_doc_with_anns(
tkns = doc.get_tokens(ann['start'], ann['end'])
try:
ent = self._pipeline.entity_from_tokens_in_doc(tkns, doc)
pn_dict = prepare_name(ann['value'], self._pipeline.tokenizer, {},
self._pn_configs)
processed_names = list(pn_dict.keys())
if len(processed_names) > 1:
logger.info("Got multiple processed names for %s: %s",
ann['value'], processed_names)
elif not processed_names:
# NOTE: shouldn't really happen
raise ValueError(f"Could not process {ann['value']} into names")
ent.detected_name = processed_names[0]
ent.detected_name = self._get_processed_name(ann['value'])
ent.cui = ann['cui']
ents.append(ent)
except ValueError as err:
Expand Down
30 changes: 30 additions & 0 deletions medcat-v2/tests/test_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from medcat import cat
from medcat.data.mctexport import count_all_annotations, iter_anns
from medcat.data.model_card import ModelCard
from medcat.preprocessors.cleaners import prepare_name
from medcat.vocab import Vocab
from medcat.config import Config
from medcat.config.config_meta_cat import ConfigMetaCAT
Expand Down Expand Up @@ -610,6 +611,35 @@ def _perform_training(cls):
data = cls._get_data()
cls.cat.trainer.train_supervised_raw(data)

def test_prepare_name_removes_new_lines(self):
# NOTE: This is easiest to test if I have the model pack
# available (to run tokenizer and tagger).
# The reason we need to check this is because in supervised
# training the detected name is set as per the prepare_name
# output and if you run that without the tagger it will
# keep stuff like new liens
text = "something\nwas\ndone"
names = prepare_name(
text, self.cat.pipe.tokenizer_with_tag, {},
(self.cat.config.general,
self.cat.config.preprocessing,
self.cat.config.cdb_maker))
self.assertEqual(len(names), 1)
name = list(names)[0]
self._assert_name_processed_correctly(text, name)

def _assert_name_processed_correctly(self, text: str, name: str):
self.assertNotIn("\n", name)
self.assertEqual(
name.count(self.cat.config.general.separator),
text.count("\n"),
"All new lines should convert to single separators")

def test_trainer_name_processor_removes_new_lines(self):
text = "something\nwas\ndone"
name = self.cat.trainer._get_processed_name(text)
self._assert_name_processed_correctly(text, name)

def test_lists_sup_train_in_config(self):
self.assertTrue(self.cat.config.meta.sup_trained)

Expand Down
Loading