Commit 98ad5eed authored by Uwe Köckemann's avatar Uwe Köckemann
Browse files

Fixed more broken references and some rogue comments.

parent 168d3964
package org.aiddl.common.learning.linear_regression;
import org.aiddl.common.math.linear_algebra.GivensQrDecomposition;
import org.aiddl.common.math.linear_algebra.IdentityMatrix;
import org.aiddl.common.math.linear_algebra.LupInverter;
import org.aiddl.common.math.linear_algebra.MatrixMultiplication;
import org.aiddl.common.math.linear_algebra.MatrixScalarMultiplication;
import org.aiddl.common.math.linear_algebra.MatrixVectorMultiplication;
import org.aiddl.core.interfaces.Function;
import org.aiddl.core.representation.NumericalTerm;
import org.aiddl.core.representation.Term;
import org.aiddl.core.tools.LockableList;
public class RidgeRegression implements Function {
@Override
public Term compute(Term args) {
Term X = args.get(0);
Term y = args.get(1);
int D = X.get(0).size();
int N = X.size();
double tau_val = 2.0;
NumericalTerm sigma = Term.real(tau_val*tau_val).asNum();
NumericalTerm tau = Term.real(tau_val).asNum();
NumericalTerm one_by_tau_sq = Term.integer(1).div(tau.mult(tau));
IdentityMatrix ident = new IdentityMatrix();
MatrixScalarMultiplication msMult = new MatrixScalarMultiplication();
Term X_div = msMult.compute(Term.tuple(X, Term.real(1.0).div(sigma)));
Term Lambda = msMult.compute(Term.tuple(ident.compute(Term.integer(D)), one_by_tau_sq));
LockableList y_tilde_list = new LockableList();
LockableList X_tilde_list = new LockableList();
for ( int i = 0 ; i < X_div.size(); i++ ) {
X_tilde_list.add(X_div.get(i));
y_tilde_list.add(y.get(i).asNum().mult(one_by_tau_sq));
}
for ( int i = 0 ; i < Lambda.size() ; i++ ) {
X_tilde_list.add(Lambda.get(i));
y_tilde_list.add(Term.integer(0));
}
Term X_tilde = Term.tuple(X_tilde_list);
Term y_tilde = Term.tuple(y_tilde_list);
GivensQrDecomposition qrDecomp = new GivensQrDecomposition();
Term QR = qrDecomp.compute(X_tilde);
Term Q = QR.get(0);
Term R = QR.get(1);
LockableList R_square_list = new LockableList();
for ( int i = 0 ; i < R.get(0).size() ; i++ ) {
R_square_list.add(R.get(i));
}
Term R_square = Term.tuple(R_square_list);
// System.out.println("Q");
// for ( int i = 0 ; i < Q.size() ; i++ ) {
// System.out.println(Q.get(i));
// }
System.out.println("R");
for ( int i = 0 ; i < R.size() ; i++ ) {
System.out.println(R.get(i));
}
LupInverter inv = new LupInverter();
LockableList null_vec_list = new LockableList();
for ( int i = 0 ; i < R.get(0).size() ; i++ ) {
null_vec_list.add(Term.integer(0));
}
Term null_vec = Term.tuple(null_vec_list);
Term R_inv = inv.compute(R_square);
LockableList R_inv_ext_list = new LockableList();
for ( int i = 0 ; i < R_inv.size() ; i++ ) {
R_inv_ext_list.add(R.get(i));
}
for ( int i = R_inv.size() ; i < R.size() ; i++ ) {
R_inv_ext_list.add(null_vec);
}
MatrixMultiplication mult = new MatrixMultiplication();
MatrixVectorMultiplication mvMult = new MatrixVectorMultiplication();
// System.out.println("INV: " + mult.compute(Term.tuple(R_square, R_inv)));
Term R_invQ = mult.compute(Term.tuple(R_inv, Q));
// System.out.println(y_tilde);
Term w = mvMult.compute(Term.tuple(R_invQ, y_tilde));
System.out.println("w");
System.out.println(w);
return null;
}
}
......@@ -3,37 +3,7 @@ package org.aiddl.common.math.graph;
import org.aiddl.core.representation.Term;
public class GraphTools {
/**
* Assemble and return a directed graph edge.
* This is a convenience function that wraps two AIDDL terms into a minimal edge.
* @param v1 source node
* @param v2 destination node
* @return term representing the new edge
*/
// public static Term assembleDirectedEdge ( Term v1, Term v2 ) {
// return Term.tuple(Term.keyVal(GraphTerm.Edge, Term.tuple(v1, v2)));
// }
// public static Term assembleDirectedLabelledEdge ( Term v1, Term v2, Term l ) {
// return Term.tuple(
// Term.keyVal(GraphTerm.Edge, Term.tuple(v1, v2)),
// Term.keyVal(GraphTerm.Label, l));
// }
// public static Term assembleEdge ( Term v1, Term v2 ) {
// return Term.tuple(Term.keyVal(GraphTerm.Edge, Term.set(v1, v2)));
// }
// public static Term assembleEdge ( Term v1, Term v2, List<KeyValueTerm> features ) {
// LockableList L = new LockableList();
// L.add(Term.keyVal(GraphTerm.Edge, Term.set(v1, v2)));
// for ( KeyValueTerm f : features ) {
// L.add(f);
// }
// return Term.tuple(L);
// }
public static Term assembleGraph ( Term V, Term E ) {
return Term.tuple(Term.keyVal(GraphTerm.Nodes, V), Term.keyVal(GraphTerm.Edges, E));
}
......
......@@ -4,7 +4,6 @@ import java.util.HashMap;
import org.aiddl.common.learning.LearningTerm;
import org.aiddl.common.learning.linear_regression.LinearRegression;
import org.aiddl.common.learning.linear_regression.OrdinaryLeastSquaresRegression;
import org.aiddl.common.learning.linear_regression.RidgeRegression;
import org.aiddl.common.learning.testing.CrossValidation;
import org.aiddl.core.container.Container;
import org.aiddl.core.container.Entry;
......@@ -78,7 +77,6 @@ public class TestLinearRegression extends TestCase {
Term X = Term.tuple(X_l);
Term y = Term.tuple(y_l);
OrdinaryLeastSquaresRegression oReg = new OrdinaryLeastSquaresRegression();
RidgeRegression reg = new RidgeRegression();
Term w_oreg = oReg.compute(Term.tuple(X, y));
}
......
package org.aiddl.common.math.linear_algebra;
import org.aiddl.core.parser.Parser;
import org.aiddl.core.representation.Term;
import junit.framework.TestCase;
@SuppressWarnings("javadoc")
public class TestCholeskyDecomposition extends TestCase {
@Override
public void setUp() throws Exception {
}
@Override
public void tearDown() throws Exception {
}
public void testCholeskyDecomposition() {
MatrixMultiplication mult = new MatrixMultiplication();
Term A = Parser.ParseTerm("("
+ "(0.25 0 0)"
+ "(0 0.25 0)"
+ "(0 0 0.25)"
+ ")");
CholeskyDecomposition cd = new CholeskyDecomposition();
Term LL = cd.compute(A);
Term B = mult.compute(LL);
assertTrue( A.size() == B.size() );
assertTrue( A.get(0).size() == B.get(0).size() );
for ( int i = 0 ; i < B.size() ; i++ ) {
for ( int j = 0 ; j < B.get(0).size() ; j++ ) {
double a_ij = A.get(i).get(j).getDoubleValue();
double b_ij = B.get(i).get(j).getDoubleValue();
assertTrue(Math.abs(a_ij - b_ij) < 0.0001);
}
}
}
}
package org.aiddl.common.math.linear_algebra;
import org.aiddl.core.parser.Parser;
import org.aiddl.core.representation.Term;
import junit.framework.TestCase;
@SuppressWarnings("javadoc")
public class TestQRDecomposition extends TestCase {
EpsilonEquality eq = new EpsilonEquality();
@Override
public void setUp() throws Exception {
}
@Override
public void tearDown() throws Exception {
}
public void testQrGivensSquare01() {
MatrixMultiplication mult = new MatrixMultiplication();
Term A = Parser.ParseTerm("("
+ "(6 5 0)"
+ "(5 1 4)"
+ "(0 4 3)"
+ ")");
GivensQrDecomposition qr = new GivensQrDecomposition();
Term QR = qr.compute(A);
Term Q = QR.get(0);
// System.out.println("Q");
// for ( int i = 0 ; i < Q.size() ; i++ ) {
// System.out.println(Q.get(i));
// }
Term R = QR.get(1);
// System.out.println("R");
// for ( int i = 0 ; i < R.size() ; i++ ) {
// System.out.println(R.get(i));
// }
Term B = mult.compute(QR);
assertTrue( A.size() == B.size() );
assertTrue( A.get(0).size() == B.get(0).size() );
for ( int i = 0 ; i < B.size() ; i++ ) {
for ( int j = 0 ; j < B.get(0).size() ; j++ ) {
double a_ij = A.get(i).get(j).getDoubleValue();
double b_ij = B.get(i).get(j).getDoubleValue();
assertTrue(Math.abs(a_ij - b_ij) < 0.0001);
}
}
}
public void testGivens01() {
GivensRotation gRot = new GivensRotation();
Term G = gRot.compute(Term.tuple(
Term.integer(3),
Term.integer(3),
Term.integer(2),
Term.integer(0),
Term.integer(12),
Term.integer(-4)));
Term G_test = Parser.ParseTerm(
"((0.9486832980505138 0.0 -0.31622776601683794)" +
"(0.0 1.0 0.0)" +
"(0.31622776601683794 0.0 0.9486832980505138))");
// MatrixTools.printMatrix(G);
assertTrue( eq.compute(Term.tuple(G, G_test, Term.real(0.0001))).getBooleanValue());
}
public void testGivens02() {
GivensRotation gRot = new GivensRotation();
MatrixMultiplication mult = new MatrixMultiplication();
Term G = gRot.compute(Term.tuple(
Term.integer(3),
Term.integer(3),
Term.integer(0),
Term.integer(1),
Term.integer(6),
Term.integer(5)));
Term A = Parser.ParseTerm(
"( (6 5 0)"
+ "(5 1 4)"
+ "(0 4 3))");
Term G1 = Parser.ParseTerm(
"((0.7682 0.6402 0.0)"
+ "(-0.6402 0.7682 0.0)"
+ "(0.0 0.0 1.0))");
Term A2 = mult.compute(Term.tuple(G1, A));
double absVal = Math.abs(A2.get(1).get(0).getDoubleValue());
assertTrue(absVal < 0.001);
// assertTrue( eq.compute(Term.tuple(G, G_test, Term.real(0.0001))).getBooleanValue());
}
public void testGivens03() {
GivensRotation gRot = new GivensRotation();
Term G = gRot.compute(Term.tuple(
Term.integer(3),
Term.integer(3),
Term.integer(1),
Term.integer(0),
Term.integer(6),
Term.integer(5)));
Term G_test = Parser.ParseTerm(
"((0.7682212795973759 0.6401843996644799 0.0)"
+ "(-0.6401843996644799 0.7682212795973759 0.0)"
+ "(0.0 0.0 1.0))");
assertTrue( eq.compute(Term.tuple(G, G_test, Term.real(0.0001))).getBooleanValue());
}
public void testGivens04() {
GivensRotation gRot = new GivensRotation();
Term G = gRot.compute(Term.tuple(
Term.integer(3),
Term.integer(3),
Term.integer(2),
Term.integer(1),
Term.real(-2.4327),
Term.integer(4)));
Term G_test = Parser.ParseTerm("((1 0 0)"
+ "(0 -0.5196 0.8544)"
+ "(0 -0.8544 -0.5196))");
// MatrixTools.printMatrix(G);
assertTrue( eq.compute(Term.tuple(G, G_test, Term.real(0.0001))).getBooleanValue() );
}
public void testQrGivensRect() {
MatrixMultiplication mult = new MatrixMultiplication();
Term A = Parser.ParseTerm("("
+ "(6 5)"
+ "(5 1)"
+ "(0 4)"
+ ")");
GivensQrDecomposition qr = new GivensQrDecomposition();
Term QR = qr.compute(A);
Term Q = QR.get(0);
// System.out.println("Q");
// for ( int i = 0 ; i < Q.size() ; i++ ) {
// System.out.println(Q.get(i));
// }
Term R = QR.get(1);
// System.out.println("R");
for ( int i = 0 ; i < R.size() ; i++ ) {
// System.out.println(R.get(i));
for ( int j = 0 ; j < R.get(0).size() ; j++ ) {
if ( i > j ) {
double absVal = Math.abs(R.get(i).get(j).getDoubleValue());
assertTrue(absVal < 0.0001);
}
}
}
Term B = mult.compute(QR);
EpsilonEquality eq = new EpsilonEquality();
assertTrue( eq.compute(Term.tuple(A, B, Term.real(0.0001))).getBooleanValue() );
}
// public void testQrDecomposition02() {
// MatrixMultiplication mult = new MatrixMultiplication();
//
// Term A = Parser.ParseTerm("("
// + "(12.0)"
// + "(5.0)"
// + "(3.0)"
// + "(5.0)"
// + "(10.0)"
// + "(-5.0)"
// + "(6.0)"
// + "(-5.0)"
// + "(1.0)"
// + ")");
//
// GivensQrDecomposition qr = new GivensQrDecomposition();
//
// Term QR = qr.compute(A);
//
// Term B = mult.compute(QR);
//
// assertTrue( eq.compute(Term.tuple(A, B, Term.real(0.0001))).getBooleanValue() );
//
//
// Term R = QR.get(1);
//// System.out.println("R");
// for ( int i = 0 ; i < R.size() ; i++ ) {
//// System.out.println(R.get(i));
//
// for ( int j = 0 ; j < R.get(0).size() ; j++ ) {
// if ( i > j ) {
// double absVal = Math.abs(R.get(i).get(j).getDoubleValue());
// assertTrue(absVal < 0.0001);
// }
// }
// }
// }
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment