Use cross entropy from thinc.legacy

This commit is contained in:
Daniël de Kok 2022-12-02 16:30:03 +01:00
parent 1d705cc683
commit 1041f68be2
4 changed files with 16 additions and 9 deletions

View File

@ -5,8 +5,9 @@ from itertools import islice
import numpy as np import numpy as np
import srsly import srsly
from thinc.api import Config, Model, SequenceCategoricalCrossentropy from thinc.api import Config, Model
from thinc.types import ArrayXd, Floats2d, Ints1d from thinc.types import ArrayXd, Floats2d, Ints1d
from thinc.legacy import LegacySequenceCategoricalCrossentropy
from ._edit_tree_internals.edit_trees import EditTrees from ._edit_tree_internals.edit_trees import EditTrees
from ._edit_tree_internals.schemas import validate_edit_tree from ._edit_tree_internals.schemas import validate_edit_tree
@ -129,7 +130,9 @@ class EditTreeLemmatizer(TrainablePipe):
self, examples: Iterable[Example], scores: List[Floats2d] self, examples: Iterable[Example], scores: List[Floats2d]
) -> Tuple[float, List[Floats2d]]: ) -> Tuple[float, List[Floats2d]]:
validate_examples(examples, "EditTreeLemmatizer.get_loss") validate_examples(examples, "EditTreeLemmatizer.get_loss")
loss_func = SequenceCategoricalCrossentropy(normalize=False, missing_value=-1) loss_func = LegacySequenceCategoricalCrossentropy(
normalize=False, missing_value=-1
)
truths = [] truths = []
for eg in examples: for eg in examples:

View File

@ -1,7 +1,8 @@
# cython: infer_types=True, profile=True, binding=True # cython: infer_types=True, profile=True, binding=True
from typing import Callable, Dict, Iterable, List, Optional, Union from typing import Callable, Dict, Iterable, List, Optional, Union
import srsly import srsly
from thinc.api import SequenceCategoricalCrossentropy, Model, Config from thinc.api import Model, Config
from thinc.legacy import LegacySequenceCategoricalCrossentropy
from thinc.types import Floats2d, Ints1d from thinc.types import Floats2d, Ints1d
from itertools import islice from itertools import islice
@ -290,7 +291,7 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#get_loss DOCS: https://spacy.io/api/morphologizer#get_loss
""" """
validate_examples(examples, "Morphologizer.get_loss") validate_examples(examples, "Morphologizer.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=tuple(self.labels), normalize=False) loss_func = LegacySequenceCategoricalCrossentropy(names=tuple(self.labels), normalize=False)
truths = [] truths = []
for eg in examples: for eg in examples:
eg_truths = [] eg_truths = []

View File

@ -3,7 +3,9 @@ from typing import Dict, Iterable, Optional, Callable, List, Union
from itertools import islice from itertools import islice
import srsly import srsly
from thinc.api import Model, SequenceCategoricalCrossentropy, Config from thinc.api import Model, Config
from thinc.legacy import LegacySequenceCategoricalCrossentropy
from thinc.types import Floats2d, Ints1d from thinc.types import Floats2d, Ints1d
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -161,7 +163,7 @@ class SentenceRecognizer(Tagger):
""" """
validate_examples(examples, "SentenceRecognizer.get_loss") validate_examples(examples, "SentenceRecognizer.get_loss")
labels = self.labels labels = self.labels
loss_func = SequenceCategoricalCrossentropy(names=labels, normalize=False) loss_func = LegacySequenceCategoricalCrossentropy(names=labels, normalize=False)
truths = [] truths = []
for eg in examples: for eg in examples:
eg_truth = [] eg_truth = []

View File

@ -2,7 +2,8 @@
from typing import Callable, Dict, Iterable, List, Optional, Union from typing import Callable, Dict, Iterable, List, Optional, Union
import numpy import numpy
import srsly import srsly
from thinc.api import Model, set_dropout_rate, SequenceCategoricalCrossentropy, Config from thinc.api import Model, set_dropout_rate, Config
from thinc.legacy import LegacySequenceCategoricalCrossentropy
from thinc.types import Floats2d, Ints1d from thinc.types import Floats2d, Ints1d
import warnings import warnings
from itertools import islice from itertools import islice
@ -244,7 +245,7 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#rehearse DOCS: https://spacy.io/api/tagger#rehearse
""" """
loss_func = SequenceCategoricalCrossentropy() loss_func = LegacySequenceCategoricalCrossentropy()
if losses is None: if losses is None:
losses = {} losses = {}
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
@ -275,7 +276,7 @@ class Tagger(TrainablePipe):
DOCS: https://spacy.io/api/tagger#get_loss DOCS: https://spacy.io/api/tagger#get_loss
""" """
validate_examples(examples, "Tagger.get_loss") validate_examples(examples, "Tagger.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"]) loss_func = LegacySequenceCategoricalCrossentropy(names=self.labels, normalize=False, neg_prefix=self.cfg["neg_prefix"])
# Convert empty tag "" to missing value None so that both misaligned # Convert empty tag "" to missing value None so that both misaligned
# tokens and tokens with missing annotation have the default missing # tokens and tokens with missing annotation have the default missing
# value None. # value None.