Merge remote-tracking branch 'upstream/develop' into indonesian

This commit is contained in:
Jim Geovedi 2017-07-30 21:24:16 +07:00
commit 2572a9ddf0
3 changed files with 22 additions and 6 deletions

View File

@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat,
batch_sizes = compounding(4., 128., 1.001)
for i in range(n_iter):
losses = {}
for batch in minibatch(tqdm.tqdm(train_data, leave=False),
size=batch_sizes):
train_data = tqdm.tqdm(train_data, leave=False) # Progress bar
for batch in minibatch(train_data, size=batch_sizes):
docs, golds = zip(*batch)
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
losses=losses)
@ -70,7 +70,7 @@ def load_data():
texts, labels = zip(*train_data)
cats = [(['POSITIVE'] if y else []) for y in labels]
split = int(len(train_data) * 0.8)
train_texts = texts[:split]
@ -104,7 +104,6 @@ def main(model_loc=None):
doc = nlp(u'This movie sucked!')
print(doc.cats)
if __name__ == '__main__':
plac.call(main)

View File

@ -483,7 +483,7 @@ cdef class GoldParse:
return not nonproj.is_nonproj_tree(self.heads)
def biluo_tags_from_offsets(doc, entities):
def biluo_tags_from_offsets(doc, entities, missing='O'):
"""Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out
scheme (BILUO).
@ -535,7 +535,7 @@ def biluo_tags_from_offsets(doc, entities):
if i in entity_chars:
break
else:
biluo[token.i] = 'O'
biluo[token.i] = missing
return biluo

View File

@ -141,6 +141,23 @@ cdef class BiluoPushDown(TransitionSystem):
entities[(start, end, label)] += prob
return entities
def get_beam_parses(self, Beam beam):
parses = []
probs = beam.probs
for i in range(beam.size):
stcls = <StateClass>beam.at(i)
if stcls.is_final():
self.finalize_state(stcls.c)
prob = probs[i]
parse = []
for j in range(stcls.c._e_i):
start = stcls.c._ents[j].start
end = stcls.c._ents[j].end
label = stcls.c._ents[j].label
parse.append((start, end, self.strings[label]))
parses.append((prob, parse))
return parses
cdef Transition lookup_transition(self, object name) except *:
cdef attr_t label
if name == '-' or name == None: