mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 02:16:32 +03:00
Merge remote-tracking branch 'upstream/develop' into indonesian
This commit is contained in:
commit
2572a9ddf0
|
@ -28,8 +28,8 @@ def train_textcat(tokenizer, textcat,
|
||||||
batch_sizes = compounding(4., 128., 1.001)
|
batch_sizes = compounding(4., 128., 1.001)
|
||||||
for i in range(n_iter):
|
for i in range(n_iter):
|
||||||
losses = {}
|
losses = {}
|
||||||
for batch in minibatch(tqdm.tqdm(train_data, leave=False),
|
train_data = tqdm.tqdm(train_data, leave=False) # Progress bar
|
||||||
size=batch_sizes):
|
for batch in minibatch(train_data, size=batch_sizes):
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
|
textcat.update((docs, None), golds, sgd=optimizer, drop=0.2,
|
||||||
losses=losses)
|
losses=losses)
|
||||||
|
@ -105,6 +105,5 @@ def main(model_loc=None):
|
||||||
print(doc.cats)
|
print(doc.cats)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
plac.call(main)
|
plac.call(main)
|
||||||
|
|
|
@ -483,7 +483,7 @@ cdef class GoldParse:
|
||||||
return not nonproj.is_nonproj_tree(self.heads)
|
return not nonproj.is_nonproj_tree(self.heads)
|
||||||
|
|
||||||
|
|
||||||
def biluo_tags_from_offsets(doc, entities):
|
def biluo_tags_from_offsets(doc, entities, missing='O'):
|
||||||
"""Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out
|
"""Encode labelled spans into per-token tags, using the Begin/In/Last/Unit/Out
|
||||||
scheme (BILUO).
|
scheme (BILUO).
|
||||||
|
|
||||||
|
@ -535,7 +535,7 @@ def biluo_tags_from_offsets(doc, entities):
|
||||||
if i in entity_chars:
|
if i in entity_chars:
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
biluo[token.i] = 'O'
|
biluo[token.i] = missing
|
||||||
return biluo
|
return biluo
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -141,6 +141,23 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
entities[(start, end, label)] += prob
|
entities[(start, end, label)] += prob
|
||||||
return entities
|
return entities
|
||||||
|
|
||||||
|
def get_beam_parses(self, Beam beam):
|
||||||
|
parses = []
|
||||||
|
probs = beam.probs
|
||||||
|
for i in range(beam.size):
|
||||||
|
stcls = <StateClass>beam.at(i)
|
||||||
|
if stcls.is_final():
|
||||||
|
self.finalize_state(stcls.c)
|
||||||
|
prob = probs[i]
|
||||||
|
parse = []
|
||||||
|
for j in range(stcls.c._e_i):
|
||||||
|
start = stcls.c._ents[j].start
|
||||||
|
end = stcls.c._ents[j].end
|
||||||
|
label = stcls.c._ents[j].label
|
||||||
|
parse.append((start, end, self.strings[label]))
|
||||||
|
parses.append((prob, parse))
|
||||||
|
return parses
|
||||||
|
|
||||||
cdef Transition lookup_transition(self, object name) except *:
|
cdef Transition lookup_transition(self, object name) except *:
|
||||||
cdef attr_t label
|
cdef attr_t label
|
||||||
if name == '-' or name == None:
|
if name == '-' or name == None:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user