Skip to content

关于代码的问题 #46

@Sunjc234

Description

@Sunjc234

您好,我在运行BERT部分代码的时候,将mode参数设为Oracle时这个函数会因为cal_lead为FALSE,selected_ids未被初始化而报错,请问lead、和oracle分别对应您论文中的哪部分呢,方便的话请您告知,非常感谢
` def test(self, test_iter, step, cal_lead=False, cal_oracle=False):

    # Set model in validating mode.
    def _get_ngrams(n, text):
        ngram_set = set()
        text_length = len(text)
        max_index_ngram_start = text_length - n
        for i in range(max_index_ngram_start + 1):
            ngram_set.add(tuple(text[i:i + n]))
        return ngram_set

    def _block_tri(c, p):
        tri_c = _get_ngrams(3, c.split())
        for s in p:
            tri_s = _get_ngrams(3, s.split())
            if len(tri_c.intersection(tri_s))>0:
                return True
        return False

    if (not cal_lead and not cal_oracle):
        self.model.eval()
    stats = Statistics()

    can_path = '%s_step%d.candidate'%(self.args.result_path,step)
    gold_path = '%s_step%d.gold' % (self.args.result_path, step)
    with open(can_path, 'w') as save_pred:
        with open(gold_path, 'w') as save_gold:
            with torch.no_grad():
                for batch in test_iter:
                    gold = []
                    pred = []
                    if (cal_lead):
                        selected_ids = [list(range(batch.clss.size(1)))] * batch.batch_size
                    for i, idx in enumerate(selected_ids):
                        _pred = []
                        if(len(batch.src_str[i])==0):
                            continue
                        for j in selected_ids[i][:len(batch.src_str[i])]:
                            if(j>=len( batch.src_str[i])):
                                continue
                            candidate = batch.src_str[i][j].strip()
                            _pred.append(candidate)

                            if ((not cal_oracle) and (not self.args.recall_eval) and len(_pred) == 3):
                                break`

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions