neuropercolation/evaluation/phi.py
2023-09-30 19:53:06 +02:00

173 lines
5.4 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 27 04:39:54 2023
@author: astral
"""
import os
import json
import math as m
import numpy as np
from numpy.linalg import norm
from datetime import datetime
from random import sample as choose
from random import random
from numba import jit, njit, prange
from plot import qtplot
from neuropercolation import Simulate4Layers
eps_space = list(np.linspace(0.01,0.2,20))
def new_folder(path):
if not os.path.exists(path):
os.makedirs(path)
return path
phase = np.vectorize(lambda x,y: (m.atan2(y,x)+m.pi)%(2*m.pi)-m.pi)
diff = np.vectorize(lambda x,y: (y-x+m.pi)%(2*m.pi)-m.pi)
H2 = lambda x: -x*m.log2(x)-(1-x)*m.log2(1-x)
@njit
def neighbor(digit0, digit1, lenght):
layer = int(lenght)
dim = int(np.sqrt(layer))
digit0, digit1 = np.array([digit0%dim, digit0//dim]), np.array([digit1%dim, digit1//dim])
#print(digit0,digit1)
coord_dif = list(map(abs,digit1 - digit0))
layer_nbor = 0 in coord_dif and len(set([1,dim-1]).intersection(set(coord_dif))) != 0
#print(coord_dif, set([1,dim-1]).intersection(set(coord_dif)))
if layer_nbor:
return True
else:
return False
@njit
def kcomb(zp,zm):
if zp>2:
val=1
elif zm>2:
val=0
elif zp==zm:
val=0.5
elif zm==2:
val=0.5**(3-zp)
elif zp==2:
val=1-0.5**(3-zm)
elif zm==0 and zp==1:
val=9/16
elif zp==0 and zm==1:
val=7/16
else:
raise NotImplementedError(zp,zm)
return val
path = new_folder('/cloud/Public/_data/neuropercolation/1lay/mips/')
def phi(dim,statestr,partstr,eps):
length = dim**2
eta = 1-eps
# statestr=statestr.translate(str.maketrans('','','.-='))
state = np.array([int(q) for q in statestr])
state = list(state.reshape((dim,dim)))
state = [list([int(cell) for cell in row]) for row in state]
part = np.array([int(p) for p in partstr])
part = list(part.reshape((dim,dim)))
part = [list([int(cell) for cell in row]) for row in part]
inp = [[q+sum([state[(i+1)%dim][j],
state[(i-1)%dim][j],
state[i][(j+1)%dim],
state[i][(j-1)%dim]
]) for j,q in enumerate(row)] for i,row in enumerate(state)]
beps = [[int(inp[i][j]>2)*eta+int(inp[i][j]<3)*eps for j,q in enumerate(row)] for i,row in enumerate(state)]
zplus = [[q+sum([state[(i+1)%dim][j]*(part[i][j]==part[(i+1)%dim][j]),
state[(i-1)%dim][j]*(part[i][j]==part[(i-1)%dim][j]),
state[i][(j+1)%dim]*(part[i][j]==part[i][(j+1)%dim]),
state[i][(j-1)%dim]*(part[i][j]==part[i][(j-1)%dim])
]) for j,q in enumerate(row)] for i,row in enumerate(state)]
zminus = [[(1-q)+sum([(1-state[(i+1)%dim][j])*(part[i][j]==part[(i+1)%dim][j]),
(1-state[(i-1)%dim][j])*(part[i][j]==part[(i-1)%dim][j]),
(1-state[i][(j+1)%dim])*(part[i][j]==part[i][(j+1)%dim]),
(1-state[i][(j-1)%dim])*(part[i][j]==part[i][(j-1)%dim])
]) for j,q in enumerate(row)] for i,row in enumerate(state)]
kplus = [[kcomb(zplus[i][j],zminus[i][j]) for j,q in enumerate(row)] for i,row in enumerate(state)]
pi = [[eps*(1-kplus[i][j]) + eta*kplus[i][j] for j,q in enumerate(row)] for i,row in enumerate(state)]
crossent = [[-beps[i][j]*m.log2(pi[i][j])-(1-beps[i][j])*m.log2(1-pi[i][j]) for j,q in enumerate(row)] for i,row in enumerate(state)]
return np.sum(crossent) - length*H2(eps)
def MIP(dim,statestr,eps):
lophi=np.inf
mip = []
# statestr=statestr.translate(str.maketrans('','','.-='))
for parti in range(1,2**(dim**2-1)):
partstr = bin(parti)[2:].zfill(dim**2)
curphi = phi(dim,statestr,partstr,eps)
if curphi<lophi:
lophi=curphi
mip = [partstr]
elif curphi==lophi:
mip.append(partstr)
print(f'Done with {dim},{eps} = {mip},{lophi}')
return mip,lophi
def calc_mips(dim,eps):
mip = []
statestr='0'*dim**2
# statestr=statestr.translate(str.maketrans('','','.-='))
for parti in range(1,2**(dim**2-1)):
partstr = bin(parti)[2:].zfill(dim**2)
curphi = phi(dim,statestr,partstr,eps)
mip.append(round(curphi,6))
mipath = new_folder(path+f'dim={dim:02d}/')
with open(mipath+f"eps={round(eps,3):.3f}_mips.txt", 'w', encoding='utf-8') as f:
json.dump(mip, f, indent=1)
def smartMIP(dim,statestr,eps):
lophi=np.inf
mip = []
for parti in range(0,dim**2):
partstr = bin(2**parti)[2:].zfill(dim**2)
curphi = phi(dim,statestr,partstr,eps)
if curphi<lophi:
lophi=curphi
mip = [partstr]
elif curphi==lophi:
mip.append(partstr)
mipath = new_folder(path+f'dim={dim:02d}/')
with open(mipath+f"eps={round(eps,3):.3f}_mips.txt", 'r', encoding='utf-8') as f:
mips=json.load(f)
for parti in range(1,2**(dim**2-1)):
partstr = bin(parti)[2:].zfill(dim**2)
if mips[parti-1]<curphi:
curphi = phi(dim,statestr,partstr,eps)
if curphi<lophi:
lophi=curphi
mip = [partstr]
elif curphi==lophi:
mip.append(partstr)
return mip,lophi