fix: don't reset context.scratch between files (#1151)

#453 fixed scratch leaking between files by setting it to empty, but that drops all the scratch space that was set up before the codemod runs (e.g. in the transformer's constructor)

This PR improves the fix by preserving the initial scratch.
This commit is contained in:
Zsolt Dollenstein 2024-05-21 15:52:49 -04:00 committed by GitHub
parent 71b0a1288b
commit db696e6348
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 12 additions and 2 deletions

View file

@ -14,6 +14,7 @@ import subprocess
import sys
import time
import traceback
from copy import deepcopy
from dataclasses import dataclass, replace
from multiprocessing import cpu_count, Pool
from pathlib import Path
@ -214,6 +215,7 @@ def _execute_transform( # noqa: C901
transformer: Codemod,
filename: str,
config: ExecutionConfig,
scratch: Dict[str, object],
) -> ExecutionResult:
for pattern in config.blacklist_patterns:
if re.fullmatch(pattern, filename):
@ -251,7 +253,7 @@ def _execute_transform( # noqa: C901
transformer.context = replace(
transformer.context,
filename=filename,
scratch={},
scratch=deepcopy(scratch),
)
# determine the module and package name for this file
@ -634,6 +636,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
"transformer": transform,
"filename": filename,
"config": config,
"scratch": transform.context.scratch,
}
for filename in files
]

View file

@ -2,7 +2,7 @@ import contextlib
import os
from pathlib import Path
from tempfile import TemporaryDirectory
from typing import Generator
from typing import Dict, Generator
from unittest import TestCase
from libcst import BaseExpression, Call, matchers as m, Name
@ -16,7 +16,14 @@ from libcst.codemod.visitors import AddImportsVisitor
class PrintToPPrintCommand(VisitorBasedCodemodCommand):
def __init__(self, context: CodemodContext, **kwargs: Dict[str, object]) -> None:
super().__init__(context, **kwargs)
self.context.scratch["PPRINT_WAS_HERE"] = True
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
if not self.context.scratch["PPRINT_WAS_HERE"]:
raise AssertionError("Scratch space lost")
if m.matches(updated_node, m.Call(func=m.Name("print"))):
AddImportsVisitor.add_needed_import(
self.context,