Morphology/Morphologizer optimizations and refactoring (#11024)

* `Morphology`: Refactor to use C types, reduce allocations, remove unused code

* `Morphologzier`: Avoid unnecessary sorting of morpho features

* `Morphologizer`: Remove execessive reallocations of labels, improve hash lookups of labels, coerce `numpy` numeric types to native ints
Update docs

* Remove unused method

* Replace `unique_ptr` usage with `shared_ptr`

* Add type annotations to internal Python methods, rename `hash` variable, fix typos

* Add comment to clarify implementation detail

* Fix return type

* `Morphology`: Stop early when splitting fields and values
This commit is contained in:
Madeesh Kannan 2022-07-15 11:14:08 +02:00 committed by GitHub
parent 851a7ca4fa
commit ba18d2913d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 235 additions and 161 deletions

View File

@ -1,23 +1,41 @@
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
cimport numpy as np
from libc.stdint cimport uint64_t
from libc.stdint cimport uint32_t, uint64_t
from libcpp.unordered_map cimport unordered_map
from libcpp.vector cimport vector
from libcpp.memory cimport shared_ptr
from .structs cimport MorphAnalysisC
from .strings cimport StringStore
from .typedefs cimport attr_t, hash_t
cdef cppclass Feature:
hash_t field
hash_t value
__init__():
this.field = 0
this.value = 0
cdef cppclass MorphAnalysisC:
hash_t key
vector[Feature] features
__init__():
this.key = 0
cdef class Morphology:
cdef readonly Pool mem
cdef readonly StringStore strings
cdef PreshMap tags # Keyed by hash, value is pointer to tag
cdef unordered_map[hash_t, shared_ptr[MorphAnalysisC]] tags
cdef MorphAnalysisC create_morph_tag(self, field_feature_pairs) except *
cdef int insert(self, MorphAnalysisC tag) except -1
cdef shared_ptr[MorphAnalysisC] _lookup_tag(self, hash_t tag_hash)
cdef void _intern_morph_tag(self, hash_t tag_key, feats)
cdef hash_t _add(self, features)
cdef str _normalize_features(self, features)
cdef str get_morph_str(self, hash_t morph_key)
cdef shared_ptr[MorphAnalysisC] get_morph_c(self, hash_t morph_key)
cdef int check_feature(const MorphAnalysisC* morph, attr_t feature) nogil
cdef list list_features(const MorphAnalysisC* morph)
cdef np.ndarray get_by_field(const MorphAnalysisC* morph, attr_t field)
cdef int get_n_by_field(attr_t* results, const MorphAnalysisC* morph, attr_t field) nogil
cdef int check_feature(const shared_ptr[MorphAnalysisC] morph, attr_t feature) nogil
cdef list list_features(const shared_ptr[MorphAnalysisC] morph)
cdef np.ndarray get_by_field(const shared_ptr[MorphAnalysisC] morph, attr_t field)
cdef int get_n_by_field(attr_t* results, const shared_ptr[MorphAnalysisC] morph, attr_t field) nogil

View File

@ -1,10 +1,10 @@
# cython: infer_types
import numpy
import warnings
from typing import Union, Tuple, List, Dict, Optional
from cython.operator cimport dereference as deref
from libcpp.memory cimport shared_ptr
from .attrs cimport POS
from .parts_of_speech import IDS as POS_IDS
from .errors import Warnings
from . import symbols
@ -24,134 +24,187 @@ cdef class Morphology:
EMPTY_MORPH = symbols.NAMES[symbols._]
def __init__(self, StringStore strings):
self.mem = Pool()
self.strings = strings
self.tags = PreshMap()
def __reduce__(self):
tags = set([self.get(self.strings[s]) for s in self.strings])
tags -= set([""])
return (unpickle_morphology, (self.strings, sorted(tags)), None, None)
def add(self, features):
cdef shared_ptr[MorphAnalysisC] _lookup_tag(self, hash_t tag_hash):
match = self.tags.find(tag_hash)
if match != self.tags.const_end():
return deref(match).second
else:
return shared_ptr[MorphAnalysisC]()
def _normalize_attr(self, attr_key : Union[int, str], attr_value : Union[int, str]) -> Optional[Tuple[str, Union[str, List[str]]]]:
if isinstance(attr_key, (int, str)) and isinstance(attr_value, (int, str)):
attr_key = self.strings.as_string(attr_key)
attr_value = self.strings.as_string(attr_value)
# Preserve multiple values as a list
if self.VALUE_SEP in attr_value:
values = attr_value.split(self.VALUE_SEP)
values.sort()
attr_value = values
else:
warnings.warn(Warnings.W100.format(feature={attr_key: attr_value}))
return None
return attr_key, attr_value
def _str_to_normalized_feat_dict(self, feats: str) -> Dict[str, str]:
if not feats or feats == self.EMPTY_MORPH:
return {}
out = []
for feat in feats.split(self.FEATURE_SEP):
field, values = feat.split(self.FIELD_SEP, 1)
normalized_attr = self._normalize_attr(field, values)
if normalized_attr is None:
continue
out.append((normalized_attr[0], normalized_attr[1]))
out.sort(key=lambda x: x[0])
return dict(out)
def _dict_to_normalized_feat_dict(self, feats: Dict[Union[int, str], Union[int, str]]) -> Dict[str, str]:
out = []
for field, values in feats.items():
normalized_attr = self._normalize_attr(field, values)
if normalized_attr is None:
continue
out.append((normalized_attr[0], normalized_attr[1]))
out.sort(key=lambda x: x[0])
return dict(out)
def _normalized_feat_dict_to_str(self, feats: Dict[str, str]) -> str:
norm_feats_string = self.FEATURE_SEP.join([
self.FIELD_SEP.join([field, self.VALUE_SEP.join(values) if isinstance(values, list) else values])
for field, values in feats.items()
])
return norm_feats_string or self.EMPTY_MORPH
cdef hash_t _add(self, features):
"""Insert a morphological analysis in the morphology table, if not
already present. The morphological analysis may be provided in the UD
FEATS format as a string or in the tag map dict format.
Returns the hash of the new analysis.
"""
cdef MorphAnalysisC* tag_ptr
cdef hash_t tag_hash = 0
cdef shared_ptr[MorphAnalysisC] tag
if isinstance(features, str):
if features == "":
features = self.EMPTY_MORPH
tag_ptr = <MorphAnalysisC*>self.tags.get(<hash_t>self.strings[features])
if tag_ptr != NULL:
return tag_ptr.key
features = self.feats_to_dict(features)
if not isinstance(features, dict):
tag_hash = self.strings[features]
tag = self._lookup_tag(tag_hash)
if tag:
return deref(tag).key
features = self._str_to_normalized_feat_dict(features)
elif isinstance(features, dict):
features = self._dict_to_normalized_feat_dict(features)
else:
warnings.warn(Warnings.W100.format(feature=features))
features = {}
string_features = {self.strings.as_string(field): self.strings.as_string(values) for field, values in features.items()}
# intified ("Field", "Field=Value") pairs
field_feature_pairs = []
for field in sorted(string_features):
values = string_features[field]
for value in values.split(self.VALUE_SEP):
field_feature_pairs.append((
self.strings.add(field),
self.strings.add(field + self.FIELD_SEP + value),
))
cdef MorphAnalysisC tag = self.create_morph_tag(field_feature_pairs)
# the hash key for the tag is either the hash of the normalized UFEATS
# string or the hash of an empty placeholder
norm_feats_string = self.normalize_features(features)
tag.key = self.strings.add(norm_feats_string)
self.insert(tag)
return tag.key
norm_feats_string = self._normalized_feat_dict_to_str(features)
tag_hash = self.strings.add(norm_feats_string)
tag = self._lookup_tag(tag_hash)
if tag:
return deref(tag).key
def normalize_features(self, features):
self._intern_morph_tag(tag_hash, features)
return tag_hash
cdef void _intern_morph_tag(self, hash_t tag_key, feats):
# intified ("Field", "Field=Value") pairs where fields with multiple values have
# been split into individual tuples, e.g.:
# [("Field1", "Field1=Value1"), ("Field1", "Field1=Value2"),
# ("Field2", "Field2=Value3")]
field_feature_pairs = []
# Feat dict is normalized at this point.
for field, values in feats.items():
field_key = self.strings.add(field)
if isinstance(values, list):
for value in values:
value_key = self.strings.add(field + self.FIELD_SEP + value)
field_feature_pairs.append((field_key, value_key))
else:
# We could box scalar values into a list and use a common
# code path to generate features but that incurs a small
# but measurable allocation/iteration overhead (as this
# branch is taken often enough).
value_key = self.strings.add(field + self.FIELD_SEP + values)
field_feature_pairs.append((field_key, value_key))
num_features = len(field_feature_pairs)
cdef shared_ptr[MorphAnalysisC] tag = shared_ptr[MorphAnalysisC](new MorphAnalysisC())
deref(tag).key = tag_key
deref(tag).features.resize(num_features)
for i in range(num_features):
deref(tag).features[i].field = field_feature_pairs[i][0]
deref(tag).features[i].value = field_feature_pairs[i][1]
self.tags[tag_key] = tag
cdef str get_morph_str(self, hash_t morph_key):
cdef shared_ptr[MorphAnalysisC] tag = self._lookup_tag(morph_key)
if not tag:
return ""
else:
return self.strings[deref(tag).key]
cdef shared_ptr[MorphAnalysisC] get_morph_c(self, hash_t morph_key):
return self._lookup_tag(morph_key)
cdef str _normalize_features(self, features):
"""Create a normalized FEATS string from a features string or dict.
features (Union[dict, str]): Features as dict or UFEATS string.
RETURNS (str): Features as normalized UFEATS string.
"""
if isinstance(features, str):
features = self.feats_to_dict(features)
if not isinstance(features, dict):
features = self._str_to_normalized_feat_dict(features)
elif isinstance(features, dict):
features = self._dict_to_normalized_feat_dict(features)
else:
warnings.warn(Warnings.W100.format(feature=features))
features = {}
features = self.normalize_attrs(features)
string_features = {self.strings.as_string(field): self.strings.as_string(values) for field, values in features.items()}
# normalized UFEATS string with sorted fields and values
norm_feats_string = self.FEATURE_SEP.join(sorted([
self.FIELD_SEP.join([field, values])
for field, values in string_features.items()
]))
return norm_feats_string or self.EMPTY_MORPH
def normalize_attrs(self, attrs):
"""Convert attrs dict so that POS is always by ID, other features are
by string. Values separated by VALUE_SEP are sorted.
"""
out = {}
attrs = dict(attrs)
for key, value in attrs.items():
# convert POS value to ID
if key == POS or (isinstance(key, str) and key.upper() == "POS"):
if isinstance(value, str) and value.upper() in POS_IDS:
value = POS_IDS[value.upper()]
elif isinstance(value, int) and value not in POS_IDS.values():
warnings.warn(Warnings.W100.format(feature={key: value}))
continue
out[POS] = value
# accept any string or ID fields and values and convert to strings
elif isinstance(key, (int, str)) and isinstance(value, (int, str)):
key = self.strings.as_string(key)
value = self.strings.as_string(value)
# sort values
if self.VALUE_SEP in value:
value = self.VALUE_SEP.join(sorted(value.split(self.VALUE_SEP)))
out[key] = value
else:
warnings.warn(Warnings.W100.format(feature={key: value}))
return out
return self._normalized_feat_dict_to_str(features)
cdef MorphAnalysisC create_morph_tag(self, field_feature_pairs) except *:
"""Creates a MorphAnalysisC from a list of intified
("Field", "Field=Value") tuples where fields with multiple values have
been split into individual tuples, e.g.:
[("Field1", "Field1=Value1"), ("Field1", "Field1=Value2"),
("Field2", "Field2=Value3")]
"""
cdef MorphAnalysisC tag
tag.length = len(field_feature_pairs)
if tag.length > 0:
tag.fields = <attr_t*>self.mem.alloc(tag.length, sizeof(attr_t))
tag.features = <attr_t*>self.mem.alloc(tag.length, sizeof(attr_t))
for i, (field, feature) in enumerate(field_feature_pairs):
tag.fields[i] = field
tag.features[i] = feature
return tag
def add(self, features):
return self._add(features)
cdef int insert(self, MorphAnalysisC tag) except -1:
cdef hash_t key = tag.key
if self.tags.get(key) == NULL:
tag_ptr = <MorphAnalysisC*>self.mem.alloc(1, sizeof(MorphAnalysisC))
tag_ptr[0] = tag
self.tags.set(key, <void*>tag_ptr)
def get(self, morph_key):
return self.get_morph_str(morph_key)
def get(self, hash_t morph):
tag = <MorphAnalysisC*>self.tags.get(morph)
if tag == NULL:
return ""
else:
return self.strings[tag.key]
def normalize_features(self, features):
return self._normalize_features(features)
@staticmethod
def feats_to_dict(feats):
def feats_to_dict(feats, *, sort_values=True):
if not feats or feats == Morphology.EMPTY_MORPH:
return {}
return {field: Morphology.VALUE_SEP.join(sorted(values.split(Morphology.VALUE_SEP))) for field, values in
[feat.split(Morphology.FIELD_SEP) for feat in feats.split(Morphology.FEATURE_SEP)]}
out = {}
for feat in feats.split(Morphology.FEATURE_SEP):
field, values = feat.split(Morphology.FIELD_SEP, 1)
if sort_values:
values = values.split(Morphology.VALUE_SEP)
values.sort()
values = Morphology.VALUE_SEP.join(values)
out[field] = values
return out
@staticmethod
def dict_to_feats(feats_dict):
@ -160,34 +213,34 @@ cdef class Morphology:
return Morphology.FEATURE_SEP.join(sorted([Morphology.FIELD_SEP.join([field, Morphology.VALUE_SEP.join(sorted(values.split(Morphology.VALUE_SEP)))]) for field, values in feats_dict.items()]))
cdef int check_feature(const MorphAnalysisC* morph, attr_t feature) nogil:
cdef int check_feature(const shared_ptr[MorphAnalysisC] morph, attr_t feature) nogil:
cdef int i
for i in range(morph.length):
if morph.features[i] == feature:
for i in range(deref(morph).features.size()):
if deref(morph).features[i].value == feature:
return True
return False
cdef list list_features(const MorphAnalysisC* morph):
cdef list list_features(const shared_ptr[MorphAnalysisC] morph):
cdef int i
features = []
for i in range(morph.length):
features.append(morph.features[i])
for i in range(deref(morph).features.size()):
features.append(deref(morph).features[i].value)
return features
cdef np.ndarray get_by_field(const MorphAnalysisC* morph, attr_t field):
cdef np.ndarray results = numpy.zeros((morph.length,), dtype="uint64")
cdef np.ndarray get_by_field(const shared_ptr[MorphAnalysisC] morph, attr_t field):
cdef np.ndarray results = numpy.zeros((deref(morph).features.size(),), dtype="uint64")
n = get_n_by_field(<uint64_t*>results.data, morph, field)
return results[:n]
cdef int get_n_by_field(attr_t* results, const MorphAnalysisC* morph, attr_t field) nogil:
cdef int get_n_by_field(attr_t* results, const shared_ptr[MorphAnalysisC] morph, attr_t field) nogil:
cdef int n_results = 0
cdef int i
for i in range(morph.length):
if morph.fields[i] == field:
results[n_results] = morph.features[i]
for i in range(deref(morph).features.size()):
if deref(morph).features[i].field == field:
results[n_results] = deref(morph).features[i].value
n_results += 1
return n_results

View File

@ -127,8 +127,8 @@ class Morphologizer(Tagger):
@property
def labels(self):
"""RETURNS (Tuple[str]): The labels currently added to the component."""
return tuple(self.cfg["labels_morph"].keys())
"""RETURNS (Iterable[str]): The labels currently added to the component."""
return self.cfg["labels_morph"].keys()
@property
def label_data(self) -> Dict[str, Dict[str, Union[str, float, int, None]]]:
@ -151,7 +151,7 @@ class Morphologizer(Tagger):
# normalize label
norm_label = self.vocab.morphology.normalize_features(label)
# extract separate POS and morph tags
label_dict = Morphology.feats_to_dict(label)
label_dict = Morphology.feats_to_dict(label, sort_values=False)
pos = label_dict.get(self.POS_FEAT, "")
if self.POS_FEAT in label_dict:
label_dict.pop(self.POS_FEAT)
@ -189,7 +189,7 @@ class Morphologizer(Tagger):
continue
morph = str(token.morph)
# create and add the combined morph+POS label
morph_dict = Morphology.feats_to_dict(morph)
morph_dict = Morphology.feats_to_dict(morph, sort_values=False)
if pos:
morph_dict[self.POS_FEAT] = pos
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
@ -206,7 +206,7 @@ class Morphologizer(Tagger):
for i, token in enumerate(example.reference):
pos = token.pos_
morph = str(token.morph)
morph_dict = Morphology.feats_to_dict(morph)
morph_dict = Morphology.feats_to_dict(morph, sort_values=False)
if pos:
morph_dict[self.POS_FEAT] = pos
norm_label = self.vocab.strings[self.vocab.morphology.add(morph_dict)]
@ -231,26 +231,29 @@ class Morphologizer(Tagger):
cdef Vocab vocab = self.vocab
cdef bint overwrite = self.cfg["overwrite"]
cdef bint extend = self.cfg["extend"]
labels = self.labels
# We require random access for the upcoming ops, so we need
# to allocate a compatible container out of the iterable.
labels = tuple(self.labels)
for i, doc in enumerate(docs):
doc_tag_ids = batch_tag_ids[i]
if hasattr(doc_tag_ids, "get"):
doc_tag_ids = doc_tag_ids.get()
for j, tag_id in enumerate(doc_tag_ids):
morph = labels[tag_id]
morph = labels[int(tag_id)]
# set morph
if doc.c[j].morph == 0 or overwrite or extend:
if overwrite and extend:
# morphologizer morph overwrites any existing features
# while extending
extended_morph = Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph])
extended_morph.update(Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0)))
extended_morph = Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph], sort_values=False)
extended_morph.update(Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0), sort_values=False))
doc.c[j].morph = self.vocab.morphology.add(extended_morph)
elif extend:
# existing features are preserved and any new features
# are added
extended_morph = Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0))
extended_morph.update(Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph]))
extended_morph = Morphology.feats_to_dict(self.cfg["labels_morph"].get(morph, 0), sort_values=False)
extended_morph.update(Morphology.feats_to_dict(self.vocab.strings[doc.c[j].morph], sort_values=False))
doc.c[j].morph = self.vocab.morphology.add(extended_morph)
else:
# clobber
@ -270,7 +273,7 @@ class Morphologizer(Tagger):
DOCS: https://spacy.io/api/morphologizer#get_loss
"""
validate_examples(examples, "Morphologizer.get_loss")
loss_func = SequenceCategoricalCrossentropy(names=self.labels, normalize=False)
loss_func = SequenceCategoricalCrossentropy(names=tuple(self.labels), normalize=False)
truths = []
for eg in examples:
eg_truths = []
@ -291,7 +294,7 @@ class Morphologizer(Tagger):
label = None
# Otherwise, generate the combined label
else:
label_dict = Morphology.feats_to_dict(morph)
label_dict = Morphology.feats_to_dict(morph, sort_values=False)
if pos:
label_dict[self.POS_FEAT] = pos
label = self.vocab.strings[self.vocab.morphology.add(label_dict)]

View File

@ -58,14 +58,6 @@ cdef struct TokenC:
hash_t ent_id
cdef struct MorphAnalysisC:
hash_t key
int length
attr_t* fields
attr_t* features
# Internal struct, for storage and disambiguation of entities.
cdef struct KBEntryC:

View File

@ -1,9 +1,12 @@
from ..vocab cimport Vocab
from ..typedefs cimport hash_t
from ..structs cimport MorphAnalysisC
from ..morphology cimport MorphAnalysisC
from libcpp.memory cimport shared_ptr
cdef class MorphAnalysis:
cdef readonly Vocab vocab
cdef readonly hash_t key
cdef MorphAnalysisC c
cdef shared_ptr[MorphAnalysisC] c
cdef void _init_c(self, hash_t key)

View File

@ -5,7 +5,12 @@ from ..errors import Errors
from ..morphology import Morphology
from ..vocab cimport Vocab
from ..typedefs cimport hash_t, attr_t
from ..morphology cimport list_features, check_feature, get_by_field
from ..morphology cimport list_features, check_feature, get_by_field, MorphAnalysisC
from libcpp.memory cimport shared_ptr
from cython.operator cimport dereference as deref
cdef shared_ptr[MorphAnalysisC] EMPTY_MORPH_TAG = shared_ptr[MorphAnalysisC](new MorphAnalysisC())
cdef class MorphAnalysis:
@ -13,39 +18,38 @@ cdef class MorphAnalysis:
def __init__(self, Vocab vocab, features=dict()):
self.vocab = vocab
self.key = self.vocab.morphology.add(features)
analysis = <const MorphAnalysisC*>self.vocab.morphology.tags.get(self.key)
if analysis is not NULL:
self.c = analysis[0]
self._init_c(self.key)
cdef void _init_c(self, hash_t key):
cdef shared_ptr[MorphAnalysisC] analysis = self.vocab.morphology.get_morph_c(key)
if analysis:
self.c = analysis
else:
memset(&self.c, 0, sizeof(self.c))
self.c = EMPTY_MORPH_TAG
@classmethod
def from_id(cls, Vocab vocab, hash_t key):
"""Create a morphological analysis from a given ID."""
cdef MorphAnalysis morph = MorphAnalysis.__new__(MorphAnalysis, vocab)
cdef MorphAnalysis morph = MorphAnalysis(vocab)
morph.vocab = vocab
morph.key = key
analysis = <const MorphAnalysisC*>vocab.morphology.tags.get(key)
if analysis is not NULL:
morph.c = analysis[0]
else:
memset(&morph.c, 0, sizeof(morph.c))
morph._init_c(key)
return morph
def __contains__(self, feature):
"""Test whether the morphological analysis contains some feature."""
cdef attr_t feat_id = self.vocab.strings.as_int(feature)
return check_feature(&self.c, feat_id)
return check_feature(self.c, feat_id)
def __iter__(self):
"""Iterate over the features in the analysis."""
cdef attr_t feature
for feature in list_features(&self.c):
for feature in list_features(self.c):
yield self.vocab.strings[feature]
def __len__(self):
"""The number of features in the analysis."""
return self.c.length
return deref(self.c).features.size()
def __hash__(self):
return self.key
@ -61,7 +65,7 @@ cdef class MorphAnalysis:
def get(self, field):
"""Retrieve feature values by field."""
cdef attr_t field_id = self.vocab.strings.as_int(field)
cdef np.ndarray results = get_by_field(&self.c, field_id)
cdef np.ndarray results = get_by_field(self.c, field_id)
features = [self.vocab.strings[result] for result in results]
return [f.split(Morphology.FIELD_SEP)[1] for f in features]
@ -69,7 +73,7 @@ cdef class MorphAnalysis:
"""Produce a json serializable representation as a UD FEATS-style
string.
"""
morph_string = self.vocab.strings[self.c.key]
morph_string = self.vocab.strings[deref(self.c).key]
if morph_string == self.vocab.morphology.EMPTY_MORPH:
return ""
return morph_string

View File

@ -22,6 +22,7 @@ from .. import parts_of_speech
from ..errors import Errors, Warnings
from ..attrs import IOB_STRINGS
from .underscore import Underscore, get_ext_args
from cython.operator cimport dereference as deref
cdef class Token:
@ -230,7 +231,7 @@ cdef class Token:
# Check that the morph has the same vocab
if self.vocab != morph.vocab:
raise ValueError(Errors.E1013)
self.c.morph = morph.c.key
self.c.morph = deref(morph.c).key
def set_morph(self, features):
cdef hash_t key

View File

@ -401,7 +401,7 @@ coarse-grained POS as the feature `POS`.
| Name | Description |
| ----------- | ------------------------------------------------------ |
| **RETURNS** | The labels added to the component. ~~Tuple[str, ...]~~ |
| **RETURNS** | The labels added to the component. ~~Iterable[str, ...]~~ |
## Morphologizer.label_data {#label_data tag="property" new="3"}