Source code for swap.utils.subject


from collections import OrderedDict
import json

from swap.utils.collection import Collection


import logging
logger = logging.getLogger(__name__)

[docs]class Subject: """ Class to track an individual subject, its gold status, and its score. """
[docs] def __init__(self, subject, gold, config, score=None, seen=0, retired=None): self.id = subject self.gold = gold self.config = config score = score or config.p0 self.prior = score self._score = score self.has_changed = False self.seen = seen self.history = [] self.retired = retired
@property def score(self): return self._score @score.setter def score(self, new): self.has_changed = self.has_changed or self._score != new self._score = new
[docs] @classmethod def new(cls, subject, gold, config): """ Create a new Subject """ return cls(subject, gold, config)
[docs] def classify(self, user, cl): """ Add a classification to this subject Params ------ user: user that made the classification cl: classification, 1 or 0 """ # Add classification to history self.seen += 1 self.history.append((user.id, user.score, cl))
[docs] def update_user(self, user): """ Update the history of this subject when a user's score has changed """ for i in range(len(self.history)): h = self.history[i] if h[0] == user.id: self.history[i] = (h[0], user.score, h[2])
def _retire(self, thresholds, score): bogus, real = thresholds if score < bogus: return 0 elif score > real: return 1 return -1
[docs] def update_score(self, thresholds=None, history=False): """ Recalculate the score for this subject from its stored classification history. Params ------ thresholds: Also update retirement status of this subject given threshold parameters. (bogus, real) history: (bool) Return list of score history """ score = self.prior _history = [] for _, (u0, u1), cl in self.history: if cl == 1: a = score * u1 b = (1-score) * (1-u0) elif cl == 0: a = score*(1-u1) b = (1-score)*(u0) try: score = a / (a + b) # leave score unchanged except ZeroDivisionError as e: print(e) if history: _history.append(score) if thresholds is not None: retired = self._retire(thresholds, score) if retired in [0, 1]: break if thresholds is not None: self.retired = self._retire(thresholds, score) self.score = score if history: return score, _history return score
[docs] def dump(self): """ Dump this subject """ return self.id, self.gold, self.score, self.retired, self.seen
# return OrderedDict([ # ('subject', self.id), # ('gold', self.gold), # ('score', self.score), # #('history', self.history), # ('retired', self.retired), # ('seen', self.seen), # ])
[docs] def truncate(self): """ Clear the history of this subject, and update the prior to the current score.""" self.prior = self.score self.history = []
[docs] @classmethod def load(cls, data): """ Load a subject from dumped data """ keys = 'subject', 'gold', 'score', 'retired', 'seen', 'config' data = {k: data[k] for k in keys} return cls(**data)
def __str__(self): return 'id %d gold %d score %.3f length %d' % \ (self.id, self.gold, self.score, len(self.history)) def __repr__(self): return str(self)
[docs]class Subjects(Collection): """ Collection of Subjects """
[docs] def new(self, subject): """ Create and return a new Subject """ return Subject.new(subject, -1, self.config)
def _load_item(self, data): return Subject.load({'config': self.config, **data})
[docs] def retired(self): """ Return all retired subjects """ subjects = [] for subject in self.iter(): if subject.retired in [0, 1]: subjects.append(subject.id) return self.subset(subjects)
[docs] def get_changed(self): subjects = [] for subject in self.iter(): if subject.has_changed: subjects.append(subject.id) return self.subset(subjects)
[docs] def gold(self): """ Return subjects that have a gold label """ subjects = [] for subject in self.iter(): if subject.gold in [0, 1]: subjects.append(subject.id) return self.subset(subjects)
[docs]class Thresholds: """ Class to determine retirement thresholds Thresholds are determined from the false positive rate (fpr) and the missed detection rate (mdr), considering only the subjects with gold labels. The bogus retirement threshold is set such that a rate equal to mdr of real subjects are mislabeled as bogus. The real retirement threshold is set such that a rate equal to fpr of bogus subjects are labeled as real. """
[docs] def __init__(self, subjects, fpr, mdr, thresholds=None): self.subjects = subjects self.fpr = fpr self.mdr = mdr self.thresholds = thresholds
[docs] def dump(self): """ Dump thresholds object """ return self.fpr, self.mdr, json.dumps(self.thresholds)
# return { # 'fpr': self.fpr, # 'mdr': self.mdr, # 'thresholds': json.dumps(self.thresholds) # } def __str__(self): return str(self.dump()) def __repr__(self): return str(self)
[docs] @classmethod def load(cls, subjects, data): """ Load thresholds from dumped data """ keys = 'fpr', 'mdr', 'thresholds' data = {k: data[k] for k in keys} data['subjects'] = subjects data['thresholds'] = json.loads(data['thresholds']) return cls(**data)
[docs] def get_scores(self): """ Generate sorted list of subject scores and gold labels """ scores = [] for subject in self.subjects.iter(): # if len(subject.history) > 0: scores.append((subject.gold, subject.score)) scores = sorted(scores, key=lambda item: item[1]) return scores
[docs] def get_counts(self, scores): """ Get number of subjects in each gold label class (-1,0,1) """ counts = {k: 0 for k in [-1, 0, 1]} for score in scores: counts[score[0]] += 1 return counts
def __call__(self): """ Determine retirement tresholds """ if self.thresholds is not None: return self.thresholds fpr = self.fpr mdr = self.mdr logger.debug('determining retirement thresholds fpr %.3f mdr %.3f', fpr, mdr) scores = self.get_scores() totals = self.get_counts(scores) # Calculate real retirement threshold count = 0 real = 0 if totals[0] == 0: logger.error('No bogus gold labels!') real = 1 _fpr = None else: for gold, score in scores: if gold == 0: count += 1 _fpr = 1 - count / totals[0] # print(_fpr, count, totals[0], score) if _fpr < fpr: real = score break # Calculate bogus retirement threshold count = 0 bogus = 0 if totals[1] == 0: logger.error('No real gold labels!') bogus = 0 _mdr = None else: for gold, score in reversed(scores): if gold == 1: count += 1 _mdr = 1 - count / totals[1] # print(_mdr, count, totals[1], score) if _mdr < mdr: bogus = score break p0 = self.subjects.config.p0 if bogus >= p0: logger.warning('bogus is greater than prior, ' 'setting bogus threshold to p0') bogus = p0 if real <= p0: logger.warning('real is less than prior, ' 'setting real threshold to p0') real = p0 logger.debug('bogus %.4f real %.4f, fpr %.4f mdr %.4f', bogus, real, _fpr, _mdr) self.thresholds = bogus, real return self.thresholds
class ScoreStats: def __init__(self, subjects, thresholds): self.subjects = subjects self.thresholds = thresholds self.tpr = None self.tnr = None self.fpr = None self.fnr = None self.mse = None self.mse_t = None self.mdr = None self.purity = None self.retired = None self.retired_correct = None def __call__(self): self.calculate() @property def completeness(self): return self.tpr def get_scores(self): scores = [] for subject in self.subjects.iter(): if subject.gold in [0, 1]: scores.append((subject.gold, subject.score)) scores = sorted(scores, key=lambda item: item[1]) return scores def calculate(self): scores = self.get_scores() bogus, real = self.thresholds() low = self.counts(scores, 0, bogus) high = self.counts(scores, real, 1) total = self.counts(scores) logger.debug('low %s high %s total %s', low, high, total) stats = {} def divide(n, d): if d == 0: return None return n / d self.tpr = divide(high[1], total[1]) self.tnr = divide(low[0], total[0]) self.fpr = divide(high[0], total[0]) self.fnr = divide(low[1], total[1]) self.purity = divide(high[1], self.total(high)) self.retired = divide( (self.total(low) + self.total(high)), len(self.subjects)) self.retired_correct = divide( (high[1] + low[0]), (self.total(low) + self.total(high))) # Calculate mean squared error self.mse = self.mean_squared_error(scores) self.mse_t = self.mean_squared_error(scores, True) # self.completeness = self.tpr self.mdr = 1 - self.tpr return stats def mean_squared_error(self, scores, retirement=False): bogus, real = self.thresholds() error = 0 n = 0 for gold, p in scores: if gold in [0, 1]: if retirement: if p < bogus: p = 0 elif p > real: p = 1 error += (gold - p) ** 2 n += 1 error = error / n return error @staticmethod def total(counts): return counts[0] + counts[1] @staticmethod def counts(scores, left=0, right=1): counts = {-1: 0, 0: 0, 1: 0} for gold, p in scores: if gold == -1 or p is None or p < left or p > right: continue counts[gold] += 1 return counts def dict(self): keys = [ 'tpr', 'tnr', 'fpr', 'fnr', 'mse', 'mse_t', 'purity', 'retired', 'retired_correct', 'mdr'] data = [] for k in keys: v = self.__dict__[k] if v is not None: data.append((k, v)) return OrderedDict(data) def __str__(self): s = '' stats = self.dict() for key, value in sorted(stats.items(), key=lambda x: x): s += '%s: %.3f ' % (key, value) return '{%s}' % s[:-1] def __repr__(self): return str(self)