# cython: infer_types=True, binding=True from cython.operator cimport dereference as deref from libc.stdint cimport UINT32_MAX, uint32_t from libc.string cimport memset from libcpp.pair cimport pair from libcpp.vector cimport vector from pathlib import Path from ...typedefs cimport hash_t from ... import util from ...errors import Errors from ...strings import StringStore from .schemas import validate_edit_tree NULL_TREE_ID = UINT32_MAX cdef LCS find_lcs(str source, str target): """ Find the longest common subsequence (LCS) between two strings. If there are multiple LCSes, only one of them is returned. source (str): The first string. target (str): The second string. RETURNS (LCS): The spans of the longest common subsequences. """ cdef Py_ssize_t source_len = len(source) cdef Py_ssize_t target_len = len(target) cdef size_t longest_align = 0; cdef int source_idx, target_idx cdef LCS lcs cdef Py_UCS4 source_cp, target_cp memset(&lcs, 0, sizeof(lcs)) cdef vector[size_t] prev_aligns = vector[size_t](target_len); cdef vector[size_t] cur_aligns = vector[size_t](target_len); for (source_idx, source_cp) in enumerate(source): for (target_idx, target_cp) in enumerate(target): if source_cp == target_cp: if source_idx == 0 or target_idx == 0: cur_aligns[target_idx] = 1 else: cur_aligns[target_idx] = prev_aligns[target_idx - 1] + 1 # Check if this is the longest alignment and replace previous # best alignment when this is the case. if cur_aligns[target_idx] > longest_align: longest_align = cur_aligns[target_idx] lcs.source_begin = source_idx - longest_align + 1 lcs.source_end = source_idx + 1 lcs.target_begin = target_idx - longest_align + 1 lcs.target_end = target_idx + 1 else: # No match, we start with a zero-length alignment. cur_aligns[target_idx] = 0 swap(prev_aligns, cur_aligns) return lcs cdef class EditTrees: """Container for constructing and storing edit trees.""" def __init__(self, strings: StringStore): """Create a container for edit trees. strings (StringStore): the string store to use.""" self.strings = strings cpdef uint32_t add(self, str form, str lemma): """Add an edit tree that rewrites the given string into the given lemma. RETURNS (int): identifier of the edit tree in the container. """ # Treat two empty strings as a special case. Generating an edit # tree for identical strings results in a match node. However, # since two empty strings have a zero-length LCS, a substitution # node would be created. Since we do not want to clutter the # recursive tree construction with logic for this case, handle # it in this wrapper method. if len(form) == 0 and len(lemma) == 0: tree = edittree_new_match(0, 0, NULL_TREE_ID, NULL_TREE_ID) return self._tree_id(tree) return self._add(form, lemma) cdef uint32_t _add(self, str form, str lemma): cdef LCS lcs = find_lcs(form, lemma) cdef EditTreeC tree cdef uint32_t tree_id, prefix_tree, suffix_tree if lcs_is_empty(lcs): tree = edittree_new_subst(self.strings.add(form), self.strings.add(lemma)) else: # If we have a non-empty LCS, such as "gooi" in "ge[gooi]d" and "[gooi]en", # create edit trees for the prefix pair ("ge"/"") and the suffix pair ("d"/"en"). prefix_tree = NULL_TREE_ID if lcs.source_begin != 0 or lcs.target_begin != 0: prefix_tree = self.add(form[:lcs.source_begin], lemma[:lcs.target_begin]) suffix_tree = NULL_TREE_ID if lcs.source_end != len(form) or lcs.target_end != len(lemma): suffix_tree = self.add(form[lcs.source_end:], lemma[lcs.target_end:]) tree = edittree_new_match(lcs.source_begin, len(form) - lcs.source_end, prefix_tree, suffix_tree) return self._tree_id(tree) cdef uint32_t _tree_id(self, EditTreeC tree): # If this tree has been constructed before, return its identifier. cdef hash_t hash = edittree_hash(tree) cdef unordered_map[hash_t, uint32_t].iterator iter = self.map.find(hash) if iter != self.map.end(): return deref(iter).second # The tree hasn't been seen before, store it. cdef uint32_t tree_id = self.trees.size() self.trees.push_back(tree) self.map.insert(pair[hash_t, uint32_t](hash, tree_id)) return tree_id cpdef str apply(self, uint32_t tree_id, str form): """Apply an edit tree to a form. tree_id (uint32_t): the identifier of the edit tree to apply. form (str): the form to apply the edit tree to. RETURNS (str): the transformer form or None if the edit tree could not be applied to the form. """ if tree_id >= self.trees.size(): raise IndexError(Errors.E1030) lemma_pieces = [] try: self._apply(tree_id, form, lemma_pieces) except ValueError: return None return "".join(lemma_pieces) cdef _apply(self, uint32_t tree_id, str form_part, list lemma_pieces): """Recursively apply an edit tree to a form, adding pieces to the lemma_pieces list.""" assert tree_id <= self.trees.size() cdef EditTreeC tree = self.trees[tree_id] cdef MatchNodeC match_node cdef int suffix_start if tree.is_match_node: match_node = tree.inner.match_node if match_node.prefix_len + match_node.suffix_len > len(form_part): raise ValueError(Errors.E1029) suffix_start = len(form_part) - match_node.suffix_len if match_node.prefix_tree != NULL_TREE_ID: self._apply(match_node.prefix_tree, form_part[:match_node.prefix_len], lemma_pieces) lemma_pieces.append(form_part[match_node.prefix_len:suffix_start]) if match_node.suffix_tree != NULL_TREE_ID: self._apply(match_node.suffix_tree, form_part[suffix_start:], lemma_pieces) else: if form_part == self.strings[tree.inner.subst_node.orig]: lemma_pieces.append(self.strings[tree.inner.subst_node.subst]) else: raise ValueError(Errors.E1029) cpdef unicode tree_to_str(self, uint32_t tree_id): """Return the tree as a string. The tree tree string is formatted like an S-expression. This is primarily useful for debugging. Match nodes have the following format: (m prefix_len suffix_len prefix_tree suffix_tree) Substitution nodes have the following format: (s original substitute) tree_id (uint32_t): the identifier of the edit tree. RETURNS (str): the tree as an S-expression. """ if tree_id >= self.trees.size(): raise IndexError(Errors.E1030) cdef EditTreeC tree = self.trees[tree_id] cdef SubstNodeC subst_node if not tree.is_match_node: subst_node = tree.inner.subst_node return f"(s '{self.strings[subst_node.orig]}' '{self.strings[subst_node.subst]}')" cdef MatchNodeC match_node = tree.inner.match_node prefix_tree = "()" if match_node.prefix_tree != NULL_TREE_ID: prefix_tree = self.tree_to_str(match_node.prefix_tree) suffix_tree = "()" if match_node.suffix_tree != NULL_TREE_ID: suffix_tree = self.tree_to_str(match_node.suffix_tree) return f"(m {match_node.prefix_len} {match_node.suffix_len} {prefix_tree} {suffix_tree})" def from_json(self, trees: list) -> "EditTrees": self.trees.clear() for tree in trees: tree = _dict2tree(tree) self.trees.push_back(tree) self._rebuild_tree_map() def from_bytes(self, bytes_data: bytes, *) -> "EditTrees": def deserialize_trees(tree_dicts): cdef EditTreeC c_tree for tree_dict in tree_dicts: c_tree = _dict2tree(tree_dict) self.trees.push_back(c_tree) deserializers = {} deserializers["trees"] = lambda n: deserialize_trees(n) util.from_bytes(bytes_data, deserializers, []) self._rebuild_tree_map() return self def to_bytes(self, **kwargs) -> bytes: tree_dicts = [] for tree in self.trees: tree = _tree2dict(tree) tree_dicts.append(tree) serializers = {} serializers["trees"] = lambda: tree_dicts return util.to_bytes(serializers, []) def to_disk(self, path, **kwargs) -> "EditTrees": path = util.ensure_path(path) with path.open("wb") as file_: file_.write(self.to_bytes()) def from_disk(self, path, **kwargs) -> "EditTrees": path = util.ensure_path(path) if path.exists(): with path.open("rb") as file_: data = file_.read() return self.from_bytes(data) return self def __getitem__(self, idx): return _tree2dict(self.trees[idx]) def __len__(self): return self.trees.size() def _rebuild_tree_map(self): """Rebuild the tree hash -> tree id mapping""" cdef EditTreeC c_tree cdef uint32_t tree_id cdef hash_t tree_hash self.map.clear() for tree_id in range(self.trees.size()): c_tree = self.trees[tree_id] tree_hash = edittree_hash(c_tree) self.map.insert(pair[hash_t, uint32_t](tree_hash, tree_id)) def __reduce__(self): return (unpickle_edittrees, (self.strings, self.to_bytes())) def unpickle_edittrees(strings, trees_data): return EditTrees(strings).from_bytes(trees_data) def _tree2dict(tree): if tree["is_match_node"]: tree = tree["inner"]["match_node"] else: tree = tree["inner"]["subst_node"] return(dict(tree)) def _dict2tree(tree): errors = validate_edit_tree(tree) if errors: raise ValueError(Errors.E1026.format(errors="\n".join(errors))) tree = dict(tree) if "prefix_len" in tree: tree = {"is_match_node": True, "inner": {"match_node": tree}} else: tree = {"is_match_node": False, "inner": {"subst_node": tree}} return tree