Python Metaclasses for Generic Learning Algorithms

Metaclasses are terrific, in the sense that they’re a powerful tool for programming, but also in that they should inspire a bit of terror. In this post, I talk about an example from my own work that fits both criteria.

# Outline #

NB: I started working on this post but other priorities came up; it’s currently only half complete

This post explores how you can use metaclasses to generate custom methods at the instantiation of a class using information from the class itself. By employing this kind of introspection, we can do some useful things1, like generating a similar interface for slightly different classes without introducing awkward code.

I provide an example of how I used metaclasses to make the learning algorithms in my benchmarking project varcompfa more easily interchangeable by generating a custom update() method for each of them.

If you’re familiar with Python’s class system, then we can dive right in; if not, see the refresher.

# The Problem #

Various learning algorithms that should be interchangeable require slightly different information in order to update their estimates.

For example, state-value estimators learn from transitions like $(s, r, s')$, where $s$ is the current state, $s'$ is the successor, and $r$ is the reward for the transition. However, state-action value algorithms understandably need to see the action taken, so they update from transitions of the form $(s, a, r, s')$.

TD(λ) would be an example of the state value kind of algorithm:

 1 2 3 4 5 6 7 8  class TD(LearningAlgorithm): ... def learn(self, x, r, xp, alpha, gm, gm_p, lm): delta = r + gm_p*np.dot(self.w, xp) - np.dot(self.w, x) self.z = x + gm*lm*self.z self.w += alpha*delta*self.z return {'delta': delta} 

while SARSA(λ) is basically TD(λ) extended to the state-action value case.

More generally, some algorithms require you specify certain additional parameters: TD(λ) interpolates between pure bootstrapping (with λ=0) and Monte Carlo (when λ=1), but you can use state-dependent parameters 2; other algorithms might need stepsize information, or an adjustment for off-policy prediction, etc.

Take Emphatic TD for example. It’s essentially similar to Original Recipe TD(λ) but with an additional parameter for the “interest” associated with a state. Its learn() method would be something like:

  1 2 3 4 5 6 7 8 9 10 11 12  class ETD(LearningAlgorithm): ... def learn(self, x, r, xp, alpha, gm, gm_p, lm, rho, interest): delta = r + gm_p*np.dot(self.w, xp) - np.dot(self.w, x) self.F = gm*self.F + interest self.M = lm*interest + (1 - lm)*self.F self.z = rho*(x*self.M + gm*lm*self.z) self.w += alpha*delta*self.z # prepare for next iteration self.F *= rho return delta 

It’s frankly unpleasant from an aesthetic perspective3– given that all these algorithms address the same problem, it should be possible to substitute one in for another without rewriting huge chunks of code.

…But it becomes acutely annoying if you’re trying to compare a whole bunch of learning algorithms against each other:

• Some of the algorithms might share a few hyperparameters: e.g. the discount rate $γ$.
• Some hyperparameters might be needed for one algorithm but irrelevant to others, like how ETD’s state interest doesn’t affect vanilla TD.
• Other times we might be using variations on the same algorithm, like when we’re comparing performance with different bootstrapping $λ$.

I wanted to avoid writing separate boilerplate for each algorithm, because that’s tedious, it’s hard to modify/refine, and every keystroke risks introducing typos and therefore bugs.

# Solutions #

• Fundamentally, we’re looking for a way to update different algorithms from the same information: transitions of the form $(s, a, r, s')$.

• The implementations of the various algorithms have a lot in common, but might have slightly different function signatures.

• Some of the hyperparameters might be state-dependent, and so not computable ahead of time.

• For maximal generality, we also consider the case where information from of one agent is used as the input to another

• We’ve already noted our reluctance to write the same boilerplate for every algorithm we might want to use.

• Plus it makes it hard for others to inspect our implmentations

## My Solution #

algo_base.py

The definition of LearningAlgorithmMeta

The LearningAlgorithm class

Implementation of TD(λ)

Discrete SARSA(λ)

The learn method for TD(λ)

The learn method for SARSA(λ)

# A Refreshing Refresher on Classes and Metaclasses #

I’m considering making this a separate post, but there’s already quite a few (and likely better) posts on this topic elsewhere. We’ll see, I suppose.

A detailed background on Python’s object model and class system is beyond the scope of this post, and there are a number of places where you can go to read about it (see the references at the end of this post)

## Python Classes Overview #

The general idea of classes is that they group functionality together into self-contained chunks. For example, if I’m building a robot, I might separate the code for controlling the LEDs from the code that controls the motors. The LED and motor subsystems would communicate with the robot through the Robot Controller, which would translate commands like “begin flashing red ominously” into signals to the robot’s actual hardware.

flowchart BT; M(Motor) L(LEDs) D(Death Ray) C[Robot Controller] R((Physical Robot)) M o--o C L o--o C D o--o C C <-->|Send/Receive| R

It’s not strictly necessary to separate the components— we could just implement all the functionality in the Robot Controller, but it can make it easier to reuse code. If I were building many robots with different kinds of motors but the same sorts of LEDs, then I would only have to rewrite the parts dealing with the motors.

Subclasses are for when you’re writing code with a number of objects that ought to behave similarly but not quite identically and want to avoid repeating yourself. To continue with the earlier example, maybe I want to switch to different kinds of LEDs (say, going from APA102s to the cheaper WS2812s). I have likely spent a bunch of time getting the ominous red flashing animation just right, and so rather than throwing that work away, I could reorganize the LED code into a base class that contains all the code the different kinds of LEDs have in common and then have subclasses inherit from it. Then each subclass only needs the additional code that’s specific to the corresponding brand of LED.

classDiagram LED <|-- WS2812 LED <|-- APA102 LED <|-- SK6812 LED : +int red LED : +int green LED : +int blue LED: +set_brightness() LED: +pulse_ominously() class WS2812{ # refresh_rate : int = 400 } class APA102{ # refresh_rate : int = 20000 } class SK6812{ +int white +set_brightness() # refresh_rate : int = 400 }

The classes in the above diagram all inherit from the LED base, which is pretty generic. It has values for the red, green, and blue intensity, along with a method for setting the brightness and another one for pulsing ominously. The subclasses can reuse a lot of this functionality, but require minor tweaks, such as setting different refresh rates. SK6812 LEDs are basically similar to the WS2812s except that they contain a fourth element that just produces white light, and so we need to accomodate for that.

In actual Python code, this would look something like:

  1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55  import time class LED: self.MIN_INTENSITY = 255 self.MAX_INTENSITY = 255 def __init__(self): # Brightness values self.r = 0 self.g = 0 self.b = 0 def set_brightness(self, brightness: int): """Set the overall brightness for each element to the same value.""" self.r = max(self.MIN_INTENSITY, min(self.MAX_INTENSITY, brightness)) self.g = max(self.MIN_INTENSITY, min(self.MAX_INTENSITY, brightness)) self.b = max(self.MIN_INTENSITY, min(self.MAX_INTENSITY, brightness)) def pulse_ominously(self): self.set_brightness(0) # turn off LEDs wait_time = 1 / self.refresh_rate counter = 0 # loop counter # Begin slowly pulsing the red element while True: self.r = int( self.MAX_INTENSITY * ((counter % self.refresh_rate) / self.refresh_rate) ) # Increment counter and prepare for next loop counter += 1 time.sleep(wait_time) class WS2812(LED): refresh_rate = 400 class APA102(LED): refresh_rate = 20000 class SK6812(LED): refresh_rate = 400 def __init__(self): super().__init__() # covers red, green, blue elements self.w = 0 # initialize white element def set_brightness(self, brightness: int): self.r = max(self.MIN_INTENSITY, min(self.MAX_INTENSITY, brightness)) self.g = max(self.MIN_INTENSITY, min(self.MAX_INTENSITY, brightness)) self.b = max(self.MIN_INTENSITY, min(self.MAX_INTENSITY, brightness)) self.w = max(self.MIN_INTENSITY, min(self.MAX_INTENSITY, brightness)) 

Having just written the above for illustrative purposes, there’s a number of things I’d already like to change In real use you’d probably want to do the ominous pulsing on a different thread, but this is just a toy example.

## Metaclasses #

1. And many more arguably useful things. ↩︎

2. TODO: GVF references ↩︎

3. Arguably this sort of problem was the motivating factor behind the discovery and formulation of the Standard Model in physics– it was just too much hassle using a slightly different formalism given that each sub-theory is trying to describe the same Universe. ↩︎