Source code for test_summary_stats

"""Tests for n2j.inference.summary_stats

"""

import os
import shutil
import unittest
import numpy as np
from addict import Dict
import torch
import n2j.inference.summary_stats_baseline as ssb


[docs]class TestSummaryStats(unittest.TestCase): """A suite of tests verifying n2j.inference.summary_stats_utils utility functions """ @classmethod
[docs] def setUpClass(cls): cls.stats_path = 'stats_path_testing.npy'
[docs] def test_get_number_counts(self): """Test `get_number_counts` """ x = torch.randn([10, 3]) batch_indices = torch.tensor([0, 0, 0, 0, 0, 1, 1, 2, 2, 2]) actual = ssb.get_number_counts(x, batch_indices) expected = [5, 2, 3] np.testing.assert_array_equal(actual, expected)
[docs] def test_get_inv_dist_number_counts(self): """Test `get_inv_dist_number_counts` """ x = np.ones([6, 3]) batch_indices = np.array([0, 0, 0, 1, 2, 2]) ra_dec_idx = [0, 1] x[:, ra_dec_idx] = np.array([[3, 4], [1.e-7, 1.e-7], [20, 21], [5, 12], [8, 15], [7, 24]]) actual = ssb.get_inv_dist_number_counts(torch.from_numpy(x), torch.from_numpy(batch_indices), ra_dec_idx) dist = np.array([5, np.sqrt(2)*1.e-7 + 1.e-5, 29, 13, 17, 25]) expected = np.array([sum(1.0/dist[:3]), sum(1.0/dist[3:4]), sum(1.0/dist[4:])]) np.testing.assert_array_almost_equal(actual, expected, decimal=4)
[docs] def test_update(self): """Test `update` method """ ss_obj = ssb.SummaryStats(n_data=3) x_np = np.array([[3, 4], [1.e-7, 1.e-7], [20, 21], [5, 12], [8, 15], [7, 24]]) batch_indices_np = np.array([0, 0, 0, 1, 2, 2]) batches = [] batches.append(Dict( x=torch.tensor(x_np), batch=torch.tensor(batch_indices_np) )) for i, b in enumerate(batches): ss_obj.update(b, i) # Compute expected expected_N = [3, 1, 2] dist = np.array([5, np.sqrt(2)*1.e-7 + 1.e-5, 29, 13, 17, 25]) expected_N_inv_dist = np.array([sum(1.0/dist[:3]), sum(1.0/dist[3:4]), sum(1.0/dist[4:])]) np.testing.assert_array_equal(ss_obj.stats['N'], expected_N) np.testing.assert_array_almost_equal(ss_obj.stats['N_inv_dist'], expected_N_inv_dist, decimal=4)
[docs] def test_set_stats_export_stats(self): """Test `set_stats` method """ any_stats = dict( N=np.arange(5), N_inv_dist=np.arange(5) ) np.save(self.stats_path, any_stats, allow_pickle=True) ss_obj = ssb.SummaryStats(n_data=5) ss_obj.set_stats(self.stats_path) np.testing.assert_array_equal(ss_obj.stats['N'], np.arange(5)) np.testing.assert_array_equal(ss_obj.stats['N_inv_dist'], np.arange(5)) ss_obj.export_stats(self.stats_path) ss_obj.set_stats(self.stats_path) np.testing.assert_array_equal(ss_obj.stats['N'], np.arange(5)) np.testing.assert_array_equal(ss_obj.stats['N_inv_dist'], np.arange(5))
@classmethod
[docs] def tearDownClass(cls): os.remove(cls.stats_path)
if __name__ == '__main__': unittest.main()