pytorch code#
The necessary code is quite comparable to numpy.
import torch
def calculate(x_min, x_max, y_min, y_max, max_iterations, resolution):
x = torch.linspace(x_min, x_max, resolution)
y = torch.linspace(y_min, y_max, resolution)
c = x + y[:, None] * 1j
# a simple .copy() is not possible in torch
c0 = c.clone().detach()
iterations = torch.zeros_like(c, dtype=torch.int32)
for iteration in range(max_iterations):
mask = torch.abs(c) < 2
c[mask] = c[mask] ** 2 + c0[mask]
iterations[mask] += 1
# note the explicit conversion into a numpy type here
return iterations.detach().numpy(), {}