from sympy import symbols, sqrt, re, im, pi, atan2, sin, cos, I
from spb import *
r, theta, x, y = symbols("r, theta, x, y", real=True)
mag = lambda z: sqrt(re(z)**2 + im(z)**2)
phase = lambda z, k=0: atan2(im(z), re(z)) + 2 * k * pi
n = 2 # exponent (integer)
z = x + I * y # cartesian
d = {x: r * cos(theta), y: r * sin(theta)} # cartesian to polar
branches = [(mag(z)**(1 / n) * cos(phase(z, i) / n)).subs(d)
    for i in range(n)]
exprs = [(r * cos(theta), r * sin(theta), rb) for rb in branches]
series = [
    surface_parametric(*e, (r, 0, 3), (theta, -pi, pi),
        label="branch %s" % (i + 1), wireframe=True, wf_n2=20)
    for i, e in enumerate(exprs)]
graphics(*series, backend=PB, zlabel="f(z)")