Commit 94709cc2 authored by Fabrizio Detassis's avatar Fabrizio Detassis
Browse files

Crime and Adult test set updated, SMT implementation refined.

parent 2b06e488
import numpy as np
from moving_target_abc import MovingTarget
from pysmt.shortcuts import Symbol, LE, GE, LT, GT, Int, And, Equals, Plus, Minus, Div, Times, Max, is_sat, get_model, \
Real, Solver, Ite
from pysmt.typing import REAL, INT
from pysmt.shortcuts import Symbol, And, LE, GE, LT, GT, is_sat, get_model, Real, Int, Solver, Ite
from pysmt.shortcuts import Equals, Plus, Minus, Div, Times
from pysmt.typing import REAL, INT, BOOL
from constraint import InequalityRegGlobalConstraint, BalanceConstraint
from constraint import FairnessRegConstraint, FairnessClsConstraint
import utils
import signal
from contextlib import contextmanager
import _thread
import threading
import sys
import multiprocess
import time
class TimeoutException(Exception):
......@@ -19,7 +21,7 @@ class TimeoutException(Exception):
@contextmanager
def time_limit2(seconds, msg=''):
def time_limit(seconds, msg=''):
timer = threading.Timer(seconds, lambda: _thread.interrupt_main())
timer.start()
try:
......@@ -31,18 +33,6 @@ def time_limit2(seconds, msg=''):
timer.cancel()
@contextmanager
def time_limit(seconds):
def signal_handler(signum, frame):
raise TimeoutException("Timed out!")
signal.signal(signal.SIGALRM, signal_handler)
signal.alarm(seconds)
try:
yield
finally:
signal.alarm(0)
class SMTModel:
_TIME_LIMIT = 5
......@@ -62,7 +52,109 @@ class SMTModel:
for c in constraints:
self.add_constraint(c)
def optimize(self, alpha=0.05, max_iter=20):
def optimize(self, max_iter=20):
it = 0
domain = self.domain
loss = self.loss
ub = None
lb = None
print("Optimizing SMT Model")
with Solver('msat') as solver:
solver.add_assertion(domain)
solver.push()
unsat_counts = 0
if solver.solve():
while it < max_iter and unsat_counts < 3:
print("Iteration ", str(it))
# print(solver.print_model())
if ub is None:
ub = solver.get_py_value(loss)
if lb is None:
lb = 0
print("Bounds: [%.2f - %.2f]" % (lb, ub))
if np.abs(ub - lb) > 0.01:
it += 1
else:
break
bound = (ub + lb) / 2
solver.add_assertion(LT(loss, Real(bound)))
solver.push()
try:
with time_limit(self._TIME_LIMIT):
sat = solver.solve()
except TimeoutException:
print("Timed out!")
sat = False
"""
sat = False
queue = multiprocess.Queue()
p = multiprocess.Process(target=self.solve, args=(solver, queue))
p.start()
sat = queue.get()
p.join(self._TIME_LIMIT)
if p.is_alive():
print("Timed out!")
p.terminate()
p.close()
manager = multiprocess.Manager()
return_dict = manager.dict()
p = multiprocess.Process(target=self.solve, args=(solver, return_dict))
p.start()
p.join(self._TIME_LIMIT)
if p.is_alive():
print("Timed out!")
p.terminate()
sat = return_dict['sat']
if sat is None:
sat = False
"""
# Update the bounds.
if sat:
ub = bound
# Distinguish between regression and classification models.
if isinstance(self.variables[0], list):
y_opt = [
[float(solver.get_py_value(_y)) for _y in y] for y in self.variables
]
else:
y_opt = [float(solver.get_py_value(y)) for y in self.variables]
unsat_counts = 0
print("SAT!")
else:
print("UNSAT!")
solver.pop()
lb = bound
unsat_counts += 1
return y_opt
else:
print("Problem not satisfiable!")
@staticmethod
def solve(solver, q):
sat = solver.solve()
q.put(sat)
# q['sat'] = sat
def optimize_OLD(self, alpha=0.05, max_iter=20):
"""
Heuristic search for the optimal value of the SMT.
We proceed as follows:
......@@ -79,7 +171,7 @@ class SMTModel:
sat = False
while it < max_iter and unsat_counts < 3:
model = get_model(domain)
model = get_model(domain, solver_name='msat')
print("Solving problem")
current_loss = model.get_py_value(loss)
print("It: %d; current loss: %.2f" % (it, current_loss))
......@@ -93,16 +185,14 @@ class SMTModel:
else:
break
# try:
# with time_limit2(self._TIME_LIMIT):
# sat = is_sat(And(self.domain, LT(loss, Real(lb))))
# except TimeoutException as e:
# print("Timed out!")
# sat = False
# except Exception as e:
# raise e
try:
with time_limit(self._TIME_LIMIT):
sat = is_sat(And(self.domain, LT(loss, Real(lb))))
except TimeoutException:
print("Timed out!")
sat = False
sat = is_sat(And(self.domain, LT(loss, Real(lb))))
# sat = is_sat(And(self.domain, LT(loss, Real(lb))))
if sat:
print("SAT!")
domain = And(self.domain, LT(loss, Real(lb)))
......@@ -111,7 +201,7 @@ class SMTModel:
alpha = alpha / 2
unsat_counts += 1
model = get_model(domain)
model = get_model(domain, solver_name='msat')
# Distinguish between regression and classification models.
try:
......@@ -328,7 +418,7 @@ class MovingTargetClsSMT(MovingTarget):
pfeat = c[2]
cval = c[3]
cstr = FairnessClsConstraint('ct', pfeat, cval)
abs_val = [Symbol("y_%d" % i, REAL) for i in range(len(np.unique(x_s[:, pfeat])) * self.n_classes)]
abs_val = [Symbol("y_%d" % i, REAL) for i in range(len(np.unique(x_s[:, pfeat])) * self.n_classes)]
# Add fairness constraint.
for ix_feat in pfeat:
......
......@@ -4,3 +4,4 @@ matplotlib==3.4.1
numpy==1.20.2
pysmt==0.9.0
jupyter==1.0.0
scikit-learn
\ No newline at end of file
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment