mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
conflict
This commit is contained in:
commit
150e7c46d7
|
@ -40,11 +40,8 @@ def build_wl_coref_model(
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
# TODO chain tok2vec with these models
|
# TODO chain tok2vec with these models
|
||||||
# TODO fix device - should be automatic
|
|
||||||
device = "cuda:0"
|
|
||||||
coref_scorer = PyTorchWrapper(
|
coref_scorer = PyTorchWrapper(
|
||||||
CorefScorer(
|
CorefScorer(
|
||||||
device,
|
|
||||||
dim,
|
dim,
|
||||||
embedding_size,
|
embedding_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -64,7 +61,6 @@ def build_wl_coref_model(
|
||||||
# TODO this was hardcoded to 1024, check
|
# TODO this was hardcoded to 1024, check
|
||||||
hidden_size,
|
hidden_size,
|
||||||
sp_embedding_size,
|
sp_embedding_size,
|
||||||
device
|
|
||||||
),
|
),
|
||||||
|
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
|
@ -266,7 +262,6 @@ class CorefScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: str,
|
|
||||||
dim: int, # tok2vec size
|
dim: int, # tok2vec size
|
||||||
dist_emb_size: int,
|
dist_emb_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
|
@ -283,8 +278,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
epochs_trained (int): the number of epochs finished
|
epochs_trained (int): the number of epochs finished
|
||||||
(useful for warm start)
|
(useful for warm start)
|
||||||
"""
|
"""
|
||||||
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
|
|
||||||
#TODO clean this up
|
#TODO clean this up
|
||||||
bert_emb = dim
|
bert_emb = dim
|
||||||
pair_emb = bert_emb * 3 + self.pw.shape
|
pair_emb = bert_emb * 3 + self.pw.shape
|
||||||
|
@ -293,7 +287,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
hidden_size,
|
hidden_size,
|
||||||
n_layers,
|
n_layers,
|
||||||
dropout_rate
|
dropout_rate
|
||||||
).to(device)
|
)
|
||||||
self.lstm = torch.nn.LSTM(
|
self.lstm = torch.nn.LSTM(
|
||||||
input_size=bert_emb,
|
input_size=bert_emb,
|
||||||
hidden_size=bert_emb,
|
hidden_size=bert_emb,
|
||||||
|
@ -304,7 +298,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
bert_emb,
|
bert_emb,
|
||||||
dropout_rate,
|
dropout_rate,
|
||||||
roughk
|
roughk
|
||||||
).to(device)
|
)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -453,7 +447,6 @@ class AnaphoricityScorer(torch.nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RoughScorer(torch.nn.Module):
|
class RoughScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Is needed to give a roughly estimate of the anaphoricity of two candidates,
|
Is needed to give a roughly estimate of the anaphoricity of two candidates,
|
||||||
|
@ -484,7 +477,6 @@ class RoughScorer(torch.nn.Module):
|
||||||
pair_mask = torch.arange(mentions.shape[0])
|
pair_mask = torch.arange(mentions.shape[0])
|
||||||
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
||||||
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
||||||
pair_mask = pair_mask.to(mentions.device)
|
|
||||||
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
||||||
rough_scores = pair_mask + bilinear_scores
|
rough_scores = pair_mask + bilinear_scores
|
||||||
|
|
||||||
|
@ -511,7 +503,7 @@ class RoughScorer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class SpanPredictor(torch.nn.Module):
|
class SpanPredictor(torch.nn.Module):
|
||||||
def __init__(self, input_size: int, distance_emb_size: int, device):
|
def __init__(self, input_size: int, distance_emb_size: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.ffnn = torch.nn.Sequential(
|
self.ffnn = torch.nn.Sequential(
|
||||||
torch.nn.Linear(input_size * 2 + 64, input_size),
|
torch.nn.Linear(input_size * 2 + 64, input_size),
|
||||||
|
@ -522,7 +514,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
torch.nn.Dropout(0.3),
|
torch.nn.Dropout(0.3),
|
||||||
torch.nn.Linear(256, 64),
|
torch.nn.Linear(256, 64),
|
||||||
)
|
)
|
||||||
self.device = device
|
|
||||||
self.conv = torch.nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
||||||
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||||
|
@ -600,17 +591,10 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
self.shape = emb_size
|
self.shape = emb_size
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
""" A workaround to get current device (which is assumed to be the
|
|
||||||
device of the first parameter of one of the submodules) """
|
|
||||||
return next(self.distance_emb.parameters()).device
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||||
top_indices: torch.Tensor
|
top_indices: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
word_ids = torch.arange(0, top_indices.size(0), device=self.device)
|
word_ids = torch.arange(0, top_indices.size(0))
|
||||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
||||||
).clamp_min_(min=1)
|
).clamp_min_(min=1)
|
||||||
log_distance = distance.to(torch.float).log2().floor_()
|
log_distance = distance.to(torch.float).log2().floor_()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user