neuropercolation/evaluation/phi.py

200 lines
6.5 KiB
Python
Raw Normal View History

2023-09-30 17:53:06 +00:00
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Sep 27 04:39:54 2023
2023-12-14 19:49:44 +00:00
@author: timofej
2023-09-30 17:53:06 +00:00
"""
import os
import json
import math as m
import numpy as np
from numpy.linalg import norm
from datetime import datetime
from random import sample as choose
from random import random
2023-12-10 21:47:15 +00:00
import itertools
#from numba import jit, njit, prange
2023-09-30 17:53:06 +00:00
2023-12-10 21:47:15 +00:00
# from plot import qtplot
2023-09-30 17:53:06 +00:00
2023-12-10 21:47:15 +00:00
# from neuropercolation import Simulate4Layers
2023-09-30 17:53:06 +00:00
eps_space = list(np.linspace(0.01,0.2,20))
def new_folder(path):
if not os.path.exists(path):
os.makedirs(path)
return path
phase = np.vectorize(lambda x,y: (m.atan2(y,x)+m.pi)%(2*m.pi)-m.pi)
diff = np.vectorize(lambda x,y: (y-x+m.pi)%(2*m.pi)-m.pi)
H2 = lambda x: -x*m.log2(x)-(1-x)*m.log2(1-x)
2023-12-10 21:47:15 +00:00
2023-09-30 17:53:06 +00:00
def neighbor(digit0, digit1, lenght):
layer = int(lenght)
dim = int(np.sqrt(layer))
digit0, digit1 = np.array([digit0%dim, digit0//dim]), np.array([digit1%dim, digit1//dim])
#print(digit0,digit1)
coord_dif = list(map(abs,digit1 - digit0))
layer_nbor = 0 in coord_dif and len(set([1,dim-1]).intersection(set(coord_dif))) != 0
#print(coord_dif, set([1,dim-1]).intersection(set(coord_dif)))
if layer_nbor:
return True
else:
return False
2023-12-10 21:47:15 +00:00
2023-09-30 17:53:06 +00:00
def kcomb(zp,zm):
2023-12-10 21:47:15 +00:00
if zp+zm>5:
val=None
elif zp>2:
2023-09-30 17:53:06 +00:00
val=1
elif zm>2:
val=0
elif zp==zm:
val=0.5
elif zm==2:
val=0.5**(3-zp)
elif zp==2:
val=1-0.5**(3-zm)
elif zm==0 and zp==1:
val=9/16
elif zp==0 and zm==1:
val=7/16
else:
raise NotImplementedError(zp,zm)
return val
2023-12-10 21:47:15 +00:00
#%%
neighbourcomb = [[(i,j) for j in range(6)] for i in range(6)]
probcombs = [[kcomb(*n) for n in row] for row in neighbourcomb]
posscombs = {val for row in probcombs for val in row}
probcombs1 = [kcomb(n,4-n) for n in range(5)]
eta = lambda eps: 1-eps
etastar = lambda eps,kp: kp*(1-eps)+(1-kp)*eps
epsstar = lambda eps,kp: kp*eps+(1-kp)*(1-eps)
KL = lambda eps,kp: eta(eps)*m.log2(eta(eps)/etastar(eps,kp)) + eps*m.log2(eps/epsstar(eps,kp)) if kp is not None else 0
phi_mat = lambda eps: [[KL(eps,kcomb(i,j)) for j in range(6)] for i in range(6)]
phi_vals = lambda eps,df=4: sorted({sum([KL(eps,kp) for kp in kps]) for kps in itertools.product(*[[1,0.5]]*df)})
phi_sing = lambda eps: KL(eps,kcomb(1,0))
phi_lims = lambda eps: [phi_sing(eps)+val for val in phi_vals(eps)]
2023-09-30 17:53:06 +00:00
2023-12-10 21:47:15 +00:00
#%%
2023-09-30 17:53:06 +00:00
path = new_folder('/cloud/Public/_data/neuropercolation/1lay/mips/')
def phi(dim,statestr,partstr,eps):
length = dim**2
eta = 1-eps
# statestr=statestr.translate(str.maketrans('','','.-='))
state = np.array([int(q) for q in statestr])
state = list(state.reshape((dim,dim)))
state = [list([int(cell) for cell in row]) for row in state]
part = np.array([int(p) for p in partstr])
part = list(part.reshape((dim,dim)))
part = [list([int(cell) for cell in row]) for row in part]
inp = [[q+sum([state[(i+1)%dim][j],
state[(i-1)%dim][j],
state[i][(j+1)%dim],
state[i][(j-1)%dim]
]) for j,q in enumerate(row)] for i,row in enumerate(state)]
beps = [[int(inp[i][j]>2)*eta+int(inp[i][j]<3)*eps for j,q in enumerate(row)] for i,row in enumerate(state)]
zplus = [[q+sum([state[(i+1)%dim][j]*(part[i][j]==part[(i+1)%dim][j]),
state[(i-1)%dim][j]*(part[i][j]==part[(i-1)%dim][j]),
state[i][(j+1)%dim]*(part[i][j]==part[i][(j+1)%dim]),
state[i][(j-1)%dim]*(part[i][j]==part[i][(j-1)%dim])
]) for j,q in enumerate(row)] for i,row in enumerate(state)]
zminus = [[(1-q)+sum([(1-state[(i+1)%dim][j])*(part[i][j]==part[(i+1)%dim][j]),
(1-state[(i-1)%dim][j])*(part[i][j]==part[(i-1)%dim][j]),
(1-state[i][(j+1)%dim])*(part[i][j]==part[i][(j+1)%dim]),
(1-state[i][(j-1)%dim])*(part[i][j]==part[i][(j-1)%dim])
]) for j,q in enumerate(row)] for i,row in enumerate(state)]
kplus = [[kcomb(zplus[i][j],zminus[i][j]) for j,q in enumerate(row)] for i,row in enumerate(state)]
pi = [[eps*(1-kplus[i][j]) + eta*kplus[i][j] for j,q in enumerate(row)] for i,row in enumerate(state)]
crossent = [[-beps[i][j]*m.log2(pi[i][j])-(1-beps[i][j])*m.log2(1-pi[i][j]) for j,q in enumerate(row)] for i,row in enumerate(state)]
return np.sum(crossent) - length*H2(eps)
def MIP(dim,statestr,eps):
lophi=np.inf
mip = []
# statestr=statestr.translate(str.maketrans('','','.-='))
for parti in range(1,2**(dim**2-1)):
partstr = bin(parti)[2:].zfill(dim**2)
curphi = phi(dim,statestr,partstr,eps)
if curphi<lophi:
lophi=curphi
mip = [partstr]
elif curphi==lophi:
mip.append(partstr)
print(f'Done with {dim},{eps} = {mip},{lophi}')
return mip,lophi
2023-12-10 21:47:15 +00:00
def calc_mips(dim,eps,save=True):
mips = [[] for i in range(5)]
2023-09-30 17:53:06 +00:00
statestr='0'*dim**2
2023-12-10 21:47:15 +00:00
lims = phi_lims(eps)
2023-09-30 17:53:06 +00:00
# statestr=statestr.translate(str.maketrans('','','.-='))
for parti in range(1,2**(dim**2-1)):
partstr = bin(parti)[2:].zfill(dim**2)
curphi = phi(dim,statestr,partstr,eps)
2023-12-10 21:47:15 +00:00
for cha,mip in enumerate(mips):
if curphi<lims[cha]:
mip.append(parti)
print(f'Part {partstr} in mips[{cha}] with phi={curphi}')
2023-09-30 17:53:06 +00:00
mipath = new_folder(path+f'dim={dim:02d}/')
2023-12-10 21:47:15 +00:00
for cha,mip in enumerate(mips):
with open(mipath+f"eps={round(eps,3):.3f}_mipis_{cha}.txt", 'w', encoding='utf-8') as f:
json.dump(mip, f, indent=1) if save else None
return mips
2023-09-30 17:53:06 +00:00
def smartMIP(dim,statestr,eps):
lophi=np.inf
mip = []
2023-12-10 21:47:15 +00:00
phi_lims = phi_vals(eps)
2023-09-30 17:53:06 +00:00
for parti in range(0,dim**2):
partstr = bin(2**parti)[2:].zfill(dim**2)
curphi = phi(dim,statestr,partstr,eps)
if curphi<lophi:
lophi=curphi
mip = [partstr]
elif curphi==lophi:
mip.append(partstr)
mipath = new_folder(path+f'dim={dim:02d}/')
with open(mipath+f"eps={round(eps,3):.3f}_mips.txt", 'r', encoding='utf-8') as f:
mips=json.load(f)
for parti in range(1,2**(dim**2-1)):
partstr = bin(parti)[2:].zfill(dim**2)
if mips[parti-1]<curphi:
curphi = phi(dim,statestr,partstr,eps)
if curphi<lophi:
lophi=curphi
mip = [partstr]
elif curphi==lophi:
mip.append(partstr)
return mip,lophi