Commit bb41dd1d authored by Fabrizio Detassis's avatar Fabrizio Detassis
Browse files

Added implementation examples for SMT and Cplex.

parent 9d0a385a
# Input Parameters:
# - n_points -> number of examples.
# - n_classes -> number of classes
import numpy as np
from docplex.mp.model import Model as CPModel
n_points = 10
n_classes = 5
"""
Variable matrix of the type
Z = [
[ 0, 0, 0, 1, 0],
[ 1, 0, 0, 0, 0],
[ 0, 1, 0, 0, 0],
...
[ 1, 0, 0, 0, 0]
]
"""
# Build a model
mod = CPModel('Class balancer')
# Define the (binary) variables.
Z = mod.binary_var_matrix(keys1=n_points,
keys2=n_classes,
name='z')
# Each example has to be assigned to one class.
for i in range(n_points):
xpr = mod.sum(Z[i, c] for c in range(n_classes))
mod.add_constraint(xpr == 1)
# Balance constraint.
print("Adding Balance Constraint")
B = int(np.ceil(n_points / n_classes))
for c in range(n_classes):
xpr = mod.sum([Z[i, c] for i in range(n_points)])
mod.add_constraint(xpr <= B)
# Define the loss w.r.t. the true labels
# p_loss = (1 / n_points) * mod.sum([(1 - Z[i, p[i]]) for i in range(n_points)])
# y_loss = (1 / n_points) * mod.sum([(1 - Z[i, y[i]]) for i in range(n_points)])
# Solve the problem
sol = mod.solve()
if sol:
sat = mod.get_solve_status()
print("Status: " + str(sat))
zarr = [sum(c * sol.get_value(Z[i, c])
for c in range(n_classes)) for i in range(n_points)]
# Z_sol = np.array([int(v) for v in zarr])
Z_sol = np.array([
[sol.get_value(Z[i, c]) for c in range(n_classes)] for i in range(n_points)
], dtype=int)
print(Z_sol)
else:
print("No solution found")
# Input Parameters:
# - n_points -> number of examples.
# - n_classes -> number of classes
import numpy as np
from pysmt.shortcuts import Symbol, LE, GE, Int, And, Equals, Plus, is_sat, get_model
from pysmt.typing import INT
n_points = 10
n_classes = 5
"""
Variable matrix of the type
Z = [
[ 0, 0, 0, 1, 0],
[ 1, 0, 0, 0, 0],
[ 0, 1, 0, 0, 0],
...
[ 1, 0, 0, 0, 0]
]
"""
Z = [[Symbol("z_%d%d" % (i, j), INT) for i in range(n_classes)]
for j in range(n_points)]
# Define the variable's domain.
bool_domain = And(And(And(GE(z, Int(0)),
LE(z, Int(1))) for z in Z[i]) for i in range(n_points))
# Each example has to be assigned to one class.
class_domain = And(Equals(Int(1), Plus(Z[i])) for i in range(n_points))
# Balance constraint.
print("Adding Balance Constraint")
B = int(np.ceil(n_points / n_classes))
bal_domain = And(LE(Plus([Z[i][j] for i in range(n_points)]), Int(B)) for j in range(n_classes))
formula = And(And(bal_domain, class_domain), bool_domain)
# print("Serialization of the formula:")
# print(formula)
model = get_model(formula)
if model:
# print(model)
sat = is_sat(formula, solver_name='msat')
print("SAT: " + str(sat))
Z_sol = np.array([[model.get_value(Z[j][i]) for i in range(n_classes)] for j in range(n_points)])
print(Z_sol)
else:
print("No solution found")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment