Fill in missing morphologizer methods

This commit is contained in:
Matthew Honnibal 2018-09-25 22:12:54 +02:00
parent 53eb96db09
commit d0dc032842

View File

@ -9,9 +9,9 @@ from .util import msgpack
from .util import msgpack_numpy from .util import msgpack_numpy
from thinc.api import chain from thinc.api import chain
from thinc.neural.util import to_categorical, copy_array from thinc.neural.util import to_categorical, copy_array, get_array_module
from . import util from . import util
from .pipe import Pipe from .pipeline import Pipe
from ._ml import Tok2Vec, build_morphologizer_model from ._ml import Tok2Vec, build_morphologizer_model
from ._ml import link_vectors_to_models, zero_init, flatten from ._ml import link_vectors_to_models, zero_init, flatten
from ._ml import create_default_optimizer from ._ml import create_default_optimizer
@ -20,6 +20,7 @@ from .compat import json_dumps, basestring_
from .tokens.doc cimport Doc from .tokens.doc cimport Doc
from .vocab cimport Vocab from .vocab cimport Vocab
from .morphology cimport Morphology from .morphology cimport Morphology
from .morphology import parse_feature
from .pipeline import Pipe from .pipeline import Pipe
@ -118,7 +119,7 @@ class Morphologizer(Pipe):
target[idx] = guesses[idx] target[idx] = guesses[idx]
else: else:
for feature in features: for feature in features:
column = feature_to_column(feature) # TODO _, column = parse_feature(feature)
target[idx, column] = 1 target[idx, column] = 1
idx += 1 idx += 1
target = self.model.ops.xp.array(target, dtype='f') target = self.model.ops.xp.array(target, dtype='f')
@ -132,7 +133,10 @@ class Morphologizer(Pipe):
yield yield
def scores_to_guesses(scores, out_sizes): def scores_to_guesses(scores, out_sizes):
raise NotImplementedError xp = get_array_module(scores)
guesses = xp.zeros((scores.shape[0], len(out_sizes)), dtype='i')
def feature_to_column(feature): offset = 0
raise NotImplementedError for i, size in enumerate(out_sizes):
guesses[:, i] = scores[:, offset : offset + size].argmax(axis=1)
offset += size
return guesses