Consolidate and freeze symbols (#11352)

* Consolidate and freeze symbols

Instead of having symbol values defined in three potentially conflicting
places (`spacy.attrs`, `spacy.parts_of_speech`, `spacy.symbols`), define
all symbols in `spacy.symbols` and reference those values in
`spacy.attrs` and `spacy.parts_of_speech`.

Remove deprecated and placeholder symbols from `spacy.attrs.IDS`.

Make `spacy.attrs.NAMES` and `spacy.symbols.NAMES` reverse dicts rather
than lists in order to support future use of hash values in `attr_id_t`.

Minor changes:

* Use `uint64_t` for attrs in `Doc.to_array` to support future use of
hash values
* Remove unneeded attrs filter for error message in `Doc.to_array`
* Remove unused attr `SENT_END`

* Handle dynamic size of attr_id_t in Doc.to_array

* Undo added warnings

* Refactor to make Doc.to_array more similar to Doc.from_array

* Improve refactoring
This commit is contained in:
Adriane Boyd 2022-09-02 09:08:40 +02:00 committed by GitHub
parent 698b8b495f
commit 4a615cacd2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 551 additions and 179 deletions

View File

@ -1,98 +1,49 @@
# Reserve 64 values for flag features
from . cimport symbols
cdef enum attr_id_t:
NULL_ATTR
IS_ALPHA
IS_ASCII
IS_DIGIT
IS_LOWER
IS_PUNCT
IS_SPACE
IS_TITLE
IS_UPPER
LIKE_URL
LIKE_NUM
LIKE_EMAIL
IS_STOP
IS_OOV_DEPRECATED
IS_BRACKET
IS_QUOTE
IS_LEFT_PUNCT
IS_RIGHT_PUNCT
IS_CURRENCY
NULL_ATTR = 0
IS_ALPHA = symbols.IS_ALPHA
IS_ASCII = symbols.IS_ASCII
IS_DIGIT = symbols.IS_DIGIT
IS_LOWER = symbols.IS_LOWER
IS_PUNCT = symbols.IS_PUNCT
IS_SPACE = symbols.IS_SPACE
IS_TITLE = symbols.IS_TITLE
IS_UPPER = symbols.IS_UPPER
LIKE_URL = symbols.LIKE_URL
LIKE_NUM = symbols.LIKE_NUM
LIKE_EMAIL = symbols.LIKE_EMAIL
IS_STOP = symbols.IS_STOP
IS_BRACKET = symbols.IS_BRACKET
IS_QUOTE = symbols.IS_QUOTE
IS_LEFT_PUNCT = symbols.IS_LEFT_PUNCT
IS_RIGHT_PUNCT = symbols.IS_RIGHT_PUNCT
IS_CURRENCY = symbols.IS_CURRENCY
FLAG19 = 19
FLAG20
FLAG21
FLAG22
FLAG23
FLAG24
FLAG25
FLAG26
FLAG27
FLAG28
FLAG29
FLAG30
FLAG31
FLAG32
FLAG33
FLAG34
FLAG35
FLAG36
FLAG37
FLAG38
FLAG39
FLAG40
FLAG41
FLAG42
FLAG43
FLAG44
FLAG45
FLAG46
FLAG47
FLAG48
FLAG49
FLAG50
FLAG51
FLAG52
FLAG53
FLAG54
FLAG55
FLAG56
FLAG57
FLAG58
FLAG59
FLAG60
FLAG61
FLAG62
FLAG63
ID = symbols.ID
ORTH = symbols.ORTH
LOWER = symbols.LOWER
NORM = symbols.NORM
SHAPE = symbols.SHAPE
PREFIX = symbols.PREFIX
SUFFIX = symbols.SUFFIX
ID
ORTH
LOWER
NORM
SHAPE
PREFIX
SUFFIX
LENGTH = symbols.LENGTH
CLUSTER = symbols.CLUSTER
LEMMA = symbols.LEMMA
POS = symbols.POS
TAG = symbols.TAG
DEP = symbols.DEP
ENT_IOB = symbols.ENT_IOB
ENT_TYPE = symbols.ENT_TYPE
HEAD = symbols.HEAD
SENT_START = symbols.SENT_START
SPACY = symbols.SPACY
PROB = symbols.PROB
LENGTH
CLUSTER
LEMMA
POS
TAG
DEP
ENT_IOB
ENT_TYPE
HEAD
SENT_START
SPACY
PROB
LANG
LANG = symbols.LANG
ENT_KB_ID = symbols.ENT_KB_ID
MORPH
MORPH = symbols.MORPH
ENT_ID = symbols.ENT_ID
IDX
SENT_END
IDX = symbols.IDX

View File

@ -16,57 +16,11 @@ IDS = {
"LIKE_NUM": LIKE_NUM,
"LIKE_EMAIL": LIKE_EMAIL,
"IS_STOP": IS_STOP,
"IS_OOV_DEPRECATED": IS_OOV_DEPRECATED,
"IS_BRACKET": IS_BRACKET,
"IS_QUOTE": IS_QUOTE,
"IS_LEFT_PUNCT": IS_LEFT_PUNCT,
"IS_RIGHT_PUNCT": IS_RIGHT_PUNCT,
"IS_CURRENCY": IS_CURRENCY,
"FLAG19": FLAG19,
"FLAG20": FLAG20,
"FLAG21": FLAG21,
"FLAG22": FLAG22,
"FLAG23": FLAG23,
"FLAG24": FLAG24,
"FLAG25": FLAG25,
"FLAG26": FLAG26,
"FLAG27": FLAG27,
"FLAG28": FLAG28,
"FLAG29": FLAG29,
"FLAG30": FLAG30,
"FLAG31": FLAG31,
"FLAG32": FLAG32,
"FLAG33": FLAG33,
"FLAG34": FLAG34,
"FLAG35": FLAG35,
"FLAG36": FLAG36,
"FLAG37": FLAG37,
"FLAG38": FLAG38,
"FLAG39": FLAG39,
"FLAG40": FLAG40,
"FLAG41": FLAG41,
"FLAG42": FLAG42,
"FLAG43": FLAG43,
"FLAG44": FLAG44,
"FLAG45": FLAG45,
"FLAG46": FLAG46,
"FLAG47": FLAG47,
"FLAG48": FLAG48,
"FLAG49": FLAG49,
"FLAG50": FLAG50,
"FLAG51": FLAG51,
"FLAG52": FLAG52,
"FLAG53": FLAG53,
"FLAG54": FLAG54,
"FLAG55": FLAG55,
"FLAG56": FLAG56,
"FLAG57": FLAG57,
"FLAG58": FLAG58,
"FLAG59": FLAG59,
"FLAG60": FLAG60,
"FLAG61": FLAG61,
"FLAG62": FLAG62,
"FLAG63": FLAG63,
"ID": ID,
"ORTH": ORTH,
"LOWER": LOWER,
@ -92,8 +46,7 @@ IDS = {
}
# ATTR IDs, in order of the symbol
NAMES = [key for key, value in sorted(IDS.items(), key=lambda item: item[1])]
NAMES = {v: k for k, v in IDS.items()}
locals().update(IDS)

View File

@ -3,22 +3,22 @@ from . cimport symbols
cpdef enum univ_pos_t:
NO_TAG = 0
ADJ = symbols.ADJ
ADP
ADV
AUX
CONJ
CCONJ # U20
DET
INTJ
NOUN
NUM
PART
PRON
PROPN
PUNCT
SCONJ
SYM
VERB
X
EOL
SPACE
ADP = symbols.ADP
ADV = symbols.ADV
AUX = symbols.AUX
CONJ = symbols.CONJ
CCONJ = symbols.CCONJ # U20
DET = symbols.DET
INTJ = symbols.INTJ
NOUN = symbols.NOUN
NUM = symbols.NUM
PART = symbols.PART
PRON = symbols.PRON
PROPN = symbols.PROPN
PUNCT = symbols.PUNCT
SCONJ = symbols.SCONJ
SYM = symbols.SYM
VERB = symbols.VERB
X = symbols.X
EOL = symbols.EOL
SPACE = symbols.SPACE

View File

@ -144,7 +144,7 @@ def validate_init_settings(
def validate_token_pattern(obj: list) -> List[str]:
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
get_key = lambda k: NAMES[k] if isinstance(k, int) and k < len(NAMES) else k
get_key = lambda k: NAMES[k] if isinstance(k, int) and k in NAMES else k
if isinstance(obj, list):
converted = []
for pattern in obj:

View File

@ -147,7 +147,7 @@ cdef class StringStore:
elif _try_coerce_to_hash(string_or_id, &str_hash):
if str_hash == 0:
return ""
elif str_hash < len(SYMBOLS_BY_INT):
elif str_hash in SYMBOLS_BY_INT:
return SYMBOLS_BY_INT[str_hash]
else:
utf8str = <Utf8Str*>self._map.get(str_hash)
@ -223,7 +223,7 @@ cdef class StringStore:
# TODO: Raise an error instead
return self._map.get(string_or_id) is not NULL
if str_hash < len(SYMBOLS_BY_INT):
if str_hash in SYMBOLS_BY_INT:
return True
else:
return self._map.get(str_hash) is not NULL

View File

@ -1,5 +1,6 @@
# DO NOT EDIT! The symbols are frozen as of spaCy v3.0.0.
cdef enum symbol_t:
NIL
NIL = 0
IS_ALPHA
IS_ASCII
IS_DIGIT
@ -65,7 +66,7 @@ cdef enum symbol_t:
FLAG62
FLAG63
ID
ID = 64
ORTH
LOWER
NORM
@ -385,7 +386,7 @@ cdef enum symbol_t:
DEPRECATED275
DEPRECATED276
PERSON
PERSON = 380
NORP
FACILITY
ORG
@ -405,7 +406,7 @@ cdef enum symbol_t:
ORDINAL
CARDINAL
acomp
acomp = 398
advcl
advmod
agent
@ -458,12 +459,12 @@ cdef enum symbol_t:
rcmod
root
xcomp
acl
ENT_KB_ID
ENT_KB_ID = 452
MORPH
ENT_ID
IDX
_
_ = 456
# DO NOT ADD ANY NEW SYMBOLS!

View File

@ -469,11 +469,7 @@ IDS = {
}
def sort_nums(x):
return x[1]
NAMES = [it[0] for it in sorted(IDS.items(), key=sort_nums)]
NAMES = {v: k for k, v in IDS.items()}
# Unfortunate hack here, to work around problem with long cpdef enum
# (which is generating an enormous amount of C++ in Cython 0.24+)
# We keep the enum cdef, and just make sure the names are available to Python

467
spacy/tests/test_symbols.py Normal file
View File

@ -0,0 +1,467 @@
import pytest
from spacy.symbols import IDS, NAMES
V3_SYMBOLS = {
"": 0,
"IS_ALPHA": 1,
"IS_ASCII": 2,
"IS_DIGIT": 3,
"IS_LOWER": 4,
"IS_PUNCT": 5,
"IS_SPACE": 6,
"IS_TITLE": 7,
"IS_UPPER": 8,
"LIKE_URL": 9,
"LIKE_NUM": 10,
"LIKE_EMAIL": 11,
"IS_STOP": 12,
"IS_OOV_DEPRECATED": 13,
"IS_BRACKET": 14,
"IS_QUOTE": 15,
"IS_LEFT_PUNCT": 16,
"IS_RIGHT_PUNCT": 17,
"IS_CURRENCY": 18,
"FLAG19": 19,
"FLAG20": 20,
"FLAG21": 21,
"FLAG22": 22,
"FLAG23": 23,
"FLAG24": 24,
"FLAG25": 25,
"FLAG26": 26,
"FLAG27": 27,
"FLAG28": 28,
"FLAG29": 29,
"FLAG30": 30,
"FLAG31": 31,
"FLAG32": 32,
"FLAG33": 33,
"FLAG34": 34,
"FLAG35": 35,
"FLAG36": 36,
"FLAG37": 37,
"FLAG38": 38,
"FLAG39": 39,
"FLAG40": 40,
"FLAG41": 41,
"FLAG42": 42,
"FLAG43": 43,
"FLAG44": 44,
"FLAG45": 45,
"FLAG46": 46,
"FLAG47": 47,
"FLAG48": 48,
"FLAG49": 49,
"FLAG50": 50,
"FLAG51": 51,
"FLAG52": 52,
"FLAG53": 53,
"FLAG54": 54,
"FLAG55": 55,
"FLAG56": 56,
"FLAG57": 57,
"FLAG58": 58,
"FLAG59": 59,
"FLAG60": 60,
"FLAG61": 61,
"FLAG62": 62,
"FLAG63": 63,
"ID": 64,
"ORTH": 65,
"LOWER": 66,
"NORM": 67,
"SHAPE": 68,
"PREFIX": 69,
"SUFFIX": 70,
"LENGTH": 71,
"CLUSTER": 72,
"LEMMA": 73,
"POS": 74,
"TAG": 75,
"DEP": 76,
"ENT_IOB": 77,
"ENT_TYPE": 78,
"ENT_ID": 454,
"ENT_KB_ID": 452,
"HEAD": 79,
"SENT_START": 80,
"SPACY": 81,
"PROB": 82,
"LANG": 83,
"IDX": 455,
"ADJ": 84,
"ADP": 85,
"ADV": 86,
"AUX": 87,
"CONJ": 88,
"CCONJ": 89,
"DET": 90,
"INTJ": 91,
"NOUN": 92,
"NUM": 93,
"PART": 94,
"PRON": 95,
"PROPN": 96,
"PUNCT": 97,
"SCONJ": 98,
"SYM": 99,
"VERB": 100,
"X": 101,
"EOL": 102,
"SPACE": 103,
"DEPRECATED001": 104,
"DEPRECATED002": 105,
"DEPRECATED003": 106,
"DEPRECATED004": 107,
"DEPRECATED005": 108,
"DEPRECATED006": 109,
"DEPRECATED007": 110,
"DEPRECATED008": 111,
"DEPRECATED009": 112,
"DEPRECATED010": 113,
"DEPRECATED011": 114,
"DEPRECATED012": 115,
"DEPRECATED013": 116,
"DEPRECATED014": 117,
"DEPRECATED015": 118,
"DEPRECATED016": 119,
"DEPRECATED017": 120,
"DEPRECATED018": 121,
"DEPRECATED019": 122,
"DEPRECATED020": 123,
"DEPRECATED021": 124,
"DEPRECATED022": 125,
"DEPRECATED023": 126,
"DEPRECATED024": 127,
"DEPRECATED025": 128,
"DEPRECATED026": 129,
"DEPRECATED027": 130,
"DEPRECATED028": 131,
"DEPRECATED029": 132,
"DEPRECATED030": 133,
"DEPRECATED031": 134,
"DEPRECATED032": 135,
"DEPRECATED033": 136,
"DEPRECATED034": 137,
"DEPRECATED035": 138,
"DEPRECATED036": 139,
"DEPRECATED037": 140,
"DEPRECATED038": 141,
"DEPRECATED039": 142,
"DEPRECATED040": 143,
"DEPRECATED041": 144,
"DEPRECATED042": 145,
"DEPRECATED043": 146,
"DEPRECATED044": 147,
"DEPRECATED045": 148,
"DEPRECATED046": 149,
"DEPRECATED047": 150,
"DEPRECATED048": 151,
"DEPRECATED049": 152,
"DEPRECATED050": 153,
"DEPRECATED051": 154,
"DEPRECATED052": 155,
"DEPRECATED053": 156,
"DEPRECATED054": 157,
"DEPRECATED055": 158,
"DEPRECATED056": 159,
"DEPRECATED057": 160,
"DEPRECATED058": 161,
"DEPRECATED059": 162,
"DEPRECATED060": 163,
"DEPRECATED061": 164,
"DEPRECATED062": 165,
"DEPRECATED063": 166,
"DEPRECATED064": 167,
"DEPRECATED065": 168,
"DEPRECATED066": 169,
"DEPRECATED067": 170,
"DEPRECATED068": 171,
"DEPRECATED069": 172,
"DEPRECATED070": 173,
"DEPRECATED071": 174,
"DEPRECATED072": 175,
"DEPRECATED073": 176,
"DEPRECATED074": 177,
"DEPRECATED075": 178,
"DEPRECATED076": 179,
"DEPRECATED077": 180,
"DEPRECATED078": 181,
"DEPRECATED079": 182,
"DEPRECATED080": 183,
"DEPRECATED081": 184,
"DEPRECATED082": 185,
"DEPRECATED083": 186,
"DEPRECATED084": 187,
"DEPRECATED085": 188,
"DEPRECATED086": 189,
"DEPRECATED087": 190,
"DEPRECATED088": 191,
"DEPRECATED089": 192,
"DEPRECATED090": 193,
"DEPRECATED091": 194,
"DEPRECATED092": 195,
"DEPRECATED093": 196,
"DEPRECATED094": 197,
"DEPRECATED095": 198,
"DEPRECATED096": 199,
"DEPRECATED097": 200,
"DEPRECATED098": 201,
"DEPRECATED099": 202,
"DEPRECATED100": 203,
"DEPRECATED101": 204,
"DEPRECATED102": 205,
"DEPRECATED103": 206,
"DEPRECATED104": 207,
"DEPRECATED105": 208,
"DEPRECATED106": 209,
"DEPRECATED107": 210,
"DEPRECATED108": 211,
"DEPRECATED109": 212,
"DEPRECATED110": 213,
"DEPRECATED111": 214,
"DEPRECATED112": 215,
"DEPRECATED113": 216,
"DEPRECATED114": 217,
"DEPRECATED115": 218,
"DEPRECATED116": 219,
"DEPRECATED117": 220,
"DEPRECATED118": 221,
"DEPRECATED119": 222,
"DEPRECATED120": 223,
"DEPRECATED121": 224,
"DEPRECATED122": 225,
"DEPRECATED123": 226,
"DEPRECATED124": 227,
"DEPRECATED125": 228,
"DEPRECATED126": 229,
"DEPRECATED127": 230,
"DEPRECATED128": 231,
"DEPRECATED129": 232,
"DEPRECATED130": 233,
"DEPRECATED131": 234,
"DEPRECATED132": 235,
"DEPRECATED133": 236,
"DEPRECATED134": 237,
"DEPRECATED135": 238,
"DEPRECATED136": 239,
"DEPRECATED137": 240,
"DEPRECATED138": 241,
"DEPRECATED139": 242,
"DEPRECATED140": 243,
"DEPRECATED141": 244,
"DEPRECATED142": 245,
"DEPRECATED143": 246,
"DEPRECATED144": 247,
"DEPRECATED145": 248,
"DEPRECATED146": 249,
"DEPRECATED147": 250,
"DEPRECATED148": 251,
"DEPRECATED149": 252,
"DEPRECATED150": 253,
"DEPRECATED151": 254,
"DEPRECATED152": 255,
"DEPRECATED153": 256,
"DEPRECATED154": 257,
"DEPRECATED155": 258,
"DEPRECATED156": 259,
"DEPRECATED157": 260,
"DEPRECATED158": 261,
"DEPRECATED159": 262,
"DEPRECATED160": 263,
"DEPRECATED161": 264,
"DEPRECATED162": 265,
"DEPRECATED163": 266,
"DEPRECATED164": 267,
"DEPRECATED165": 268,
"DEPRECATED166": 269,
"DEPRECATED167": 270,
"DEPRECATED168": 271,
"DEPRECATED169": 272,
"DEPRECATED170": 273,
"DEPRECATED171": 274,
"DEPRECATED172": 275,
"DEPRECATED173": 276,
"DEPRECATED174": 277,
"DEPRECATED175": 278,
"DEPRECATED176": 279,
"DEPRECATED177": 280,
"DEPRECATED178": 281,
"DEPRECATED179": 282,
"DEPRECATED180": 283,
"DEPRECATED181": 284,
"DEPRECATED182": 285,
"DEPRECATED183": 286,
"DEPRECATED184": 287,
"DEPRECATED185": 288,
"DEPRECATED186": 289,
"DEPRECATED187": 290,
"DEPRECATED188": 291,
"DEPRECATED189": 292,
"DEPRECATED190": 293,
"DEPRECATED191": 294,
"DEPRECATED192": 295,
"DEPRECATED193": 296,
"DEPRECATED194": 297,
"DEPRECATED195": 298,
"DEPRECATED196": 299,
"DEPRECATED197": 300,
"DEPRECATED198": 301,
"DEPRECATED199": 302,
"DEPRECATED200": 303,
"DEPRECATED201": 304,
"DEPRECATED202": 305,
"DEPRECATED203": 306,
"DEPRECATED204": 307,
"DEPRECATED205": 308,
"DEPRECATED206": 309,
"DEPRECATED207": 310,
"DEPRECATED208": 311,
"DEPRECATED209": 312,
"DEPRECATED210": 313,
"DEPRECATED211": 314,
"DEPRECATED212": 315,
"DEPRECATED213": 316,
"DEPRECATED214": 317,
"DEPRECATED215": 318,
"DEPRECATED216": 319,
"DEPRECATED217": 320,
"DEPRECATED218": 321,
"DEPRECATED219": 322,
"DEPRECATED220": 323,
"DEPRECATED221": 324,
"DEPRECATED222": 325,
"DEPRECATED223": 326,
"DEPRECATED224": 327,
"DEPRECATED225": 328,
"DEPRECATED226": 329,
"DEPRECATED227": 330,
"DEPRECATED228": 331,
"DEPRECATED229": 332,
"DEPRECATED230": 333,
"DEPRECATED231": 334,
"DEPRECATED232": 335,
"DEPRECATED233": 336,
"DEPRECATED234": 337,
"DEPRECATED235": 338,
"DEPRECATED236": 339,
"DEPRECATED237": 340,
"DEPRECATED238": 341,
"DEPRECATED239": 342,
"DEPRECATED240": 343,
"DEPRECATED241": 344,
"DEPRECATED242": 345,
"DEPRECATED243": 346,
"DEPRECATED244": 347,
"DEPRECATED245": 348,
"DEPRECATED246": 349,
"DEPRECATED247": 350,
"DEPRECATED248": 351,
"DEPRECATED249": 352,
"DEPRECATED250": 353,
"DEPRECATED251": 354,
"DEPRECATED252": 355,
"DEPRECATED253": 356,
"DEPRECATED254": 357,
"DEPRECATED255": 358,
"DEPRECATED256": 359,
"DEPRECATED257": 360,
"DEPRECATED258": 361,
"DEPRECATED259": 362,
"DEPRECATED260": 363,
"DEPRECATED261": 364,
"DEPRECATED262": 365,
"DEPRECATED263": 366,
"DEPRECATED264": 367,
"DEPRECATED265": 368,
"DEPRECATED266": 369,
"DEPRECATED267": 370,
"DEPRECATED268": 371,
"DEPRECATED269": 372,
"DEPRECATED270": 373,
"DEPRECATED271": 374,
"DEPRECATED272": 375,
"DEPRECATED273": 376,
"DEPRECATED274": 377,
"DEPRECATED275": 378,
"DEPRECATED276": 379,
"PERSON": 380,
"NORP": 381,
"FACILITY": 382,
"ORG": 383,
"GPE": 384,
"LOC": 385,
"PRODUCT": 386,
"EVENT": 387,
"WORK_OF_ART": 388,
"LANGUAGE": 389,
"DATE": 391,
"TIME": 392,
"PERCENT": 393,
"MONEY": 394,
"QUANTITY": 395,
"ORDINAL": 396,
"CARDINAL": 397,
"acomp": 398,
"advcl": 399,
"advmod": 400,
"agent": 401,
"amod": 402,
"appos": 403,
"attr": 404,
"aux": 405,
"auxpass": 406,
"cc": 407,
"ccomp": 408,
"complm": 409,
"conj": 410,
"cop": 411,
"csubj": 412,
"csubjpass": 413,
"dep": 414,
"det": 415,
"dobj": 416,
"expl": 417,
"hmod": 418,
"hyph": 419,
"infmod": 420,
"intj": 421,
"iobj": 422,
"mark": 423,
"meta": 424,
"neg": 425,
"nmod": 426,
"nn": 427,
"npadvmod": 428,
"nsubj": 429,
"nsubjpass": 430,
"num": 431,
"number": 432,
"oprd": 433,
"obj": 434,
"obl": 435,
"parataxis": 436,
"partmod": 437,
"pcomp": 438,
"pobj": 439,
"poss": 440,
"possessive": 441,
"preconj": 442,
"prep": 443,
"prt": 444,
"punct": 445,
"quantmod": 446,
"rcmod": 448,
"relcl": 447,
"root": 449,
"xcomp": 450,
"acl": 451,
"LAW": 390,
"MORPH": 453,
"_": 456,
}
def test_frozen_symbols():
assert IDS == V3_SYMBOLS
assert NAMES == {v: k for k, v in IDS.items()}

View File

@ -974,22 +974,26 @@ cdef class Doc:
py_attr_ids = [(IDS[id_.upper()] if hasattr(id_, "upper") else id_)
for id_ in py_attr_ids]
except KeyError as msg:
keys = [k for k in IDS.keys() if not k.startswith("FLAG")]
keys = list(IDS.keys())
raise KeyError(Errors.E983.format(dict="IDS", key=msg, keys=keys)) from None
# Make an array from the attributes --- otherwise our inner loop is
# Python dict iteration.
cdef np.ndarray attr_ids = numpy.asarray(py_attr_ids, dtype="i")
output = numpy.ndarray(shape=(self.length, len(attr_ids)), dtype=numpy.uint64)
cdef Pool mem = Pool()
cdef int n_attrs = len(py_attr_ids)
cdef attr_id_t* c_attr_ids
if n_attrs > 0:
c_attr_ids = <attr_id_t*>mem.alloc(n_attrs, sizeof(attr_id_t))
for i, attr_id in enumerate(py_attr_ids):
c_attr_ids[i] = attr_id
output = numpy.ndarray(shape=(self.length, n_attrs), dtype=numpy.uint64)
c_output = <attr_t*>output.data
c_attr_ids = <attr_id_t*>attr_ids.data
cdef TokenC* token
cdef int nr_attr = attr_ids.shape[0]
for i in range(self.length):
token = &self.c[i]
for j in range(nr_attr):
c_output[i*nr_attr + j] = get_token_attr(token, c_attr_ids[j])
for j in range(n_attrs):
c_output[i*n_attrs + j] = get_token_attr(token, c_attr_ids[j])
# Handle 1d case
return output if len(attr_ids) >= 2 else output.reshape((self.length,))
return output if n_attrs >= 2 else output.reshape((self.length,))
def count_by(self, attr_id_t attr_id, exclude=None, object counts=None):
"""Count the frequencies of a given attribute. Produces a dict of