Merge remote-tracking branch 'upstream/develop' into indonesian

This commit is contained in:
Jim Geovedi 2017-07-30 21:24:16 +07:00
commit 2572a9ddf0
3 changed files with 22 additions and 6 deletions

View File

@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat,
batch_sizes = compounding(4., 128., 1.001) batch_sizes = compounding(4., 128., 1.001)
for i in range(n_iter): for i in range(n_iter):
losses = {} losses = {}
for batch in minibatch(tqdm.tqdm(train_data, leave=False), train_data = tqdm.tqdm(train_data, leave=False) # Progress bar
size=batch_sizes): for batch in minibatch(train_data, size=batch_sizes):
docs, golds = zip(*batch) docs, golds = zip(*batch)
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2, textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
losses=losses) losses=losses)
@ -105,6 +105,5 @@ def main(model_loc=None):
print(doc.cats) print(doc.cats)
if __name__ == '__main__': if __name__ == '__main__':
plac.call(main) plac.call(main)

View File

@ -483,7 +483,7 @@ cdef class GoldParse:
return not nonproj.is_nonproj_tree(self.heads) return not nonproj.is_nonproj_tree(self.heads)
def biluo_tags_from_offsets(doc, entities): def biluo_tags_from_offsets(doc, entities, missing='O'):
"""Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out """Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out
scheme (BILUO). scheme (BILUO).
@ -535,7 +535,7 @@ def biluo_tags_from_offsets(doc, entities):
if i in entity_chars: if i in entity_chars:
break break
else: else:
biluo[token.i] = 'O' biluo[token.i] = missing
return biluo return biluo

View File

@ -141,6 +141,23 @@ cdef class BiluoPushDown(TransitionSystem):
entities[(start, end, label)] += prob entities[(start, end, label)] += prob
return entities return entities
def get_beam_parses(self, Beam beam):
parses = []
probs = beam.probs
for i in range(beam.size):
stcls = <StateClass>beam.at(i)
if stcls.is_final():
self.finalize_state(stcls.c)
prob = probs[i]
parse = []
for j in range(stcls.c._e_i):
start = stcls.c._ents[j].start
end = stcls.c._ents[j].end
label = stcls.c._ents[j].label
parse.append((start, end, self.strings[label]))
parses.append((prob, parse))
return parses
cdef Transition lookup_transition(self, object name) except *: cdef Transition lookup_transition(self, object name) except *:
cdef attr_t label cdef attr_t label
if name == '-' or name == None: if name == '-' or name == None: