changed Module interface to 3 functions

Signed-off-by: Arthur Lu <learthurgo@gmail.com>
This commit is contained in:
Arthur Lu 2021-11-23 22:23:59 +00:00
parent a251fef024
commit 356b71be62
2 changed files with 42 additions and 25 deletions

View File

@ -8,20 +8,24 @@ class Module(metaclass = abc.ABCMeta):
@classmethod @classmethod
def __subclasshook__(cls, subclass): def __subclasshook__(cls, subclass):
return (hasattr(subclass, 'validate_config') and return (hasattr(subclass, '__init__') and
callable(subclass.__init__) and
hasattr(subclass, 'validate_config') and
callable(subclass.validate_config) and callable(subclass.validate_config) and
hasattr(subclass, 'load_data') and hasattr(subclass, 'run') and
callable(subclass.load_data) and callable(subclass.run)
hasattr(subclass, 'process_data') and
callable(subclass.process_data) and
hasattr(subclass, 'push_results') and
callable(subclass.push_results)
) )
@abc.abstractmethod @abc.abstractmethod
def validate_config(self): def __init__(self, config, apikey, tbakey, timestamp, competition, *args, **kwargs):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def validate_config(self, *args, **kwargs):
raise NotImplementedError
@abc.abstractmethod
def run(self, exec_threads, *args, **kwargs):
raise NotImplementedError
"""
@abc.abstractmethod
def load_data(self): def load_data(self):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
@ -29,7 +33,7 @@ class Module(metaclass = abc.ABCMeta):
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
def push_results(self): def push_results(self):
raise NotImplementedError raise NotImplementedError"""
class Match (Module): class Match (Module):
@ -52,10 +56,15 @@ class Match (Module):
def validate_config(self): def validate_config(self):
return True, "" return True, ""
def load_data(self): def run(self, exec_threads):
self._load_data()
self._process_data(exec_threads)
self._push_results()
def _load_data(self):
self.data = d.load_match(self.apikey, self.competition) self.data = d.load_match(self.apikey, self.competition)
def simplestats(data_test): def _simplestats(data_test):
signal.signal(signal.SIGINT, signal.SIG_IGN) signal.signal(signal.SIGINT, signal.SIG_IGN)
@ -86,7 +95,7 @@ class Match (Module):
if test == "regression_sigmoidal": if test == "regression_sigmoidal":
return an.regression(ranges, data, ['sig']) return an.regression(ranges, data, ['sig'])
def process_data(self, exec_threads): def _process_data(self, exec_threads):
tests = self.config["tests"] tests = self.config["tests"]
data = self.data data = self.data
@ -104,9 +113,9 @@ class Match (Module):
input_vector.append((team, variable, test, data[team][variable])) input_vector.append((team, variable, test, data[team][variable]))
self.data = input_vector self.data = input_vector
self.results = list(exec_threads.map(self.simplestats, self.data)) self.results = list(exec_threads.map(self._simplestats, self.data))
def push_results(self): def _push_results(self):
short_mapping = {"regression_linear": "lin", "regression_logarithmic": "log", "regression_exponential": "exp", "regression_polynomial": "ply", "regression_sigmoidal": "sig"} short_mapping = {"regression_linear": "lin", "regression_logarithmic": "log", "regression_exponential": "exp", "regression_polynomial": "ply", "regression_sigmoidal": "sig"}
@ -162,10 +171,15 @@ class Metric (Module):
def validate_config(self): def validate_config(self):
return True, "" return True, ""
def load_data(self): def run(self, exec_threads):
self._load_data()
self._process_data(exec_threads)
self._push_results()
def _load_data(self):
self.data = d.pull_new_tba_matches(self.tbakey, self.competition, self.timestamp) self.data = d.pull_new_tba_matches(self.tbakey, self.competition, self.timestamp)
def process_data(self, exec_threads): def _process_data(self, exec_threads):
elo_N = self.config["tests"]["elo"]["N"] elo_N = self.config["tests"]["elo"]["N"]
elo_K = self.config["tests"]["elo"]["K"] elo_K = self.config["tests"]["elo"]["K"]
@ -258,7 +272,7 @@ class Metric (Module):
d.push_metric(self.client, self.competition, temp_vector) d.push_metric(self.client, self.competition, temp_vector)
def push_results(self): def _push_results(self):
pass pass
class Pit (Module): class Pit (Module):
@ -282,10 +296,15 @@ class Pit (Module):
def validate_config(self): def validate_config(self):
return True, "" return True, ""
def load_data(self): def run(self, exec_threads):
self._load_data()
self._process_data(exec_threads)
self._push_results()
def _load_data(self):
self.data = d.load_pit(self.apikey, self.competition) self.data = d.load_pit(self.apikey, self.competition)
def process_data(self, exec_threads): def _process_data(self, exec_threads):
return_vector = {} return_vector = {}
for team in self.data: for team in self.data:
for variable in self.data[team]: for variable in self.data[team]:
@ -296,7 +315,7 @@ class Pit (Module):
self.results = return_vector self.results = return_vector
def push_results(self): def _push_results(self):
d.push_pit(self.apikey, self.competition, self.results) d.push_pit(self.apikey, self.competition, self.results)
class Rating (Module): class Rating (Module):

View File

@ -223,9 +223,7 @@ def main(send, verbose = False, profile = False, debug = False):
valid = current_module.validate_config() valid = current_module.validate_config()
if not valid: if not valid:
continue continue
current_module.load_data() current_module.run(exec_threads)
current_module.process_data(exec_threads)
current_module.push_results()
send(stdout, INF, m + " module finished in " + str(time.time() - start) + " seconds") send(stdout, INF, m + " module finished in " + str(time.time() - start) + " seconds")
if debug: if debug:
f = open(m + ".log", "w+") f = open(m + ".log", "w+")