mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
* POS tagger training working after reorg
This commit is contained in:
parent
4c4aa2c5c9
commit
cf8d26c3d2
|
@ -9,6 +9,7 @@ from ..tokens import Tokens
|
|||
from ..morphology import Morphologizer
|
||||
from .lemmatizer import Lemmatizer
|
||||
from .pos import EnPosTagger
|
||||
from .pos import POS_TAGS
|
||||
from .attrs import get_flags
|
||||
|
||||
|
||||
|
@ -21,13 +22,13 @@ class English(object):
|
|||
if data_dir is None:
|
||||
data_dir = path.join(path.dirname(__file__), 'data')
|
||||
self.vocab = Vocab.from_dir(data_dir, get_lex_props=get_lex_props)
|
||||
for pos_str in POS_TAGS:
|
||||
_ = self.vocab.strings.pos_tags[pos_str]
|
||||
self.tokenizer = Tokenizer.from_dir(self.vocab, data_dir)
|
||||
if pos_tag:
|
||||
self.pos_tagger = EnPosTagger(data_dir,
|
||||
Morphologizer.from_dir(
|
||||
self.vocab.strings,
|
||||
Lemmatizer(path.join(data_dir, 'wordnet')),
|
||||
data_dir))
|
||||
morph = Morphologizer(self.vocab.strings, POS_TAGS,
|
||||
Lemmatizer(path.join(data_dir, 'wordnet')))
|
||||
self.pos_tagger = EnPosTagger(data_dir, morph)
|
||||
else:
|
||||
self.pos_tagger = None
|
||||
if parse:
|
||||
|
|
|
@ -35,15 +35,15 @@ cdef struct _Cached:
|
|||
cdef class Morphologizer:
|
||||
"""Given a POS tag and a Lexeme, find its lemma and morphological analysis.
|
||||
"""
|
||||
def __init__(self, StringStore strings, object lemmatizer,
|
||||
irregulars=None, tag_map=None, tag_names=None):
|
||||
def __init__(self, StringStore strings, object tag_map, object lemmatizer,
|
||||
irregulars=None):
|
||||
self.mem = Pool()
|
||||
self.strings = strings
|
||||
self.tag_names = tag_names
|
||||
self.lemmatizer = lemmatizer
|
||||
self._cache = PreshMapArray(len(self.tag_names))
|
||||
self.tags = <PosTag*>self.mem.alloc(len(self.tag_names), sizeof(PosTag))
|
||||
for i, tag in enumerate(self.tag_names):
|
||||
cdef int n_tags = len(self.strings.pos_tags) + 1
|
||||
self._cache = PreshMapArray(n_tags)
|
||||
self.tags = <PosTag*>self.mem.alloc(n_tags, sizeof(PosTag))
|
||||
for tag, i in self.strings.pos_tags:
|
||||
pos, props = tag_map[tag]
|
||||
self.tags[i].id = i
|
||||
self.tags[i].pos = pos
|
||||
|
@ -57,15 +57,6 @@ cdef class Morphologizer:
|
|||
if irregulars is not None:
|
||||
self.load_exceptions(irregulars)
|
||||
|
||||
@classmethod
|
||||
def from_dir(cls, StringStore strings, object lemmatizer, data_dir):
|
||||
tagger_cfg = json.loads(open(path.join(data_dir, 'pos', 'config.json')).read())
|
||||
tag_map = tagger_cfg['tag_map']
|
||||
tag_names = tagger_cfg['tag_names']
|
||||
irregulars = json.loads(open(path.join(data_dir, 'morphs.json')).read())
|
||||
return cls(strings, lemmatizer, tag_map=tag_map, irregulars=irregulars,
|
||||
tag_names=tag_names)
|
||||
|
||||
cdef int lemmatize(self, const univ_tag_t pos, const Lexeme* lex) except -1:
|
||||
if self.lemmatizer is None:
|
||||
return lex.sic
|
||||
|
@ -104,9 +95,10 @@ cdef class Morphologizer:
|
|||
cdef dict props
|
||||
cdef int lemma
|
||||
cdef id_t sic
|
||||
cdef univ_tag_t pos
|
||||
cdef int pos
|
||||
for pos_str, entries in exc.items():
|
||||
pos = self.tag_names.index(pos_str)
|
||||
pos = self.strings.pos_tags[pos_str]
|
||||
assert pos < len(self.strings.pos_tags)
|
||||
for form_str, props in entries.items():
|
||||
lemma_str = props.get('L', form_str)
|
||||
sic = self.strings[form_str]
|
||||
|
|
|
@ -19,6 +19,8 @@ cdef class _SymbolMap:
|
|||
cdef class StringStore:
|
||||
cdef Pool mem
|
||||
cdef Utf8Str* strings
|
||||
cdef readonly _SymbolMap pos_tags
|
||||
cdef readonly _SymbolMap dep_tags
|
||||
cdef size_t size
|
||||
|
||||
cdef PreshMap _map
|
||||
|
|
|
@ -18,6 +18,9 @@ cdef class _SymbolMap:
|
|||
for id_, string in enumerate(self._id_to_string[1:]):
|
||||
yield string, id_
|
||||
|
||||
def __len__(self):
|
||||
return len(self._id_to_string)
|
||||
|
||||
def __getitem__(self, object string_or_id):
|
||||
cdef bytes byte_string
|
||||
if isinstance(string_or_id, int) or isinstance(string_or_id, long):
|
||||
|
@ -42,6 +45,7 @@ cdef class StringStore:
|
|||
self.mem = Pool()
|
||||
self._map = PreshMap()
|
||||
self._resize_at = 10000
|
||||
self.size = 1
|
||||
self.strings = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
||||
self.pos_tags = _SymbolMap()
|
||||
self.dep_tags = _SymbolMap()
|
||||
|
|
|
@ -18,6 +18,3 @@ cdef class Tagger:
|
|||
cpdef readonly Pool mem
|
||||
cpdef readonly Extractor extractor
|
||||
cpdef readonly LinearModel model
|
||||
|
||||
cpdef readonly list tag_names
|
||||
cdef dict tagdict
|
||||
|
|
|
@ -12,15 +12,13 @@ import cython
|
|||
from thinc.features cimport Feature, count_feats
|
||||
|
||||
|
||||
def setup_model_dir(tag_names, tag_map, tag_counts, templates, model_dir):
|
||||
def setup_model_dir(tag_names, templates, model_dir):
|
||||
if path.exists(model_dir):
|
||||
shutil.rmtree(model_dir)
|
||||
os.mkdir(model_dir)
|
||||
config = {
|
||||
'templates': templates,
|
||||
'tag_names': tag_names,
|
||||
'tag_map': tag_map,
|
||||
'tag_counts': tag_counts,
|
||||
}
|
||||
with open(path.join(model_dir, 'config.json'), 'w') as file_:
|
||||
json.dump(config, file_)
|
||||
|
@ -37,10 +35,9 @@ cdef class Tagger:
|
|||
univ_counts = {}
|
||||
cdef unicode tag
|
||||
cdef unicode univ_tag
|
||||
self.tag_names = cfg['tag_names']
|
||||
self.tagdict = _make_tag_dict(cfg['tag_counts'])
|
||||
tag_names = cfg['tag_names']
|
||||
self.extractor = Extractor(templates)
|
||||
self.model = LinearModel(len(self.tag_names), self.extractor.n_templ+2)
|
||||
self.model = LinearModel(len(tag_names) + 1, self.extractor.n_templ+2) # TODO
|
||||
if path.exists(path.join(model_dir, 'model')):
|
||||
self.model.load(path.join(model_dir, 'model'))
|
||||
|
||||
|
@ -63,30 +60,6 @@ cdef class Tagger:
|
|||
self.model.update(counts)
|
||||
return guess
|
||||
|
||||
def tag_id(self, object tag_name):
|
||||
"""Encode tag_name into a tag ID integer."""
|
||||
tag_id = self.tag_names.index(tag_name)
|
||||
if tag_id == -1:
|
||||
tag_id = len(self.tag_names)
|
||||
self.tag_names.append(tag_name)
|
||||
return tag_id
|
||||
|
||||
|
||||
def _make_tag_dict(counts):
|
||||
freq_thresh = 20
|
||||
ambiguity_thresh = 0.97
|
||||
tagdict = {}
|
||||
cdef atom_t word
|
||||
cdef atom_t tag
|
||||
for word_str, tag_freqs in counts.items():
|
||||
tag_str, mode = max(tag_freqs.items(), key=lambda item: item[1])
|
||||
n = sum(tag_freqs.values())
|
||||
word = int(word_str)
|
||||
tag = int(tag_str)
|
||||
if n >= freq_thresh and (float(mode) / n) >= ambiguity_thresh:
|
||||
tagdict[word] = tag
|
||||
return tagdict
|
||||
|
||||
|
||||
cdef int _arg_max(const weight_t* scores, int n_classes) except -1:
|
||||
cdef int best = 0
|
||||
|
|
|
@ -39,10 +39,10 @@ cdef class Token:
|
|||
cdef readonly StringStore string_store
|
||||
cdef public int i
|
||||
cdef public int idx
|
||||
cdef int pos
|
||||
cdef readonly int pos_id
|
||||
cdef readonly int dep_id
|
||||
cdef int lemma
|
||||
cdef public int head
|
||||
cdef public int dep_tag
|
||||
|
||||
cdef public atom_t id
|
||||
cdef public atom_t cluster
|
||||
|
|
|
@ -1,19 +1,25 @@
|
|||
from __future__ import unicode_literals
|
||||
from spacy.en import EN
|
||||
from spacy.en import English
|
||||
|
||||
import pytest
|
||||
|
||||
from spacy.en import English
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def EN():
|
||||
return English(pos_tag=True, parse=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def morph_exc():
|
||||
return {
|
||||
'PRP$': {'his': {'L': '-PRP-', 'person': 3, 'case': 2}},
|
||||
}
|
||||
|
||||
def test_load_exc(morph_exc):
|
||||
EN.load()
|
||||
EN.morphologizer.load_exceptions(morph_exc)
|
||||
tokens = EN.tokenize('I like his style.')
|
||||
EN.set_pos(tokens)
|
||||
def test_load_exc(EN, morph_exc):
|
||||
EN.pos_tagger.morphologizer.load_exceptions(morph_exc)
|
||||
tokens = EN('I like his style.', pos_tag=True)
|
||||
his = tokens[2]
|
||||
assert his.pos == 'PRP$'
|
||||
assert his.lemma == '-PRP-'
|
||||
|
|
Loading…
Reference in New Issue
Block a user