diff --git a/archivebox/cli/archivebox_add.py b/archivebox/cli/archivebox_add.py index c729e9fb..c692750b 100644 --- a/archivebox/cli/archivebox_add.py +++ b/archivebox/cli/archivebox_add.py @@ -10,7 +10,7 @@ from typing import List, Optional, IO from ..main import add, docstring from ..config import OUTPUT_DIR, ONLY_NEW -from .logging import SmartFormatter, reject_stdin +from .logging import SmartFormatter, accept_stdin @docstring(add.__doc__) @@ -55,9 +55,20 @@ def main(args: Optional[List[str]]=None, stdin: Optional[IO]=None, pwd: Optional help="Recursively archive all linked pages up to this many hops away" ) command = parser.parse_args(args or ()) - reject_stdin(__command__, stdin) + import_string = accept_stdin(stdin) + if import_string and command.import_path: + stderr( + '[X] You should pass an import path or a page url as an argument or in stdin but not both\n', + color='red', + ) + raise SystemExit(2) + elif import_string: + import_path = import_string + else: + import_path = command.import_path + add( - import_str=command.import_path, + import_str=import_path, import_path=None, update_all=command.update_all, index_only=command.index_only, @@ -66,7 +77,7 @@ def main(args: Optional[List[str]]=None, stdin: Optional[IO]=None, pwd: Optional if command.depth == 1: add( import_str=None, - import_path=command.import_path, + import_path=import_path, update_all=command.update_all, index_only=command.index_only, out_dir=pwd or OUTPUT_DIR, diff --git a/tests/test_init.py b/tests/test_init.py index d592b0a1..97870459 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -31,9 +31,15 @@ def test_add_link(tmp_path, process): output_html = f.read() assert "Example Domain" in output_html -def test_add_link_does_not_support_stdin(tmp_path, process): +def test_add_link_support_stdin(tmp_path, process): os.chdir(tmp_path) stdin_process = subprocess.Popen(["archivebox", "add"], stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) - output = stdin_process.communicate(input="example.com".encode())[0] - assert "does not accept stdin" in output.decode("utf-8") + stdin_process.communicate(input="http://example.com".encode()) + archived_item_path = list(tmp_path.glob('archive/**/*'))[0] + + assert "index.json" in [x.name for x in archived_item_path.iterdir()] + + with open(archived_item_path / "index.json", "r") as f: + output_json = json.load(f) + assert "Example Domain" == output_json['history']['title'][0]['output']