mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Add progress on SpanPredictor component
This isn't working. There is a CUDA error in the torch code during initialization and it's not clear why.
This commit is contained in:
		
							parent
							
								
									a098849112
								
							
						
					
					
						commit
						2190cbc0e6
					
				| 
						 | 
					@ -14,7 +14,7 @@ from ..extract_spans import extract_spans
 | 
				
			||||||
import torch
 | 
					import torch
 | 
				
			||||||
from thinc.util import xp2torch, torch2xp
 | 
					from thinc.util import xp2torch, torch2xp
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .coref_util import add_dummy
 | 
					from .coref_util import add_dummy, get_sentence_ids
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@registry.architectures("spacy.Coref.v1")
 | 
					@registry.architectures("spacy.Coref.v1")
 | 
				
			||||||
def build_wl_coref_model(
 | 
					def build_wl_coref_model(
 | 
				
			||||||
| 
						 | 
					@ -74,6 +74,33 @@ def build_wl_coref_model(
 | 
				
			||||||
    # and just return words as spans.
 | 
					    # and just return words as spans.
 | 
				
			||||||
    return coref_model
 | 
					    return coref_model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@registry.architectures("spacy.SpanPredictor.v1")
 | 
				
			||||||
 | 
					def build_span_predictor(
 | 
				
			||||||
 | 
					    tok2vec: Model[List[Doc], List[Floats2d]],
 | 
				
			||||||
 | 
					    hidden_size: int = 1024,
 | 
				
			||||||
 | 
					    dist_emb_size: int = 64,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					    # TODO fix this
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        dim = tok2vec.get_dim("nO")
 | 
				
			||||||
 | 
					    except ValueError:
 | 
				
			||||||
 | 
					        # happens with transformer listener
 | 
				
			||||||
 | 
					        dim = 768
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with Model.define_operators({">>": chain, "&": tuplify}):
 | 
				
			||||||
 | 
					        # TODO fix device - should be automatic
 | 
				
			||||||
 | 
					        device = "cuda:0"
 | 
				
			||||||
 | 
					        span_predictor = PyTorchWrapper(
 | 
				
			||||||
 | 
					            SpanPredictor(hidden_size, dist_emb_size, device),
 | 
				
			||||||
 | 
					            convert_inputs=convert_span_predictor_inputs
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        # TODO use proper parameter for prefix
 | 
				
			||||||
 | 
					        head_info = build_get_head_metadata("coref_head_clusters")
 | 
				
			||||||
 | 
					        model = (tok2vec & head_info) >> span_predictor
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def convert_coref_scorer_inputs(
 | 
					def convert_coref_scorer_inputs(
 | 
				
			||||||
    model: Model,
 | 
					    model: Model,
 | 
				
			||||||
    X: List[Floats2d],
 | 
					    X: List[Floats2d],
 | 
				
			||||||
| 
						 | 
					@ -84,6 +111,7 @@ def convert_coref_scorer_inputs(
 | 
				
			||||||
    # TODO real batching
 | 
					    # TODO real batching
 | 
				
			||||||
    X = X[0]
 | 
					    X = X[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    word_features = xp2torch(X, requires_grad=is_train)
 | 
					    word_features = xp2torch(X, requires_grad=is_train)
 | 
				
			||||||
    def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | 
					    def backprop(args: ArgsKwargs) -> List[Floats2d]:
 | 
				
			||||||
        # convert to xp and wrap in list
 | 
					        # convert to xp and wrap in list
 | 
				
			||||||
| 
						 | 
					@ -116,10 +144,15 @@ def convert_span_predictor_inputs(
 | 
				
			||||||
    X: Tuple[Ints1d, Floats2d, Ints1d],
 | 
					    X: Tuple[Ints1d, Floats2d, Ints1d],
 | 
				
			||||||
    is_train: bool
 | 
					    is_train: bool
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    sent_id = xp2torch(X[0], requires_grad=False)
 | 
					    tok2vec, (sent_ids, head_ids) = X
 | 
				
			||||||
    word_features = xp2torch(X[1], requires_grad=False)
 | 
					    # Normally we shoudl use the input is_train, but for these two it's not relevant
 | 
				
			||||||
    head_ids = xp2torch(X[2], requires_grad=False)
 | 
					    sent_ids = xp2torch(sent_ids[0], requires_grad=False)
 | 
				
			||||||
    argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={})
 | 
					    head_ids = xp2torch(head_ids[0], requires_grad=False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    word_features = xp2torch(tok2vec[0], requires_grad=is_train)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
 | 
				
			||||||
 | 
					    # TODO actually support backprop
 | 
				
			||||||
    return argskwargs, lambda dX: []
 | 
					    return argskwargs, lambda dX: []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# TODO This probably belongs in the component, not the model.
 | 
					# TODO This probably belongs in the component, not the model.
 | 
				
			||||||
| 
						 | 
					@ -189,6 +222,36 @@ def _clusterize(
 | 
				
			||||||
            clusters.append(sorted(cluster))
 | 
					            clusters.append(sorted(cluster))
 | 
				
			||||||
    return sorted(clusters)
 | 
					    return sorted(clusters)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def build_get_head_metadata(prefix):
 | 
				
			||||||
 | 
					    # TODO this name is awful, fix it
 | 
				
			||||||
 | 
					    model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward)
 | 
				
			||||||
 | 
					    return model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def head_data_forward(model, docs, is_train):
 | 
				
			||||||
 | 
					    """A layer to generate the extra data needed for the span predictor.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					    sent_ids = []
 | 
				
			||||||
 | 
					    head_ids = []
 | 
				
			||||||
 | 
					    prefix = model.attrs["prefix"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for doc in docs:
 | 
				
			||||||
 | 
					        sids = model.ops.asarray2i(get_sentence_ids(doc))
 | 
				
			||||||
 | 
					        sent_ids.append(sids)
 | 
				
			||||||
 | 
					        heads = []
 | 
				
			||||||
 | 
					        for key, sg in doc.spans.items():
 | 
				
			||||||
 | 
					            if not key.startswith(prefix):
 | 
				
			||||||
 | 
					                continue
 | 
				
			||||||
 | 
					            for span in sg:
 | 
				
			||||||
 | 
					                # TODO warn if spans are more than one token
 | 
				
			||||||
 | 
					                heads.append(span[0].i)
 | 
				
			||||||
 | 
					        heads = model.ops.asarray2i(heads)
 | 
				
			||||||
 | 
					        head_ids.append(heads)
 | 
				
			||||||
 | 
					    
 | 
				
			||||||
 | 
					    # each of these is a list with one entry per doc
 | 
				
			||||||
 | 
					    # backprop is just a placeholder
 | 
				
			||||||
 | 
					    # TODO it would probably be better to have a list of tuples than two lists of arrays
 | 
				
			||||||
 | 
					    return (sent_ids, head_ids), lambda x: []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CorefScorer(torch.nn.Module):
 | 
					class CorefScorer(torch.nn.Module):
 | 
				
			||||||
    """Combines all coref modules together to find coreferent spans.
 | 
					    """Combines all coref modules together to find coreferent spans.
 | 
				
			||||||
| 
						 | 
					@ -492,6 +555,7 @@ class SpanPredictor(torch.nn.Module):
 | 
				
			||||||
        emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
 | 
					        emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
 | 
				
			||||||
        # Obtain "same sentence" boolean mask, [n_heads, n_words]
 | 
					        # Obtain "same sentence" boolean mask, [n_heads, n_words]
 | 
				
			||||||
        sent_id = torch.tensor(sent_id, device=words.device)
 | 
					        sent_id = torch.tensor(sent_id, device=words.device)
 | 
				
			||||||
 | 
					        heads_ids = heads_ids.long()
 | 
				
			||||||
        same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
 | 
					        same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # To save memory, only pass candidates from one sentence for each head
 | 
					        # To save memory, only pass candidates from one sentence for each head
 | 
				
			||||||
| 
						 | 
					@ -506,7 +570,7 @@ class SpanPredictor(torch.nn.Module):
 | 
				
			||||||
        ), dim=1)
 | 
					        ), dim=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        lengths = same_sent.sum(dim=1)
 | 
					        lengths = same_sent.sum(dim=1)
 | 
				
			||||||
        padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
 | 
					        padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
 | 
				
			||||||
        padding_mask = (padding_mask < lengths.unsqueeze(1))  # [n_heads, max_sent_len]
 | 
					        padding_mask = (padding_mask < lengths.unsqueeze(1))  # [n_heads, max_sent_len]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
 | 
					        # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,6 +39,15 @@ def add_dummy(tensor: torch.Tensor, eps: bool = False):
 | 
				
			||||||
    output = torch.cat((dummy, tensor), dim=1)
 | 
					    output = torch.cat((dummy, tensor), dim=1)
 | 
				
			||||||
    return output
 | 
					    return output
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def get_sentence_ids(doc):
 | 
				
			||||||
 | 
					    out = []
 | 
				
			||||||
 | 
					    sent_id = -1
 | 
				
			||||||
 | 
					    for tok in doc:
 | 
				
			||||||
 | 
					        if tok.is_sent_start:
 | 
				
			||||||
 | 
					            sent_id += 1
 | 
				
			||||||
 | 
					        out.append(sent_id)
 | 
				
			||||||
 | 
					    return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
 | 
					def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
 | 
				
			||||||
    """Given a doc, give the mention clusters.
 | 
					    """Given a doc, give the mention clusters.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,7 +1,7 @@
 | 
				
			||||||
from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
 | 
					from typing import Iterable, Tuple, Optional, Dict, Callable, Any, List
 | 
				
			||||||
import warnings
 | 
					import warnings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from thinc.types import Floats2d, Ints2d
 | 
					from thinc.types import Floats2d, Floats3d, Ints2d
 | 
				
			||||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
 | 
					from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
 | 
				
			||||||
from thinc.api import set_dropout_rate
 | 
					from thinc.api import set_dropout_rate
 | 
				
			||||||
from itertools import islice
 | 
					from itertools import islice
 | 
				
			||||||
| 
						 | 
					@ -84,6 +84,7 @@ def make_coref(
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class CoreferenceResolver(TrainablePipe):
 | 
					class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
    """Pipeline component for coreference resolution.
 | 
					    """Pipeline component for coreference resolution.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -208,7 +209,7 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
        total_loss = 0
 | 
					        total_loss = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for eg in examples:
 | 
					        for eg in examples:
 | 
				
			||||||
            # TODO does this even work?
 | 
					            # TODO check this causes no issues (in practice it runs)
 | 
				
			||||||
            preds, backprop = self.model.begin_update([eg.predicted])
 | 
					            preds, backprop = self.model.begin_update([eg.predicted])
 | 
				
			||||||
            score_matrix, mention_idx = preds
 | 
					            score_matrix, mention_idx = preds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -384,6 +385,52 @@ class CoreferenceResolver(TrainablePipe):
 | 
				
			||||||
        return out
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					default_span_predictor_config = """
 | 
				
			||||||
 | 
					[model]
 | 
				
			||||||
 | 
					@architectures = "spacy.SpanPredictor.v1"
 | 
				
			||||||
 | 
					hidden_size = 1024
 | 
				
			||||||
 | 
					dist_emb_size = 64
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[model.tok2vec]
 | 
				
			||||||
 | 
					@architectures = "spacy.Tok2Vec.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[model.tok2vec.embed]
 | 
				
			||||||
 | 
					@architectures = "spacy.MultiHashEmbed.v1"
 | 
				
			||||||
 | 
					width = 64
 | 
				
			||||||
 | 
					rows = [2000, 2000, 1000, 1000, 1000, 1000]
 | 
				
			||||||
 | 
					attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
 | 
				
			||||||
 | 
					include_static_vectors = false
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[model.tok2vec.encode]
 | 
				
			||||||
 | 
					@architectures = "spacy.MaxoutWindowEncoder.v2"
 | 
				
			||||||
 | 
					width = ${model.tok2vec.embed.width}
 | 
				
			||||||
 | 
					window_size = 1
 | 
				
			||||||
 | 
					maxout_pieces = 3
 | 
				
			||||||
 | 
					depth = 2
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)["model"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					@Language.factory(
 | 
				
			||||||
 | 
					        "span_predictor",
 | 
				
			||||||
 | 
					        assigns=["doc.spans"],
 | 
				
			||||||
 | 
					        requires=["doc.spans"],
 | 
				
			||||||
 | 
					        default_config={
 | 
				
			||||||
 | 
					            "model": DEFAULT_SPAN_PREDICTOR_MODEL,
 | 
				
			||||||
 | 
					            "input_prefix": "coref_head_clusters",
 | 
				
			||||||
 | 
					            "output_prefix": "coref_clusters",
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					    default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None},
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					def make_span_predictor(
 | 
				
			||||||
 | 
					        nlp: Language,
 | 
				
			||||||
 | 
					        name: str,
 | 
				
			||||||
 | 
					        model,
 | 
				
			||||||
 | 
					        input_prefix: str = "coref_head_clusters",
 | 
				
			||||||
 | 
					        output_prefix: str = "coref_clusters",
 | 
				
			||||||
 | 
					) -> "SpanPredictor":
 | 
				
			||||||
 | 
					    """Create a SpanPredictor component."""
 | 
				
			||||||
 | 
					    return SpanPredictor(nlp.vocab, model, name, input_prefix=input_prefix, output_prefix=output_prefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class SpanPredictor(TrainablePipe):
 | 
					class SpanPredictor(TrainablePipe):
 | 
				
			||||||
    """Pipeline component to resolve one-token spans to full spans.
 | 
					    """Pipeline component to resolve one-token spans to full spans.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -407,11 +454,41 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.cfg = {}
 | 
					        self.cfg = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def predict(self, docs: Iterable[Doc]):
 | 
					    def predict(self, docs: Iterable[Doc]) -> List[MentionClusters]:
 | 
				
			||||||
        ...
 | 
					        # for now pretend there's just one doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out = []
 | 
				
			||||||
 | 
					        for doc in docs:
 | 
				
			||||||
 | 
					            # TODO check shape here
 | 
				
			||||||
 | 
					            span_scores = self.model.predict(doc)
 | 
				
			||||||
 | 
					            span_scores = span_scores[0]
 | 
				
			||||||
 | 
					            # the information about clustering has to come from the input docs
 | 
				
			||||||
 | 
					            # first let's convert the scores to a list of span idxs
 | 
				
			||||||
 | 
					            start_scores = span_scores[:, :, 0]
 | 
				
			||||||
 | 
					            end_scores = span_scores[:, :, 1]
 | 
				
			||||||
 | 
					            starts = start_scores.argmax(axis=1)
 | 
				
			||||||
 | 
					            ends = end_scores.argmax(axis=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # TODO check start < end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # get the old clusters (shape will be preserved)
 | 
				
			||||||
 | 
					            clusters = doc2clusters(doc, self.input_prefix)
 | 
				
			||||||
 | 
					            cidx = 0
 | 
				
			||||||
 | 
					            out_clusters = []
 | 
				
			||||||
 | 
					            for cluster in clusters:
 | 
				
			||||||
 | 
					                ncluster = []
 | 
				
			||||||
 | 
					                for mention in cluster:
 | 
				
			||||||
 | 
					                    ncluster.append( (starts[cidx], ends[cidx]) )
 | 
				
			||||||
 | 
					                    cidx += 1
 | 
				
			||||||
 | 
					                out_clusters.append(ncluster)
 | 
				
			||||||
 | 
					            out.append(out_clusters)
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
 | 
					    def set_annotations(self, docs: Iterable[Doc], clusters_by_doc) -> None:
 | 
				
			||||||
        ...
 | 
					        for doc, clusters in zip(docs, clusters_by_doc):
 | 
				
			||||||
 | 
					            for ii, cluster in enumerate(clusters):
 | 
				
			||||||
 | 
					                spans = [doc[mm[0]:mm[1]] for mm in cluster]
 | 
				
			||||||
 | 
					                doc.spans[f"{self.output_prefix}_{ii}"] = spans
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update(
 | 
					    def update(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -421,7 +498,33 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
        sgd: Optional[Optimizer] = None,
 | 
					        sgd: Optional[Optimizer] = None,
 | 
				
			||||||
        losses: Optional[Dict[str, float]] = None,
 | 
					        losses: Optional[Dict[str, float]] = None,
 | 
				
			||||||
    ) -> Dict[str, float]:
 | 
					    ) -> Dict[str, float]:
 | 
				
			||||||
        ...
 | 
					        """Learn from a batch of documents and gold-standard information,
 | 
				
			||||||
 | 
					        updating the pipe's model. Delegates to predict and get_loss.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        if losses is None:
 | 
				
			||||||
 | 
					            losses = {}
 | 
				
			||||||
 | 
					        losses.setdefault(self.name, 0.0)
 | 
				
			||||||
 | 
					        validate_examples(examples, "SpanPredictor.update")
 | 
				
			||||||
 | 
					        if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
 | 
				
			||||||
 | 
					            # Handle cases where there are no tokens in any docs.
 | 
				
			||||||
 | 
					            return losses
 | 
				
			||||||
 | 
					        set_dropout_rate(self.model, drop)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        total_loss = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for eg in examples:
 | 
				
			||||||
 | 
					            preds, backprop = self.model.begin_update([eg.predicted])
 | 
				
			||||||
 | 
					            score_matrix, mention_idx = preds
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
 | 
				
			||||||
 | 
					            total_loss += loss
 | 
				
			||||||
 | 
					            # TODO check shape here
 | 
				
			||||||
 | 
					            backprop((d_scores, mention_idx))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if sgd is not None:
 | 
				
			||||||
 | 
					            self.finish_update(sgd)
 | 
				
			||||||
 | 
					        losses[self.name] += total_loss
 | 
				
			||||||
 | 
					        return losses
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def rehearse(
 | 
					    def rehearse(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -431,7 +534,12 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
        sgd: Optional[Optimizer] = None,
 | 
					        sgd: Optional[Optimizer] = None,
 | 
				
			||||||
        losses: Optional[Dict[str, float]] = None,
 | 
					        losses: Optional[Dict[str, float]] = None,
 | 
				
			||||||
    ) -> Dict[str, float]:
 | 
					    ) -> Dict[str, float]:
 | 
				
			||||||
        ...
 | 
					        # TODO this should be added later
 | 
				
			||||||
 | 
					        raise NotImplementedError(
 | 
				
			||||||
 | 
					            Errors.E931.format(
 | 
				
			||||||
 | 
					                parent="SpanPredictor", method="add_label", name=self.name
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_label(self, label: str) -> int:
 | 
					    def add_label(self, label: str) -> int:
 | 
				
			||||||
        """Technically this method should be implemented from TrainablePipe,
 | 
					        """Technically this method should be implemented from TrainablePipe,
 | 
				
			||||||
| 
						 | 
					@ -446,9 +554,39 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
    def get_loss(
 | 
					    def get_loss(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
        examples: Iterable[Example],
 | 
					        examples: Iterable[Example],
 | 
				
			||||||
        # TODO add necessary args
 | 
					        span_scores: Floats3d,
 | 
				
			||||||
    ):
 | 
					    ):
 | 
				
			||||||
        ...
 | 
					        ops = self.model.ops
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # NOTE This is doing fake batching, and should always get a list of one example
 | 
				
			||||||
 | 
					        assert len(examples) == 1, "Only fake batching is supported."
 | 
				
			||||||
 | 
					        # starts and ends are gold starts and ends (Ints1d)
 | 
				
			||||||
 | 
					        # span_scores is a Floats3d. What are the axes? mention x token x start/end
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for eg in examples:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # get gold data
 | 
				
			||||||
 | 
					            gold = doc2clusters(eg.reference, self.output_prefix)
 | 
				
			||||||
 | 
					            # flatten the gold data
 | 
				
			||||||
 | 
					            starts = []
 | 
				
			||||||
 | 
					            ends = []
 | 
				
			||||||
 | 
					            for cluster in gold:
 | 
				
			||||||
 | 
					                for mention in cluster:
 | 
				
			||||||
 | 
					                    starts.append(mention[0])
 | 
				
			||||||
 | 
					                    ends.append(mention[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            start_scores = span_scores[:, :, 0]
 | 
				
			||||||
 | 
					            end_scores = span_scores[:, :, 1]
 | 
				
			||||||
 | 
					            n_classes = start_scores.shape[1]
 | 
				
			||||||
 | 
					            start_probs = ops.softmax(start_scores, axis=1)
 | 
				
			||||||
 | 
					            end_probs = ops.softmax(end_scores, axis=1)
 | 
				
			||||||
 | 
					            start_targets = to_categorical(starts, n_classes)
 | 
				
			||||||
 | 
					            end_targets = to_categorical(ends, n_classes)
 | 
				
			||||||
 | 
					            start_grads = (start_probs - start_targets)
 | 
				
			||||||
 | 
					            end_grads = (end_probs - end_targets)
 | 
				
			||||||
 | 
					            grads = ops.xp.stack((start_grads, end_grads), axis=2)
 | 
				
			||||||
 | 
					            loss = float((grads ** 2).sum())
 | 
				
			||||||
 | 
					        return loss, grads
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def initialize(
 | 
					    def initialize(
 | 
				
			||||||
        self,
 | 
					        self,
 | 
				
			||||||
| 
						 | 
					@ -461,6 +599,12 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
        X = []
 | 
					        X = []
 | 
				
			||||||
        Y = []
 | 
					        Y = []
 | 
				
			||||||
        for ex in islice(get_examples(), 2):
 | 
					        for ex in islice(get_examples(), 2):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not ex.predicted.spans:
 | 
				
			||||||
 | 
					                # set placeholder for shape inference
 | 
				
			||||||
 | 
					                doc = ex.predicted
 | 
				
			||||||
 | 
					                assert len(doc) > 2, "Coreference requires at least two tokens"
 | 
				
			||||||
 | 
					                doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
 | 
				
			||||||
            X.append(ex.predicted)
 | 
					            X.append(ex.predicted)
 | 
				
			||||||
            Y.append(ex.reference)
 | 
					            Y.append(ex.reference)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -468,5 +612,31 @@ class SpanPredictor(TrainablePipe):
 | 
				
			||||||
        self.model.initialize(X=X, Y=Y)
 | 
					        self.model.initialize(X=X, Y=Y)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def score(self, examples, **kwargs):
 | 
					    def score(self, examples, **kwargs):
 | 
				
			||||||
        # TODO this will overlap significantly with coref, maybe factor into function
 | 
					        """Score a batch of examples."""
 | 
				
			||||||
        ...
 | 
					        # TODO This is basically the same as the main coref component - factor out?
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        scores = []
 | 
				
			||||||
 | 
					        for metric in (b_cubed, muc, ceafe):
 | 
				
			||||||
 | 
					            evaluator = Evaluator(metric)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for ex in examples:
 | 
				
			||||||
 | 
					                # XXX this is the only different part
 | 
				
			||||||
 | 
					                p_clusters = doc2clusters(ex.predicted, self.output_prefix)
 | 
				
			||||||
 | 
					                g_clusters = doc2clusters(ex.reference, self.output_prefix)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                cluster_info = get_cluster_info(p_clusters, g_clusters)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                evaluator.update(cluster_info)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            score = {
 | 
				
			||||||
 | 
					                "coref_f": evaluator.get_f1(),
 | 
				
			||||||
 | 
					                "coref_p": evaluator.get_precision(),
 | 
				
			||||||
 | 
					                "coref_r": evaluator.get_recall(),
 | 
				
			||||||
 | 
					            }
 | 
				
			||||||
 | 
					            scores.append(score)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        out = {}
 | 
				
			||||||
 | 
					        for field in ("f", "p", "r"):
 | 
				
			||||||
 | 
					            fname = f"coref_{field}"
 | 
				
			||||||
 | 
					            out[fname] = mean([ss[fname] for ss in scores])
 | 
				
			||||||
 | 
					        return out
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user