Optimizing Cython for Mandelbrot fractal calculations
This is a continuation of of a series of blog postings found here:
http://aroberge.blogspot.com/2010/01/profiling-adventures-and-cython-setting.html
At the end of that series of posts, the OP had Cython code that calculated and drew the fractal image. I noted that the drawing might be faster if you modified the pixels one by one as an array in memory, then blitted that to the image.
Not knowing quite how to do that with TK, I made a wx version, then set about making the Cython modifications, using a numpy array for the image.
First version I got working
1 # mandelcy1.pyx
2 # cython: profile=True
3
4 import cython
5
6 @cython.profile(False)
7 cdef inline int mandel(double real, double imag, int max_iterations=20):
8 '''determines if a point is in the Mandelbrot set based on deciding if,
9 after a maximum allowed number of iterations, the absolute value of
10 the resulting number is greater or equal to 2.'''
11 cdef double z_real = 0., z_imag = 0.
12 cdef int i
13
14 for i in range(0, max_iterations):
15 z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
16 2*z_real*z_imag + imag )
17 if (z_real*z_real + z_imag*z_imag) >= 4:
18 return i
19 return -1
20
21 def create_fractal( double min_x,
22 double min_y,
23 double pixel_size,
24 int nb_iterations,
25 colours,
26 image):
27
28 cdef int width, height
29 cdef int x, y, start_y, end_y
30 cdef int nb_colours, current_colour, new_colour
31 cdef double real, imag
32
33 nb_colours = len(colours)
34 # image is an ndarray of size: w,h,3
35 width = image.shape[0]
36 height = image.shape[1]
37
38 for x in range(width):
39 real = min_x + x*pixel_size
40 for y in range(height):
41 imag = min_y + y*pixel_size
42 colour = mandel(real, imag, nb_iterations)
43 image[x, y, :] = colours[ colour%nb_colours ]
What I did differently:
moved the makePallet code out of cython -- it's only done once, no need for Cython
- changed the colors to be numbers (RGB) rather than hex strings.
- had create_fractal take a numpy array of pixels as input. The code then loops through the pixels and sets their color. Since I'm setting each pixel, there is none of that line-drawing stuff, so it's pretty simple.
I put very simple timing code in the wx wrapper (see below for that), and this takes about 5.9 seconds to run, on a 500x500 image -- increasing the iterations doesn't change it much -- I think it's stopping far short of the max iterations much of the time anyway. I don't know how this compares to the OP's version -- I never did get that running on my machine.
What I haven't done:
Adding numpy arrays
I haven't told Cython that this is a numpy array -- Cython understands numpy arrays; I suspect that will make a big difference:
1) make the colours sequence a numpy array:
return np.array(colours, dtype=np.uint8)
WOW! that sped it up by about a factor or two (2.8 seconds), even without telling Cython that it was a numpy array!
2) let cython know it's a numpy array:
1 cimport numpy as np # for the special numpy stuff
2
3 def create_fractal( double min_x,
4 double min_y,
5 double pixel_size,
6 int nb_iterations,
7 np.ndarray colours not None,
8 np.ndarray image not None):
The not None means that you'll get an exception if you pass in None -- That's left over from Pyrex, where None can be used to mean "not a valid value" for a dtype that doesn't have a NaN or anything.
note that you need to tell the compiler where to find the numpy headers. I do this in the setup.py I'm using to build it:
1 setup(
2 cmdclass = {'build_ext': build_ext},
3 ext_modules = [Extension("mandelcy", ["mandelcy.pyx"], )],
4 include_dirs = [numpy.get_include(),],
5 )
OK -- this helped a little, but not much -- about 2 seconds now.
But we can tell Cython what data types to expect:
1 def create_fractal( double min_x,
2 double min_y,
3 double pixel_size,
4 int nb_iterations,
5 np.ndarray[np.uint8_t, ndim=2, mode="c"] colours not None,
6 np.ndarray[np.uint8_t, ndim=3, mode="c"] image not None):
So it now knows that colours and image are arrays of unsigned 8 bit integers, and what their dimensions are. The "mode = 'c'" means that the array is c-contiguous (rather than fortran order). Interestingly, this didn't help speed any ... yet.
Turning off bounds checking:
1 @cython.boundscheck(False)
2 def create_fractal( double min_x, ...):
not much change there, either.
oops, I forgot to tell cython that colour is an int, so it had to convert it to a generic python object, and then back again, inside the loop:
1 colour = mandel(real, imag, nb_iterations)
2 image[x, y, :] = colours[ colour%nb_colours]
That helped a bit: from 2.04 to 1.96 seconds.
Cython can index numpy arrays with simple pointer math, If you use only simple indexing to do it, so I change:
1 image[x, y, :] = colours[ colour%nb_colours]
to:
1 image[x, y, 0] = colours[ colour%nb_colours, 0 ]
2 image[x, y, 1] = colours[ colour%nb_colours, 1 ]
3 image[x, y, 2] = colours[ colour%nb_colours, 2 ]
BINGO! 0.03 seconds, about a 66 times speed up!
Someone on the Cython list suggested that using a single integer to store the color, rather than the three separate 8 byte ones would speed things up -- I tried it, and no noticeable difference. I think that the calculation is swamping the color-setting at this point, and perhaps the pushing around of extra data (I had to use RGBA for a 32 bit integer) slows things down a bit.
Here is that version ( I also tweaked the inputs a little to make it a bit easier to change what part of the fractal to view) :
1 # mandel3cy.pyx
2 # cython: profile=True
3
4 import cython
5 #import numpy as np
6 cimport numpy as np # for the special numpy stuff
7
8 @cython.profile(False)
9 cdef inline int mandel(double real, double imag, int max_iterations=20):
10 '''determines if a point is in the Mandelbrot set based on deciding if,
11 after a maximum allowed number of iterations, the absolute value of
12 the resulting number is greater or equal to 2.'''
13 cdef double z_real = 0., z_imag = 0.
14 cdef int i
15
16 for i in range(0, max_iterations):
17 z_real, z_imag = ( z_real*z_real - z_imag*z_imag + real,
18 2*z_real*z_imag + imag )
19 if (z_real*z_real + z_imag*z_imag) >= 4:
20 return i
21 return -1
22
23 @cython.boundscheck(False)
24 def create_fractal( double min_x,
25 double max_x,
26 double min_y,
27 int nb_iterations,
28 np.ndarray[np.uint8_t, ndim=2, mode="c"] colours not None,
29 np.ndarray[np.uint8_t, ndim=3, mode="c"] image not None):
30
31 cdef int width, height
32 cdef int x, y, start_y, end_y
33 cdef int nb_colours, colour
34 cdef double real, imag, pixel_size
35
36
37
38 nb_colours = len(colours)
39
40 width = image.shape[0]
41 height = image.shape[1]
42
43 pixel_size = (max_x - min_x) / width
44
45 for x in range(width):
46 real = min_x + x*pixel_size
47 for y in range(height):
48 imag = min_y + y*pixel_size
49 colour = mandel(real, imag, nb_iterations)
50 image[x, y, 0] = colours[ colour%nb_colours, 0 ]
51 image[x, y, 1] = colours[ colour%nb_colours, 1 ]
52 image[x, y, 2] = colours[ colour%nb_colours, 2 ]
and the setup.py:
1 #!/usr/bin/env python
2
3 """
4 setup.py to build mandelbot code with cython
5 """
6 from distutils.core import setup
7 from distutils.extension import Extension
8 from Cython.Distutils import build_ext
9 import numpy # to get includes
10
11
12 setup(
13 cmdclass = {'build_ext': build_ext},
14 ext_modules = [Extension("mandelcy", ["mandelcy.pyx"], )],
15 include_dirs = [numpy.get_include(),],
16 )
And here is the wx code to view it:
1 #!/usr/bin/env python
2
3 """
4 A simple app to show mandelbrot fractals.
5 """
6
7 import time
8 import numpy as np
9 import wx
10
11 import mandelcy # this is the cython module that does the real work.
12
13 class BitmapWindow(wx.Window):
14 """
15 A simple window to display a bitmap from a numpy array
16 """
17 def __init__(self, parent, bytearray, *args, **kwargs):
18 wx.Window.__init__(self, parent, *args, **kwargs)
19
20 self.bytearray = bytearray
21 self.Bind(wx.EVT_PAINT, self.OnPaint)
22
23 def OnPaint(self, evt):
24 dc = wx.PaintDC(self)
25 w, h = self.bytearray.shape[:2]
26 bmp = wx.BitmapFromBuffer(w, h, self.bytearray)
27 dc.DrawBitmap(bmp, 50, 0 )
28
29 class DemoFrame(wx.Frame):
30 def __init__(self, title = "Mandelbrot Demo"):
31 wx.Frame.__init__(self, None , -1, title)#, size = (800,600), style=wx.DEFAULT_FRAME_STYLE|wx.NO_FULL_REPAINT_ON_RESIZE)
32
33 # create the array and bitmap:
34 self.bytearray = np.zeros((500, 500, 3), dtype=np.uint8) + 125
35
36 self.BitmapWindow = BitmapWindow(self, self.bytearray,
37 size=self.bytearray.shape[:2])
38
39 sizer = wx.BoxSizer(wx.VERTICAL)
40 sizer.Add(self.BitmapWindow, 0, wx.ALIGN_CENTER|wx.ALL, 10)
41 # set up the buttons
42 sizer.Add(self.SetUpTheButtons(), 0, wx.EXPAND)
43 self.SetSizerAndFit(sizer)
44
45 self.colours = make_palette()
46
47 def SetUpTheButtons(self):
48 RunButton = wx.Button(self, wx.NewId(), "Run")
49 RunButton.Bind(wx.EVT_BUTTON, self.OnRun)
50
51 self.IterSlider = wx.Slider( self, wx.ID_ANY,
52 value=20,
53 minValue=20,
54 maxValue=10000,
55 size=(250, -1),
56 style = wx.SL_HORIZONTAL | wx.SL_AUTOTICKS | wx.SL_LABELS
57 )
58
59 QuitButton = wx.Button(self, wx.NewId(), "Quit")
60 QuitButton.Bind(wx.EVT_BUTTON, self.OnQuit)
61
62 self.Bind(wx.EVT_CLOSE, self.OnQuit)
63
64 sizer = wx.BoxSizer(wx.HORIZONTAL)
65 sizer.Add((1,1), 1)
66 sizer.Add(RunButton, 0, wx.ALIGN_CENTER | wx.ALL, 4 )
67 sizer.Add((1,1), 1)
68 sizer.Add(self.IterSlider, 0, wx.ALIGN_CENTER | wx.ALL, 4 )
69 sizer.Add((1,1), 1)
70 sizer.Add(QuitButton, 0, wx.ALIGN_CENTER | wx.ALL, 4 )
71 sizer.Add((1,1), 1)
72 return sizer
73
74 def OnRun(self,Event):
75 width, height = self.bytearray.shape[:2]
76
77 min_x = -1.5
78 max_x = 0
79 min_y = -1.5
80 # max_y is calulated from X, to keep it symetric
81
82 nb_iterations = self.IterSlider.Value
83 print "Calculating with %i interations:"%nb_iterations
84 start = time.clock()
85 mandelcy.create_fractal(min_x, max_x, min_y, nb_iterations, self.colours, self.bytearray)
86 print "it took %f seconds to run"%( time.clock() - start)
87 self.Refresh()
88
89 def OnStop(self, Event=None):
90 self.Timer.Stop()
91
92 def OnQuit(self,Event):
93 self.Destroy()
94
95 def make_palette():
96 '''sample coloring scheme for the fractal - feel free to experiment
97
98 No need for this to be in Cython
99 '''
100 colours = []
101
102 for i in range(0, 25):
103 #colours.append('#%02x%02x%02x' % (i*10, i*8, 50 + i*8))
104 colours.append( (i*10, i*8, 50 + i*8), )
105 for i in range(25, 5, -1):
106 #colours.append('#%02x%02x%02x' % (50 + i*8, 150+i*2, i*10))
107 colours.append( (50 + i*8, 150+i*2, i*10), )
108 for i in range(10, 2, -1):
109 #colours.append('#00%02x30' % (i*15))
110 colours.append( (0, i*15, 48), )
111 return np.array(colours, dtype=np.uint8)
112
113
114 app = wx.PySimpleApp(0)
115 frame = DemoFrame()
116 frame.Show()
117 app.MainLoop()
If you don't have wx installed, here's a script that uses SciPy and PIL to save images:
import time
import numpy as np
import mandelcy # this is the cython module that does the real work.
def make_palette():
'''sample coloring scheme for the fractal - feel free to experiment
No need for this to be in Cython
'''
colours = []
for i in range(0, 25):
colours.append( (i*10, i*8, 50 + i*8), )
for i in range(25, 5, -1):
colours.append( (50 + i*8, 150+i*2, i*10), )
for i in range(10, 2, -1):
colours.append( (0, i*15, 48), )
return np.array(colours, dtype=np.uint8)
min_x = -1.5
max_x = 0
min_y = -1.5
nb_iterations = 500
bytearray = np.zeros((500, 500, 3), dtype=np.uint8) + 125
colours = make_palette()
start = time.clock()
mandelcy.create_fractal(min_x, max_x, min_y, nb_iterations, colours, bytearray)
print "it took %f seconds to run"%( time.clock() - start)
from scipy.misc import toimage
toimage(bytearray).save("mandelbrot.png")