Source code for nordlys.core.eval.trec_run

"""
Trec run
========

Utility module for working with TREC runfiles.

Usage
-----

Get statistics about a runfile
  ``trec_run <run_file> -o stat``


Filter runfile to contain only documents from a given set
  ``trec_run <run_file> -o filter -d <doc_ids_file> -f <output_file> -n <num_results>``


:Authors: Krisztian Balog, Dario Garigliotti
"""

import argparse
from math import exp
from nordlys.core.retrieval.retrieval_results import RetrievalResults
from nordlys.core.storage.parser.uri_prefix import URIPrefix
from nordlys.config import PLOGGER


[docs]class TrecRun(object): """Represents a TREC runfile. :param file_name: name of the run file :param normalize: whether retrieval scores are to be normalized for each query (default: False) :param remap_by_exp: whether scores are to be converted from the log-domain by taking their exp (default: False) """ def __init__(self, file_name=None, normalize=False, remap_by_exp=False, run_id=None): self.__results = {} # key is a query_id, value is a RetrievalResults object self.__sum_scores = {} self.run_id = run_id if file_name is not None: self.load_file(file_name, remap_by_exp) if normalize is True: self.normalize()
[docs] def load_file(self, file_name, remap_by_exp=False): """Loads a TREC runfile. :param file_name: name of the run file :param remap_by_exp: whether scores are to be converted from the log-domain by taking their exp (default: False) """ # load the file such that self.results[query_id] = res holds the results for a given query, # where res is a RetrievalResults object pre = URIPrefix() with open(file_name, "r") as f_baseline: for line in f_baseline: # Parse data fields = line.rstrip().split() if len(fields) != 6: continue query_id, doc_id, score = fields[0], fields[2], float(fields[4]) if self.run_id is None: self.run_id = fields[5] # Add parsed data if query_id not in self.__results: self.__results[query_id] = RetrievalResults() # initialize # remap exponentially the scores in log-domain to (0, 1) if remap_by_exp: score = exp(score) self.__results[query_id].append(doc_id, score) # an additional data structure to make the normalization easier self.__sum_scores[query_id] = self.__sum_scores.get(query_id, 0) + score
[docs] def normalize(self): """Normalizes the retrieval scores such that they sum up to one for each query.""" query_ids = self.get_results().keys() # new var, since for-loop will modify this dict for query_id in query_ids: norm_result = RetrievalResults() for entity_id, score in self.get_results()[query_id].get_scores_sorted(): norm_result.append(entity_id, score / self.__get_sum_scores(query_id)) self.get_results()[query_id] = norm_result # overwrite previous result
[docs] def filter(self, doc_ids_file, output_file, num_results=100): """Filters runfile to include only selected docIDs and outputs the results to a file. :param doc_ids_file: file with one doc_id per line :param output_file: output file name :param num_results: number of results per query """ # loading docids (with ignoring empty lines in the input file) with open(doc_ids_file, "r") as f: doc_ids = [l for l in (line.strip() for line in f) if l] # filtering qrels with open(output_file, "w") as f: for query_id, res in self.__results.items(): filtered_res = RetrievalResults() for doc_id, score in res.get_scores_sorted(): if doc_id in doc_ids: filtered_res.append(doc_id, score) if filtered_res.num_docs() == num_results: break filtered_res.write_trec_format(query_id, self.run_id, f, num_results)
[docs] def get_query_results(self, query_id): """Returns the corresponding RetrievalResults object for a given query. :param query_id: queryID :rtype: :py:class:`nordlys.core.retrieval.retrieval_results.RetrievalResults` """ return self.__results.get(query_id, None)
[docs] def get_results(self): """Returns all results. :return: a dict with queryIDs as keys and RetrievalResults object as values """ return self.__results
def __get_sum_scores(self, query_id): """Returns the sum of all the retrieval scores for a given query. :param query_id: queryID :return: sum of scores (or None if the queryID cannot be found) """ return self.__sum_scores.get(query_id, None)
[docs] def print_stat(self): """Prints simple statistics.""" PLOGGER.info("#queries: " + str(len(self.__results))) PLOGGER.info("#results: " + str(sum(v.num_docs() for k, v in self.__results.items())))
[docs]def arg_parser(): parser = argparse.ArgumentParser() parser.add_argument("run_file", help="run file") # mandatory arg parser.add_argument("-o", "--operation", help="operation name", choices=["stat", "filter"]) parser.add_argument("-d", "--doc_ids_file", help="file with the allowed doc_ids (for filtering)", type=str) parser.add_argument("-f", "--output_file", help="output file", type=str) parser.add_argument("-n", "--num_results", help="number of results", type=int) args = parser.parse_args() return args
[docs]def main(args): run = TrecRun(args.run_file) if args.operation == "stat": run.print_stat() elif args.operation == "filter": if len(args.doc_ids_file) == 0 or len(args.output_file) == 0: PLOGGER.info("doc_ids_file or output_file missing") else: run.filter(args.doc_ids_file, args.output_file)
if __name__ == "__main__": main(arg_parser())