from __future__ import print_function
from h2p3s import *

A_d, b_d, x_s, loc2glob = generate_problem(geom="Cube", ni=4, dim=3, K1=1, K2=1, nLineK=3, random_rhs=True)

# run(ConjGrad(M=AdditiveSchwarz), A_d, b_d, x_s)
s = DistSchur(A_d, interface_solver=None, local_solver=SpFacto)
S = s.S

dd = S.dd

#Si = ssp.eye(dd.ni).A * (rank+1.)
#S = DistMatrix(Si, dd)
f = S.dot(DistVector(np.ones((dd.ni, 1)), dd))
fi = S.dd.D.dot(f.local)

x = ConjGrad(S).dot(f)

r = RobinRobin_operator(S, Ti=.1)
rhs = r.rhs(f)

# print(rank, "x", x.local.T)
# print(rank, "S", S.local)
# print(rank, "f", f.local.T)
# print(rank, "fi", fi.T)
#print(rank, "rhs1", rhs.local.T)
#TimeIt.DEBUG = rank==1

def cb(x):
    print(rank, "cb", x.local.T)

#l = DistVector2(S.local.dot(x.local) + x.local - fi, dd)

l = 0*rhs
for i in range(200):
    l += .2*(rhs - r.dot(l))

TimeIt.DEBUG = 0

#print(rank, "rhs2", rhs.local.T)

res = r.dot(l) - rhs

print(rank, "res", res.local.T)

x = DistVector(r.Sih_solver.dot(l.local+fi), S.dd)

print(rank, "res", (S.dot(x) - f).local.T)



# 
# xs = ConjGrad(S).dot(f)
# li = S.local.dot(xs.local) - S.dd.D.dot(f.local)
# 
# rhs_ = r.dot(li)
# print(rank, rhs.T.round(4), rhs_.T.round(4))
#