mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 16:24:16 +03:00
* Refactor away from the _ml module, to use thinc 4.0. Still some work needs to be done, e.g. to add __reduce__ to the models, more testing, etc.
This commit is contained in:
parent
c339783bbe
commit
3c162dcac3
1
setup.py
1
setup.py
|
@ -210,7 +210,6 @@ MOD_NAMES = ['spacy.parts_of_speech', 'spacy.strings',
|
|||
'spacy.lexeme', 'spacy.vocab', 'spacy.attrs',
|
||||
'spacy.morphology', 'spacy.tagger',
|
||||
'spacy.syntax.stateclass',
|
||||
'spacy._ml',
|
||||
'spacy.tokenizer',
|
||||
'spacy.syntax.parser',
|
||||
'spacy.syntax.transition_system',
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import absolute_import
|
||||
from os import path
|
||||
from warnings import warn
|
||||
import io
|
||||
|
@ -13,7 +14,6 @@ from .syntax.parser import Parser
|
|||
from .tagger import Tagger
|
||||
from .matcher import Matcher
|
||||
from .serialize.packer import Packer
|
||||
from ._ml import Model
|
||||
from . import attrs
|
||||
from . import orth
|
||||
from .syntax.ner import BiluoPushDown
|
||||
|
@ -245,9 +245,12 @@ class Language(object):
|
|||
def end_training(self, data_dir=None):
|
||||
if data_dir is None:
|
||||
data_dir = self.data_dir
|
||||
self.parser.model.end_training(path.join(data_dir, 'deps', 'model'))
|
||||
self.entity.model.end_training(path.join(data_dir, 'ner', 'model'))
|
||||
self.tagger.model.end_training(path.join(data_dir, 'pos', 'model'))
|
||||
self.parser.model.end_training()
|
||||
self.parser.model.dump(path.join(data_dir, 'deps', 'model'))
|
||||
self.entity.model.end_training()
|
||||
self.entity.model.dump(path.join(data_dir, 'ner', 'model'))
|
||||
self.tagger.model.end_training()
|
||||
self.tagger.model.dump(path.join(data_dir, 'pos', 'model'))
|
||||
|
||||
strings_loc = path.join(data_dir, 'vocab', 'strings.json')
|
||||
with io.open(strings_loc, 'w', encoding='utf8') as file_:
|
||||
|
|
|
@ -78,7 +78,7 @@ cdef class StringStore:
|
|||
def __init__(self, strings=None):
|
||||
self.mem = Pool()
|
||||
self._map = PreshMap()
|
||||
self._resize_at = 10
|
||||
self._resize_at = 10000
|
||||
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
||||
self.size = 1
|
||||
if strings is not None:
|
||||
|
|
|
@ -1,18 +1,17 @@
|
|||
from thinc.search cimport Beam
|
||||
from thinc.api cimport AveragedPerceptron
|
||||
from thinc.api cimport Example, ExampleC
|
||||
|
||||
from .._ml cimport Model
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from .arc_eager cimport TransitionSystem
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..structs cimport TokenC
|
||||
from thinc.api cimport Example, ExampleC
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
|
||||
cdef class ParserModel(AveragedPerceptron):
|
||||
cdef void set_features(self, ExampleC* eg, StateClass stcls) except *
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
cdef readonly Model model
|
||||
cdef readonly ParserModel model
|
||||
cdef readonly TransitionSystem moves
|
||||
|
||||
cdef void parse(self, StateClass stcls, ExampleC eg) nogil
|
||||
cdef void predict(self, StateClass stcls, ExampleC* eg) nogil
|
||||
|
|
|
@ -18,18 +18,15 @@ import sys
|
|||
from cymem.cymem cimport Pool, Address
|
||||
from murmurhash.mrmr cimport hash64
|
||||
from thinc.typedefs cimport weight_t, class_t, feat_t, atom_t, hash_t
|
||||
from thinc.features cimport ConjunctionExtracter
|
||||
|
||||
from util import Config
|
||||
|
||||
from thinc.api cimport Example, ExampleC
|
||||
|
||||
|
||||
from ..structs cimport TokenC
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..strings cimport StringStore
|
||||
|
||||
|
||||
from .transition_system import OracleError
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
|
||||
|
@ -40,7 +37,6 @@ from ._parse_features cimport CONTEXT_SIZE
|
|||
from ._parse_features cimport fill_context
|
||||
from .stateclass cimport StateClass
|
||||
|
||||
from thinc.learner cimport arg_max_if_true
|
||||
|
||||
|
||||
DEBUG = False
|
||||
|
@ -66,8 +62,18 @@ def ParserFactory(transition_system):
|
|||
return lambda strings, dir_: Parser(strings, dir_, transition_system)
|
||||
|
||||
|
||||
cdef class ParserModel(AveragedPerceptron):
|
||||
def __init__(self, n_classes, templates):
|
||||
AveragedPerceptron.__init__(self, n_classes,
|
||||
ConjunctionExtracter(CONTEXT_SIZE, templates))
|
||||
|
||||
cdef void set_features(self, ExampleC* eg, StateClass stcls) except *:
|
||||
fill_context(eg.atoms, stcls)
|
||||
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
def __init__(self, StringStore strings, transition_system, model):
|
||||
def __init__(self, StringStore strings, transition_system, ParserModel model):
|
||||
self.moves = transition_system
|
||||
self.model = model
|
||||
|
||||
|
@ -80,54 +86,50 @@ cdef class Parser:
|
|||
cfg = Config.read(model_dir, 'config')
|
||||
moves = transition_system(strings, cfg.labels)
|
||||
templates = get_templates(cfg.features)
|
||||
model = Model(moves.n_moves, templates, model_dir)
|
||||
model = ParserModel(moves.n_moves, templates)
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
model.load(path.join(model_dir, 'model'))
|
||||
return cls(strings, moves, model)
|
||||
|
||||
def __reduce__(self):
|
||||
return (Parser, (self.moves.strings, self.moves, self.model), None, None)
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
self.model.n_feats, self.model.n_feats)
|
||||
self.parse(stcls, eg.c)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
def __reduce__(self):
|
||||
return (Parser, (self.moves.strings, self.moves, self.model), None, None)
|
||||
|
||||
cdef void predict(self, StateClass stcls, ExampleC* eg) nogil:
|
||||
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
|
||||
self.moves.set_valid(eg.is_valid, stcls)
|
||||
fill_context(eg.atoms, stcls)
|
||||
self.model.set_scores(eg.scores, eg.atoms)
|
||||
eg.guess = arg_max_if_true(eg.scores, eg.is_valid, self.model.n_classes)
|
||||
|
||||
cdef void parse(self, StateClass stcls, ExampleC eg) nogil:
|
||||
cdef Pool mem = Pool()
|
||||
cdef ExampleC eg = self.model.allocate(mem)
|
||||
while not stcls.is_final():
|
||||
self.predict(stcls, &eg)
|
||||
if not eg.is_valid[eg.guess]:
|
||||
break
|
||||
self.moves.c[eg.guess].do(stcls, self.moves.c[eg.guess].label)
|
||||
self.moves.finalize_state(stcls)
|
||||
self.model.set_features(&eg, stcls)
|
||||
self.moves.set_valid(eg.is_valid, stcls)
|
||||
self.model.set_prediction(&eg)
|
||||
|
||||
assert eg.is_valid[eg.guess]
|
||||
|
||||
action = self.moves.c[eg.guess]
|
||||
action.do(stcls, action.label)
|
||||
self.moves.finalize_state(stcls)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
def train(self, Doc tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||
self.model.n_feats, self.model.n_feats)
|
||||
cdef Pool mem = Pool()
|
||||
cdef ExampleC eg = self.model.allocate(mem)
|
||||
cdef weight_t loss = 0
|
||||
words = [w.orth_ for w in tokens]
|
||||
cdef Transition G
|
||||
cdef Transition action
|
||||
while not stcls.is_final():
|
||||
memset(eg.c.scores, 0, eg.c.nr_class * sizeof(weight_t))
|
||||
self.moves.set_costs(eg.c.is_valid, eg.c.costs, stcls, gold)
|
||||
fill_context(eg.c.atoms, stcls)
|
||||
self.model.train(eg)
|
||||
G = self.moves.c[eg.c.guess]
|
||||
self.model.set_features(&eg, stcls)
|
||||
self.moves.set_costs(eg.is_valid, eg.costs, stcls, gold)
|
||||
self.model.set_prediction(&eg)
|
||||
self.model.update(&eg)
|
||||
|
||||
self.moves.c[eg.c.guess].do(stcls, self.moves.c[eg.c.guess].label)
|
||||
loss += eg.c.loss
|
||||
action = self.moves.c[eg.guess]
|
||||
action.do(stcls, action.label)
|
||||
loss += eg.costs[eg.guess]
|
||||
return loss
|
||||
|
||||
def step_through(self, Doc doc):
|
||||
|
@ -176,7 +178,10 @@ cdef class StepwiseState:
|
|||
for i in range(self.stcls.length)]
|
||||
|
||||
def predict(self):
|
||||
self.parser.predict(self.stcls, &self.eg.c)
|
||||
self.parser.model.set_features(&self.eg.c, self.stcls)
|
||||
self.parser.moves.set_valid(self.eg.c.is_valid, self.stcls)
|
||||
self.parser.model.set_prediction(&self.eg.c)
|
||||
|
||||
action = self.parser.moves.c[self.eg.c.guess]
|
||||
return self.parser.moves.move_name(action.move, action.label)
|
||||
|
||||
|
|
|
@ -1,9 +1,17 @@
|
|||
from ._ml cimport Model
|
||||
from thinc.api cimport AveragedPerceptron
|
||||
from thinc.api cimport ExampleC
|
||||
|
||||
from .structs cimport TokenC
|
||||
from .vocab cimport Vocab
|
||||
|
||||
|
||||
cdef class TaggerModel(AveragedPerceptron):
|
||||
cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except *
|
||||
cdef void set_costs(self, ExampleC* eg, int gold) except *
|
||||
cdef void update(self, ExampleC* eg) except *
|
||||
|
||||
|
||||
cdef class Tagger:
|
||||
cdef readonly Vocab vocab
|
||||
cdef readonly Model model
|
||||
cdef readonly TaggerModel model
|
||||
cdef public dict freqs
|
||||
|
|
151
spacy/tagger.pyx
151
spacy/tagger.pyx
|
@ -1,10 +1,12 @@
|
|||
import json
|
||||
from os import path
|
||||
from collections import defaultdict
|
||||
from libc.string cimport memset
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.typedefs cimport atom_t, weight_t
|
||||
from thinc.learner cimport arg_max, arg_max_if_true, arg_max_if_zero
|
||||
from thinc.api cimport Example
|
||||
from thinc.api cimport Example, ExampleC
|
||||
from thinc.features cimport ConjunctionExtracter
|
||||
|
||||
from .typedefs cimport attr_t
|
||||
from .tokens.doc cimport Doc
|
||||
|
@ -64,6 +66,44 @@ cpdef enum:
|
|||
N_CONTEXT_FIELDS
|
||||
|
||||
|
||||
cdef class TaggerModel(AveragedPerceptron):
|
||||
def __init__(self, n_classes, templates):
|
||||
AveragedPerceptron.__init__(self, n_classes,
|
||||
ConjunctionExtracter(N_CONTEXT_FIELDS, templates))
|
||||
|
||||
cdef void set_features(self, ExampleC* eg, const TokenC* tokens, int i) except *:
|
||||
_fill_from_token(&eg.atoms[P2_orth], &tokens[i-2])
|
||||
_fill_from_token(&eg.atoms[P1_orth], &tokens[i-1])
|
||||
_fill_from_token(&eg.atoms[W_orth], &tokens[i])
|
||||
_fill_from_token(&eg.atoms[N1_orth], &tokens[i+1])
|
||||
_fill_from_token(&eg.atoms[N2_orth], &tokens[i+2])
|
||||
|
||||
eg.nr_feat = self.extracter.set_features(eg.features, eg.atoms)
|
||||
|
||||
cdef void update(self, ExampleC* eg) except *:
|
||||
self.updater.update(eg)
|
||||
|
||||
|
||||
cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
|
||||
context[0] = t.lex.lower
|
||||
context[1] = t.lex.cluster
|
||||
context[2] = t.lex.shape
|
||||
context[3] = t.lex.prefix
|
||||
context[4] = t.lex.suffix
|
||||
context[5] = t.tag
|
||||
context[6] = t.lemma
|
||||
if t.lex.flags & (1 << IS_ALPHA):
|
||||
context[7] = 1
|
||||
elif t.lex.flags & (1 << IS_PUNCT):
|
||||
context[7] = 2
|
||||
elif t.lex.flags & (1 << LIKE_URL):
|
||||
context[7] = 3
|
||||
elif t.lex.flags & (1 << LIKE_NUM):
|
||||
context[7] = 4
|
||||
else:
|
||||
context[7] = 0
|
||||
|
||||
|
||||
cdef class Tagger:
|
||||
"""A part-of-speech tagger for English"""
|
||||
@classmethod
|
||||
|
@ -105,7 +145,7 @@ cdef class Tagger:
|
|||
|
||||
@classmethod
|
||||
def blank(cls, vocab, templates):
|
||||
model = Model(vocab.morphology.n_tags, templates, model_loc=None)
|
||||
model = TaggerModel(vocab.morphology.n_tags, templates)
|
||||
return cls(vocab, model)
|
||||
|
||||
@classmethod
|
||||
|
@ -114,10 +154,12 @@ cdef class Tagger:
|
|||
templates = json.loads(open(path.join(data_dir, 'templates.json')))
|
||||
else:
|
||||
templates = cls.default_templates()
|
||||
model = Model(vocab.morphology.n_tags, templates, data_dir)
|
||||
model = TaggerModel(vocab.morphology.n_tags, templates)
|
||||
if path.exists(path.join(data_dir, 'model')):
|
||||
model.load(path.join(data_dir, 'model'))
|
||||
return cls(vocab, model)
|
||||
|
||||
def __init__(self, Vocab vocab, model):
|
||||
def __init__(self, Vocab vocab, TaggerModel model):
|
||||
self.vocab = vocab
|
||||
self.model = model
|
||||
|
||||
|
@ -131,27 +173,6 @@ cdef class Tagger:
|
|||
def tag_names(self):
|
||||
return self.vocab.morphology.tag_names
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
"""Apply the tagger, setting the POS tags onto the Doc object.
|
||||
|
||||
Args:
|
||||
tokens (Doc): The tokens to be tagged.
|
||||
"""
|
||||
if tokens.length == 0:
|
||||
return 0
|
||||
|
||||
cdef Example eg = self.model._eg
|
||||
cdef int i
|
||||
for i in range(tokens.length):
|
||||
if tokens.c[i].pos == 0:
|
||||
eg.wipe()
|
||||
fill_atoms(eg.c.atoms, tokens.c, i)
|
||||
self.model(eg)
|
||||
self.vocab.morphology.assign_tag(&tokens.c[i], eg.c.guess)
|
||||
|
||||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab, self.model), None, None)
|
||||
|
||||
|
@ -162,53 +183,45 @@ cdef class Tagger:
|
|||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
"""Apply the tagger, setting the POS tags onto the Doc object.
|
||||
|
||||
Args:
|
||||
tokens (Doc): The tokens to be tagged.
|
||||
"""
|
||||
if tokens.length == 0:
|
||||
return 0
|
||||
|
||||
cdef Pool mem = Pool()
|
||||
cdef ExampleC eg
|
||||
|
||||
cdef int i, tag
|
||||
for i in range(tokens.length):
|
||||
if tokens.c[i].pos == 0:
|
||||
eg = self.model.allocate(mem)
|
||||
self.model.set_features(&eg, tokens.c, i)
|
||||
self.model.set_prediction(&eg)
|
||||
self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess)
|
||||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
|
||||
def train(self, Doc tokens, object gold_tag_strs):
|
||||
assert len(tokens) == len(gold_tag_strs)
|
||||
cdef int i
|
||||
cdef int loss
|
||||
cdef const weight_t* scores
|
||||
try:
|
||||
golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs]
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
[g for g in gold_tag_strs if g is not None and g not in self.tag_names])
|
||||
correct = 0
|
||||
cdef Example eg = self.model._eg
|
||||
golds = [self.tag_names.index(g) if g is not None else -1 for g in gold_tag_strs]
|
||||
cdef int correct = 0
|
||||
cdef Pool mem = Pool()
|
||||
cdef ExampleC eg
|
||||
for i in range(tokens.length):
|
||||
eg.wipe()
|
||||
fill_atoms(eg.c.atoms, tokens.c, i)
|
||||
self.train(eg)
|
||||
eg = self.model.allocate(mem)
|
||||
self.model.set_features(&eg, tokens.c, i)
|
||||
self.model.set_costs(&eg, golds[i])
|
||||
self.model.set_prediction(&eg)
|
||||
self.model.update(&eg)
|
||||
|
||||
self.vocab.morphology.assign_tag(&tokens.c[i], eg.c.guess)
|
||||
self.vocab.morphology.assign_tag(&tokens.c[i], eg.guess)
|
||||
|
||||
correct += eg.c.cost == 0
|
||||
correct += eg.cost == 0
|
||||
self.freqs[TAG][tokens.c[i].tag] += 1
|
||||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
return correct
|
||||
|
||||
|
||||
cdef inline void fill_atoms(atom_t* atoms, const TokenC* tokens, int i) nogil:
|
||||
_fill_from_token(&atoms[P2_orth], &tokens[i-2])
|
||||
_fill_from_token(&atoms[P1_orth], &tokens[i-1])
|
||||
_fill_from_token(&atoms[W_orth], &tokens[i])
|
||||
_fill_from_token(&atoms[N1_orth], &tokens[i+1])
|
||||
_fill_from_token(&atoms[N2_orth], &tokens[i+2])
|
||||
|
||||
|
||||
cdef inline void _fill_from_token(atom_t* context, const TokenC* t) nogil:
|
||||
context[0] = t.lex.lower
|
||||
context[1] = t.lex.cluster
|
||||
context[2] = t.lex.shape
|
||||
context[3] = t.lex.prefix
|
||||
context[4] = t.lex.suffix
|
||||
context[5] = t.tag
|
||||
context[6] = t.lemma
|
||||
if t.lex.flags & (1 << IS_ALPHA):
|
||||
context[7] = 1
|
||||
elif t.lex.flags & (1 << IS_PUNCT):
|
||||
context[7] = 2
|
||||
elif t.lex.flags & (1 << LIKE_URL):
|
||||
context[7] = 3
|
||||
elif t.lex.flags & (1 << LIKE_NUM):
|
||||
context[7] = 4
|
||||
else:
|
||||
context[7] = 0
|
||||
|
|
|
@ -11,7 +11,6 @@ from spacy.strings import StringStore
|
|||
from spacy.vocab import Vocab
|
||||
from spacy.tokenizer import Tokenizer
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy._ml import Model
|
||||
from spacy.tagger import Tagger
|
||||
from spacy.syntax.parser import Parser
|
||||
from spacy.matcher import Matcher
|
||||
|
|
|
@ -12,16 +12,13 @@ from spacy.strings import StringStore
|
|||
from spacy.vocab import Vocab
|
||||
from spacy.tokenizer import Tokenizer
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy._ml import Model
|
||||
from spacy.tagger import Tagger
|
||||
from spacy.syntax.parser import Parser
|
||||
from spacy.syntax.parser import Parser, ParserModel
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.syntax.parser import get_templates
|
||||
|
||||
from spacy.en import English
|
||||
|
||||
from thinc.learner import LinearModel
|
||||
|
||||
|
||||
class TestLoadVocab(unittest.TestCase):
|
||||
def test_load(self):
|
||||
|
@ -54,7 +51,6 @@ class TestLoadParser(unittest.TestCase):
|
|||
if path.exists(path.join(data_dir, 'deps')):
|
||||
parser = Parser.from_dir(path.join(data_dir, 'deps'), vocab.strings, ArcEager)
|
||||
|
||||
def test_load_careful(self):
|
||||
config_data = {"labels": {"0": {"": True}, "1": {"": True}, "2": {"cc": True, "agent": True, "ccomp": True, "prt": True, "meta": True, "nsubjpass": True, "csubj": True, "conj": True, "dobj": True, "neg": True, "csubjpass": True, "mark": True, "auxpass": True, "advcl": True, "aux": True, "ROOT": True, "prep": True, "parataxis": True, "xcomp": True, "nsubj": True, "nummod": True, "advmod": True, "punct": True, "relcl": True, "quantmod": True, "acomp": True, "compound": True, "pcomp": True, "intj": True, "poss": True, "npadvmod": True, "case": True, "attr": True, "dep": True, "appos": True, "det": True, "nmod": True, "amod": True, "dative": True, "pobj": True, "expl": True, "predet": True, "preconj": True, "oprd": True, "acl": True}, "3": {"cc": True, "agent": True, "ccomp": True, "prt": True, "meta": True, "nsubjpass": True, "csubj": True, "conj": True, "acl": True, "poss": True, "neg": True, "mark": True, "auxpass": True, "advcl": True, "aux": True, "amod": True, "ROOT": True, "prep": True, "parataxis": True, "xcomp": True, "nsubj": True, "nummod": True, "advmod": True, "punct": True, "quantmod": True, "acomp": True, "pcomp": True, "intj": True, "relcl": True, "npadvmod": True, "case": True, "attr": True, "dep": True, "appos": True, "det": True, "nmod": True, "dobj": True, "dative": True, "pobj": True, "iobj": True, "expl": True, "predet": True, "preconj": True, "oprd": True}, "4": {"ROOT": True}}, "seed": 0, "features": "basic", "beam_width": 1}
|
||||
|
||||
data_dir = English.default_data_dir()
|
||||
|
@ -63,20 +59,11 @@ class TestLoadParser(unittest.TestCase):
|
|||
moves = ArcEager(vocab.strings, config_data['labels'])
|
||||
templates = get_templates(config_data['features'])
|
||||
|
||||
model = Model(moves.n_moves, templates, path.join(data_dir, 'deps'))
|
||||
model = ParserModel(moves.n_moves, templates)
|
||||
model.load(path.join(data_dir, 'deps', 'model'))
|
||||
|
||||
parser = Parser(vocab.strings, moves, model)
|
||||
|
||||
def test_thinc_load(self):
|
||||
data_dir = English.default_data_dir()
|
||||
model_loc = path.join(data_dir, 'deps', 'model')
|
||||
|
||||
# n classes. moves.n_moves above
|
||||
# n features. len(templates) + 1 above
|
||||
if path.exists(model_loc):
|
||||
model = LinearModel(92, 116)
|
||||
model.load(model_loc)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
|
Loading…
Reference in New Issue
Block a user