ml_pid_cbm.test_validate_model

  1import os
  2import shutil
  3import unittest
  4from pathlib import Path
  5from unittest.mock import mock_open, patch
  6
  7import pandas as pd
  8
  9from .tools.particles_id import ParticlesId as Pid
 10from .validate_model import ValidateModel
 11
 12
 13class TestValidateModel(unittest.TestCase):
 14    def setUp(self):
 15        first_entry = {
 16            "Complex_q": 1.0,
 17            "Complex_p": 2.0,
 18            "Complex_pid": 2212,
 19            "Complex_mass2": 1.1,
 20            "Complex_pT": 0.5,
 21            "Complex_rapidity": 3.0,
 22            "model_output_0": 0.8,
 23            "model_output_1": 0.1,
 24            "model_output_2": 0.1,
 25        }
 26        second_entry = {
 27            "Complex_q": 1.0,
 28            "Complex_p": 2.0,
 29            "Complex_pid": 2212,
 30            "Complex_mass2": 0.5,
 31            "Complex_pT": 0.5,
 32            "Complex_rapidity": 3.0,
 33            "model_output_0": 0.6,
 34            "model_output_1": 0.2,
 35            "model_output_2": 0.2,
 36        }
 37        proton_entry = {
 38            "Complex_q": 1.0,
 39            "Complex_p": 2.0,
 40            "Complex_pid": Pid.PROTON.value,
 41            "Complex_mass2": 0.8,
 42            "model_output_0": 0.8,
 43            "model_output_1": 0.1,
 44            "model_output_2": 0.1,
 45        }
 46        kaon_entry = {
 47            "Complex_q": 1.0,
 48            "Complex_p": 2.0,
 49            "Complex_pid": Pid.POS_KAON.value,
 50            "Complex_mass2": 0.4,
 51            "Complex_pT": 0.5,
 52            "Complex_rapidity": 3.0,
 53            "model_output_0": 0.1,
 54            "model_output_1": 0.8,
 55            "model_output_2": 0.1,
 56        }
 57        pion_entry = {
 58            "Complex_q": 1.0,
 59            "Complex_p": 2.0,
 60            "Complex_pid": Pid.POS_PION.value,
 61            "Complex_mass2": 0.2,
 62            "Complex_pT": 0.5,
 63            "Complex_rapidity": 3.0,
 64            "model_output_0": 0.1,
 65            "model_output_1": 0.1,
 66            "model_output_2": 0.8,
 67        }
 68        bckgr_entry = {
 69            "Complex_q": 1.0,
 70            "Complex_p": 2.0,
 71            "Complex_pid": 1010,
 72            "Complex_mass2": 0.1,
 73            "Complex_pT": 0.5,
 74            "Complex_rapidity": 3.0,
 75            "model_output_0": 0.3,
 76            "model_output_1": 0.4,
 77            "model_output_2": 0.3,
 78        }
 79        complete_data = [
 80            first_entry,
 81            second_entry,
 82            proton_entry,
 83            {**proton_entry, "Complex_mass2": 0.9},
 84            kaon_entry,
 85            {**kaon_entry, "Complex_mass2": 0.6},
 86            pion_entry,
 87            {**pion_entry, "Complex_mass2": 0.3},
 88            bckgr_entry,
 89            {**bckgr_entry, "Complex_mass2": 0.0},
 90        ]
 91        self.json_data = """{"var_names": {"momentum": "Complex_p","charge": "Complex_q","mass2": "Complex_mass2","pid": "Complex_pid"},
 92                    "vars_to_draw": ["Complex_mass2", "Complex_p"]}"""
 93        test_config_path = f"{Path(__file__).resolve().parent}/test_config.json"
 94        with patch("builtins.open", mock_open(read_data=self.json_data)):
 95            self.validate = ValidateModel(
 96                2, 4, False, test_config_path, pd.DataFrame(complete_data)
 97            )
 98            self.validate_false = ValidateModel(
 99                2, 4, True, test_config_path, pd.DataFrame(complete_data)
100            )
101
102    def test_get_n_classes(self):
103        self.assertEqual(self.validate.get_n_classes(), 4)
104
105    def test_xgb_preds(self):
106        self.validate.xgb_preds(0.7, 0.7, 0.7)
107        df = self.validate.particles_df
108        self.assertEqual(df[df["Complex_mass2"] == 1.1]["xgb_preds"].item(), 0)
109        self.assertEqual(df[df["Complex_mass2"] == 0.5]["xgb_preds"].item(), 3)
110
111    def test_remap_names(self):
112        self.validate.remap_names()
113        df = self.validate.particles_df
114        self.assertEqual(df[df["Complex_mass2"] == 0.2]["Complex_pid"].item(), 2)
115        self.assertEqual(df[df["Complex_mass2"] == 0.4]["Complex_pid"].item(), 1)
116        self.assertEqual(df[df["Complex_mass2"] == 0.8]["Complex_pid"].item(), 0)
117        self.validate_false.remap_names()
118
119    def test_save_df(self):
120        self.validate.save_df()
121
122    def test_generate_plots(self):
123        self.validate.xgb_preds(0.7, 0.7, 0.7)
124        # with patch("builtins.open", mock_open(read_data=self.json_data)):
125        # should be mock json but
126        # https://github.com/julnow/ml-pid-cbm/actions/runs/5004465285/jobs/9029522575
127        self.validate.generate_plots()
128
129    def test_evaluate_probas(self):
130        self.validate.evaluate_probas(0.1, 0.9, 5, 50)
131
132    def test_confusion_matrix_and_stats(self):
133        self.validate.xgb_preds(0.7, 0.7, 0.7)
134        self.validate.confusion_matrix_and_stats()
135
136    def test_parse_model_name(self):
137        model_name_positive = "model_0.0_6.0_positive"
138        lower_p, upper_p, anti = ValidateModel.parse_model_name(model_name_positive)
139        self.assertEqual([lower_p, upper_p, anti], [0.0, 6.0, False])
140        model_name_anti = "model_3.0_6.0_anti"
141        lower_p, upper_p, anti = ValidateModel.parse_model_name(model_name_anti)
142        self.assertEqual([lower_p, upper_p, anti], [3.0, 6.0, True])
143        model_name_incorrect = "model_anti_1_4"
144        self.assertRaises(
145            ValueError, lambda: ValidateModel.parse_model_name(model_name_incorrect)
146        )
147
148    @classmethod
149    def setUpClass(cls):
150        cls.test_dir = Path(__file__).resolve().parent
151        cls.img_dir = f"{cls.test_dir}/testimg"
152        if not os.path.exists(cls.img_dir):
153            os.makedirs(cls.img_dir)
154        os.chdir(cls.img_dir)
155
156    @classmethod
157    def tearDownClass(cls):
158        os.chdir(cls.test_dir)
159        shutil.rmtree(cls.img_dir)
class TestValidateModel(unittest.case.TestCase):
 14class TestValidateModel(unittest.TestCase):
 15    def setUp(self):
 16        first_entry = {
 17            "Complex_q": 1.0,
 18            "Complex_p": 2.0,
 19            "Complex_pid": 2212,
 20            "Complex_mass2": 1.1,
 21            "Complex_pT": 0.5,
 22            "Complex_rapidity": 3.0,
 23            "model_output_0": 0.8,
 24            "model_output_1": 0.1,
 25            "model_output_2": 0.1,
 26        }
 27        second_entry = {
 28            "Complex_q": 1.0,
 29            "Complex_p": 2.0,
 30            "Complex_pid": 2212,
 31            "Complex_mass2": 0.5,
 32            "Complex_pT": 0.5,
 33            "Complex_rapidity": 3.0,
 34            "model_output_0": 0.6,
 35            "model_output_1": 0.2,
 36            "model_output_2": 0.2,
 37        }
 38        proton_entry = {
 39            "Complex_q": 1.0,
 40            "Complex_p": 2.0,
 41            "Complex_pid": Pid.PROTON.value,
 42            "Complex_mass2": 0.8,
 43            "model_output_0": 0.8,
 44            "model_output_1": 0.1,
 45            "model_output_2": 0.1,
 46        }
 47        kaon_entry = {
 48            "Complex_q": 1.0,
 49            "Complex_p": 2.0,
 50            "Complex_pid": Pid.POS_KAON.value,
 51            "Complex_mass2": 0.4,
 52            "Complex_pT": 0.5,
 53            "Complex_rapidity": 3.0,
 54            "model_output_0": 0.1,
 55            "model_output_1": 0.8,
 56            "model_output_2": 0.1,
 57        }
 58        pion_entry = {
 59            "Complex_q": 1.0,
 60            "Complex_p": 2.0,
 61            "Complex_pid": Pid.POS_PION.value,
 62            "Complex_mass2": 0.2,
 63            "Complex_pT": 0.5,
 64            "Complex_rapidity": 3.0,
 65            "model_output_0": 0.1,
 66            "model_output_1": 0.1,
 67            "model_output_2": 0.8,
 68        }
 69        bckgr_entry = {
 70            "Complex_q": 1.0,
 71            "Complex_p": 2.0,
 72            "Complex_pid": 1010,
 73            "Complex_mass2": 0.1,
 74            "Complex_pT": 0.5,
 75            "Complex_rapidity": 3.0,
 76            "model_output_0": 0.3,
 77            "model_output_1": 0.4,
 78            "model_output_2": 0.3,
 79        }
 80        complete_data = [
 81            first_entry,
 82            second_entry,
 83            proton_entry,
 84            {**proton_entry, "Complex_mass2": 0.9},
 85            kaon_entry,
 86            {**kaon_entry, "Complex_mass2": 0.6},
 87            pion_entry,
 88            {**pion_entry, "Complex_mass2": 0.3},
 89            bckgr_entry,
 90            {**bckgr_entry, "Complex_mass2": 0.0},
 91        ]
 92        self.json_data = """{"var_names": {"momentum": "Complex_p","charge": "Complex_q","mass2": "Complex_mass2","pid": "Complex_pid"},
 93                    "vars_to_draw": ["Complex_mass2", "Complex_p"]}"""
 94        test_config_path = f"{Path(__file__).resolve().parent}/test_config.json"
 95        with patch("builtins.open", mock_open(read_data=self.json_data)):
 96            self.validate = ValidateModel(
 97                2, 4, False, test_config_path, pd.DataFrame(complete_data)
 98            )
 99            self.validate_false = ValidateModel(
100                2, 4, True, test_config_path, pd.DataFrame(complete_data)
101            )
102
103    def test_get_n_classes(self):
104        self.assertEqual(self.validate.get_n_classes(), 4)
105
106    def test_xgb_preds(self):
107        self.validate.xgb_preds(0.7, 0.7, 0.7)
108        df = self.validate.particles_df
109        self.assertEqual(df[df["Complex_mass2"] == 1.1]["xgb_preds"].item(), 0)
110        self.assertEqual(df[df["Complex_mass2"] == 0.5]["xgb_preds"].item(), 3)
111
112    def test_remap_names(self):
113        self.validate.remap_names()
114        df = self.validate.particles_df
115        self.assertEqual(df[df["Complex_mass2"] == 0.2]["Complex_pid"].item(), 2)
116        self.assertEqual(df[df["Complex_mass2"] == 0.4]["Complex_pid"].item(), 1)
117        self.assertEqual(df[df["Complex_mass2"] == 0.8]["Complex_pid"].item(), 0)
118        self.validate_false.remap_names()
119
120    def test_save_df(self):
121        self.validate.save_df()
122
123    def test_generate_plots(self):
124        self.validate.xgb_preds(0.7, 0.7, 0.7)
125        # with patch("builtins.open", mock_open(read_data=self.json_data)):
126        # should be mock json but
127        # https://github.com/julnow/ml-pid-cbm/actions/runs/5004465285/jobs/9029522575
128        self.validate.generate_plots()
129
130    def test_evaluate_probas(self):
131        self.validate.evaluate_probas(0.1, 0.9, 5, 50)
132
133    def test_confusion_matrix_and_stats(self):
134        self.validate.xgb_preds(0.7, 0.7, 0.7)
135        self.validate.confusion_matrix_and_stats()
136
137    def test_parse_model_name(self):
138        model_name_positive = "model_0.0_6.0_positive"
139        lower_p, upper_p, anti = ValidateModel.parse_model_name(model_name_positive)
140        self.assertEqual([lower_p, upper_p, anti], [0.0, 6.0, False])
141        model_name_anti = "model_3.0_6.0_anti"
142        lower_p, upper_p, anti = ValidateModel.parse_model_name(model_name_anti)
143        self.assertEqual([lower_p, upper_p, anti], [3.0, 6.0, True])
144        model_name_incorrect = "model_anti_1_4"
145        self.assertRaises(
146            ValueError, lambda: ValidateModel.parse_model_name(model_name_incorrect)
147        )
148
149    @classmethod
150    def setUpClass(cls):
151        cls.test_dir = Path(__file__).resolve().parent
152        cls.img_dir = f"{cls.test_dir}/testimg"
153        if not os.path.exists(cls.img_dir):
154            os.makedirs(cls.img_dir)
155        os.chdir(cls.img_dir)
156
157    @classmethod
158    def tearDownClass(cls):
159        os.chdir(cls.test_dir)
160        shutil.rmtree(cls.img_dir)

A class whose instances are single test cases.

By default, the test code itself should be placed in a method named 'runTest'.

If the fixture may be used for many test cases, create as many test methods as are needed. When instantiating such a TestCase subclass, specify in the constructor arguments the name of the test method that the instance is to execute.

Test authors should subclass TestCase for their own tests. Construction and deconstruction of the test's environment ('fixture') can be implemented by overriding the 'setUp' and 'tearDown' methods respectively.

If it is necessary to override the __init__ method, the base class __init__ method must always be called. It is important that subclasses should not change the signature of their __init__ method, since instances of the classes are instantiated automatically by parts of the framework in order to be run.

When subclassing TestCase, you can set these attributes:

  • failureException: determines which exception will be raised when the instance's assertion methods fail; test methods raising this exception will be deemed to have 'failed' rather than 'errored'.
  • longMessage: determines whether long messages (including repr of objects used in assert methods) will be printed on failure in addition to any explicit message passed.
  • maxDiff: sets the maximum length of a diff in failure messages by assert methods using difflib. It is looked up as an instance attribute so can be configured by individual tests if required.
def setUp(self):
 15    def setUp(self):
 16        first_entry = {
 17            "Complex_q": 1.0,
 18            "Complex_p": 2.0,
 19            "Complex_pid": 2212,
 20            "Complex_mass2": 1.1,
 21            "Complex_pT": 0.5,
 22            "Complex_rapidity": 3.0,
 23            "model_output_0": 0.8,
 24            "model_output_1": 0.1,
 25            "model_output_2": 0.1,
 26        }
 27        second_entry = {
 28            "Complex_q": 1.0,
 29            "Complex_p": 2.0,
 30            "Complex_pid": 2212,
 31            "Complex_mass2": 0.5,
 32            "Complex_pT": 0.5,
 33            "Complex_rapidity": 3.0,
 34            "model_output_0": 0.6,
 35            "model_output_1": 0.2,
 36            "model_output_2": 0.2,
 37        }
 38        proton_entry = {
 39            "Complex_q": 1.0,
 40            "Complex_p": 2.0,
 41            "Complex_pid": Pid.PROTON.value,
 42            "Complex_mass2": 0.8,
 43            "model_output_0": 0.8,
 44            "model_output_1": 0.1,
 45            "model_output_2": 0.1,
 46        }
 47        kaon_entry = {
 48            "Complex_q": 1.0,
 49            "Complex_p": 2.0,
 50            "Complex_pid": Pid.POS_KAON.value,
 51            "Complex_mass2": 0.4,
 52            "Complex_pT": 0.5,
 53            "Complex_rapidity": 3.0,
 54            "model_output_0": 0.1,
 55            "model_output_1": 0.8,
 56            "model_output_2": 0.1,
 57        }
 58        pion_entry = {
 59            "Complex_q": 1.0,
 60            "Complex_p": 2.0,
 61            "Complex_pid": Pid.POS_PION.value,
 62            "Complex_mass2": 0.2,
 63            "Complex_pT": 0.5,
 64            "Complex_rapidity": 3.0,
 65            "model_output_0": 0.1,
 66            "model_output_1": 0.1,
 67            "model_output_2": 0.8,
 68        }
 69        bckgr_entry = {
 70            "Complex_q": 1.0,
 71            "Complex_p": 2.0,
 72            "Complex_pid": 1010,
 73            "Complex_mass2": 0.1,
 74            "Complex_pT": 0.5,
 75            "Complex_rapidity": 3.0,
 76            "model_output_0": 0.3,
 77            "model_output_1": 0.4,
 78            "model_output_2": 0.3,
 79        }
 80        complete_data = [
 81            first_entry,
 82            second_entry,
 83            proton_entry,
 84            {**proton_entry, "Complex_mass2": 0.9},
 85            kaon_entry,
 86            {**kaon_entry, "Complex_mass2": 0.6},
 87            pion_entry,
 88            {**pion_entry, "Complex_mass2": 0.3},
 89            bckgr_entry,
 90            {**bckgr_entry, "Complex_mass2": 0.0},
 91        ]
 92        self.json_data = """{"var_names": {"momentum": "Complex_p","charge": "Complex_q","mass2": "Complex_mass2","pid": "Complex_pid"},
 93                    "vars_to_draw": ["Complex_mass2", "Complex_p"]}"""
 94        test_config_path = f"{Path(__file__).resolve().parent}/test_config.json"
 95        with patch("builtins.open", mock_open(read_data=self.json_data)):
 96            self.validate = ValidateModel(
 97                2, 4, False, test_config_path, pd.DataFrame(complete_data)
 98            )
 99            self.validate_false = ValidateModel(
100                2, 4, True, test_config_path, pd.DataFrame(complete_data)
101            )

Hook method for setting up the test fixture before exercising it.

def test_get_n_classes(self):
103    def test_get_n_classes(self):
104        self.assertEqual(self.validate.get_n_classes(), 4)
def test_xgb_preds(self):
106    def test_xgb_preds(self):
107        self.validate.xgb_preds(0.7, 0.7, 0.7)
108        df = self.validate.particles_df
109        self.assertEqual(df[df["Complex_mass2"] == 1.1]["xgb_preds"].item(), 0)
110        self.assertEqual(df[df["Complex_mass2"] == 0.5]["xgb_preds"].item(), 3)
def test_remap_names(self):
112    def test_remap_names(self):
113        self.validate.remap_names()
114        df = self.validate.particles_df
115        self.assertEqual(df[df["Complex_mass2"] == 0.2]["Complex_pid"].item(), 2)
116        self.assertEqual(df[df["Complex_mass2"] == 0.4]["Complex_pid"].item(), 1)
117        self.assertEqual(df[df["Complex_mass2"] == 0.8]["Complex_pid"].item(), 0)
118        self.validate_false.remap_names()
def test_save_df(self):
120    def test_save_df(self):
121        self.validate.save_df()
def test_generate_plots(self):
123    def test_generate_plots(self):
124        self.validate.xgb_preds(0.7, 0.7, 0.7)
125        # with patch("builtins.open", mock_open(read_data=self.json_data)):
126        # should be mock json but
127        # https://github.com/julnow/ml-pid-cbm/actions/runs/5004465285/jobs/9029522575
128        self.validate.generate_plots()
def test_evaluate_probas(self):
130    def test_evaluate_probas(self):
131        self.validate.evaluate_probas(0.1, 0.9, 5, 50)
def test_confusion_matrix_and_stats(self):
133    def test_confusion_matrix_and_stats(self):
134        self.validate.xgb_preds(0.7, 0.7, 0.7)
135        self.validate.confusion_matrix_and_stats()
def test_parse_model_name(self):
137    def test_parse_model_name(self):
138        model_name_positive = "model_0.0_6.0_positive"
139        lower_p, upper_p, anti = ValidateModel.parse_model_name(model_name_positive)
140        self.assertEqual([lower_p, upper_p, anti], [0.0, 6.0, False])
141        model_name_anti = "model_3.0_6.0_anti"
142        lower_p, upper_p, anti = ValidateModel.parse_model_name(model_name_anti)
143        self.assertEqual([lower_p, upper_p, anti], [3.0, 6.0, True])
144        model_name_incorrect = "model_anti_1_4"
145        self.assertRaises(
146            ValueError, lambda: ValidateModel.parse_model_name(model_name_incorrect)
147        )
@classmethod
def setUpClass(cls):
149    @classmethod
150    def setUpClass(cls):
151        cls.test_dir = Path(__file__).resolve().parent
152        cls.img_dir = f"{cls.test_dir}/testimg"
153        if not os.path.exists(cls.img_dir):
154            os.makedirs(cls.img_dir)
155        os.chdir(cls.img_dir)

Hook method for setting up class fixture before running tests in the class.

@classmethod
def tearDownClass(cls):
157    @classmethod
158    def tearDownClass(cls):
159        os.chdir(cls.test_dir)
160        shutil.rmtree(cls.img_dir)

Hook method for deconstructing the class fixture after running all tests in the class.

Inherited Members
unittest.case.TestCase
TestCase
addTypeEqualityFunc
addCleanup
addClassCleanup
tearDown
countTestCases
defaultTestResult
shortDescription
id
subTest
run
doCleanups
doClassCleanups
debug
skipTest
fail
assertFalse
assertTrue
assertRaises
assertWarns
assertLogs
assertEqual
assertNotEqual
assertAlmostEqual
assertNotAlmostEqual
assertSequenceEqual
assertListEqual
assertTupleEqual
assertSetEqual
assertIn
assertNotIn
assertIs
assertIsNot
assertDictEqual
assertDictContainsSubset
assertCountEqual
assertMultiLineEqual
assertLess
assertLessEqual
assertGreater
assertGreaterEqual
assertIsNone
assertIsNotNone
assertIsInstance
assertNotIsInstance
assertRaisesRegex
assertWarnsRegex
assertRegex
assertNotRegex
failUnlessRaises
failIf
assertRaisesRegexp
assertRegexpMatches
assertNotRegexpMatches
failUnlessEqual
assertEquals
failIfEqual
assertNotEquals
failUnlessAlmostEqual
assertAlmostEquals
failIfAlmostEqual
assertNotAlmostEquals
failUnless
assert_