mirror of
https://github.com/Instagram/LibCST.git
synced 2025-12-23 10:35:53 +00:00
fix: don't reset context.scratch between files (#1151)
#453 fixed scratch leaking between files by setting it to empty, but that drops all the scratch space that was set up before the codemod runs (e.g. in the transformer's constructor) This PR improves the fix by preserving the initial scratch.
This commit is contained in:
parent
71b0a1288b
commit
db696e6348
2 changed files with 12 additions and 2 deletions
|
|
@ -14,6 +14,7 @@ import subprocess
|
|||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass, replace
|
||||
from multiprocessing import cpu_count, Pool
|
||||
from pathlib import Path
|
||||
|
|
@ -214,6 +215,7 @@ def _execute_transform( # noqa: C901
|
|||
transformer: Codemod,
|
||||
filename: str,
|
||||
config: ExecutionConfig,
|
||||
scratch: Dict[str, object],
|
||||
) -> ExecutionResult:
|
||||
for pattern in config.blacklist_patterns:
|
||||
if re.fullmatch(pattern, filename):
|
||||
|
|
@ -251,7 +253,7 @@ def _execute_transform( # noqa: C901
|
|||
transformer.context = replace(
|
||||
transformer.context,
|
||||
filename=filename,
|
||||
scratch={},
|
||||
scratch=deepcopy(scratch),
|
||||
)
|
||||
|
||||
# determine the module and package name for this file
|
||||
|
|
@ -634,6 +636,7 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
|
|||
"transformer": transform,
|
||||
"filename": filename,
|
||||
"config": config,
|
||||
"scratch": transform.context.scratch,
|
||||
}
|
||||
for filename in files
|
||||
]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,7 @@ import contextlib
|
|||
import os
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Generator
|
||||
from typing import Dict, Generator
|
||||
from unittest import TestCase
|
||||
|
||||
from libcst import BaseExpression, Call, matchers as m, Name
|
||||
|
|
@ -16,7 +16,14 @@ from libcst.codemod.visitors import AddImportsVisitor
|
|||
|
||||
|
||||
class PrintToPPrintCommand(VisitorBasedCodemodCommand):
|
||||
def __init__(self, context: CodemodContext, **kwargs: Dict[str, object]) -> None:
|
||||
super().__init__(context, **kwargs)
|
||||
self.context.scratch["PPRINT_WAS_HERE"] = True
|
||||
|
||||
def leave_Call(self, original_node: Call, updated_node: Call) -> BaseExpression:
|
||||
if not self.context.scratch["PPRINT_WAS_HERE"]:
|
||||
raise AssertionError("Scratch space lost")
|
||||
|
||||
if m.matches(updated_node, m.Call(func=m.Name("print"))):
|
||||
AddImportsVisitor.add_needed_import(
|
||||
self.context,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue