mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 05:31:15 +03:00 
			
		
		
		
	Preserve both `-` and `O` annotation in augmenters rather than relying on `Example.to_dict`'s default support for one option outside of labeled entity spans. This is intended as a temporary workaround for augmenters for v3.4.x. The behavior of `Example` and related IOB utils could be improved in the general case for v3.5.
		
			
				
	
	
		
			350 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			350 lines
		
	
	
		
			13 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Callable, Iterator, Dict, List, Tuple, TYPE_CHECKING
 | |
| from typing import Optional
 | |
| import random
 | |
| import itertools
 | |
| from functools import partial
 | |
| 
 | |
| from ..util import registry
 | |
| from .example import Example
 | |
| from .iob_utils import split_bilu_label, _doc_to_biluo_tags_with_partial
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from ..language import Language  # noqa: F401
 | |
| 
 | |
| 
 | |
| @registry.augmenters("spacy.combined_augmenter.v1")
 | |
| def create_combined_augmenter(
 | |
|     lower_level: float,
 | |
|     orth_level: float,
 | |
|     orth_variants: Optional[Dict[str, List[Dict]]],
 | |
|     whitespace_level: float,
 | |
|     whitespace_per_token: float,
 | |
|     whitespace_variants: Optional[List[str]],
 | |
| ) -> Callable[["Language", Example], Iterator[Example]]:
 | |
|     """Create a data augmentation callback that uses orth-variant replacement.
 | |
|     The callback can be added to a corpus or other data iterator during training.
 | |
| 
 | |
|     lower_level (float): The percentage of texts that will be lowercased.
 | |
|     orth_level (float): The percentage of texts that will be augmented.
 | |
|     orth_variants (Optional[Dict[str, List[Dict]]]): A dictionary containing the
 | |
|         single and paired orth variants. Typically loaded from a JSON file.
 | |
|     whitespace_level (float): The percentage of texts that will have whitespace
 | |
|         tokens inserted.
 | |
|     whitespace_per_token (float): The number of whitespace tokens to insert in
 | |
|         the modified doc as a percentage of the doc length.
 | |
|     whitespace_variants (Optional[List[str]]): The whitespace token texts.
 | |
|     RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
 | |
|     """
 | |
|     return partial(
 | |
|         combined_augmenter,
 | |
|         lower_level=lower_level,
 | |
|         orth_level=orth_level,
 | |
|         orth_variants=orth_variants,
 | |
|         whitespace_level=whitespace_level,
 | |
|         whitespace_per_token=whitespace_per_token,
 | |
|         whitespace_variants=whitespace_variants,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def combined_augmenter(
 | |
|     nlp: "Language",
 | |
|     example: Example,
 | |
|     *,
 | |
|     lower_level: float = 0.0,
 | |
|     orth_level: float = 0.0,
 | |
|     orth_variants: Optional[Dict[str, List[Dict]]] = None,
 | |
|     whitespace_level: float = 0.0,
 | |
|     whitespace_per_token: float = 0.0,
 | |
|     whitespace_variants: Optional[List[str]] = None,
 | |
| ) -> Iterator[Example]:
 | |
|     if random.random() < lower_level:
 | |
|         example = make_lowercase_variant(nlp, example)
 | |
|     if orth_variants and random.random() < orth_level:
 | |
|         raw_text = example.text
 | |
|         orig_dict = example.to_dict()
 | |
|         orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
 | |
|             example.reference
 | |
|         )
 | |
|         variant_text, variant_token_annot = make_orth_variants(
 | |
|             nlp,
 | |
|             raw_text,
 | |
|             orig_dict["token_annotation"],
 | |
|             orth_variants,
 | |
|             lower=False,
 | |
|         )
 | |
|         orig_dict["token_annotation"] = variant_token_annot
 | |
|         example = example.from_dict(nlp.make_doc(variant_text), orig_dict)
 | |
|     if whitespace_variants and random.random() < whitespace_level:
 | |
|         for _ in range(int(len(example.reference) * whitespace_per_token)):
 | |
|             example = make_whitespace_variant(
 | |
|                 nlp,
 | |
|                 example,
 | |
|                 random.choice(whitespace_variants),
 | |
|                 random.randrange(0, len(example.reference)),
 | |
|             )
 | |
|     yield example
 | |
| 
 | |
| 
 | |
| @registry.augmenters("spacy.orth_variants.v1")
 | |
| def create_orth_variants_augmenter(
 | |
|     level: float, lower: float, orth_variants: Dict[str, List[Dict]]
 | |
| ) -> Callable[["Language", Example], Iterator[Example]]:
 | |
|     """Create a data augmentation callback that uses orth-variant replacement.
 | |
|     The callback can be added to a corpus or other data iterator during training.
 | |
| 
 | |
|     level (float): The percentage of texts that will be augmented.
 | |
|     lower (float): The percentage of texts that will be lowercased.
 | |
|     orth_variants (Dict[str, List[Dict]]): A dictionary containing
 | |
|         the single and paired orth variants. Typically loaded from a JSON file.
 | |
|     RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
 | |
|     """
 | |
|     return partial(
 | |
|         orth_variants_augmenter, orth_variants=orth_variants, level=level, lower=lower
 | |
|     )
 | |
| 
 | |
| 
 | |
| @registry.augmenters("spacy.lower_case.v1")
 | |
| def create_lower_casing_augmenter(
 | |
|     level: float,
 | |
| ) -> Callable[["Language", Example], Iterator[Example]]:
 | |
|     """Create a data augmentation callback that converts documents to lowercase.
 | |
|     The callback can be added to a corpus or other data iterator during training.
 | |
| 
 | |
|     level (float): The percentage of texts that will be augmented.
 | |
|     RETURNS (Callable[[Language, Example], Iterator[Example]]): The augmenter.
 | |
|     """
 | |
|     return partial(lower_casing_augmenter, level=level)
 | |
| 
 | |
| 
 | |
| def dont_augment(nlp: "Language", example: Example) -> Iterator[Example]:
 | |
|     yield example
 | |
| 
 | |
| 
 | |
| def lower_casing_augmenter(
 | |
|     nlp: "Language", example: Example, *, level: float
 | |
| ) -> Iterator[Example]:
 | |
|     if random.random() >= level:
 | |
|         yield example
 | |
|     else:
 | |
|         yield make_lowercase_variant(nlp, example)
 | |
| 
 | |
| 
 | |
| def make_lowercase_variant(nlp: "Language", example: Example):
 | |
|     example_dict = example.to_dict()
 | |
|     example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
 | |
|         example.reference
 | |
|     )
 | |
|     doc = nlp.make_doc(example.text.lower())
 | |
|     example_dict["token_annotation"]["ORTH"] = [t.lower_ for t in example.reference]
 | |
|     return example.from_dict(doc, example_dict)
 | |
| 
 | |
| 
 | |
| def orth_variants_augmenter(
 | |
|     nlp: "Language",
 | |
|     example: Example,
 | |
|     orth_variants: Dict[str, List[Dict]],
 | |
|     *,
 | |
|     level: float = 0.0,
 | |
|     lower: float = 0.0,
 | |
| ) -> Iterator[Example]:
 | |
|     if random.random() >= level:
 | |
|         yield example
 | |
|     else:
 | |
|         raw_text = example.text
 | |
|         orig_dict = example.to_dict()
 | |
|         orig_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
 | |
|             example.reference
 | |
|         )
 | |
|         variant_text, variant_token_annot = make_orth_variants(
 | |
|             nlp,
 | |
|             raw_text,
 | |
|             orig_dict["token_annotation"],
 | |
|             orth_variants,
 | |
|             lower=raw_text is not None and random.random() < lower,
 | |
|         )
 | |
|         orig_dict["token_annotation"] = variant_token_annot
 | |
|         yield example.from_dict(nlp.make_doc(variant_text), orig_dict)
 | |
| 
 | |
| 
 | |
| def make_orth_variants(
 | |
|     nlp: "Language",
 | |
|     raw: str,
 | |
|     token_dict: Dict[str, List[str]],
 | |
|     orth_variants: Dict[str, List[Dict[str, List[str]]]],
 | |
|     *,
 | |
|     lower: bool = False,
 | |
| ) -> Tuple[str, Dict[str, List[str]]]:
 | |
|     words = token_dict.get("ORTH", [])
 | |
|     tags = token_dict.get("TAG", [])
 | |
|     # keep unmodified if words are not defined
 | |
|     if not words:
 | |
|         return raw, token_dict
 | |
|     if lower:
 | |
|         words = [w.lower() for w in words]
 | |
|         raw = raw.lower()
 | |
|     # if no tags, only lowercase
 | |
|     if not tags:
 | |
|         token_dict["ORTH"] = words
 | |
|         return raw, token_dict
 | |
|     # single variants
 | |
|     ndsv = orth_variants.get("single", [])
 | |
|     punct_choices = [random.choice(x["variants"]) for x in ndsv]
 | |
|     for word_idx in range(len(words)):
 | |
|         for punct_idx in range(len(ndsv)):
 | |
|             if (
 | |
|                 tags[word_idx] in ndsv[punct_idx]["tags"]
 | |
|                 and words[word_idx] in ndsv[punct_idx]["variants"]
 | |
|             ):
 | |
|                 words[word_idx] = punct_choices[punct_idx]
 | |
|     # paired variants
 | |
|     ndpv = orth_variants.get("paired", [])
 | |
|     punct_choices = [random.choice(x["variants"]) for x in ndpv]
 | |
|     for word_idx in range(len(words)):
 | |
|         for punct_idx in range(len(ndpv)):
 | |
|             if tags[word_idx] in ndpv[punct_idx]["tags"] and words[
 | |
|                 word_idx
 | |
|             ] in itertools.chain.from_iterable(ndpv[punct_idx]["variants"]):
 | |
|                 # backup option: random left vs. right from pair
 | |
|                 pair_idx = random.choice([0, 1])
 | |
|                 # best option: rely on paired POS tags like `` / ''
 | |
|                 if len(ndpv[punct_idx]["tags"]) == 2:
 | |
|                     pair_idx = ndpv[punct_idx]["tags"].index(tags[word_idx])
 | |
|                 # next best option: rely on position in variants
 | |
|                 # (may not be unambiguous, so order of variants matters)
 | |
|                 else:
 | |
|                     for pair in ndpv[punct_idx]["variants"]:
 | |
|                         if words[word_idx] in pair:
 | |
|                             pair_idx = pair.index(words[word_idx])
 | |
|                 words[word_idx] = punct_choices[punct_idx][pair_idx]
 | |
|     token_dict["ORTH"] = words
 | |
|     raw = construct_modified_raw_text(token_dict)
 | |
|     return raw, token_dict
 | |
| 
 | |
| 
 | |
| def make_whitespace_variant(
 | |
|     nlp: "Language",
 | |
|     example: Example,
 | |
|     whitespace: str,
 | |
|     position: int,
 | |
| ) -> Example:
 | |
|     """Insert the whitespace token at the specified token offset in the doc.
 | |
|     This is primarily intended for v2-compatible training data that doesn't
 | |
|     include links or spans. If the document includes links, spans, or partial
 | |
|     dependency annotation, it is returned without modifications.
 | |
| 
 | |
|     The augmentation follows the basics of the v2 space attachment policy, but
 | |
|     without a distinction between "real" and other tokens, so space tokens
 | |
|     may be attached to space tokens:
 | |
|     - at the beginning of a sentence attach the space token to the following
 | |
|       token
 | |
|     - otherwise attach the space token to the preceding token
 | |
| 
 | |
|     The augmenter does not attempt to consolidate adjacent whitespace in the
 | |
|     same way that the tokenizer would.
 | |
| 
 | |
|     The following annotation is used for the space token:
 | |
|     TAG: "_SP"
 | |
|     MORPH: ""
 | |
|     POS: "SPACE"
 | |
|     LEMMA: ORTH
 | |
|     DEP: "dep"
 | |
|     SENT_START: False
 | |
| 
 | |
|     The annotation for each attribute is only set for the space token if there
 | |
|     is already at least partial annotation for that attribute in the original
 | |
|     example.
 | |
| 
 | |
|     RETURNS (Example): Example with one additional space token.
 | |
|     """
 | |
|     example_dict = example.to_dict()
 | |
|     example_dict["doc_annotation"]["entities"] = _doc_to_biluo_tags_with_partial(
 | |
|         example.reference
 | |
|     )
 | |
|     doc_dict = example_dict.get("doc_annotation", {})
 | |
|     token_dict = example_dict.get("token_annotation", {})
 | |
|     # returned unmodified if:
 | |
|     # - doc is empty
 | |
|     # - words are not defined
 | |
|     # - links are defined (only character-based offsets, which is more a quirk
 | |
|     #   of Example.to_dict than a technical constraint)
 | |
|     # - spans are defined
 | |
|     # - there are partial dependencies
 | |
|     if (
 | |
|         len(example.reference) == 0
 | |
|         or "ORTH" not in token_dict
 | |
|         or len(doc_dict.get("links", [])) > 0
 | |
|         or len(example.reference.spans) > 0
 | |
|         or (
 | |
|             example.reference.has_annotation("DEP")
 | |
|             and not example.reference.has_annotation("DEP", require_complete=True)
 | |
|         )
 | |
|     ):
 | |
|         return example
 | |
|     words = token_dict.get("ORTH", [])
 | |
|     length = len(words)
 | |
|     assert 0 <= position <= length
 | |
|     if example.reference.has_annotation("ENT_TYPE"):
 | |
|         # I-ENTITY if between B/I-ENTITY and I/L-ENTITY otherwise O
 | |
|         entity = "O"
 | |
|         if position > 1 and position < length:
 | |
|             ent_prev = doc_dict["entities"][position - 1]
 | |
|             ent_next = doc_dict["entities"][position]
 | |
|             if "-" in ent_prev and "-" in ent_next:
 | |
|                 ent_iob_prev, ent_type_prev = split_bilu_label(ent_prev)
 | |
|                 ent_iob_next, ent_type_next = split_bilu_label(ent_next)
 | |
|                 if (
 | |
|                     ent_iob_prev in ("B", "I")
 | |
|                     and ent_iob_next in ("I", "L")
 | |
|                     and ent_type_prev == ent_type_next
 | |
|                 ):
 | |
|                     entity = f"I-{ent_type_prev}"
 | |
|         doc_dict["entities"].insert(position, entity)
 | |
|     else:
 | |
|         del doc_dict["entities"]
 | |
|     token_dict["ORTH"].insert(position, whitespace)
 | |
|     token_dict["SPACY"].insert(position, False)
 | |
|     if example.reference.has_annotation("TAG"):
 | |
|         token_dict["TAG"].insert(position, "_SP")
 | |
|     else:
 | |
|         del token_dict["TAG"]
 | |
|     if example.reference.has_annotation("LEMMA"):
 | |
|         token_dict["LEMMA"].insert(position, whitespace)
 | |
|     else:
 | |
|         del token_dict["LEMMA"]
 | |
|     if example.reference.has_annotation("POS"):
 | |
|         token_dict["POS"].insert(position, "SPACE")
 | |
|     else:
 | |
|         del token_dict["POS"]
 | |
|     if example.reference.has_annotation("MORPH"):
 | |
|         token_dict["MORPH"].insert(position, "")
 | |
|     else:
 | |
|         del token_dict["MORPH"]
 | |
|     if example.reference.has_annotation("DEP", require_complete=True):
 | |
|         if position == 0:
 | |
|             token_dict["HEAD"].insert(position, 0)
 | |
|         else:
 | |
|             token_dict["HEAD"].insert(position, position - 1)
 | |
|         for i in range(len(token_dict["HEAD"])):
 | |
|             if token_dict["HEAD"][i] >= position:
 | |
|                 token_dict["HEAD"][i] += 1
 | |
|         token_dict["DEP"].insert(position, "dep")
 | |
|     else:
 | |
|         del token_dict["HEAD"]
 | |
|         del token_dict["DEP"]
 | |
|     if example.reference.has_annotation("SENT_START"):
 | |
|         token_dict["SENT_START"].insert(position, False)
 | |
|     else:
 | |
|         del token_dict["SENT_START"]
 | |
|     raw = construct_modified_raw_text(token_dict)
 | |
|     return Example.from_dict(nlp.make_doc(raw), example_dict)
 | |
| 
 | |
| 
 | |
| def construct_modified_raw_text(token_dict):
 | |
|     """Construct modified raw text from words and spaces."""
 | |
|     raw = ""
 | |
|     for orth, spacy in zip(token_dict["ORTH"], token_dict["SPACY"]):
 | |
|         raw += orth
 | |
|         if spacy:
 | |
|             raw += " "
 | |
|     return raw
 |