mirror of
https://github.com/StevenBlack/hosts.git
synced 2026-07-01 02:36:52 +00:00
Refactor out global settings usage in update logic
This commit is contained in:
@@ -12,8 +12,8 @@ from updateHostsFile import (
|
||||
normalize_rule, path_join_robust, print_failure, print_success,
|
||||
prompt_for_exclusions, prompt_for_move, prompt_for_flush_dns_cache,
|
||||
prompt_for_update, query_yes_no, recursive_glob, remove_old_hosts_file,
|
||||
supports_color, strip_rule, update_readme_data, write_data,
|
||||
write_opening_header)
|
||||
supports_color, strip_rule, update_all_sources, update_readme_data,
|
||||
write_data, write_opening_header)
|
||||
|
||||
import updateHostsFile
|
||||
import unittest
|
||||
@@ -190,9 +190,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
||||
|
||||
sys.stdout = StringIO()
|
||||
|
||||
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
|
||||
@mock.patch("updateHostsFile.query_yes_no", return_value=False)
|
||||
def test_freshen_no_update(self, _, mock_update):
|
||||
def test_freshen_no_update(self, _):
|
||||
hosts_file = os.path.join(self.test_dir, "hosts")
|
||||
hosts_data = "This data should not be overwritten"
|
||||
|
||||
@@ -204,10 +203,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
||||
|
||||
dir_count = self.dir_count
|
||||
|
||||
prompt_for_update(freshen=True, update_auto=False)
|
||||
|
||||
mock_update.assert_not_called()
|
||||
mock_update.reset_mock()
|
||||
update_sources = prompt_for_update(freshen=True, update_auto=False)
|
||||
self.assertFalse(update_sources)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expected = ("OK, we'll stick with "
|
||||
@@ -222,9 +219,8 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
||||
contents = f.read()
|
||||
self.assertEqual(contents, hosts_data)
|
||||
|
||||
@mock.patch("updateHostsFile.update_all_sources", return_value=0)
|
||||
@mock.patch("updateHostsFile.query_yes_no", return_value=True)
|
||||
def test_freshen_update(self, _, mock_update):
|
||||
def test_freshen_update(self, _):
|
||||
hosts_file = os.path.join(self.test_dir, "hosts")
|
||||
hosts_data = "This data should not be overwritten"
|
||||
|
||||
@@ -237,10 +233,9 @@ class TestPromptForUpdate(BaseStdout, BaseMockDir):
|
||||
dir_count = self.dir_count
|
||||
|
||||
for update_auto in (False, True):
|
||||
prompt_for_update(freshen=True, update_auto=update_auto)
|
||||
|
||||
self.assert_called_once(mock_update)
|
||||
mock_update.reset_mock()
|
||||
update_sources = prompt_for_update(freshen=True,
|
||||
update_auto=update_auto)
|
||||
self.assertTrue(update_sources)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
self.assertEqual(output, "")
|
||||
@@ -547,6 +542,76 @@ class TestMatchesExclusions(Base):
|
||||
# End Exclusion Logic
|
||||
|
||||
|
||||
# Update Logic
|
||||
class TestUpdateAllSources(BaseStdout):
|
||||
|
||||
def setUp(self):
|
||||
BaseStdout.setUp(self)
|
||||
|
||||
self.source_data_filename = "data.json"
|
||||
self.host_filename = "hosts.txt"
|
||||
|
||||
@mock.patch(builtins() + ".open")
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=[])
|
||||
def test_no_sources(self, _, mock_open):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
mock_open.assert_not_called()
|
||||
|
||||
@mock.patch(builtins() + ".open", return_value=mock.Mock())
|
||||
@mock.patch("json.load", return_value={"url": "example.com"})
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
|
||||
@mock.patch("updateHostsFile.write_data", return_value=0)
|
||||
@mock.patch("updateHostsFile.get_file_by_url", return_value="file_data")
|
||||
def test_one_source(self, mock_get, mock_write, *_):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
self.assert_called_once(mock_write)
|
||||
self.assert_called_once(mock_get)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expected = "Updating source from example.com"
|
||||
|
||||
self.assertIn(expected, output)
|
||||
|
||||
@mock.patch(builtins() + ".open", return_value=mock.Mock())
|
||||
@mock.patch("json.load", return_value={"url": "example.com"})
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo"])
|
||||
@mock.patch("updateHostsFile.write_data", return_value=0)
|
||||
@mock.patch("updateHostsFile.get_file_by_url",
|
||||
return_value=Exception("fail"))
|
||||
def test_source_fail(self, mock_get, mock_write, *_):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
mock_write.assert_not_called()
|
||||
self.assert_called_once(mock_get)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expecteds = ["Updating source from example.com",
|
||||
"Error in updating source: example.com"]
|
||||
for expected in expecteds:
|
||||
self.assertIn(expected, output)
|
||||
|
||||
@mock.patch(builtins() + ".open", return_value=mock.Mock())
|
||||
@mock.patch("json.load", side_effect=[{"url": "example.com"},
|
||||
{"url": "example2.com"}])
|
||||
@mock.patch("updateHostsFile.recursive_glob", return_value=["foo", "bar"])
|
||||
@mock.patch("updateHostsFile.write_data", return_value=0)
|
||||
@mock.patch("updateHostsFile.get_file_by_url",
|
||||
side_effect=[Exception("fail"), "file_data"])
|
||||
def test_sources_fail_succeed(self, mock_get, mock_write, *_):
|
||||
update_all_sources(self.source_data_filename, self.host_filename)
|
||||
self.assert_called_once(mock_write)
|
||||
|
||||
get_calls = [mock.call("example.com"), mock.call("example2.com")]
|
||||
mock_get.assert_has_calls(get_calls)
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
expecteds = ["Updating source from example.com",
|
||||
"Error in updating source: example.com",
|
||||
"Updating source from example2.com"]
|
||||
for expected in expecteds:
|
||||
self.assertIn(expected, output)
|
||||
# End Update Logic
|
||||
|
||||
|
||||
# File Logic
|
||||
class TestNormalizeRule(BaseStdout):
|
||||
|
||||
|
||||
Reference in New Issue
Block a user