mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* POS tagger training working after reorg
This commit is contained in:
parent
4c4aa2c5c9
commit
cf8d26c3d2
|
@ -9,6 +9,7 @@ from ..tokens import Tokens
|
||||||
from ..morphology import Morphologizer
|
from ..morphology import Morphologizer
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
from .pos import EnPosTagger
|
from .pos import EnPosTagger
|
||||||
|
from .pos import POS_TAGS
|
||||||
from .attrs import get_flags
|
from .attrs import get_flags
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,13 +22,13 @@ class English(object):
|
||||||
if data_dir is None:
|
if data_dir is None:
|
||||||
data_dir = path.join(path.dirname(__file__), 'data')
|
data_dir = path.join(path.dirname(__file__), 'data')
|
||||||
self.vocab = Vocab.from_dir(data_dir, get_lex_props=get_lex_props)
|
self.vocab = Vocab.from_dir(data_dir, get_lex_props=get_lex_props)
|
||||||
|
for pos_str in POS_TAGS:
|
||||||
|
_ = self.vocab.strings.pos_tags[pos_str]
|
||||||
self.tokenizer = Tokenizer.from_dir(self.vocab, data_dir)
|
self.tokenizer = Tokenizer.from_dir(self.vocab, data_dir)
|
||||||
if pos_tag:
|
if pos_tag:
|
||||||
self.pos_tagger = EnPosTagger(data_dir,
|
morph = Morphologizer(self.vocab.strings, POS_TAGS,
|
||||||
Morphologizer.from_dir(
|
Lemmatizer(path.join(data_dir, 'wordnet')))
|
||||||
self.vocab.strings,
|
self.pos_tagger = EnPosTagger(data_dir, morph)
|
||||||
Lemmatizer(path.join(data_dir, 'wordnet')),
|
|
||||||
data_dir))
|
|
||||||
else:
|
else:
|
||||||
self.pos_tagger = None
|
self.pos_tagger = None
|
||||||
if parse:
|
if parse:
|
||||||
|
|
|
@ -35,15 +35,15 @@ cdef struct _Cached:
|
||||||
cdef class Morphologizer:
|
cdef class Morphologizer:
|
||||||
"""Given a POS tag and a Lexeme, find its lemma and morphological analysis.
|
"""Given a POS tag and a Lexeme, find its lemma and morphological analysis.
|
||||||
"""
|
"""
|
||||||
def __init__(self, StringStore strings, object lemmatizer,
|
def __init__(self, StringStore strings, object tag_map, object lemmatizer,
|
||||||
irregulars=None, tag_map=None, tag_names=None):
|
irregulars=None):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.strings = strings
|
self.strings = strings
|
||||||
self.tag_names = tag_names
|
|
||||||
self.lemmatizer = lemmatizer
|
self.lemmatizer = lemmatizer
|
||||||
self._cache = PreshMapArray(len(self.tag_names))
|
cdef int n_tags = len(self.strings.pos_tags) + 1
|
||||||
self.tags = <PosTag*>self.mem.alloc(len(self.tag_names), sizeof(PosTag))
|
self._cache = PreshMapArray(n_tags)
|
||||||
for i, tag in enumerate(self.tag_names):
|
self.tags = <PosTag*>self.mem.alloc(n_tags, sizeof(PosTag))
|
||||||
|
for tag, i in self.strings.pos_tags:
|
||||||
pos, props = tag_map[tag]
|
pos, props = tag_map[tag]
|
||||||
self.tags[i].id = i
|
self.tags[i].id = i
|
||||||
self.tags[i].pos = pos
|
self.tags[i].pos = pos
|
||||||
|
@ -57,15 +57,6 @@ cdef class Morphologizer:
|
||||||
if irregulars is not None:
|
if irregulars is not None:
|
||||||
self.load_exceptions(irregulars)
|
self.load_exceptions(irregulars)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_dir(cls, StringStore strings, object lemmatizer, data_dir):
|
|
||||||
tagger_cfg = json.loads(open(path.join(data_dir, 'pos', 'config.json')).read())
|
|
||||||
tag_map = tagger_cfg['tag_map']
|
|
||||||
tag_names = tagger_cfg['tag_names']
|
|
||||||
irregulars = json.loads(open(path.join(data_dir, 'morphs.json')).read())
|
|
||||||
return cls(strings, lemmatizer, tag_map=tag_map, irregulars=irregulars,
|
|
||||||
tag_names=tag_names)
|
|
||||||
|
|
||||||
cdef int lemmatize(self, const univ_tag_t pos, const Lexeme* lex) except -1:
|
cdef int lemmatize(self, const univ_tag_t pos, const Lexeme* lex) except -1:
|
||||||
if self.lemmatizer is None:
|
if self.lemmatizer is None:
|
||||||
return lex.sic
|
return lex.sic
|
||||||
|
@ -104,9 +95,10 @@ cdef class Morphologizer:
|
||||||
cdef dict props
|
cdef dict props
|
||||||
cdef int lemma
|
cdef int lemma
|
||||||
cdef id_t sic
|
cdef id_t sic
|
||||||
cdef univ_tag_t pos
|
cdef int pos
|
||||||
for pos_str, entries in exc.items():
|
for pos_str, entries in exc.items():
|
||||||
pos = self.tag_names.index(pos_str)
|
pos = self.strings.pos_tags[pos_str]
|
||||||
|
assert pos < len(self.strings.pos_tags)
|
||||||
for form_str, props in entries.items():
|
for form_str, props in entries.items():
|
||||||
lemma_str = props.get('L', form_str)
|
lemma_str = props.get('L', form_str)
|
||||||
sic = self.strings[form_str]
|
sic = self.strings[form_str]
|
||||||
|
|
|
@ -19,6 +19,8 @@ cdef class _SymbolMap:
|
||||||
cdef class StringStore:
|
cdef class StringStore:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cdef Utf8Str* strings
|
cdef Utf8Str* strings
|
||||||
|
cdef readonly _SymbolMap pos_tags
|
||||||
|
cdef readonly _SymbolMap dep_tags
|
||||||
cdef size_t size
|
cdef size_t size
|
||||||
|
|
||||||
cdef PreshMap _map
|
cdef PreshMap _map
|
||||||
|
|
|
@ -18,6 +18,9 @@ cdef class _SymbolMap:
|
||||||
for id_, string in enumerate(self._id_to_string[1:]):
|
for id_, string in enumerate(self._id_to_string[1:]):
|
||||||
yield string, id_
|
yield string, id_
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._id_to_string)
|
||||||
|
|
||||||
def __getitem__(self, object string_or_id):
|
def __getitem__(self, object string_or_id):
|
||||||
cdef bytes byte_string
|
cdef bytes byte_string
|
||||||
if isinstance(string_or_id, int) or isinstance(string_or_id, long):
|
if isinstance(string_or_id, int) or isinstance(string_or_id, long):
|
||||||
|
@ -42,6 +45,7 @@ cdef class StringStore:
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self._map = PreshMap()
|
self._map = PreshMap()
|
||||||
self._resize_at = 10000
|
self._resize_at = 10000
|
||||||
|
self.size = 1
|
||||||
self.strings = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
self.strings = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
|
||||||
self.pos_tags = _SymbolMap()
|
self.pos_tags = _SymbolMap()
|
||||||
self.dep_tags = _SymbolMap()
|
self.dep_tags = _SymbolMap()
|
||||||
|
|
|
@ -18,6 +18,3 @@ cdef class Tagger:
|
||||||
cpdef readonly Pool mem
|
cpdef readonly Pool mem
|
||||||
cpdef readonly Extractor extractor
|
cpdef readonly Extractor extractor
|
||||||
cpdef readonly LinearModel model
|
cpdef readonly LinearModel model
|
||||||
|
|
||||||
cpdef readonly list tag_names
|
|
||||||
cdef dict tagdict
|
|
||||||
|
|
|
@ -12,15 +12,13 @@ import cython
|
||||||
from thinc.features cimport Feature, count_feats
|
from thinc.features cimport Feature, count_feats
|
||||||
|
|
||||||
|
|
||||||
def setup_model_dir(tag_names, tag_map, tag_counts, templates, model_dir):
|
def setup_model_dir(tag_names, templates, model_dir):
|
||||||
if path.exists(model_dir):
|
if path.exists(model_dir):
|
||||||
shutil.rmtree(model_dir)
|
shutil.rmtree(model_dir)
|
||||||
os.mkdir(model_dir)
|
os.mkdir(model_dir)
|
||||||
config = {
|
config = {
|
||||||
'templates': templates,
|
'templates': templates,
|
||||||
'tag_names': tag_names,
|
'tag_names': tag_names,
|
||||||
'tag_map': tag_map,
|
|
||||||
'tag_counts': tag_counts,
|
|
||||||
}
|
}
|
||||||
with open(path.join(model_dir, 'config.json'), 'w') as file_:
|
with open(path.join(model_dir, 'config.json'), 'w') as file_:
|
||||||
json.dump(config, file_)
|
json.dump(config, file_)
|
||||||
|
@ -37,10 +35,9 @@ cdef class Tagger:
|
||||||
univ_counts = {}
|
univ_counts = {}
|
||||||
cdef unicode tag
|
cdef unicode tag
|
||||||
cdef unicode univ_tag
|
cdef unicode univ_tag
|
||||||
self.tag_names = cfg['tag_names']
|
tag_names = cfg['tag_names']
|
||||||
self.tagdict = _make_tag_dict(cfg['tag_counts'])
|
|
||||||
self.extractor = Extractor(templates)
|
self.extractor = Extractor(templates)
|
||||||
self.model = LinearModel(len(self.tag_names), self.extractor.n_templ+2)
|
self.model = LinearModel(len(tag_names) + 1, self.extractor.n_templ+2) # TODO
|
||||||
if path.exists(path.join(model_dir, 'model')):
|
if path.exists(path.join(model_dir, 'model')):
|
||||||
self.model.load(path.join(model_dir, 'model'))
|
self.model.load(path.join(model_dir, 'model'))
|
||||||
|
|
||||||
|
@ -63,30 +60,6 @@ cdef class Tagger:
|
||||||
self.model.update(counts)
|
self.model.update(counts)
|
||||||
return guess
|
return guess
|
||||||
|
|
||||||
def tag_id(self, object tag_name):
|
|
||||||
"""Encode tag_name into a tag ID integer."""
|
|
||||||
tag_id = self.tag_names.index(tag_name)
|
|
||||||
if tag_id == -1:
|
|
||||||
tag_id = len(self.tag_names)
|
|
||||||
self.tag_names.append(tag_name)
|
|
||||||
return tag_id
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tag_dict(counts):
|
|
||||||
freq_thresh = 20
|
|
||||||
ambiguity_thresh = 0.97
|
|
||||||
tagdict = {}
|
|
||||||
cdef atom_t word
|
|
||||||
cdef atom_t tag
|
|
||||||
for word_str, tag_freqs in counts.items():
|
|
||||||
tag_str, mode = max(tag_freqs.items(), key=lambda item: item[1])
|
|
||||||
n = sum(tag_freqs.values())
|
|
||||||
word = int(word_str)
|
|
||||||
tag = int(tag_str)
|
|
||||||
if n >= freq_thresh and (float(mode) / n) >= ambiguity_thresh:
|
|
||||||
tagdict[word] = tag
|
|
||||||
return tagdict
|
|
||||||
|
|
||||||
|
|
||||||
cdef int _arg_max(const weight_t* scores, int n_classes) except -1:
|
cdef int _arg_max(const weight_t* scores, int n_classes) except -1:
|
||||||
cdef int best = 0
|
cdef int best = 0
|
||||||
|
|
|
@ -39,10 +39,10 @@ cdef class Token:
|
||||||
cdef readonly StringStore string_store
|
cdef readonly StringStore string_store
|
||||||
cdef public int i
|
cdef public int i
|
||||||
cdef public int idx
|
cdef public int idx
|
||||||
cdef int pos
|
cdef readonly int pos_id
|
||||||
|
cdef readonly int dep_id
|
||||||
cdef int lemma
|
cdef int lemma
|
||||||
cdef public int head
|
cdef public int head
|
||||||
cdef public int dep_tag
|
|
||||||
|
|
||||||
cdef public atom_t id
|
cdef public atom_t id
|
||||||
cdef public atom_t cluster
|
cdef public atom_t cluster
|
||||||
|
|
|
@ -1,19 +1,25 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
from spacy.en import EN
|
from spacy.en import English
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from spacy.en import English
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def EN():
|
||||||
|
return English(pos_tag=True, parse=False)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def morph_exc():
|
def morph_exc():
|
||||||
return {
|
return {
|
||||||
'PRP$': {'his': {'L': '-PRP-', 'person': 3, 'case': 2}},
|
'PRP$': {'his': {'L': '-PRP-', 'person': 3, 'case': 2}},
|
||||||
}
|
}
|
||||||
|
|
||||||
def test_load_exc(morph_exc):
|
def test_load_exc(EN, morph_exc):
|
||||||
EN.load()
|
EN.pos_tagger.morphologizer.load_exceptions(morph_exc)
|
||||||
EN.morphologizer.load_exceptions(morph_exc)
|
tokens = EN('I like his style.', pos_tag=True)
|
||||||
tokens = EN.tokenize('I like his style.')
|
|
||||||
EN.set_pos(tokens)
|
|
||||||
his = tokens[2]
|
his = tokens[2]
|
||||||
assert his.pos == 'PRP$'
|
assert his.pos == 'PRP$'
|
||||||
assert his.lemma == '-PRP-'
|
assert his.lemma == '-PRP-'
|
||||||
|
|
Loading…
Reference in New Issue
Block a user