mirror of
https://github.com/StevenBlack/hosts.git
synced 2026-07-01 02:36:52 +00:00
Reduce dependency on global settings variable
Global variables make code less modular and therefore more difficult to test.
This commit is contained in:
@@ -11,10 +11,15 @@ from updateHostsFile import (Colors, PY3, colorize, flush_dns_cache,
|
||||
move_hosts_file_into_place, normalize_rule,
|
||||
path_join_robust, print_failure, print_success,
|
||||
supports_color, query_yes_no, recursive_glob,
|
||||
strip_rule, write_data)
|
||||
remove_old_hosts_file, strip_rule,
|
||||
update_readme_data, write_data,
|
||||
write_opening_header)
|
||||
import updateHostsFile
|
||||
import unittest
|
||||
import tempfile
|
||||
import locale
|
||||
import shutil
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
@@ -48,6 +53,19 @@ class BaseStdout(Base):
|
||||
def tearDown(self):
|
||||
sys.stdout.close()
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
|
||||
class BaseMockDir(Base):
|
||||
|
||||
@property
|
||||
def dir_count(self):
|
||||
return len(os.listdir(self.test_dir))
|
||||
|
||||
def setUp(self):
|
||||
self.test_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
shutil.rmtree(self.test_dir)
|
||||
# End Base Test Classes
|
||||
|
||||
|
||||
@@ -119,12 +137,12 @@ class TestGatherCustomExclusions(BaseStdout):
|
||||
# File Logic
|
||||
class TestNormalizeRule(BaseStdout):
|
||||
|
||||
# Can only test non-matches because they don't
|
||||
# interact with the settings global variable.
|
||||
def test_no_match(self):
|
||||
kwargs = dict(target_ip="0.0.0.0", keep_domain_comments=False)
|
||||
|
||||
for rule in ["foo", "128.0.0.1", "bar.com/usa", "0.0.0 google",
|
||||
"0.1.2.3.4 foo/bar", "twitter.com"]:
|
||||
self.assertEqual(normalize_rule(rule), (None, None))
|
||||
self.assertEqual(normalize_rule(rule, **kwargs), (None, None))
|
||||
|
||||
output = sys.stdout.getvalue()
|
||||
sys.stdout = StringIO()
|
||||
@@ -132,6 +150,38 @@ class TestNormalizeRule(BaseStdout):
|
||||
expected = "==>" + rule + "<=="
|
||||
self.assertIn(expected, output)
|
||||
|
||||
def test_no_comments(self):
|
||||
for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
|
||||
rule = "127.0.0.1 google foo"
|
||||
expected = ("google", str(target_ip) + " google\n")
|
||||
|
||||
actual = normalize_rule(rule, target_ip=target_ip,
|
||||
keep_domain_comments=False)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
# Nothing gets printed if there's a match.
|
||||
output = sys.stdout.getvalue()
|
||||
self.assertEqual(output, "")
|
||||
|
||||
sys.stdout = StringIO()
|
||||
|
||||
def test_with_comments(self):
|
||||
for target_ip in ("0.0.0.0", "127.0.0.1", "8.8.8.8"):
|
||||
for comment in ("foo", "bar", "baz"):
|
||||
rule = "127.0.0.1 google " + comment
|
||||
expected = ("google", (str(target_ip) + " google # " +
|
||||
comment + "\n"))
|
||||
|
||||
actual = normalize_rule(rule, target_ip=target_ip,
|
||||
keep_domain_comments=True)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
# Nothing gets printed if there's a match.
|
||||
output = sys.stdout.getvalue()
|
||||
self.assertEqual(output, "")
|
||||
|
||||
sys.stdout = StringIO()
|
||||
|
||||
|
||||
class TestStripRule(Base):
|
||||
|
||||
@@ -153,6 +203,304 @@ class TestStripRule(Base):
|
||||
self.assertEqual(output, line)
|
||||
|
||||
|
||||
class TestWriteOpeningHeader(BaseMockDir):
|
||||
|
||||
def setUp(self):
|
||||
super(TestWriteOpeningHeader, self).setUp()
|
||||
self.final_file = BytesIO()
|
||||
|
||||
def test_missing_keyword(self):
|
||||
kwargs = dict(extensions="", outputsubfolder="",
|
||||
numberofrules=5, skipstatichosts=False)
|
||||
|
||||
for k in kwargs.keys():
|
||||
bad_kwargs = kwargs.copy()
|
||||
bad_kwargs.pop(k)
|
||||
|
||||
self.assertRaises(KeyError, write_opening_header,
|
||||
self.final_file, **bad_kwargs)
|
||||
|
||||
def test_basic(self):
|
||||
kwargs = dict(extensions="", outputsubfolder="",
|
||||
numberofrules=5, skipstatichosts=True)
|
||||
write_opening_header(self.final_file, **kwargs)
|
||||
|
||||
contents = self.final_file.getvalue()
|
||||
contents = contents.decode("UTF-8")
|
||||
|
||||
# Expected contents.
|
||||
for expected in (
|
||||
"# This hosts file is a merged collection",
|
||||
"# with a dash of crowd sourcing via Github",
|
||||
"# Number of unique domains: {count}".format(
|
||||
count=kwargs["numberofrules"]),
|
||||
"Fetch the latest version of this file:",
|
||||
"Project home page: https://github.com/StevenBlack/hosts",
|
||||
):
|
||||
self.assertIn(expected, contents)
|
||||
|
||||
# Expected non-contents.
|
||||
for expected in (
|
||||
"# Extensions added to this file:",
|
||||
"127.0.0.1 localhost",
|
||||
"127.0.0.1 local",
|
||||
"127.0.0.53",
|
||||
"127.0.1.1",
|
||||
):
|
||||
self.assertNotIn(expected, contents)
|
||||
|
||||
def test_basic_include_static_hosts(self):
|
||||
kwargs = dict(extensions="", outputsubfolder="",
|
||||
numberofrules=5, skipstatichosts=False)
|
||||
with self.mock_property("platform.system") as obj:
|
||||
obj.return_value = "Windows"
|
||||
write_opening_header(self.final_file, **kwargs)
|
||||
|
||||
contents = self.final_file.getvalue()
|
||||
contents = contents.decode("UTF-8")
|
||||
|
||||
# Expected contents.
|
||||
for expected in (
|
||||
"127.0.0.1 local",
|
||||
"127.0.0.1 localhost",
|
||||
"# This hosts file is a merged collection",
|
||||
"# with a dash of crowd sourcing via Github",
|
||||
"# Number of unique domains: {count}".format(
|
||||
count=kwargs["numberofrules"]),
|
||||
"Fetch the latest version of this file:",
|
||||
"Project home page: https://github.com/StevenBlack/hosts",
|
||||
):
|
||||
self.assertIn(expected, contents)
|
||||
|
||||
# Expected non-contents.
|
||||
for expected in (
|
||||
"# Extensions added to this file:",
|
||||
"127.0.0.53",
|
||||
"127.0.1.1",
|
||||
):
|
||||
self.assertNotIn(expected, contents)
|
||||
|
||||
def test_basic_include_static_hosts_linux(self):
|
||||
kwargs = dict(extensions="", outputsubfolder="",
|
||||
numberofrules=5, skipstatichosts=False)
|
||||
with self.mock_property("platform.system") as system:
|
||||
system.return_value = "Linux"
|
||||
|
||||
with self.mock_property("socket.gethostname") as hostname:
|
||||
hostname.return_value = "steven-hosts"
|
||||
write_opening_header(self.final_file, **kwargs)
|
||||
|
||||
contents = self.final_file.getvalue()
|
||||
contents = contents.decode("UTF-8")
|
||||
|
||||
# Expected contents.
|
||||
for expected in (
|
||||
"127.0.1.1",
|
||||
"127.0.0.53",
|
||||
"steven-hosts",
|
||||
"127.0.0.1 local",
|
||||
"127.0.0.1 localhost",
|
||||
"# This hosts file is a merged collection",
|
||||
"# with a dash of crowd sourcing via Github",
|
||||
"# Number of unique domains: {count}".format(
|
||||
count=kwargs["numberofrules"]),
|
||||
"Fetch the latest version of this file:",
|
||||
"Project home page: https://github.com/StevenBlack/hosts",
|
||||
):
|
||||
self.assertIn(expected, contents)
|
||||
|
||||
# Expected non-contents.
|
||||
expected = "# Extensions added to this file:"
|
||||
self.assertNotIn(expected, contents)
|
||||
|
||||
def test_extensions(self):
|
||||
kwargs = dict(extensions=["epsilon", "gamma", "mu", "phi"],
|
||||
outputsubfolder="", numberofrules=5,
|
||||
skipstatichosts=True)
|
||||
write_opening_header(self.final_file, **kwargs)
|
||||
|
||||
contents = self.final_file.getvalue()
|
||||
contents = contents.decode("UTF-8")
|
||||
|
||||
# Expected contents.
|
||||
for expected in (
|
||||
", ".join(kwargs["extensions"]),
|
||||
"# Extensions added to this file:",
|
||||
"# This hosts file is a merged collection",
|
||||
"# with a dash of crowd sourcing via Github",
|
||||
"# Number of unique domains: {count}".format(
|
||||
count=kwargs["numberofrules"]),
|
||||
"Fetch the latest version of this file:",
|
||||
"Project home page: https://github.com/StevenBlack/hosts",
|
||||
):
|
||||
self.assertIn(expected, contents)
|
||||
|
||||
# Expected non-contents.
|
||||
for expected in (
|
||||
"127.0.0.1 localhost",
|
||||
"127.0.0.1 local",
|
||||
"127.0.0.53",
|
||||
"127.0.1.1",
|
||||
):
|
||||
self.assertNotIn(expected, contents)
|
||||
|
||||
def test_no_preamble(self):
|
||||
# We should not even attempt to read this, as it is a directory.
|
||||
hosts_dir = os.path.join(self.test_dir, "myhosts")
|
||||
os.mkdir(hosts_dir)
|
||||
|
||||
kwargs = dict(extensions="", outputsubfolder="",
|
||||
numberofrules=5, skipstatichosts=True)
|
||||
|
||||
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
|
||||
updateHostsFile.BASEDIR_PATH = self.test_dir
|
||||
write_opening_header(self.final_file, **kwargs)
|
||||
|
||||
contents = self.final_file.getvalue()
|
||||
contents = contents.decode("UTF-8")
|
||||
|
||||
# Expected contents.
|
||||
for expected in (
|
||||
"# This hosts file is a merged collection",
|
||||
"# with a dash of crowd sourcing via Github",
|
||||
"# Number of unique domains: {count}".format(
|
||||
count=kwargs["numberofrules"]),
|
||||
"Fetch the latest version of this file:",
|
||||
"Project home page: https://github.com/StevenBlack/hosts",
|
||||
):
|
||||
self.assertIn(expected, contents)
|
||||
|
||||
# Expected non-contents.
|
||||
for expected in (
|
||||
"# Extensions added to this file:",
|
||||
"127.0.0.1 localhost",
|
||||
"127.0.0.1 local",
|
||||
"127.0.0.53",
|
||||
"127.0.1.1",
|
||||
):
|
||||
self.assertNotIn(expected, contents)
|
||||
|
||||
def test_preamble(self):
|
||||
hosts_file = os.path.join(self.test_dir, "myhosts")
|
||||
with open(hosts_file, "w") as f:
|
||||
f.write("peter-piper-picked-a-pepper")
|
||||
|
||||
kwargs = dict(extensions="", outputsubfolder="",
|
||||
numberofrules=5, skipstatichosts=True)
|
||||
|
||||
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
|
||||
updateHostsFile.BASEDIR_PATH = self.test_dir
|
||||
write_opening_header(self.final_file, **kwargs)
|
||||
|
||||
contents = self.final_file.getvalue()
|
||||
contents = contents.decode("UTF-8")
|
||||
|
||||
# Expected contents.
|
||||
for expected in (
|
||||
"peter-piper-picked-a-pepper",
|
||||
"# This hosts file is a merged collection",
|
||||
"# with a dash of crowd sourcing via Github",
|
||||
"# Number of unique domains: {count}".format(
|
||||
count=kwargs["numberofrules"]),
|
||||
"Fetch the latest version of this file:",
|
||||
"Project home page: https://github.com/StevenBlack/hosts",
|
||||
):
|
||||
self.assertIn(expected, contents)
|
||||
|
||||
# Expected non-contents.
|
||||
for expected in (
|
||||
"# Extensions added to this file:",
|
||||
"127.0.0.1 localhost",
|
||||
"127.0.0.1 local",
|
||||
"127.0.0.53",
|
||||
"127.0.1.1",
|
||||
):
|
||||
self.assertNotIn(expected, contents)
|
||||
|
||||
def tearDown(self):
|
||||
super(TestWriteOpeningHeader, self).tearDown()
|
||||
self.final_file.close()
|
||||
|
||||
|
||||
class TestUpdateReadmeData(BaseMockDir):
|
||||
|
||||
def setUp(self):
|
||||
super(TestUpdateReadmeData, self).setUp()
|
||||
self.readme_file = os.path.join(self.test_dir, "readmeData.json")
|
||||
|
||||
def test_missing_keyword(self):
|
||||
kwargs = dict(extensions="", outputsubfolder="",
|
||||
numberofrules="", sourcesdata="")
|
||||
|
||||
for k in kwargs.keys():
|
||||
bad_kwargs = kwargs.copy()
|
||||
bad_kwargs.pop(k)
|
||||
|
||||
self.assertRaises(KeyError, update_readme_data,
|
||||
self.readme_file, **bad_kwargs)
|
||||
|
||||
def test_add_fields(self):
|
||||
with open(self.readme_file, "w") as f:
|
||||
json.dump({"foo": "bar"}, f)
|
||||
|
||||
kwargs = dict(extensions=None, outputsubfolder="foo",
|
||||
numberofrules=5, sourcesdata="hosts")
|
||||
update_readme_data(self.readme_file, **kwargs)
|
||||
|
||||
expected = {
|
||||
"base": {
|
||||
"location": "foo" + self.sep,
|
||||
"sourcesdata": "hosts",
|
||||
"entries": 5,
|
||||
},
|
||||
"foo": "bar"
|
||||
}
|
||||
|
||||
with open(self.readme_file, "r") as f:
|
||||
actual = json.load(f)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_modify_fields(self):
|
||||
with open(self.readme_file, "w") as f:
|
||||
json.dump({"base": "soprano"}, f)
|
||||
|
||||
kwargs = dict(extensions=None, outputsubfolder="foo",
|
||||
numberofrules=5, sourcesdata="hosts")
|
||||
update_readme_data(self.readme_file, **kwargs)
|
||||
|
||||
expected = {
|
||||
"base": {
|
||||
"location": "foo" + self.sep,
|
||||
"sourcesdata": "hosts",
|
||||
"entries": 5,
|
||||
}
|
||||
}
|
||||
|
||||
with open(self.readme_file, "r") as f:
|
||||
actual = json.load(f)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
def test_set_extensions(self):
|
||||
with open(self.readme_file, "w") as f:
|
||||
json.dump({}, f)
|
||||
|
||||
kwargs = dict(extensions=["com", "org"], outputsubfolder="foo",
|
||||
numberofrules=5, sourcesdata="hosts")
|
||||
update_readme_data(self.readme_file, **kwargs)
|
||||
|
||||
expected = {
|
||||
"com-org": {
|
||||
"location": "foo" + self.sep,
|
||||
"sourcesdata": "hosts",
|
||||
"entries": 5,
|
||||
}
|
||||
}
|
||||
|
||||
with open(self.readme_file, "r") as f:
|
||||
actual = json.load(f)
|
||||
self.assertEqual(actual, expected)
|
||||
|
||||
|
||||
class TestMoveHostsFile(BaseStdout):
|
||||
|
||||
@mock.patch("os.path.abspath", side_effect=lambda f: f)
|
||||
@@ -312,6 +660,78 @@ class TestFlushDnsCache(BaseStdout):
|
||||
("Flushing the DNS cache by restarting "
|
||||
"NetworkManager.service succeeded")]:
|
||||
self.assertIn(expected, output)
|
||||
|
||||
|
||||
def mock_path_join_robust(*args):
|
||||
# We want to hard-code the backup hosts filename
|
||||
# instead of parametrizing based on current time.
|
||||
if len(args) == 2 and args[1].startswith("hosts-"):
|
||||
return os.path.join(args[0], "hosts-new")
|
||||
else:
|
||||
return os.path.join(*args)
|
||||
|
||||
|
||||
class TestRemoveOldHostsFile(BaseMockDir):
|
||||
|
||||
def setUp(self):
|
||||
super(TestRemoveOldHostsFile, self).setUp()
|
||||
self.hosts_file = os.path.join(self.test_dir, "hosts")
|
||||
|
||||
def test_remove_hosts_file(self):
|
||||
old_dir_count = self.dir_count
|
||||
|
||||
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
|
||||
updateHostsFile.BASEDIR_PATH = self.test_dir
|
||||
remove_old_hosts_file(backup=False)
|
||||
|
||||
new_dir_count = old_dir_count + 1
|
||||
self.assertEqual(self.dir_count, new_dir_count)
|
||||
|
||||
with open(self.hosts_file, "r") as f:
|
||||
contents = f.read()
|
||||
self.assertEqual(contents, "")
|
||||
|
||||
def test_remove_hosts_file_exists(self):
|
||||
with open(self.hosts_file, "w") as f:
|
||||
f.write("foo")
|
||||
|
||||
old_dir_count = self.dir_count
|
||||
|
||||
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
|
||||
updateHostsFile.BASEDIR_PATH = self.test_dir
|
||||
remove_old_hosts_file(backup=False)
|
||||
|
||||
new_dir_count = old_dir_count
|
||||
self.assertEqual(self.dir_count, new_dir_count)
|
||||
|
||||
with open(self.hosts_file, "r") as f:
|
||||
contents = f.read()
|
||||
self.assertEqual(contents, "")
|
||||
|
||||
@mock.patch("updateHostsFile.path_join_robust",
|
||||
side_effect=mock_path_join_robust)
|
||||
def test_remove_hosts_file_backup(self, _):
|
||||
with open(self.hosts_file, "w") as f:
|
||||
f.write("foo")
|
||||
|
||||
old_dir_count = self.dir_count
|
||||
|
||||
with self.mock_property("updateHostsFile.BASEDIR_PATH"):
|
||||
updateHostsFile.BASEDIR_PATH = self.test_dir
|
||||
remove_old_hosts_file(backup=True)
|
||||
|
||||
new_dir_count = old_dir_count + 1
|
||||
self.assertEqual(self.dir_count, new_dir_count)
|
||||
|
||||
with open(self.hosts_file, "r") as f:
|
||||
contents = f.read()
|
||||
self.assertEqual(contents, "")
|
||||
|
||||
new_hosts_file = self.hosts_file + "-new"
|
||||
|
||||
with open(new_hosts_file, "r") as f:
|
||||
contents = f.read()
|
||||
self.assertEqual(contents, "foo")
|
||||
# End File Logic
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user