/*
 * Decompiled with CFR 0.152.
 */
package cc.mallet.types.tests;

import cc.mallet.types.MatrixOps;
import java.io.File;
import org.junit.Assert;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.TemporaryFolder;

public class TestMatrixOps {
    public static double[] digits = new double[]{1.0, 2.0, 3.0, 4.0, 5.0};
    public static double[][] matrix = new double[][]{{1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}, {1.0, 2.0}};
    public static double[][] matrixTranspose = new double[][]{{1.0, 1.0, 1.0, 1.0, 1.0}, {2.0, 2.0, 2.0, 2.0, 2.0}};
    @Rule
    public TemporaryFolder folder = new TemporaryFolder();

    @Test
    public void testSum() {
        double sum = MatrixOps.sum(digits);
        Assert.assertEquals(15.0, sum, 0.0);
    }

    @Test
    public void testClone() {
        double[][] clone = MatrixOps.deepClone(matrix);
        double diff = MatrixOps.sumSquaredDiff(clone, matrix);
        Assert.assertEquals(0.0, diff, 0.0);
        Assert.assertNotSame(clone[0], matrix[0]);
    }

    @Test
    public void testFrobenius() {
        double diff = MatrixOps.sumSquaredDiff(matrix, matrix);
        Assert.assertEquals(0.0, diff, 0.0);
        double[][] zeros = new double[][]{{0.0, 0.0}, {0.0, 0.0}};
        double[][] ones = new double[][]{{1.0, 1.0}, {1.0, 1.0}};
        diff = MatrixOps.sumSquaredDiff(zeros, ones);
        Assert.assertEquals(4.0, diff, 0.0);
    }

    @Test
    public void saveAndLoad() {
        try {
            File savedFile = this.folder.newFile("matrix.txt");
            MatrixOps.savetxt(matrix, savedFile);
            double[][] loadedMatrix = MatrixOps.loadtxt(savedFile);
            MatrixOps.print(matrix);
            MatrixOps.print(loadedMatrix);
            double diff = MatrixOps.sumSquaredDiff(matrix, loadedMatrix);
            Assert.assertEquals(0.0, diff, 0.0);
        }
        catch (Exception e) {
            e.printStackTrace();
        }
    }

    @Test
    public void matrixMultiply() {
        double[][] product = MatrixOps.aTransposeTimesB(matrix, matrix);
        double[][] correct = new double[][]{{5.0, 10.0}, {10.0, 20.0}};
        MatrixOps.print(product);
        MatrixOps.print(correct);
        double diff = MatrixOps.sumSquaredDiff(product, correct);
        Assert.assertEquals(0.0, diff, 0.0);
        product = MatrixOps.aTimesB(matrixTranspose, matrix);
        diff = MatrixOps.sumSquaredDiff(product, correct);
        Assert.assertEquals(0.0, diff, 0.0);
    }
}

