from sympy import symbols, sqrt, re, im, pi, atan2, sin, cos, I
from spb import plot3d_parametric_surface, PB
r, theta, x, y = symbols("r, theta, x, y", real=True)
mag = lambda z: sqrt(re(z)**2 + im(z)**2)
phase = lambda z, k=0: atan2(im(z), re(z)) + 2 * k * pi
n = 2 # exponent (integer)
z = x + I * y # cartesian
d = {x: r * cos(theta), y: r * sin(theta)} # cartesian to polar
branches = [(mag(z)**(1 / n) * cos(phase(z, i) / n)).subs(d)
    for i in range(n)]
exprs = [(r * cos(theta), r * sin(theta), rb) for rb in branches]
plot3d_parametric_surface(*exprs, (r, 0, 3), (theta, -pi, pi),
    backend=PB, wireframe=True, wf_n2=20, zlabel="f(z)",
    label=["branch %s" % (i + 1) for i in range(len(branches))])