| """Tests for ``agent.tools.parse_config._parse_config``. |
| |
| Covers all three input shapes (Python AST, JSON, YAML), redaction of every |
| secret pattern we ship, error paths for missing/malformed inputs, and the |
| WorkloadConfig field mapping for HF TrainingArguments + DataLoader kwargs. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| import pytest |
|
|
| from agent.schemas import WorkloadConfig |
| from agent.tools.parse_config import PARSE_CONFIG, _parse_config, _parse_config_full |
|
|
| FIXTURES = Path(__file__).parent / "fixtures" |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestPythonScript: |
| def test_returns_ok(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| assert result.ok, result.error |
| assert result.result is not None |
|
|
| def test_extracts_training_arguments_kwargs(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| cfg = result.result |
| assert cfg["batch_size"] == 4 |
| assert cfg["grad_accum_steps"] == 8 |
| assert cfg["lr"] == pytest.approx(2e-4) |
| assert cfg["warmup_steps"] == 100 |
| assert cfg["optimizer"] == "adamw_torch" |
| |
| assert cfg["precision"] == "fp16" |
|
|
| def test_dataloader_kwargs_captured(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| cfg = result.result |
| assert cfg["dataloader_workers"] == 0 |
| assert cfg["dataloader_pin_memory"] is False |
| assert cfg["dataloader_prefetch_factor"] == 2 |
| assert cfg["dataloader_persistent_workers"] is False |
|
|
| def test_torch_compile_call_flips_flag(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| |
| |
| assert result.result["torch_compile"] is True |
|
|
| def test_gradient_checkpointing_enable_call(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| assert result.result["gradient_checkpointing"] is True |
|
|
| def test_env_vars_captured(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| env = result.result["env_vars"] |
| assert env["HSA_FORCE_FINE_GRAIN_PCIE"] == "1" |
| assert env["MIOPEN_FIND_MODE"] == "3" |
| assert env["NCCL_MIN_NCHANNELS"] == "112" |
|
|
| def test_lora_rank_extracted(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| assert result.result["lora_rank"] == 16 |
|
|
| def test_attention_impl_from_from_pretrained(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| assert result.result["attention_impl"] == "eager" |
|
|
| def test_model_name_resolved(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| assert result.result["model_name"] == "Qwen/Qwen2.5-7B-Instruct" |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestJsonConfig: |
| def test_returns_ok(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.json")) |
| assert result.ok, result.error |
|
|
| def test_field_mapping(self) -> None: |
| cfg = _parse_config(str(FIXTURES / "sample_train.json")).result |
| assert cfg["model_name"] == "Qwen/Qwen2.5-7B-Instruct" |
| assert cfg["batch_size"] == 8 |
| assert cfg["grad_accum_steps"] == 4 |
| assert cfg["seq_len"] == 4096 |
| assert cfg["precision"] == "bf16" |
| assert cfg["optimizer"] == "adamw_torch_fused" |
| assert cfg["torch_compile"] is True |
| assert cfg["gradient_checkpointing"] is True |
| assert cfg["dataloader_workers"] == 4 |
| assert cfg["dataloader_pin_memory"] is True |
| assert cfg["dataloader_persistent_workers"] is True |
| assert cfg["attention_impl"] == "flash" |
|
|
| def test_env_vars_dict(self) -> None: |
| cfg = _parse_config(str(FIXTURES / "sample_train.json")).result |
| assert cfg["env_vars"]["HSA_FORCE_FINE_GRAIN_PCIE"] == "1" |
|
|
| def test_extras_collects_unmapped_fields(self) -> None: |
| cfg = _parse_config(str(FIXTURES / "sample_train.json")).result |
| |
| assert cfg["extras"]["num_train_epochs"] == 3 |
| assert cfg["extras"]["save_steps"] == 500 |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestYamlConfig: |
| def test_returns_ok(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.yaml")) |
| assert result.ok, result.error |
|
|
| def test_field_mapping(self) -> None: |
| cfg = _parse_config(str(FIXTURES / "sample_train.yaml")).result |
| assert cfg["batch_size"] == 2 |
| assert cfg["grad_accum_steps"] == 16 |
| assert cfg["seq_len"] == 8192 |
| assert cfg["precision"] == "bf16" |
| assert cfg["torch_compile"] is False |
| assert cfg["gradient_checkpointing"] is True |
| assert cfg["attention_impl"] == "sdpa" |
| assert cfg["dataloader_workers"] == 8 |
| assert cfg["dataloader_persistent_workers"] is True |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestRedaction: |
| def test_python_script_redactions(self) -> None: |
| cfg = _parse_config(str(FIXTURES / "sample_train.py")).result |
| labels = set(cfg["redactions"]) |
| |
| assert "hf_token" in labels |
| assert "openai_key" in labels |
| assert "github_token" in labels |
| assert "bearer_token" in labels |
| assert "home_path" in labels |
| assert "s3_uri" in labels |
| assert "ws_uri" in labels |
|
|
| def test_raw_source_is_scrubbed(self) -> None: |
| |
| |
| |
| cfg = _parse_config_full(str(FIXTURES / "sample_train.py")) |
| assert isinstance(cfg, WorkloadConfig) |
| raw = cfg.raw_source |
| assert "hf_abcdefghijklmnopqrstuvwxyz123456" not in raw |
| assert "sk-abcdefghijklmnopqrstuvwxyz1234567890" not in raw |
| assert "gho_abcdefghijklmnopqrstuvwxyz123456" not in raw |
| assert "/home/researcher/datasets/alpaca" not in raw |
| assert "s3://my-team/checkpoints/qwen-lora/" not in raw |
| assert "wss://logs.internal.example.com/stream" not in raw |
| assert "<REDACTED:hf_token>" in raw |
| assert "<REDACTED:openai_key>" in raw |
|
|
| def test_raw_source_excluded_from_tool_result(self) -> None: |
| |
| |
| cfg = _parse_config(str(FIXTURES / "sample_train.py")).result |
| assert "raw_source" not in cfg |
|
|
| def test_json_redactions(self) -> None: |
| cfg = _parse_config(str(FIXTURES / "sample_train.json")).result |
| labels = set(cfg["redactions"]) |
| assert "hf_token" in labels |
| assert "s3_uri" in labels |
| |
| |
| full = _parse_config_full(str(FIXTURES / "sample_train.json")) |
| assert isinstance(full, WorkloadConfig) |
| assert "hf_jsonsamplehfabcdefghijklmnopqrs" not in full.raw_source |
|
|
| def test_extras_values_are_scrubbed(self) -> None: |
| |
| |
| cfg = _parse_config(str(FIXTURES / "sample_train.json")).result |
| extras = cfg["extras"] |
| assert "hf_jsonsamplehfabcdefghijklmnopqrs" not in extras.get("hub_token", "") |
| assert extras.get("hub_token", "").startswith("<REDACTED:") |
| assert extras.get("checkpoint_uri", "").startswith("<REDACTED:") |
|
|
| def test_yaml_redactions(self) -> None: |
| cfg = _parse_config(str(FIXTURES / "sample_train.yaml")).result |
| labels = set(cfg["redactions"]) |
| assert "hf_token" in labels |
| assert "bearer_token" in labels |
| assert "home_path" in labels |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestErrors: |
| def test_missing_file(self) -> None: |
| result = _parse_config("/tmp/definitely-does-not-exist-xyz.py") |
| assert result.ok is False |
| assert "not found" in (result.error or "").lower() |
|
|
| def test_unsupported_extension(self, tmp_path: Path) -> None: |
| bad = tmp_path / "config.toml" |
| bad.write_text("model_name = 'foo'\n") |
| result = _parse_config(str(bad)) |
| assert result.ok is False |
| assert "unsupported" in (result.error or "").lower() |
|
|
| def test_malformed_python(self, tmp_path: Path) -> None: |
| bad = tmp_path / "broken.py" |
| bad.write_text("def oops(:\n pass\n") |
| result = _parse_config(str(bad)) |
| assert result.ok is False |
| assert "parse error" in (result.error or "").lower() |
|
|
| def test_malformed_json(self, tmp_path: Path) -> None: |
| bad = tmp_path / "broken.json" |
| bad.write_text("{not really json") |
| result = _parse_config(str(bad)) |
| assert result.ok is False |
| assert "json" in (result.error or "").lower() |
|
|
| def test_json_top_level_must_be_dict(self, tmp_path: Path) -> None: |
| bad = tmp_path / "list.json" |
| bad.write_text(json.dumps([{"foo": 1}])) |
| result = _parse_config(str(bad)) |
| assert result.ok is False |
|
|
| def test_yaml_top_level_must_be_mapping(self, tmp_path: Path) -> None: |
| bad = tmp_path / "scalar.yaml" |
| bad.write_text("- 1\n- 2\n") |
| result = _parse_config(str(bad)) |
| assert result.ok is False |
|
|
|
|
| |
| |
| |
|
|
|
|
| class TestSchema: |
| def test_result_round_trips_through_workload_config(self) -> None: |
| result = _parse_config(str(FIXTURES / "sample_train.py")) |
| |
| cfg = WorkloadConfig(**result.result) |
| assert cfg.model_name == "Qwen/Qwen2.5-7B-Instruct" |
|
|
| def test_defaults_when_field_absent(self, tmp_path: Path) -> None: |
| |
| path = tmp_path / "tiny.json" |
| path.write_text(json.dumps({"model_name": "test/tiny"})) |
| result = _parse_config(str(path)) |
| assert result.ok |
| cfg = result.result |
| assert cfg["batch_size"] == 1 |
| assert cfg["precision"] == "fp16" |
| assert cfg["optimizer"] == "adamw_torch" |
| assert cfg["gradient_checkpointing"] is False |
| assert cfg["redactions"] == [] |
|
|
| def test_tool_definition_unchanged_in_shape(self) -> None: |
| |
| assert PARSE_CONFIG.name == "parse_config" |
| assert PARSE_CONFIG.fn is _parse_config |
| assert "file_path" in PARSE_CONFIG.input_schema["properties"] |
|
|
|
|
| |
| |
| |
| |
| |
| |
|
|
|
|
| REPO_ROOT = Path(__file__).resolve().parent.parent |
|
|
|
|
| class TestCanonicalWorkload: |
| def test_canonical_workload_extracts_full_config(self) -> None: |
| """The canonical demo workload must yield batch_size=4, lr=2e-4, etc. |
| — not HF defaults. Catches the `**dict_var` splat regression where |
| every TrainingArguments kwarg disappears. |
| """ |
| result = _parse_config(str(REPO_ROOT / "workloads" / "train_qwen_lora.py")) |
| assert result.ok, result.error |
| cfg = result.result |
| assert cfg["model_name"] == "Qwen/Qwen2.5-7B-Instruct" |
| assert cfg["batch_size"] == 4, ( |
| "expected batch_size=4 from per_device_train_batch_size; " |
| "did `**_ta_kwargs` splat hide the kwargs?" |
| ) |
| assert cfg["grad_accum_steps"] == 8 |
| assert cfg["lr"] == 2e-4 |
| assert cfg["warmup_steps"] == 100 |
| assert cfg["precision"] == "fp16" |
| assert cfg["attention_impl"] == "eager" |
| assert cfg["dataloader_workers"] == 0 |
| assert cfg["dataloader_pin_memory"] is False |
| assert cfg["lora_rank"] == 16 |
| assert cfg["torch_compile"] is False |
| assert cfg["env_vars"]["HSA_FORCE_FINE_GRAIN_PCIE"] == "1" |
|
|
|
|
| class TestSplatKwargsResolution: |
| """`_ta = dict(k=v); Foo(**_ta)` must resolve back through the dict |
| constant. Defensive — the canonical workload no longer uses this |
| pattern, but third-party scripts often do. |
| """ |
|
|
| def test_dict_function_call_splat(self, tmp_path) -> None: |
| src = """ |
| from transformers import TrainingArguments |
| |
| _ta = dict( |
| per_device_train_batch_size=8, |
| gradient_accumulation_steps=2, |
| fp16=True, |
| optim=\"adamw_torch_fused\", |
| ) |
| training_args = TrainingArguments(output_dir=\"./out\", **_ta) |
| """ |
| p = tmp_path / "splat.py" |
| p.write_text(src) |
| cfg = _parse_config(str(p)).result |
| assert cfg["batch_size"] == 8 |
| assert cfg["grad_accum_steps"] == 2 |
| assert cfg["precision"] == "fp16" |
| assert cfg["optimizer"] == "adamw_torch_fused" |
|
|
| def test_dict_literal_splat(self, tmp_path) -> None: |
| src = """ |
| from transformers import TrainingArguments |
| |
| _ta = { |
| "per_device_train_batch_size": 16, |
| "bf16": True, |
| } |
| training_args = TrainingArguments(output_dir=\"./out\", **_ta) |
| """ |
| p = tmp_path / "splat_literal.py" |
| p.write_text(src) |
| cfg = _parse_config(str(p)).result |
| assert cfg["batch_size"] == 16 |
| assert cfg["precision"] == "bf16" |
|
|
| def test_explicit_kwarg_overrides_splat(self, tmp_path) -> None: |
| src = """ |
| from transformers import TrainingArguments |
| |
| _ta = dict(per_device_train_batch_size=8) |
| training_args = TrainingArguments(per_device_train_batch_size=32, **_ta) |
| """ |
| p = tmp_path / "splat_override.py" |
| p.write_text(src) |
| cfg = _parse_config(str(p)).result |
| |
| |
| assert cfg["batch_size"] == 32 |
|
|
|
|