diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index 033cc50a9..eefae111f 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat, batch_sizes = compounding(4., 128., 1.001) for i in range(n_iter): losses = {} - for batch in minibatch(tqdm.tqdm(train_data, leave=False), - size=batch_sizes): + train_data = tqdm.tqdm(train_data, leave=False) # Progress bar + for batch in minibatch(train_data, size=batch_sizes): docs, golds = zip(*batch) textcat.update((docs, None), golds, sgd=optimizer, drop=0.2, losses=losses) @@ -70,7 +70,7 @@ def load_data(): texts, labels = zip(*train_data) cats = [(['POSITIVE'] if y else []) for y in labels] - + split = int(len(train_data) * 0.8) train_texts = texts[:split] @@ -104,7 +104,6 @@ def main(model_loc=None): doc = nlp(u'This movie sucked!') print(doc.cats) - if __name__ == '__main__': plac.call(main) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index aa5daa41d..39951447c 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -483,7 +483,7 @@ cdef class GoldParse: return not nonproj.is_nonproj_tree(self.heads) -def biluo_tags_from_offsets(doc, entities): +def biluo_tags_from_offsets(doc, entities, missing='O'): """Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out scheme (BILUO). @@ -535,7 +535,7 @@ def biluo_tags_from_offsets(doc, entities): if i in entity_chars: break else: - biluo[token.i] = 'O' + biluo[token.i] = missing return biluo diff --git a/spacy/syntax/ner.pyx b/spacy/syntax/ner.pyx index 023707aaa..d15de0181 100644 --- a/spacy/syntax/ner.pyx +++ b/spacy/syntax/ner.pyx @@ -141,6 +141,23 @@ cdef class BiluoPushDown(TransitionSystem): entities[(start, end, label)] += prob return entities + def get_beam_parses(self, Beam beam): + parses = [] + probs = beam.probs + for i in range(beam.size): + stcls = beam.at(i) + if stcls.is_final(): + self.finalize_state(stcls.c) + prob = probs[i] + parse = [] + for j in range(stcls.c._e_i): + start = stcls.c._ents[j].start + end = stcls.c._ents[j].end + label = stcls.c._ents[j].label + parse.append((start, end, self.strings[label])) + parses.append((prob, parse)) + return parses + cdef Transition lookup_transition(self, object name) except *: cdef attr_t label if name == '-' or name == None: