mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Merge branch 'master' of ssh://github.com/explosion/spaCy
This commit is contained in:
commit
296d33a4fc
|
@ -100,7 +100,7 @@ def evaluate(Language, gold_tuples, model_dir, gold_preproc=False, verbose=False
|
|||
nlp.entity(tokens)
|
||||
else:
|
||||
tokens = nlp(raw_text)
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
gold = GoldParse.from_annot_tuples(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
return scorer
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import unicode_literals
|
||||
import plac
|
||||
import json
|
||||
from os import path
|
||||
|
@ -5,106 +6,25 @@ import shutil
|
|||
import os
|
||||
import random
|
||||
import io
|
||||
import pathlib
|
||||
|
||||
from spacy.syntax.util import Config
|
||||
from spacy.tokens import Doc
|
||||
from spacy.syntax.nonproj import PseudoProjectivity
|
||||
from spacy.language import Language
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.tokenizer import Tokenizer
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tagger import Tagger
|
||||
from spacy.syntax.parser import Parser
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.pipeline import DependencyParser
|
||||
from spacy.syntax.parser import get_templates
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.scorer import Scorer
|
||||
import spacy.attrs
|
||||
|
||||
from spacy.language import Language
|
||||
|
||||
from spacy.tagger import W_orth
|
||||
|
||||
TAGGER_TEMPLATES = (
|
||||
(W_orth,),
|
||||
)
|
||||
|
||||
try:
|
||||
from codecs import open
|
||||
except ImportError:
|
||||
pass
|
||||
import io
|
||||
|
||||
|
||||
class TreebankParser(object):
|
||||
@staticmethod
|
||||
def setup_model_dir(model_dir, labels, templates, feat_set='basic', seed=0):
|
||||
dep_model_dir = path.join(model_dir, 'deps')
|
||||
pos_model_dir = path.join(model_dir, 'pos')
|
||||
if path.exists(dep_model_dir):
|
||||
shutil.rmtree(dep_model_dir)
|
||||
if path.exists(pos_model_dir):
|
||||
shutil.rmtree(pos_model_dir)
|
||||
os.mkdir(dep_model_dir)
|
||||
os.mkdir(pos_model_dir)
|
||||
|
||||
Config.write(dep_model_dir, 'config', features=feat_set, seed=seed,
|
||||
labels=labels)
|
||||
|
||||
@classmethod
|
||||
def from_dir(cls, tag_map, model_dir):
|
||||
vocab = Vocab(tag_map=tag_map, get_lex_attr=Language.default_lex_attrs())
|
||||
vocab.get_lex_attr[spacy.attrs.LANG] = lambda _: 0
|
||||
tokenizer = Tokenizer(vocab, {}, None, None, None)
|
||||
tagger = Tagger.blank(vocab, TAGGER_TEMPLATES)
|
||||
|
||||
cfg = Config.read(path.join(model_dir, 'deps'), 'config')
|
||||
parser = Parser.from_dir(path.join(model_dir, 'deps'), vocab.strings, ArcEager)
|
||||
return cls(vocab, tokenizer, tagger, parser)
|
||||
|
||||
def __init__(self, vocab, tokenizer, tagger, parser):
|
||||
self.vocab = vocab
|
||||
self.tokenizer = tokenizer
|
||||
self.tagger = tagger
|
||||
self.parser = parser
|
||||
|
||||
def train(self, words, tags, heads, deps):
|
||||
tokens = self.tokenizer.tokens_from_list(list(words))
|
||||
self.tagger.train(tokens, tags)
|
||||
|
||||
tokens = self.tokenizer.tokens_from_list(list(words))
|
||||
ids = range(len(words))
|
||||
ner = ['O'] * len(words)
|
||||
gold = GoldParse(tokens, ((ids, words, tags, heads, deps, ner)),
|
||||
make_projective=False)
|
||||
self.tagger(tokens)
|
||||
if gold.is_projective:
|
||||
try:
|
||||
self.parser.train(tokens, gold)
|
||||
except:
|
||||
for id_, word, head, dep in zip(ids, words, heads, deps):
|
||||
print(id_, word, head, dep)
|
||||
raise
|
||||
|
||||
def __call__(self, words, tags=None):
|
||||
tokens = self.tokenizer.tokens_from_list(list(words))
|
||||
if tags is None:
|
||||
self.tagger(tokens)
|
||||
else:
|
||||
self.tagger.tag_from_strings(tokens, tags)
|
||||
self.parser(tokens)
|
||||
return tokens
|
||||
|
||||
def end_training(self, data_dir):
|
||||
self.parser.model.end_training()
|
||||
self.parser.model.dump(path.join(data_dir, 'deps', 'model'))
|
||||
self.tagger.model.end_training()
|
||||
self.tagger.model.dump(path.join(data_dir, 'pos', 'model'))
|
||||
strings_loc = path.join(data_dir, 'vocab', 'strings.json')
|
||||
with io.open(strings_loc, 'w', encoding='utf8') as file_:
|
||||
self.vocab.strings.dump(file_)
|
||||
self.vocab.dump(path.join(data_dir, 'vocab', 'lexemes.bin'))
|
||||
|
||||
|
||||
|
||||
|
||||
def read_conllx(loc):
|
||||
with open(loc, 'r', 'utf8') as file_:
|
||||
with io.open(loc, 'r', encoding='utf8') as file_:
|
||||
text = file_.read()
|
||||
for sent in text.strip().split('\n\n'):
|
||||
lines = sent.strip().split('\n')
|
||||
|
@ -113,24 +33,31 @@ def read_conllx(loc):
|
|||
lines.pop(0)
|
||||
tokens = []
|
||||
for line in lines:
|
||||
id_, word, lemma, pos, tag, morph, head, dep, _1, _2 = line.split()
|
||||
id_, word, lemma, tag, pos, morph, head, dep, _1, _2 = line.split()
|
||||
if '-' in id_:
|
||||
continue
|
||||
id_ = int(id_) - 1
|
||||
head = (int(head) - 1) if head != '0' else id_
|
||||
dep = 'ROOT' if dep == 'root' else dep
|
||||
tokens.append((id_, word, tag, head, dep, 'O'))
|
||||
tuples = zip(*tokens)
|
||||
yield (None, [(tuples, [])])
|
||||
try:
|
||||
id_ = int(id_) - 1
|
||||
head = (int(head) - 1) if head != '0' else id_
|
||||
dep = 'ROOT' if dep == 'root' else dep
|
||||
tokens.append((id_, word, tag, head, dep, 'O'))
|
||||
except:
|
||||
print(line)
|
||||
raise
|
||||
tuples = [list(t) for t in zip(*tokens)]
|
||||
yield (None, [[tuples, []]])
|
||||
|
||||
|
||||
def score_model(nlp, gold_docs, verbose=False):
|
||||
def score_model(vocab, tagger, parser, gold_docs, verbose=False):
|
||||
scorer = Scorer()
|
||||
for _, gold_doc in gold_docs:
|
||||
for annot_tuples, _ in gold_doc:
|
||||
tokens = nlp(list(annot_tuples[1]), tags=list(annot_tuples[2]))
|
||||
gold = GoldParse(tokens, annot_tuples)
|
||||
scorer.score(tokens, gold, verbose=verbose)
|
||||
for (ids, words, tags, heads, deps, entities), _ in gold_doc:
|
||||
doc = Doc(vocab, words=words)
|
||||
tagger(doc)
|
||||
parser(doc)
|
||||
PseudoProjectivity.deprojectivize(doc)
|
||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||
scorer.score(doc, gold, verbose=verbose)
|
||||
return scorer
|
||||
|
||||
|
||||
|
@ -138,22 +65,45 @@ def main(train_loc, dev_loc, model_dir, tag_map_loc):
|
|||
with open(tag_map_loc) as file_:
|
||||
tag_map = json.loads(file_.read())
|
||||
train_sents = list(read_conllx(train_loc))
|
||||
labels = ArcEager.get_labels(train_sents)
|
||||
templates = get_templates('basic')
|
||||
train_sents = PseudoProjectivity.preprocess_training_data(train_sents)
|
||||
|
||||
TreebankParser.setup_model_dir(model_dir, labels, templates)
|
||||
actions = ArcEager.get_actions(gold_parses=train_sents)
|
||||
features = get_templates('basic')
|
||||
|
||||
nlp = TreebankParser.from_dir(tag_map, model_dir)
|
||||
model_dir = pathlib.Path(model_dir)
|
||||
with (model_dir / 'deps' / 'config.json').open('w') as file_:
|
||||
json.dump({'pseudoprojective': True, 'labels': actions, 'features': features}, file_)
|
||||
|
||||
vocab = Vocab(lex_attr_getters=Language.Defaults.lex_attr_getters, tag_map=tag_map)
|
||||
# Populate vocab
|
||||
for _, doc_sents in train_sents:
|
||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||
for word in words:
|
||||
_ = vocab[word]
|
||||
for dep in deps:
|
||||
_ = vocab[dep]
|
||||
for tag in tags:
|
||||
_ = vocab[tag]
|
||||
for tag in tags:
|
||||
assert tag in tag_map, repr(tag)
|
||||
tagger = Tagger(vocab, tag_map=tag_map)
|
||||
parser = DependencyParser(vocab, actions=actions, features=features)
|
||||
|
||||
for itn in range(15):
|
||||
for _, doc_sents in train_sents:
|
||||
for (ids, words, tags, heads, deps, ner), _ in doc_sents:
|
||||
nlp.train(words, tags, heads, deps)
|
||||
doc = Doc(vocab, words=words)
|
||||
gold = GoldParse(doc, tags=tags, heads=heads, deps=deps)
|
||||
tagger(doc)
|
||||
parser.update(doc, gold)
|
||||
doc = Doc(vocab, words=words)
|
||||
tagger.update(doc, gold)
|
||||
random.shuffle(train_sents)
|
||||
scorer = score_model(nlp, read_conllx(dev_loc))
|
||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||
print('%d:\t%.3f\t%.3f' % (itn, scorer.uas, scorer.tags_acc))
|
||||
nlp = Language(vocab=vocab, tagger=tagger, parser=parser)
|
||||
nlp.end_training(model_dir)
|
||||
scorer = score_model(nlp, read_conllx(dev_loc))
|
||||
scorer = score_model(vocab, tagger, parser, read_conllx(dev_loc))
|
||||
print('%d:\t%.3f\t%.3f\t%.3f' % (itn, scorer.uas, scorer.las, scorer.tags_acc))
|
||||
|
||||
|
||||
|
|
|
@ -116,6 +116,10 @@ def intify_attrs(stringy_attrs, strings_map=None, _do_deprecated=False):
|
|||
stringy_attrs["TAG"] = stringy_attrs.pop("pos")
|
||||
if 'morph' in stringy_attrs:
|
||||
morphs = stringy_attrs.pop('morph')
|
||||
if 'number' in stringy_attrs:
|
||||
stringy_attrs.pop('number')
|
||||
if 'tenspect' in stringy_attrs:
|
||||
stringy_attrs.pop('tenspect')
|
||||
# for name, value in morphs.items():
|
||||
# stringy_attrs[name] = value
|
||||
for name, value in stringy_attrs.items():
|
||||
|
|
|
@ -19,6 +19,7 @@ cdef class GoldParse:
|
|||
|
||||
cdef int length
|
||||
cdef readonly int loss
|
||||
cdef readonly list words
|
||||
cdef readonly list tags
|
||||
cdef readonly list heads
|
||||
cdef readonly list labels
|
||||
|
|
|
@ -19,6 +19,8 @@ def tags_to_entities(tags):
|
|||
entities = []
|
||||
start = None
|
||||
for i, tag in enumerate(tags):
|
||||
if tag is None:
|
||||
continue
|
||||
if tag.startswith('O'):
|
||||
# TODO: We shouldn't be getting these malformed inputs. Fix this.
|
||||
if start is not None:
|
||||
|
@ -249,7 +251,7 @@ cdef class GoldParse:
|
|||
if deps is None:
|
||||
deps = [None for _ in doc]
|
||||
if entities is None:
|
||||
entities = [None for _ in doc]
|
||||
entities = ['-' for _ in doc]
|
||||
elif len(entities) == 0:
|
||||
entities = ['O' for _ in doc]
|
||||
elif not isinstance(entities[0], basestring):
|
||||
|
@ -266,6 +268,7 @@ cdef class GoldParse:
|
|||
self.c.labels = <int*>self.mem.alloc(len(doc), sizeof(int))
|
||||
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
||||
|
||||
self.words = [None] * len(doc)
|
||||
self.tags = [None] * len(doc)
|
||||
self.heads = [None] * len(doc)
|
||||
self.labels = [''] * len(doc)
|
||||
|
@ -279,6 +282,7 @@ cdef class GoldParse:
|
|||
|
||||
for i, gold_i in enumerate(self.cand_to_gold):
|
||||
if doc[i].text.isspace():
|
||||
self.words[i] = doc[i].text
|
||||
self.tags[i] = 'SP'
|
||||
self.heads[i] = None
|
||||
self.labels[i] = None
|
||||
|
@ -286,6 +290,7 @@ cdef class GoldParse:
|
|||
if gold_i is None:
|
||||
pass
|
||||
else:
|
||||
self.words[i] = words[gold_i]
|
||||
self.tags[i] = tags[gold_i]
|
||||
self.heads[i] = self.gold_to_cand[heads[gold_i]]
|
||||
self.labels[i] = deps[gold_i]
|
||||
|
|
|
@ -5,10 +5,7 @@ import pathlib
|
|||
from contextlib import contextmanager
|
||||
import shutil
|
||||
|
||||
try:
|
||||
import ujson as json
|
||||
except ImportError:
|
||||
import json
|
||||
import ujson as json
|
||||
|
||||
|
||||
try:
|
||||
|
@ -31,6 +28,8 @@ from .attrs import TAG, DEP, ENT_IOB, ENT_TYPE, HEAD, PROB, LANG, IS_STOP
|
|||
from .syntax.parser import get_templates
|
||||
from .syntax.nonproj import PseudoProjectivity
|
||||
from .pipeline import DependencyParser, EntityRecognizer
|
||||
from .syntax.arc_eager import ArcEager
|
||||
from .syntax.ner import BiluoPushDown
|
||||
|
||||
|
||||
class BaseDefaults(object):
|
||||
|
@ -96,26 +95,27 @@ class BaseDefaults(object):
|
|||
return Tagger.load(nlp.path / 'pos', nlp.vocab)
|
||||
|
||||
@classmethod
|
||||
def create_parser(cls, nlp=None):
|
||||
def create_parser(cls, nlp=None, **cfg):
|
||||
if nlp is None:
|
||||
return DependencyParser(cls.create_vocab(), features=cls.parser_features)
|
||||
return DependencyParser(cls.create_vocab(), features=cls.parser_features,
|
||||
**cfg)
|
||||
elif nlp.path is False:
|
||||
return DependencyParser(nlp.vocab, features=cls.parser_features)
|
||||
return DependencyParser(nlp.vocab, features=cls.parser_features, **cfg)
|
||||
elif nlp.path is None or not (nlp.path / 'deps').exists():
|
||||
return None
|
||||
else:
|
||||
return DependencyParser.load(nlp.path / 'deps', nlp.vocab)
|
||||
return DependencyParser.load(nlp.path / 'deps', nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_entity(cls, nlp=None):
|
||||
def create_entity(cls, nlp=None, **cfg):
|
||||
if nlp is None:
|
||||
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features)
|
||||
return EntityRecognizer(cls.create_vocab(), features=cls.entity_features, **cfg)
|
||||
elif nlp.path is False:
|
||||
return EntityRecognizer(nlp.vocab, features=cls.entity_features)
|
||||
return EntityRecognizer(nlp.vocab, features=cls.entity_features, **cfg)
|
||||
elif nlp.path is None or not (nlp.path / 'ner').exists():
|
||||
return None
|
||||
else:
|
||||
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab)
|
||||
return EntityRecognizer.load(nlp.path / 'ner', nlp.vocab, **cfg)
|
||||
|
||||
@classmethod
|
||||
def create_matcher(cls, nlp=None):
|
||||
|
@ -216,14 +216,14 @@ class Language(object):
|
|||
# preprocess training data here before ArcEager.get_labels() is called
|
||||
gold_tuples = PseudoProjectivity.preprocess_training_data(gold_tuples)
|
||||
|
||||
parser_cfg['labels'] = ArcEager.get_labels(gold_tuples)
|
||||
entity_cfg['labels'] = BiluoPushDown.get_labels(gold_tuples)
|
||||
parser_cfg['actions'] = ArcEager.get_actions(gold_parses=gold_tuples)
|
||||
entity_cfg['actions'] = BiluoPushDown.get_actions(gold_parses=gold_tuples)
|
||||
|
||||
with (dep_model_dir / 'config.json').open('wb') as file_:
|
||||
with (dep_model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(parser_cfg, file_)
|
||||
with (ner_model_dir / 'config.json').open('wb') as file_:
|
||||
with (ner_model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(entity_cfg, file_)
|
||||
with (pos_model_dir / 'config.json').open('wb') as file_:
|
||||
with (pos_model_dir / 'config.json').open('w') as file_:
|
||||
json.dump(tagger_cfg, file_)
|
||||
|
||||
self = cls(
|
||||
|
@ -238,15 +238,12 @@ class Language(object):
|
|||
vectors=False,
|
||||
pipeline=False)
|
||||
|
||||
self.defaults.parser_labels = parser_cfg['labels']
|
||||
self.defaults.entity_labels = entity_cfg['labels']
|
||||
|
||||
self.vocab = self.defaults.Vocab()
|
||||
self.tokenizer = self.defaults.Tokenizer(self.vocab)
|
||||
self.tagger = self.defaults.Tagger(self.vocab, **tagger_cfg)
|
||||
self.parser = self.defaults.Parser(self.vocab, **parser_cfg)
|
||||
self.entity = self.defaults.Entity(self.vocab, **entity_cfg)
|
||||
self.pipeline = self.defaults.Pipeline(self)
|
||||
self.vocab = self.Defaults.create_vocab(self)
|
||||
self.tokenizer = self.Defaults.create_tokenizer(self)
|
||||
self.tagger = self.Defaults.create_tagger(self)
|
||||
self.parser = self.Defaults.create_parser(self)
|
||||
self.entity = self.Defaults.create_entity(self)
|
||||
self.pipeline = self.Defaults.create_pipeline(self)
|
||||
yield Trainer(self, gold_tuples)
|
||||
self.end_training()
|
||||
|
||||
|
@ -267,7 +264,7 @@ class Language(object):
|
|||
add_vectors = self.Defaults.add_vectors(self) \
|
||||
if 'add_vectors' not in overrides \
|
||||
else overrides['add_vectors']
|
||||
if add_vectors:
|
||||
if self.vocab and add_vectors:
|
||||
add_vectors(self.vocab)
|
||||
self.tokenizer = self.Defaults.create_tokenizer(self) \
|
||||
if 'tokenizer' not in overrides \
|
||||
|
@ -387,7 +384,7 @@ class Language(object):
|
|||
else:
|
||||
entity_iob_freqs = []
|
||||
entity_type_freqs = []
|
||||
with (path / 'vocab' / 'serializer.json').open('wb') as file_:
|
||||
with (path / 'vocab' / 'serializer.json').open('w') as file_:
|
||||
file_.write(
|
||||
json.dumps([
|
||||
(TAG, tagger_freqs),
|
||||
|
|
|
@ -87,7 +87,7 @@ class Scorer(object):
|
|||
gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot]))
|
||||
for id_, word, tag, head, dep, ner in gold.orig_annot:
|
||||
gold_tags.add((id_, tag))
|
||||
if dep.lower() not in punct_labels:
|
||||
if dep is not None and dep.lower() not in punct_labels:
|
||||
gold_deps.add((id_, head, dep.lower()))
|
||||
cand_deps = set()
|
||||
cand_tags = set()
|
||||
|
|
|
@ -439,7 +439,7 @@ cdef class ArcEager(TransitionSystem):
|
|||
if move_costs[move] == -1:
|
||||
move_costs[move] = move_cost_funcs[move](stcls, &gold.c)
|
||||
costs[i] = move_costs[move] + label_cost_funcs[move](stcls, &gold.c, label)
|
||||
n_gold += costs[i] == 0
|
||||
n_gold += costs[i] <= 0
|
||||
else:
|
||||
is_valid[i] = False
|
||||
costs[i] = 9000
|
||||
|
@ -456,8 +456,14 @@ cdef class ArcEager(TransitionSystem):
|
|||
"before training and after parsing. Either pass make_projective=True "
|
||||
"to the GoldParse class, or use PseudoProjectivity.preprocess_training_data")
|
||||
else:
|
||||
print(gold.words)
|
||||
print(gold.heads)
|
||||
print(gold.labels)
|
||||
raise ValueError(
|
||||
"Could not find a gold-standard action to supervise the dependency "
|
||||
"parser.\n"
|
||||
"The GoldParse was projective.")
|
||||
"The GoldParse was projective.\n"
|
||||
"The transition system has %d actions.\n"
|
||||
"State at failure:\n"
|
||||
"%s" % (self.n_moves, stcls.print_state(gold.words)))
|
||||
assert n_gold >= 1
|
||||
|
|
|
@ -65,7 +65,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
|||
for action in (BEGIN, IN, LAST, UNIT):
|
||||
actions[action][entity_type] = True
|
||||
moves = ('M', 'B', 'I', 'L', 'U')
|
||||
for raw_text, sents in kwargs.get('gold_tuples', []):
|
||||
for raw_text, sents in kwargs.get('gold_parses', []):
|
||||
for (ids, words, tags, heads, labels, biluo), _ in sents:
|
||||
for i, ner_tag in enumerate(biluo):
|
||||
if ner_tag != 'O' and ner_tag != '-':
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from __future__ import unicode_literals
|
||||
from copy import copy
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
|
|
@ -76,7 +76,7 @@ cdef class ParserModel(AveragedPerceptron):
|
|||
cdef class Parser:
|
||||
"""Base class of the DependencyParser and EntityRecognizer."""
|
||||
@classmethod
|
||||
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False):
|
||||
def load(cls, path, Vocab vocab, TransitionSystem=None, require=False, **cfg):
|
||||
"""Load the statistical model from the supplied path.
|
||||
|
||||
Arguments:
|
||||
|
@ -92,7 +92,7 @@ cdef class Parser:
|
|||
with (path / 'config.json').open() as file_:
|
||||
cfg = json.load(file_)
|
||||
# TODO: remove this shim when we don't have to support older data
|
||||
if 'labels' in cfg:
|
||||
if 'labels' in cfg and 'actions' not in cfg:
|
||||
cfg['actions'] = cfg.pop('labels')
|
||||
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
||||
if (path / 'model').exists():
|
||||
|
@ -266,7 +266,7 @@ cdef class Parser:
|
|||
loss += eg.costs[eg.guess]
|
||||
eg.fill_scores(0, eg.nr_class)
|
||||
eg.fill_costs(0, eg.nr_class)
|
||||
eg.fill_is_valid(0, eg.nr_class)
|
||||
eg.fill_is_valid(1, eg.nr_class)
|
||||
return loss
|
||||
|
||||
def step_through(self, Doc doc):
|
||||
|
|
|
@ -14,22 +14,31 @@ class Trainer(object):
|
|||
self.gold_tuples = gold_tuples
|
||||
|
||||
def epochs(self, nr_epoch, augment_data=None, gold_preproc=False):
|
||||
def _epoch():
|
||||
for raw_text, paragraph_tuples in self.gold_tuples:
|
||||
cached_golds = {}
|
||||
def _epoch(indices):
|
||||
for i in indices:
|
||||
raw_text, paragraph_tuples = self.gold_tuples[i]
|
||||
if gold_preproc:
|
||||
raw_text = None
|
||||
else:
|
||||
paragraph_tuples = merge_sents(paragraph_tuples)
|
||||
if augment_data is not None:
|
||||
if augment_data is None:
|
||||
docs = self.make_docs(raw_text, paragraph_tuples)
|
||||
if i in cached_golds:
|
||||
golds = cached_golds[i]
|
||||
else:
|
||||
golds = self.make_golds(docs, paragraph_tuples)
|
||||
else:
|
||||
raw_text, paragraph_tuples = augment_data(raw_text, paragraph_tuples)
|
||||
docs = self.make_docs(raw_text, paragraph_tuples)
|
||||
golds = self.make_golds(docs, paragraph_tuples)
|
||||
docs = self.make_docs(raw_text, paragraph_tuples)
|
||||
golds = self.make_golds(docs, paragraph_tuples)
|
||||
for doc, gold in zip(docs, golds):
|
||||
yield doc, gold
|
||||
|
||||
indices = list(range(len(self.gold_tuples)))
|
||||
for itn in range(nr_epoch):
|
||||
random.shuffle(self.gold_tuples)
|
||||
yield _epoch()
|
||||
random.shuffle(indices)
|
||||
yield _epoch(indices)
|
||||
|
||||
def update(self, doc, gold):
|
||||
for process in self.nlp.pipeline:
|
||||
|
@ -62,8 +71,8 @@ class Trainer(object):
|
|||
|
||||
def make_golds(self, docs, paragraph_tuples):
|
||||
if len(docs) == 1:
|
||||
return [GoldParse(docs[0], sent_tuples[0])
|
||||
return [GoldParse.from_annot_tuples(docs[0], sent_tuples[0])
|
||||
for sent_tuples in paragraph_tuples]
|
||||
else:
|
||||
return [GoldParse(doc, sent_tuples[0])
|
||||
return [GoldParse.from_annot_tuples(doc, sent_tuples[0])
|
||||
for doc, sent_tuples in zip(docs, paragraph_tuples)]
|
||||
|
|
|
@ -36,10 +36,10 @@ p
|
|||
|
||||
+code("Example").
|
||||
doc = nlp(u'London is a big city in the United Kingdom.')
|
||||
print(doc[0].text, doc[0].ent_iob, doc[0].ent_type_))
|
||||
print(doc[0].text, doc[0].ent_iob, doc[0].ent_type_)
|
||||
# (u'London', 2, u'GPE')
|
||||
print(doc[1].text, doc[1].ent_iob, doc[1].ent_type_))
|
||||
(u'is', 3, u'')]
|
||||
print(doc[1].text, doc[1].ent_iob, doc[1].ent_type_)
|
||||
# (u'is', 3, u'')
|
||||
|
||||
+h(2, "setting") Setting entity annotations
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user