Source code for unKR.loss.GMUCp_Loss

import torch
import torch.nn as nn
import torch.nn.functional as F


[docs] class GMUCp_Loss(nn.Module): """GMUC+ Loss Attributes: args: Some pre-set parameters, etc model: The UKG model for training. """ def __init__(self, args, model): super(GMUCp_Loss, self).__init__() self.args = args self.model = model
[docs] def forward(self, query_scores, query_scores_var, false_scores, query_confidence, symbolid_ic): """ Args: query_scores: The matching scores for link prediction. query_scores_var: The prediction value of confidence. false_scores: The loss for false set. query_confidence: The true value for confidence in query set. symbolid_ic: The loss for symbolid-ic value. Returns: loss: The training loss for back propagation. """ if self.args.num_neg != 1: false_scores = false_scores.reshape((query_scores.shape[0], self.args.num_neg)) # resize false_scores = torch.mean(false_scores, dim=1) false_scores = false_scores.reshape((query_scores.shape[0])) # resize zero_torch = torch.zeros(query_confidence.shape).cuda() ones_torch = torch.ones(query_confidence.shape).cuda() query_conf_mask = torch.where(query_confidence < self.args.conf_thr, zero_torch, query_confidence) # ------ MSE loss ------- mse_loss = (query_scores_var - query_confidence) ** 2 mse_loss = self.args.mse_weight * mse_loss.mean() # ------ rank loss ------ rank_loss = self.args.margin - (query_scores - false_scores) if self.args.if_conf: rank_loss = torch.mean(F.relu(rank_loss) * query_conf_mask) else: rank_loss = torch.mean(F.relu(rank_loss)) rank_loss = self.args.rank_weight * rank_loss # ic loss symbol_ids = symbolid_ic[:, 0].squeeze().long() symbol_emb_var_ = self.model.symbol_emb_var(symbol_ids) symbol_emb_var_norm = torch.norm(symbol_emb_var_, p=2, dim=1) symbol_ics = symbolid_ic[:, 1].squeeze() ic_loss = torch.mean(torch.square(self.model.ic_loss_w * symbol_emb_var_norm + self.model.ic_loss_b - symbol_ics)) ic_loss = self.args.ic_weight * ic_loss loss = rank_loss + mse_loss + ic_loss return loss