mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-12 15:25:47 +03:00
Update DocBin
Add missing strings when serializing
This commit is contained in:
parent
17226a60ac
commit
49145b9ec1
|
@ -14,7 +14,6 @@ ALL_ATTRS = (
|
|||
"TAG",
|
||||
"HEAD",
|
||||
"DEP",
|
||||
"SENT_START",
|
||||
"ENT_IOB",
|
||||
"ENT_TYPE",
|
||||
"LEMMA",
|
||||
|
@ -93,7 +92,12 @@ class DocBin(object):
|
|||
assert array.shape[0] == spaces.shape[0] # this should never happen
|
||||
spaces = spaces.reshape((spaces.shape[0], 1))
|
||||
self.spaces.append(numpy.asarray(spaces, dtype=bool))
|
||||
self.strings.update(w.text for w in doc)
|
||||
for token in doc:
|
||||
self.strings.add(token.text)
|
||||
self.strings.add(token.tag_)
|
||||
self.strings.add(token.lemma_)
|
||||
self.strings.add(token.dep_)
|
||||
self.strings.add(token.ent_type_)
|
||||
self.cats.append(doc.cats)
|
||||
if self.store_user_data:
|
||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||
|
@ -112,8 +116,7 @@ class DocBin(object):
|
|||
for i in range(len(self.tokens)):
|
||||
tokens = self.tokens[i]
|
||||
spaces = self.spaces[i]
|
||||
words = [vocab.strings[orth] for orth in tokens[:, orth_col]]
|
||||
doc = Doc(vocab, words=words, spaces=spaces)
|
||||
doc = Doc(vocab, words=tokens[:, orth_col], spaces=spaces)
|
||||
doc = doc.from_array(self.attrs, tokens)
|
||||
doc.cats = self.cats[i]
|
||||
if self.store_user_data:
|
||||
|
|
Loading…
Reference in New Issue
Block a user