mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
* Rename Doc.data to Doc.c
This commit is contained in:
parent
3ddea19b2b
commit
68f479e821
|
@ -215,7 +215,7 @@ cdef class Matcher:
|
||||||
cdef Pattern* state
|
cdef Pattern* state
|
||||||
matches = []
|
matches = []
|
||||||
for token_i in range(doc.length):
|
for token_i in range(doc.length):
|
||||||
token = &doc.data[token_i]
|
token = &doc.c[token_i]
|
||||||
q = 0
|
q = 0
|
||||||
# Go over the open matches, extending or finalizing if able. Otherwise,
|
# Go over the open matches, extending or finalizing if able. Otherwise,
|
||||||
# we over-write them (q doesn't advance)
|
# we over-write them (q doesn't advance)
|
||||||
|
@ -286,7 +286,7 @@ cdef class PhraseMatcher:
|
||||||
for i in range(self.max_length):
|
for i in range(self.max_length):
|
||||||
self._phrase_key[i] = 0
|
self._phrase_key[i] = 0
|
||||||
for i, tag in enumerate(tags):
|
for i, tag in enumerate(tags):
|
||||||
lexeme = self.vocab[tokens.data[i].lex.orth]
|
lexeme = self.vocab[tokens.c[i].lex.orth]
|
||||||
lexeme.set_flag(tag, True)
|
lexeme.set_flag(tag, True)
|
||||||
self._phrase_key[i] = lexeme.orth
|
self._phrase_key[i] = lexeme.orth
|
||||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||||
|
@ -309,7 +309,7 @@ cdef class PhraseMatcher:
|
||||||
for i in range(self.max_length):
|
for i in range(self.max_length):
|
||||||
self._phrase_key[i] = 0
|
self._phrase_key[i] = 0
|
||||||
for i, j in enumerate(range(start, end)):
|
for i, j in enumerate(range(start, end)):
|
||||||
self._phrase_key[i] = doc.data[j].lex.orth
|
self._phrase_key[i] = doc.c[j].lex.orth
|
||||||
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
cdef hash_t key = hash64(self._phrase_key, self.max_length * sizeof(attr_t), 0)
|
||||||
if self.phrase_ids.get(key):
|
if self.phrase_ids.get(key):
|
||||||
return True
|
return True
|
||||||
|
|
|
@ -84,7 +84,7 @@ cdef class Parser:
|
||||||
return cls(strings, moves, model)
|
return cls(strings, moves, model)
|
||||||
|
|
||||||
def __call__(self, Doc tokens):
|
def __call__(self, Doc tokens):
|
||||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||||
self.moves.initialize_state(stcls)
|
self.moves.initialize_state(stcls)
|
||||||
|
|
||||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||||
|
@ -112,7 +112,7 @@ cdef class Parser:
|
||||||
|
|
||||||
def train(self, Doc tokens, GoldParse gold):
|
def train(self, Doc tokens, GoldParse gold):
|
||||||
self.moves.preprocess_gold(gold)
|
self.moves.preprocess_gold(gold)
|
||||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
cdef StateClass stcls = StateClass.init(tokens.c, tokens.length)
|
||||||
self.moves.initialize_state(stcls)
|
self.moves.initialize_state(stcls)
|
||||||
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
cdef Example eg = Example(self.model.n_classes, CONTEXT_SIZE,
|
||||||
self.model.n_feats, self.model.n_feats)
|
self.model.n_feats, self.model.n_feats)
|
||||||
|
@ -143,7 +143,7 @@ cdef class StepwiseState:
|
||||||
def __init__(self, Parser parser, Doc doc):
|
def __init__(self, Parser parser, Doc doc):
|
||||||
self.parser = parser
|
self.parser = parser
|
||||||
self.doc = doc
|
self.doc = doc
|
||||||
self.stcls = StateClass.init(doc.data, doc.length)
|
self.stcls = StateClass.init(doc.c, doc.length)
|
||||||
self.parser.moves.initialize_state(self.stcls)
|
self.parser.moves.initialize_state(self.stcls)
|
||||||
self.eg = Example(self.parser.model.n_classes, CONTEXT_SIZE,
|
self.eg = Example(self.parser.model.n_classes, CONTEXT_SIZE,
|
||||||
self.parser.model.n_feats, self.parser.model.n_feats)
|
self.parser.model.n_feats, self.parser.model.n_feats)
|
||||||
|
|
|
@ -141,9 +141,9 @@ cdef class Tagger:
|
||||||
cdef int i
|
cdef int i
|
||||||
cdef const weight_t* scores
|
cdef const weight_t* scores
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
if tokens.data[i].pos == 0:
|
if tokens.c[i].pos == 0:
|
||||||
guess = self.predict(i, tokens.data)
|
guess = self.predict(i, tokens.c)
|
||||||
self.vocab.morphology.assign_tag(&tokens.data[i], guess)
|
self.vocab.morphology.assign_tag(&tokens.c[i], guess)
|
||||||
|
|
||||||
tokens.is_tagged = True
|
tokens.is_tagged = True
|
||||||
tokens._py_tokens = [None] * tokens.length
|
tokens._py_tokens = [None] * tokens.length
|
||||||
|
@ -154,7 +154,7 @@ cdef class Tagger:
|
||||||
def tag_from_strings(self, Doc tokens, object tag_strs):
|
def tag_from_strings(self, Doc tokens, object tag_strs):
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
self.vocab.morphology.assign_tag(&tokens.data[i], tag_strs[i])
|
self.vocab.morphology.assign_tag(&tokens.c[i], tag_strs[i])
|
||||||
tokens.is_tagged = True
|
tokens.is_tagged = True
|
||||||
tokens._py_tokens = [None] * tokens.length
|
tokens._py_tokens = [None] * tokens.length
|
||||||
|
|
||||||
|
@ -170,13 +170,13 @@ cdef class Tagger:
|
||||||
[g for g in gold_tag_strs if g is not None and g not in self.tag_names])
|
[g for g in gold_tag_strs if g is not None and g not in self.tag_names])
|
||||||
correct = 0
|
correct = 0
|
||||||
for i in range(tokens.length):
|
for i in range(tokens.length):
|
||||||
guess = self.update(i, tokens.data, golds[i])
|
guess = self.update(i, tokens.c, golds[i])
|
||||||
loss = golds[i] != -1 and guess != golds[i]
|
loss = golds[i] != -1 and guess != golds[i]
|
||||||
|
|
||||||
self.vocab.morphology.assign_tag(&tokens.data[i], guess)
|
self.vocab.morphology.assign_tag(&tokens.c[i], guess)
|
||||||
|
|
||||||
correct += loss == 0
|
correct += loss == 0
|
||||||
self.freqs[TAG][tokens.data[i].tag] += 1
|
self.freqs[TAG][tokens.c[i].tag] += 1
|
||||||
return correct
|
return correct
|
||||||
|
|
||||||
cdef int predict(self, int i, const TokenC* tokens) except -1:
|
cdef int predict(self, int i, const TokenC* tokens) except -1:
|
||||||
|
|
|
@ -113,7 +113,7 @@ cdef class Tokenizer:
|
||||||
self._tokenize(tokens, span, key)
|
self._tokenize(tokens, span, key)
|
||||||
in_ws = not in_ws
|
in_ws = not in_ws
|
||||||
if uc == ' ':
|
if uc == ' ':
|
||||||
tokens.data[tokens.length - 1].spacy = True
|
tokens.c[tokens.length - 1].spacy = True
|
||||||
start = i + 1
|
start = i + 1
|
||||||
else:
|
else:
|
||||||
start = i
|
start = i
|
||||||
|
@ -125,7 +125,7 @@ cdef class Tokenizer:
|
||||||
cache_hit = self._try_cache(key, tokens)
|
cache_hit = self._try_cache(key, tokens)
|
||||||
if not cache_hit:
|
if not cache_hit:
|
||||||
self._tokenize(tokens, span, key)
|
self._tokenize(tokens, span, key)
|
||||||
tokens.data[tokens.length - 1].spacy = string[-1] == ' '
|
tokens.c[tokens.length - 1].spacy = string[-1] == ' '
|
||||||
return tokens
|
return tokens
|
||||||
|
|
||||||
cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
|
cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
|
||||||
|
@ -148,7 +148,7 @@ cdef class Tokenizer:
|
||||||
orig_size = tokens.length
|
orig_size = tokens.length
|
||||||
span = self._split_affixes(span, &prefixes, &suffixes)
|
span = self._split_affixes(span, &prefixes, &suffixes)
|
||||||
self._attach_tokens(tokens, span, &prefixes, &suffixes)
|
self._attach_tokens(tokens, span, &prefixes, &suffixes)
|
||||||
self._save_cached(&tokens.data[orig_size], orig_key, tokens.length - orig_size)
|
self._save_cached(&tokens.c[orig_size], orig_key, tokens.length - orig_size)
|
||||||
|
|
||||||
cdef unicode _split_affixes(self, unicode string, vector[const LexemeC*] *prefixes,
|
cdef unicode _split_affixes(self, unicode string, vector[const LexemeC*] *prefixes,
|
||||||
vector[const LexemeC*] *suffixes):
|
vector[const LexemeC*] *suffixes):
|
||||||
|
|
|
@ -26,7 +26,7 @@ cdef class Doc:
|
||||||
cdef public object _vector
|
cdef public object _vector
|
||||||
cdef public object _vector_norm
|
cdef public object _vector_norm
|
||||||
|
|
||||||
cdef TokenC* data
|
cdef TokenC* c
|
||||||
|
|
||||||
cdef public bint is_tagged
|
cdef public bint is_tagged
|
||||||
cdef public bint is_parsed
|
cdef public bint is_parsed
|
||||||
|
|
|
@ -73,7 +73,7 @@ cdef class Doc:
|
||||||
data_start[i].lex = &EMPTY_LEXEME
|
data_start[i].lex = &EMPTY_LEXEME
|
||||||
data_start[i].l_edge = i
|
data_start[i].l_edge = i
|
||||||
data_start[i].r_edge = i
|
data_start[i].r_edge = i
|
||||||
self.data = data_start + PADDING
|
self.c = data_start + PADDING
|
||||||
self.max_length = size
|
self.max_length = size
|
||||||
self.length = 0
|
self.length = 0
|
||||||
self.is_tagged = False
|
self.is_tagged = False
|
||||||
|
@ -97,7 +97,7 @@ cdef class Doc:
|
||||||
if self._py_tokens[i] is not None:
|
if self._py_tokens[i] is not None:
|
||||||
return self._py_tokens[i]
|
return self._py_tokens[i]
|
||||||
else:
|
else:
|
||||||
return Token.cinit(self.vocab, &self.data[i], i, self)
|
return Token.cinit(self.vocab, &self.c[i], i, self)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
"""Iterate over the tokens.
|
"""Iterate over the tokens.
|
||||||
|
@ -110,7 +110,7 @@ cdef class Doc:
|
||||||
if self._py_tokens[i] is not None:
|
if self._py_tokens[i] is not None:
|
||||||
yield self._py_tokens[i]
|
yield self._py_tokens[i]
|
||||||
else:
|
else:
|
||||||
yield Token.cinit(self.vocab, &self.data[i], i, self)
|
yield Token.cinit(self.vocab, &self.c[i], i, self)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.length
|
return self.length
|
||||||
|
@ -187,7 +187,7 @@ cdef class Doc:
|
||||||
cdef int label = 0
|
cdef int label = 0
|
||||||
output = []
|
output = []
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
token = &self.data[i]
|
token = &self.c[i]
|
||||||
if token.ent_iob == 1:
|
if token.ent_iob == 1:
|
||||||
assert start != -1
|
assert start != -1
|
||||||
elif token.ent_iob == 2 or token.ent_iob == 0:
|
elif token.ent_iob == 2 or token.ent_iob == 0:
|
||||||
|
@ -212,23 +212,23 @@ cdef class Doc:
|
||||||
# 4. Test more nuanced date and currency regex
|
# 4. Test more nuanced date and currency regex
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
self.data[i].ent_type = 0
|
self.c[i].ent_type = 0
|
||||||
self.data[i].ent_iob = 0
|
self.c[i].ent_iob = 0
|
||||||
cdef attr_t ent_type
|
cdef attr_t ent_type
|
||||||
cdef int start, end
|
cdef int start, end
|
||||||
for ent_type, start, end in ents:
|
for ent_type, start, end in ents:
|
||||||
if ent_type is None or ent_type < 0:
|
if ent_type is None or ent_type < 0:
|
||||||
# Mark as O
|
# Mark as O
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
self.data[i].ent_type = 0
|
self.c[i].ent_type = 0
|
||||||
self.data[i].ent_iob = 2
|
self.c[i].ent_iob = 2
|
||||||
else:
|
else:
|
||||||
# Mark (inside) as I
|
# Mark (inside) as I
|
||||||
for i in range(start, end):
|
for i in range(start, end):
|
||||||
self.data[i].ent_type = ent_type
|
self.c[i].ent_type = ent_type
|
||||||
self.data[i].ent_iob = 1
|
self.c[i].ent_iob = 1
|
||||||
# Set start as B
|
# Set start as B
|
||||||
self.data[start].ent_iob = 3
|
self.c[start].ent_iob = 3
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def noun_chunks(self):
|
def noun_chunks(self):
|
||||||
|
@ -245,7 +245,7 @@ cdef class Doc:
|
||||||
np_deps = [self.vocab.strings[label] for label in labels]
|
np_deps = [self.vocab.strings[label] for label in labels]
|
||||||
np_label = self.vocab.strings['NP']
|
np_label = self.vocab.strings['NP']
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
word = &self.data[i]
|
word = &self.c[i]
|
||||||
if word.pos == NOUN and word.dep in np_deps:
|
if word.pos == NOUN and word.dep in np_deps:
|
||||||
yield Span(self, word.l_edge, i+1, label=np_label)
|
yield Span(self, word.l_edge, i+1, label=np_label)
|
||||||
|
|
||||||
|
@ -263,7 +263,7 @@ cdef class Doc:
|
||||||
cdef int i
|
cdef int i
|
||||||
start = 0
|
start = 0
|
||||||
for i in range(1, self.length):
|
for i in range(1, self.length):
|
||||||
if self.data[i].sent_start:
|
if self.c[i].sent_start:
|
||||||
yield Span(self, start, i)
|
yield Span(self, start, i)
|
||||||
start = i
|
start = i
|
||||||
yield Span(self, start, self.length)
|
yield Span(self, start, self.length)
|
||||||
|
@ -271,7 +271,7 @@ cdef class Doc:
|
||||||
cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1:
|
cdef int push_back(self, LexemeOrToken lex_or_tok, bint has_space) except -1:
|
||||||
if self.length == self.max_length:
|
if self.length == self.max_length:
|
||||||
self._realloc(self.length * 2)
|
self._realloc(self.length * 2)
|
||||||
cdef TokenC* t = &self.data[self.length]
|
cdef TokenC* t = &self.c[self.length]
|
||||||
if LexemeOrToken is const_TokenC_ptr:
|
if LexemeOrToken is const_TokenC_ptr:
|
||||||
t[0] = lex_or_tok[0]
|
t[0] = lex_or_tok[0]
|
||||||
else:
|
else:
|
||||||
|
@ -310,7 +310,7 @@ cdef class Doc:
|
||||||
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int32)
|
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.int32)
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
for j, feature in enumerate(attr_ids):
|
for j, feature in enumerate(attr_ids):
|
||||||
output[i, j] = get_token_attr(&self.data[i], feature)
|
output[i, j] = get_token_attr(&self.c[i], feature)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
|
def count_by(self, attr_id_t attr_id, exclude=None, PreshCounter counts=None):
|
||||||
|
@ -340,11 +340,11 @@ cdef class Doc:
|
||||||
# Take this check out of the loop, for a bit of extra speed
|
# Take this check out of the loop, for a bit of extra speed
|
||||||
if exclude is None:
|
if exclude is None:
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
counts.inc(get_token_attr(&self.data[i], attr_id), 1)
|
counts.inc(get_token_attr(&self.c[i], attr_id), 1)
|
||||||
else:
|
else:
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
if not exclude(self[i]):
|
if not exclude(self[i]):
|
||||||
attr = get_token_attr(&self.data[i], attr_id)
|
attr = get_token_attr(&self.c[i], attr_id)
|
||||||
counts.inc(attr, 1)
|
counts.inc(attr, 1)
|
||||||
if output_dict:
|
if output_dict:
|
||||||
return dict(counts)
|
return dict(counts)
|
||||||
|
@ -357,12 +357,12 @@ cdef class Doc:
|
||||||
# words out-of-bounds, and get out-of-bounds markers.
|
# words out-of-bounds, and get out-of-bounds markers.
|
||||||
# Now that we want to realloc, we need the address of the true start,
|
# Now that we want to realloc, we need the address of the true start,
|
||||||
# so we jump the pointer back PADDING places.
|
# so we jump the pointer back PADDING places.
|
||||||
cdef TokenC* data_start = self.data - PADDING
|
cdef TokenC* data_start = self.c - PADDING
|
||||||
data_start = <TokenC*>self.mem.realloc(data_start, n * sizeof(TokenC))
|
data_start = <TokenC*>self.mem.realloc(data_start, n * sizeof(TokenC))
|
||||||
self.data = data_start + PADDING
|
self.c = data_start + PADDING
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.length, self.max_length + PADDING):
|
for i in range(self.length, self.max_length + PADDING):
|
||||||
self.data[i].lex = &EMPTY_LEXEME
|
self.c[i].lex = &EMPTY_LEXEME
|
||||||
|
|
||||||
cdef int set_parse(self, const TokenC* parsed) except -1:
|
cdef int set_parse(self, const TokenC* parsed) except -1:
|
||||||
# TODO: This method is fairly misleading atm. It's used by Parser
|
# TODO: This method is fairly misleading atm. It's used by Parser
|
||||||
|
@ -371,14 +371,14 @@ cdef class Doc:
|
||||||
# Probably we should use from_array?
|
# Probably we should use from_array?
|
||||||
self.is_parsed = True
|
self.is_parsed = True
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
self.data[i] = parsed[i]
|
self.c[i] = parsed[i]
|
||||||
assert self.data[i].l_edge <= i
|
assert self.c[i].l_edge <= i
|
||||||
assert self.data[i].r_edge >= i
|
assert self.c[i].r_edge >= i
|
||||||
|
|
||||||
def from_array(self, attrs, array):
|
def from_array(self, attrs, array):
|
||||||
cdef int i, col
|
cdef int i, col
|
||||||
cdef attr_id_t attr_id
|
cdef attr_id_t attr_id
|
||||||
cdef TokenC* tokens = self.data
|
cdef TokenC* tokens = self.c
|
||||||
cdef int length = len(array)
|
cdef int length = len(array)
|
||||||
cdef attr_t[:] values
|
cdef attr_t[:] values
|
||||||
for col, attr_id in enumerate(attrs):
|
for col, attr_id in enumerate(attrs):
|
||||||
|
@ -412,7 +412,7 @@ cdef class Doc:
|
||||||
tokens[i].ent_type = values[i]
|
tokens[i].ent_type = values[i]
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown attribute ID: %d" % attr_id)
|
raise ValueError("Unknown attribute ID: %d" % attr_id)
|
||||||
set_children_from_heads(self.data, self.length)
|
set_children_from_heads(self.c, self.length)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_bytes(self):
|
def to_bytes(self):
|
||||||
|
@ -447,9 +447,9 @@ cdef class Doc:
|
||||||
cdef int start = -1
|
cdef int start = -1
|
||||||
cdef int end = -1
|
cdef int end = -1
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
if self.data[i].idx == start_idx:
|
if self.c[i].idx == start_idx:
|
||||||
start = i
|
start = i
|
||||||
if (self.data[i].idx + self.data[i].lex.length) == end_idx:
|
if (self.c[i].idx + self.c[i].lex.length) == end_idx:
|
||||||
if start == -1:
|
if start == -1:
|
||||||
return None
|
return None
|
||||||
end = i + 1
|
end = i + 1
|
||||||
|
@ -464,10 +464,10 @@ cdef class Doc:
|
||||||
new_orth = new_orth[:-len(span[-1].whitespace_)]
|
new_orth = new_orth[:-len(span[-1].whitespace_)]
|
||||||
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
|
cdef const LexemeC* lex = self.vocab.get(self.mem, new_orth)
|
||||||
# House the new merged token where it starts
|
# House the new merged token where it starts
|
||||||
cdef TokenC* token = &self.data[start]
|
cdef TokenC* token = &self.c[start]
|
||||||
# Update fields
|
# Update fields
|
||||||
token.lex = lex
|
token.lex = lex
|
||||||
token.spacy = self.data[end-1].spacy
|
token.spacy = self.c[end-1].spacy
|
||||||
if tag in self.vocab.morphology.tag_map:
|
if tag in self.vocab.morphology.tag_map:
|
||||||
self.vocab.morphology.assign_tag(token, tag)
|
self.vocab.morphology.assign_tag(token, tag)
|
||||||
else:
|
else:
|
||||||
|
@ -486,31 +486,31 @@ cdef class Doc:
|
||||||
span_root = span.root.i
|
span_root = span.root.i
|
||||||
token.dep = span.root.dep
|
token.dep = span.root.dep
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
self.data[i].head += i
|
self.c[i].head += i
|
||||||
# Set the head of the merged token, and its dep relation, from the Span
|
# Set the head of the merged token, and its dep relation, from the Span
|
||||||
token.head = self.data[span_root].head
|
token.head = self.c[span_root].head
|
||||||
# Adjust deps before shrinking tokens
|
# Adjust deps before shrinking tokens
|
||||||
# Tokens which point into the merged token should now point to it
|
# Tokens which point into the merged token should now point to it
|
||||||
# Subtract the offset from all tokens which point to >= end
|
# Subtract the offset from all tokens which point to >= end
|
||||||
offset = (end - start) - 1
|
offset = (end - start) - 1
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
head_idx = self.data[i].head
|
head_idx = self.c[i].head
|
||||||
if start <= head_idx < end:
|
if start <= head_idx < end:
|
||||||
self.data[i].head = start
|
self.c[i].head = start
|
||||||
elif head_idx >= end:
|
elif head_idx >= end:
|
||||||
self.data[i].head -= offset
|
self.c[i].head -= offset
|
||||||
# Now compress the token array
|
# Now compress the token array
|
||||||
for i in range(end, self.length):
|
for i in range(end, self.length):
|
||||||
self.data[i - offset] = self.data[i]
|
self.c[i - offset] = self.c[i]
|
||||||
for i in range(self.length - offset, self.length):
|
for i in range(self.length - offset, self.length):
|
||||||
memset(&self.data[i], 0, sizeof(TokenC))
|
memset(&self.c[i], 0, sizeof(TokenC))
|
||||||
self.data[i].lex = &EMPTY_LEXEME
|
self.c[i].lex = &EMPTY_LEXEME
|
||||||
self.length -= offset
|
self.length -= offset
|
||||||
for i in range(self.length):
|
for i in range(self.length):
|
||||||
# ...And, set heads back to a relative position
|
# ...And, set heads back to a relative position
|
||||||
self.data[i].head -= i
|
self.c[i].head -= i
|
||||||
# Set the left/right children, left/right edges
|
# Set the left/right children, left/right edges
|
||||||
set_children_from_heads(self.data, self.length)
|
set_children_from_heads(self.c, self.length)
|
||||||
# Clear the cached Python objects
|
# Clear the cached Python objects
|
||||||
self._py_tokens = [None] * self.length
|
self._py_tokens = [None] * self.length
|
||||||
# Return the merged Python object
|
# Return the merged Python object
|
||||||
|
|
|
@ -139,12 +139,12 @@ cdef class Span:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
# This should probably be called 'head', and the other one called
|
# This should probably be called 'head', and the other one called
|
||||||
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/
|
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/
|
||||||
cdef const TokenC* start = &self.doc.data[self.start]
|
cdef const TokenC* start = &self.doc.c[self.start]
|
||||||
cdef const TokenC* end = &self.doc.data[self.end]
|
cdef const TokenC* end = &self.doc.c[self.end]
|
||||||
head = start
|
head = start
|
||||||
while start <= (head + head.head) < end and head.head != 0:
|
while start <= (head + head.head) < end and head.head != 0:
|
||||||
head += head.head
|
head += head.head
|
||||||
return self.doc[head - self.doc.data]
|
return self.doc[head - self.doc.c]
|
||||||
|
|
||||||
property lefts:
|
property lefts:
|
||||||
"""Tokens that are to the left of the Span, whose head is within the Span."""
|
"""Tokens that are to the left of the Span, whose head is within the Span."""
|
||||||
|
|
|
@ -31,7 +31,7 @@ cdef class Token:
|
||||||
def __cinit__(self, Vocab vocab, Doc doc, int offset):
|
def __cinit__(self, Vocab vocab, Doc doc, int offset):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.doc = doc
|
self.doc = doc
|
||||||
self.c = &self.doc.data[offset]
|
self.c = &self.doc.c[offset]
|
||||||
self.i = offset
|
self.i = offset
|
||||||
self.array_len = doc.length
|
self.array_len = doc.length
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user