import matplotlib.cm as cm
plot3d(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), dict(cmap=cm.coolwarm), use_cm=True)