Compare commits

..

No commits in common. "main" and "v0.4.0" have entirely different histories.
main ... v0.4.0

285 changed files with 6933 additions and 23629 deletions

View file

@ -1,7 +1,7 @@
root = true
[*.{py,pyi,rs,toml,md}]
charset = utf-8
charset = "utf-8"
end_of_line = lf
indent_size = 4
indent_style = space

210
.flake8
View file

@ -1,126 +1,69 @@
[flake8]
ignore =
# unnecessary list comprehension; A generator only better than a list
# comprehension if we don't always need to iterate through all items in
# the generator (based on the use case).
C407,
C407, # unnecessary list comprehension; A generator only better than a list
# comprehension if we don't always need to iterate through all items in
# the generator (based on the use case).
# The following codes belong to pycodestyle, and overlap with black:
# indentation contains mixed spaces and tabs
E101,
# indentation is not a multiple of four
E111,
# expected an indented block
E112,
# unexpected indentation
E113,
# indentation is not a multiple of four (comment)
E114,
# expected an indented block (comment)
E115,
# unexpected indentation (comment)
E116,
# continuation line under-indented for hanging indent
E121,
# continuation line missing indentation or outdented
E122,
# closing bracket does not match indentation of opening brackets line
E123,
# closing bracket does not match visual indentation
E124,
# continuation line with same indent as next logical line
E125,
# continuation line over-indented for hanging indent
E126,
# continuation line over-indented for visual indent; is harmless
# (over-indent is visually unambiguous) and currently generates too
# many warnings for existing code.
E127,
# continuation line under-indented for visual indent
E128,
# visually indented line with same indent as next logical line
E129,
# continuation line unaligned for hanging indent
E131,
# closing bracket is missing indentation
E133,
# whitespace after (
E201,
# whitespace before )
E202,
# whitespace before :; this warning is invalid for slices
E203,
# whitespace before (
E211,
# multiple spaces before operator
E221,
# multiple spaces after operator
E222,
# tab before operator
E223,
# tab after operator
E224,
# missing whitespace around operator
E225,
# missing whitespace around arithmetic operator
E226,
# missing whitespace around bitwise or shift operator
E227,
# missing whitespace around modulo operator
E228,
# missing whitespace after ,, ;, or :
E231,
# multiple spaces after ,
E241,
# tab after ,
E242,
# unexpected spaces around keyword / parameter equals
E251,
# at least two spaces before inline comment
E261,
# inline comment should start with #
E262,
# block comment should start with #
E265,
# too many leading # for block comment
E266,
# multiple spaces after keyword
E271,
# multiple spaces before keyword
E272,
# tab after keyword
E273,
# tab before keyword
E274,
# missing whitespace after keyword
E275,
# expected 1 blank line, found 0
E301,
# expected 2 blank lines, found 0
E302,
# too many blank lines (3)
E303,
# blank lines found after function decorator
E304,
# expected 2 blank lines after end of function or class
E305,
# expected 1 blank line before a nested definition
E306,
# multiple imports on one line
E401,
# line too long (> 79 characters)
E501,
# the backslash is redundant between brackets
E502,
# multiple statements on one line (colon)
E701,
# multiple statements on one line (semicolon)
E702,
# statement ends with a semicolon
E703,
# multiple statements on one line (def)
E704,
E101, # indentation contains mixed spaces and tabs
E111, # indentation is not a multiple of four
E112, # expected an indented block
E113, # unexpected indentation
E114, # indentation is not a multiple of four (comment)
E115, # expected an indented block (comment)
E116, # unexpected indentation (comment)
E121, # continuation line under-indented for hanging indent
E122, # continuation line missing indentation or outdented
E123, # closing bracket does not match indentation of opening brackets line
E124, # closing bracket does not match visual indentation
E125, # continuation line with same indent as next logical line
E126, # continuation line over-indented for hanging indent
E127, # continuation line over-indented for visual indent; is harmless
# (over-indent is visually unambiguous) and currently generates too
# many warnings for existing code.
E128, # continuation line under-indented for visual indent
E129, # visually indented line with same indent as next logical line
E131, # continuation line unaligned for hanging indent
E133, # closing bracket is missing indentation
E201, # whitespace after (
E202, # whitespace before )
E203, # whitespace before :; this warning is invalid for slices
E211, # whitespace before (
E221, # multiple spaces before operator
E222, # multiple spaces after operator
E223, # tab before operator
E224, # tab after operator
E225, # missing whitespace around operator
E226, # missing whitespace around arithmetic operator
E227, # missing whitespace around bitwise or shift operator
E228, # missing whitespace around modulo operator
E231, # missing whitespace after ,, ;, or :
E241, # multiple spaces after ,
E242, # tab after ,
E251, # unexpected spaces around keyword / parameter equals
E261, # at least two spaces before inline comment
E262, # inline comment should start with #
E265, # block comment should start with #
E266, # too many leading # for block comment
E271, # multiple spaces after keyword
E272, # multiple spaces before keyword
E273, # tab after keyword
E274, # tab before keyword
E275, # missing whitespace after keyword
E301, # expected 1 blank line, found 0
E302, # expected 2 blank lines, found 0
E303, # too many blank lines (3)
E304, # blank lines found after function decorator
E305, # expected 2 blank lines after end of function or class
E306, # expected 1 blank line before a nested definition
E401, # multiple imports on one line
E501, # line too long (> 79 characters)
E502, # the backslash is redundant between brackets
E701, # multiple statements on one line (colon)
E702, # multiple statements on one line (semicolon)
E703, # statement ends with a semicolon
E704, # multiple statements on one line (def)
# These are pycodestyle lints that black doesn't catch:
# E711, # comparison to None should be if cond is None:
# E712, # comparison to True should be if cond is True: or if cond:
@ -135,25 +78,16 @@ ignore =
# I think these are internal to pycodestyle?
# E901, # SyntaxError or IndentationError
# E902, # IOError
# isn't aware of type-only imports, results in false-positives
F811,
# indentation contains tabs
W191,
# trailing whitespace
W291,
# no newline at end of file
W292,
# blank line contains whitespace
W293,
# blank line at end of file
W391,
# line break before binary operator; binary operator in a new line is
# the standard
W503,
# line break after binary operator
W504,
# not part of PEP8; doc line too long (> 79 characters)
W505,
F811, # isn't aware of type-only imports, results in false-positives
W191, # indentation contains tabs
W291, # trailing whitespace
W292, # no newline at end of file
W293, # blank line contains whitespace
W391, # blank line at end of file
W503, # line break before binary operator; binary operator in a new line is
# the standard
W504, # line break after binary operator
W505, # not part of PEP8; doc line too long (> 79 characters)
# These are pycodestyle lints that black doesn't catch:
# W601, # .has_key() is deprecated, use in
# W602, # deprecated form of raising exception

View file

@ -1,31 +0,0 @@
[
{
"vers": "x86_64",
"os": "ubuntu-20.04"
},
{
"vers": "i686",
"os": "ubuntu-20.04"
},
{
"vers": "arm64",
"os": "macos-latest"
},
{
"vers": "auto64",
"os": "macos-latest"
},
{
"vers": "auto64",
"os": "windows-2019"
},
{
"vers": "aarch64",
"os": [
"self-hosted",
"linux",
"ARM64"
],
"on_ref_regex": "^refs/(heads/main|tags/.*)$"
}
]

View file

@ -1,18 +0,0 @@
# https://docs.github.com/en/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file
version: 2
updates:
- package-ecosystem: pip
directory: "/"
schedule:
interval: weekly
- package-ecosystem: cargo
directory: "/native"
schedule:
interval: weekly
- package-ecosystem: github-actions
directory: "/"
schedule:
interval: weekly

View file

@ -1,45 +1,283 @@
name: build
name: Python CI
on:
workflow_call:
push:
branches:
- main
pull_request:
jobs:
# Build python wheels
build:
name: Build wheels on ${{ matrix.os }}
# Run unittests
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os:
[
macos-latest,
ubuntu-latest,
ubuntu-24.04-arm,
windows-latest,
windows-11-arm,
]
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: [3.6, 3.7, 3.8, 3.9, "3.10"]
parser: [pure, native]
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- uses: actions/cache@v2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt', 'setup.py') }}
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install --upgrade --upgrade-strategy eager build -r requirements.txt -r requirements-dev.txt
- if: ${{ matrix.parser == 'native' }}
uses: actions-rs/toolchain@v1
with:
toolchain: stable
- run: >-
echo LIBCST_PARSER_TYPE=${{ matrix.parser }} >> $GITHUB_ENV
- name: Run Tests
run: python setup.py test
# Run linters
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- uses: actions/cache@v2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt', 'setup.py') }}
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install --upgrade --upgrade-strategy eager build -r requirements.txt -r requirements-dev.txt
- run: flake8
- run: ufmt check .
- run: python3 -m fixit.cli.run_rules
# Run pyre typechecker
typecheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- uses: actions/cache@v2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt', 'setup.py') }}
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install --upgrade --upgrade-strategy eager build -r requirements.txt -r requirements-dev.txt
- name: Make sure Pyre uses the working copy
run: pip install -e .
- run: pyre --version
- run: pyre -n check
- run: python libcst/tests/test_pyre_integration.py
- run: git diff --exit-code
# Upload test coverage
coverage:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- uses: actions/cache@v2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt', 'setup.py') }}
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install --upgrade --upgrade-strategy eager build -r requirements.txt -r requirements-dev.txt
- name: Generate Coverage
run: |
coverage run setup.py test
coverage xml -i
- uses: codecov/codecov-action@v2
with:
files: coverage.xml
fail_ci_if_error: true
verbose: true
- name: Archive Coverage
uses: actions/upload-artifact@v2
with:
name: coverage
path: coverage.xml
# Build the docs
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- uses: actions/cache@v2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt', 'setup.py') }}
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install --upgrade --upgrade-strategy eager build -r requirements.txt -r requirements-dev.txt
- uses: ts-graphviz/setup-graphviz@v1
- run: sphinx-build docs/source/ docs/build/
- name: Archive Docs
uses: actions/upload-artifact@v2
with:
name: sphinx-docs
path: docs/build
# Build python package
build:
name: Build wheels on ${{ matrix.os }}/${{ matrix.vers }}
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
include:
- vers: i686
os: ubuntu-20.04
# aarch64 seems to be stuck
# - vers: aarch64
# os: ubuntu-20.04
- vers: auto64
os: ubuntu-20.04
- vers: arm64
os: macos-10.15
- vers: auto64
os: macos-10.15
- vers: auto64
os: windows-2019
env:
SCCACHE_VERSION: 0.2.13
GITHUB_WORKSPACE: "${{github.workspace}}"
CIBW_BEFORE_ALL_LINUX: "curl https://sh.rustup.rs -sSf | sh -s -- --default-toolchain stable -y"
CIBW_BEFORE_ALL_MACOS: "rustup target add aarch64-apple-darwin x86_64-apple-darwin"
CIBW_BEFORE_ALL_WINDOWS: "rustup target add x86_64-pc-windows-msvc i686-pc-windows-msvc"
CIBW_ENVIRONMENT: 'PATH="$PATH:$HOME/.cargo/bin" LIBCST_NO_LOCAL_SCHEME=$LIBCST_NO_LOCAL_SCHEME'
CIBW_SKIP: "cp27-* cp34-* cp35-* pp* *-win32 *-win_arm64 *-musllinux_*"
CIBW_ARCHS: ${{ matrix.vers }}
CIBW_BUILD_VERBOSITY: 1
steps:
- uses: actions/checkout@v4
- uses: actions/checkout@v1
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- uses: actions/cache@v2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt', 'setup.py') }}
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install --upgrade --upgrade-strategy eager build -r requirements.txt -r requirements-dev.txt
- name: Disable scmtools local scheme
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
run: >-
echo LIBCST_NO_LOCAL_SCHEME=1 >> $GITHUB_ENV
- name: Build wheels
uses: pypa/cibuildwheel@v2.3.1
- uses: actions/upload-artifact@v2
with:
path: wheelhouse/*.whl
name: wheels
pypi:
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
name: Upload wheels to pypi
runs-on: ubuntu-latest
needs: build
steps:
- uses: actions/checkout@v1
- name: Download binary wheels
id: download
uses: actions/download-artifact@v2
with:
fetch-depth: 0
persist-credentials: false
- uses: actions/setup-python@v6
name: wheels
path: wheelhouse
- uses: actions/setup-python@v2
with:
python-version: "3.12"
- uses: dtolnay/rust-toolchain@stable
python-version: "3.10"
- uses: actions/cache@v2
id: cache
with:
path: ${{ env.pythonLocation }}
key: ${{ env.pythonLocation }}-${{ hashFiles('requirements.txt', 'requirements-dev.txt', 'setup.py') }}
- name: Install Dependencies
if: steps.cache.outputs.cache-hit != 'true'
run: |
pip install --upgrade --upgrade-strategy eager build -r requirements.txt -r requirements-dev.txt
- name: Disable scmtools local scheme
if: ${{ github.event_name == 'push' && github.ref == 'refs/heads/main' }}
run: >-
echo LIBCST_NO_LOCAL_SCHEME=1 >> $GITHUB_ENV
- name: Enable building wheels for pre-release CPython versions
if: github.event_name != 'release'
run: echo CIBW_ENABLE=cpython-prerelease >> $GITHUB_ENV
- name: Build wheels
uses: pypa/cibuildwheel@v3.2.1
- uses: actions/upload-artifact@v4
- name: Build a source tarball
run: >-
python -m
build
--sdist
--outdir ${{ steps.download.outputs.download-path }}
- name: Publish distribution 📦 to Test PyPI
uses: pypa/gh-action-pypi-publish@release/v1
with:
path: wheelhouse/*.whl
name: wheels-${{matrix.os}}
user: __token__
password: ${{ secrets.TEST_PYPI_API_TOKEN }}
repository_url: https://test.pypi.org/legacy/
packages_dir: ${{ steps.download.outputs.download-path }}
# Test rust parts
native:
name: Rust unit tests
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
toolchain: stable
components: rustfmt, clippy
- uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: test
uses: actions-rs/cargo@v1
with:
command: test
args: --manifest-path=native/Cargo.toml --release
- name: clippy
uses: actions-rs/clippy-check@v1
with:
token: ${{ secrets.GITHUB_TOKEN }}
args: --manifest-path=native/Cargo.toml --all-features
rustfmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- uses: actions-rs/toolchain@v1
with:
profile: minimal
toolchain: stable
override: true
- run: rustup component add rustfmt
- uses: actions-rs/cargo@v1
with:
command: fmt
args: --all --manifest-path=native/Cargo.toml -- --check

View file

@ -1,142 +0,0 @@
name: CI
on:
push:
branches:
- main
pull_request:
permissions: {}
jobs:
test:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [macos-latest, ubuntu-latest, windows-latest]
python-version:
- "3.9"
- "3.10"
- "3.11"
- "3.12"
- "3.13"
- "3.13t"
- "3.14"
- "3.14t"
steps:
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: "0.7.13"
python-version: ${{ matrix.python-version }}
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- uses: dtolnay/rust-toolchain@stable
- name: Build LibCST
run: uv sync --locked --dev
- name: Native Parser Tests
run: uv run poe test
- name: Coverage
run: uv run coverage report
# Run linters
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: "0.7.13"
python-version: "3.10"
- run: uv run poe lint
- run: uv run poe fixtures
# Run pyre typechecker
typecheck:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: "0.7.13"
python-version: "3.10"
- run: uv run poe typecheck
# Build the docs
docs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: "0.7.13"
python-version: "3.10"
- uses: ts-graphviz/setup-graphviz@v2
- run: uv run --group docs poe docs
- name: Archive Docs
uses: actions/upload-artifact@v4
with:
name: sphinx-docs
path: docs/build
# Test rust parts
native:
name: Rust unit tests
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest, macos-latest, windows-latest]
python-version: ["3.10", "3.13t"]
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt, clippy
- uses: actions/setup-python@v6
with:
python-version: ${{ matrix.python-version }}
- name: test
run: cargo test --manifest-path=native/Cargo.toml --release
- name: test without python
if: matrix.os == 'ubuntu-latest'
run: cargo test --manifest-path=native/Cargo.toml --release --no-default-features
- name: clippy
run: cargo clippy --manifest-path=native/Cargo.toml --all-targets --all-features
- name: compile-benchmarks
run: cargo bench --manifest-path=native/Cargo.toml --no-run
rustfmt:
name: Rustfmt
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
with:
persist-credentials: false
- uses: dtolnay/rust-toolchain@stable
with:
components: rustfmt
- run: rustup component add rustfmt
- name: format
run: cargo fmt --all --manifest-path=native/Cargo.toml -- --check
build:
# only trigger here for pull requests - regular pushes are handled in pypi_upload
if: ${{ github.event_name == 'pull_request' }}
uses: Instagram/LibCST/.github/workflows/build.yml@main

View file

@ -1,60 +0,0 @@
name: pypi_upload
on:
release:
types: [published]
push:
branches: [main]
permissions:
contents: read
jobs:
build:
uses: Instagram/LibCST/.github/workflows/build.yml@main
upload_release:
name: Upload wheels to pypi
runs-on: ubuntu-latest
needs: build
permissions:
id-token: write
steps:
- uses: actions/checkout@v4
with:
fetch-depth: 0
persist-credentials: false
- name: Download binary wheels
id: download
uses: actions/download-artifact@v5
with:
pattern: wheels-*
path: wheelhouse
merge-multiple: true
- uses: actions/setup-python@v6
with:
python-version: "3.10"
- name: Install uv
uses: astral-sh/setup-uv@v7
with:
version: "0.7.13"
enable-cache: false
- name: Build a source tarball
env:
LIBCST_NO_LOCAL_SCHEME: 1
OUTDIR: ${{ steps.download.outputs.download-path }}
run: >-
uv run python -m
build
--sdist
--outdir "$OUTDIR"
- name: Publish distribution 📦 to Test PyPI
if: github.event_name == 'push'
uses: pypa/gh-action-pypi-publish@release/v1
with:
repository-url: https://test.pypi.org/legacy/
packages-dir: ${{ steps.download.outputs.download-path }}
- name: Publish distribution 📦 to PyPI
if: github.event_name == 'release'
uses: pypa/gh-action-pypi-publish@release/v1
with:
packages-dir: ${{ steps.download.outputs.download-path }}

View file

@ -1,35 +0,0 @@
name: GitHub Actions Security Analysis with zizmor 🌈
on:
push:
branches: ["main"]
pull_request:
branches: ["**"]
jobs:
zizmor:
name: zizmor latest via PyPI
runs-on: ubuntu-latest
permissions:
security-events: write
contents: read
actions: read
steps:
- name: Checkout repository
uses: actions/checkout@v4
with:
persist-credentials: false
- name: Install the latest version of uv
uses: astral-sh/setup-uv@v7
- name: Run zizmor 🌈
run: uvx zizmor --format sarif . > results.sarif
env:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: Upload SARIF file
uses: github/codeql-action/upload-sarif@v4
with:
sarif_file: results.sarif
category: zizmor

4
.gitignore vendored
View file

@ -1,7 +1,6 @@
*.swp
*.swo
*.pyc
*.pyd
*.pyo
*.so
*.egg-info/
@ -18,6 +17,3 @@ libcst/_version.py
.hypothesis/
.python-version
target/
venv/
.venv/
.idea/

View file

@ -2,9 +2,6 @@
"exclude": [
".*\/native\/.*"
],
"ignore_all_errors": [
".venv"
],
"source_directories": [
"."
],

View file

@ -5,18 +5,12 @@ sphinx:
formats: all
build:
os: ubuntu-20.04
tools:
python: "3"
rust: "1.70"
apt_packages:
- graphviz
python:
version: 3.7
install:
- requirements: requirements.txt
- requirements: requirements-dev.txt
- method: pip
path: .
extra_requirements:
- dev
system_packages: true

File diff suppressed because it is too large Load diff

View file

@ -1,80 +1,5 @@
# Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to make participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies within all project spaces, and it also applies when
an individual is representing the project or its community in public spaces.
Examples of representing a project or community include using an official
project e-mail address, posting via an official social media account, or acting
as an appointed representative at an online or offline event. Representation of
a project may be further defined and clarified by project maintainers.
This Code of Conduct also applies outside the project spaces when there is a
reasonable belief that an individual's behavior may have a negative impact on
the project or its community.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at <opensource-conduct@fb.com>. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq
Facebook has adopted a Code of Conduct that we expect project participants to adhere to.
Please read the [full text](https://code.fb.com/codeofconduct/)
so that you can understand what actions will and will not be tolerated.

View file

@ -9,32 +9,12 @@ pull requests.
## Pull Requests
We actively welcome your pull requests.
### Setup Your Environment
1. Install a [Rust toolchain](https://rustup.rs) and [uv](https://docs.astral.sh/uv/)
2. Fork the repo on your side
3. Clone the repo
> git clone [your fork.git] libcst
> cd libcst
4. Sync with the main libcst version package
> git fetch --tags https://github.com/instagram/libcst
5. Setup the env
> uv sync
You are now ready to create your own branch from main, and contribute.
Please provide tests (using unittest), and update the documentation (both docstrings
and sphinx doc), if applicable.
### Before Submitting Your Pull Request
1. Format your code
> uv run poe format
2. Run the type checker
> uv run poe typecheck
3. Test your changes
> uv run poe test
4. Check linters
> uv run poe lint
1. Fork the repo and create your branch from `main`.
2. If you've added code that should be tested, add tests.
3. If you've changed APIs, update the documentation.
4. Ensure the test suite passes by `python -m unittest`.
5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA").
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need

View file

@ -13,8 +13,8 @@ PSF). These files are:
- libcst/_parser/parso/tests/test_fstring.py
- libcst/_parser/parso/tests/test_tokenize.py
- libcst/_parser/parso/tests/test_utils.py
- native/libcst/src/tokenizer/core/mod.rs
- native/libcst/src/tokenizer/core/string_types.rs
- libcst_native/src/tokenize/core/mod.rs
- libcst_native/src/tokenize/core/string_types.rs
Some Python files have been taken from dataclasses and are therefore Apache
licensed. Modifications on these files are licensed under Apache 2.0 license.

View file

@ -1,12 +0,0 @@
# How to make a new release
1. Add a new entry to `CHANGELOG.md` (I normally use the [new release page](https://github.com/Instagram/LibCST/releases/new) to generate a changelog, then manually group)
1. Follow the existing format: `Fixed`, `Added`, `Updated`, `Deprecated`, `Removed`, `New Contributors` sections, and the full changelog link at the bottom.
1. Mention only user-visible changes - improvements to CI, tests, or development workflow aren't noteworthy enough
1. Version bumps are generally not worth mentioning with some notable exceptions (like pyo3)
1. Group related PRs into one bullet point if it makes sense
2. manually bump versions in `Cargo.toml` files in the repo
3. run `cargo update -p libcst`
4. make a new PR with the above changes, get it reviewed and landed
5. make a new release on Github, create a new tag on publish, and copy the contents of the changelog entry in there
6. after publishing, check out the repo at the new tag, and run `cd native; cargo +nightly publish -Z package-workspace -p libcst_derive -p libcst`

View file

@ -1,5 +1,4 @@
include README.rst LICENSE CODE_OF_CONDUCT.md CONTRIBUTING.md docs/source/*.rst libcst/py.typed
include README.rst LICENSE CODE_OF_CONDUCT.md CONTRIBUTING.md requirements.txt requirements-dev.txt docs/source/*.rst libcst/py.typed
include native/Cargo.toml
recursive-include native *
recursive-exclude native/target *

View file

@ -4,13 +4,9 @@
A Concrete Syntax Tree (CST) parser and serializer library for Python
|support-ukraine| |readthedocs-badge| |ci-badge| |pypi-badge| |pypi-download| |notebook-badge| |types-badge|
|readthedocs-badge| |ci-badge| |codecov-badge| |pypi-badge| |pypi-download| |notebook-badge|
.. |support-ukraine| image:: https://img.shields.io/badge/Support-Ukraine-FFD500?style=flat&labelColor=005BBB
:alt: Support Ukraine - Help Provide Humanitarian Aid to Ukraine.
:target: https://opensource.fb.com/support-ukraine
.. |readthedocs-badge| image:: https://readthedocs.org/projects/libcst/badge/?version=latest&style=flat
.. |readthedocs-badge| image:: https://readthedocs.org/projects/pip/badge/?version=latest&style=flat
:target: https://libcst.readthedocs.io/en/latest/
:alt: Documentation
@ -18,6 +14,10 @@ A Concrete Syntax Tree (CST) parser and serializer library for Python
:target: https://github.com/Instagram/LibCST/actions/workflows/build.yml?query=branch%3Amain
:alt: Github Actions
.. |codecov-badge| image:: https://codecov.io/gh/Instagram/LibCST/branch/main/graph/badge.svg
:target: https://codecov.io/gh/Instagram/LibCST/branch/main
:alt: CodeCov
.. |pypi-badge| image:: https://img.shields.io/pypi/v/libcst.svg
:target: https://pypi.org/project/libcst
:alt: PYPI
@ -31,13 +31,9 @@ A Concrete Syntax Tree (CST) parser and serializer library for Python
:target: https://mybinder.org/v2/gh/Instagram/LibCST/main?filepath=docs%2Fsource%2Ftutorial.ipynb
:alt: Notebook
.. |types-badge| image:: https://img.shields.io/pypi/types/libcst
:target: https://pypi.org/project/libcst
:alt: PYPI - Types
.. intro-start
LibCST parses Python 3.0 -> 3.14 source code as a CST tree that keeps
LibCST parses Python 3.0, 3.1, 3.3, 3.5, 3.6, 3.7 or 3.8 source code as a CST tree that keeps
all formatting details (comments, whitespaces, parentheses, etc). It's useful for
building automated refactoring (codemod) applications and linters.
@ -62,9 +58,7 @@ Example expression::
1 + 2
CST representation:
.. code-block:: python
CST representation::
BinaryOperation(
left=Integer(
@ -127,7 +121,7 @@ For a more detailed usage example, `see our documentation
Installation
------------
LibCST requires Python 3.9+ and can be easily installed using most common Python
LibCST requires Python 3.6+ and can be easily installed using most common Python
packaging tools. We recommend installing the latest stable release from
`PyPI <https://pypi.org/project/libcst/>`_ with pip:
@ -135,11 +129,6 @@ packaging tools. We recommend installing the latest stable release from
pip install libcst
For parsing, LibCST ships with a native extension, so releases are distributed as binary
wheels as well as the source code. If a binary wheel is not available for your system
(Linux/Windows x86/x64 and Mac x64/arm are covered), you'll need a recent
`Rust toolchain <https://rustup.rs>`_ for installing.
Further Reading
---------------
- `Static Analysis at Scale: An Instagram Story. <https://instagram-engineering.com/static-analysis-at-scale-an-instagram-story-8f498ab71a0c>`_
@ -148,8 +137,42 @@ Further Reading
Development
-----------
See `CONTRIBUTING.md <CONTRIBUTING.md>`_ for more details.
Start by setting up and activating a virtualenv:
.. code-block:: shell
git clone git@github.com:Instagram/LibCST.git libcst
cd libcst
python3 -m venv ../libcst-env/ # just an example, put this wherever you want
source ../libcst-env/bin/activate
pip install --upgrade pip # optional, if you have an old system version of pip
pip install -r requirements.txt -r requirements-dev.txt
# If you're done with the virtualenv, you can leave it by running:
deactivate
We use `ufmt <https://ufmt.omnilib.dev/en/stable/>`_ to format code. To format
changes to be conformant, run the following in the root:
.. code-block:: shell
ufmt format && python -m fixit.cli.apply_fix
To run all tests, you'll need to do the following in the root:
.. code-block:: shell
python -m unittest
You can also run individual tests by using unittest and specifying a module like
this:
.. code-block:: shell
python -m unittest libcst.tests.test_batched_visitor
See the `unittest documentation <https://docs.python.org/3/library/unittest.html>`_
for more examples of how to run tests.
Building
~~~~~~~~
@ -166,11 +189,13 @@ directory:
cargo build
The ``libcst.native`` module should be rebuilt automatically, but to force it:
To build the ``libcst.native`` module and install ``libcst``, run this
from the root:
.. code-block:: shell
uv sync --reinstall-package libcst
pip uninstall -y libcst
pip install -e .
Type Checking
~~~~~~~~~~~~~
@ -181,7 +206,10 @@ To verify types for the library, do the following in the root:
.. code-block:: shell
uv run poe typecheck
pyre check
*Note:* You may need to run the ``pip install -e .`` command prior
to type checking, see the section above on building.
Generating Documents
~~~~~~~~~~~~~~~~~~~~
@ -190,7 +218,7 @@ To generate documents, do the following in the root:
.. code-block:: shell
uv run --group docs poe docs
sphinx-build docs/source/ docs/build/
Future
======

View file

@ -1,2 +0,0 @@
rustc
cargo

4
codecov.yml Normal file
View file

@ -0,0 +1,4 @@
coverage:
status:
project: no
patch: yes

View file

@ -26,7 +26,7 @@ then edit the produced ``.libcst.codemod.yaml`` file::
python3 -m libcst.tool initialize .
The file includes provisions for customizing any generated code marker, calling an
external code formatter such as `black <https://pypi.org/project/black/>`_, blacklisting
external code formatter such as `black <https://pypi.org/project/black/>`_, blackisting
patterns of files you never wish to touch and a list of modules that contain valid
codemods that can be executed. If you want to write and run codemods specific to your
repository or organization, you can add an in-repo module location to the list of
@ -135,18 +135,16 @@ replaces any string which matches our string command-line argument with a consta
It also takes care of adding the import required for the constant to be defined properly.
Cool! Let's look at the command-line help for this codemod. Let's assume you saved it
as ``constant_folding.py``. You can get help for the
as ``constant_folding.py`` inside ``libcst.codemod.commands``. You can get help for the
codemod by running the following command::
python3 -m libcst.tool codemod -x constant_folding.ConvertConstantCommand --help
python3 -m libcst.tool codemod constant_folding.ConvertConstantCommand --help
Notice that along with the default arguments, the ``--string`` and ``--constant``
arguments are present in the help, and the command-line description has been updated
with the codemod's description string. You'll notice that the codemod also shows up
on ``libcst.tool list``.
And ``-x`` flag allows to load any module as a codemod in addition to the standard ones.
----------------
Testing Codemods
----------------

View file

@ -71,7 +71,7 @@ master_doc = "index"
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = "en"
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
@ -196,7 +196,6 @@ intersphinx_mapping = {"python": ("https://docs.python.org/3", None)}
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
# -- autodoc customization
def strip_class_signature(app, what, name, obj, options, signature, return_annotation):
if what == "class":
@ -219,7 +218,7 @@ def setup(app):
nbsphinx_prolog = r"""
{% set docname = 'docs/source/' + env.doc2path(env.docname, base=None)|string%}
{% set docname = 'docs/source/' + env.doc2path(env.docname, base=None) %}
.. only:: html

View file

@ -32,18 +32,3 @@ Functions that assist in traversing an existing LibCST tree.
.. autofunction:: libcst.helpers.get_full_name_for_node
.. autofunction:: libcst.helpers.get_full_name_for_node_or_raise
.. autofunction:: libcst.helpers.ensure_type
Node fields filtering Helpers
-----------------------------
Function that assist when handling CST nodes' fields.
.. autofunction:: libcst.helpers.filter_node_fields
And lower level functions:
.. autofunction:: libcst.helpers.get_node_fields
.. autofunction:: libcst.helpers.is_whitespace_node_field
.. autofunction:: libcst.helpers.is_syntax_node_field
.. autofunction:: libcst.helpers.is_default_node_field
.. autofunction:: libcst.helpers.get_field_default_value

View file

@ -18,10 +18,10 @@ numbers of nodes through the :class:`~libcst.metadata.PositionProvider`:
.. code-block:: python
class NamePrinter(cst.CSTVisitor):
METADATA_DEPENDENCIES = (cst.metadata.PositionProvider,)
METADATA_DEPENDENCIES = (cst.PositionProvider,)
def visit_Name(self, node: cst.Name) -> None:
pos = self.get_metadata(cst.metadata.PositionProvider, node).start
pos = self.get_metadata(cst.PositionProvider, node).start
print(f"{node.value} found at line {pos.line}, column {pos.column}")
wrapper = cst.metadata.MetadataWrapper(cst.parse_module("x = 1"))
@ -94,7 +94,7 @@ declaring one of :class:`~libcst.metadata.PositionProvider` or
most cases, :class:`~libcst.metadata.PositionProvider` is what you probably
want.
Node positions are represented with :class:`~libcst.metadata.CodeRange`
Node positions are is represented with :class:`~libcst.metadata.CodeRange`
objects. See :ref:`the above example<libcst-metadata-position-example>`.
.. autoclass:: libcst.metadata.PositionProvider
@ -134,7 +134,7 @@ New scopes are created for classes, functions, and comprehensions. Other block
constructs like conditional statements, loops, and try…except don't create their
own scope.
There are five different types of scopes in Python:
There are five different type of scope in Python:
:class:`~libcst.metadata.BuiltinScope`,
:class:`~libcst.metadata.GlobalScope`,
:class:`~libcst.metadata.ClassScope`,
@ -226,14 +226,6 @@ We provide :class:`~libcst.metadata.ParentNodeProvider` for those use cases.
.. autoclass:: libcst.metadata.ParentNodeProvider
:no-undoc-members:
File Path Metadata
------------------
This provides the absolute file path on disk for any module being visited.
Requires an active :class:`~libcst.metadata.FullRepoManager` when using this provider.
.. autoclass:: libcst.metadata.FilePathProvider
:no-undoc-members:
Type Inference Metadata
-----------------------
`Type inference <https://en.wikipedia.org/wiki/Type_inference>`__ is to automatically infer
@ -242,8 +234,8 @@ In Python, type checkers like `Mypy <https://github.com/python/mypy>`_ or
`Pyre <https://pyre-check.org/>`__ analyze `type annotations <https://docs.python.org/3/library/typing.html>`__
and infer types for expressions.
:class:`~libcst.metadata.TypeInferenceProvider` is provided by `Pyre Query API <https://pyre-check.org/docs/querying-pyre.html>`__
which requires `setup watchman <https://pyre-check.org/docs/getting-started/>`_ for incremental typechecking.
:class:`~libcst.metadata.FullRepoManager` is built for manage the inter process communication to Pyre.
which requires `setup watchman <https://pyre-check.org/docs/watchman-integration.html>`_ for incremental typechecking.
:class:`~libcst.metadata.FullRepoManger` is built for manage the inter process communication to Pyre.
.. autoclass:: libcst.metadata.TypeInferenceProvider
:no-undoc-members:

View file

@ -90,7 +90,7 @@
"source": [
"Warn on unused imports and undefined references\n",
"===============================================\n",
"To find all unused imports, we iterate through :attr:`~libcst.metadata.Scope.assignments` and an assignment is unused when its :attr:`~libcst.metadata.BaseAssignment.references` is empty. To find all undefined references, we iterate through :attr:`~libcst.metadata.Scope.accesses` (we focus on :class:`~libcst.Import`/:class:`~libcst.ImportFrom` assignments) and an access is undefined reference when its :attr:`~libcst.metadata.Access.referents` is empty. When reporting the warning to the developer, we'll want to report the line number and column offset along with the suggestion to make it more clear. We can get position information from :class:`~libcst.metadata.PositionProvider` and print the warnings as follows.\n"
"To find all unused imports, we iterate through :attr:`~libcst.metadata.Scope.assignments` and an assignment is unused when its :attr:`~libcst.metadata.BaseAssignment.references` is empty. To find all undefined references, we iterate through :attr:`~libcst.metadata.Scope.accesses` (we focus on :class:`~libcst.Import`/:class:`~libcst.ImportFrom` assignments) and an access is undefined reference when its :attr:`~libcst.metadata.Access.referents` is empty. When reporting the warning to developer, we'll want to report the line number and column offset along with the suggestion to make it more clear. We can get position information from :class:`~libcst.metadata.PositionProvider` and print the warnings as follows.\n"
]
},
{
@ -136,13 +136,13 @@
"Automatically Remove Unused Import\n",
"==================================\n",
"Unused import is a commmon code suggestion provided by lint tool like `flake8 F401 <https://lintlyci.github.io/Flake8Rules/rules/F401.html>`_ ``imported but unused``.\n",
"Even though reporting unused imports is already useful, with LibCST we can provide an automatic fix to remove unused imports. That can make the suggestion more actionable and save developer's time.\n",
"Even though reporting unused import is already useful, with LibCST we can provide automatic fix to remove unused import. That can make the suggestion more actionable and save developer's time.\n",
"\n",
"An import statement may import multiple names, we want to remove those unused names from the import statement. If all the names in the import statement are not used, we remove the entire import.\n",
"To remove the unused name, we implement ``RemoveUnusedImportTransformer`` by subclassing :class:`~libcst.CSTTransformer`. We overwrite ``leave_Import`` and ``leave_ImportFrom`` to modify the import statements.\n",
"When we find the import node in the lookup table, we iterate through all ``names`` and keep used names in ``names_to_keep``.\n",
"When we find the import node in lookup table, we iterate through all ``names`` and keep used names in ``names_to_keep``.\n",
"If ``names_to_keep`` is empty, all names are unused and we remove the entire import node.\n",
"Otherwise, we update the import node and just remove partial names."
"Otherwise, we update the import node and just removing partial names."
]
},
{
@ -195,7 +195,7 @@
"raw_mimetype": "text/restructuredtext"
},
"source": [
"After the transform, we use ``.code`` to generate the fixed code and all unused names are fixed as expected! The difflib is used to show only the changed part and only imported lines are updated as expected."
"After the transform, we use ``.code`` to generate fixed code and all unused names are fixed as expected! The difflib is used to show only changed part and only import lines are updated as expected."
]
},
{

View file

@ -1,25 +1,24 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {
"raw_mimetype": "text/restructuredtext"
},
"cell_type": "raw",
"source": [
"====================\n",
"Parsing and Visiting\n",
"====================\n",
"\n",
"LibCST provides helpers to parse source code string as a concrete syntax tree. In order to perform static analysis to identify patterns in the tree or modify the tree programmatically, we can use the visitor pattern to traverse the tree. In this tutorial, we demonstrate a common four-step-workflow to build an automated refactoring (codemod) application:\n",
"LibCST provides helpers to parse source code string as concrete syntax tree. In order to perform static analysis to identify patterns in the tree or modify the tree programmatically, we can use visitor pattern to traverse the tree. In this tutorial, we demonstrate a common three-step-workflow to build an automated refactoring (codemod) application:\n",
"\n",
"1. `Parse Source Code <#Parse-Source-Code>`_\n",
"2. `Display The Source Code CST <#Display-Source-Code-CST>`_\n",
"3. `Build Visitor or Transformer <#Build-Visitor-or-Transformer>`_\n",
"4. `Generate Source Code <#Generate-Source-Code>`_\n",
"2. `Build Visitor or Transformer <#Build-Visitor-or-Transformer>`_\n",
"3. `Generate Source Code <#Generate-Source-Code>`_\n",
"\n",
"Parse Source Code\n",
"=================\n",
"LibCST provides various helpers to parse source code as a concrete syntax tree: :func:`~libcst.parse_module`, :func:`~libcst.parse_expression` and :func:`~libcst.parse_statement` (see :doc:`Parsing <parser>` for more detail)."
"LibCST provides various helpers to parse source code as concrete syntax tree: :func:`~libcst.parse_module`, :func:`~libcst.parse_expression` and :func:`~libcst.parse_statement` (see :doc:`Parsing <parser>` for more detail). The default :class:`~libcst.CSTNode` repr provides pretty print formatting for reading the tree easily."
]
},
{
@ -42,42 +41,7 @@
"source": [
"import libcst as cst\n",
"\n",
"source_tree = cst.parse_expression(\"1 + 2\")"
]
},
{
"metadata": {
"raw_mimetype": "text/restructuredtext"
},
"cell_type": "raw",
"source": [
"|\n",
"Display Source Code CST\n",
"=======================\n",
"The default :class:`~libcst.CSTNode` repr provides pretty print formatting for displaying the entire CST tree."
]
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": "print(source_tree)"
},
{
"metadata": {},
"cell_type": "raw",
"source": "The entire CST tree may be overwhelming at times. To only focus on essential elements of the CST tree, LibCST provides the ``dump`` helper."
},
{
"metadata": {},
"cell_type": "code",
"outputs": [],
"execution_count": null,
"source": [
"from libcst.display import dump\n",
"\n",
"print(dump(source_tree))"
"cst.parse_expression(\"1 + 2\")"
]
},
{
@ -86,11 +50,9 @@
"raw_mimetype": "text/restructuredtext"
},
"source": [
" \n",
"|\n",
"Example: add typing annotation from pyi stub file to Python source\n",
"------------------------------------------------------------------\n",
"Python `typing annotation <https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html>`_ was added in Python 3.5. Some Python applications add typing annotations in separate ``pyi`` stub files in order to support old Python versions. When applications decide to stop supporting old Python versions, they'll want to automatically copy the type annotation from a pyi file to a source file. Here we demonstrate how to do that easily using LibCST. The first step is to parse the pyi stub and source files as trees."
"Python `typing annotation <https://mypy.readthedocs.io/en/latest/cheat_sheet_py3.html>`_ was added in Python 3.5. Some Python applications add typing annotations in separate ``pyi`` stub files in order to support old Python versions. When applications decide to stop supporting old Python versions, they'll want to automatically copy the type annotation from a pyi file to a source file. Here we demonstrate how to do that easliy using LibCST. The first step is to parse the pyi stub and source files as trees."
]
},
{
@ -106,7 +68,7 @@
" self._replace(type=self.type.name))\n",
"\n",
"def tokenize(code, version_info, start_pos=(1, 0)):\n",
" \"\"\"Generate tokens from the source code (string).\"\"\"\n",
" \"\"\"Generate tokens from a the source code (string).\"\"\"\n",
" lines = split_lines(code, keepends=True)\n",
" return tokenize_lines(lines, version_info, start_pos=start_pos)\n",
"'''\n",
@ -130,11 +92,10 @@
"raw_mimetype": "text/restructuredtext"
},
"source": [
"|\n",
"Build Visitor or Transformer\n",
"============================\n",
"For traversing and modifying the tree, LibCST provides Visitor and Transformer classes similar to the `ast module <https://docs.python.org/3/library/ast.html#ast.NodeVisitor>`_. To implement a visitor (read only) or transformer (read/write), simply implement a subclass of :class:`~libcst.CSTVisitor` or :class:`~libcst.CSTTransformer` (see :doc:`Visitors <visitors>` for more detail).\n",
"In the typing example, we need to implement a visitor to collect typing annotations from the stub tree and a transformer to copy the annotation to the function signature. In the visitor, we implement ``visit_FunctionDef`` to collect annotations. Later in the transformer, we implement ``leave_FunctionDef`` to add the collected annotations."
"In the typing example, we need to implement a visitor to collect typing annotation from the stub tree and a transformer to copy the annotation to the function signature. In the visitor, we implement ``visit_FunctionDef`` to collect annotations. Later in the transformer, we implement ``leave_FunctionDef`` to add the collected annotations."
]
},
{
@ -223,10 +184,9 @@
"raw_mimetype": "text/restructuredtext"
},
"source": [
"|\n",
"Generate Source Code\n",
"====================\n",
"Generating the source code from a cst tree is as easy as accessing the :attr:`~libcst.Module.code` attribute on :class:`~libcst.Module`. After the code generation, we often use `ufmt <https://ufmt.omnilib.dev/en/stable/>`_ to reformat the code to keep a consistent coding style."
"Generating the source code from a cst tree is as easy as accessing the :attr:`~libcst.Module.code` attribute on :class:`~libcst.Module`. After the code generation, we often use `ufmt <https://ufmt.omnilib.dev/en/stable/>`_ to reformate the code to keep a consistent coding style."
]
},
{

View file

@ -4,7 +4,7 @@
# LICENSE file in the root directory of this source tree.
from libcst._batched_visitor import BatchableCSTVisitor, visit_batched
from libcst._exceptions import CSTLogicError, MetadataException, ParserSyntaxError
from libcst._exceptions import MetadataException, ParserSyntaxError
from libcst._flatten_sentinel import FlattenSentinel
from libcst._maybe_sentinel import MaybeSentinel
from libcst._metadata_dependent import MetadataDependent
@ -29,7 +29,6 @@ from libcst._nodes.expression import (
BaseSimpleComp,
BaseSlice,
BaseString,
BaseTemplatedStringContent,
BinaryOperation,
BooleanOperation,
Call,
@ -76,9 +75,6 @@ from libcst._nodes.expression import (
StarredElement,
Subscript,
SubscriptElement,
TemplatedString,
TemplatedStringExpression,
TemplatedStringText,
Tuple,
UnaryOperation,
Yield,
@ -187,7 +183,6 @@ from libcst._nodes.statement import (
MatchValue,
NameItem,
Nonlocal,
ParamSpec,
Pass,
Raise,
Return,
@ -195,11 +190,6 @@ from libcst._nodes.statement import (
SimpleStatementSuite,
Try,
TryStar,
TypeAlias,
TypeParam,
TypeParameters,
TypeVar,
TypeVarTuple,
While,
With,
WithItem,
@ -246,7 +236,6 @@ __all__ = [
"CSTVisitorT",
"FlattenSentinel",
"MaybeSentinel",
"CSTLogicError",
"MetadataException",
"ParserSyntaxError",
"PartialParserConfig",
@ -272,7 +261,6 @@ __all__ = [
"BaseElement",
"BaseExpression",
"BaseFormattedStringContent",
"BaseTemplatedStringContent",
"BaseList",
"BaseNumber",
"BaseSet",
@ -296,9 +284,6 @@ __all__ = [
"FormattedString",
"FormattedStringExpression",
"FormattedStringText",
"TemplatedString",
"TemplatedStringText",
"TemplatedStringExpression",
"From",
"GeneratorExp",
"IfExp",
@ -453,10 +438,4 @@ __all__ = [
"VisitorMetadataProvider",
"MetadataDependent",
"MetadataWrapper",
"TypeVar",
"TypeVarTuple",
"ParamSpec",
"TypeParam",
"TypeParameters",
"TypeAlias",
]

View file

@ -1,10 +1,8 @@
# This file is derived from github.com/ericvsmith/dataclasses, and is Apache 2 licensed.
# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f188f452/LICENSE.txt
# https://github.com/ericvsmith/dataclasses/blob/ae712dd993420d43444f/dataclass_tools.py
# Changed: takes slots in base classes into account when creating slots
import dataclasses
from itertools import chain, filterfalse
from typing import Any, Mapping, Type, TypeVar
_T = TypeVar("_T")
@ -21,14 +19,7 @@ def add_slots(cls: Type[_T]) -> Type[_T]:
# Create a new dict for our new class.
cls_dict = dict(cls.__dict__)
field_names = tuple(f.name for f in dataclasses.fields(cls))
inherited_slots = set(
chain.from_iterable(
superclass.__dict__.get("__slots__", ()) for superclass in cls.mro()
)
)
cls_dict["__slots__"] = tuple(
filterfalse(inherited_slots.__contains__, field_names)
)
cls_dict["__slots__"] = field_names
for field_name in field_names:
# Remove our attributes, if present. They'll still be
# available in _MARKER.
@ -38,10 +29,19 @@ def add_slots(cls: Type[_T]) -> Type[_T]:
# Create the class.
qualname = getattr(cls, "__qualname__", None)
# pyre-fixme[9]: cls has type `Type[Variable[_T]]`; used as `_T`.
# pyre-fixme[19]: Expected 0 positional arguments.
cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
try:
# GenericMeta in py3.6 requires us to track __orig_bases__. This is fixed in py3.7
# by the removal of GenericMeta. We should just be able to use cls.__bases__ in the
# future.
bases = getattr(cls, "__orig_bases__", cls.__bases__)
# pyre-fixme[9]: cls has type `Type[Variable[_T]]`; used as `_T`.
# pyre-fixme[19]: Expected 0 positional arguments.
cls = type(cls)(cls.__name__, bases, cls_dict)
except TypeError:
# We're in py3.7 and should use cls.__bases__
# pyre-fixme[9]: cls has type `Type[Variable[_T]]`; used as `_T`.
# pyre-fixme[19]: Expected 0 positional arguments.
cls = type(cls)(cls.__name__, cls.__bases__, cls_dict)
if qualname is not None:
cls.__qualname__ = qualname
@ -50,14 +50,12 @@ def add_slots(cls: Type[_T]) -> Type[_T]:
def __getstate__(self: object) -> Mapping[str, Any]:
return {
field.name: getattr(self, field.name)
for field in dataclasses.fields(self)
if hasattr(self, field.name)
slot: getattr(self, slot) for slot in self.__slots__ if hasattr(self, slot)
}
def __setstate__(self: object, state: Mapping[str, Any]) -> None:
for fieldname, value in state.items():
object.__setattr__(self, fieldname, value)
for slot, value in state.items():
object.__setattr__(self, slot, value)
cls.__getstate__ = __getstate__
cls.__setstate__ = __setstate__

View file

@ -4,11 +4,18 @@
# LICENSE file in the root directory of this source tree.
from enum import auto, Enum
from typing import Any, Callable, final, Optional, Sequence, Tuple
from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union
from typing_extensions import final
from libcst._parser.parso.pgen2.generator import ReservedString
from libcst._parser.parso.python.token import PythonTokenTypes, TokenType
from libcst._parser.types.token import Token
from libcst._tabs import expand_tabs
_EOF_STR: str = "end of file (EOF)"
_INDENT_STR: str = "an indent"
_DEDENT_STR: str = "a dedent"
_NEWLINE_CHARS: str = "\r\n"
@ -16,10 +23,42 @@ class EOFSentinel(Enum):
EOF = auto()
class CSTLogicError(Exception):
"""General purpose internal error within LibCST itself."""
def get_expected_str(
encountered: Union[Token, EOFSentinel],
expected: Union[Iterable[Union[TokenType, ReservedString]], EOFSentinel],
) -> str:
if (
isinstance(encountered, EOFSentinel)
or encountered.type is PythonTokenTypes.ENDMARKER
):
encountered_str = _EOF_STR
elif encountered.type is PythonTokenTypes.INDENT:
encountered_str = _INDENT_STR
elif encountered.type is PythonTokenTypes.DEDENT:
encountered_str = _DEDENT_STR
else:
encountered_str = repr(encountered.string)
pass
if isinstance(expected, EOFSentinel):
expected_names = [_EOF_STR]
else:
expected_names = sorted(
[
repr(el.name) if isinstance(el, TokenType) else repr(el.value)
for el in expected
]
)
if len(expected_names) > 10:
# There's too many possibilities, so it's probably not useful to list them.
# Instead, let's just abbreviate the message.
return f"Unexpectedly encountered {encountered_str}."
else:
if len(expected_names) == 1:
expected_str = expected_names[0]
else:
expected_str = f"{', '.join(expected_names[:-1])}, or {expected_names[-1]}"
return f"Encountered {encountered_str}, but expected {expected_str}."
# pyre-fixme[2]: 'Any' type isn't pyre-strict.

View file

@ -7,17 +7,14 @@ import inspect
from abc import ABC
from contextlib import contextmanager
from typing import (
Callable,
cast,
ClassVar,
Collection,
Generic,
Iterator,
Mapping,
Type,
TYPE_CHECKING,
TypeVar,
Union,
)
if TYPE_CHECKING:
@ -32,28 +29,7 @@ if TYPE_CHECKING:
_T = TypeVar("_T")
class _UNDEFINED_DEFAULT:
pass
class LazyValue(Generic[_T]):
"""
The class for implementing a lazy metadata loading mechanism that improves the
performance when retriving expensive metadata (e.g., qualified names). Providers
including :class:`~libcst.metadata.QualifiedNameProvider` use this class to load
the metadata of a certain node lazily when calling
:func:`~libcst.MetadataDependent.get_metadata`.
"""
def __init__(self, callable: Callable[[], _T]) -> None:
self.callable = callable
self.return_value: Union[_T, Type[_UNDEFINED_DEFAULT]] = _UNDEFINED_DEFAULT
def __call__(self) -> _T:
if self.return_value is _UNDEFINED_DEFAULT:
self.return_value = self.callable()
return cast(_T, self.return_value)
_UNDEFINED_DEFAULT = object()
class MetadataDependent(ABC):
@ -131,9 +107,6 @@ class MetadataDependent(ABC):
)
if default is not _UNDEFINED_DEFAULT:
value = self.metadata[key].get(node, default)
return cast(_T, self.metadata[key].get(node, default))
else:
value = self.metadata[key][node]
if isinstance(value, LazyValue):
value = value()
return cast(_T, value)
return cast(_T, self.metadata[key][node])

View file

@ -6,9 +6,8 @@
from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass, field, fields, replace
from typing import Any, cast, ClassVar, Dict, List, Mapping, Sequence, TypeVar, Union
from typing import Any, cast, Dict, List, Mapping, Sequence, TypeVar, Union
from libcst import CSTLogicError
from libcst._flatten_sentinel import FlattenSentinel
from libcst._nodes.internal import CodegenState
from libcst._removal_sentinel import RemovalSentinel
@ -110,8 +109,6 @@ def _clone(val: object) -> object:
@dataclass(frozen=True)
class CSTNode(ABC):
__slots__: ClassVar[Sequence[str]] = ()
def __post_init__(self) -> None:
# PERF: It might make more sense to move validation work into the visitor, which
# would allow us to avoid validating the tree when parsing a file.
@ -238,7 +235,7 @@ class CSTNode(ABC):
# validate return type of the user-defined `visitor.on_leave` method
if not isinstance(leave_result, (CSTNode, RemovalSentinel, FlattenSentinel)):
raise CSTValidationError(
raise Exception(
"Expected a node of type CSTNode or a RemovalSentinel, "
+ f"but got a return value of {type(leave_result).__name__}"
)
@ -293,7 +290,8 @@ class CSTNode(ABC):
return False
@abstractmethod
def _codegen_impl(self, state: CodegenState) -> None: ...
def _codegen_impl(self, state: CodegenState) -> None:
...
def _codegen(self, state: CodegenState, **kwargs: Any) -> None:
state.before_codegen(self)
@ -383,7 +381,7 @@ class CSTNode(ABC):
new_tree = self.visit(_ChildReplacementTransformer(old_node, new_node))
if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)):
# The above transform never returns *Sentinel, so this isn't possible
raise CSTLogicError("Logic error, cannot get a *Sentinel here!")
raise Exception("Logic error, cannot get a *Sentinal here!")
return new_tree
def deep_remove(
@ -400,7 +398,7 @@ class CSTNode(ABC):
if isinstance(new_tree, FlattenSentinel):
# The above transform never returns FlattenSentinel, so this isn't possible
raise CSTLogicError("Logic error, cannot get a FlattenSentinel here!")
raise Exception("Logic error, cannot get a FlattenSentinel here!")
return new_tree
@ -422,7 +420,7 @@ class CSTNode(ABC):
new_tree = self.visit(_ChildWithChangesTransformer(old_node, changes))
if isinstance(new_tree, (FlattenSentinel, RemovalSentinel)):
# This is impossible with the above transform.
raise CSTLogicError("Logic error, cannot get a *Sentinel here!")
raise Exception("Logic error, cannot get a *Sentinel here!")
return new_tree
def __eq__(self: _CSTNodeSelfT, other: object) -> bool:
@ -470,8 +468,6 @@ class CSTNode(ABC):
class BaseLeaf(CSTNode, ABC):
__slots__ = ()
@property
def children(self) -> Sequence[CSTNode]:
# override this with an optimized implementation
@ -491,8 +487,6 @@ class BaseValueToken(BaseLeaf, ABC):
into the parent CSTNode, and hard-coded into the implementation of _codegen.
"""
__slots__ = ()
value: str
def _codegen_impl(self, state: CodegenState) -> None:

View file

@ -15,9 +15,9 @@ from tokenize import (
Imagnumber as IMAGNUMBER_RE,
Intnumber as INTNUMBER_RE,
)
from typing import Callable, Generator, Literal, Optional, Sequence, Union
from typing import Callable, Generator, Optional, Sequence, Union
from libcst import CSTLogicError
from typing_extensions import Literal
from libcst._add_slots import add_slots
from libcst._maybe_sentinel import MaybeSentinel
@ -222,8 +222,6 @@ class _BaseParenthesizedNode(CSTNode, ABC):
this to get that functionality.
"""
__slots__ = ()
lpar: Sequence[LeftParen] = ()
# Sequence of parenthesis for precedence dictation.
rpar: Sequence[RightParen] = ()
@ -256,8 +254,6 @@ class BaseExpression(_BaseParenthesizedNode, ABC):
An base class for all expressions. :class:`BaseExpression` contains no fields.
"""
__slots__ = ()
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
"""
Returns true if this expression is safe to be use with a word operator
@ -300,7 +296,7 @@ class BaseAssignTargetExpression(BaseExpression, ABC):
<https://github.com/python/cpython/blob/v3.8.0a4/Python/ast.c#L1120>`_.
"""
__slots__ = ()
pass
class BaseDelTargetExpression(BaseExpression, ABC):
@ -320,7 +316,7 @@ class BaseDelTargetExpression(BaseExpression, ABC):
<https://github.com/python/cpython/blob/v3.8.0a4/Python/compile.c#L4854>`_.
"""
__slots__ = ()
pass
@add_slots
@ -354,7 +350,7 @@ class Name(BaseAssignTargetExpression, BaseDelTargetExpression):
if len(self.value) == 0:
raise CSTValidationError("Cannot have empty name identifier.")
if not self.value.isidentifier():
raise CSTValidationError(f"Name {self.value!r} is not a valid identifier.")
raise CSTValidationError("Name is not a valid identifier.")
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
@ -397,8 +393,6 @@ class BaseNumber(BaseExpression, ABC):
used anywhere that you need to explicitly take any number type.
"""
__slots__ = ()
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
"""
Numbers are funny. The expression "5in [1,2,3,4,5]" is a valid expression
@ -528,15 +522,13 @@ class BaseString(BaseExpression, ABC):
:class:`SimpleString`, :class:`ConcatenatedString`, and :class:`FormattedString`.
"""
__slots__ = ()
pass
StringQuoteLiteral = Literal['"', "'", '"""', "'''"]
class _BasePrefixedString(BaseString, ABC):
__slots__ = ()
@property
def prefix(self) -> str:
"""
@ -655,20 +647,14 @@ class SimpleString(_BasePrefixedString):
if len(quote) == 2:
# Let's assume this is an empty string.
quote = quote[:1]
elif 3 < len(quote) <= 6:
# Let's assume this can be one of the following:
# >>> """"foo"""
# '"foo'
# >>> """""bar"""
# '""bar'
# >>> """"""
# ''
elif len(quote) == 6:
# Let's assume this is an empty triple-quoted string.
quote = quote[:3]
if len(quote) not in {1, 3}:
# We shouldn't get here due to construction validation logic,
# but handle the case anyway.
raise CSTLogicError(f"Invalid string {self.value}")
raise Exception("Invalid string {self.value}")
# pyre-ignore We know via the above validation that we will only
# ever return one of the four string literals.
@ -699,7 +685,7 @@ class SimpleString(_BasePrefixedString):
state.add_token(self.value)
@property
def evaluated_value(self) -> Union[str, bytes]:
def evaluated_value(self) -> str:
"""
Return an :func:`ast.literal_eval` evaluated str of :py:attr:`value`.
"""
@ -713,7 +699,7 @@ class BaseFormattedStringContent(CSTNode, ABC):
sequence of :class:`BaseFormattedStringContent` parts.
"""
__slots__ = ()
pass
@add_slots
@ -958,253 +944,6 @@ class FormattedString(_BasePrefixedString):
state.add_token(self.end)
class BaseTemplatedStringContent(CSTNode, ABC):
"""
The base type for :class:`TemplatedStringText` and
:class:`TemplatedStringExpression`. A :class:`TemplatedString` is composed of a
sequence of :class:`BaseTemplatedStringContent` parts.
"""
__slots__ = ()
@add_slots
@dataclass(frozen=True)
class TemplatedStringText(BaseTemplatedStringContent):
"""
Part of a :class:`TemplatedString` that is not inside curly braces (``{`` or ``}``).
For example, in::
f"ab{cd}ef"
``ab`` and ``ef`` are :class:`TemplatedStringText` nodes, but ``{cd}`` is a
:class:`TemplatedStringExpression`.
"""
#: The raw string value, including any escape characters present in the source
#: code, not including any enclosing quotes.
value: str
def _visit_and_replace_children(
self, visitor: CSTVisitorT
) -> "TemplatedStringText":
return TemplatedStringText(value=self.value)
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token(self.value)
@add_slots
@dataclass(frozen=True)
class TemplatedStringExpression(BaseTemplatedStringContent):
"""
Part of a :class:`TemplatedString` that is inside curly braces (``{`` or ``}``),
including the surrounding curly braces. For example, in::
f"ab{cd}ef"
``{cd}`` is a :class:`TemplatedStringExpression`, but ``ab`` and ``ef`` are
:class:`TemplatedStringText` nodes.
An t-string expression may contain ``conversion`` and ``format_spec`` suffixes that
control how the expression is converted to a string.
"""
#: The expression we will evaluate and render when generating the string.
expression: BaseExpression
#: An optional conversion specifier, such as ``!s``, ``!r`` or ``!a``.
conversion: Optional[str] = None
#: An optional format specifier following the `format specification mini-language
#: <https://docs.python.org/3/library/string.html#formatspec>`_.
format_spec: Optional[Sequence[BaseTemplatedStringContent]] = None
#: Whitespace after the opening curly brace (``{``), but before the ``expression``.
whitespace_before_expression: BaseParenthesizableWhitespace = (
SimpleWhitespace.field("")
)
#: Whitespace after the ``expression``, but before the ``conversion``,
#: ``format_spec`` and the closing curly brace (``}``). Python does not
#: allow whitespace inside or after a ``conversion`` or ``format_spec``.
whitespace_after_expression: BaseParenthesizableWhitespace = SimpleWhitespace.field(
""
)
#: Equal sign for Templated string expression uses self-documenting expressions,
#: such as ``f"{x=}"``. See the `Python 3.8 release notes
#: <https://docs.python.org/3/whatsnew/3.8.html#f-strings-support-for-self-documenting-expressions-and-debugging>`_.
equal: Optional[AssignEqual] = None
def _validate(self) -> None:
if self.conversion is not None and self.conversion not in ("s", "r", "a"):
raise CSTValidationError("Invalid t-string conversion.")
def _visit_and_replace_children(
self, visitor: CSTVisitorT
) -> "TemplatedStringExpression":
format_spec = self.format_spec
return TemplatedStringExpression(
whitespace_before_expression=visit_required(
self,
"whitespace_before_expression",
self.whitespace_before_expression,
visitor,
),
expression=visit_required(self, "expression", self.expression, visitor),
equal=visit_optional(self, "equal", self.equal, visitor),
whitespace_after_expression=visit_required(
self,
"whitespace_after_expression",
self.whitespace_after_expression,
visitor,
),
conversion=self.conversion,
format_spec=(
visit_sequence(self, "format_spec", format_spec, visitor)
if format_spec is not None
else None
),
)
def _codegen_impl(self, state: CodegenState) -> None:
state.add_token("{")
self.whitespace_before_expression._codegen(state)
self.expression._codegen(state)
equal = self.equal
if equal is not None:
equal._codegen(state)
self.whitespace_after_expression._codegen(state)
conversion = self.conversion
if conversion is not None:
state.add_token("!")
state.add_token(conversion)
format_spec = self.format_spec
if format_spec is not None:
state.add_token(":")
for spec in format_spec:
spec._codegen(state)
state.add_token("}")
@add_slots
@dataclass(frozen=True)
class TemplatedString(_BasePrefixedString):
"""
An "t-string". Template strings are a generalization of f-strings,
using a t in place of the f prefix. Instead of evaluating to str,
t-strings evaluate to a new type: Template
T-Strings are defined in 'PEP 750'
>>> import libcst as cst
>>> cst.parse_expression('t"ab{cd}ef"')
TemplatedString(
parts=[
TemplatedStringText(
value='ab',
),
TemplatedStringExpression(
expression=Name(
value='cd',
lpar=[],
rpar=[],
),
conversion=None,
format_spec=None,
whitespace_before_expression=SimpleWhitespace(
value='',
),
whitespace_after_expression=SimpleWhitespace(
value='',
),
equal=None,
),
TemplatedStringText(
value='ef',
),
],
start='t"',
end='"',
lpar=[],
rpar=[],
)
>>>
"""
#: A templated string is composed as a series of :class:`TemplatedStringText` and
#: :class:`TemplatedStringExpression` parts.
parts: Sequence[BaseTemplatedStringContent]
#: The string prefix and the leading quote, such as ``t"``, ``T'``, ``tr"``, or
#: ``t"""``.
start: str = 't"'
#: The trailing quote. This must match the type of quote used in ``start``.
end: Literal['"', "'", '"""', "'''"] = '"'
lpar: Sequence[LeftParen] = ()
#: Sequence of parenthesis for precidence dictation.
rpar: Sequence[RightParen] = ()
def _validate(self) -> None:
super(_BasePrefixedString, self)._validate()
# Validate any prefix
prefix = self.prefix
if prefix not in ("t", "tr", "rt"):
raise CSTValidationError("Invalid t-string prefix.")
# Validate wrapping quotes
starttoken = self.start[len(prefix) :]
if starttoken != self.end:
raise CSTValidationError("t-string must have matching enclosing quotes.")
# Validate valid wrapping quote usage
if starttoken not in ('"', "'", '"""', "'''"):
raise CSTValidationError("Invalid t-string enclosing quotes.")
@property
def prefix(self) -> str:
"""
Returns the string's prefix, if any exists. The prefix can be ``t``,
``tr``, or ``rt``.
"""
prefix = ""
for c in self.start:
if c in ['"', "'"]:
break
prefix += c
return prefix.lower()
@property
def quote(self) -> StringQuoteLiteral:
"""
Returns the quotation used to denote the string. Can be either ``'``,
``"``, ``'''`` or ``\"\"\"``.
"""
return self.end
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TemplatedString":
return TemplatedString(
lpar=visit_sequence(self, "lpar", self.lpar, visitor),
start=self.start,
parts=visit_sequence(self, "parts", self.parts, visitor),
end=self.end,
rpar=visit_sequence(self, "rpar", self.rpar, visitor),
)
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
state.add_token(self.start)
for part in self.parts:
part._codegen(state)
state.add_token(self.end)
@add_slots
@dataclass(frozen=True)
class ConcatenatedString(BaseString):
@ -1259,7 +998,7 @@ class ConcatenatedString(BaseString):
elif isinstance(right, FormattedString):
rightbytes = "b" in right.prefix
else:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
if leftbytes != rightbytes:
raise CSTValidationError("Cannot concatenate string and bytes.")
@ -1281,7 +1020,7 @@ class ConcatenatedString(BaseString):
self.right._codegen(state)
@property
def evaluated_value(self) -> Union[str, bytes, None]:
def evaluated_value(self) -> Optional[str]:
"""
Return an :func:`ast.literal_eval` evaluated str of recursively concatenated :py:attr:`left` and :py:attr:`right`
if and only if both :py:attr:`left` and :py:attr:`right` are composed by :class:`SimpleString` or :class:`ConcatenatedString`
@ -1295,11 +1034,7 @@ class ConcatenatedString(BaseString):
right_val = right.evaluated_value
if right_val is None:
return None
if isinstance(left_val, bytes) and isinstance(right_val, bytes):
return left_val + right_val
if isinstance(left_val, str) and isinstance(right_val, str):
return left_val + right_val
return None
return left_val + right_val
@add_slots
@ -1680,8 +1415,6 @@ class BaseSlice(CSTNode, ABC):
This node is purely for typing.
"""
__slots__ = ()
@add_slots
@dataclass(frozen=True)
@ -1694,29 +1427,10 @@ class Index(BaseSlice):
#: The index value itself.
value: BaseExpression
#: An optional string with an asterisk appearing before the name. This is
#: expanded into variable number of positional arguments. See PEP-646
star: Optional[Literal["*"]] = None
#: Whitespace after the ``star`` (if it exists), but before the ``value``.
whitespace_after_star: Optional[BaseParenthesizableWhitespace] = None
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Index":
return Index(
star=self.star,
whitespace_after_star=visit_optional(
self, "whitespace_after_star", self.whitespace_after_star, visitor
),
value=visit_required(self, "value", self.value, visitor),
)
return Index(value=visit_required(self, "value", self.value, visitor))
def _codegen_impl(self, state: CodegenState) -> None:
star = self.star
if star is not None:
state.add_token(star)
ws = self.whitespace_after_star
if ws is not None:
ws._codegen(state)
self.value._codegen(state)
@ -1896,9 +1610,9 @@ class Annotation(CSTNode):
#: colon or arrow.
annotation: BaseExpression
whitespace_before_indicator: Union[BaseParenthesizableWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
whitespace_before_indicator: Union[
BaseParenthesizableWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
whitespace_after_indicator: BaseParenthesizableWhitespace = SimpleWhitespace.field(
" "
)
@ -1937,7 +1651,7 @@ class Annotation(CSTNode):
if default_indicator == "->":
state.add_token(" ")
else:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
# Now, output the indicator and the rest of the annotation
state.add_token(default_indicator)
@ -1982,26 +1696,15 @@ class ParamSlash(CSTNode):
.. _PEP 570: https://www.python.org/dev/peps/pep-0570/#specification
"""
#: Optional comma that comes after the slash. This comma doesn't own the whitespace
#: between ``/`` and ``,``.
# Optional comma that comes after the slash.
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT
#: Whitespace after the ``/`` character. This is captured here in case there is a
#: comma.
whitespace_after: BaseParenthesizableWhitespace = SimpleWhitespace.field("")
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ParamSlash":
return ParamSlash(
comma=visit_sentinel(self, "comma", self.comma, visitor),
whitespace_after=visit_required(
self, "whitespace_after", self.whitespace_after, visitor
),
)
return ParamSlash(comma=visit_sentinel(self, "comma", self.comma, visitor))
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
state.add_token("/")
self.whitespace_after._codegen(state)
comma = self.comma
if comma is MaybeSentinel.DEFAULT and default_comma:
state.add_token(", ")
@ -2230,25 +1933,6 @@ class Parameters(CSTNode):
star_kwarg=visit_optional(self, "star_kwarg", self.star_kwarg, visitor),
)
def _safe_to_join_with_lambda(self) -> bool:
"""
Determine if Parameters need a space after the `lambda` keyword. Returns True
iff it's safe to omit the space between `lambda` and these Parameters.
See also `BaseExpression._safe_to_use_with_word_operator`.
For example: `lambda*_: pass`
"""
if len(self.posonly_params) != 0:
return False
# posonly_ind can't appear if above condition is false
if len(self.params) > 0 and self.params[0].star not in {"*", "**"}:
return False
return True
def _codegen_impl(self, state: CodegenState) -> None: # noqa: C901
# Compute the star existence first so we can ask about whether
# each element is the last in the list or not.
@ -2350,16 +2034,9 @@ class Lambda(BaseExpression):
rpar: Sequence[RightParen] = ()
#: Whitespace after the lambda keyword, but before any argument or the colon.
whitespace_after_lambda: Union[BaseParenthesizableWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
if position == ExpressionPosition.LEFT:
return len(self.rpar) > 0 or self.body._safe_to_use_with_word_operator(
position
)
return super()._safe_to_use_with_word_operator(position)
whitespace_after_lambda: Union[
BaseParenthesizableWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
def _validate(self) -> None:
# Validate parents
@ -2388,7 +2065,6 @@ class Lambda(BaseExpression):
if (
isinstance(whitespace_after_lambda, BaseParenthesizableWhitespace)
and whitespace_after_lambda.empty
and not self.params._safe_to_join_with_lambda()
):
raise CSTValidationError(
"Must have at least one space after lambda when specifying params"
@ -2514,8 +2190,6 @@ class _BaseExpressionWithArgs(BaseExpression, ABC):
in typing. So, we have common validation functions here.
"""
__slots__ = ()
#: Sequence of arguments that will be passed to the function call.
args: Sequence[Arg] = ()
@ -2707,12 +2381,7 @@ class Await(BaseExpression):
# Validate any super-class stuff, whatever it may be.
super(Await, self)._validate()
# Make sure we don't run identifiers together.
if (
self.whitespace_after_await.empty
and not self.expression._safe_to_use_with_word_operator(
ExpressionPosition.RIGHT
)
):
if self.whitespace_after_await.empty:
raise CSTValidationError("Must have at least one space after await")
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Await":
@ -2766,12 +2435,6 @@ class IfExp(BaseExpression):
#: Whitespace after the ``else`` keyword, but before the ``orelse`` expression.
whitespace_after_else: BaseParenthesizableWhitespace = SimpleWhitespace.field(" ")
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
if position == ExpressionPosition.RIGHT:
return self.body._safe_to_use_with_word_operator(position)
else:
return self.orelse._safe_to_use_with_word_operator(position)
def _validate(self) -> None:
# Paren validation and such
super(IfExp, self)._validate()
@ -2850,9 +2513,9 @@ class From(CSTNode):
item: BaseExpression
#: The whitespace at the very start of this node.
whitespace_before_from: Union[BaseParenthesizableWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
whitespace_before_from: Union[
BaseParenthesizableWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
#: The whitespace after the ``from`` keyword, but before the ``item``.
whitespace_after_from: BaseParenthesizableWhitespace = SimpleWhitespace.field(" ")
@ -2911,9 +2574,9 @@ class Yield(BaseExpression):
rpar: Sequence[RightParen] = ()
#: Whitespace after the ``yield`` keyword, but before the ``value``.
whitespace_after_yield: Union[BaseParenthesizableWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
whitespace_after_yield: Union[
BaseParenthesizableWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
def _validate(self) -> None:
# Paren rules and such
@ -2968,8 +2631,6 @@ class _BaseElementImpl(CSTNode, ABC):
An internal base class for :class:`Element` and :class:`DictElement`.
"""
__slots__ = ()
value: BaseExpression
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT
@ -2997,7 +2658,8 @@ class _BaseElementImpl(CSTNode, ABC):
state: CodegenState,
default_comma: bool = False,
default_comma_whitespace: bool = False, # False for a single-item collection
) -> None: ...
) -> None:
...
class BaseElement(_BaseElementImpl, ABC):
@ -3006,8 +2668,6 @@ class BaseElement(_BaseElementImpl, ABC):
BaseDictElement.
"""
__slots__ = ()
class BaseDictElement(_BaseElementImpl, ABC):
"""
@ -3015,8 +2675,6 @@ class BaseDictElement(_BaseElementImpl, ABC):
BaseElement.
"""
__slots__ = ()
@add_slots
@dataclass(frozen=True)
@ -3103,7 +2761,7 @@ class DictElement(BaseDictElement):
@add_slots
@dataclass(frozen=True)
class StarredElement(BaseElement, BaseExpression, _BaseParenthesizedNode):
class StarredElement(BaseElement, _BaseParenthesizedNode):
"""
A starred ``*value`` element that expands to represent multiple values in a literal
:class:`List`, :class:`Tuple`, or :class:`Set`.
@ -3299,8 +2957,6 @@ class BaseList(BaseExpression, ABC):
object when evaluated.
"""
__slots__ = ()
lbracket: LeftSquareBracket = LeftSquareBracket.field()
#: Brackets surrounding the list.
rbracket: RightSquareBracket = RightSquareBracket.field()
@ -3381,8 +3037,6 @@ class _BaseSetOrDict(BaseExpression, ABC):
shouldn't be exported.
"""
__slots__ = ()
lbrace: LeftCurlyBrace = LeftCurlyBrace.field()
#: Braces surrounding the set or dict.
rbrace: RightCurlyBrace = RightCurlyBrace.field()
@ -3408,8 +3062,6 @@ class BaseSet(_BaseSetOrDict, ABC):
a set object when evaluated.
"""
__slots__ = ()
@add_slots
@dataclass(frozen=True)
@ -3479,8 +3131,6 @@ class BaseDict(_BaseSetOrDict, ABC):
a dict object when evaluated.
"""
__slots__ = ()
@add_slots
@dataclass(frozen=True)
@ -3757,8 +3407,6 @@ class BaseComp(BaseExpression, ABC):
:class:`GeneratorExp`, :class:`ListComp`, :class:`SetComp`, and :class:`DictComp`.
"""
__slots__ = ()
for_in: CompFor
@ -3769,12 +3417,10 @@ class BaseSimpleComp(BaseComp, ABC):
``value``.
"""
__slots__ = ()
#: The expression evaluated during each iteration of the comprehension. This
#: lexically comes before the ``for_in`` clause, but it is semantically the
#: inner-most element, evaluated inside the ``for_in`` clause.
elt: BaseExpression
elt: BaseAssignTargetExpression
#: The ``for ... in ... if ...`` clause that lexically comes after ``elt``. This may
#: be a nested structure for nested comprehensions. See :class:`CompFor` for
@ -3807,7 +3453,7 @@ class GeneratorExp(BaseSimpleComp):
"""
#: The expression evaluated and yielded during each iteration of the generator.
elt: BaseExpression
elt: BaseAssignTargetExpression
#: The ``for ... in ... if ...`` clause that comes after ``elt``. This may be a
#: nested structure for nested comprehensions. See :class:`CompFor` for details.
@ -3858,7 +3504,7 @@ class ListComp(BaseList, BaseSimpleComp):
"""
#: The expression evaluated and stored during each iteration of the comprehension.
elt: BaseExpression
elt: BaseAssignTargetExpression
#: The ``for ... in ... if ...`` clause that comes after ``elt``. This may be a
#: nested structure for nested comprehensions. See :class:`CompFor` for details.
@ -3900,7 +3546,7 @@ class SetComp(BaseSet, BaseSimpleComp):
"""
#: The expression evaluated and stored during each iteration of the comprehension.
elt: BaseExpression
elt: BaseAssignTargetExpression
#: The ``for ... in ... if ...`` clause that comes after ``elt``. This may be a
#: nested structure for nested comprehensions. See :class:`CompFor` for details.
@ -3942,10 +3588,10 @@ class DictComp(BaseDict, BaseComp):
"""
#: The key inserted into the dictionary during each iteration of the comprehension.
key: BaseExpression
key: BaseAssignTargetExpression
#: The value associated with the ``key`` inserted into the dictionary during each
#: iteration of the comprehension.
value: BaseExpression
value: BaseAssignTargetExpression
#: The ``for ... in ... if ...`` clause that lexically comes after ``key`` and
#: ``value``. This may be a nested structure for nested comprehensions. See
@ -4049,15 +3695,6 @@ class NamedExpr(BaseExpression):
rpar=visit_sequence(self, "rpar", self.rpar, visitor),
)
def _safe_to_use_with_word_operator(self, position: ExpressionPosition) -> bool:
if position == ExpressionPosition.LEFT:
return len(self.rpar) > 0 or self.value._safe_to_use_with_word_operator(
position
)
return len(self.lpar) > 0 or self.target._safe_to_use_with_word_operator(
position
)
def _codegen_impl(self, state: CodegenState) -> None:
with self._parenthesize(state):
self.target._codegen(state)

View file

@ -79,6 +79,7 @@ class Module(CSTNode):
has_trailing_newline=self.has_trailing_newline,
)
# pyre-fixme[14]: `visit` overrides method defined in `CSTNode` inconsistently.
def visit(self: _ModuleSelfT, visitor: CSTVisitorT) -> _ModuleSelfT:
"""
Returns the result of running a visitor over this module.

View file

@ -19,8 +19,6 @@ class _BaseOneTokenOp(CSTNode, ABC):
Any node that has a static value and needs to own whitespace on both sides.
"""
__slots__ = ()
whitespace_before: BaseParenthesizableWhitespace
whitespace_after: BaseParenthesizableWhitespace
@ -43,7 +41,8 @@ class _BaseOneTokenOp(CSTNode, ABC):
self.whitespace_after._codegen(state)
@abstractmethod
def _get_token(self) -> str: ...
def _get_token(self) -> str:
...
class _BaseTwoTokenOp(CSTNode, ABC):
@ -52,8 +51,6 @@ class _BaseTwoTokenOp(CSTNode, ABC):
in beteween them.
"""
__slots__ = ()
whitespace_before: BaseParenthesizableWhitespace
whitespace_between: BaseParenthesizableWhitespace
@ -87,7 +84,8 @@ class _BaseTwoTokenOp(CSTNode, ABC):
self.whitespace_after._codegen(state)
@abstractmethod
def _get_tokens(self) -> Tuple[str, str]: ...
def _get_tokens(self) -> Tuple[str, str]:
...
class BaseUnaryOp(CSTNode, ABC):
@ -95,8 +93,6 @@ class BaseUnaryOp(CSTNode, ABC):
Any node that has a static value used in a :class:`UnaryOperation` expression.
"""
__slots__ = ()
#: Any space that appears directly after this operator.
whitespace_after: BaseParenthesizableWhitespace
@ -113,7 +109,8 @@ class BaseUnaryOp(CSTNode, ABC):
self.whitespace_after._codegen(state)
@abstractmethod
def _get_token(self) -> str: ...
def _get_token(self) -> str:
...
class BaseBooleanOp(_BaseOneTokenOp, ABC):
@ -122,8 +119,6 @@ class BaseBooleanOp(_BaseOneTokenOp, ABC):
This node is purely for typing.
"""
__slots__ = ()
class BaseBinaryOp(CSTNode, ABC):
"""
@ -131,8 +126,6 @@ class BaseBinaryOp(CSTNode, ABC):
This node is purely for typing.
"""
__slots__ = ()
class BaseCompOp(CSTNode, ABC):
"""
@ -140,8 +133,6 @@ class BaseCompOp(CSTNode, ABC):
This node is purely for typing.
"""
__slots__ = ()
class BaseAugOp(CSTNode, ABC):
"""
@ -149,8 +140,6 @@ class BaseAugOp(CSTNode, ABC):
This node is purely for typing.
"""
__slots__ = ()
@add_slots
@dataclass(frozen=True)

View file

@ -7,9 +7,7 @@ import inspect
import re
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import Literal, Optional, Pattern, Sequence, Union
from libcst import CSTLogicError
from typing import Optional, Pattern, Sequence, Union
from libcst._add_slots import add_slots
from libcst._maybe_sentinel import MaybeSentinel
@ -23,6 +21,7 @@ from libcst._nodes.expression import (
BaseAssignTargetExpression,
BaseDelTargetExpression,
BaseExpression,
Call,
ConcatenatedString,
ExpressionPosition,
From,
@ -50,7 +49,6 @@ from libcst._nodes.op import (
AssignEqual,
BaseAugOp,
BitOr,
Colon,
Comma,
Dot,
ImportStar,
@ -81,8 +79,6 @@ class BaseSuite(CSTNode, ABC):
-- https://docs.python.org/3/reference/compound_stmts.html
"""
__slots__ = ()
body: Union[Sequence["BaseStatement"], Sequence["BaseSmallStatement"]]
@ -92,7 +88,7 @@ class BaseStatement(CSTNode, ABC):
in a particular location.
"""
__slots__ = ()
pass
class BaseSmallStatement(CSTNode, ABC):
@ -103,8 +99,6 @@ class BaseSmallStatement(CSTNode, ABC):
simplify type definitions and isinstance checks.
"""
__slots__ = ()
#: An optional semicolon that appears after a small statement. This is optional
#: for the last small statement in a :class:`SimpleStatementLine` or
#: :class:`SimpleStatementSuite`, but all other small statements inside a simple
@ -115,7 +109,8 @@ class BaseSmallStatement(CSTNode, ABC):
@abstractmethod
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None: ...
) -> None:
...
@add_slots
@ -274,9 +269,9 @@ class Return(BaseSmallStatement):
#: Optional whitespace after the ``return`` keyword before the optional
#: value expression.
whitespace_after_return: Union[SimpleWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
whitespace_after_return: Union[
SimpleWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
#: Optional semicolon when this is used in a statement line. This semicolon
#: owns the whitespace on both sides of it when it is used.
@ -375,8 +370,6 @@ class _BaseSimpleStatement(CSTNode, ABC):
small statement.
"""
__slots__ = ()
#: Sequence of small statements. All but the last statement are required to have
#: a semicolon.
body: Sequence[BaseSmallStatement]
@ -561,8 +554,6 @@ class BaseCompoundStatement(BaseStatement, ABC):
-- https://docs.python.org/3/reference/compound_stmts.html
"""
__slots__ = ()
#: The body of this compound statement.
body: BaseSuite
@ -600,12 +591,7 @@ class If(BaseCompoundStatement):
#: The whitespace appearing after the test expression but before the colon.
whitespace_after_test: SimpleWhitespace = SimpleWhitespace.field("")
def _validate(self) -> None:
if (
self.whitespace_before_test.empty
and not self.test._safe_to_use_with_word_operator(ExpressionPosition.RIGHT)
):
raise CSTValidationError("Must have at least one space after 'if' keyword.")
# TODO: _validate
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "If":
return If(
@ -752,13 +738,12 @@ class AsName(CSTNode):
whitespace_after_as: BaseParenthesizableWhitespace = SimpleWhitespace.field(" ")
def _validate(self) -> None:
if (
self.whitespace_after_as.empty
and not self.name._safe_to_use_with_word_operator(ExpressionPosition.RIGHT)
):
if self.whitespace_after_as.empty:
raise CSTValidationError(
"There must be at least one space between 'as' and name."
)
if self.whitespace_before_as.empty:
raise CSTValidationError("There must be at least one space before 'as'.")
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "AsName":
return AsName(
@ -822,16 +807,6 @@ class ExceptHandler(CSTNode):
raise CSTValidationError(
"Must have at least one space after except when ExceptHandler has a type."
)
name = self.name
if (
type_ is not None
and name is not None
and name.whitespace_before_as.empty
and not type_._safe_to_use_with_word_operator(ExpressionPosition.LEFT)
):
raise CSTValidationError(
"Must have at least one space before as keyword in an except handler."
)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ExceptHandler":
return ExceptHandler(
@ -1156,21 +1131,18 @@ class ImportAlias(CSTNode):
def _validate(self) -> None:
asname = self.asname
if asname is not None:
if not isinstance(asname.name, Name):
raise CSTValidationError(
"Must use a Name node for AsName name inside ImportAlias."
)
if asname.whitespace_before_as.empty:
raise CSTValidationError(
"Must have at least one space before as keyword in an ImportAlias."
)
if asname is not None and not isinstance(asname.name, Name):
raise CSTValidationError(
"Must use a Name node for AsName name inside ImportAlias."
)
try:
self.evaluated_name
except CSTLogicError as e:
raise CSTValidationError(
"The imported name must be a valid qualified name."
) from e
except Exception as e:
if str(e) == "Logic error!":
raise CSTValidationError(
"The imported name must be a valid qualified name."
)
raise e
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ImportAlias":
return ImportAlias(
@ -1199,7 +1171,7 @@ class ImportAlias(CSTNode):
elif isinstance(node, Attribute):
return f"{self._name(node.value)}.{node.attr.value}"
else:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
@property
def evaluated_name(self) -> str:
@ -1626,7 +1598,7 @@ class Decorator(CSTNode):
#: The decorator that will return a new function wrapping the parent
#: of this decorator.
decorator: BaseExpression
decorator: Union[Name, Attribute, Call]
#: Line comments and empty lines before this decorator. The parent
#: :class:`FunctionDef` or :class:`ClassDef` node owns leading lines before
@ -1639,6 +1611,19 @@ class Decorator(CSTNode):
#: Optional trailing comment and newline following the decorator before the next line.
trailing_whitespace: TrailingWhitespace = TrailingWhitespace.field()
def _validate(self) -> None:
decorator = self.decorator
if len(decorator.lpar) > 0 or len(decorator.rpar) > 0:
raise CSTValidationError(
"Cannot have parens around decorator in a Decorator."
)
if isinstance(decorator, Call) and not isinstance(
decorator.func, (Name, Attribute)
):
raise CSTValidationError(
"Decorator call function must be Name or Attribute node."
)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "Decorator":
return Decorator(
leading_lines=visit_sequence(
@ -1694,8 +1679,6 @@ def get_docstring_impl(
evaluated_value = val.evaluated_value
else:
return None
if isinstance(evaluated_value, bytes):
return None
if evaluated_value is not None and clean:
return inspect.cleandoc(evaluated_value)
@ -1743,8 +1726,8 @@ class FunctionDef(BaseCompoundStatement):
#: Whitespace after the ``def`` keyword and before the function name.
whitespace_after_def: SimpleWhitespace = SimpleWhitespace.field(" ")
#: Whitespace after the function name and before the type parameters or the opening
#: parenthesis for the parameters.
#: Whitespace after the function name and before the opening parenthesis for
#: the parameters.
whitespace_after_name: SimpleWhitespace = SimpleWhitespace.field("")
#: Whitespace after the opening parenthesis for the parameters but before
@ -1755,13 +1738,6 @@ class FunctionDef(BaseCompoundStatement):
#: the colon.
whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("")
#: An optional declaration of type parameters.
type_parameters: Optional["TypeParameters"] = None
#: Whitespace between the type parameters and the opening parenthesis for the
#: (non-type) parameters.
whitespace_after_type_parameters: SimpleWhitespace = SimpleWhitespace.field("")
def _validate(self) -> None:
if len(self.name.lpar) > 0 or len(self.name.rpar) > 0:
raise CSTValidationError("Cannot have parens around Name in a FunctionDef.")
@ -1770,15 +1746,6 @@ class FunctionDef(BaseCompoundStatement):
"There must be at least one space between 'def' and name."
)
if (
self.type_parameters is None
and not self.whitespace_after_type_parameters.empty
):
raise CSTValidationError(
"whitespace_after_type_parameters must be empty if there are no type "
"parameters in FunctionDef"
)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "FunctionDef":
return FunctionDef(
leading_lines=visit_sequence(
@ -1798,15 +1765,6 @@ class FunctionDef(BaseCompoundStatement):
whitespace_after_name=visit_required(
self, "whitespace_after_name", self.whitespace_after_name, visitor
),
type_parameters=visit_optional(
self, "type_parameters", self.type_parameters, visitor
),
whitespace_after_type_parameters=visit_required(
self,
"whitespace_after_type_parameters",
self.whitespace_after_type_parameters,
visitor,
),
whitespace_before_params=visit_required(
self, "whitespace_before_params", self.whitespace_before_params, visitor
),
@ -1835,10 +1793,6 @@ class FunctionDef(BaseCompoundStatement):
self.whitespace_after_def._codegen(state)
self.name._codegen(state)
self.whitespace_after_name._codegen(state)
type_params = self.type_parameters
if type_params is not None:
type_params._codegen(state)
self.whitespace_after_type_parameters._codegen(state)
state.add_token("(")
self.whitespace_before_params._codegen(state)
self.params._codegen(state)
@ -1900,34 +1854,19 @@ class ClassDef(BaseCompoundStatement):
#: Whitespace after the ``class`` keyword and before the class name.
whitespace_after_class: SimpleWhitespace = SimpleWhitespace.field(" ")
#: Whitespace after the class name and before the type parameters or the opening
#: parenthesis for the bases and keywords.
#: Whitespace after the class name and before the opening parenthesis for
#: the bases and keywords.
whitespace_after_name: SimpleWhitespace = SimpleWhitespace.field("")
#: Whitespace after the closing parenthesis or class name and before
#: the colon.
whitespace_before_colon: SimpleWhitespace = SimpleWhitespace.field("")
#: An optional declaration of type parameters.
type_parameters: Optional["TypeParameters"] = None
#: Whitespace between type parameters and opening parenthesis for the bases and
#: keywords.
whitespace_after_type_parameters: SimpleWhitespace = SimpleWhitespace.field("")
def _validate_whitespace(self) -> None:
if self.whitespace_after_class.empty:
raise CSTValidationError(
"There must be at least one space between 'class' and name."
)
if (
self.type_parameters is None
and not self.whitespace_after_type_parameters.empty
):
raise CSTValidationError(
"whitespace_after_type_parameters must be empty if there are no type"
"parameters in a ClassDef"
)
def _validate_parens(self) -> None:
if len(self.name.lpar) > 0 or len(self.name.rpar) > 0:
@ -1970,15 +1909,6 @@ class ClassDef(BaseCompoundStatement):
whitespace_after_name=visit_required(
self, "whitespace_after_name", self.whitespace_after_name, visitor
),
type_parameters=visit_optional(
self, "type_parameters", self.type_parameters, visitor
),
whitespace_after_type_parameters=visit_required(
self,
"whitespace_after_type_parameters",
self.whitespace_after_type_parameters,
visitor,
),
lpar=visit_sentinel(self, "lpar", self.lpar, visitor),
bases=visit_sequence(self, "bases", self.bases, visitor),
keywords=visit_sequence(self, "keywords", self.keywords, visitor),
@ -2003,10 +1933,6 @@ class ClassDef(BaseCompoundStatement):
self.whitespace_after_class._codegen(state)
self.name._codegen(state)
self.whitespace_after_name._codegen(state)
type_params = self.type_parameters
if type_params is not None:
type_params._codegen(state)
self.whitespace_after_type_parameters._codegen(state)
lpar = self.lpar
if isinstance(lpar, MaybeSentinel):
if self.bases or self.keywords:
@ -2052,15 +1978,6 @@ class WithItem(CSTNode):
#: other items inside a with block must contain a comma to separate them.
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT
def _validate(self) -> None:
asname = self.asname
if (
asname is not None
and asname.whitespace_before_as.empty
and not self.item._safe_to_use_with_word_operator(ExpressionPosition.LEFT)
):
raise CSTValidationError("Must have at least one space before as keyword.")
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "WithItem":
return WithItem(
item=visit_required(self, "item", self.item, visitor),
@ -2402,9 +2319,9 @@ class Raise(BaseSmallStatement):
cause: Optional[From] = None
#: Any whitespace appearing between the ``raise`` keyword and the exception.
whitespace_after_raise: Union[SimpleWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
whitespace_after_raise: Union[
SimpleWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
#: Optional semicolon when this is used in a statement line. This semicolon
#: owns the whitespace on both sides of it when it is used.
@ -2716,8 +2633,6 @@ class MatchPattern(_BaseParenthesizedNode, ABC):
statement.
"""
__slots__ = ()
@add_slots
@dataclass(frozen=True)
@ -2762,6 +2677,11 @@ class Match(BaseCompoundStatement):
if len(self.cases) == 0:
raise CSTValidationError("A match statement must have at least one case.")
if self.whitespace_after_match.empty:
raise CSTValidationError(
"Must have at least one space after a 'match' keyword"
)
indent = self.indent
if indent is not None:
if len(indent) == 0:
@ -2858,16 +2778,17 @@ class MatchCase(CSTNode):
self, "whitespace_after_case", self.whitespace_after_case, visitor
),
pattern=visit_required(self, "pattern", self.pattern, visitor),
whitespace_before_if=visit_required(
# pyre-fixme[6]: Expected `SimpleWhitespace` for 4th param but got
# `Optional[SimpleWhitespace]`.
whitespace_before_if=visit_optional(
self, "whitespace_before_if", self.whitespace_before_if, visitor
),
whitespace_after_if=visit_required(
# pyre-fixme[6]: Expected `SimpleWhitespace` for 5th param but got
# `Optional[SimpleWhitespace]`.
whitespace_after_if=visit_optional(
self, "whitespace_after_if", self.whitespace_after_if, visitor
),
guard=visit_optional(self, "guard", self.guard, visitor),
whitespace_before_colon=visit_required(
self, "whitespace_before_colon", self.whitespace_before_colon, visitor
),
body=visit_required(self, "body", self.body, visitor),
)
@ -2886,9 +2807,6 @@ class MatchCase(CSTNode):
state.add_token("if")
self.whitespace_after_if._codegen(state)
guard._codegen(state)
else:
self.whitespace_before_if._codegen(state)
self.whitespace_after_if._codegen(state)
self.whitespace_before_colon._codegen(state)
state.add_token(":")
@ -2920,14 +2838,6 @@ class MatchValue(MatchPattern):
def lpar(self, value: Sequence[LeftParen]) -> None:
self.value.lpar = value
@property
def rpar(self) -> Sequence[RightParen]:
return self.value.rpar
@rpar.setter
def rpar(self, value: Sequence[RightParen]) -> None:
self.value.rpar = value
@add_slots
@dataclass(frozen=True)
@ -2961,15 +2871,6 @@ class MatchSingleton(MatchPattern):
# pyre-fixme[41]: Cannot reassign final attribute `lpar`.
self.value.lpar = value
@property
def rpar(self) -> Sequence[RightParen]:
return self.value.rpar
@rpar.setter
def rpar(self, value: Sequence[RightParen]) -> None:
# pyre-fixme[41]: Cannot reassign final attribute `rpar`.
self.value.rpar = value
@add_slots
@dataclass(frozen=True)
@ -3059,8 +2960,6 @@ class MatchSequence(MatchPattern, ABC):
otherwise matches a fixed length sequence.
"""
__slots__ = ()
#: Patterns to be matched against the subject elements if it is a sequence.
patterns: Sequence[Union[MatchSequenceElement, MatchStar]]
@ -3077,10 +2976,10 @@ class MatchList(MatchSequence):
patterns: Sequence[Union[MatchSequenceElement, MatchStar]]
#: An optional left bracket. If missing, this is an open sequence pattern.
lbracket: Optional[LeftSquareBracket] = None
lbracket: Optional[LeftSquareBracket] = LeftSquareBracket.field()
#: An optional left bracket. If missing, this is an open sequence pattern.
rbracket: Optional[RightSquareBracket] = None
rbracket: Optional[RightSquareBracket] = RightSquareBracket.field()
#: Parenthesis at the beginning of the node
lpar: Sequence[LeftParen] = ()
@ -3388,7 +3287,6 @@ class MatchClass(MatchPattern):
whitespace_after_kwds=visit_required(
self, "whitespace_after_kwds", self.whitespace_after_kwds, visitor
),
rpar=visit_sequence(self, "rpar", self.rpar, visitor),
)
def _codegen_impl(self, state: CodegenState) -> None:
@ -3425,15 +3323,15 @@ class MatchAs(MatchPattern):
#: Whitespace between ``pattern`` and the ``as`` keyword (if ``pattern`` is not
#: ``None``)
whitespace_before_as: Union[BaseParenthesizableWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
whitespace_before_as: Union[
BaseParenthesizableWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
#: Whitespace between the ``as`` keyword and ``name`` (if ``pattern`` is not
#: ``None``)
whitespace_after_as: Union[BaseParenthesizableWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
whitespace_after_as: Union[
BaseParenthesizableWhitespace, MaybeSentinel
] = MaybeSentinel.DEFAULT
#: Parenthesis at the beginning of the node
lpar: Sequence[LeftParen] = ()
@ -3476,13 +3374,6 @@ class MatchAs(MatchPattern):
state.add_token(" ")
elif isinstance(ws_after, BaseParenthesizableWhitespace):
ws_after._codegen(state)
else:
ws_before = self.whitespace_before_as
if isinstance(ws_before, BaseParenthesizableWhitespace):
ws_before._codegen(state)
ws_after = self.whitespace_after_as
if isinstance(ws_after, BaseParenthesizableWhitespace):
ws_after._codegen(state)
if name is None:
state.add_token("_")
else:
@ -3548,326 +3439,3 @@ class MatchOr(MatchPattern):
pats = self.patterns
for idx, pat in enumerate(pats):
pat._codegen(state, default_separator=idx + 1 < len(pats))
@add_slots
@dataclass(frozen=True)
class TypeVar(CSTNode):
"""
A simple (non-variadic) type variable.
Note: this node represents type a variable when declared using PEP-695 syntax.
"""
#: The name of the type variable.
name: Name
#: An optional bound on the type.
bound: Optional[BaseExpression] = None
#: The colon used to separate the name and bound. If not specified,
#: :class:`MaybeSentinel` will be replaced with a colon if there is a bound,
#: otherwise will be left empty.
colon: Union[Colon, MaybeSentinel] = MaybeSentinel.DEFAULT
def _codegen_impl(self, state: CodegenState) -> None:
with state.record_syntactic_position(self):
self.name._codegen(state)
bound = self.bound
colon = self.colon
if not isinstance(colon, MaybeSentinel):
colon._codegen(state)
else:
if bound is not None:
state.add_token(": ")
if bound is not None:
bound._codegen(state)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeVar":
return TypeVar(
name=visit_required(self, "name", self.name, visitor),
colon=visit_sentinel(self, "colon", self.colon, visitor),
bound=visit_optional(self, "bound", self.bound, visitor),
)
@add_slots
@dataclass(frozen=True)
class TypeVarTuple(CSTNode):
"""
A variadic type variable.
"""
#: The name of this type variable.
name: Name
#: The (optional) whitespace between the star declaring this type variable as
#: variadic, and the variable's name.
whitespace_after_star: SimpleWhitespace = SimpleWhitespace.field("")
def _codegen_impl(self, state: CodegenState) -> None:
with state.record_syntactic_position(self):
state.add_token("*")
self.whitespace_after_star._codegen(state)
self.name._codegen(state)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeVarTuple":
return TypeVarTuple(
name=visit_required(self, "name", self.name, visitor),
whitespace_after_star=visit_required(
self, "whitespace_after_star", self.whitespace_after_star, visitor
),
)
@add_slots
@dataclass(frozen=True)
class ParamSpec(CSTNode):
"""
A parameter specification.
Note: this node represents a parameter specification when declared using PEP-695
syntax.
"""
#: The name of this parameter specification.
name: Name
#: The (optional) whitespace between the double star declaring this type variable as
#: a parameter specification, and the name.
whitespace_after_star: SimpleWhitespace = SimpleWhitespace.field("")
def _codegen_impl(self, state: CodegenState) -> None:
with state.record_syntactic_position(self):
state.add_token("**")
self.whitespace_after_star._codegen(state)
self.name._codegen(state)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "ParamSpec":
return ParamSpec(
name=visit_required(self, "name", self.name, visitor),
whitespace_after_star=visit_required(
self, "whitespace_after_star", self.whitespace_after_star, visitor
),
)
@add_slots
@dataclass(frozen=True)
class TypeParam(CSTNode):
"""
A single type parameter that is contained in a :class:`TypeParameters` list.
"""
#: The actual parameter.
param: Union[TypeVar, TypeVarTuple, ParamSpec]
#: A trailing comma. If one is not provided, :class:`MaybeSentinel` will be replaced
#: with a comma only if a comma is required.
comma: Union[Comma, MaybeSentinel] = MaybeSentinel.DEFAULT
#: The equal sign used to denote assignment if there is a default.
equal: Union[AssignEqual, MaybeSentinel] = MaybeSentinel.DEFAULT
#: The star used to denote a variadic default
star: Literal["", "*"] = ""
#: The whitespace between the star and the type.
whitespace_after_star: SimpleWhitespace = SimpleWhitespace.field("")
#: Any optional default value, used when the argument is not supplied.
default: Optional[BaseExpression] = None
def _codegen_impl(self, state: CodegenState, default_comma: bool = False) -> None:
self.param._codegen(state)
equal = self.equal
if equal is MaybeSentinel.DEFAULT and self.default is not None:
state.add_token(" = ")
elif isinstance(equal, AssignEqual):
equal._codegen(state)
state.add_token(self.star)
self.whitespace_after_star._codegen(state)
default = self.default
if default is not None:
default._codegen(state)
comma = self.comma
if isinstance(comma, MaybeSentinel):
if default_comma:
state.add_token(", ")
else:
comma._codegen(state)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeParam":
ret = TypeParam(
param=visit_required(self, "param", self.param, visitor),
equal=visit_sentinel(self, "equal", self.equal, visitor),
star=self.star,
whitespace_after_star=visit_required(
self, "whitespace_after_star", self.whitespace_after_star, visitor
),
default=visit_optional(self, "default", self.default, visitor),
comma=visit_sentinel(self, "comma", self.comma, visitor),
)
return ret
def _validate(self) -> None:
if self.default is None and isinstance(self.equal, AssignEqual):
raise CSTValidationError(
"Must have a default when specifying an AssignEqual."
)
if self.star and not (self.default or isinstance(self.equal, AssignEqual)):
raise CSTValidationError("Star can only be present if a default")
if isinstance(self.star, str) and self.star not in ("", "*"):
raise CSTValidationError("Must specify either '' or '*' for star.")
@add_slots
@dataclass(frozen=True)
class TypeParameters(CSTNode):
"""
Type parameters when specified with PEP-695 syntax.
This node captures all specified parameters that are enclosed with square brackets.
"""
#: The parameters within the square brackets.
params: Sequence[TypeParam] = ()
#: Opening square bracket that marks the start of these parameters.
lbracket: LeftSquareBracket = LeftSquareBracket.field()
#: Closing square bracket that marks the end of these parameters.
rbracket: RightSquareBracket = RightSquareBracket.field()
def _codegen_impl(self, state: CodegenState) -> None:
self.lbracket._codegen(state)
params_len = len(self.params)
for idx, param in enumerate(self.params):
param._codegen(state, default_comma=idx + 1 < params_len)
self.rbracket._codegen(state)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeParameters":
return TypeParameters(
lbracket=visit_required(self, "lbracket", self.lbracket, visitor),
params=visit_sequence(self, "params", self.params, visitor),
rbracket=visit_required(self, "rbracket", self.rbracket, visitor),
)
@add_slots
@dataclass(frozen=True)
class TypeAlias(BaseSmallStatement):
"""
A type alias statement.
This node represents the ``type`` statement as specified initially by PEP-695.
Example: ``type ListOrSet[T] = list[T] | set[T]``.
"""
#: The name being introduced in this statement.
name: Name
#: Everything on the right hand side of the ``=``.
value: BaseExpression
#: An optional list of type parameters, specified after the name.
type_parameters: Optional[TypeParameters] = None
#: Whitespace between the ``type`` soft keyword and the name.
whitespace_after_type: SimpleWhitespace = SimpleWhitespace.field(" ")
#: Whitespace between the name and the type parameters (if they exist) or the ``=``.
#: If not specified, :class:`MaybeSentinel` will be replaced with a single space if
#: there are no type parameters, otherwise no spaces.
whitespace_after_name: Union[SimpleWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
#: Whitespace between the type parameters and the ``=``. Always empty if there are
#: no type parameters. If not specified, :class:`MaybeSentinel` will be replaced
#: with a single space if there are type parameters.
whitespace_after_type_parameters: Union[SimpleWhitespace, MaybeSentinel] = (
MaybeSentinel.DEFAULT
)
#: Whitespace between the ``=`` and the value.
whitespace_after_equals: SimpleWhitespace = SimpleWhitespace.field(" ")
#: Optional semicolon when this is used in a statement line. This semicolon
#: owns the whitespace on both sides of it when it is used.
semicolon: Union[Semicolon, MaybeSentinel] = MaybeSentinel.DEFAULT
def _validate(self) -> None:
if (
self.type_parameters is None
and self.whitespace_after_type_parameters
not in {
SimpleWhitespace(""),
MaybeSentinel.DEFAULT,
}
):
raise CSTValidationError(
"whitespace_after_type_parameters must be empty when there are no type parameters in a TypeAlias"
)
def _visit_and_replace_children(self, visitor: CSTVisitorT) -> "TypeAlias":
return TypeAlias(
whitespace_after_type=visit_required(
self, "whitespace_after_type", self.whitespace_after_type, visitor
),
name=visit_required(self, "name", self.name, visitor),
whitespace_after_name=visit_sentinel(
self, "whitespace_after_name", self.whitespace_after_name, visitor
),
type_parameters=visit_optional(
self, "type_parameters", self.type_parameters, visitor
),
whitespace_after_type_parameters=visit_sentinel(
self,
"whitespace_after_type_parameters",
self.whitespace_after_type_parameters,
visitor,
),
whitespace_after_equals=visit_required(
self, "whitespace_after_equals", self.whitespace_after_equals, visitor
),
value=visit_required(self, "value", self.value, visitor),
semicolon=visit_sentinel(self, "semicolon", self.semicolon, visitor),
)
def _codegen_impl(
self, state: CodegenState, default_semicolon: bool = False
) -> None:
with state.record_syntactic_position(self):
state.add_token("type")
self.whitespace_after_type._codegen(state)
self.name._codegen(state)
ws_after_name = self.whitespace_after_name
if isinstance(ws_after_name, MaybeSentinel):
if self.type_parameters is None:
state.add_token(" ")
else:
ws_after_name._codegen(state)
ws_after_type_params = self.whitespace_after_type_parameters
if self.type_parameters is not None:
self.type_parameters._codegen(state)
if isinstance(ws_after_type_params, MaybeSentinel):
state.add_token(" ")
else:
ws_after_type_params._codegen(state)
state.add_token("=")
self.whitespace_after_equals._codegen(state)
self.value._codegen(state)
semi = self.semicolon
if isinstance(semi, MaybeSentinel):
if default_semicolon:
state.add_token("; ")
else:
semi._codegen(state)

View file

@ -239,7 +239,7 @@ class CSTNodeTest(UnitTest):
def assert_parses(
self,
code: str,
parser: Callable[[str], cst.CSTNode],
parser: Callable[[str], cst.BaseExpression],
expect_success: bool,
) -> None:
if not expect_success:

View file

@ -9,6 +9,7 @@ from typing import Any
import libcst as cst
from libcst import parse_expression
from libcst._nodes.tests.base import CSTNodeTest, parse_expression_as
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
@ -739,69 +740,6 @@ class AtomTest(CSTNodeTest):
"parser": parse_expression,
"expected_position": None,
},
# Unpacked tuple
{
"node": cst.FormattedString(
parts=[
cst.FormattedStringExpression(
expression=cst.Tuple(
elements=[
cst.Element(
value=cst.Name(
value="a",
),
comma=cst.Comma(
whitespace_before=cst.SimpleWhitespace(
value="",
),
whitespace_after=cst.SimpleWhitespace(
value=" ",
),
),
),
cst.Element(
value=cst.Name(
value="b",
),
),
],
lpar=[],
rpar=[],
),
),
],
start="f'",
end="'",
),
"code": "f'{a, b}'",
"parser": parse_expression,
"expected_position": None,
},
# Conditional expression
{
"node": cst.FormattedString(
parts=[
cst.FormattedStringExpression(
expression=cst.IfExp(
test=cst.Name(
value="b",
),
body=cst.Name(
value="a",
),
orelse=cst.Name(
value="c",
),
),
),
],
start="f'",
end="'",
),
"code": "f'{a if b else c}'",
"parser": parse_expression,
"expected_position": None,
},
# Concatenated strings
{
"node": cst.ConcatenatedString(
@ -1183,7 +1121,7 @@ class AtomTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -46,14 +46,6 @@ class AwaitTest(CSTNodeTest):
),
"expected_position": CodeRange((1, 2), (1, 13)),
},
# Whitespace after await
{
"node": cst.Await(
cst.Name("foo", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]),
whitespace_after_await=cst.SimpleWhitespace(""),
),
"code": "await(foo)",
},
)
)
def test_valid_py37(self, **kwargs: Any) -> None:

View file

@ -174,18 +174,3 @@ class BinaryOperationTest(CSTNodeTest):
)
def test_invalid(self, **kwargs: Any) -> None:
self.assert_invalid(**kwargs)
@data_provider(
(
{
"code": '"a"' * 6000,
"parser": parse_expression,
},
{
"code": "[_" + " for _ in _" * 6000 + "]",
"parser": parse_expression,
},
)
)
def test_parse_error(self, **kwargs: Any) -> None:
self.assert_parses(**kwargs, expect_success=False)

View file

@ -112,105 +112,6 @@ class ClassDefCreationTest(CSTNodeTest):
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
{
"node": cst.ClassDef(
cst.Name("Foo"),
cst.SimpleStatementSuite((cst.Pass(),)),
type_parameters=cst.TypeParameters(
(
cst.TypeParam(
cst.TypeVar(
cst.Name("T"),
bound=cst.Name("int"),
colon=cst.Colon(
whitespace_after=cst.SimpleWhitespace(" ")
),
),
cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
cst.TypeParam(
cst.TypeVarTuple(cst.Name("Ts")),
cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
cst.TypeParam(cst.ParamSpec(cst.Name("KW"))),
)
),
),
"code": "class Foo[T: int, *Ts, **KW]: pass\n",
},
{
"node": cst.ClassDef(
cst.Name("Foo"),
cst.SimpleStatementSuite((cst.Pass(),)),
type_parameters=cst.TypeParameters(
params=(
cst.TypeParam(
param=cst.TypeVar(
cst.Name("T"),
bound=cst.Name("str"),
colon=cst.Colon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.ParenthesizedWhitespace(
empty_lines=(cst.EmptyLine(),),
indent=True,
),
),
),
comma=cst.Comma(cst.SimpleWhitespace(" ")),
),
cst.TypeParam(
cst.ParamSpec(
cst.Name("PS"), cst.SimpleWhitespace(" ")
),
cst.Comma(cst.SimpleWhitespace(" ")),
),
)
),
whitespace_after_type_parameters=cst.SimpleWhitespace(" "),
),
"code": "class Foo[T :\n\nstr ,** PS ,] : pass\n",
},
{
"node": cst.ClassDef(
cst.Name("Foo"),
cst.SimpleStatementSuite((cst.Pass(),)),
type_parameters=cst.TypeParameters(
params=(
cst.TypeParam(
param=cst.TypeVar(
cst.Name("T"),
bound=cst.Name("str"),
colon=cst.Colon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.ParenthesizedWhitespace(
empty_lines=(cst.EmptyLine(),),
indent=True,
),
),
),
comma=cst.Comma(cst.SimpleWhitespace(" ")),
),
cst.TypeParam(
cst.ParamSpec(
cst.Name("PS"), cst.SimpleWhitespace(" ")
),
cst.Comma(cst.SimpleWhitespace(" ")),
),
)
),
lpar=cst.LeftParen(),
rpar=cst.RightParen(),
whitespace_after_type_parameters=cst.SimpleWhitespace(" "),
),
"code": "class Foo[T :\n\nstr ,** PS ,] (): pass\n",
},
)
)
def test_valid_native(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
# Basic parenthesis tests.

View file

@ -8,6 +8,7 @@ from typing import Any
import libcst as cst
from libcst import parse_expression
from libcst._nodes.tests.base import CSTNodeTest, parse_expression_as
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
@ -187,6 +188,6 @@ class DictTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -26,17 +26,6 @@ class DictCompTest(CSTNodeTest):
"parser": parse_expression,
"expected_position": CodeRange((1, 0), (1, 17)),
},
# non-trivial keys & values in DictComp
{
"node": cst.DictComp(
cst.BinaryOperation(cst.Name("k1"), cst.Add(), cst.Name("k2")),
cst.BinaryOperation(cst.Name("v1"), cst.Add(), cst.Name("v2")),
cst.CompFor(target=cst.Name("a"), iter=cst.Name("b")),
),
"code": "{k1 + k2: v1 + v2 for a in b}",
"parser": parse_expression,
"expected_position": CodeRange((1, 0), (1, 29)),
},
# custom whitespace around colon
{
"node": cst.DictComp(

View file

@ -8,6 +8,7 @@ from typing import Any, Callable
import libcst as cst
from libcst import parse_statement
from libcst._nodes.tests.base import CSTNodeTest, DummyIndentedBlock, parse_statement_as
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
@ -622,46 +623,6 @@ class FunctionDefCreationTest(CSTNodeTest):
"code": "@ bar ( )\n",
"expected_position": CodeRange((1, 0), (1, 10)),
},
# Allow nested calls on decorator
{
"node": cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(),
cst.SimpleStatementSuite((cst.Pass(),)),
(cst.Decorator(cst.Call(func=cst.Call(func=cst.Name("bar")))),),
),
"code": "@bar()()\ndef foo(): pass\n",
},
# Allow any expression in decorator
{
"node": cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(),
cst.SimpleStatementSuite((cst.Pass(),)),
(
cst.Decorator(
cst.BinaryOperation(cst.Name("a"), cst.Add(), cst.Name("b"))
),
),
),
"code": "@a + b\ndef foo(): pass\n",
},
# Allow parentheses around decorator
{
"node": cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(),
cst.SimpleStatementSuite((cst.Pass(),)),
(
cst.Decorator(
cst.Name(
"bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
)
),
),
),
"code": "@(bar)\ndef foo(): pass\n",
},
# Parameters
{
"node": cst.Parameters(
@ -740,154 +701,6 @@ class FunctionDefCreationTest(CSTNodeTest):
)
)
def test_valid(self, **kwargs: Any) -> None:
if "native_only" in kwargs:
kwargs.pop("native_only")
self.validate_node(**kwargs)
@data_provider(
(
# PEP 646
{
"node": cst.FunctionDef(
name=cst.Name(value="foo"),
params=cst.Parameters(
params=[],
star_arg=cst.Param(
star="*",
name=cst.Name("a"),
annotation=cst.Annotation(
cst.StarredElement(value=cst.Name("b")),
whitespace_before_indicator=cst.SimpleWhitespace(""),
),
),
),
body=cst.SimpleStatementSuite((cst.Pass(),)),
),
"parser": parse_statement,
"code": "def foo(*a: *b): pass\n",
},
{
"node": cst.FunctionDef(
name=cst.Name(value="foo"),
params=cst.Parameters(
params=[],
star_arg=cst.Param(
star="*",
name=cst.Name("a"),
annotation=cst.Annotation(
cst.StarredElement(
value=cst.Subscript(
value=cst.Name("tuple"),
slice=[
cst.SubscriptElement(
cst.Index(cst.Name("int")),
comma=cst.Comma(),
),
cst.SubscriptElement(
cst.Index(
value=cst.Name("Ts"),
star="*",
whitespace_after_star=cst.SimpleWhitespace(
""
),
),
comma=cst.Comma(),
),
cst.SubscriptElement(
cst.Index(cst.Ellipsis())
),
],
)
),
whitespace_before_indicator=cst.SimpleWhitespace(""),
),
),
),
body=cst.SimpleStatementSuite((cst.Pass(),)),
),
"parser": parse_statement,
"code": "def foo(*a: *tuple[int,*Ts,...]): pass\n",
},
# Single type variable
{
"node": cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(),
cst.SimpleStatementSuite((cst.Pass(),)),
type_parameters=cst.TypeParameters(
(cst.TypeParam(cst.TypeVar(cst.Name("T"))),)
),
),
"code": "def foo[T](): pass\n",
"parser": parse_statement,
},
# All the type parameters
{
"node": cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(),
cst.SimpleStatementSuite((cst.Pass(),)),
type_parameters=cst.TypeParameters(
(
cst.TypeParam(
cst.TypeVar(
cst.Name("T"),
bound=cst.Name("int"),
colon=cst.Colon(
whitespace_after=cst.SimpleWhitespace(" ")
),
),
cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
cst.TypeParam(
cst.TypeVarTuple(cst.Name("Ts")),
cst.Comma(whitespace_after=cst.SimpleWhitespace(" ")),
),
cst.TypeParam(cst.ParamSpec(cst.Name("KW"))),
)
),
),
"code": "def foo[T: int, *Ts, **KW](): pass\n",
"parser": parse_statement,
},
# Type parameters with whitespace
{
"node": cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(),
cst.SimpleStatementSuite((cst.Pass(),)),
type_parameters=cst.TypeParameters(
params=(
cst.TypeParam(
param=cst.TypeVar(
cst.Name("T"),
bound=cst.Name("str"),
colon=cst.Colon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.ParenthesizedWhitespace(
empty_lines=(cst.EmptyLine(),),
indent=True,
),
),
),
comma=cst.Comma(cst.SimpleWhitespace(" ")),
),
cst.TypeParam(
cst.ParamSpec(
cst.Name("PS"), cst.SimpleWhitespace(" ")
),
cst.Comma(cst.SimpleWhitespace(" ")),
),
)
),
whitespace_after_type_parameters=cst.SimpleWhitespace(" "),
),
"code": "def foo[T :\n\nstr ,** PS ,] (): pass\n",
"parser": parse_statement,
},
)
)
def test_valid_native(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
@ -1034,6 +847,22 @@ class FunctionDefCreationTest(CSTNodeTest):
),
r"Expecting a star prefix of '\*\*'",
),
# Validate decorator name semantics
(
lambda: cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(),
cst.SimpleStatementSuite((cst.Pass(),)),
(
cst.Decorator(
cst.Name(
"bar", lpar=(cst.LeftParen(),), rpar=(cst.RightParen(),)
)
),
),
),
"Cannot have parens around decorator in a Decorator",
),
)
)
def test_invalid(
@ -1047,9 +876,7 @@ def _parse_statement_force_38(code: str) -> cst.BaseCompoundStatement:
code, config=cst.PartialParserConfig(python_version="3.8")
)
if not isinstance(statement, cst.BaseCompoundStatement):
raise ValueError(
"This function is expecting to parse compound statements only!"
)
raise Exception("This function is expecting to parse compound statements only!")
return statement
@ -1972,36 +1799,6 @@ class FunctionDefParserTest(CSTNodeTest):
),
"code": "def foo(bar, baz, /): pass\n",
},
# Positional only params with whitespace after but no comma
{
"node": cst.FunctionDef(
cst.Name("foo"),
cst.Parameters(
posonly_params=(
cst.Param(
cst.Name("bar"),
star="",
comma=cst.Comma(
whitespace_after=cst.SimpleWhitespace(" ")
),
),
cst.Param(
cst.Name("baz"),
star="",
comma=cst.Comma(
whitespace_after=cst.SimpleWhitespace(" ")
),
),
),
posonly_ind=cst.ParamSlash(
whitespace_after=cst.SimpleWhitespace(" ")
),
),
cst.SimpleStatementSuite((cst.Pass(),)),
),
"code": "def foo(bar, baz, / ): pass\n",
"native_only": True,
},
# Typed positional only params
{
"node": cst.FunctionDef(
@ -2217,7 +2014,7 @@ class FunctionDefParserTest(CSTNodeTest):
},
)
)
def test_valid_38(self, node: cst.CSTNode, code: str, **kwargs: Any) -> None:
def test_valid_38(self, node: cst.CSTNode, code: str) -> None:
self.validate_node(node, code, _parse_statement_force_38)
@data_provider(
@ -2245,23 +2042,6 @@ class FunctionDefParserTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)
@data_provider(
(
{"code": "A[:*b]"},
{"code": "A[*b:]"},
{"code": "A[*b:*b]"},
{"code": "A[*(1:2)]"},
{"code": "A[*:]"},
{"code": "A[:*]"},
{"code": "A[**b]"},
{"code": "def f(x: *b): pass"},
{"code": "def f(**x: *b): pass"},
{"code": "x: *b"},
)
)
def test_parse_error(self, **kwargs: Any) -> None:
self.assert_parses(**kwargs, expect_success=False, parser=parse_statement)

View file

@ -3,7 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable
from typing import Any
import libcst as cst
from libcst import parse_statement
@ -129,21 +129,3 @@ class IfTest(CSTNodeTest):
)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
@data_provider(
(
# Validate whitespace handling
(
lambda: cst.If(
cst.Name("conditional"),
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_before_test=cst.SimpleWhitespace(""),
),
"Must have at least one space after 'if' keyword.",
),
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)

View file

@ -52,41 +52,6 @@ class IfExpTest(CSTNodeTest):
"(foo)if(bar)else(baz)",
CodeRange((1, 0), (1, 21)),
),
(
cst.IfExp(
body=cst.Name("foo"),
whitespace_before_if=cst.SimpleWhitespace(" "),
whitespace_after_if=cst.SimpleWhitespace(" "),
test=cst.Name("bar"),
whitespace_before_else=cst.SimpleWhitespace(" "),
whitespace_after_else=cst.SimpleWhitespace(""),
orelse=cst.IfExp(
body=cst.SimpleString("''"),
whitespace_before_if=cst.SimpleWhitespace(""),
test=cst.Name("bar"),
orelse=cst.Name("baz"),
),
),
"foo if bar else''if bar else baz",
CodeRange((1, 0), (1, 32)),
),
(
cst.GeneratorExp(
elt=cst.IfExp(
body=cst.Name("foo"),
test=cst.Name("bar"),
orelse=cst.SimpleString("''"),
whitespace_after_else=cst.SimpleWhitespace(""),
),
for_in=cst.CompFor(
target=cst.Name("_"),
iter=cst.Name("_"),
whitespace_before=cst.SimpleWhitespace(""),
),
),
"(foo if bar else''for _ in _)",
CodeRange((1, 1), (1, 28)),
),
# Make sure that spacing works
(
cst.IfExp(

View file

@ -195,20 +195,6 @@ class ImportCreateTest(CSTNodeTest):
),
"expected_re": "at least one space",
},
{
"get_node": lambda: cst.Import(
names=(
cst.ImportAlias(
cst.Name("foo"),
asname=cst.AsName(
cst.Name("bar"),
whitespace_before_as=cst.SimpleWhitespace(""),
),
),
),
),
"expected_re": "at least one space",
},
{
"get_node": lambda: cst.Import(
names=[
@ -578,25 +564,6 @@ class ImportFromCreateTest(CSTNodeTest):
),
"expected_re": "one space after import",
},
{
"get_node": lambda: cst.ImportFrom(
module=cst.Name("foo"),
names=(
cst.ImportAlias(
cst.Name("bar"),
asname=cst.AsName(
cst.Name(
"baz",
lpar=(cst.LeftParen(),),
rpar=(cst.RightParen(),),
),
whitespace_before_as=cst.SimpleWhitespace(""),
),
),
),
),
"expected_re": "one space before as keyword",
},
)
)
def test_invalid(self, **kwargs: Any) -> None:

View file

@ -30,22 +30,6 @@ class LambdaCreationTest(CSTNodeTest):
),
"code": "lambda bar, baz, /: 5",
},
# Test basic positional only params with extra trailing whitespace
{
"node": cst.Lambda(
cst.Parameters(
posonly_params=(
cst.Param(cst.Name("bar")),
cst.Param(cst.Name("baz")),
),
posonly_ind=cst.ParamSlash(
whitespace_after=cst.SimpleWhitespace(" ")
),
),
cst.Integer("5"),
),
"code": "lambda bar, baz, / : 5",
},
# Test basic positional params
(
cst.Lambda(
@ -303,6 +287,30 @@ class LambdaCreationTest(CSTNodeTest):
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(star_arg=cst.Param(cst.Name("arg"))),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(kwonly_params=(cst.Param(cst.Name("arg")),)),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(star_kwarg=cst.Param(cst.Name("arg"))),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"at least one space after lambda",
),
(
lambda: cst.Lambda(
cst.Parameters(
@ -920,53 +928,6 @@ class LambdaParserTest(CSTNodeTest):
),
"( lambda : 5 )",
),
# No space between lambda and params
(
cst.Lambda(
cst.Parameters(star_arg=cst.Param(cst.Name("args"), star="*")),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"lambda*args: 5",
),
(
cst.Lambda(
cst.Parameters(star_kwarg=cst.Param(cst.Name("kwargs"), star="**")),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"lambda**kwargs: 5",
),
(
cst.Lambda(
cst.Parameters(
star_arg=cst.ParamStar(
comma=cst.Comma(
cst.SimpleWhitespace(""), cst.SimpleWhitespace("")
)
),
kwonly_params=[cst.Param(cst.Name("args"), star="")],
),
cst.Integer("5"),
whitespace_after_lambda=cst.SimpleWhitespace(""),
),
"lambda*,args: 5",
),
(
cst.ListComp(
elt=cst.Lambda(
params=cst.Parameters(),
body=cst.Tuple(()),
colon=cst.Colon(),
),
for_in=cst.CompFor(
target=cst.Name("_"),
iter=cst.Name("_"),
whitespace_before=cst.SimpleWhitespace(""),
),
),
"[lambda:()for _ in _]",
),
)
)
def test_valid(

View file

@ -8,11 +8,13 @@ from typing import Any, Callable
import libcst as cst
from libcst import parse_expression, parse_statement
from libcst._nodes.tests.base import CSTNodeTest, parse_expression_as
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
class ListTest(CSTNodeTest):
# A lot of Element/StarredElement tests are provided by the tests for Tuple, so we
# we don't need to duplicate them here.
@data_provider(
@ -125,6 +127,6 @@ class ListTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -1,16 +1,19 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable
from typing import Any, Callable, Optional
import libcst as cst
from libcst import parse_statement
from libcst._nodes.tests.base import CSTNodeTest
from libcst._parser.entrypoints import is_native
from libcst.testing.utils import data_provider
parser: Callable[[str], cst.CSTNode] = parse_statement
parser: Optional[Callable[[str], cst.CSTNode]] = (
parse_statement if is_native() else None
)
class MatchTest(CSTNodeTest):
@ -36,39 +39,6 @@ class MatchTest(CSTNodeTest):
+ ' case "foo": pass\n',
"parser": parser,
},
# Parenthesized value
{
"node": cst.Match(
subject=cst.Name(
value="x",
),
cases=[
cst.MatchCase(
pattern=cst.MatchAs(
pattern=cst.MatchValue(
value=cst.Integer(
value="1",
lpar=[
cst.LeftParen(),
],
rpar=[
cst.RightParen(),
],
),
),
name=cst.Name(
value="z",
),
whitespace_before_as=cst.SimpleWhitespace(" "),
whitespace_after_as=cst.SimpleWhitespace(" "),
),
body=cst.SimpleStatementSuite([cst.Pass()]),
),
],
),
"code": "match x:\n case (1) as z: pass\n",
"parser": parser,
},
# List patterns
{
"node": cst.Match(
@ -455,34 +425,6 @@ class MatchTest(CSTNodeTest):
+ " case None | False | True: pass\n",
"parser": None,
},
# Match without whitespace between keyword and the expr
{
"node": cst.Match(
subject=cst.Name(
"x", lpar=[cst.LeftParen()], rpar=[cst.RightParen()]
),
cases=[
cst.MatchCase(
pattern=cst.MatchSingleton(
cst.Name(
"None",
lpar=[cst.LeftParen()],
rpar=[cst.RightParen()],
)
),
body=cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_after_case=cst.SimpleWhitespace(
value="",
),
),
],
whitespace_after_match=cst.SimpleWhitespace(
value="",
),
),
"code": "match(x):\n case(None): pass\n",
"parser": parser,
},
)
)
def test_valid(self, **kwargs: Any) -> None:

View file

@ -11,6 +11,7 @@ from libcst._nodes.tests.base import (
parse_expression_as,
parse_statement_as,
)
from libcst._parser.entrypoints import is_native
from libcst.testing.utils import data_provider
@ -69,6 +70,6 @@ class NamedExprTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -8,7 +8,7 @@ from typing import cast, Tuple
import libcst as cst
from libcst import parse_module, parse_statement
from libcst._nodes.tests.base import CSTNodeTest
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange, MetadataWrapper, PositionProvider
from libcst.testing.utils import data_provider
@ -117,7 +117,7 @@ class ModuleTest(CSTNodeTest):
def test_parser(
self, *, code: str, expected: cst.Module, enabled_for_native: bool = True
) -> None:
if not enabled_for_native:
if is_native() and not enabled_for_native:
self.skipTest("Disabled for native parser")
self.assertEqual(parse_module(code), expected)

View file

@ -22,9 +22,7 @@ def _parse_statement_force_38(code: str) -> cst.BaseCompoundStatement:
code, config=cst.PartialParserConfig(python_version="3.8")
)
if not isinstance(statement, cst.BaseCompoundStatement):
raise ValueError(
"This function is expecting to parse compound statements only!"
)
raise Exception("This function is expecting to parse compound statements only!")
return statement
@ -168,22 +166,6 @@ class NamedExprTest(CSTNodeTest):
"parser": _parse_expression_force_38,
"expected_position": None,
},
{
"node": cst.ListComp(
elt=cst.NamedExpr(
cst.Name("_"),
cst.SimpleString("''"),
whitespace_after_walrus=cst.SimpleWhitespace(""),
whitespace_before_walrus=cst.SimpleWhitespace(""),
),
for_in=cst.CompFor(
target=cst.Name("_"),
iter=cst.Name("_"),
whitespace_before=cst.SimpleWhitespace(""),
),
),
"code": "[_:=''for _ in _]",
},
)
)
def test_valid(self, **kwargs: Any) -> None:

View file

@ -95,7 +95,7 @@ class RemovalBehavior(CSTNodeTest):
self, before: str, after: str, visitor: Type[CSTTransformer]
) -> None:
if before.endswith("\n") or after.endswith("\n"):
raise ValueError("Test cases should not be newline-terminated!")
raise Exception("Test cases should not be newline-terminated!")
# Test doesn't have newline termination case
before_module = parse_module(before)

View file

@ -8,10 +8,12 @@ from typing import Any, Callable
import libcst as cst
from libcst import parse_expression
from libcst._nodes.tests.base import CSTNodeTest, parse_expression_as
from libcst._parser.entrypoints import is_native
from libcst.testing.utils import data_provider
class ListTest(CSTNodeTest):
# A lot of Element/StarredElement tests are provided by the tests for Tuple, so we
# we don't need to duplicate them here.
@data_provider(
@ -132,6 +134,6 @@ class ListTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -41,33 +41,6 @@ class SimpleCompTest(CSTNodeTest):
"code": "{a for b in c}",
"parser": parse_expression,
},
# non-trivial elt in GeneratorExp
{
"node": cst.GeneratorExp(
cst.BinaryOperation(cst.Name("a1"), cst.Add(), cst.Name("a2")),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
"code": "(a1 + a2 for b in c)",
"parser": parse_expression,
},
# non-trivial elt in ListComp
{
"node": cst.ListComp(
cst.BinaryOperation(cst.Name("a1"), cst.Add(), cst.Name("a2")),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
"code": "[a1 + a2 for b in c]",
"parser": parse_expression,
},
# non-trivial elt in SetComp
{
"node": cst.SetComp(
cst.BinaryOperation(cst.Name("a1"), cst.Add(), cst.Name("a2")),
cst.CompFor(target=cst.Name("b"), iter=cst.Name("c")),
),
"code": "{a1 + a2 for b in c}",
"parser": parse_expression,
},
# async GeneratorExp
{
"node": cst.GeneratorExp(

View file

@ -1,31 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import unittest
import libcst as cst
class TestSimpleString(unittest.TestCase):
def test_quote(self) -> None:
test_cases = [
('"a"', '"'),
("'b'", "'"),
('""', '"'),
("''", "'"),
('"""c"""', '"""'),
("'''d'''", "'''"),
('""""e"""', '"""'),
("''''f'''", "'''"),
('"""""g"""', '"""'),
("'''''h'''", "'''"),
('""""""', '"""'),
("''''''", "'''"),
]
for s, expected_quote in test_cases:
simple_string = cst.SimpleString(s)
actual = simple_string.quote
self.assertEqual(expected_quote, actual)

View file

@ -1,183 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Callable, Optional
import libcst as cst
from libcst import parse_expression
from libcst._nodes.tests.base import CSTNodeTest
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
class TemplatedStringTest(CSTNodeTest):
@data_provider(
(
# Simple t-string with only text
(
cst.TemplatedString(
parts=(cst.TemplatedStringText("hello world"),),
),
't"hello world"',
True,
),
# t-string with one expression
(
cst.TemplatedString(
parts=(
cst.TemplatedStringText("hello "),
cst.TemplatedStringExpression(
expression=cst.Name("name"),
),
),
),
't"hello {name}"',
True,
),
# t-string with multiple expressions
(
cst.TemplatedString(
parts=(
cst.TemplatedStringText("a="),
cst.TemplatedStringExpression(expression=cst.Name("a")),
cst.TemplatedStringText(", b="),
cst.TemplatedStringExpression(expression=cst.Name("b")),
),
),
't"a={a}, b={b}"',
True,
CodeRange((1, 0), (1, 15)),
),
# t-string with nested expression
(
cst.TemplatedString(
parts=(
cst.TemplatedStringText("sum="),
cst.TemplatedStringExpression(
expression=cst.BinaryOperation(
left=cst.Name("a"),
operator=cst.Add(),
right=cst.Name("b"),
)
),
),
),
't"sum={a + b}"',
True,
),
# t-string with spacing in expression
(
cst.TemplatedString(
parts=(
cst.TemplatedStringText("x = "),
cst.TemplatedStringExpression(
whitespace_before_expression=cst.SimpleWhitespace(" "),
expression=cst.Name("x"),
whitespace_after_expression=cst.SimpleWhitespace(" "),
),
),
),
't"x = { x }"',
True,
),
# t-string with escaped braces
(
cst.TemplatedString(
parts=(cst.TemplatedStringText("{{foo}}"),),
),
't"{{foo}}"',
True,
),
# t-string with only an expression
(
cst.TemplatedString(
parts=(
cst.TemplatedStringExpression(expression=cst.Name("value")),
),
),
't"{value}"',
True,
),
# t-string with whitespace and newlines
(
cst.TemplatedString(
parts=(
cst.TemplatedStringText("line1\\n"),
cst.TemplatedStringExpression(expression=cst.Name("x")),
cst.TemplatedStringText("\\nline2"),
),
),
't"line1\\n{x}\\nline2"',
True,
),
# t-string with parenthesis (not typical, but test node construction)
(
cst.TemplatedString(
lpar=(cst.LeftParen(),),
parts=(cst.TemplatedStringText("foo"),),
rpar=(cst.RightParen(),),
),
'(t"foo")',
True,
),
# t-string with whitespace in delimiters
(
cst.TemplatedString(
lpar=(cst.LeftParen(whitespace_after=cst.SimpleWhitespace(" ")),),
parts=(cst.TemplatedStringText("foo"),),
rpar=(cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),),
),
'( t"foo" )',
True,
),
# Test TemplatedStringText and TemplatedStringExpression individually
(
cst.TemplatedStringText("abc"),
"abc",
False,
CodeRange((1, 0), (1, 3)),
),
(
cst.TemplatedStringExpression(expression=cst.Name("foo")),
"{foo}",
False,
CodeRange((1, 0), (1, 5)),
),
)
)
def test_valid(
self,
node: cst.CSTNode,
code: str,
check_parsing: bool,
position: Optional[CodeRange] = None,
) -> None:
if check_parsing:
self.validate_node(node, code, parse_expression, expected_position=position)
else:
self.validate_node(node, code, expected_position=position)
@data_provider(
(
(
lambda: cst.TemplatedString(
parts=(cst.TemplatedStringText("foo"),),
lpar=(cst.LeftParen(),),
),
"left paren without right paren",
),
(
lambda: cst.TemplatedString(
parts=(cst.TemplatedStringText("foo"),),
rpar=(cst.RightParen(),),
),
"right paren without left paren",
),
)
)
def test_invalid(
self, get_node: Callable[[], cst.CSTNode], expected_re: str
) -> None:
self.assert_invalid(get_node, expected_re)

View file

@ -3,15 +3,18 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable
from typing import Any, Callable, Optional
import libcst as cst
from libcst import parse_statement
from libcst._nodes.tests.base import CSTNodeTest, DummyIndentedBlock
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
native_parse_statement: Callable[[str], cst.CSTNode] = parse_statement
native_parse_statement: Optional[Callable[[str], cst.CSTNode]] = (
parse_statement if is_native() else None
)
class TryTest(CSTNodeTest):
@ -326,52 +329,6 @@ class TryTest(CSTNodeTest):
"code": "try: pass\nexcept(IOError, ImportError): pass\n",
"parser": parse_statement,
},
# No space before as
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=[
cst.ExceptHandler(
cst.SimpleStatementSuite((cst.Pass(),)),
whitespace_after_except=cst.SimpleWhitespace(" "),
type=cst.Call(cst.Name("foo")),
name=cst.AsName(
whitespace_before_as=cst.SimpleWhitespace(""),
name=cst.Name("bar"),
),
)
],
),
"code": "try: pass\nexcept foo()as bar: pass\n",
},
# PEP758 - Multiple exceptions with no parentheses
{
"node": cst.Try(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=[
cst.ExceptHandler(
cst.SimpleStatementSuite((cst.Pass(),)),
type=cst.Tuple(
elements=[
cst.Element(
value=cst.Name(
value="ValueError",
),
),
cst.Element(
value=cst.Name(
value="RuntimeError",
),
),
],
lpar=[],
rpar=[],
),
)
],
),
"code": "try: pass\nexcept ValueError, RuntimeError: pass\n",
},
)
)
def test_valid(self, **kwargs: Any) -> None:
@ -389,6 +346,12 @@ class TryTest(CSTNodeTest):
),
"expected_re": "between 'as'",
},
{
"get_node": lambda: cst.AsName(
cst.Name("bla"), whitespace_before_as=cst.SimpleWhitespace("")
),
"expected_re": "before 'as'",
},
{
"get_node": lambda: cst.ExceptHandler(
cst.SimpleStatementSuite((cst.Pass(),)),
@ -604,38 +567,6 @@ class TryStarTest(CSTNodeTest):
"parser": native_parse_statement,
"expected_position": CodeRange((1, 0), (5, 13)),
},
# PEP758 - Multiple exceptions with no parentheses
{
"node": cst.TryStar(
cst.SimpleStatementSuite((cst.Pass(),)),
handlers=[
cst.ExceptStarHandler(
cst.SimpleStatementSuite((cst.Pass(),)),
type=cst.Tuple(
elements=[
cst.Element(
value=cst.Name(
value="ValueError",
),
comma=cst.Comma(
whitespace_after=cst.SimpleWhitespace(" ")
),
),
cst.Element(
value=cst.Name(
value="RuntimeError",
),
),
],
lpar=[],
rpar=[],
),
)
],
),
"code": "try: pass\nexcept* ValueError, RuntimeError: pass\n",
"parser": native_parse_statement,
},
)
)
def test_valid(self, **kwargs: Any) -> None:

View file

@ -8,6 +8,7 @@ from typing import Any, Callable
import libcst as cst
from libcst import parse_expression, parse_statement
from libcst._nodes.tests.base import CSTNodeTest, parse_expression_as
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
@ -90,47 +91,6 @@ class TupleTest(CSTNodeTest):
"parser": parse_expression,
"expected_position": CodeRange((1, 1), (1, 11)),
},
# top-level two-element tuple, with one being starred
{
"node": cst.SimpleStatementLine(
body=[
cst.Expr(
value=cst.Tuple(
[
cst.Element(cst.Name("one"), comma=cst.Comma()),
cst.StarredElement(cst.Name("two")),
],
lpar=[],
rpar=[],
)
)
]
),
"code": "one,*two\n",
"parser": parse_statement,
},
# top-level three-element tuple, start/end is starred
{
"node": cst.SimpleStatementLine(
body=[
cst.Expr(
value=cst.Tuple(
[
cst.StarredElement(
cst.Name("one"), comma=cst.Comma()
),
cst.Element(cst.Name("two"), comma=cst.Comma()),
cst.StarredElement(cst.Name("three")),
],
lpar=[],
rpar=[],
)
)
]
),
"code": "*one,two,*three\n",
"parser": parse_statement,
},
# missing spaces around tuple, okay with parenthesis
{
"node": cst.For(
@ -285,6 +245,6 @@ class TupleTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -1,252 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any
import libcst as cst
from libcst import parse_statement
from libcst._nodes.tests.base import CSTNodeTest
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
class TypeAliasCreationTest(CSTNodeTest):
@data_provider(
(
{
"node": cst.TypeAlias(
cst.Name("foo"),
cst.Name("bar"),
),
"code": "type foo = bar",
"expected_position": CodeRange((1, 0), (1, 14)),
},
{
"node": cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[cst.TypeParam(cst.TypeVar(cst.Name("T")))]
),
value=cst.BinaryOperation(
cst.Name("bar"), cst.BitOr(), cst.Name("baz")
),
),
"code": "type foo[T] = bar | baz",
"expected_position": CodeRange((1, 0), (1, 23)),
},
{
"node": cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[
cst.TypeParam(
cst.TypeVar(cst.Name("T"), bound=cst.Name("str"))
),
cst.TypeParam(cst.TypeVarTuple(cst.Name("Ts"))),
cst.TypeParam(cst.ParamSpec(cst.Name("KW"))),
]
),
value=cst.BinaryOperation(
cst.Name("bar"), cst.BitOr(), cst.Name("baz")
),
),
"code": "type foo[T: str, *Ts, **KW] = bar | baz",
"expected_position": CodeRange((1, 0), (1, 39)),
},
{
"node": cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[
cst.TypeParam(
cst.TypeVar(cst.Name("T")), default=cst.Name("str")
),
]
),
value=cst.Name("bar"),
),
"code": "type foo[T = str] = bar",
"expected_position": CodeRange((1, 0), (1, 23)),
},
{
"node": cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[
cst.TypeParam(
cst.ParamSpec(cst.Name("P")),
default=cst.List(
elements=[
cst.Element(cst.Name("int")),
cst.Element(cst.Name("str")),
]
),
),
]
),
value=cst.Name("bar"),
),
"code": "type foo[**P = [int, str]] = bar",
"expected_position": CodeRange((1, 0), (1, 32)),
},
{
"node": cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[
cst.TypeParam(
cst.TypeVarTuple(cst.Name("T")),
equal=cst.AssignEqual(),
default=cst.Name("default"),
star="*",
),
]
),
value=cst.Name("bar"),
),
"code": "type foo[*T = *default] = bar",
"expected_position": CodeRange((1, 0), (1, 29)),
},
{
"node": cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[
cst.TypeParam(
cst.TypeVarTuple(cst.Name("T")),
equal=cst.AssignEqual(),
default=cst.Name("default"),
star="*",
whitespace_after_star=cst.SimpleWhitespace(" "),
),
]
),
value=cst.Name("bar"),
),
"code": "type foo[*T = * default] = bar",
"expected_position": CodeRange((1, 0), (1, 31)),
},
)
)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)
class TypeAliasParserTest(CSTNodeTest):
@data_provider(
(
{
"node": cst.SimpleStatementLine(
[
cst.TypeAlias(
cst.Name("foo"),
cst.Name("bar"),
whitespace_after_name=cst.SimpleWhitespace(" "),
)
]
),
"code": "type foo = bar\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
[
cst.TypeAlias(
cst.Name("foo"),
cst.Name("bar"),
type_parameters=cst.TypeParameters(
params=[
cst.TypeParam(
cst.TypeVar(
cst.Name("T"), cst.Name("str"), cst.Colon()
),
cst.Comma(),
),
cst.TypeParam(
cst.ParamSpec(
cst.Name("KW"),
whitespace_after_star=cst.SimpleWhitespace(
" "
),
),
cst.Comma(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
),
),
],
rbracket=cst.RightSquareBracket(
cst.SimpleWhitespace("")
),
),
whitespace_after_name=cst.SimpleWhitespace(" "),
whitespace_after_type=cst.SimpleWhitespace(" "),
whitespace_after_equals=cst.SimpleWhitespace(" "),
whitespace_after_type_parameters=cst.SimpleWhitespace(" "),
semicolon=cst.Semicolon(
whitespace_before=cst.SimpleWhitespace(" "),
whitespace_after=cst.SimpleWhitespace(" "),
),
)
]
),
"code": "type foo [T:str,** KW , ] = bar ; \n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
[
cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[
cst.TypeParam(
cst.TypeVarTuple(cst.Name("P")),
star="*",
equal=cst.AssignEqual(),
default=cst.Name("default"),
),
]
),
value=cst.Name("bar"),
whitespace_after_name=cst.SimpleWhitespace(" "),
whitespace_after_type_parameters=cst.SimpleWhitespace(" "),
)
]
),
"code": "type foo [*P = *default] = bar\n",
"parser": parse_statement,
},
{
"node": cst.SimpleStatementLine(
[
cst.TypeAlias(
cst.Name("foo"),
type_parameters=cst.TypeParameters(
[
cst.TypeParam(
cst.TypeVarTuple(cst.Name("P")),
star="*",
whitespace_after_star=cst.SimpleWhitespace(
" "
),
equal=cst.AssignEqual(),
default=cst.Name("default"),
),
]
),
value=cst.Name("bar"),
whitespace_after_name=cst.SimpleWhitespace(" "),
whitespace_after_type_parameters=cst.SimpleWhitespace(" "),
)
]
),
"code": "type foo [*P = * default] = bar\n",
"parser": parse_statement,
},
)
)
def test_valid(self, **kwargs: Any) -> None:
self.validate_node(**kwargs)

View file

@ -7,7 +7,9 @@ from typing import Any
import libcst as cst
from libcst import parse_statement, PartialParserConfig
from libcst._maybe_sentinel import MaybeSentinel
from libcst._nodes.tests.base import CSTNodeTest, DummyIndentedBlock, parse_statement_as
from libcst._parser.entrypoints import is_native
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
@ -100,23 +102,6 @@ class WithTest(CSTNodeTest):
"code": "with context_mgr() as ctx: pass\n",
"parser": parse_statement,
},
{
"node": cst.With(
(
cst.WithItem(
cst.Call(cst.Name("context_mgr")),
cst.AsName(
cst.Tuple(()),
whitespace_after_as=cst.SimpleWhitespace(""),
whitespace_before_as=cst.SimpleWhitespace(""),
),
),
),
cst.SimpleStatementSuite((cst.Pass(),)),
),
"code": "with context_mgr()as(): pass\n",
"parser": parse_statement,
},
# indentation
{
"node": DummyIndentedBlock(
@ -185,14 +170,14 @@ class WithTest(CSTNodeTest):
cst.WithItem(
cst.Call(
cst.Name("context_mgr"),
lpar=(),
rpar=(),
lpar=() if is_native() else (cst.LeftParen(),),
rpar=() if is_native() else (cst.RightParen(),),
)
),
),
cst.SimpleStatementSuite((cst.Pass(),)),
lpar=(cst.LeftParen()),
rpar=(cst.RightParen()),
lpar=(cst.LeftParen() if is_native() else MaybeSentinel.DEFAULT),
rpar=(cst.RightParen() if is_native() else MaybeSentinel.DEFAULT),
whitespace_after_with=cst.SimpleWhitespace(""),
),
"code": "with(context_mgr()): pass\n",
@ -231,7 +216,7 @@ class WithTest(CSTNodeTest):
rpar=cst.RightParen(whitespace_before=cst.SimpleWhitespace(" ")),
),
"code": ("with ( foo(),\n" " bar(), ): pass\n"), # noqa
"parser": parse_statement,
"parser": parse_statement if is_native() else None,
"expected_position": CodeRange((1, 0), (2, 21)),
},
)
@ -308,7 +293,7 @@ class WithTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -8,6 +8,7 @@ from typing import Any, Callable, Optional
import libcst as cst
from libcst import parse_statement
from libcst._nodes.tests.base import CSTNodeTest, parse_statement_as
from libcst._parser.entrypoints import is_native
from libcst.helpers import ensure_type
from libcst.metadata import CodeRange
from libcst.testing.utils import data_provider
@ -240,6 +241,6 @@ class YieldParsingTest(CSTNodeTest):
)
)
def test_versions(self, **kwargs: Any) -> None:
if not kwargs.get("expect_success", True):
if is_native() and not kwargs.get("expect_success", True):
self.skipTest("parse errors are disabled for native parser")
self.assert_parses(**kwargs)

View file

@ -48,8 +48,6 @@ class BaseParenthesizableWhitespace(CSTNode, ABC):
``iftest``), it has some semantic value.
"""
__slots__ = ()
# TODO: Should we somehow differentiate places where we require non-zero whitespace
# with a separate type?

View file

@ -1,53 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Iterable, Union
from libcst._exceptions import EOFSentinel
from libcst._parser.parso.pgen2.generator import ReservedString
from libcst._parser.parso.python.token import PythonTokenTypes, TokenType
from libcst._parser.types.token import Token
_EOF_STR: str = "end of file (EOF)"
_INDENT_STR: str = "an indent"
_DEDENT_STR: str = "a dedent"
def get_expected_str(
encountered: Union[Token, EOFSentinel],
expected: Union[Iterable[Union[TokenType, ReservedString]], EOFSentinel],
) -> str:
if (
isinstance(encountered, EOFSentinel)
or encountered.type is PythonTokenTypes.ENDMARKER
):
encountered_str = _EOF_STR
elif encountered.type is PythonTokenTypes.INDENT:
encountered_str = _INDENT_STR
elif encountered.type is PythonTokenTypes.DEDENT:
encountered_str = _DEDENT_STR
else:
encountered_str = repr(encountered.string)
if isinstance(expected, EOFSentinel):
expected_names = [_EOF_STR]
else:
expected_names = sorted(
[
repr(el.name) if isinstance(el, TokenType) else repr(el.value)
for el in expected
]
)
if len(expected_names) > 10:
# There's too many possibilities, so it's probably not useful to list them.
# Instead, let's just abbreviate the message.
return f"Unexpectedly encountered {encountered_str}."
else:
if len(expected_names) == 1:
expected_str = expected_names[0]
else:
expected_str = f"{', '.join(expected_names[:-1])}, or {expected_names[-1]}"
return f"Encountered {encountered_str}, but expected {expected_str}."

View file

@ -26,8 +26,12 @@
from dataclasses import dataclass, field
from typing import Generic, Iterable, List, Sequence, TypeVar, Union
from libcst._exceptions import EOFSentinel, ParserSyntaxError, PartialParserSyntaxError
from libcst._parser._parsing_check import get_expected_str
from libcst._exceptions import (
EOFSentinel,
get_expected_str,
ParserSyntaxError,
PartialParserSyntaxError,
)
from libcst._parser.parso.pgen2.generator import DFAState, Grammar, ReservedString
from libcst._parser.parso.python.token import TokenType
from libcst._parser.types.token import Token
@ -99,7 +103,7 @@ class BaseParser(Generic[_TokenT, _TokenTypeT, _NodeT]):
def parse(self) -> _NodeT:
# Ensure that we don't re-use parsers.
if self.__was_parse_called:
raise ValueError("Each parser object may only be used to parse once.")
raise Exception("Each parser object may only be used to parse once.")
self.__was_parse_called = True
for token in self.tokens:
@ -125,9 +129,11 @@ class BaseParser(Generic[_TokenT, _TokenTypeT, _NodeT]):
def convert_nonterminal(
self, nonterminal: str, children: Sequence[_NodeT]
) -> _NodeT: ...
) -> _NodeT:
...
def convert_terminal(self, token: _TokenT) -> _NodeT: ...
def convert_terminal(self, token: _TokenT) -> _NodeT:
...
def _add_token(self, token: _TokenT) -> None:
"""

View file

@ -12,8 +12,7 @@ from tokenize import (
Intnumber as INTNUMBER_RE,
)
from libcst import CSTLogicError
from libcst._exceptions import ParserSyntaxError, PartialParserSyntaxError
from libcst._exceptions import PartialParserSyntaxError
from libcst._maybe_sentinel import MaybeSentinel
from libcst._nodes.expression import (
Arg,
@ -328,12 +327,7 @@ def convert_boolop(
# Convert all of the operations that have no precedence in a loop
for op, rightexpr in grouper(rightexprs, 2):
if op.string not in BOOLOP_TOKEN_LUT:
raise ParserSyntaxError(
f"Unexpected token '{op.string}'!",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception(f"Unexpected token '{op.string}'!")
leftexpr = BooleanOperation(
left=leftexpr,
# pyre-ignore Pyre thinks that the type of the LUT is CSTNode.
@ -426,12 +420,7 @@ def convert_comp_op(
)
else:
# this should be unreachable
raise ParserSyntaxError(
f"Unexpected token '{op.string}'!",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception(f"Unexpected token '{op.string}'!")
else:
# A two-token comparison
leftcomp, rightcomp = children
@ -462,12 +451,7 @@ def convert_comp_op(
)
else:
# this should be unreachable
raise ParserSyntaxError(
f"Unexpected token '{leftcomp.string} {rightcomp.string}'!",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception(f"Unexpected token '{leftcomp.string} {rightcomp.string}'!")
@with_production("star_expr", "'*' expr")
@ -509,12 +493,7 @@ def convert_binop(
# Convert all of the operations that have no precedence in a loop
for op, rightexpr in grouper(rightexprs, 2):
if op.string not in BINOP_TOKEN_LUT:
raise ParserSyntaxError(
f"Unexpected token '{op.string}'!",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception(f"Unexpected token '{op.string}'!")
leftexpr = BinaryOperation(
left=leftexpr,
# pyre-ignore Pyre thinks that the type of the LUT is CSTNode.
@ -561,12 +540,7 @@ def convert_factor(
)
)
else:
raise ParserSyntaxError(
f"Unexpected token '{op.string}'!",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception(f"Unexpected token '{op.string}'!")
return WithLeadingWhitespace(
UnaryOperation(operator=opnode, expression=factor.value), op.whitespace_before
@ -677,7 +651,7 @@ def convert_atom_expr_trailer(
)
else:
# This is an invalid trailer, so lets give up
raise CSTLogicError()
raise Exception("Logic error!")
return WithLeadingWhitespace(atom, whitespace_before)
@ -896,19 +870,9 @@ def convert_atom_basic(
Imaginary(child.string), child.whitespace_before
)
else:
raise ParserSyntaxError(
f"Unparseable number {child.string}",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception("Unparseable number {child.string}")
else:
raise ParserSyntaxError(
f"Logic error, unexpected token {child.type.name}",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception(f"Logic error, unexpected token {child.type.name}")
@with_production("atom_squarebrackets", "'[' [testlist_comp_list] ']'")
@ -1483,7 +1447,7 @@ def convert_arg_assign_comp_for(
if equal.string == ":=":
val = convert_namedexpr_test(config, children)
if not isinstance(val, WithLeadingWhitespace):
raise TypeError(
raise Exception(
f"convert_namedexpr_test returned {val!r}, not WithLeadingWhitespace"
)
return Arg(value=val.value)

View file

@ -6,7 +6,6 @@
from typing import Any, List, Optional, Sequence, Union
from libcst import CSTLogicError
from libcst._exceptions import PartialParserSyntaxError
from libcst._maybe_sentinel import MaybeSentinel
from libcst._nodes.expression import (
@ -122,7 +121,7 @@ def convert_argslist( # noqa: C901
# Example code:
# def fn(*abc, *): ...
# This should be unreachable, the grammar already disallows it.
raise ValueError(
raise Exception(
"Cannot have multiple star ('*') markers in a single argument "
+ "list."
)
@ -137,7 +136,7 @@ def convert_argslist( # noqa: C901
# Example code:
# def fn(foo, /, *, /, bar): ...
# This should be unreachable, the grammar already disallows it.
raise ValueError(
raise Exception(
"Cannot have multiple slash ('/') markers in a single argument "
+ "list."
)
@ -169,7 +168,7 @@ def convert_argslist( # noqa: C901
# Example code:
# def fn(**kwargs, trailing=None)
# This should be unreachable, the grammar already disallows it.
raise ValueError("Cannot have any arguments after a kwargs expansion.")
raise Exception("Cannot have any arguments after a kwargs expansion.")
elif (
isinstance(param.star, str) and param.star == "*" and param.default is None
):
@ -182,7 +181,7 @@ def convert_argslist( # noqa: C901
# Example code:
# def fn(*first, *second): ...
# This should be unreachable, the grammar already disallows it.
raise ValueError(
raise Exception(
"Expected a keyword argument but found a starred positional "
+ "argument expansion."
)
@ -198,13 +197,13 @@ def convert_argslist( # noqa: C901
# Example code:
# def fn(**first, **second)
# This should be unreachable, the grammar already disallows it.
raise ValueError(
raise Exception(
"Multiple starred keyword argument expansions are not allowed in a "
+ "single argument list"
)
else:
# The state machine should never end up here.
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
return current_param

View file

@ -6,8 +6,7 @@
from typing import Any, Dict, List, Optional, Sequence, Tuple, Type
from libcst import CSTLogicError
from libcst._exceptions import ParserSyntaxError, PartialParserSyntaxError
from libcst._exceptions import PartialParserSyntaxError
from libcst._maybe_sentinel import MaybeSentinel
from libcst._nodes.expression import (
Annotation,
@ -284,9 +283,7 @@ def convert_annassign(config: ParserConfig, children: Sequence[Any]) -> Any:
whitespace_after=parse_simple_whitespace(config, equal.whitespace_after),
)
else:
raise ParserSyntaxError(
"Invalid parser state!", lines=config.lines, raw_line=0, raw_column=0
)
raise Exception("Invalid parser state!")
return AnnAssignPartial(
annotation=Annotation(
@ -322,13 +319,7 @@ def convert_annassign(config: ParserConfig, children: Sequence[Any]) -> Any:
def convert_augassign(config: ParserConfig, children: Sequence[Any]) -> Any:
op, expr = children
if op.string not in AUGOP_TOKEN_LUT:
raise ParserSyntaxError(
f"Unexpected token '{op.string}'!",
lines=config.lines,
raw_line=0,
raw_column=0,
)
raise Exception(f"Unexpected token '{op.string}'!")
return AugAssignPartial(
# pyre-ignore Pyre seems to think that the value of this LUT is CSTNode
operator=AUGOP_TOKEN_LUT[op.string](
@ -456,7 +447,7 @@ def convert_import_relative(config: ParserConfig, children: Sequence[Any]) -> An
# This should be the dotted name, and we can't get more than
# one, but lets be sure anyway
if dotted_name is not None:
raise CSTLogicError()
raise Exception("Logic error!")
dotted_name = child
return ImportRelativePartial(relative=tuple(dots), module=dotted_name)
@ -653,7 +644,7 @@ def convert_raise_stmt(config: ParserConfig, children: Sequence[Any]) -> Any:
item=source.value,
)
else:
raise CSTLogicError()
raise Exception("Logic error!")
return WithLeadingWhitespace(
Raise(whitespace_after_raise=whitespace_after_raise, exc=exc, cause=cause),
@ -902,7 +893,7 @@ def convert_try_stmt(config: ParserConfig, children: Sequence[Any]) -> Any:
if isinstance(clause, Token):
if clause.string == "else":
if orelse is not None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
orelse = Else(
leading_lines=parse_empty_lines(config, clause.whitespace_before),
whitespace_before_colon=parse_simple_whitespace(
@ -912,7 +903,7 @@ def convert_try_stmt(config: ParserConfig, children: Sequence[Any]) -> Any:
)
elif clause.string == "finally":
if finalbody is not None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
finalbody = Finally(
leading_lines=parse_empty_lines(config, clause.whitespace_before),
whitespace_before_colon=parse_simple_whitespace(
@ -921,7 +912,7 @@ def convert_try_stmt(config: ParserConfig, children: Sequence[Any]) -> Any:
body=suite,
)
else:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
elif isinstance(clause, ExceptClausePartial):
handlers.append(
ExceptHandler(
@ -936,7 +927,7 @@ def convert_try_stmt(config: ParserConfig, children: Sequence[Any]) -> Any:
)
)
else:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
return Try(
leading_lines=parse_empty_lines(config, trytoken.whitespace_before),
@ -1342,7 +1333,7 @@ def convert_asyncable_stmt(config: ParserConfig, children: Sequence[Any]) -> Any
asynchronous=asyncnode, leading_lines=leading_lines
)
else:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
@with_production("suite", "simple_stmt_suite | indented_suite")

View file

@ -9,6 +9,7 @@ parser. A parser entrypoint should take the source code and some configuration
information
"""
import os
from functools import partial
from typing import Union
@ -16,12 +17,19 @@ from libcst._nodes.base import CSTNode
from libcst._nodes.expression import BaseExpression
from libcst._nodes.module import Module
from libcst._nodes.statement import BaseCompoundStatement, SimpleStatementLine
from libcst._parser.detect_config import convert_to_utf8
from libcst._parser.detect_config import convert_to_utf8, detect_config
from libcst._parser.grammar import get_grammar, validate_grammar
from libcst._parser.python_parser import PythonCSTParser
from libcst._parser.types.config import PartialParserConfig
_DEFAULT_PARTIAL_PARSER_CONFIG: PartialParserConfig = PartialParserConfig()
def is_native() -> bool:
typ = os.environ.get("LIBCST_PARSER_TYPE", None)
return typ == "native"
def _parse(
entrypoint: str,
source: Union[str, bytes],
@ -30,21 +38,57 @@ def _parse(
detect_trailing_newline: bool,
detect_default_newline: bool,
) -> CSTNode:
if is_native():
from libcst.native import parse_expression, parse_module, parse_statement
encoding, source_str = convert_to_utf8(source, partial=config)
encoding, source_str = convert_to_utf8(source, partial=config)
from libcst import native
if entrypoint == "file_input":
parse = partial(parse_module, encoding=encoding)
elif entrypoint == "stmt_input":
parse = parse_statement
elif entrypoint == "expression_input":
parse = parse_expression
else:
raise ValueError(f"Unknown parser entry point: {entrypoint}")
if entrypoint == "file_input":
parse = partial(native.parse_module, encoding=encoding)
elif entrypoint == "stmt_input":
parse = native.parse_statement
elif entrypoint == "expression_input":
parse = native.parse_expression
else:
raise ValueError(f"Unknown parser entry point: {entrypoint}")
return parse(source_str)
return _pure_python_parse(
entrypoint,
source,
config,
detect_trailing_newline=detect_trailing_newline,
detect_default_newline=detect_default_newline,
)
return parse(source_str)
def _pure_python_parse(
entrypoint: str,
source: Union[str, bytes],
config: PartialParserConfig,
*,
detect_trailing_newline: bool,
detect_default_newline: bool,
) -> CSTNode:
detection_result = detect_config(
source,
partial=config,
detect_trailing_newline=detect_trailing_newline,
detect_default_newline=detect_default_newline,
)
validate_grammar()
grammar = get_grammar(config.parsed_python_version, config.future_imports)
parser = PythonCSTParser(
tokens=detection_result.tokens,
config=detection_result.config,
pgen_grammar=grammar,
start_nonterminal=entrypoint,
)
# The parser has an Any return type, we can at least refine it to CSTNode here.
result = parser.parse()
assert isinstance(result, CSTNode)
return result
def parse_module(

View file

@ -319,7 +319,7 @@ def validate_grammar() -> None:
production_name = fn_productions[0].name
expected_name = f"convert_{production_name}"
if fn.__name__ != expected_name:
raise ValueError(
raise Exception(
f"The conversion function for '{production_name}' "
+ f"must be called '{expected_name}', not '{fn.__name__}'."
)
@ -330,7 +330,7 @@ def _get_version_comparison(version: str) -> Tuple[str, PythonVersionInfo]:
return (version[:2], parse_version_string(version[2:].strip()))
if version[:1] in (">", "<"):
return (version[:1], parse_version_string(version[1:].strip()))
raise ValueError(f"Invalid version comparison specifier '{version}'")
raise Exception(f"Invalid version comparison specifier '{version}'")
def _compare_versions(
@ -350,7 +350,7 @@ def _compare_versions(
return actual_version > requested_version
if comparison == "<":
return actual_version < requested_version
raise ValueError(f"Invalid version comparison specifier '{comparison}'")
raise Exception(f"Invalid version comparison specifier '{comparison}'")
def _should_include(
@ -405,7 +405,7 @@ def get_nonterminal_conversions(
if not _should_include_future(fn_production.future, future_imports):
continue
if fn_production.name in conversions:
raise ValueError(
raise Exception(
f"Found duplicate '{fn_production.name}' production in grammar"
)
conversions[fn_production.name] = fn

View file

@ -72,9 +72,9 @@ class DFAState(Generic[_TokenTypeT]):
def __init__(self, from_rule: str, nfa_set: Set[NFAState], final: NFAState) -> None:
self.from_rule = from_rule
self.nfa_set = nfa_set
self.arcs: Mapping[str, DFAState] = (
{}
) # map from terminals/nonterminals to DFAState
self.arcs: Mapping[
str, DFAState
] = {} # map from terminals/nonterminals to DFAState
# In an intermediary step we set these nonterminal arcs (which has the
# same structure as arcs). These don't contain terminals anymore.
self.nonterminal_arcs: Mapping[str, DFAState] = {}
@ -259,7 +259,7 @@ def generate_grammar(bnf_grammar: str, token_namespace: Any) -> Grammar[Any]:
_calculate_tree_traversal(rule_to_dfas)
if start_nonterminal is None:
raise ValueError("could not find starting nonterminal!")
raise Exception("could not find starting nonterminal!")
return Grammar(start_nonterminal, rule_to_dfas, reserved_strings)

View file

@ -93,10 +93,14 @@ class GrammarParser:
def _parse_items(self):
# items: item+
a, b = self._parse_item()
while self.type in (
PythonTokenTypes.NAME,
PythonTokenTypes.STRING,
) or self.value in ("(", "["):
while (
self.type
in (
PythonTokenTypes.NAME,
PythonTokenTypes.STRING,
)
or self.value in ("(", "[")
):
c, d = self._parse_item()
# Need to end on the next item.
b.add_arc(c)

View file

@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
@ -26,8 +26,9 @@ try:
ERRORTOKEN: TokenType = native_token_type.ERRORTOKEN
ERROR_DEDENT: TokenType = native_token_type.ERROR_DEDENT
except ImportError:
from libcst._parser.parso.python.py_token import ( # noqa: F401
from libcst._parser.parso.python.py_token import ( # noqa F401
PythonTokenTypes,
TokenType,
)

View file

@ -36,7 +36,6 @@ from collections import namedtuple
from dataclasses import dataclass
from typing import Dict, Generator, Iterable, Optional, Pattern, Set, Tuple
from libcst import CSTLogicError
from libcst._parser.parso.python.token import PythonTokenTypes
from libcst._parser.parso.utils import PythonVersionInfo, split_lines
@ -523,14 +522,14 @@ def _tokenize_lines_py36_or_below( # noqa: C901
if contstr: # continued string
if endprog is None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
endmatch = endprog.match(line)
if endmatch:
pos = endmatch.end(0)
if contstr_start is None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
if stashed is not None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
yield PythonToken(STRING, contstr + line[:pos], contstr_start, prefix)
contstr = ""
contline = None
@ -548,7 +547,7 @@ def _tokenize_lines_py36_or_below( # noqa: C901
)
if string:
if stashed is not None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
yield PythonToken(
FSTRING_STRING,
string,
@ -573,7 +572,7 @@ def _tokenize_lines_py36_or_below( # noqa: C901
pos += quote_length
if fstring_end_token is not None:
if stashed is not None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
yield fstring_end_token
continue
@ -886,12 +885,12 @@ def _tokenize_lines_py37_or_above( # noqa: C901
if contstr: # continued string
if endprog is None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
endmatch = endprog.match(line)
if endmatch:
pos = endmatch.end(0)
if contstr_start is None:
raise CSTLogicError("Logic error!")
raise Exception("Logic error!")
yield PythonToken(STRING, contstr + line[:pos], contstr_start, prefix)
contstr = ""
contline = None

View file

@ -39,8 +39,8 @@ class ParsoUtilsTest(UnitTest):
# Invalid line breaks
("a\vb", ["a\vb"], False),
("a\vb", ["a\vb"], True),
("\x1c", ["\x1c"], False),
("\x1c", ["\x1c"], True),
("\x1C", ["\x1C"], False),
("\x1C", ["\x1C"], True),
)
)
def test_split_lines(self, string, expected_result, keepends):

View file

@ -29,9 +29,9 @@ from typing import Optional, Sequence, Tuple, Union
_NON_LINE_BREAKS = (
"\v", # Vertical Tabulation 0xB
"\f", # Form Feed 0xC
"\x1c", # File Separator
"\x1d", # Group Separator
"\x1e", # Record Separator
"\x1C", # File Separator
"\x1D", # Group Separator
"\x1E", # Record Separator
"\x85", # Next Line (NEL - Equivalent to CR+LF.
# Used to mark end-of-line on some IBM mainframes.)
"\u2028", # Line Separator
@ -114,11 +114,11 @@ def python_bytes_to_unicode(
return b"utf-8"
# pyre-ignore Pyre can't see that Union[str, bytes] conforms to AnyStr.
first_two_match = re.match(rb"(?:[^\n]*\n){0,2}", source)
first_two_match = re.match(br"(?:[^\n]*\n){0,2}", source)
if first_two_match is None:
return encoding
first_two_lines = first_two_match.group(0)
possible_encoding = re.search(rb"coding[=:]\s*([-\w.]+)", first_two_lines)
possible_encoding = re.search(br"coding[=:]\s*([-\w.]+)", first_two_lines)
if possible_encoding:
return possible_encoding.group(1)
else:

View file

@ -39,7 +39,7 @@ def with_production(
# pyre-ignore: Pyre doesn't think that fn has a __name__ attribute
fn_name = fn.__name__
if not fn_name.startswith("convert_"):
raise ValueError(
raise Exception(
"A function with a production must be named 'convert_X', not "
+ f"'{fn_name}'."
)

View file

@ -1,11 +1,10 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Sequence, Tuple, Union
from libcst import CSTLogicError, ParserSyntaxError
from libcst._nodes.whitespace import (
Comment,
COMMENT_RE,
@ -104,13 +103,10 @@ def parse_trailing_whitespace(
) -> TrailingWhitespace:
trailing_whitespace = _parse_trailing_whitespace(config, state)
if trailing_whitespace is None:
raise ParserSyntaxError(
raise Exception(
"Internal Error: Failed to parse TrailingWhitespace. This should never "
+ "happen because a TrailingWhitespace is never optional in the grammar, "
+ "so this error should've been caught by parso first.",
lines=config.lines,
raw_line=state.line,
raw_column=state.column,
+ "so this error should've been caught by parso first."
)
return trailing_whitespace
@ -181,9 +177,7 @@ def _parse_indent(
if state.column == len(line_str) and state.line == len(config.lines):
# We're at EOF, treat this as a failed speculative parse
return False
raise CSTLogicError(
"Internal Error: Column should be 0 when parsing an indent."
)
raise Exception("Internal Error: Column should be 0 when parsing an indent.")
if line_str.startswith(absolute_indent, state.column):
state.column += len(absolute_indent)
return True
@ -212,12 +206,7 @@ def _parse_newline(
newline_str = newline_match.group(0)
state.column += len(newline_str)
if state.column != len(line_str):
raise ParserSyntaxError(
"Internal Error: Found a newline, but it wasn't the EOL.",
lines=config.lines,
raw_line=state.line,
raw_column=state.column,
)
raise Exception("Internal Error: Found a newline, but it wasn't the EOL.")
if state.line < len(config.lines):
# this newline was the end of a line, and there's another line,
# therefore we should move to the next line

View file

@ -6,10 +6,9 @@
from textwrap import dedent
from typing import Callable
from unittest.mock import patch
import libcst as cst
from libcst._nodes.base import CSTValidationError
from libcst._parser.entrypoints import is_native
from libcst.testing.utils import data_provider, UnitTest
@ -171,11 +170,5 @@ class ParseErrorsTest(UnitTest):
) -> None:
with self.assertRaises(cst.ParserSyntaxError) as cm:
parse_fn()
# make sure str() doesn't blow up
self.assertIn("Syntax Error", str(cm.exception))
def test_native_fallible_into_py(self) -> None:
with patch("libcst._nodes.expression.Name._validate") as await_validate:
await_validate.side_effect = CSTValidationError("validate is broken")
with self.assertRaises((SyntaxError, cst.ParserSyntaxError)):
cst.parse_module("foo")
if not is_native():
self.assertEqual(str(cm.exception), expected)

View file

@ -27,9 +27,9 @@ except ImportError:
BaseWhitespaceParserConfig = config_mod.BaseWhitespaceParserConfig
ParserConfig = config_mod.ParserConfig
parser_config_asdict: Callable[[ParserConfig], Mapping[str, Any]] = (
config_mod.parser_config_asdict
)
parser_config_asdict: Callable[
[ParserConfig], Mapping[str, Any]
] = config_mod.parser_config_asdict
class AutoConfig(Enum):

View file

@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

View file

@ -1,4 +1,4 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# Copyright (c) Meta Platforms, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

View file

@ -9,4 +9,4 @@ try:
Token = tokenize.Token
except ImportError:
from libcst._parser.types.py_token import Token # noqa: F401
from libcst._parser.types.py_token import Token # noqa F401

View file

@ -40,10 +40,12 @@ class CodeRange:
end: CodePosition
@overload
def __init__(self, start: CodePosition, end: CodePosition) -> None: ...
def __init__(self, start: CodePosition, end: CodePosition) -> None:
...
@overload
def __init__(self, start: Tuple[int, int], end: Tuple[int, int]) -> None: ...
def __init__(self, start: Tuple[int, int], end: Tuple[int, int]) -> None:
...
def __init__(self, start: _CodePositionT, end: _CodePositionT) -> None:
if isinstance(start, tuple) and isinstance(end, tuple):

View file

@ -3,21 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import (
Any,
ClassVar,
ForwardRef,
get_args,
get_origin,
Iterable,
Literal,
Mapping,
MutableMapping,
MutableSequence,
Tuple,
TypeVar,
Union,
)
from typing import Any, Iterable, Mapping, MutableMapping, MutableSequence, Tuple
from typing_extensions import Literal
from typing_inspect import get_args, get_origin, is_classvar, is_typevar, is_union_type
try: # py37+
from typing import ForwardRef
except ImportError: # py36
# pyre-fixme[21]: Could not find name `_ForwardRef` in `typing` (stubbed).
from typing import _ForwardRef as ForwardRef
def is_value_of_type( # noqa: C901 "too complex"
@ -51,11 +46,15 @@ def is_value_of_type( # noqa: C901 "too complex"
- Forward Refs -- use `typing.get_type_hints` to resolve these
- Type[...]
"""
if expected_type is ClassVar or get_origin(expected_type) is ClassVar:
classvar_args = get_args(expected_type)
expected_type = (classvar_args[0] or Any) if classvar_args else Any
if is_classvar(expected_type):
# `ClassVar` (no subscript) is implicitly `ClassVar[Any]`
if hasattr(expected_type, "__type__"): # py36
expected_type = expected_type.__type__ or Any
else: # py37+
classvar_args = get_args(expected_type)
expected_type = (classvar_args[0] or Any) if classvar_args else Any
if type(expected_type) is TypeVar:
if is_typevar(expected_type):
# treat this the same as Any
# TODO: evaluate bounds
return True
@ -65,13 +64,16 @@ def is_value_of_type( # noqa: C901 "too complex"
if expected_origin_type == Any:
return True
elif expected_type is Union or get_origin(expected_type) is Union:
elif is_union_type(expected_type):
return any(
is_value_of_type(value, subtype) for subtype in expected_type.__args__
)
elif isinstance(expected_origin_type, type(Literal)):
literal_values = get_args(expected_type)
if hasattr(expected_type, "__values__"): # py36
literal_values = expected_type.__values__
else: # py37+
literal_values = get_args(expected_type, evaluate=True)
return any(value == literal for literal in literal_values)
elif isinstance(expected_origin_type, ForwardRef):
@ -85,11 +87,14 @@ def is_value_of_type( # noqa: C901 "too complex"
if not isinstance(value, tuple):
return False
type_args = get_args(expected_type)
type_args = get_args(expected_type, evaluate=True)
if len(type_args) == 0:
# `Tuple` (no subscript) is implicitly `Tuple[Any, ...]`
return True
if type_args is None:
return True
if len(value) != len(type_args):
return False
# TODO: Handle `Tuple[T, ...]` like `Iterable[T]`
@ -106,7 +111,7 @@ def is_value_of_type( # noqa: C901 "too complex"
if not issubclass(type(value), expected_origin_type):
return False
type_args = get_args(expected_type)
type_args = get_args(expected_type, evaluate=True)
if len(type_args) == 0:
# `Mapping` (no subscript) is implicitly `Mapping[Any, Any]`.
return True
@ -143,7 +148,7 @@ def is_value_of_type( # noqa: C901 "too complex"
if not issubclass(type(value), expected_origin_type):
return False
type_args = get_args(expected_type)
type_args = get_args(expected_type, evaluate=True)
if len(type_args) == 0:
# `Iterable` (no subscript) is implicitly `Iterable[Any]`.
return True

View file

@ -25,7 +25,6 @@ if TYPE_CHECKING:
BaseExpression,
BaseFormattedStringContent,
BaseSlice,
BaseTemplatedStringContent,
BinaryOperation,
BooleanOperation,
Call,
@ -72,9 +71,6 @@ if TYPE_CHECKING:
StarredElement,
Subscript,
SubscriptElement,
TemplatedString,
TemplatedStringExpression,
TemplatedStringText,
Tuple,
UnaryOperation,
Yield,
@ -182,7 +178,6 @@ if TYPE_CHECKING:
MatchValue,
NameItem,
Nonlocal,
ParamSpec,
Pass,
Raise,
Return,
@ -190,11 +185,6 @@ if TYPE_CHECKING:
SimpleStatementSuite,
Try,
TryStar,
TypeAlias,
TypeParam,
TypeParameters,
TypeVar,
TypeVarTuple,
While,
With,
WithItem,
@ -211,7 +201,6 @@ if TYPE_CHECKING:
class CSTTypedBaseFunctions:
@mark_no_op
def visit_Add(self, node: "Add") -> Optional[bool]:
pass
@ -1064,22 +1053,6 @@ class CSTTypedBaseFunctions:
def leave_ClassDef_whitespace_before_colon(self, node: "ClassDef") -> None:
pass
@mark_no_op
def visit_ClassDef_type_parameters(self, node: "ClassDef") -> None:
pass
@mark_no_op
def leave_ClassDef_type_parameters(self, node: "ClassDef") -> None:
pass
@mark_no_op
def visit_ClassDef_whitespace_after_type_parameters(self, node: "ClassDef") -> None:
pass
@mark_no_op
def leave_ClassDef_whitespace_after_type_parameters(self, node: "ClassDef") -> None:
pass
@mark_no_op
def visit_Colon(self, node: "Colon") -> Optional[bool]:
pass
@ -2366,26 +2339,6 @@ class CSTTypedBaseFunctions:
def leave_FunctionDef_whitespace_before_colon(self, node: "FunctionDef") -> None:
pass
@mark_no_op
def visit_FunctionDef_type_parameters(self, node: "FunctionDef") -> None:
pass
@mark_no_op
def leave_FunctionDef_type_parameters(self, node: "FunctionDef") -> None:
pass
@mark_no_op
def visit_FunctionDef_whitespace_after_type_parameters(
self, node: "FunctionDef"
) -> None:
pass
@mark_no_op
def leave_FunctionDef_whitespace_after_type_parameters(
self, node: "FunctionDef"
) -> None:
pass
@mark_no_op
def visit_GeneratorExp(self, node: "GeneratorExp") -> Optional[bool]:
pass
@ -2854,22 +2807,6 @@ class CSTTypedBaseFunctions:
def leave_Index_value(self, node: "Index") -> None:
pass
@mark_no_op
def visit_Index_star(self, node: "Index") -> None:
pass
@mark_no_op
def leave_Index_star(self, node: "Index") -> None:
pass
@mark_no_op
def visit_Index_whitespace_after_star(self, node: "Index") -> None:
pass
@mark_no_op
def leave_Index_whitespace_after_star(self, node: "Index") -> None:
pass
@mark_no_op
def visit_Integer(self, node: "Integer") -> Optional[bool]:
pass
@ -4354,34 +4291,6 @@ class CSTTypedBaseFunctions:
def leave_ParamSlash_comma(self, node: "ParamSlash") -> None:
pass
@mark_no_op
def visit_ParamSlash_whitespace_after(self, node: "ParamSlash") -> None:
pass
@mark_no_op
def leave_ParamSlash_whitespace_after(self, node: "ParamSlash") -> None:
pass
@mark_no_op
def visit_ParamSpec(self, node: "ParamSpec") -> Optional[bool]:
pass
@mark_no_op
def visit_ParamSpec_name(self, node: "ParamSpec") -> None:
pass
@mark_no_op
def leave_ParamSpec_name(self, node: "ParamSpec") -> None:
pass
@mark_no_op
def visit_ParamSpec_whitespace_after_star(self, node: "ParamSpec") -> None:
pass
@mark_no_op
def leave_ParamSpec_whitespace_after_star(self, node: "ParamSpec") -> None:
pass
@mark_no_op
def visit_ParamStar(self, node: "ParamStar") -> Optional[bool]:
pass
@ -5186,140 +5095,6 @@ class CSTTypedBaseFunctions:
def leave_SubtractAssign_whitespace_after(self, node: "SubtractAssign") -> None:
pass
@mark_no_op
def visit_TemplatedString(self, node: "TemplatedString") -> Optional[bool]:
pass
@mark_no_op
def visit_TemplatedString_parts(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def leave_TemplatedString_parts(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def visit_TemplatedString_start(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def leave_TemplatedString_start(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def visit_TemplatedString_end(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def leave_TemplatedString_end(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def visit_TemplatedString_lpar(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def leave_TemplatedString_lpar(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def visit_TemplatedString_rpar(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def leave_TemplatedString_rpar(self, node: "TemplatedString") -> None:
pass
@mark_no_op
def visit_TemplatedStringExpression(
self, node: "TemplatedStringExpression"
) -> Optional[bool]:
pass
@mark_no_op
def visit_TemplatedStringExpression_expression(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def leave_TemplatedStringExpression_expression(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def visit_TemplatedStringExpression_conversion(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def leave_TemplatedStringExpression_conversion(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def visit_TemplatedStringExpression_format_spec(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def leave_TemplatedStringExpression_format_spec(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def visit_TemplatedStringExpression_whitespace_before_expression(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def leave_TemplatedStringExpression_whitespace_before_expression(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def visit_TemplatedStringExpression_whitespace_after_expression(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def leave_TemplatedStringExpression_whitespace_after_expression(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def visit_TemplatedStringExpression_equal(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def leave_TemplatedStringExpression_equal(
self, node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def visit_TemplatedStringText(self, node: "TemplatedStringText") -> Optional[bool]:
pass
@mark_no_op
def visit_TemplatedStringText_value(self, node: "TemplatedStringText") -> None:
pass
@mark_no_op
def leave_TemplatedStringText_value(self, node: "TemplatedStringText") -> None:
pass
@mark_no_op
def visit_TrailingWhitespace(self, node: "TrailingWhitespace") -> Optional[bool]:
pass
@ -5480,206 +5255,6 @@ class CSTTypedBaseFunctions:
def leave_Tuple_rpar(self, node: "Tuple") -> None:
pass
@mark_no_op
def visit_TypeAlias(self, node: "TypeAlias") -> Optional[bool]:
pass
@mark_no_op
def visit_TypeAlias_name(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeAlias_name(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def visit_TypeAlias_value(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeAlias_value(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def visit_TypeAlias_type_parameters(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeAlias_type_parameters(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def visit_TypeAlias_whitespace_after_type(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeAlias_whitespace_after_type(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def visit_TypeAlias_whitespace_after_name(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeAlias_whitespace_after_name(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def visit_TypeAlias_whitespace_after_type_parameters(
self, node: "TypeAlias"
) -> None:
pass
@mark_no_op
def leave_TypeAlias_whitespace_after_type_parameters(
self, node: "TypeAlias"
) -> None:
pass
@mark_no_op
def visit_TypeAlias_whitespace_after_equals(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeAlias_whitespace_after_equals(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def visit_TypeAlias_semicolon(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeAlias_semicolon(self, node: "TypeAlias") -> None:
pass
@mark_no_op
def visit_TypeParam(self, node: "TypeParam") -> Optional[bool]:
pass
@mark_no_op
def visit_TypeParam_param(self, node: "TypeParam") -> None:
pass
@mark_no_op
def leave_TypeParam_param(self, node: "TypeParam") -> None:
pass
@mark_no_op
def visit_TypeParam_comma(self, node: "TypeParam") -> None:
pass
@mark_no_op
def leave_TypeParam_comma(self, node: "TypeParam") -> None:
pass
@mark_no_op
def visit_TypeParam_equal(self, node: "TypeParam") -> None:
pass
@mark_no_op
def leave_TypeParam_equal(self, node: "TypeParam") -> None:
pass
@mark_no_op
def visit_TypeParam_star(self, node: "TypeParam") -> None:
pass
@mark_no_op
def leave_TypeParam_star(self, node: "TypeParam") -> None:
pass
@mark_no_op
def visit_TypeParam_whitespace_after_star(self, node: "TypeParam") -> None:
pass
@mark_no_op
def leave_TypeParam_whitespace_after_star(self, node: "TypeParam") -> None:
pass
@mark_no_op
def visit_TypeParam_default(self, node: "TypeParam") -> None:
pass
@mark_no_op
def leave_TypeParam_default(self, node: "TypeParam") -> None:
pass
@mark_no_op
def visit_TypeParameters(self, node: "TypeParameters") -> Optional[bool]:
pass
@mark_no_op
def visit_TypeParameters_params(self, node: "TypeParameters") -> None:
pass
@mark_no_op
def leave_TypeParameters_params(self, node: "TypeParameters") -> None:
pass
@mark_no_op
def visit_TypeParameters_lbracket(self, node: "TypeParameters") -> None:
pass
@mark_no_op
def leave_TypeParameters_lbracket(self, node: "TypeParameters") -> None:
pass
@mark_no_op
def visit_TypeParameters_rbracket(self, node: "TypeParameters") -> None:
pass
@mark_no_op
def leave_TypeParameters_rbracket(self, node: "TypeParameters") -> None:
pass
@mark_no_op
def visit_TypeVar(self, node: "TypeVar") -> Optional[bool]:
pass
@mark_no_op
def visit_TypeVar_name(self, node: "TypeVar") -> None:
pass
@mark_no_op
def leave_TypeVar_name(self, node: "TypeVar") -> None:
pass
@mark_no_op
def visit_TypeVar_bound(self, node: "TypeVar") -> None:
pass
@mark_no_op
def leave_TypeVar_bound(self, node: "TypeVar") -> None:
pass
@mark_no_op
def visit_TypeVar_colon(self, node: "TypeVar") -> None:
pass
@mark_no_op
def leave_TypeVar_colon(self, node: "TypeVar") -> None:
pass
@mark_no_op
def visit_TypeVarTuple(self, node: "TypeVarTuple") -> Optional[bool]:
pass
@mark_no_op
def visit_TypeVarTuple_name(self, node: "TypeVarTuple") -> None:
pass
@mark_no_op
def leave_TypeVarTuple_name(self, node: "TypeVarTuple") -> None:
pass
@mark_no_op
def visit_TypeVarTuple_whitespace_after_star(self, node: "TypeVarTuple") -> None:
pass
@mark_no_op
def leave_TypeVarTuple_whitespace_after_star(self, node: "TypeVarTuple") -> None:
pass
@mark_no_op
def visit_UnaryOperation(self, node: "UnaryOperation") -> Optional[bool]:
pass
@ -5902,7 +5477,6 @@ class CSTTypedBaseFunctions:
class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_Add(self, original_node: "Add") -> None:
pass
@ -6405,10 +5979,6 @@ class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):
def leave_ParamSlash(self, original_node: "ParamSlash") -> None:
pass
@mark_no_op
def leave_ParamSpec(self, original_node: "ParamSpec") -> None:
pass
@mark_no_op
def leave_ParamStar(self, original_node: "ParamStar") -> None:
pass
@ -6523,20 +6093,6 @@ class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):
def leave_SubtractAssign(self, original_node: "SubtractAssign") -> None:
pass
@mark_no_op
def leave_TemplatedString(self, original_node: "TemplatedString") -> None:
pass
@mark_no_op
def leave_TemplatedStringExpression(
self, original_node: "TemplatedStringExpression"
) -> None:
pass
@mark_no_op
def leave_TemplatedStringText(self, original_node: "TemplatedStringText") -> None:
pass
@mark_no_op
def leave_TrailingWhitespace(self, original_node: "TrailingWhitespace") -> None:
pass
@ -6553,26 +6109,6 @@ class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):
def leave_Tuple(self, original_node: "Tuple") -> None:
pass
@mark_no_op
def leave_TypeAlias(self, original_node: "TypeAlias") -> None:
pass
@mark_no_op
def leave_TypeParam(self, original_node: "TypeParam") -> None:
pass
@mark_no_op
def leave_TypeParameters(self, original_node: "TypeParameters") -> None:
pass
@mark_no_op
def leave_TypeVar(self, original_node: "TypeVar") -> None:
pass
@mark_no_op
def leave_TypeVarTuple(self, original_node: "TypeVarTuple") -> None:
pass
@mark_no_op
def leave_UnaryOperation(self, original_node: "UnaryOperation") -> None:
pass
@ -6595,6 +6131,7 @@ class CSTTypedVisitorFunctions(CSTTypedBaseFunctions):
class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
pass
@mark_no_op
def leave_Add(self, original_node: "Add", updated_node: "Add") -> "BaseBinaryOp":
@ -7372,12 +6909,6 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
) -> Union["ParamSlash", MaybeSentinel]:
return updated_node
@mark_no_op
def leave_ParamSpec(
self, original_node: "ParamSpec", updated_node: "ParamSpec"
) -> "ParamSpec":
return updated_node
@mark_no_op
def leave_ParamStar(
self, original_node: "ParamStar", updated_node: "ParamStar"
@ -7525,7 +7056,7 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
@mark_no_op
def leave_StarredElement(
self, original_node: "StarredElement", updated_node: "StarredElement"
) -> "BaseExpression":
) -> Union["BaseElement", FlattenSentinel["BaseElement"], RemovalSentinel]:
return updated_node
@mark_no_op
@ -7554,34 +7085,6 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
) -> "BaseAugOp":
return updated_node
@mark_no_op
def leave_TemplatedString(
self, original_node: "TemplatedString", updated_node: "TemplatedString"
) -> "BaseExpression":
return updated_node
@mark_no_op
def leave_TemplatedStringExpression(
self,
original_node: "TemplatedStringExpression",
updated_node: "TemplatedStringExpression",
) -> Union[
"BaseTemplatedStringContent",
FlattenSentinel["BaseTemplatedStringContent"],
RemovalSentinel,
]:
return updated_node
@mark_no_op
def leave_TemplatedStringText(
self, original_node: "TemplatedStringText", updated_node: "TemplatedStringText"
) -> Union[
"BaseTemplatedStringContent",
FlattenSentinel["BaseTemplatedStringContent"],
RemovalSentinel,
]:
return updated_node
@mark_no_op
def leave_TrailingWhitespace(
self, original_node: "TrailingWhitespace", updated_node: "TrailingWhitespace"
@ -7606,38 +7109,6 @@ class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):
) -> "BaseExpression":
return updated_node
@mark_no_op
def leave_TypeAlias(
self, original_node: "TypeAlias", updated_node: "TypeAlias"
) -> Union[
"BaseSmallStatement", FlattenSentinel["BaseSmallStatement"], RemovalSentinel
]:
return updated_node
@mark_no_op
def leave_TypeParam(
self, original_node: "TypeParam", updated_node: "TypeParam"
) -> Union["TypeParam", FlattenSentinel["TypeParam"], RemovalSentinel]:
return updated_node
@mark_no_op
def leave_TypeParameters(
self, original_node: "TypeParameters", updated_node: "TypeParameters"
) -> "TypeParameters":
return updated_node
@mark_no_op
def leave_TypeVar(
self, original_node: "TypeVar", updated_node: "TypeVar"
) -> "TypeVar":
return updated_node
@mark_no_op
def leave_TypeVarTuple(
self, original_node: "TypeVarTuple", updated_node: "TypeVarTuple"
) -> "TypeVarTuple":
return updated_node
@mark_no_op
def leave_UnaryOperation(
self, original_node: "UnaryOperation", updated_node: "UnaryOperation"

View file

@ -3,8 +3,10 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Callable, cast, TypeVar
from typing import Any, Callable, cast, TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from libcst._typed_visitor import CSTTypedBaseFunctions # noqa: F401
# pyre-fixme[24]: Generic type `Callable` expects 2 type parameters.
F = TypeVar("F", bound=Callable)

View file

@ -4,8 +4,7 @@
# LICENSE file in the root directory of this source tree.
from pathlib import PurePath
from typing import TYPE_CHECKING, TypeVar, Union
from typing import TYPE_CHECKING, TypeVar
if TYPE_CHECKING:
from libcst._nodes.base import CSTNode # noqa: F401
@ -13,4 +12,3 @@ if TYPE_CHECKING:
CSTNodeT = TypeVar("CSTNodeT", bound="CSTNode")
CSTNodeT_co = TypeVar("CSTNodeT_co", bound="CSTNode", covariant=True)
StrPath = Union[str, PurePath]

View file

@ -7,12 +7,12 @@ import inspect
from collections import defaultdict
from collections.abc import Sequence as ABCSequence
from dataclasses import dataclass, fields, replace
from typing import Dict, Iterator, List, Mapping, Sequence, Set, Type, Union
from typing import Dict, Generator, List, Mapping, Sequence, Set, Type, Union
import libcst as cst
def _get_bases() -> Iterator[Type[cst.CSTNode]]:
def _get_bases() -> Generator[Type[cst.CSTNode], None, None]:
"""
Get all base classes that are subclasses of CSTNode but not an actual
node itself. This allows us to keep our types sane by refering to the
@ -27,11 +27,11 @@ def _get_bases() -> Iterator[Type[cst.CSTNode]]:
typeclasses: Sequence[Type[cst.CSTNode]] = sorted(
_get_bases(), key=lambda base: base.__name__
list(_get_bases()), key=lambda base: base.__name__
)
def _get_nodes() -> Iterator[Type[cst.CSTNode]]:
def _get_nodes() -> Generator[Type[cst.CSTNode], None, None]:
"""
Grab all CSTNodes that are not a superclass. Basically, anything that a
person might use to generate a tree.
@ -53,7 +53,7 @@ def _get_nodes() -> Iterator[Type[cst.CSTNode]]:
all_libcst_nodes: Sequence[Type[cst.CSTNode]] = sorted(
_get_nodes(), key=lambda node: node.__name__
list(_get_nodes()), key=lambda node: node.__name__
)
node_to_bases: Dict[Type[cst.CSTNode], List[Type[cst.CSTNode]]] = {}
for node in all_libcst_nodes:

View file

@ -8,7 +8,7 @@ from dataclasses import dataclass, fields
from typing import Generator, List, Optional, Sequence, Set, Tuple, Type, Union
import libcst as cst
from libcst import CSTLogicError, ensure_type, parse_expression
from libcst import ensure_type, parse_expression
from libcst.codegen.gather import all_libcst_nodes, typeclasses
CST_DIR: Set[str] = set(dir(cst))
@ -16,109 +16,6 @@ CLASS_RE = r"<class \'(.*?)\'>"
OPTIONAL_RE = r"typing\.Union\[([^,]*?), NoneType]"
class NormalizeUnions(cst.CSTTransformer):
"""
Convert a binary operation with | operators into a Union type.
For example, converts `foo | bar | baz` into `typing.Union[foo, bar, baz]`.
Special case: converts `foo | None` or `None | foo` into `typing.Optional[foo]`.
Also flattens nested typing.Union types.
"""
def leave_Subscript(
self, original_node: cst.Subscript, updated_node: cst.Subscript
) -> cst.Subscript:
# Check if this is a typing.Union
if (
isinstance(updated_node.value, cst.Attribute)
and isinstance(updated_node.value.value, cst.Name)
and updated_node.value.attr.value == "Union"
and updated_node.value.value.value == "typing"
):
# Collect all operands from any nested Unions
operands: List[cst.BaseExpression] = []
for slc in updated_node.slice:
if not isinstance(slc.slice, cst.Index):
continue
value = slc.slice.value
# If this is a nested Union, add its elements
if (
isinstance(value, cst.Subscript)
and isinstance(value.value, cst.Attribute)
and isinstance(value.value.value, cst.Name)
and value.value.attr.value == "Union"
and value.value.value.value == "typing"
):
operands.extend(
nested_slc.slice.value
for nested_slc in value.slice
if isinstance(nested_slc.slice, cst.Index)
)
else:
operands.append(value)
# flatten operands into a Union type
return cst.Subscript(
cst.Attribute(cst.Name("typing"), cst.Name("Union")),
[cst.SubscriptElement(cst.Index(operand)) for operand in operands],
)
return updated_node
def leave_BinaryOperation(
self, original_node: cst.BinaryOperation, updated_node: cst.BinaryOperation
) -> Union[cst.BinaryOperation, cst.Subscript]:
if not updated_node.operator.deep_equals(cst.BitOr()):
return updated_node
def flatten_binary_op(node: cst.BaseExpression) -> List[cst.BaseExpression]:
"""Flatten a binary operation tree into a list of operands."""
if not isinstance(node, cst.BinaryOperation):
# If it's a Union type, extract its elements
if (
isinstance(node, cst.Subscript)
and isinstance(node.value, cst.Attribute)
and isinstance(node.value.value, cst.Name)
and node.value.attr.value == "Union"
and node.value.value.value == "typing"
):
return [
slc.slice.value
for slc in node.slice
if isinstance(slc.slice, cst.Index)
]
return [node]
if not node.operator.deep_equals(cst.BitOr()):
return [node]
left_operands = flatten_binary_op(node.left)
right_operands = flatten_binary_op(node.right)
return left_operands + right_operands
# Flatten the binary operation tree into a list of operands
operands = flatten_binary_op(updated_node)
# Check for Optional case (None in union)
none_count = sum(
1 for op in operands if isinstance(op, cst.Name) and op.value == "None"
)
if none_count == 1 and len(operands) == 2:
# This is an Optional case - find the non-None operand
non_none = next(
op
for op in operands
if not (isinstance(op, cst.Name) and op.value == "None")
)
return cst.Subscript(
cst.Attribute(cst.Name("typing"), cst.Name("Optional")),
[cst.SubscriptElement(cst.Index(non_none))],
)
# Regular Union case
return cst.Subscript(
cst.Attribute(cst.Name("typing"), cst.Name("Union")),
[cst.SubscriptElement(cst.Index(operand)) for operand in operands],
)
class CleanseFullTypeNames(cst.CSTTransformer):
def leave_Call(
self, original_node: cst.Call, updated_node: cst.Call
@ -283,9 +180,9 @@ class AddWildcardsToSequenceUnions(cst.CSTTransformer):
# type blocks, even for sequence types.
return
if len(node.slice) != 1:
raise ValueError(
raise Exception(
"Unexpected number of sequence elements inside Sequence type "
"annotation!"
+ "annotation!"
)
nodeslice = node.slice[0].slice
if isinstance(nodeslice, cst.Index):
@ -368,9 +265,7 @@ def _get_raw_name(node: cst.CSTNode) -> Optional[str]:
if isinstance(node, cst.Name):
return node.value
elif isinstance(node, cst.SimpleString):
evaluated_value = node.evaluated_value
if isinstance(evaluated_value, str):
return evaluated_value
return node.evaluated_value
elif isinstance(node, cst.SubscriptElement):
return _get_raw_name(node.slice)
elif isinstance(node, cst.Index):
@ -449,14 +344,10 @@ def _get_clean_type_from_subscript(
if typecst.value.deep_equals(cst.Name("Sequence")):
# Lets attempt to widen the sequence type and alias it.
if len(typecst.slice) != 1:
raise CSTLogicError(
"Logic error, Sequence shouldn't have more than one param!"
)
raise Exception("Logic error, Sequence shouldn't have more than one param!")
inner_type = typecst.slice[0].slice
if not isinstance(inner_type, cst.Index):
raise CSTLogicError(
"Logic error, expecting Index for only Sequence element!"
)
raise Exception("Logic error, expecting Index for only Sequence element!")
inner_type = inner_type.value
if isinstance(inner_type, cst.Subscript):
@ -464,9 +355,7 @@ def _get_clean_type_from_subscript(
elif isinstance(inner_type, (cst.Name, cst.SimpleString)):
clean_inner_type = _get_clean_type_from_expression(aliases, inner_type)
else:
raise CSTLogicError(
f"Logic error, unexpected type in Sequence: {type(inner_type)}!"
)
raise Exception("Logic error, unexpected type in Sequence!")
return _get_wrapped_union_type(
typecst.deep_replace(inner_type, clean_inner_type),
@ -495,12 +384,9 @@ def _get_clean_type_and_aliases(
typestr = re.sub(OPTIONAL_RE, r"typing.Optional[\1]", typestr)
# Now, parse the expression with LibCST.
cleanser = CleanseFullTypeNames()
typecst = parse_expression(typestr)
typecst = typecst.visit(NormalizeUnions())
assert isinstance(typecst, cst.BaseExpression)
typecst = typecst.visit(CleanseFullTypeNames())
assert isinstance(typecst, cst.BaseExpression)
typecst = typecst.visit(cleanser)
aliases: List[Alias] = []
# Now, convert the type to allow for MetadataMatchType and MatchIfTrue values.
@ -509,7 +395,7 @@ def _get_clean_type_and_aliases(
elif isinstance(typecst, (cst.Name, cst.SimpleString)):
clean_type = _get_clean_type_from_expression(aliases, typecst)
else:
raise CSTLogicError(f"Logic error, unexpected top level type: {type(typecst)}!")
raise Exception("Logic error, unexpected top level type!")
# Now, insert OneOf/AllOf and MatchIfTrue into unions so we can typecheck their usage.
# This allows us to put OneOf[SomeType] or MatchIfTrue[cst.SomeType] into any
@ -555,7 +441,8 @@ generated_code.append("")
generated_code.append("")
generated_code.append("# This file was generated by libcst.codegen.gen_matcher_classes")
generated_code.append("from dataclasses import dataclass")
generated_code.append("from typing import Literal, Optional, Sequence, Union")
generated_code.append("from typing import Optional, Sequence, Union")
generated_code.append("from typing_extensions import Literal")
generated_code.append("import libcst as cst")
generated_code.append("")
generated_code.append(
@ -660,7 +547,7 @@ for node in all_libcst_nodes:
# Make sure to add an __all__ for flake8 and compatibility with "from libcst.matchers import *"
generated_code.append(f"__all__ = {repr(sorted(all_exports))}")
generated_code.append(f"__all__ = {repr(sorted(list(all_exports)))}")
if __name__ == "__main__":

View file

@ -29,7 +29,7 @@ generated_code.append("")
generated_code.append("")
for module, objects in imports.items():
generated_code.append(f"from {module} import (")
generated_code.append(f" {', '.join(sorted(objects))}")
generated_code.append(f" {', '.join(sorted(list(objects)))}")
generated_code.append(")")
# Generate the base visit_ methods

View file

@ -32,7 +32,7 @@ generated_code.append("")
generated_code.append("if TYPE_CHECKING:")
for module, objects in imports.items():
generated_code.append(f" from {module} import ( # noqa: F401")
generated_code.append(f" {', '.join(sorted(objects))}")
generated_code.append(f" {', '.join(sorted(list(objects)))}")
generated_code.append(" )")
@ -87,6 +87,7 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
generated_code.append("")
generated_code.append("")
generated_code.append("class CSTTypedTransformerFunctions(CSTTypedBaseFunctions):")
generated_code.append(" pass")
for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
name = node.__name__
if name.startswith("Base"):
@ -110,7 +111,6 @@ for node in sorted(nodebases.keys(), key=lambda node: node.__name__):
)
generated_code.append(" return updated_node")
if __name__ == "__main__":
# Output the code
print("\n".join(generated_code))

View file

@ -25,11 +25,8 @@ from libcst.codegen.transforms import (
def format_file(fname: str) -> None:
subprocess.check_call(
["ufmt", "format", fname],
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
with open(os.devnull, "w") as devnull:
subprocess.check_call(["ufmt", "format", fname], stdout=devnull, stderr=devnull)
def clean_generated_code(code: str) -> str:
@ -68,11 +65,12 @@ def codegen_visitors() -> None:
# Now, see if the file we generated causes any import errors
# by attempting to run codegen again in a new process.
subprocess.check_call(
[sys.executable, "-m", "libcst.codegen.gen_visitor_functions"],
cwd=base,
stdout=subprocess.DEVNULL,
)
with open(os.devnull, "w") as devnull:
subprocess.check_call(
["python3", "-m", "libcst.codegen.gen_visitor_functions"],
cwd=base,
stdout=devnull,
)
# If it worked, lets format the file
format_file(visitors_file)

View file

@ -3,7 +3,6 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import difflib
import os
import os.path
@ -21,20 +20,12 @@ class TestCodegenClean(UnitTest):
new_code: str,
module_name: str,
) -> None:
if old_code != new_code:
diff = difflib.unified_diff(
old_code.splitlines(keepends=True),
new_code.splitlines(keepends=True),
fromfile="old_code",
tofile="new_code",
)
diff_str = "".join(diff)
self.fail(
f"{module_name} needs new codegen, see "
+ "`python -m libcst.codegen.generate --help` "
+ "for instructions, or run `python -m libcst.codegen.generate all`. "
+ f"Diff:\n{diff_str}"
)
self.assertTrue(
old_code == new_code,
f"{module_name} needs new codegen, see "
+ "`python -m libcst.codegen.generate --help` "
+ "for instructions, or run `python -m libcst.codegen.generate all`",
)
def test_codegen_clean_visitor_functions(self) -> None:
"""
@ -132,50 +123,3 @@ class TestCodegenClean(UnitTest):
# Now that we've done simple codegen, verify that it matches.
self.assert_code_matches(old_code, new_code, "libcst.matchers._return_types")
def test_normalize_unions(self) -> None:
"""
Verifies that NormalizeUnions correctly converts binary operations with |
into Union types, with special handling for Optional cases.
"""
import libcst as cst
from libcst.codegen.gen_matcher_classes import NormalizeUnions
def assert_transforms_to(input_code: str, expected_code: str) -> None:
input_cst = cst.parse_expression(input_code)
expected_cst = cst.parse_expression(expected_code)
result = input_cst.visit(NormalizeUnions())
assert isinstance(
result, cst.BaseExpression
), f"Expected BaseExpression, got {type(result)}"
result_code = cst.Module(body=()).code_for_node(result)
expected_code_str = cst.Module(body=()).code_for_node(expected_cst)
self.assertEqual(
result_code,
expected_code_str,
f"Expected {expected_code_str}, got {result_code}",
)
# Test regular union case
assert_transforms_to("foo | bar | baz", "typing.Union[foo, bar, baz]")
# Test Optional case (None on right)
assert_transforms_to("foo | None", "typing.Optional[foo]")
# Test Optional case (None on left)
assert_transforms_to("None | foo", "typing.Optional[foo]")
# Test case with more than 2 operands including None (should remain Union)
assert_transforms_to("foo | bar | None", "typing.Union[foo, bar, None]")
# Flatten existing Union types
assert_transforms_to(
"typing.Union[foo, typing.Union[bar, baz]]", "typing.Union[foo, bar, baz]"
)
# Merge two kinds of union types
assert_transforms_to(
"foo | typing.Union[bar, baz]", "typing.Union[foo, bar, baz]"
)

View file

@ -8,25 +8,20 @@ Provides helpers for CLI interaction.
"""
import difflib
import functools
import os.path
import re
import subprocess
import sys
import time
import traceback
from concurrent.futures import as_completed, Executor
from copy import deepcopy
from dataclasses import dataclass
from multiprocessing import cpu_count
from pathlib import Path
from typing import AnyStr, Callable, cast, Dict, List, Optional, Sequence, Type, Union
from warnings import warn
from dataclasses import dataclass, replace
from multiprocessing import cpu_count, Pool
from pathlib import Path, PurePath
from typing import Any, AnyStr, cast, Dict, List, Optional, Sequence, Union
from libcst import parse_module, PartialParserConfig
from libcst.codemod._codemod import Codemod
from libcst.codemod._context import CodemodContext
from libcst.codemod._dummy_pool import DummyExecutor
from libcst.codemod._dummy_pool import DummyPool
from libcst.codemod._runner import (
SkipFile,
SkipReason,
@ -37,7 +32,6 @@ from libcst.codemod._runner import (
TransformSkip,
TransformSuccess,
)
from libcst.helpers import calculate_module_and_package
from libcst.metadata import FullRepoManager
_DEFAULT_GENERATED_CODE_MARKER: str = f"@gen{''}erated"
@ -51,7 +45,7 @@ def invoke_formatter(formatter_args: Sequence[str], code: AnyStr) -> AnyStr:
# Make sure there is something to run
if len(formatter_args) == 0:
raise ValueError("No formatter configured but code formatting requested.")
raise Exception("No formatter configured but code formatting requested.")
# Invoke the formatter, giving it the code as stdin and assuming the formatted
# code comes from stdout.
@ -95,10 +89,7 @@ def gather_files(
ret.extend(
str(p)
for p in Path(fd).rglob("*.py*")
if Path.is_file(p)
and (
str(p).endswith("py") or (include_stubs and str(p).endswith("pyi"))
)
if str(p).endswith("py") or (include_stubs and str(p).endswith("pyi"))
)
return sorted(ret)
@ -193,6 +184,30 @@ def exec_transform_with_prettyprint(
return maybe_code
def _calculate_module(repo_root: Optional[str], filename: str) -> Optional[str]:
# Given an absolute repo_root and an absolute filename, calculate the
# python module name for the file.
if repo_root is None:
# We don't have a repo root, so this is impossible to calculate.
return None
try:
relative_filename = PurePath(filename).relative_to(repo_root)
except ValueError:
# This file seems to be out of the repo root.
return None
# get rid of extension
relative_filename = relative_filename.with_suffix("")
# get rid of any special cases
if relative_filename.stem in ["__init__", "__main__"]:
relative_filename = relative_filename.parent
# Now, convert to dots to represent the python module.
return ".".join(relative_filename.parts)
@dataclass(frozen=True)
class ExecutionResult:
# File we have results for
@ -215,52 +230,11 @@ class ExecutionConfig:
unified_diff: Optional[int] = None
def _prepare_context(
repo_root: str,
def _execute_transform( # noqa: C901
transformer: Codemod,
filename: str,
scratch: Dict[str, object],
repo_manager: Optional[FullRepoManager],
) -> CodemodContext:
# determine the module and package name for this file
try:
module_name_and_package = calculate_module_and_package(repo_root, filename)
mod_name = module_name_and_package.name
pkg_name = module_name_and_package.package
except ValueError as ex:
print(f"Failed to determine module name for {filename}: {ex}", file=sys.stderr)
mod_name = None
pkg_name = None
return CodemodContext(
scratch=scratch,
filename=filename,
full_module_name=mod_name,
full_package_name=pkg_name,
metadata_manager=repo_manager,
)
def _instantiate_transformer(
transformer: Union[Codemod, Type[Codemod]],
repo_root: str,
filename: str,
original_scratch: Dict[str, object],
codemod_kwargs: Dict[str, object],
repo_manager: Optional[FullRepoManager],
) -> Codemod:
if isinstance(transformer, type):
return transformer( # type: ignore
context=_prepare_context(repo_root, filename, {}, repo_manager),
**codemod_kwargs,
)
transformer.context = _prepare_context(
repo_root, filename, deepcopy(original_scratch), repo_manager
)
return transformer
def _check_for_skip(
filename: str, config: ExecutionConfig
) -> Union[ExecutionResult, bytes]:
config: ExecutionConfig,
) -> ExecutionResult:
for pattern in config.blacklist_patterns:
if re.fullmatch(pattern, filename):
return ExecutionResult(
@ -272,46 +246,33 @@ def _check_for_skip(
),
)
with open(filename, "rb") as fp:
oldcode = fp.read()
# Skip generated files
if (
not config.include_generated
and config.generated_code_marker.encode("utf-8") in oldcode
):
return ExecutionResult(
filename=filename,
changed=False,
transform_result=TransformSkip(
skip_reason=SkipReason.GENERATED,
skip_description="Generated file.",
),
)
return oldcode
def _execute_transform(
transformer: Union[Codemod, Type[Codemod]],
filename: str,
config: ExecutionConfig,
original_scratch: Dict[str, object],
codemod_args: Optional[Dict[str, object]],
repo_manager: Optional[FullRepoManager],
) -> ExecutionResult:
warnings: list[str] = []
try:
oldcode = _check_for_skip(filename, config)
if isinstance(oldcode, ExecutionResult):
return oldcode
with open(filename, "rb") as fp:
oldcode = fp.read()
transformer_instance = _instantiate_transformer(
transformer,
config.repo_root or ".",
filename,
original_scratch,
codemod_args or {},
repo_manager,
# Skip generated files
if (
not config.include_generated
and config.generated_code_marker.encode("utf-8") in oldcode
):
return ExecutionResult(
filename=filename,
changed=False,
transform_result=TransformSkip(
skip_reason=SkipReason.GENERATED,
skip_description="Generated file.",
),
)
# Somewhat gross hack to provide the filename in the transform's context.
# We do this after the fork so that a context that was initialized with
# some defaults before calling parallel_exec_transform_with_prettyprint
# will be updated per-file.
transformer.context = replace(
transformer.context,
filename=filename,
full_module_name=_calculate_module(config.repo_root, filename),
scratch={},
)
# Run the transform, bail if we failed or if we aren't formatting code
@ -324,26 +285,55 @@ def _execute_transform(
else PartialParserConfig()
),
)
output_tree = transformer_instance.transform_module(input_tree)
output_tree = transformer.transform_module(input_tree)
newcode = output_tree.bytes
encoding = output_tree.encoding
warnings.extend(transformer_instance.context.warnings)
except KeyboardInterrupt:
return ExecutionResult(
filename=filename, changed=False, transform_result=TransformExit()
)
except SkipFile as ex:
warnings.extend(transformer_instance.context.warnings)
return ExecutionResult(
filename=filename,
changed=False,
transform_result=TransformSkip(
skip_reason=SkipReason.OTHER,
skip_description=str(ex),
warning_messages=warnings,
warning_messages=transformer.context.warnings,
),
)
except Exception as ex:
return ExecutionResult(
filename=filename,
changed=False,
transform_result=TransformFailure(
error=ex,
traceback_str=traceback.format_exc(),
warning_messages=transformer.context.warnings,
),
)
# Call formatter if needed, but only if we actually changed something in this
# file
if config.format_code and newcode != oldcode:
newcode = invoke_formatter(config.formatter_args, newcode)
try:
newcode = invoke_formatter(config.formatter_args, newcode)
except KeyboardInterrupt:
return ExecutionResult(
filename=filename,
changed=False,
transform_result=TransformExit(),
)
except Exception as ex:
return ExecutionResult(
filename=filename,
changed=False,
transform_result=TransformFailure(
error=ex,
traceback_str=traceback.format_exc(),
warning_messages=transformer.context.warnings,
),
)
# Format as unified diff if needed, otherwise save it back
changed = oldcode != newcode
@ -366,14 +356,13 @@ def _execute_transform(
return ExecutionResult(
filename=filename,
changed=changed,
transform_result=TransformSuccess(warning_messages=warnings, code=newcode),
transform_result=TransformSuccess(
warning_messages=transformer.context.warnings, code=newcode
),
)
except KeyboardInterrupt:
return ExecutionResult(
filename=filename,
changed=False,
transform_result=TransformExit(warning_messages=warnings),
filename=filename, changed=False, transform_result=TransformExit()
)
except Exception as ex:
return ExecutionResult(
@ -382,7 +371,7 @@ def _execute_transform(
transform_result=TransformFailure(
error=ex,
traceback_str=traceback.format_exc(),
warning_messages=warnings,
warning_messages=transformer.context.warnings,
),
)
@ -435,7 +424,7 @@ class Progress:
operations still to do.
"""
if files_finished <= 0 or elapsed_seconds == 0:
if files_finished <= 0:
# Technically infinite but calculating sounds better.
return "[calculating]"
@ -493,7 +482,7 @@ def _print_parallel_result(
)
# In unified diff mode, the code is a diff we must print.
if unified_diff and result.code:
if unified_diff:
print(result.code)
@ -519,8 +508,15 @@ class ParallelTransformResult:
skips: int
# Unfortunate wrapper required since there is no `istarmap_unordered`...
def _execute_transform_wrap(
job: Dict[str, Any],
) -> ExecutionResult:
return _execute_transform(**job)
def parallel_exec_transform_with_prettyprint( # noqa: C901
transform: Union[Codemod, Type[Codemod]],
transform: Codemod,
files: Sequence[str],
*,
jobs: Optional[int] = None,
@ -536,52 +532,41 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
blacklist_patterns: Sequence[str] = (),
python_version: Optional[str] = None,
repo_root: Optional[str] = None,
codemod_args: Optional[Dict[str, object]] = None,
) -> ParallelTransformResult:
"""
Given a list of files and a codemod we should apply to them, fork and apply the
codemod in parallel to all of the files, including any configured formatter. The
``jobs`` parameter controls the maximum number of in-flight transforms, and needs to
be at least 1. If not included, the number of jobs will automatically be set to the
number of CPU cores. If ``unified_diff`` is set to a number, changes to files will
be printed to stdout with ``unified_diff`` lines of context. If it is set to
``None`` or left out, files themselves will be updated with changes and formatting.
If a ``python_version`` is provided, then we will parse each source file using this
version. Otherwise, we will use the version of the currently executing python
Given a list of files and an instantiated codemod we should apply to them,
fork and apply the codemod in parallel to all of the files, including any
configured formatter. The ``jobs`` parameter controls the maximum number of
in-flight transforms, and needs to be at least 1. If not included, the number
of jobs will automatically be set to the number of CPU cores. If ``unified_diff``
is set to a number, changes to files will be printed to stdout with
``unified_diff`` lines of context. If it is set to ``None`` or left out, files
themselves will be updated with changes and formatting. If a
``python_version`` is provided, then we will parse each source file using
this version. Otherwise, we will use the version of the currently executing python
binary.
A progress indicator as well as any generated warnings will be printed to stderr. To
supress the interactive progress indicator, set ``hide_progress`` to ``True``. Files
that include the generated code marker will be skipped unless the
``include_generated`` parameter is set to ``True``. Similarly, files that match a
supplied blacklist of regex patterns will be skipped. Warnings for skipping both
blacklisted and generated files will be printed to stderr along with warnings
generated by the codemod unless ``hide_blacklisted`` and ``hide_generated`` are set
to ``True``. Files that were successfully codemodded will not be printed to stderr
unless ``show_successes`` is set to ``True``.
A progress indicator as well as any generated warnings will be printed to stderr.
To supress the interactive progress indicator, set ``hide_progress`` to ``True``.
Files that include the generated code marker will be skipped unless the
``include_generated`` parameter is set to ``True``. Similarly, files that match
a supplied blacklist of regex patterns will be skipped. Warnings for skipping
both blacklisted and generated files will be printed to stderr along with
warnings generated by the codemod unless ``hide_blacklisted`` and
``hide_generated`` are set to ``True``. Files that were successfully codemodded
will not be printed to stderr unless ``show_successes`` is set to ``True``.
We take a :class:`~libcst.codemod._codemod.Codemod` class, or an instantiated
:class:`~libcst.codemod._codemod.Codemod`. In the former case, the codemod will be
instantiated for each file, with ``codemod_args`` passed in to the constructor.
Passing an already instantiated :class:`~libcst.codemod._codemod.Codemod` is
deprecated, because it leads to sharing of the
:class:`~libcst.codemod._codemod.Codemod` instance across files, which is a common
source of hard-to-track-down bugs when the :class:`~libcst.codemod._codemod.Codemod`
tracks its state on the instance.
To make this API possible, we take an instantiated transform. This is due to
the fact that lambdas are not pickleable and pickling functions is undefined.
This means we're implicitly relying on fork behavior on UNIX-like systems, and
this function will not work on Windows systems. To create a command-line utility
that runs on Windows, please instead see
:func:`~libcst.codemod.exec_transform_with_prettyprint`.
"""
if isinstance(transform, Codemod):
warn(
"Passing transformer instances to `parallel_exec_transform_with_prettyprint` "
"is deprecated and will break in a future version. "
"Please pass the transformer class instead.",
DeprecationWarning,
stacklevel=2,
)
# Ensure that we have no duplicates, otherwise we might get race conditions
# on write.
files = sorted({os.path.abspath(f) for f in files})
files = sorted(list({os.path.abspath(f) for f in files}))
total = len(files)
progress = Progress(enabled=not hide_progress, total=total)
@ -593,12 +578,11 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
)
if jobs < 1:
raise ValueError("Must have at least one job to process!")
raise Exception("Must have at least one job to process!")
if total == 0:
return ParallelTransformResult(successes=0, failures=0, skips=0, warnings=0)
metadata_manager: Optional[FullRepoManager] = None
if repo_root is not None:
# Make sure if there is a root that we have the absolute path to it.
repo_root = os.path.abspath(repo_root)
@ -611,7 +595,10 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
transform.get_inherited_dependencies(),
)
metadata_manager.resolve_cache()
transform.context = replace(
transform.context,
metadata_manager=metadata_manager,
)
print("Executing codemod...", file=sys.stderr)
config = ExecutionConfig(
@ -625,16 +612,13 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
python_version=python_version,
)
pool_impl: Callable[[], Executor]
if total == 1 or jobs == 1:
# Simple case, we should not pay for process overhead.
# Let's just use a dummy synchronous executor.
# Let's just use a dummy synchronous pool.
jobs = 1
pool_impl = DummyExecutor
elif getattr(sys, "_is_gil_enabled", lambda: True)(): # pyre-ignore[16]
from concurrent.futures import ProcessPoolExecutor
pool_impl = functools.partial(ProcessPoolExecutor, max_workers=jobs)
pool_impl = DummyPool
else:
pool_impl = Pool
# Warm the parser, pre-fork.
parse_module(
"",
@ -644,35 +628,25 @@ def parallel_exec_transform_with_prettyprint( # noqa: C901
else PartialParserConfig()
),
)
else:
from concurrent.futures import ThreadPoolExecutor
pool_impl = functools.partial(ThreadPoolExecutor, max_workers=jobs)
successes: int = 0
failures: int = 0
warnings: int = 0
skips: int = 0
original_scratch = (
deepcopy(transform.context.scratch) if isinstance(transform, Codemod) else {}
)
with pool_impl() as executor: # type: ignore
with pool_impl(processes=jobs) as p: # type: ignore
args = [
{
"transformer": transform,
"filename": filename,
"config": config,
}
for filename in files
]
try:
futures = [
executor.submit(
_execute_transform,
transformer=transform,
filename=filename,
config=config,
original_scratch=original_scratch,
codemod_args=codemod_args,
repo_manager=metadata_manager,
)
for filename in files
]
for future in as_completed(futures):
result = future.result()
for result in p.imap_unordered(
_execute_transform_wrap, args, chunksize=chunksize
):
# Print an execution result, keep track of failures
_print_parallel_result(
result,

View file

@ -56,9 +56,9 @@ class Codemod(MetadataDependent, ABC):
"""
module = self.context.module
if module is None:
raise ValueError(
raise Exception(
f"Attempted access of {self.__class__.__name__}.module outside of "
"transform_module()."
+ "transform_module()."
)
return module

View file

@ -3,14 +3,12 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
#
from __future__ import annotations
import argparse
import inspect
from abc import ABC, abstractmethod
from typing import Dict, Generator, List, Tuple, Type, TypeVar
from typing import Dict, Generator, List, Type, TypeVar
from libcst import CSTNode, Module
from libcst import Module
from libcst.codemod._codemod import Codemod
from libcst.codemod._context import CodemodContext
from libcst.codemod._visitor import ContextAwareTransformer
@ -67,28 +65,6 @@ class CodemodCommand(Codemod, ABC):
"""
...
# Lightweight wrappers for RemoveImportsVisitor static functions
def remove_unused_import(
self,
module: str,
obj: str | None = None,
asname: str | None = None,
) -> None:
RemoveImportsVisitor.remove_unused_import(self.context, module, obj, asname)
def remove_unused_import_by_node(self, node: CSTNode) -> None:
RemoveImportsVisitor.remove_unused_import_by_node(self.context, node)
# Lightweight wrappers for AddImportsVisitor static functions
def add_needed_import(
self,
module: str,
obj: str | None = None,
asname: str | None = None,
relative: int = 0,
) -> None:
AddImportsVisitor.add_needed_import(self.context, module, obj, asname, relative)
def transform_module(self, tree: Module) -> Module:
# Overrides (but then calls) Codemod's transform_module to provide
# a spot where additional supported transforms can be attached and run.
@ -99,13 +75,13 @@ class CodemodCommand(Codemod, ABC):
# have a static method that other transforms can use which takes
# a context and other optional args and modifies its own context key
# accordingly. We import them here so that we don't have circular imports.
supported_transforms: List[Tuple[str, Type[Codemod]]] = [
(AddImportsVisitor.CONTEXT_KEY, AddImportsVisitor),
(RemoveImportsVisitor.CONTEXT_KEY, RemoveImportsVisitor),
]
supported_transforms: Dict[str, Type[Codemod]] = {
AddImportsVisitor.CONTEXT_KEY: AddImportsVisitor,
RemoveImportsVisitor.CONTEXT_KEY: RemoveImportsVisitor,
}
# For any visitors that we support auto-running, run them here if needed.
for key, transform in supported_transforms:
for key, transform in supported_transforms.items():
if key in self.context.scratch:
# We have work to do, so lets run this.
tree = self._instantiate_and_run(transform, tree)

View file

@ -44,12 +44,6 @@ class CodemodContext:
#: in the repo named ``foo/bar/baz.py``.
full_module_name: Optional[str] = None
#: The current package if a codemod is being executed against a file that
#: lives on disk, and the repository root is correctly configured. This
#: Will take the form of a dotted name such as ``foo.bar`` for a file
#: in the repo named ``foo/bar/baz.py``
full_package_name: Optional[str] = None
#: The current top level metadata wrapper for the module being modified.
#: To access computed metadata when inside an actively running codemod, use
#: the :meth:`~libcst.MetadataDependent.get_metadata` method on

View file

@ -3,47 +3,37 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import sys
from concurrent.futures import Executor, Future
from types import TracebackType
from typing import Callable, Optional, Type, TypeVar
from typing import Callable, Generator, Iterable, Optional, Type, TypeVar
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
Return = TypeVar("Return")
Params = ParamSpec("Params")
RetT = TypeVar("RetT")
ArgT = TypeVar("ArgT")
class DummyExecutor(Executor):
class DummyPool:
"""
Synchronous dummy `concurrent.futures.Executor` analogue.
Synchronous dummy `multiprocessing.Pool` analogue.
"""
def submit(
def __init__(self, processes: Optional[int] = None) -> None:
pass
def imap_unordered(
self,
fn: Callable[Params, Return],
/,
*args: Params.args,
**kwargs: Params.kwargs,
) -> Future[Return]:
future: Future[Return] = Future()
try:
result = fn(*args, **kwargs)
future.set_result(result)
except Exception as exc:
future.set_exception(exc)
return future
func: Callable[[ArgT], RetT],
iterable: Iterable[ArgT],
chunksize: Optional[int] = None,
) -> Generator[RetT, None, None]:
for args in iterable:
yield func(args)
def __enter__(self) -> "DummyExecutor":
def __enter__(self) -> "DummyPool":
return self
def __exit__(
self,
exc_type: Optional[Type[BaseException]],
exc_val: Optional[BaseException],
exc_tb: Optional[TracebackType],
exc_type: Optional[Type[Exception]],
exc: Optional[Exception],
tb: Optional[TracebackType],
) -> None:
pass

Some files were not shown because too many files have changed in this diff Show more