Repository: MasoniteFramework/orm Branch: 2.0 Commit: 0d31e53f1881 Files: 224 Total size: 1020.6 KB Directory structure: gitextract_q9lmgekj/ ├── .deepsource.toml ├── .env-example ├── .envrc ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ └── workflows/ │ ├── pythonapp.yml │ └── pythonpublish.yml ├── .gitignore ├── .pre-commit-config.yaml ├── .pypirc ├── .tool-versions ├── CONTRIBUTING.md ├── LICENSE ├── MANIFEST.in ├── README.md ├── TODO.md ├── app/ │ └── observers/ │ └── UserObserver.py ├── cc.py ├── conda/ │ ├── conda_build_config.yaml │ └── meta.yaml ├── config/ │ └── test-database.py ├── databases/ │ ├── migrations/ │ │ ├── 2018_01_09_043202_create_users_table.py │ │ ├── 2020_04_17_000000_create_friends_table.py │ │ ├── 2020_04_17_00000_create_articles_table.py │ │ ├── 2020_10_20_152904_create_table_schema_migration.py │ │ └── __init__.py │ └── seeds/ │ ├── database_seeder.py │ └── user_table_seeder.py ├── makefile ├── orm ├── pyproject.toml ├── pytest.ini ├── requirements.dev ├── requirements.txt ├── setup.py ├── src/ │ └── masoniteorm/ │ ├── .gitignore │ ├── __init__.py │ ├── collection/ │ │ ├── Collection.py │ │ └── __init__.py │ ├── commands/ │ │ ├── CanOverrideConfig.py │ │ ├── CanOverrideOptionsDefault.py │ │ ├── Command.py │ │ ├── Entry.py │ │ ├── MakeMigrationCommand.py │ │ ├── MakeModelCommand.py │ │ ├── MakeModelDocstringCommand.py │ │ ├── MakeObserverCommand.py │ │ ├── MakeSeedCommand.py │ │ ├── MigrateCommand.py │ │ ├── MigrateFreshCommand.py │ │ ├── MigrateRefreshCommand.py │ │ ├── MigrateResetCommand.py │ │ ├── MigrateRollbackCommand.py │ │ ├── MigrateStatusCommand.py │ │ ├── SeedRunCommand.py │ │ ├── ShellCommand.py │ │ ├── __init__.py │ │ └── stubs/ │ │ ├── create_migration.stub │ │ ├── create_seed.stub │ │ ├── model.stub │ │ ├── observer.stub │ │ └── table_migration.stub │ ├── config.py │ ├── connections/ │ │ ├── .gitignore │ │ ├── BaseConnection.py │ │ ├── ConnectionFactory.py │ │ ├── ConnectionResolver.py │ │ ├── MSSQLConnection.py │ │ ├── MySQLConnection.py │ │ ├── PostgresConnection.py │ │ ├── SQLiteConnection.py │ │ └── __init__.py │ ├── exceptions.py │ ├── expressions/ │ │ ├── __init__.py │ │ └── expressions.py │ ├── factories/ │ │ ├── Factory.py │ │ └── __init__.py │ ├── helpers/ │ │ ├── __init__.py │ │ └── misc.py │ ├── migrations/ │ │ ├── Migration.py │ │ └── __init__.py │ ├── models/ │ │ ├── MigrationModel.py │ │ ├── Model.py │ │ ├── Model.pyi │ │ ├── Pivot.py │ │ └── __init__.py │ ├── observers/ │ │ ├── ObservesEvents.py │ │ └── __init__.py │ ├── pagination/ │ │ ├── BasePaginator.py │ │ ├── LengthAwarePaginator.py │ │ ├── SimplePaginator.py │ │ └── __init__.py │ ├── providers/ │ │ ├── ORMProvider.py │ │ └── __init__.py │ ├── query/ │ │ ├── EagerRelation.py │ │ ├── QueryBuilder.py │ │ ├── __init__.py │ │ ├── grammars/ │ │ │ ├── BaseGrammar.py │ │ │ ├── MSSQLGrammar.py │ │ │ ├── MySQLGrammar.py │ │ │ ├── PostgresGrammar.py │ │ │ ├── SQLiteGrammar.py │ │ │ └── __init__.py │ │ └── processors/ │ │ ├── MSSQLPostProcessor.py │ │ ├── MySQLPostProcessor.py │ │ ├── PostgresPostProcessor.py │ │ ├── SQLitePostProcessor.py │ │ └── __init__.py │ ├── relationships/ │ │ ├── BaseRelationship.py │ │ ├── BelongsTo.py │ │ ├── BelongsToMany.py │ │ ├── HasMany.py │ │ ├── HasManyThrough.py │ │ ├── HasOne.py │ │ ├── HasOneThrough.py │ │ ├── MorphMany.py │ │ ├── MorphOne.py │ │ ├── MorphTo.py │ │ ├── MorphToMany.py │ │ └── __init__.py │ ├── schema/ │ │ ├── Blueprint.py │ │ ├── Column.py │ │ ├── ColumnDiff.py │ │ ├── Constraint.py │ │ ├── ForeignKeyConstraint.py │ │ ├── Index.py │ │ ├── Schema.py │ │ ├── Table.py │ │ ├── TableDiff.py │ │ ├── __init__.py │ │ └── platforms/ │ │ ├── MSSQLPlatform.py │ │ ├── MySQLPlatform.py │ │ ├── Platform.py │ │ ├── PostgresPlatform.py │ │ ├── SQLitePlatform.py │ │ └── __init__.py │ ├── scopes/ │ │ ├── BaseScope.py │ │ ├── SoftDeleteScope.py │ │ ├── SoftDeletesMixin.py │ │ ├── TimeStampsMixin.py │ │ ├── TimeStampsScope.py │ │ ├── UUIDPrimaryKeyMixin.py │ │ ├── UUIDPrimaryKeyScope.py │ │ ├── __init__.py │ │ └── scope.py │ ├── seeds/ │ │ ├── Seeder.py │ │ └── __init__.py │ ├── stubs/ │ │ ├── create-migration.html │ │ └── table-migration.html │ └── testing/ │ ├── BaseTestCaseSelectGrammar.py │ └── __init__.py └── tests/ ├── User.py ├── collection/ │ └── test_collection.py ├── commands/ │ └── test_shell.py ├── config/ │ └── test_db_url.py ├── connections/ │ └── test_base_connections.py ├── eagers/ │ └── test_eager.py ├── factories/ │ └── test_factories.py ├── integrations/ │ └── config/ │ ├── __init__.py │ └── database.py ├── models/ │ └── test_models.py ├── mssql/ │ ├── builder/ │ │ ├── test_mssql_query_builder.py │ │ └── test_mssql_query_builder_relationships.py │ ├── grammar/ │ │ ├── test_mssql_delete_grammar.py │ │ ├── test_mssql_insert_grammar.py │ │ ├── test_mssql_qmark.py │ │ ├── test_mssql_select_grammar.py │ │ └── test_mssql_update_grammar.py │ └── schema/ │ ├── test_mssql_schema_builder.py │ └── test_mssql_schema_builder_alter.py ├── mysql/ │ ├── builder/ │ │ ├── test_mysql_builder_transaction.py │ │ ├── test_query_builder.py │ │ ├── test_query_builder_scopes.py │ │ └── test_transactions.py │ ├── connections/ │ │ └── test_mysql_connection_selects.py │ ├── grammar/ │ │ ├── test_mysql_delete_grammar.py │ │ ├── test_mysql_insert_grammar.py │ │ ├── test_mysql_qmark.py │ │ ├── test_mysql_select_grammar.py │ │ └── test_mysql_update_grammar.py │ ├── model/ │ │ ├── test_accessors_and_mutators.py │ │ └── test_model.py │ ├── relationships/ │ │ ├── test_belongs_to_many.py │ │ ├── test_has_many_through.py │ │ ├── test_has_one_through.py │ │ └── test_relationships.py │ ├── schema/ │ │ ├── test_mysql_schema_builder.py │ │ └── test_mysql_schema_builder_alter.py │ └── scopes/ │ ├── test_can_use_global_scopes.py │ ├── test_can_use_scopes.py │ └── test_soft_delete.py ├── postgres/ │ ├── builder/ │ │ ├── test_postgres_query_builder.py │ │ └── test_postgres_transaction.py │ ├── grammar/ │ │ ├── test_delete_grammar.py │ │ ├── test_insert_grammar.py │ │ ├── test_select_grammar.py │ │ └── test_update_grammar.py │ ├── relationships/ │ │ └── test_postgres_relationships.py │ └── schema/ │ ├── test_postgres_schema_builder.py │ └── test_postgres_schema_builder_alter.py ├── scopes/ │ └── test_default_global_scopes.py ├── seeds/ │ └── test_seeds.py ├── sqlite/ │ ├── builder/ │ │ ├── test_sqlite_builder_insert.py │ │ ├── test_sqlite_builder_pagination.py │ │ ├── test_sqlite_query_builder.py │ │ ├── test_sqlite_query_builder_eager_loading.py │ │ ├── test_sqlite_query_builder_relationships.py │ │ └── test_sqlite_transaction.py │ ├── grammar/ │ │ ├── test_sqlite_delete_grammar.py │ │ ├── test_sqlite_insert_grammar.py │ │ ├── test_sqlite_select_grammar.py │ │ └── test_sqlite_update_grammar.py │ ├── models/ │ │ ├── test_attach_detach.py │ │ ├── test_observers.py │ │ └── test_sqlite_model.py │ ├── relationships/ │ │ ├── test_sqlite_has_many_through_relationship.py │ │ ├── test_sqlite_has_one_through_relationship.py │ │ ├── test_sqlite_polymorphic.py │ │ └── test_sqlite_relationships.py │ └── schema/ │ ├── test_sqlite_schema_builder.py │ ├── test_sqlite_schema_builder_alter.py │ ├── test_table.py │ └── test_table_diff.py └── utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .deepsource.toml ================================================ # generated by deepsource.io version = 1 test_patterns = [ 'tests/**/*.py' ] exclude_patterns = [ 'databases/migrations/*' ] [[analyzers]] name = "python" enabled = true runtime_version = "3.x.x" ================================================ FILE: .env-example ================================================ RUN_MYSQL_DATABASE=False MYSQL_DATABASE_HOST= MYSQL_DATABASE_USER= MYSQL_DATABASE_PASSWORD= MYSQL_DATABASE_DATABASE= MYSQL_DATABASE_PORT= POSTGRES_DATABASE_HOST= POSTGRES_DATABASE_USER= POSTGRES_DATABASE_PASSWORD= POSTGRES_DATABASE_DATABASE= POSTGRES_DATABASE_PORT= DATABASE_URL= ================================================ FILE: .envrc ================================================ use asdf layout python ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: A bug would be defined as an issue / problem in the original requirement. If the feature works but could be enhanced please use the feature request option. title: '' labels: 'bug' assignees: '' --- **Describe the bug** A clear and concise description of what the bug is. **To Reproduce** Steps to reproduce the behavior: 1. Go to '...' 2. Click on '....' 3. Scroll down to '....' 4. See error **Expected behavior** What do you believe should be happening? **Screenshots or code snippets** Screenshots help a lot. If applicable, add screenshots to help explain your problem. **Desktop (please complete the following information):** - OS: [e.g. Mac OSX, Windows] - Version [e.g. Big Sur, 10] **What database are you using?** - Type: [e.g. Postgres, MySQL, SQLite] - Version [e.g. 8, 9.1, 10.5] - Masonite ORM [e.g. v1.0.26, v1.0.27] **Additional context** Any other steps you are doing or any other related information that will help us debug the problem please put here. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request or enhancement about: Suggest an idea or improvement for this project. title: '' labels: enhancement, feature request assignees: '' --- **Describe the feature as you'd like to see it** A clear and concise description of what you want to happen. **What do we currently have to do now?** Give some examples or code snippets on the current way of doing things. **Additional context** Add any other context or screenshots about the feature request here. - [ ] Is this a breaking change? ================================================ FILE: .github/workflows/pythonapp.yml ================================================ name: Test Application on: [push, pull_request] jobs: build: runs-on: ubuntu-20.04 services: postgres: image: postgres:10.8 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: postgres ports: # will assign a random free host port - 5432/tcp # needed because the postgres container does not provide a healthcheck options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 mysql: image: mysql:5.7 env: MYSQL_ALLOW_EMPTY_PASSWORD: yes MYSQL_DATABASE: orm ports: - 3306 options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 strategy: matrix: python-version: ["3.6", "3.7", "3.8", "3.9"] name: Python ${{ matrix.python-version }} steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | make init-ci - name: Test with pytest env: POSTGRES_DATABASE_HOST: localhost POSTGRES_DATABASE_DATABASE: postgres POSTGRES_DATABASE_USER: postgres POSTGRES_DATABASE_PASSWORD: postgres POSTGRES_DATABASE_PORT: ${{ job.services.postgres.ports[5432] }} MYSQL_DATABASE_HOST: localhost MYSQL_DATABASE_DATABASE: orm MYSQL_DATABASE_USER: root MYSQL_DATABASE_PORT: ${{ job.services.mysql.ports[3306] }} DB_CONFIG_PATH: tests/integrations/config/database.py run: | python orm migrate --connection postgres python orm migrate --connection mysql make test lint: runs-on: ubuntu-20.04 name: Lint steps: - uses: actions/checkout@v1 - name: Set up Python 3.6 uses: actions/setup-python@v4 with: python-version: 3.6 - name: Install Flake8 run: | pip install flake8-pyproject - name: Lint run: make lint ================================================ FILE: .github/workflows/pythonpublish.yml ================================================ name: Upload Python Package on: release: types: [created] jobs: build: runs-on: ubuntu-20.04 services: postgres: image: postgres:10.8 env: POSTGRES_USER: postgres POSTGRES_PASSWORD: postgres POSTGRES_DB: postgres ports: # will assign a random free host port - 5432/tcp # needed because the postgres container does not provide a healthcheck options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 mysql: image: mysql:5.7 env: MYSQL_ALLOW_EMPTY_PASSWORD: yes MYSQL_DATABASE: orm ports: - 3306 options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 strategy: matrix: python-version: ["3.6"] name: Python ${{ matrix.python-version }} steps: - uses: actions/checkout@v1 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | make init-ci - name: Test with Pytest and Publish to PYPI env: POSTGRES_DATABASE_HOST: localhost POSTGRES_DATABASE_DATABASE: postgres POSTGRES_DATABASE_USER: postgres POSTGRES_DATABASE_PASSWORD: postgres POSTGRES_DATABASE_PORT: ${{ job.services.postgres.ports[5432] }} MYSQL_DATABASE_HOST: localhost MYSQL_DATABASE_DATABASE: orm MYSQL_DATABASE_USER: root MYSQL_DATABASE_PORT: ${{ job.services.mysql.ports[3306] }} TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} DB_CONFIG_PATH: tests/integrations/config/database.py run: | python orm migrate --connection postgres python orm migrate --connection mysql make publish - name: Discord notification env: DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }} uses: Ilshidur/action-discord@master with: args: "{{ EVENT_PAYLOAD.repository.full_name }} {{ EVENT_PAYLOAD.release.tag_name }} has been released. Checkout the full release notes here: {{ EVENT_PAYLOAD.release.html_url }}" publish: runs-on: ubuntu-latest steps: - uses: actions/checkout@v2 - name: publish-to-conda uses: fcakyon/conda-publish-action@v1.3 with: subdir: "conda" anacondatoken: ${{ secrets.ANACONDA_TOKEN }} platforms: "win osx linux" ================================================ FILE: .gitignore ================================================ venv .direnv .python-version .vscode .pytest_* **/*__pycache__* **/*.DS_Store* masonite_validation* dist .env *.db *.sqlite3 .idea **/*.egg-info htmlcov/* coverage.xml .coverage *.log build /orm.sqlite3 /.bootstrapped-pip /.ignore-pre-commit ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/psf/black rev: 25.1.0 hooks: - id: black exclude: | (?x)( ^build| ^conda ) - repo: https://github.com/pycqa/isort rev: 6.0.1 hooks: - id: isort exclude: | (?x)( ^build| ^conda ) - repo: https://github.com/pycqa/flake8 rev: 7.1.2 hooks: - id: flake8 additional_dependencies: [flake8-pyproject] exclude: | (?x)( ^build| ^conda ) ================================================ FILE: .pypirc ================================================ [distutils] index-servers = pypi pypitest [pypi] username=username password=password ================================================ FILE: .tool-versions ================================================ python 3.8.10 ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing Guide This guide is intended to explain how to contribute to this project. ## Preface Note that you do not need to write code in order to contribute to the project. You can contribute your voice, your ideas, past experiences or just join general discussions we are having in GitHub or the Slack channel. Whether its 1 hour per day or 1 minute per week. All input and ideas are important for the success of the project. That one sentence could lead to more discussion and ideas. If you have any questions at all then be sure to join the [Slack Channel](https://slack.masoniteproject.com). If you are interested in the project then it would be a great idea to read the "White Paper". This is a document about how the project works and how the classes all work together. The White Paper can be [Found Here](https://orm.masoniteproject.com/white-page) ## Issues Everything really should start with opening an issue or finding an issue. If you feel you have an idea for how the project can be improved, no matter how small, you should open an issue so we can have an open discussion with the maintainers of the project. We can discuss in that issue the solution to the problem or feature you have. If we do not feel it fits within the project then we will close the issue. Feel free to open a new issue if new information comes up. If there is already an issue open that you want to contribute ideas to, have information to add to the discussion, or want to contribute to the issue by writing code to complete the issue then please comment on the issue saying you would like to contribute to it. ## Labels To improve the quality of issues, please add any related labels to the issue you think are most relevant. You may add as many as you think make sense. There are tag descriptions on the labels section of the repo so please read those descriptions to choose which labels best work for the issue. **Please do not use any of the difficulty labels (easy, medium or hard). A maintainer will label the issue with the difficulty level after reviewing the issue** ## Difficulty Levels Before contributing, it is assumed you have basic Python or programming skills and you are able to understand the issues enough to have a discussion about it without much information direction. All issues are marked with a difficulty level to determine how much effort will be involved in closing the issue. There are several difficulty level issues: **good first issue** - Issues marked with this label are great issues to take if you have never contributed to open source before. These issues typically have a step by step solution in the issues are are intended for first time contributors to expand the pool of maintainers. **easy** - Issues marked as easy are great issues to take if you have never contributed to this project before. Take this opportunity to take a simple issue to understand how some of the code works together and a simple test. **medium** - Issues marked with this should not be worked on by someone who has not contributed to the project before. These issues assume you have basic knowledge of the codebase and can work on the issue with little direction. Discussions should be had on these issues on the best way to solve and close them. **hard** - These issues should really not be worked on unless you are a maintainer of the Masonite organization. These issues are very involved and assume advanced knowledge of the codebase. You may contribute your voice to the issue but it is not advised you work on these issues unless you are a maintainer or have contributed to the past and have completed a medium difficulty task ## Pull Request Flow If you choose to contribute to an issue via code contribution then please follow the steps below: * First you will need to fork the repository. You can do this directly in GitHub by clicking the fork icon in the top right corner of the repository. * You should then checkout the repository to your computer * Make the code change and push up your changes to a local branch. * **The branch should should follow a common naming convention. If the issue is #123 then your branch should be called `feature/123`. This helps me identify which issue the branch is supposed to fix.** * You should then open a pull request to the repository. * **All tests are required to be written before merging a pull request.** If you do not know how to write tests you can open the pull request without tests and we can discuss the best way to test the code you wrote. A maintainer or contributor could also step in and write tests for you Once the pull request is open, the code will be reviewed and we will discuss how this particular solution to the problem solves the original issue. If there are code improvements or corrections to be made then they will be discussed with maintainers of the project. ## Running Tests You should run all tests locally and make sure they pass before writing any code. This way you can be sure if your code is not breaking any tests that may be failing for other reasons. You should set up a virtual environment and run tests via pytest: ``` $ python -m venv venv $ source venv/bin/activate $ python -m pytest ``` This should run all tests successfully. The code was written in a way where you do not need a database to test the code so all tests should run fine. ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2020 Joseph Mancuso Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: MANIFEST.in ================================================ # include src/package/some/directory/* include src/masoniteorm/commands/stubs/* ================================================ FILE: README.md ================================================

Masonite ORM

Masonite Package Python Version GitHub release (latest by date) License Code style: black

## Installation & Usage All documentation can be found here [https://orm.masoniteproject.com](https://orm.masoniteproject.com). Hop on [Masonite Discord Community](https://discord.gg/TwKeFahmPZ) to ask any questions you need! ## Contributing If you would like to contribute please read the [Contributing Documentation](CONTRIBUTING.md) here. ## License Masonite ORM is open-sourced software licensed under the [MIT License](LICENSE). ================================================ FILE: TODO.md ================================================ - [x] fix scopes - need to find a new way to perform scopes - [x] scopes need to be set on the model and then passed off to the query builder - [x] global scopes - on select need to call a scope - on delete need to call a scope - need to be able to remove global scopes - need to be able to able to call something like with_trashed() - this needs to remove global scopes only from the soft deletes class ================================================ FILE: app/observers/UserObserver.py ================================================ """User Observer""" from masoniteorm.models import Model class UserObserver: def created(self, clients): """Handle the Clients "created" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def creating(self, clients): """Handle the Clients "creating" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def saving(self, clients): """Handle the Clients "saving" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def saved(self, clients): """Handle the Clients "saved" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def updating(self, clients): """Handle the Clients "updating" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def updated(self, clients): """Handle the Clients "updated" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def booted(self, clients): """Handle the Clients "booted" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def booting(self, clients): """Handle the Clients "booting" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def hydrating(self, clients): """Handle the Clients "hydrating" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def hydrated(self, clients): """Handle the Clients "hydrated" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def deleting(self, clients): """Handle the Clients "deleting" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass def deleted(self, clients): """Handle the Clients "deleted" event. Args: clients (masoniteorm.models.Model): Clients model. """ pass ================================================ FILE: cc.py ================================================ """Sandbox experimental file used to quickly feature test features of the package """ from src.masoniteorm.query import QueryBuilder from src.masoniteorm.connections import MySQLConnection, PostgresConnection from src.masoniteorm.query.grammars import MySQLGrammar, PostgresGrammar from src.masoniteorm.models import Model from src.masoniteorm.relationships import has_many import inspect # builder = QueryBuilder(connection=PostgresConnection, grammar=PostgresGrammar).table("users").on("postgres") # print(builder.where("id", 1).or_where(lambda q: q.where('id', 2).or_where('id', 3)).get()) class User(Model): __connection__ = "t" __table__ = "users" __dates__ = ["verified_at"] @has_many("id", "user_id") def articles(self): return Article class Company(Model): __connection__ = "sqlite" # user = User.create({"name": "phill", "email": "phill"}) # print(inspect.isclass(User)) user = User.first() # user.update({"verified_at": None, "updated_at": None}) print(user.serialize()) # print(user.serialize()) # print(User.first()) ================================================ FILE: conda/conda_build_config.yaml ================================================ python: - 3.6 - 3.7 - 3.8 - 3.9 ================================================ FILE: conda/meta.yaml ================================================ {% set data = load_setup_py_data() %} package: name: masonite-orm version: {{ data['version'] }} source: path: .. build: number: 0 script: python setup.py install --single-version-externally-managed --record=record.txt requirements: build: - python run: - python test: run: - python -m pytest about: home: {{ data['url'] }} license: {{ data['license'] }} summary: {{ data['description'] }} ================================================ FILE: config/test-database.py ================================================ from src.masoniteorm.connections import ConnectionResolver DATABASES = { "default": "mysql", "mysql": { "host": "127.0.0.1", "driver": "mysql", "database": "masonite", "user": "root", "password": "", "port": 3306, "log_queries": False, "options": { # } }, "postgres": { "host": "127.0.0.1", "driver": "postgres", "database": "masonite", "user": "root", "password": "", "port": 5432, "log_queries": False, "options": { # } }, "sqlite": { "driver": "sqlite", "database": "masonite.sqlite3", } } DB = ConnectionResolver().set_connection_details(DATABASES) ================================================ FILE: databases/migrations/2018_01_09_043202_create_users_table.py ================================================ from src.masoniteorm.migrations import Migration from tests.User import User class CreateUsersTable(Migration): def up(self): """Run the migrations.""" with self.schema.create('users') as table: table.increments('id') table.string('name') table.string('email').unique() table.string('password') table.string('second_password').nullable() table.string('remember_token').nullable() table.timestamp('verified_at').nullable() table.timestamps() if not self.schema._dry: User.on(self.connection).set_schema(self.schema_name).create({ 'name': 'Joe', 'email': 'joe@email.com', 'password': 'secret' }) def down(self): """Revert the migrations.""" self.schema.drop('users') ================================================ FILE: databases/migrations/2020_04_17_000000_create_friends_table.py ================================================ from src.masoniteorm.migrations.Migration import Migration class CreateFriendsTable(Migration): def up(self): """ Run the migrations. """ with self.schema.create('friends') as table: table.increments('id') table.string('name') table.integer('age') def down(self): """ Revert the migrations. """ self.schema.drop('friends') ================================================ FILE: databases/migrations/2020_04_17_00000_create_articles_table.py ================================================ from src.masoniteorm.migrations.Migration import Migration class CreateArticlesTable(Migration): def up(self): """ Run the migrations. """ with self.schema.create('fans') as table: table.increments('id') table.string('name') table.integer('age') def down(self): """ Revert the migrations. """ self.schema.drop('fans') ================================================ FILE: databases/migrations/2020_10_20_152904_create_table_schema_migration.py ================================================ """CreateTableSchemaMigration Migration.""" from src.masoniteorm.migrations import Migration class CreateTableSchemaMigration(Migration): def up(self): """ Run the migrations. """ with self.schema.create("table_schema") as table: table.increments('id') table.string('name') table.timestamps() def down(self): """ Revert the migrations. """ self.schema.drop("table_schema") ================================================ FILE: databases/migrations/__init__.py ================================================ import os import sys sys.path.append(os.getcwd()) ================================================ FILE: databases/seeds/database_seeder.py ================================================ """Base Database Seeder Module.""" from src.masoniteorm.seeds import Seeder from .user_table_seeder import UserTableSeeder class DatabaseSeeder(Seeder): def run(self): """Run the database seeds.""" self.call(UserTableSeeder) ================================================ FILE: databases/seeds/user_table_seeder.py ================================================ """UserTableSeeder Seeder.""" from src.masoniteorm.seeds import Seeder from src.masoniteorm.factories import Factory as factory from tests.User import User factory.register(User, lambda faker: {'email': faker.email()}) class UserTableSeeder(Seeder): def run(self): """Run the database seeds.""" factory(User, 5).create({ 'name': 'Joe', 'password': 'joe', }) ================================================ FILE: makefile ================================================ SHELL := /bin/bash init: .env .bootstrapped-pip .git/hooks/pre-commit init-ci: touch .ignore-pre-commit make init .bootstrapped-pip: requirements.txt requirements.dev pip install -r requirements.txt -r requirements.dev touch .bootstrapped-pip .git/hooks/pre-commit: @if ! test -e ".ignore-pre-commit"; then \ pip install pre-commit; \ pre-commit install --install-hooks; \ fi .env: cp .env-example .env # Create MySQL Database # Create Postgres Database test: init python -m pytest tests ci: make test check: format sort lint lint: flake8 src/masoniteorm/ format: init black src/masoniteorm tests/ sort: init isort src/masoniteorm tests/ coverage: python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/ python -m coveralls show: python -m pytest --cov-report term --cov-report html --cov=src/masoniteorm tests/ cov: python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/ publish: pip install twine make test python setup.py sdist twine upload dist/* rm -fr build dist .egg masonite.egg-info rm -rf dist/* pub: python setup.py sdist twine upload dist/* rm -fr build dist .egg masonite.egg-info rm -rf dist/* pypirc: cp .pypirc ~/.pypirc ================================================ FILE: orm ================================================ """Craft Command. This module is really used for backup only if the masonite CLI cannot import this for you. This can be used by running "python craft". This module is not ran when the CLI can successfully import commands for you. """ from cleo import Application from src.masoniteorm.commands import ( MigrateCommand, MigrateRollbackCommand, MigrateRefreshCommand, MigrateFreshCommand, MakeMigrationCommand, MakeObserverCommand, MakeModelCommand, MigrateStatusCommand, MigrateResetCommand, MakeSeedCommand, MakeModelDocstringCommand, SeedRunCommand, ) application = Application("ORM Version:", 0.1) application.add(MigrateCommand()) application.add(MigrateRollbackCommand()) application.add(MigrateRefreshCommand()) application.add(MigrateFreshCommand()) application.add(MakeMigrationCommand()) application.add(MakeModelCommand()) application.add(MakeModelDocstringCommand()) application.add(MakeObserverCommand()) application.add(MigrateResetCommand()) application.add(MigrateStatusCommand()) application.add(MakeSeedCommand()) application.add(SeedRunCommand()) if __name__ == "__main__": application.run() ================================================ FILE: pyproject.toml ================================================ [tool.black] target-version = ['py38'] include = '\.pyi?$' line-length = 79 [tool.isort] profile = "black" multi_line_output = 3 include_trailing_comma = true force_grid_wrap = 0 use_parentheses = true ensure_newline_before_comments = true [tool.flake8] ignore = ['E501', 'E203', 'E128', 'E402', 'E731', 'F821', 'E712', 'W503', 'F811'] #max-line-length = 79 #max-complexity = 18 per-file-ignores = [ '__init__.py:F401', ] ================================================ FILE: pytest.ini ================================================ [pytest] env = D:DB_CONFIG_PATH=config/test-database ================================================ FILE: requirements.dev ================================================ flake8-pyproject black faker pytest pytest-cov pytest-env pymysql isort ================================================ FILE: requirements.txt ================================================ inflection==0.3.1 psycopg2-binary pyodbc pendulum>=2.1,<3.1 cleo>=0.8.0,<0.9 python-dotenv==0.14.0 ================================================ FILE: setup.py ================================================ from setuptools import setup with open("README.md", "r") as fh: long_description = fh.read() setup( name="masonite-orm", # Versions should comply with PEP440. For a discussion on single-sourcing # the version across setup.py and the project code, see # https://packaging.python.org/en/latest/single_source_version.html version="2.24.0", package_dir={"": "src"}, description="The Official Masonite ORM", long_description=long_description, long_description_content_type="text/markdown", # The project's main homepage. url="https://github.com/masoniteframework/orm", # Author details author="Joe Mancuso", author_email="joe@masoniteproject.com", # Choose your license license="MIT", # If your package should include things you specify in your MANIFEST.in file # Use this option if your package needs to include files that are not python files # like html templates or css files include_package_data=True, # List run-time dependencies here. These will be installed by pip when # your project is installed. For an analysis of "install_requires" vs pip's # requirements files see: # https://packaging.python.org/en/latest/requirements.html install_requires=[ "inflection>=0.3,<0.6", "pendulum>=2.1,<3.1", "faker>=4.1.0,<14.0", "cleo>=0.8.0,<0.9", ], # See https://pypi.python.org/pypi?%3Aaction=list_classifiers classifiers=[ # How mature is this project? Common values are # 3 - Alpha # 4 - Beta # 5 - Production/Stable "Development Status :: 5 - Production/Stable", # Indicate who your project is intended for "Intended Audience :: Developers", "Topic :: Software Development :: Build Tools", "Environment :: Web Environment", # Pick your license as you wish (should match "license" above) "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", # Specify the Python versions you support here. In particular, ensure # that you indicate whether you support Python 2, Python 3 or both. "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", "Programming Language :: Python :: 3.9", "Framework :: Masonite", "Topic :: Software Development :: Libraries :: Python Modules", "Framework :: Masonite", ], # What does your project relate to? keywords="Masonite, MasoniteFramework, Python, ORM", # You can just specify the packages manually here if your project is # simple. Or you can use find_packages(). packages=[ "masoniteorm", "masoniteorm.collection", "masoniteorm.commands", "masoniteorm.connections", "masoniteorm.expressions", "masoniteorm.factories", "masoniteorm.helpers", "masoniteorm.migrations", "masoniteorm.models", "masoniteorm.observers", "masoniteorm.pagination", "masoniteorm.providers", "masoniteorm.query", "masoniteorm.query.grammars", "masoniteorm.query.processors", "masoniteorm.relationships", "masoniteorm.schema", "masoniteorm.schema.platforms", "masoniteorm.scopes", "masoniteorm.seeds", "masoniteorm.testing", ], # List additional groups of dependencies here (e.g. development # dependencies). You can install these using the following syntax, # for example: # $ pip install -e .[dev,test] # $ pip install your-package[dev,test] extras_require={ "test": ["coverage", "pytest"], }, # If there are data files included in your packages that need to be # installed, specify them here. If using Python 2.6 or less, then these # have to be included in MANIFEST.in as well. ## package_data={ ## 'sample': [], ## }, # Although 'package_data' is the preferred approach, in some case you may # need to place data files outside of your packages. See: # http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa # In this case, 'data_file' will be installed into '/my_data' ## data_files=[('my_data', ['data/data_file.txt'])], # To provide executable scripts, use entry points in preference to the # "scripts" keyword. Entry points provide cross-platform support and allow # pip to create the appropriate form of executable for the target platform. entry_points={ "console_scripts": [ "masonite-orm = masoniteorm.commands.Entry:application.run", ], }, ) ================================================ FILE: src/masoniteorm/.gitignore ================================================ ================================================ FILE: src/masoniteorm/__init__.py ================================================ from .models import Model from .factories.Factory import Factory ================================================ FILE: src/masoniteorm/collection/Collection.py ================================================ import json import random import operator from functools import reduce class Collection: """Wraps various data types to make working with them easier.""" def __init__(self, items=None): self._items = items or [] self.__appends__ = [] def take(self, number: int): """Takes a specific number of results from the items. Arguments: number {integer} -- The number of results to take. Returns: int """ if number < 0: return self[number:] return self[:number] def first(self, callback=None): """Takes the first result in the items. If a callback is given then the first result will be the result after the filter. Keyword Arguments: callback {callable} -- Used to filter the results before returning the first item. (default: {None}) Returns: mixed -- Returns whatever the first item is. """ filtered = self if callback: filtered = self.filter(callback) response = None if filtered: response = filtered[0] return response def items(self): return self._items.items() def last(self, callback=None): """Takes the last result in the items. If a callback is given then the last result will be the result after the filter. Keyword Arguments: callback {callable} -- Used to filter the results before returning the last item. (default: {None}) Returns: mixed -- Returns whatever the last item is. """ filtered = self if callback: filtered = self.filter(callback) return filtered[-1] def all(self): """Returns all the items. Returns: mixed -- Returns all items. """ return self._items def avg(self, key=None): """Returns the average of the items. If a key is given it will return the average of all the values of the key. Keyword Arguments: key {string} -- The key to use to find the average of all the values of that key. (default: {None}) Returns: int -- Returns the average. """ result = 0 items = self._get_value(key) or self._items try: result = sum(items) / len(items) except TypeError: pass return result def max(self, key=None): """Returns the max of the items. If a key is given it will return the max of all the values of the key. Keyword Arguments: key {string} -- The key to use to find the max of all the values of that key. (default: {None}) Returns: int -- Returns the max. """ result = 0 items = self._get_value(key) or self._items try: return max(items) except (TypeError, ValueError): pass return result def min(self, key=None): """Returns the min of the items. If a key is given it will return the min of all the values of the key. Keyword Arguments: key {string} -- The key to use to find the min of all the values of that key. (default: {None}) Returns: int -- Returns the min. """ result = 0 items = self._get_value(key) or self._items try: return min(items) except (TypeError, ValueError): pass return result def chunk(self, size: int): """Chunks the items. Keyword Arguments: size {integer} -- The number of values in each chunk. Returns: int -- Returns the average. """ items = [] for i in range(0, self.count(), size): items.append(self[i : i + size]) return self.__class__(items) def collapse(self): items = [] for item in self: items += self.__get_items(item) return self.__class__(items) def contains(self, key, value=None): if value: return self.contains(lambda x: self._data_get(x, key) == value) if self._check_is_callable(key, raise_exception=False): return self.first(key) is not None return key in self def count(self): return len(self._items) def diff(self, items): items = self.__get_items(items) return self.__class__([x for x in self if x not in items]) def each(self, callback): self._check_is_callable(callback) for k, v in enumerate(self): result = callback(v) if not result: break self[k] = result def every(self, callback): self._check_is_callable(callback) return all([callback(x) for x in self]) def filter(self, callback): self._check_is_callable(callback) return self.__class__(list(filter(callback, self))) def flatten(self): def _flatten(items): if isinstance(items, dict): for v in items.values(): for x in _flatten(v): yield x elif isinstance(items, list): for i in items: for j in _flatten(i): yield j else: yield items return self.__class__(list(_flatten(self._items))) def forget(self, *keys): keys = reversed(sorted(keys)) for key in keys: del self[key] return self def for_page(self, page, number): return self.__class__(self[page:number]) def get(self, key, default=None): try: return self[key] except IndexError: pass return self._value(default) def implode(self, glue=",", key=None): first = self.first() if not isinstance(first, str) and key: return glue.join(self.pluck(key)) return glue.join([str(x) for x in self]) def is_empty(self): return not self def map(self, callback): self._check_is_callable(callback) items = [callback(x) for x in self] return self.__class__(items) def map_into(self, cls, method=None, **kwargs): results = [] for item in self: if method: results.append(getattr(cls, method)(item, **kwargs)) else: results.append(cls(item)) return self.__class__(results) def merge(self, items): if isinstance(items, Collection): items = items._items elif not isinstance(items, list): raise ValueError("Unable to merge uncompatible types") items = self.__get_items(items) self._items += items return self def pluck(self, value, key=None, keep_nulls=True): if key: attributes = {} else: attributes = [] if isinstance(self._items, dict): return Collection([self._items.get(value)]) for item in self: if isinstance(item, dict): iterable = item.items() elif hasattr(item, "serialize"): iterable = item.serialize().items() else: iterable = self.all().items() for k, v in iterable: if keep_nulls is False and v is None: continue if k == value: if key: attributes[self._data_get(item, key)] = self._data_get( item, value ) else: attributes.append(v) return Collection(attributes) def pop(self): last = self._items.pop() return last def prepend(self, value): self._items.insert(0, value) return self def pull(self, key): value = self.get(key) self.forget(key) return value def push(self, value): self._items.append(value) def put(self, key, value): self._items[key] = value return self def random(self, count=None): """Returns a random item of the collection.""" collection_count = self.count() if collection_count == 0: return None elif count and count > collection_count: raise ValueError("count argument must be inferior to collection length.") elif count: self._items = random.sample(self._items, k=count) return self else: return random.choice(self._items) def reduce(self, callback, initial=0): return reduce(callback, self, initial) def reject(self, callback): self._check_is_callable(callback) items = self._get_value(callback) or self._items self._items = items def reverse(self): self._items = self[::-1] def serialize(self, *args, **kwargs): def _serialize(item): if self.__appends__: item.set_appends(self.__appends__) if hasattr(item, "serialize"): return item.serialize(*args, **kwargs) elif hasattr(item, "to_dict"): return item.to_dict() return item return list(map(_serialize, self)) def add_relation(self, result=None): for model in self._items: model.add_relation(result or {}) return self def shift(self): return self.pull(0) def sort(self, key=None): if key: self._items.sort(key=lambda x: x[key], reverse=False) return self self._items = sorted(self) return self def sum(self, key=None): result = 0 items = self._get_value(key) or self._items try: result = sum(items) except TypeError: pass return result def to_json(self, **kwargs): return json.dumps(self.serialize(), **kwargs) def group_by(self, key): from itertools import groupby self.sort(key) new_dict = {} for k, v in groupby(self._items, key=lambda x: x[key]): new_dict.update({k: list(v)}) return Collection(new_dict) def transform(self, callback): self._check_is_callable(callback) self._items = self._get_value(callback) def unique(self, key=None): if not key: items = list(set(self._items)) return self.__class__(items) keys = set() items = [] if isinstance(self.all(), dict): return self for item in self: if isinstance(item, dict): comparison = item.get(key) elif isinstance(item, str): comparison = item else: comparison = getattr(item, key) if comparison not in keys: items.append(item) keys.add(comparison) return self.__class__(items) def where(self, key, *args): op = "==" value = args[0] if len(args) >= 2: op = args[0] value = args[1] attributes = [] for item in self._items: if isinstance(item, dict): comparison = item.get(key) else: comparison = getattr(item, key) if hasattr(item, key) else False if self._make_comparison(comparison, value, op): attributes.append(item) return self.__class__(attributes) def where_in(self, key, args: list) -> "Collection": # Compatibility patch - allow numeric strings to match integers # (if all args are numeric strings) if all( [isinstance(arg, str) and arg.isnumeric() for arg in args] ): return self.where_in(key, [int(arg) for arg in args]) attributes = [] for item in self._items: if isinstance(item, dict): if key not in item: continue comparison = item.get(key) else: if not hasattr(item, key): continue comparison = getattr(item, key) if comparison in args: attributes.append(item) return self.__class__(attributes) def where_not_in(self, key, args: list) -> "Collection": # Compatibility patch - allow numeric strings to match integers # (if all args are numeric strings) if all( [isinstance(arg, str) and arg.isnumeric() for arg in args] ): return self.where_not_in(key, [int(arg) for arg in args]) attributes = [] for item in self._items: if isinstance(item, dict): if key not in item: continue comparison = item.get(key) else: if not hasattr(item, key): continue comparison = getattr(item, key) if comparison not in args: attributes.append(item) return self.__class__(attributes) def zip(self, items): items = self.__get_items(items) if not isinstance(items, list): raise ValueError("The 'items' parameter must be a list or a Collection") _items = [] for x, y in zip(self, items): _items.append([x, y]) return self.__class__(_items) def set_appends(self, appends): """ Set the attributes that should be appended to the Collection. :rtype: list """ self.__appends__ += appends return self def _get_value(self, key): if not key: return None items = [] for item in self: if isinstance(key, str): if hasattr(item, key) or (key in item): items.append(getattr(item, key, item[key])) elif callable(key): result = key(item) if result: items.append(result) return items def _data_get(self, item, key, default=None): try: if isinstance(item, (list, tuple, dict)): item = item[key] elif isinstance(item, object): item = getattr(item, key) except (IndexError, AttributeError, KeyError, TypeError): return self._value(default) return item def _value(self, value): if callable(value): return value() return value def _check_is_callable(self, callback, raise_exception=True): if not callable(callback): if not raise_exception: return False raise ValueError("The 'callback' should be a function") return True def _make_comparison(self, a, b, op): operators = { "<": operator.lt, "<=": operator.le, "==": operator.eq, "!=": operator.ne, ">": operator.gt, ">=": operator.ge, } return operators[op](str(a), str(b)) def __iter__(self): for item in self._items: yield item def __eq__(self, other): other = self.__get_items(other) return other == self._items def __getitem__(self, item): if isinstance(item, slice): return self.__class__(self._items[item]) if isinstance(item, dict): return self._items.get(item, None) try: return self._items[item] except KeyError: return None def __setitem__(self, key, value): self._items[key] = value def __delitem__(self, key): del self._items[key] def __ne__(self, other): other = self.__get_items(other) return other != self._items def __len__(self): return len(self._items) def __le__(self, other): other = self.__get_items(other) return self._items <= other def __lt__(self, other): other = self.__get_items(other) return self._items < other def __ge__(self, other): other = self.__get_items(other) return self._items >= other def __gt__(self, other): other = self.__get_items(other) return self._items > other @classmethod def __get_items(cls, items): if isinstance(items, Collection): items = items.all() return items ================================================ FILE: src/masoniteorm/collection/__init__.py ================================================ from .Collection import Collection ================================================ FILE: src/masoniteorm/commands/CanOverrideConfig.py ================================================ from cleo import Command class CanOverrideConfig(Command): def __init__(self): super().__init__() self.add_option() def add_option(self): # 8 is the required flag constant in cleo self._config.add_option( "config", "C", 8, description="The path to the ORM configuration file. If not given DB_CONFIG_PATH env variable will be used and finally 'config.database'.", ) ================================================ FILE: src/masoniteorm/commands/CanOverrideOptionsDefault.py ================================================ from inflection import underscore class CanOverrideOptionsDefault: """Command mixin to allow to override optional argument default values when instantiating the command. Example: SomeCommand(app, option1="other/default"). If an argument long name is using - then use _ in keyword argument: Example: SomeCommand(app, option_1="other/default") for an option named in option-1 """ def __init__(self, **kwargs): super().__init__() self.overriden_default = kwargs for option_name, option in self.config.options.items(): # Cleo does not authorize _ in option name but - are authorized and unfortunately - # cannot be used in Python variables. So underscore() is called to make sure that # an option like 'option-a' will be accessible with 'option_a' in kwargs default = self.overriden_default.get(underscore(option_name)) if default: option.set_default(default) ================================================ FILE: src/masoniteorm/commands/Command.py ================================================ from .CanOverrideConfig import CanOverrideConfig from .CanOverrideOptionsDefault import CanOverrideOptionsDefault class Command(CanOverrideOptionsDefault, CanOverrideConfig): pass ================================================ FILE: src/masoniteorm/commands/Entry.py ================================================ """Craft Command. This module is really used for backup only if the masonite CLI cannot import this for you. This can be used by running "python craft". This module is not ran when the CLI can successfully import commands for you. """ from cleo import Application from . import ( MigrateCommand, MigrateRollbackCommand, MigrateRefreshCommand, MigrateFreshCommand, MakeMigrationCommand, MakeModelCommand, MakeModelDocstringCommand, MakeObserverCommand, MigrateStatusCommand, MigrateResetCommand, MakeSeedCommand, SeedRunCommand, ShellCommand, ) application = Application("ORM Version:", 0.1) application.add(MigrateCommand()) application.add(MigrateRollbackCommand()) application.add(MigrateRefreshCommand()) application.add(MigrateFreshCommand()) application.add(MakeMigrationCommand()) application.add(MakeModelCommand()) application.add(MakeModelDocstringCommand()) application.add(MakeObserverCommand()) application.add(MigrateResetCommand()) application.add(MigrateStatusCommand()) application.add(MakeSeedCommand()) application.add(SeedRunCommand()) application.add(ShellCommand()) if __name__ == "__main__": application.run() ================================================ FILE: src/masoniteorm/commands/MakeMigrationCommand.py ================================================ import datetime import os import pathlib from inflection import camelize, tableize from .Command import Command class MakeMigrationCommand(Command): """ Creates a new migration file. migration {name : The name of the migration} {--c|create=None : The table to create} {--t|table=None : The table to alter} {--d|directory=databases/migrations : The location of the migration directory} """ def handle(self): name = self.argument("name") now = datetime.datetime.today() if self.option("create") != "None": table = self.option("create") stub_file = "create_migration" else: table = self.option("table") stub_file = "table_migration" if table == "None": table = tableize(name.replace("create_", "").replace("_table", "")) stub_file = "create_migration" migration_directory = self.option("directory") with open( os.path.join( pathlib.Path(__file__).parent.absolute(), f"stubs/{stub_file}.stub" ) ) as fp: output = fp.read() output = output.replace("__MIGRATION_NAME__", camelize(name)) output = output.replace("__TABLE_NAME__", table) file_name = f"{now.strftime('%Y_%m_%d_%H%M%S')}_{name}.py" with open(os.path.join(os.getcwd(), migration_directory, file_name), "w") as fp: fp.write(output) self.info( f"Migration file created: {os.path.join(migration_directory, file_name)}" ) ================================================ FILE: src/masoniteorm/commands/MakeModelCommand.py ================================================ import os import pathlib from inflection import camelize, tableize, underscore from .Command import Command class MakeModelCommand(Command): """ Creates a new model file. model {name : The name of the model} {--m|migration : Optionally create a migration file} {--s|seeder : Optionally create a seeder file} {--c|create : If the migration file should create a table} {--t|table : If the migration file should modify an existing table} {--p|pep : Makes the file into pep 8 standards} {--d|directory=app : The location of the model directory} {--D|migrations-directory=databases/migrations : The location of the migration directory} {--S|seeders-directory=databases/seeds : The location of the seeders directory} """ def handle(self): name = self.argument("name") model_directory = self.option("directory") with open( os.path.join(pathlib.Path(__file__).parent.absolute(), "stubs/model.stub") ) as fp: output = fp.read() output = output.replace("__CLASS__", camelize(name)) if self.option("pep"): file_name = f"{underscore(name)}.py" else: file_name = f"{camelize(name)}.py" full_directory_path = os.path.join(os.getcwd(), model_directory) if os.path.exists(os.path.join(full_directory_path, file_name)): self.line( f'Model "{name}" Already Exists ({full_directory_path}/{file_name})' ) return os.makedirs(os.path.dirname(os.path.join(full_directory_path)), exist_ok=True) with open(os.path.join(os.getcwd(), model_directory, file_name), "w+") as fp: fp.write(output) self.info(f"Model created: {os.path.join(model_directory, file_name)}") if self.option("migration"): migrations_directory = self.option("migrations-directory") if self.option("table"): self.call( "migration", f"update_{tableize(name)}_table --table {tableize(name)} --directory {migrations_directory}", ) else: self.call( "migration", f"create_{tableize(name)}_table --create {tableize(name)} --directory {migrations_directory}", ) if self.option("seeder"): directory = self.option("seeders-directory") self.call("seed", f"{self.argument('name')} --directory {directory}") ================================================ FILE: src/masoniteorm/commands/MakeModelDocstringCommand.py ================================================ from ..config import load_config from .Command import Command class MakeModelDocstringCommand(Command): """ Generate model docstring and type hints (for auto-completion). model:docstring {table : The table you want to generate docstring and type hints} {--t|type-hints : The table you want to generate docstring and type hints} {--c|connection=default : The connection you want to use} """ def handle(self): table = self.argument("table") DB = load_config(self.option("config")).DB schema = DB.get_schema_builder(self.option("connection")) if not schema.has_table(table): return self.line_error( f"There is no such table {table} for this connection." ) self.info(f"Model Docstring for table: {table}") print('"""') for _, column in schema.get_columns(table).items(): length = f"({column.length})" if column.length else "" default = f" default: {column.default}" print(f"{column.name}: {column.column_type}{length}{default}") print('"""') if self.option("type-hints"): self.info(f"Model Type Hints for table: {table}") for name, column in schema.get_columns(table).items(): print(f" {name}:{column.column_python_type.__name__}") ================================================ FILE: src/masoniteorm/commands/MakeObserverCommand.py ================================================ import os import pathlib from inflection import camelize, underscore from .Command import Command class MakeObserverCommand(Command): """ Creates a new observer file. observer {name : The name of the observer} {--m|model=None : The name of the model} {--d|directory=app/observers : The location of the observers directory} """ def handle(self): name = self.argument("name") model = self.option("model") if model == "None": model = name observer_directory = self.option("directory") with open( os.path.join( pathlib.Path(__file__).parent.absolute(), "stubs/observer.stub" ) ) as fp: output = fp.read() output = output.replace("__CLASS__", camelize(name)) output = output.replace("__MODEL_VARIABLE__", underscore(model)) output = output.replace("__MODEL__", camelize(model)) file_name = f"{camelize(name)}Observer.py" full_directory_path = os.path.join(os.getcwd(), observer_directory) if os.path.exists(os.path.join(full_directory_path, file_name)): self.line( f'Observer "{name}" Already Exists ({full_directory_path}/{file_name})' ) return os.makedirs(os.path.join(full_directory_path), exist_ok=True) with open(os.path.join(os.getcwd(), observer_directory, file_name), "w+") as fp: fp.write(output) self.info(f"Observer created: {file_name}") ================================================ FILE: src/masoniteorm/commands/MakeSeedCommand.py ================================================ import os import pathlib from inflection import camelize, underscore from .Command import Command class MakeSeedCommand(Command): """ Creates a new seed file. seed {name : The name of the seed} {--d|directory=databases/seeds : The location of the seed directory} """ def handle(self): # get the contents of a stub file # replace the placeholders of a stub file # output the content to a file location name = self.argument("name") + "TableSeeder" seed_directory = self.option("directory") file_name = underscore(name) stub_file = "create_seed" with open( os.path.join( pathlib.Path(__file__).parent.absolute(), f"stubs/{stub_file}.stub" ) ) as fp: output = fp.read() output = output.replace("__SEEDER_NAME__", camelize(name)) file_name = f"{underscore(name)}.py" full_path = pathlib.Path(os.path.join(os.getcwd(), seed_directory, file_name)) path_normalized = pathlib.Path(seed_directory) / pathlib.Path(file_name) if os.path.exists(full_path): return self.line(f"{path_normalized} already exists.") with open(full_path, "w") as fp: fp.write(output) self.info(f"Seed file created: {path_normalized}") ================================================ FILE: src/masoniteorm/commands/MigrateCommand.py ================================================ import os from ..migrations import Migration from .Command import Command class MigrateCommand(Command): """ Run migrations. migrate {--m|migration=all : Migration's name to be migrated} {--c|connection=default : The connection you want to run migrations on} {--f|force : Force migrations without prompt in production} {--s|show : Shows the output of SQL for migrations that would be running} {--schema=? : Sets the schema to be migrated} {--d|directory=databases/migrations : The location of the migration directory} """ def handle(self): # prompt user for confirmation in production if os.getenv("APP_ENV") == "production" and not self.option("force"): answer = "" while answer not in ["y", "n"]: answer = input( "Do you want to run migrations in PRODUCTION ? (y/n)\n" ).lower() if answer != "y": self.info("Migrations cancelled") exit(0) migration = Migration( command_class=self, connection=self.option("connection"), migration_directory=self.option("directory"), config_path=self.option("config"), schema=self.option("schema"), ) migration.create_table_if_not_exists() if not migration.get_unran_migrations(): self.info("Nothing To Migrate!") return migration_name = self.option("migration") show_output = self.option("show") migration.migrate(migration=migration_name, output=show_output) ================================================ FILE: src/masoniteorm/commands/MigrateFreshCommand.py ================================================ from ..migrations import Migration from .Command import Command class MigrateFreshCommand(Command): """ Drops all tables and migrates them again. migrate:fresh {--c|connection=default : The connection you want to run migrations on} {--d|directory=databases/migrations : The location of the migration directory} {--f|ignore-fk=? : The connection you want to run migrations on} {--s|seed=? : Seed database after fresh. The seeder to be ran can be provided in argument} {--schema=? : Sets the schema to be migrated} {--D|seed-directory=databases/seeds : The location of the seed directory if seed option is used.} """ def handle(self): migration = Migration( command_class=self, connection=self.option("connection"), migration_directory=self.option("directory"), config_path=self.option("config"), schema=self.option("schema"), ) migration.fresh(ignore_fk=self.option("ignore-fk")) self.line("") if self.option("seed") == "null": self.call( "seed:run", f"None --directory {self.option('seed-directory')} --connection {self.option('connection')}", ) elif self.option("seed"): self.call( "seed:run", f"{self.option('seed')} --directory {self.option('seed-directory')} --connection {self.option('connection')}", ) ================================================ FILE: src/masoniteorm/commands/MigrateRefreshCommand.py ================================================ from ..migrations import Migration from .Command import Command class MigrateRefreshCommand(Command): """ Rolls back migrations and migrates them again. migrate:refresh {--m|migration=all : Migration's name to be refreshed} {--c|connection=default : The connection you want to run migrations on} {--d|directory=databases/migrations : The location of the migration directory} {--s|seed=? : Seed database after refresh. The seeder to be ran can be provided in argument} {--schema=? : Sets the schema to be migrated} {--D|seed-directory=databases/seeds : The location of the seed directory if seed option is used.} """ def handle(self): migration = Migration( command_class=self, connection=self.option("connection"), migration_directory=self.option("directory"), config_path=self.option("config"), schema=self.option("schema"), ) migration.refresh(self.option("migration")) self.line("") if self.option("seed") == "null": self.call( "seed:run", f"None --directory {self.option('seed-directory')} --connection {self.option('connection')}", ) elif self.option("seed"): self.call( "seed:run", f"{self.option('seed')} --directory {self.option('seed-directory')} --connection {self.option('connection')}", ) ================================================ FILE: src/masoniteorm/commands/MigrateResetCommand.py ================================================ from ..migrations import Migration from .Command import Command class MigrateResetCommand(Command): """ Reset migrations. migrate:reset {--m|migration=all : Migration's name to be rollback} {--c|connection=default : The connection you want to run migrations on} {--schema=? : Sets the schema to be migrated} {--d|directory=databases/migrations : The location of the migration directory} """ def handle(self): migration = Migration( command_class=self, connection=self.option("connection"), migration_directory=self.option("directory"), config_path=self.option("config"), schema=self.option("schema"), ) migration.reset(self.option("migration")) ================================================ FILE: src/masoniteorm/commands/MigrateRollbackCommand.py ================================================ from ..migrations import Migration from .Command import Command class MigrateRollbackCommand(Command): """ Rolls back the last batch of migrations. migrate:rollback {--m|migration=all : Migration's name to be rollback} {--c|connection=default : The connection you want to run migrations on} {--s|show : Shows the output of SQL for migrations that would be running} {--schema=? : Sets the schema to be migrated} {--d|directory=databases/migrations : The location of the migration directory} """ def handle(self): Migration( command_class=self, connection=self.option("connection"), migration_directory=self.option("directory"), config_path=self.option("config"), schema=self.option("schema"), ).rollback(migration=self.option("migration"), output=self.option("show")) ================================================ FILE: src/masoniteorm/commands/MigrateStatusCommand.py ================================================ from ..migrations import Migration from .Command import Command class MigrateStatusCommand(Command): """ Display migrations status. migrate:status {--c|connection=default : The connection you want to run migrations on} {--schema=? : Sets the schema to be migrated} {--d|directory=databases/migrations : The location of the migration directory} """ def handle(self): migration = Migration( command_class=self, connection=self.option("connection"), migration_directory=self.option("directory"), config_path=self.option("config"), schema=self.option("schema"), ) migration.create_table_if_not_exists() table = self.table() table.set_header_row(["Ran?", "Migration", "Batch"]) migrations = [] for migration_data in migration.get_ran_migrations(): migration_file = migration_data["migration_file"] batch = migration_data["batch"] migrations.append( [ "Y", f"{migration_file}", f"{batch}", ] ) for migration_file in migration.get_unran_migrations(): migrations.append( [ "N", f"{migration_file}", "-", ] ) table.set_rows(migrations) table.render(self.io) ================================================ FILE: src/masoniteorm/commands/SeedRunCommand.py ================================================ from inflection import camelize, underscore from ..seeds import Seeder from .Command import Command class SeedRunCommand(Command): """ Run seeds. seed:run {--c|connection=default : The connection you want to run migrations on} {--dry : If the seed should run in dry mode} {table=None : Name of the table to seed} {--d|directory=databases/seeds : The location of the seed directory} """ def handle(self): seeder = Seeder( dry=self.option("dry"), seed_path=self.option("directory"), connection=self.option("connection"), ) if self.argument("table") == "None": seeder.run_database_seed() seeder_seeded = "Database Seeder" else: table = self.argument("table") seeder_file = ( f"{underscore(table)}_table_seeder.{camelize(table)}TableSeeder" ) seeder.run_specific_seed(seeder_file) seeder_seeded = f"{camelize(table)}TableSeeder" self.line(f"{seeder_seeded} seeded!") ================================================ FILE: src/masoniteorm/commands/ShellCommand.py ================================================ import subprocess import os import re import shlex from collections import OrderedDict from ..config import load_config from .Command import Command class ShellCommand(Command): """ Connect to your database interactive terminal. shell {--c|connection=default : The connection you want to use to connect to interactive terminal} {--s|show=? : Display the command which will be called to connect} """ shell_programs = { "mysql": "mysql", "postgres": "psql", "sqlite": "sqlite3", "mssql": "sqlcmd", } def handle(self): resolver = load_config(self.option("config")).DB connection = self.option("connection") if connection == "default": connection = resolver.get_connection_details()["default"] config = resolver.get_connection_information(connection) if not config.get("full_details"): self.line( f"Connection configuration for '{connection}' not found !" ) exit(-1) command, env = self.get_command(config) if self.option("show"): cleaned_command = self.hide_sensitive_options(config, command) self.comment(cleaned_command) self.line("") # let shlex split command in a list as it's safer command_args = shlex.split(command) try: subprocess.run(command_args, check=True, env=env) except FileNotFoundError: self.line( f"Cannot find {config.get('full_details').get('driver')} program ! Please ensure you can call this program in your shell first." ) exit(-1) except subprocess.CalledProcessError: self.line("An error happened calling the command.") exit(-1) def get_shell_program(self, connection): """Get the database shell program to run.""" return self.shell_programs.get(connection) def get_command(self, config): """Get the command to run as a string.""" driver = config.get("full_details").get("driver") program = self.get_shell_program(driver) try: get_driver_args = getattr(self, f"get_{driver}_args") except AttributeError: self.line( f"Connecting with driver '{driver}' is not implemented !" ) exit(-1) args, options = get_driver_args(config) # process positional arguments args = " ".join(args) # process optional arguments options = self.remove_empty_options(options) options_string = " ".join( f"{option} {value}" if value else option for option, value in options.items() ) # finally build command string command = program if args: command += f" {args}" if options_string: command += f" {options_string}" # prepare environment in which command will be run # some drivers need to define env variable such as psql for specifying password try: driver_env = getattr(self, f"get_{driver}_env")(config) except AttributeError: driver_env = {} command_env = {**os.environ.copy(), **driver_env} return command, command_env def get_mysql_args(self, config): """Get command positional arguments and options for MySQL driver.""" args = [config.get("database")] options = OrderedDict( { "--host": config.get("host"), "--port": config.get("port"), "--user": config.get("user"), "--password": config.get("password"), "--default-character-set": config.get("options", {}).get("charset"), } ) return args, options def get_postgres_args(self, config): """Get command positional arguments and options for PostgreSQL driver.""" args = [config.get("database")] options = OrderedDict( { "--host": config.get("host"), "--port": config.get("port"), "--username": config.get("user"), } ) return args, options def get_postgres_env(self, config): return {"PGPASSWORD": config.get("password")} def get_mssql_args(self, config): """Get command positional arguments and options for MSSQL driver.""" args = [] # instance of SQL Server: -S [protocol:]server[instance_name][,port] server = f"tcp:{config.get('host')}" if config.get("port"): server += f",{config.get('port')}" trusted_connection = config.get("options").get("trusted_connection") == "Yes" options = OrderedDict( { "-d": config.get("database"), "-U": config.get("user"), "-P": config.get("password"), "-S": server, "-E": trusted_connection, "-t": config.get("options", {}).get("connection_timeout"), } ) return args, options def get_sqlite_args(self, config): """Get command positional arguments and options for SQLite driver.""" args = [config.get("database")] options = OrderedDict() return args, options def remove_empty_options(self, options): """Remove empty options when value does not evaluate to True. Also when value is exactly 'True' we don't want True to be passed as option value but we want the option to be passed. """ # remove empty options and remove value when option is True cleaned_options = OrderedDict() for key, value in options.items(): if value is True: cleaned_options[key] = "" elif value: cleaned_options[key] = value return cleaned_options def get_sensitive_options(self, config): driver = config.get("full_details").get("driver") try: sensitive_options = getattr(self, f"get_{driver}_sensitive_options")() except AttributeError: sensitive_options = [] return sensitive_options def get_mysql_sensitive_options(self): return ["--password"] def get_mssql_sensitive_options(self): return ["-P"] def hide_sensitive_options(self, config, command): """Hide sensitive options in command string before it's displayed in the shell for security reasons. All drivers can declare what options are considered sensitive, their values will be then replaced by *** when displayed only.""" cleaned_command = command sensitive_options = self.get_sensitive_options(config) for option in sensitive_options: # if option is used obfuscate its value if option in command: match = re.search(rf"{option} (\w+)", command) if match.groups(): cleaned_command = cleaned_command.replace(match.groups()[0], "***") return cleaned_command ================================================ FILE: src/masoniteorm/commands/__init__.py ================================================ import os import sys sys.path.append(os.getcwd()) from .MigrateCommand import MigrateCommand from .MigrateRollbackCommand import MigrateRollbackCommand from .MigrateRefreshCommand import MigrateRefreshCommand from .MigrateFreshCommand import MigrateFreshCommand from .MigrateResetCommand import MigrateResetCommand from .MakeModelCommand import MakeModelCommand from .MakeModelDocstringCommand import MakeModelDocstringCommand from .MakeObserverCommand import MakeObserverCommand from .MigrateStatusCommand import MigrateStatusCommand from .MakeMigrationCommand import MakeMigrationCommand from .MakeSeedCommand import MakeSeedCommand from .SeedRunCommand import SeedRunCommand from .ShellCommand import ShellCommand ================================================ FILE: src/masoniteorm/commands/stubs/create_migration.stub ================================================ """__MIGRATION_NAME__ Migration.""" from masoniteorm.migrations import Migration class __MIGRATION_NAME__(Migration): def up(self): """ Run the migrations. """ with self.schema.create("__TABLE_NAME__") as table: table.increments("id") table.timestamps() def down(self): """ Revert the migrations. """ self.schema.drop("__TABLE_NAME__") ================================================ FILE: src/masoniteorm/commands/stubs/create_seed.stub ================================================ """__SEEDER_NAME__ Seeder.""" from masoniteorm.seeds import Seeder class __SEEDER_NAME__(Seeder): def run(self): """Run the database seeds.""" pass ================================================ FILE: src/masoniteorm/commands/stubs/model.stub ================================================ """ __CLASS__ Model """ from masoniteorm.models import Model class __CLASS__(Model): """__CLASS__ Model""" pass ================================================ FILE: src/masoniteorm/commands/stubs/observer.stub ================================================ """__CLASS__ Observer""" from masoniteorm.models import Model class __CLASS__Observer: def created(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "created" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def creating(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "creating" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def saving(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "saving" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def saved(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "saved" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def updating(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "updating" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def updated(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "updated" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def booted(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "booted" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def booting(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "booting" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def hydrating(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "hydrating" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def hydrated(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "hydrated" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def deleting(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "deleting" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass def deleted(self, __MODEL_VARIABLE__): """Handle the __MODEL__ "deleted" event. Args: __MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model. """ pass ================================================ FILE: src/masoniteorm/commands/stubs/table_migration.stub ================================================ """__MIGRATION_NAME__ Migration.""" from masoniteorm.migrations import Migration class __MIGRATION_NAME__(Migration): def up(self): """ Run the migrations. """ with self.schema.table("__TABLE_NAME__") as table: pass def down(self): """ Revert the migrations. """ with self.schema.table("__TABLE_NAME__") as table: pass ================================================ FILE: src/masoniteorm/config.py ================================================ import os import pydoc import urllib.parse as urlparse from .exceptions import ConfigurationNotFound from .exceptions import InvalidUrlConfiguration def load_config(config_path=None): """Load ORM configuration from given configuration path (dotted or not). If no path is provided: 1. try to load from DB_CONFIG_PATH environment variable 2. else try to load from default config_path: config/database """ selected_config_path = ( os.getenv("DB_CONFIG_PATH", None) or config_path or "config/database" ) os.environ["DB_CONFIG_PATH"] = selected_config_path # format path as python module if needed selected_config_path = ( selected_config_path.replace("/", ".").replace("\\", ".").rstrip(".py") ) config_module = pydoc.locate(selected_config_path) if config_module is None: raise ConfigurationNotFound( f"ORM configuration file has not been found in {selected_config_path}." ) return config_module def db_url(database_url=None, prefix="", options={}, log_queries=False): """Parse connection configuration from database url format. If no url is provided, DATABASE_URL environment variable will be used instead. Reference: Code adapted from https://github.com/jacobian/dj-database-url """ url = database_url or os.getenv("DATABASE_URL") if not url: raise InvalidUrlConfiguration("Database url is empty !") # Register database schemes in URLs. urlparse.uses_netloc.append("postgres") urlparse.uses_netloc.append("postgresql") urlparse.uses_netloc.append("pgsql") urlparse.uses_netloc.append("postgis") urlparse.uses_netloc.append("mysql") urlparse.uses_netloc.append("mysql2") urlparse.uses_netloc.append("mysqlgis") urlparse.uses_netloc.append("mssql") urlparse.uses_netloc.append("sqlite") DRIVERS_MAP = { "postgres": "postgres", "postgresql": "postgres", "pgsql": "postgres", "postgis": "postgres", "mysql": "mysql", "mysql2": "mysql", "mysqlgis": "mysql", "mysql-connector": "mysql", "mssql": "mssql", "sqlite": "sqlite", } # this is a special case, because if we pass this URL into # urlparse, urlparse will choke trying to interpret "memory" # as a port number if url in ["sqlite://:memory:", "sqlite://memory"]: driver = DRIVERS_MAP["sqlite"] path = ":memory:" # otherwise parse the url as normal else: url = urlparse.urlparse(url) # remove query string from path (not parsed for now) path = url.path[1:] if "?" in path and not url.query: path, _ = path.split("?", 2) # if we are using sqlite and we have no path, then assume we # want an in-memory database (this is the behaviour of sqlalchemy) if url.scheme == "sqlite" and path == "": path = ":memory:" # handle postgres percent-encoded paths. hostname = url.hostname or "" if "%2f" in hostname.lower(): # Switch to url.netloc to avoid lower cased paths hostname = url.netloc if "@" in hostname: hostname = hostname.rsplit("@", 1)[1] if ":" in hostname: hostname = hostname.split(":", 1)[0] hostname = hostname.replace("%2f", "/").replace("%2F", "/") # lookup specified driver driver = DRIVERS_MAP[url.scheme] port = ( str(url.port) if url.port and driver in [DRIVERS_MAP["mssql"]] else url.port ) # build final configuration config = { "driver": driver, "database": urlparse.unquote(path or ""), "prefix": prefix, "options": options, "log_queries": log_queries, } if driver != DRIVERS_MAP["sqlite"]: config.update( { "user": urlparse.unquote(url.username or ""), "password": urlparse.unquote(url.password or ""), "host": hostname, "port": port or "", } ) return config ================================================ FILE: src/masoniteorm/connections/.gitignore ================================================ ================================================ FILE: src/masoniteorm/connections/BaseConnection.py ================================================ import logging from timeit import default_timer as timer from .ConnectionResolver import ConnectionResolver class BaseConnection: _connection = None _cursor = None _dry = False def dry(self): self._dry = True return self def set_schema(self, schema): self.schema = schema return self def log( self, query, bindings, query_time=0, logger="masoniteorm.connections.queries" ): logger = logging.getLogger("masoniteorm.connection.queries") logger.propagate = self.full_details.get("propagate", True) logger.debug( f"Running query {query}, {bindings}. Executed in {query_time}ms", extra={"query": query, "bindings": bindings, "query_time": query_time}, ) def statement(self, query, bindings=()): """Wrapper around calling the cursor query. Helpful for logging output. Args: query (string): The query to execute on the cursor bindings (tuple, optional): Tuple of query bindings. Defaults to (). """ start = timer() if not self._cursor: raise AttributeError( f"Must set the _cursor attribute on the {self.__class__.__name__} class before calling the 'statement' method." ) self._cursor.execute(query, bindings) end = "{:.2f}".format(timer() - start) if self.full_details and self.full_details.get("log_queries", False): self.log(query, bindings, query_time=end) def has_global_connection(self): return self.name in ConnectionResolver().get_global_connections() def get_global_connection(self): return ConnectionResolver().get_global_connections()[self.name] def enable_query_log(self): self.full_details["log_queries"] = True def disable_query_log(self): self.full_details["log_queries"] = False def format_cursor_results(self, cursor_result): return cursor_result def set_cursor(self): self._cursor = self._connection.cursor() return self def select_many(self, query, bindings, amount): self.set_cursor() self.statement(query) if not self.open: self.make_connection() result = self.format_cursor_results(self._cursor.fetchmany(amount)) while result: yield result result = self.format_cursor_results(self._cursor.fetchmany(amount)) def enable_disable_foreign_keys(self): foreign_keys = self.full_details.get("foreign_keys") platform = self.get_default_platform()() if foreign_keys: self._connection.execute(platform.enable_foreign_key_constraints()) elif foreign_keys is not None: self._connection.execute(platform.disable_foreign_key_constraints()) ================================================ FILE: src/masoniteorm/connections/ConnectionFactory.py ================================================ from ..config import load_config class ConnectionFactory: """Class for controlling the registration and creation of connection types.""" _connections = {} def __init__(self, config_path=None): self.config_path = config_path @classmethod def register(cls, key, connection): """Registers new connections Arguments: key {key} -- The key or driver name you want assigned to this connection connection {masoniteorm.connections.BaseConnection} -- An instance of a BaseConnection class. Returns: cls """ cls._connections.update({key: connection}) return cls def make(self, key): """Makes already registered connections Arguments: key {string} -- The name of the connection you want to make Raises: Exception: Raises exception if there are no driver keys that match Returns: masoniteorm.connection.BaseConnection -- Returns an instance of a BaseConnection class. """ DB = load_config(config_path=self.config_path).DB connections = DB.get_connection_details() if key == "default": connection_details = connections.get(connections.get("default")) connection = self._connections.get(connection_details.get("driver")) else: connection_details = connections.get(key) connection = self._connections.get(key) if connection: return connection raise Exception( "The '{connection}' connection does not exist".format(connection=key) ) ================================================ FILE: src/masoniteorm/connections/ConnectionResolver.py ================================================ from contextlib import contextmanager class ConnectionResolver: _connection_details = {} _connections = {} _morph_map = {} def __init__(self, config_path=None): from ..connections import ( SQLiteConnection, PostgresConnection, MySQLConnection, MSSQLConnection, ) self.config_path = config_path from ..connections import ConnectionFactory self.connection_factory = ConnectionFactory(config_path=config_path) self.register(SQLiteConnection) self.register(PostgresConnection) self.register(MySQLConnection) self.register(MSSQLConnection) def morph_map(self, map): self._morph_map = map return self def set_connection_details(self, connection_details): self.__class__._connection_details = connection_details return self def get_connection_details(self): return self._connection_details def set_connection_option(self, connection: str, options: dict): self._connection_details.get(connection).update(options) return self def get_global_connections(self): return self._connections def remove_global_connection(self, name=None): self._connections.pop(name) def register(self, connection): self.connection_factory.register(connection.name, connection) def begin_transaction(self, name=None): if name is None: name = self.get_connection_details()["default"] driver = self.get_connection_details()[name].get("driver") connection = ( self.connection_factory.make(driver)( **self.get_connection_information(name) ) .make_connection() .begin() ) self.__class__._connections.update({name: connection}) return connection def commit(self, name=None): if name is None: name = self.get_connection_details()["default"] connection = self.get_global_connections()[name] self.remove_global_connection(name) connection.commit() def rollback(self, name=None): if name is None: name = self.get_connection_details()["default"] connection = self.get_global_connections()[name] self.remove_global_connection(name) connection.rollback() @contextmanager def transaction(self, name=None): self.begin_transaction(name) try: yield self except Exception: self.rollback(name) raise try: self.commit(name) except Exception: self.rollback(name) raise def get_connection_information(self, name): details = self.get_connection_details() return { "host": details.get(name, {}).get("host"), "database": details.get(name, {}).get("database"), "user": details.get(name, {}).get("user"), "port": details.get(name, {}).get("port"), "password": details.get(name, {}).get("password"), "prefix": details.get(name, {}).get("prefix"), "options": details.get(name, {}).get("options", {}), "full_details": details.get(name, {}), } def get_schema_builder(self, connection="default", schema=None): from ..schema import Schema return Schema( connection=connection, connection_details=self.get_connection_details(), schema=schema, ) def get_query_builder(self, connection="default"): from ..query import QueryBuilder return QueryBuilder( connection=connection, connection_details=self.get_connection_details() ) def statement(self, query, bindings=(), connection="default"): return self.get_query_builder().on(connection).statement(query, bindings) ================================================ FILE: src/masoniteorm/connections/MSSQLConnection.py ================================================ from ..exceptions import DriverNotFound from .BaseConnection import BaseConnection from ..query.grammars import MSSQLGrammar from ..schema.platforms import MSSQLPlatform from ..query.processors import MSSQLPostProcessor from ..exceptions import QueryException CONNECTION_POOL = [] class MSSQLConnection(BaseConnection): """MSSQL Connection class.""" name = "mssql" def __init__( self, host=None, database=None, user=None, port=None, password=None, prefix=None, options=None, full_details=None, name=None, ): self.host = host if port: self.port = int(port) else: self.port = port self.database = database self.user = user self.password = password self.prefix = prefix self.full_details = full_details or {} self.options = options or {} self._cursor = None self.transaction_level = 0 self.open = 0 if name: self.name = name def make_connection(self): """This sets the connection on the connection class""" try: import pyodbc except ModuleNotFoundError: raise DriverNotFound( "You must have the 'pyodbc' package installed to make a connection to Microsoft SQL Server. Please install it using 'pip install pyodbc'" ) if self.has_global_connection(): return self.get_global_connection() driver = self.options.get("driver", "ODBC Driver 17 for SQL Server") integrated_security = self.options.get("integrated_security") connection_timeout = str(self.options.get("connection_timeout", "30")) authentication = self.options.get("authentication") instance = self.options.get("instance", "") trusted_connection = self.options.get("trusted_connection") if instance: instance = "\\" + instance self._connection = pyodbc.connect( f"DRIVER={driver};SERVER={self.host}{instance if instance else ''},{self.port};Connection Timeout={connection_timeout};DATABASE={self.database}{f';Integrated Security={integrated_security}' if integrated_security else ''};UID={self.user};PWD={self.password}{f';Trusted_Connection={trusted_connection}' if trusted_connection else ''}{f';Authentication={authentication}' if authentication else ''}", autocommit=True, ) self.enable_disable_foreign_keys() self.open = 1 return self def get_database_name(self): return self.database @classmethod def get_default_query_grammar(cls): return MSSQLGrammar @classmethod def get_default_platform(cls): return MSSQLPlatform @classmethod def get_default_post_processor(cls): return MSSQLPostProcessor def reconnect(self): pass def commit(self): """Transaction""" if self.get_transaction_level() == 1: self._connection.commit() self._connection.autocommit = True self.transaction_level -= 1 def begin(self): """MSSQL Transaction""" self._connection.autocommit = False self.transaction_level += 1 return self def rollback(self): """Transaction""" if self.get_transaction_level() == 1: self._connection.rollback() self._connection.autocommit = True self.transaction_level -= 1 def get_transaction_level(self): """Transaction""" return self.transaction_level def get_cursor(self): return self._cursor def query(self, query, bindings=(), results="*"): """Make the actual query that will reach the database and come back with a result. Arguments: query {string} -- A string query. This could be a qmarked string or a regular query. bindings {tuple} -- A tuple of bindings Keyword Arguments: results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll' else it will return 'fetchOne' and return a single record. (default: {"*"}) Returns: dict|None -- Returns a dictionary of results or None """ try: if not self.open: self.make_connection() self._cursor = self._connection.cursor() with self._cursor as cursor: if isinstance(query, list) and not self._dry: for q in query: self.statement(q, ()) return query = query.replace("'?'", "?") self.statement(query, bindings) if results == 1: if not cursor.description: return {} columnNames = [column[0] for column in cursor.description] result = cursor.fetchone() return dict(zip(columnNames, result)) if result is not None else {} else: if not cursor.description: return {} return self.format_cursor_results(cursor.fetchall()) return {} except Exception as e: raise QueryException(str(e)) from e finally: if self.get_transaction_level() <= 0: self._connection.close() def format_cursor_results(self, cursor_result): columnNames = [column[0] for column in self.get_cursor().description] results = [] for record in cursor_result: results.append(dict(zip(columnNames, record))) return results ================================================ FILE: src/masoniteorm/connections/MySQLConnection.py ================================================ from ..exceptions import DriverNotFound from .BaseConnection import BaseConnection from ..query.grammars import MySQLGrammar from ..schema.platforms import MySQLPlatform from ..query.processors import MySQLPostProcessor from ..exceptions import QueryException CONNECTION_POOL = [] class MySQLConnection(BaseConnection): """MYSQL Connection class.""" name = "mysql" _dry = False def __init__( self, host=None, database=None, user=None, port=None, password=None, prefix=None, options=None, full_details=None, name=None, ): self.host = host self.port = port if str(port).isdigit(): self.port = int(self.port) self.database = database self.user = user self.password = password self.prefix = prefix self.full_details = full_details or {} self.connection_pool_size = full_details.get("connection_pooling_max_size", 100) self.options = options or {} self._cursor = None self.open = 0 self.transaction_level = 0 if name: self.name = name def make_connection(self): """This sets the connection on the connection class""" if self._dry: return if self.has_global_connection(): return self.get_global_connection() # Check if there is an available connection in the pool self._connection = self.create_connection() self.enable_disable_foreign_keys() return self def close_connection(self): if ( self.full_details.get("connection_pooling_enabled") and len(CONNECTION_POOL) < self.connection_pool_size ): CONNECTION_POOL.append(self._connection) self.open = 0 self._connection = None def create_connection(self, autocommit=True): try: import pymysql except ModuleNotFoundError: raise DriverNotFound( "You must have the 'pymysql' package " "installed to make a connection to MySQL. " "Please install it using 'pip install pymysql'" ) import pendulum import pymysql.converters pymysql.converters.conversions[pendulum.DateTime] = ( pymysql.converters.escape_datetime ) # Initialize the connection pool if the option is set initialize_size = self.full_details.get("connection_pooling_min_size") if initialize_size and len(CONNECTION_POOL) < initialize_size: for _ in range(initialize_size - len(CONNECTION_POOL)): connection = pymysql.connect( cursorclass=pymysql.cursors.DictCursor, autocommit=autocommit, host=self.host, user=self.user, password=self.password, port=self.port, database=self.database, **self.options ) CONNECTION_POOL.append(connection) if ( self.full_details.get("connection_pooling_enabled") and CONNECTION_POOL and len(CONNECTION_POOL) > 0 ): connection = CONNECTION_POOL.pop() else: connection = pymysql.connect( cursorclass=pymysql.cursors.DictCursor, autocommit=autocommit, host=self.host, user=self.user, password=self.password, port=self.port, database=self.database, **self.options ) connection.close = self.close_connection self.open = 1 return connection def reconnect(self): self._connection.connect() return self @classmethod def get_default_query_grammar(cls): return MySQLGrammar @classmethod def get_default_platform(cls): return MySQLPlatform @classmethod def get_default_post_processor(cls): return MySQLPostProcessor def get_database_name(self): return self.database def commit(self): """Transaction""" self._connection.commit() self.transaction_level -= 1 if self.get_transaction_level() <= 0: self.open = 0 self._connection.close() def dry(self): """Transaction""" self._dry = True return self def begin(self): """Mysql Transaction""" self._connection.begin() self.transaction_level += 1 return self def rollback(self): """Transaction""" self._connection.rollback() self.transaction_level -= 1 if self.get_transaction_level() <= 0: self.open = 0 self._connection.close() def get_transaction_level(self): """Transaction""" return self.transaction_level def get_cursor(self): return self._cursor def query(self, query, bindings=(), results="*"): """Make the actual query that will reach the database and come back with a result. Arguments: query {string} -- A string query. This could be a qmarked string or a regular query. bindings {tuple} -- A tuple of bindings Keyword Arguments: results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll' else it will return 'fetchOne' and return a single record. (default: {"*"}) Returns: dict|None -- Returns a dictionary of results or None """ if self._dry: return {} if not self.open: if self._connection is None: self._connection = self.create_connection() self._connection.connect() self._cursor = self._connection.cursor() try: with self._cursor as cursor: if isinstance(query, list): for q in query: q = q.replace("'?'", "%s") self.statement(q, ()) return query = query.replace("'?'", "%s") self.statement(query, bindings) if results == 1: return self.format_cursor_results(cursor.fetchone()) else: return self.format_cursor_results(cursor.fetchall()) except Exception as e: raise QueryException(str(e)) from e finally: self._cursor.close() if self.get_transaction_level() <= 0: self.open = 0 self._connection.close() ================================================ FILE: src/masoniteorm/connections/PostgresConnection.py ================================================ from ..exceptions import DriverNotFound from .BaseConnection import BaseConnection from ..query.grammars import PostgresGrammar from ..schema.platforms import PostgresPlatform from ..query.processors import PostgresPostProcessor from ..exceptions import QueryException CONNECTION_POOL = [] class PostgresConnection(BaseConnection): """Postgres Connection class.""" name = "postgres" def __init__( self, host=None, database=None, user=None, port=None, password=None, prefix=None, options=None, full_details=None, name=None, ): self.host = host if port: self.port = int(port) else: self.port = port self.database = database self.user = user self.password = password self.prefix = prefix self.full_details = full_details or {} self.connection_pool_size = full_details.get("connection_pooling_max_size", 100) self.options = options or {} self._cursor = None self.transaction_level = 0 self.open = 0 self.schema = None if name: self.name = name def make_connection(self): """This sets the connection on the connection class""" try: import psycopg2 # noqa F401 except ModuleNotFoundError: raise DriverNotFound( "You must have the 'psycopg2' package installed to make a connection to Postgres. Please install it using 'pip install psycopg2-binary'" ) if self.has_global_connection(): return self.get_global_connection() self._connection = self.create_connection() self._connection.autocommit = True self.enable_disable_foreign_keys() self.open = 1 return self def create_connection(self): import psycopg2 # Initialize the connection pool if the option is set initialize_size = self.full_details.get("connection_pooling_min_size") if ( self.full_details.get("connection_pooling_enabled") and initialize_size and len(CONNECTION_POOL) < initialize_size ): for _ in range(initialize_size - len(CONNECTION_POOL)): connection = psycopg2.connect( database=self.database, user=self.user, password=self.password, host=self.host, port=self.port, sslmode=self.options.get("sslmode"), sslcert=self.options.get("sslcert"), sslkey=self.options.get("sslkey"), sslrootcert=self.options.get("sslrootcert"), options=( f"-c search_path={self.schema or self.full_details.get('schema')}" if self.schema or self.full_details.get("schema") else "" ), ) CONNECTION_POOL.append(connection) if ( self.full_details.get("connection_pooling_enabled") and CONNECTION_POOL and len(CONNECTION_POOL) > 0 ): connection = CONNECTION_POOL.pop() else: connection = psycopg2.connect( database=self.database, user=self.user, password=self.password, host=self.host, port=self.port, sslmode=self.options.get("sslmode"), sslcert=self.options.get("sslcert"), sslkey=self.options.get("sslkey"), sslrootcert=self.options.get("sslrootcert"), options=( f"-c search_path={self.schema or self.full_details.get('schema')}" if self.schema or self.full_details.get("schema") else "" ), ) return connection def get_database_name(self): return self.database @classmethod def get_default_query_grammar(cls): return PostgresGrammar @classmethod def get_default_platform(cls): return PostgresPlatform @classmethod def get_default_post_processor(cls): return PostgresPostProcessor def reconnect(self): pass def close_connection(self): if ( self.full_details.get("connection_pooling_enabled") and len(CONNECTION_POOL) < self.connection_pool_size ): CONNECTION_POOL.append(self._connection) else: self._connection.close() self._connection = None def commit(self): """Transaction""" if self.get_transaction_level() == 1: self._connection.commit() self._connection.autocommit = True self.transaction_level -= 1 def begin(self): """Postgres Transaction""" self._connection.autocommit = False self.transaction_level += 1 return self def rollback(self): """Transaction""" if self.get_transaction_level() == 1: self._connection.rollback() self._connection.autocommit = True self.transaction_level -= 1 def get_transaction_level(self): """Transaction""" return self.transaction_level def set_cursor(self): from psycopg2.extras import RealDictCursor self._cursor = self._connection.cursor(cursor_factory=RealDictCursor) return self._cursor def query(self, query, bindings=(), results="*"): """Make the actual query that will reach the database and come back with a result. Arguments: query {string} -- A string query. This could be a qmarked string or a regular query. bindings {tuple} -- A tuple of bindings Keyword Arguments: results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll' else it will return 'fetchOne' and return a single record. (default: {"*"}) Returns: dict|None -- Returns a dictionary of results or None """ try: if not self._connection or self._connection.closed: self.make_connection() self.set_cursor() with self._cursor as cursor: if isinstance(query, list) and not self._dry: for q in query: self.statement(q, ()) return query = query.replace("'?'", "%s") self.statement(query, bindings) if results == 1: return dict(cursor.fetchone() or {}) else: if "SELECT" in cursor.statusmessage: return cursor.fetchall() return {} except Exception as e: raise QueryException(str(e)) from e finally: if self.get_transaction_level() <= 0: self.open = 0 self.close_connection() # self._connection.close() ================================================ FILE: src/masoniteorm/connections/SQLiteConnection.py ================================================ from ..query.grammars import SQLiteGrammar from .BaseConnection import BaseConnection from ..schema.platforms import SQLitePlatform from ..query.processors import SQLitePostProcessor from ..exceptions import DriverNotFound, QueryException import re def regexp(expr, item): reg = re.compile(expr) return reg.search(item) is not None class SQLiteConnection(BaseConnection): """SQLite Connection class.""" name = "sqlite" _connection = None def __init__( self, host=None, database=None, user=None, port=None, password=None, prefix=None, full_details=None, options=None, name=None, ): self.host = host if port: self.port = int(port) else: self.port = port self.database = database self.user = user self.password = password self.prefix = prefix self.full_details = full_details or {} self.options = options or {} self._cursor = None self.transaction_level = 0 self.open = 0 if name: self.name = name def make_connection(self): """This sets the connection on the connection class""" try: import sqlite3 except ModuleNotFoundError: raise DriverNotFound( "You must have the 'sqlite3' package installed to make a connection to SQLite." ) if self.has_global_connection(): return self.get_global_connection() self._connection = sqlite3.connect(self.database, isolation_level=None) self._connection.create_function("REGEXP", 2, regexp) self._connection.row_factory = sqlite3.Row self.enable_disable_foreign_keys() self.open = 1 return self @classmethod def get_default_query_grammar(cls): return SQLiteGrammar @classmethod def get_default_platform(cls): return SQLitePlatform @classmethod def get_default_post_processor(cls): return SQLitePostProcessor def get_database_name(self): return self.database def reconnect(self): pass def commit(self): """Transaction""" if self.get_transaction_level() == 1: self.transaction_level -= 1 self._connection.commit() self._connection.isolation_level = None self._connection.close() self.open = 0 self.transaction_level -= 1 return self def begin(self): """Sqlite Transaction""" self._connection.isolation_level = "DEFERRED" self.transaction_level += 1 return self def rollback(self): """Transaction""" if self.get_transaction_level() == 1: self.transaction_level -= 1 self._connection.rollback() self._connection.close() self.open = 0 self.transaction_level -= 1 return self def get_cursor(self): return self._cursor def get_transaction_level(self): return self.transaction_level def query(self, query, bindings=(), results="*"): """Make the actual query that will reach the database and come back with a result. Arguments: query {string} -- A string query. This could be a qmarked string or a regular query. bindings {tuple} -- A tuple of bindings Keyword Arguments: results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll' else it will return 'fetchOne' and return a single record. (default: {"*"}) Returns: dict|None -- Returns a dictionary of results or None """ if not self.open: self.make_connection() try: self._cursor = self._connection.cursor() if isinstance(query, list): for query in query: self.statement(query) else: query = query.replace("'?'", "?") self.statement(query, bindings) if results == 1: result = [dict(row) for row in self._cursor.fetchall()] if result: return result[0] else: return [dict(row) for row in self._cursor.fetchall()] except Exception as e: raise QueryException(str(e)) from e finally: if self.get_transaction_level() <= 0: self._connection.close() self.open = 0 def format_cursor_results(self, cursor_result): return [dict(row) for row in cursor_result] def select_many(self, query, bindings, amount): self._cursor = self._connection.cursor() self.statement(query) if not self.open: self.make_connection() result = self.format_cursor_results(self._cursor.fetchmany(amount)) while result: yield result result = self.format_cursor_results(self._cursor.fetchmany(amount)) ================================================ FILE: src/masoniteorm/connections/__init__.py ================================================ from .ConnectionResolver import ConnectionResolver from .ConnectionFactory import ConnectionFactory from .MySQLConnection import MySQLConnection from .PostgresConnection import PostgresConnection from .SQLiteConnection import SQLiteConnection from .MSSQLConnection import MSSQLConnection ================================================ FILE: src/masoniteorm/exceptions.py ================================================ class DriverNotFound(Exception): pass class ModelNotFound(Exception): pass class HTTP404(Exception): pass class ConnectionNotRegistered(Exception): pass class QueryException(Exception): pass class MigrationNotFound(Exception): pass class ConfigurationNotFound(Exception): pass class InvalidUrlConfiguration(Exception): pass class MultipleRecordsFound(Exception): pass class InvalidArgument(Exception): pass ================================================ FILE: src/masoniteorm/expressions/__init__.py ================================================ from .expressions import Raw from .expressions import JoinClause ================================================ FILE: src/masoniteorm/expressions/expressions.py ================================================ from ..helpers.misc import deprecated class QueryExpression: """A helper class to manage query expressions.""" def __init__( self, column, equality, value, value_type="value", keyword=None, raw=False, bindings=(), ): self.column = column self.equality = equality self.value = value self.value_type = value_type self.keyword = keyword self.raw = raw self.bindings = bindings class HavingExpression: """A helper class to manage having expressions.""" def __init__(self, column, equality=None, value=None, raw=False): self.column = column self.raw = raw if equality and not value: value = equality equality = "=" self.equality = equality self.value = value self.value_type = "having" class FromTable: """A helper class to manage having expressions.""" def __init__(self, name, raw=False): self.name = name self.raw = raw class UpdateQueryExpression: """A helper class to manage update expressions.""" def __init__(self, column, value=None, update_type="keyvalue"): self.column = column self.value = value self.update_type = update_type class BetweenExpression: """A helper class to manage where between expressions.""" def __init__(self, column, low, high, equality="BETWEEN"): self.column = column self.low = low self.high = high self.equality = equality self.value = None self.value_type = "BETWEEN" self.raw = False class SubSelectExpression: """A helper class to manage subselect expressions.""" def __init__(self, builder): self.builder = builder class SubGroupExpression: """A helper class to manage subgroup expressions.""" def __init__(self, builder, alias="group"): self.builder = builder self.alias = alias class SelectExpression: """A helper class to manage select expressions.""" def __init__(self, column, raw=False): self.column = column.strip() self.alias = None self.raw = raw if raw is False and " as " in self.column: self.column, self.alias = self.column.split(" as ") self.column = self.column.strip() self.alias = self.alias.strip() class OrderByExpression: """A helper class to manage select expressions.""" def __init__(self, column, direction="ASC", raw=False, bindings=()): self.column = column.strip() self.raw = raw self.direction = direction self.bindings = bindings if raw is False: if self.column.endswith(" desc"): self.column = self.column.split(" desc")[0].strip() self.direction = "DESC" if self.column.endswith(" asc"): self.column = self.column.split(" asc")[0].strip() self.direction = "ASC" class GroupByExpression: """A helper class to manage select expressions.""" def __init__(self, column=None, raw=False, bindings=()): self.column = column.strip() self.raw = raw self.bindings = bindings class AggregateExpression: def __init__(self, aggregate=None, column=None, alias=False): self.aggregate = aggregate self.column = column.strip() self.alias = alias if " as " in self.column: self.column, self.alias = self.column.split(" as ") class Raw: def __init__(self, expression): self.expression = expression class JoinClause: def __init__(self, table, clause="join"): self.table = table self.alias = None self.clause = clause self.on_clauses = [] if " as " in self.table: self.table = table.split(" as ")[0] self.alias = table.split(" as ")[1] def on(self, column1, equality, column2): self.on_clauses.append(OnClause(column1, equality, column2)) return self def or_on(self, column1, equality, column2): self.on_clauses.append(OnClause(column1, equality, column2, "or")) return self def on_value(self, column, *args): equality, value = self._extract_operator_value(*args) self.on_clauses += ((OnValueClause(column, equality, value, "value")),) return self def or_on_value(self, column, *args): equality, value = self._extract_operator_value(*args) self.on_clauses += ( (OnValueClause(column, equality, value, "value", operator="or")), ) return self def on_null(self, column): """Specifies an ON expression where the column IS NULL. Arguments: column {string} -- The name of the column. Returns: self """ self.on_clauses += ((OnValueClause(column, "=", None, "NULL")),) return self def on_not_null(self, column: str): """Specifies an ON expression where the column IS NOT NULL. Arguments: column {string} -- The name of the column. Returns: self """ self.on_clauses += ((OnValueClause(column, "=", True, "NOT NULL")),) return self def or_on_null(self, column): """Specifies an ON expression where the column IS NULL. Arguments: column {string} -- The name of the column. Returns: self """ self.on_clauses += ((OnValueClause(column, "=", None, "NULL", operator="or")),) return self def or_on_not_null(self, column: str): """Specifies an ON expression where the column IS NOT NULL. Arguments: column {string} -- The name of the column. Returns: self """ self.on_clauses += ( (OnValueClause(column, "=", True, "NOT NULL", operator="or")), ) return self @deprecated("Using where() in a Join clause has been superceded by on_value()") def where(self, column, *args): return self.on_value(column, *args) def _extract_operator_value(self, *args): operators = ["=", ">", ">=", "<", "<=", "!=", "<>", "like", "not like"] operator = operators[0] value = None if (len(args)) >= 2: operator = args[0] value = args[1] elif len(args) == 1: value = args[0] if operator not in operators: raise ValueError( "Invalid comparison operator. The operator can be %s" % ", ".join(operators) ) return operator, value def get_on_clauses(self): return self.on_clauses class OnClause: def __init__(self, column1, equality, column2, operator="and"): self.column1 = column1 self.column2 = column2 self.equality = equality self.operator = operator class OnValueClause: """A helper class to manage ON expressions in joins with a value.""" def __init__( self, column, equality, value, value_type="value", keyword=None, raw=False, bindings=(), operator="and", ): self.column = column self.equality = equality self.value = value self.value_type = value_type self.keyword = keyword self.raw = raw self.bindings = bindings self.operator = operator ================================================ FILE: src/masoniteorm/factories/Factory.py ================================================ from faker import Faker import random class Factory: _factories = {} _after_creates = {} _faker = None @property def faker(self): if not Factory._faker: Factory._faker = Faker() random.seed() Factory._faker.seed_instance(random.randint(1, 10000)) return Factory._faker def __init__(self, model, number=1): self.model = model self.number = number def make(self, dictionary=None, name="default"): if dictionary is None: dictionary = {} if self.number == 1 and not isinstance(dictionary, list): called = self._factories[self.model][name](self.faker) called.update(dictionary) model = self.model.hydrate(called) self.run_after_creates(model) return model elif isinstance(dictionary, list): results = [] for index in range(0, len(dictionary)): called = self._factories[self.model][name](self.faker) called.update(dictionary) results.append(called) models = self.model.hydrate(results) for model in models: self.run_after_creates(model) return models else: results = [] for index in range(0, self.number): called = self._factories[self.model][name](self.faker) called.update(dictionary) results.append(called) models = self.model.hydrate(results) for model in models: self.run_after_creates(model) return models def create(self, dictionary=None, name="default"): if dictionary is None: dictionary = {} if self.number == 1 and not isinstance(dictionary, list): called = self._factories[self.model][name](self.faker) called.update(dictionary) model = self.model.create(called) self.run_after_creates(model) return model elif isinstance(dictionary, list): results = [] for index in range(0, len(dictionary)): called = self._factories[self.model][name](self.faker) called.update(dictionary) results.append(called) models = self.model.create(results) for model in models: self.run_after_creates(model) return models else: full_collection = [] for index in range(0, self.number): called = self._factories[self.model][name](self.faker) called.update(dictionary) full_collection.append(called) model = self.model.create(called) self.run_after_creates(model) return self.model.hydrate(full_collection) @classmethod def register(cls, model, call, name="default"): if model not in cls._factories: cls._factories[model] = {name: call} else: cls._factories[model][name] = call @classmethod def after_creating(cls, model, call, name="default"): if model not in cls._after_creates: cls._after_creates[model] = {name: call} else: cls._after_creates[model][name] = call def run_after_creates(self, model): if self.model not in self._after_creates: return model for name, callback in self._after_creates[self.model].items(): callback(model, self.faker) ================================================ FILE: src/masoniteorm/factories/__init__.py ================================================ from .Factory import Factory ================================================ FILE: src/masoniteorm/helpers/__init__.py ================================================ ================================================ FILE: src/masoniteorm/helpers/misc.py ================================================ """Module for miscellaneous helper methods.""" import warnings def deprecated(message): warnings.simplefilter("default", DeprecationWarning) def deprecated_decorator(func): def deprecated_func(*args, **kwargs): warnings.warn( "{} is a deprecated function. {}".format(func.__name__, message), category=DeprecationWarning, stacklevel=2, ) return func(*args, **kwargs) return deprecated_func return deprecated_decorator ================================================ FILE: src/masoniteorm/migrations/Migration.py ================================================ import os from os import listdir from os.path import isfile, join from pydoc import locate from inflection import camelize from ..models.MigrationModel import MigrationModel from ..schema import Schema from ..config import load_config from timeit import default_timer as timer class Migration: def __init__( self, connection="default", dry=False, command_class=None, migration_directory="databases/migrations", config_path=None, schema=None, ): self.connection = connection self.migration_directory = migration_directory self.last_migrations_ran = [] self.command_class = command_class self.schema_name = schema DB = load_config(config_path).DB DATABASES = DB.get_connection_details() self.schema = Schema( connection=connection, connection_details=DATABASES, dry=dry, schema=self.schema_name, ) self.migration_model = MigrationModel.on(self.connection) if self.schema_name: self.migration_model.set_schema(self.schema_name) def create_table_if_not_exists(self): if not self.schema.has_table("migrations"): with self.schema.create("migrations") as table: table.increments("migration_id") table.string("migration") table.integer("batch") return True return False def get_unran_migrations(self): directory_path = os.path.join(os.getcwd(), self.migration_directory) all_migrations = [ f.replace(".py", "") for f in listdir(directory_path) if isfile(join(directory_path, f)) and f != "__init__.py" and not f.startswith(".") ] all_migrations.sort() unran_migrations = [] database_migrations = self.migration_model.all() for migration in all_migrations: if migration not in database_migrations.pluck("migration"): unran_migrations.append(migration) return unran_migrations def get_rollback_migrations(self): return ( self.migration_model.where("batch", self.migration_model.all().max("batch")) .order_by("migration_id", "desc") .get() .pluck("migration") ) def get_all_migrations(self, reverse=False): if reverse: return ( self.migration_model.order_by("migration_id", "desc") .get() .pluck("migration") ) return self.migration_model.all().pluck("migration") def get_last_batch_number(self): return self.migration_model.select("batch").get().max("batch") def delete_migration(self, file_path): return self.migration_model.where("migration", file_path).delete() def locate(self, file_name): migration_name = camelize("_".join(file_name.split("_")[4:]).replace(".py", "")) file_name = file_name.replace(".py", "") migration_directory = self.migration_directory.replace("/", ".").replace( "\\", "." ) return locate(f"{migration_directory}.{file_name}.{migration_name}") def get_ran_migrations(self): directory_path = os.path.join(os.getcwd(), self.migration_directory) all_migrations = [ f.replace(".py", "") for f in listdir(directory_path) if isfile(join(directory_path, f)) and f != "__init__.py" and not f.startswith(".") ] all_migrations.sort() ran = [] database_migrations = self.migration_model.all() for migration in all_migrations: matched_migration = database_migrations.where( "migration", migration ).first() if matched_migration: ran.append( { "migration_file": matched_migration.migration, "batch": matched_migration.batch, } ) return ran def migrate(self, migration="all", output=False): default_migrations = self.get_unran_migrations() migrations = default_migrations if migration == "all" else [migration] batch = self.get_last_batch_number() + 1 for migration in migrations: try: migration_class = self.locate(migration) except TypeError: self.command_class.line(f"Not Found: {migration}") continue self.last_migrations_ran.append(migration) if self.command_class: self.command_class.line( f"Migrating: {migration}" ) migration_class = migration_class( connection=self.connection, schema=self.schema_name ) if output: migration_class.schema.dry() start = timer() migration_class.up() duration = "{:.2f}".format(timer() - start) if output: if self.command_class: table = self.command_class.table() table.set_header_row(["SQL"]) sql = migration_class.schema._blueprint.to_sql() if isinstance(sql, list): sql = ",".join(sql) table.set_rows([[sql]]) table.render(self.command_class.io) continue else: print(migration_class.schema._blueprint.to_sql()) if self.command_class: self.command_class.line( f"Migrated: {migration} ({duration}s)" ) self.migration_model.create( {"batch": batch, "migration": migration.replace(".py", "")} ) def rollback(self, migration="all", output=False): default_migrations = self.get_rollback_migrations() migrations = default_migrations if migration == "all" else [migration] for migration in migrations: if migration.endswith(".py"): migration = migration.replace(".py", "") if self.command_class: self.command_class.line( f"Rolling back: {migration}" ) try: migration_class = self.locate(migration) except TypeError: self.command_class.line(f"Not Found: {migration}") continue migration_class = migration_class( connection=self.connection, schema=self.schema_name ) if output: migration_class.schema.dry() start = timer() migration_class.down() duration = "{:.2f}".format(timer() - start) if output: if self.command_class: table = self.command_class.table() table.set_header_row(["SQL"]) if ( hasattr(migration_class.schema, "_blueprint") and migration_class.schema._blueprint ): sql = migration_class.schema._blueprint.to_sql() if isinstance(sql, list): sql = ",".join(sql) table.set_rows([[sql]]) elif migration_class.schema._sql: table.set_rows([[migration_class.schema._sql]]) table.render(self.command_class.io) continue else: print(migration_class.schema._blueprint.to_sql()) self.delete_migration(migration) if self.command_class: self.command_class.line( f"Rolled back: {migration} ({duration}s)" ) def delete_migrations(self, migrations=None): return self.migration_model.where_in("migration", migrations or []).delete() def delete_last_batch(self): return self.migration_model.where( "batch", self.get_last_batch_number() ).delete() def reset(self, migration="all"): default_migrations = self.get_all_migrations(reverse=True) migrations = default_migrations if migration == "all" else [migration] if not len(migrations): if self.command_class: self.command_class.line("Nothing to reset") else: print("Nothing to reset") for migration in migrations: if self.command_class: self.command_class.line( f"Rolling back: {migration}" ) try: self.locate(migration)( connection=self.connection, schema=self.schema_name ).down() except TypeError: self.command_class.line(f"Not Found: {migration}") continue # raise MigrationNotFound(f"Could not find {migration}") self.delete_migration(migration) if self.command_class: self.command_class.line( f"Rolled back: {migration}" ) self.delete_migrations([migration]) if self.command_class: self.command_class.line("") def refresh(self, migration="all"): self.reset(migration) self.migrate(migration) def drop_all_tables(self, ignore_fk=False): if self.command_class: self.command_class.line("Dropping all tables") if ignore_fk: self.schema.disable_foreign_key_constraints() for table in self.schema.get_all_tables(): self.schema.drop(table) if ignore_fk: self.schema.enable_foreign_key_constraints() if self.command_class: self.command_class.line("All tables dropped") def fresh(self, ignore_fk=False, migration="all"): self.drop_all_tables(ignore_fk=ignore_fk) self.create_table_if_not_exists() if not self.get_unran_migrations(): if self.command_class: self.command_class.line("Nothing to migrate") return self.migrate(migration) ================================================ FILE: src/masoniteorm/migrations/__init__.py ================================================ from .Migration import Migration ================================================ FILE: src/masoniteorm/models/MigrationModel.py ================================================ from .Model import Model class MigrationModel(Model): __table__ = "migrations" __fillable__ = ["migration", "batch"] __timestamps__ = None __primary_key__ = "migration_id" ================================================ FILE: src/masoniteorm/models/Model.py ================================================ import inspect import json import logging from datetime import date as datetimedate from datetime import datetime from datetime import time as datetimetime from decimal import Decimal from typing import Any, Dict import pendulum from inflection import tableize, underscore from ..collection import Collection from ..config import load_config from ..exceptions import ModelNotFound from ..observers import ObservesEvents from ..query import QueryBuilder from ..scopes import TimeStampsMixin """This is a magic class that will help using models like User.first() instead of having to instatiate a class like User().first() """ class ModelMeta(type): def __getattr__(self, attribute, *args, **kwargs): """This method is called between a Model and accessing a property. This is a quick and easy way to instantiate a class before the first method is called. This is to avoid needing to do this: User().where(..) and instead, with this class inherited as a meta class, we can do this: User.where(...) This class (potentially magically) instantiates the class even though we really didn't instantiate it. Args: attribute (string): The name of the attribute Returns: Model|mixed: An instantiated model's attribute """ instantiated = self() return getattr(instantiated, attribute) class BoolCast: """Casts a value to a boolean""" def get(self, value): return bool(value) def set(self, value): return bool(value) class JsonCast: """Casts a value to JSON""" def get(self, value): if isinstance(value, str): try: return json.loads(value) except ValueError: return None return value def set(self, value): if isinstance(value, str): # make sure the string is valid JSON json.loads(value) return value return json.dumps(value, default=str) class IntCast: """Casts a value to a int""" def get(self, value): return int(value) def set(self, value): return int(value) class FloatCast: """Casts a value to a float""" def get(self, value): return float(value) def set(self, value): return float(value) class DateCast: """Casts a value to a float""" def get(self, value): return pendulum.parse(value).to_date_string() def set(self, value): return pendulum.parse(value).to_date_string() class DecimalCast: """Casts a value to Decimal for accuracy""" def get(self, value): """ Get the value """ if isinstance(value, Decimal): return str(value) return Decimal(str(value)) def set(self, value): """ Set the value """ return Decimal(str(value)) class Model(TimeStampsMixin, ObservesEvents, metaclass=ModelMeta): """The ORM Model class Base Classes: TimeStampsMixin (TimeStampsMixin): Adds scopes to add timestamps when something is inserted metaclass (ModelMeta, optional): Helps instantiate a class when it hasn't been instantiated. Defaults to ModelMeta. """ __fillable__ = ["*"] __guarded__ = [] __dry__ = False __table__ = None __connection__ = "default" __resolved_connection__ = None __selects__ = [] __observers__ = {} __has_events__ = True _booted = False _scopes = {} __primary_key__ = "id" __primary_key_type__ = "int" __casts__ = {} __dates__ = [] __hidden__ = [] __relationship_hidden__ = {} __visible__ = [] __timestamps__ = True __timezone__ = "UTC" __with__ = () __force_update__ = False date_created_at = "created_at" date_updated_at = "updated_at" builder: QueryBuilder """Pass through will pass any method calls to the model directly through to the query builder. Anytime one of these methods are called on the model it will actually be called on the query builder class. """ __passthrough__ = set( ( "add_select", "aggregate", "all", "avg", "between", "bulk_create", "chunk", "count", "decrement", "delete", "distinct", "doesnt_exist", "doesnt_have", "exists", "find_or", "find_or_404", "find_or_fail", "first_or_fail", "first", "first_where", "first_or_create", "force_update", "from_", "from_raw", "get", "get_table_schema", "group_by_raw", "group_by", "has", "having", "having_raw", "increment", "in_random_order", "join_on", "join", "joins", "last", "left_join", "limit", "lock_for_update", "make_lock", "max", "min", "new_from_builder", "new", "not_between", "offset", "on", "or_where", "or_where_null", "order_by_raw", "order_by", "paginate", "right_join", "select_raw", "select", "set_global_scope", "set_schema", "shared_lock", "simple_paginate", "skip", "statement", "sum", "table_raw", "take", "to_qmark", "to_sql", "truncate", "update", "when", "where_between", "where_column", "where_date", "or_where_doesnt_have", "or_has", "or_where_has", "or_doesnt_have", "or_where_not_exists", "or_where_date", "where_exists", "where_from_builder", "where_has", "where_in", "where_like", "where_not_between", "where_not_in", "where_not_like", "where_not_null", "where_null", "where_raw", "without_global_scopes", "where", "where_doesnt_have", "with_", "with_count", "latest", "oldest", "value", ) ) __cast_map__ = {} __internal_cast_map__ = { "bool": BoolCast, "json": JsonCast, "int": IntCast, "float": FloatCast, "date": DateCast, "decimal": DecimalCast, } def __init__(self): self.__attributes__ = {} self.__original_attributes__ = {} self.__dirty_attributes__ = {} if not hasattr(self, "__appends__"): self.__appends__ = [] self._relationships = {} self._global_scopes = {} self.boot() @classmethod def get_primary_key(self): """Gets the primary key column Returns: mixed """ return self.__primary_key__ def get_primary_key_type(self): """Gets the primary key column type Returns: mixed """ return self.__primary_key_type__ def get_primary_key_value(self): """Gets the primary key value. Raises: AttributeError: Raises attribute error if the model does not have an attribute with the primary key. Returns: str|int """ try: return getattr(self, self.get_primary_key()) except AttributeError: name = self.__class__.__name__ raise AttributeError( f"class '{name}' has no attribute {self.get_primary_key()}. Did you set the primary key correctly on the model using the __primary_key__ attribute?" ) def get_foreign_key(self): """Gets the foreign key based on this model name. Args: relationship (str): The relationship name. Returns: str """ return underscore(self.__class__.__name__ + "_" + self.get_primary_key()) def query(self): return self.get_builder() def get_builder(self): if hasattr(self, "builder"): return self.builder self.builder = QueryBuilder( connection=self.__connection__, table=self.get_table_name(), connection_details=self.get_connection_details(), model=self, scopes=self._scopes.get(self.__class__), dry=self.__dry__, ) return self.builder def get_selects(self): return self.__selects__ @classmethod def get_columns(cls): return list(cls.first().__attributes__.keys()) def get_connection_details(self): DB = load_config().DB return DB.get_connection_details() def boot(self): if not self._booted: self.observe_events(self, "booting") for base_class in inspect.getmro(self.__class__): class_name = base_class.__name__ if class_name.endswith("Mixin"): getattr(self, "boot_" + class_name)(self.get_builder()) elif ( base_class != Model and issubclass(base_class, Model) and "__fillable__" in base_class.__dict__ and "__guarded__" in base_class.__dict__ ): raise AttributeError( f"{type(self).__name__} must specify either __fillable__ or __guarded__ properties, but not both." ) self._booted = True self.observe_events(self, "booted") self.append_passthrough(list(self.get_builder()._macros.keys())) def append_passthrough(self, passthrough): self.__passthrough__.update(passthrough) return self @classmethod def get_table_name(cls): """Gets the table name. Returns: str """ return cls.__table__ or tableize(cls.__name__) @classmethod def table(cls, table): """Gets the table name. Returns: str """ cls.__table__ = table return cls @classmethod def find(cls, record_id, query=False): """Finds a row by the primary key ID. Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model """ if isinstance(record_id, (list, tuple)): builder = cls().where_in(cls.get_primary_key(), record_id) else: builder = cls().where(cls.get_primary_key(), record_id) if query: return builder else: if isinstance(record_id, (list, tuple)): return builder.get() return builder.first() @classmethod def find_or_fail(cls, record_id, query=False): """Finds a row by the primary key ID or raise a ModelNotFound exception. Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model """ result = cls.find(record_id, query) if not result: raise ModelNotFound() return result def is_loaded(self): return bool(self.__attributes__) def is_created(self): return self.get_primary_key() in self.__attributes__ def add_relation(self, relations): self._relationships.update(relations) return self @classmethod def hydrate(cls, result, relations=None): """Takes a result and loads it into a model Args: result ([type]): [description] relations (dict, optional): [description]. Defaults to {}. Returns: [type]: [description] """ relations = relations or {} if result is None: return None if isinstance(result, (list, tuple)): response = [] for element in result: response.append(cls.hydrate(element)) return cls.new_collection(response) elif isinstance(result, dict): model = cls() dic = {} for key, value in result.items(): if key in model.get_dates() and value: value = model.get_new_date(value) dic.update({key: value}) logger = logging.getLogger("masoniteorm.models.hydrate") logger.setLevel(logging.INFO) logger.propagate = False logger.info( f"Hydrating Model {cls.__name__}", extra={"class_name": cls.__name__, "class_module": cls.__module__}, ) model.observe_events(model, "hydrating") model.__attributes__.update(dic or {}) model.__original_attributes__.update(dic or {}) model.add_relation(relations) model.observe_events(model, "hydrated") return model elif hasattr(result, "serialize"): model = cls() model.__attributes__.update(result.serialize()) model.__original_attributes__.update(result.serialize()) return model else: model = cls() model.observe_events(model, "hydrating") model.__attributes__.update(dict(result)) model.__original_attributes__.update(dict(result)) model.observe_events(model, "hydrated") return model def fill(self, attributes): self.__attributes__.update(attributes) return self def fill_original(self, attributes): self.__original_attributes__.update(attributes) return self @classmethod def new_collection(cls, data): """Takes a result and puts it into a new collection. This is designed to be able to be overidden by the user. Args: data (list|dict): Could be any data type but will be loaded directly into a collection. Returns: Collection """ return Collection(data) @classmethod def create( cls, dictionary: Dict[str, Any] = None, query: bool = False, cast: bool = False, **kwargs, ): """Creates new records based off of a dictionary as well as data set on the model such as fillable values. Args: dictionary (dict, optional): [description]. Defaults to {}. query (bool, optional): [description]. Defaults to False. cast (bool, optional): [description]. Whether or not to cast passed values. Returns: self: A hydrated version of a model """ if query: return cls.builder.create(dictionary, query=True, cast=cast, **kwargs) return cls.builder.create(dictionary, cast=cast, **kwargs) @classmethod def cast_value(cls, attribute: str, value: Any): """ Given an attribute name and a value, casts the value using the model's registered caster. If no registered caster exists, returns the unmodified value. """ cast_method = cls.__casts__.get(attribute) cast_map = cls.get_cast_map(cls) if value is None: return None if isinstance(cast_method, str): return cast_map[cast_method]().set(value) if cast_method: return cast_method(value) return value @classmethod def cast_values(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Runs provided dictionary through all model casters and returns the result. Does not mutate the passed dictionary. """ return {x: cls.cast_value(x, dictionary[x]) for x in dictionary} def fresh(self): return ( self.get_builder() .where(self.get_primary_key(), self.get_primary_key_value()) .first() ) def serialize(self, exclude=None, include=None): """Takes the data as a model and converts it into a dictionary. Returns: dict """ serialized_dictionary = self.__attributes__.copy() # prevent using both exclude and include at the same time if exclude is not None and include is not None: raise AttributeError("Can not define both includes and exclude values.") if exclude is not None: self.__hidden__ = exclude if include is not None: self.__visible__ = include # prevent using both hidden and visible at the same time if self.__visible__ and self.__hidden__: raise AttributeError( f"class model '{self.__class__.__name__}' defines both __visible__ and __hidden__." ) if self.__visible__: new_serialized_dictionary = { k: serialized_dictionary[k] for k in self.__visible__ if k in serialized_dictionary } serialized_dictionary = new_serialized_dictionary else: for key in self.__hidden__: if key in serialized_dictionary: serialized_dictionary.pop(key) for date_column in self.get_dates(): if ( date_column in serialized_dictionary and serialized_dictionary[date_column] ): serialized_dictionary[date_column] = self.get_new_serialized_date( serialized_dictionary[date_column] ) serialized_dictionary.update(self.__dirty_attributes__) # The builder is inside the attributes but should not be serialized if "builder" in serialized_dictionary: serialized_dictionary.pop("builder") # Serialize relationships as well serialized_dictionary.update(self.relations_to_dict()) for append in self.__appends__: serialized_dictionary.update({append: getattr(self, append)}) remove_keys = [] for key, value in serialized_dictionary.items(): if key in self.__hidden__: remove_keys.append(key) if hasattr(value, "serialize"): value = value.serialize(self.__relationship_hidden__.get(key, [])) if isinstance(value, datetime): value = self.get_new_serialized_date(value) if key in self.__casts__: value = self._cast_attribute(key, value) serialized_dictionary.update({key: value}) for key in remove_keys: serialized_dictionary.pop(key) return serialized_dictionary def to_json(self): """Converts a model to JSON Returns: string """ return json.dumps(self.serialize(), default=str) @classmethod def first_or_create(cls, wheres, creates: dict = None): """Get the first record matching the attributes or create it. Returns: Model """ if creates is None: creates = {} self = cls() record = self.where(wheres).first() total = {} total.update(creates) total.update(wheres) if not record: return self.create(total, id_key=cls.get_primary_key()) return record @classmethod def update_or_create(cls, wheres, updates): self = cls() record = self.where(wheres).first() total = {} total.update(updates) total.update(wheres) if not record: return self.create(total, id_key=cls.get_primary_key()).fresh() return self.where(wheres).update(total) def relations_to_dict(self): """Converts a models relationships to a dictionary Returns: [type]: [description] """ new_dic = {} for key, value in self._relationships.items(): if value == {}: new_dic.update({key: {}}) else: if value is None: new_dic.update({key: {}}) continue elif isinstance(value, list): value = Collection(value).serialize() elif isinstance(value, dict): pass else: value = value.serialize() new_dic.update({key: value}) return new_dic def touch(self, date=None, query=True): """Updates the current timestamps on the model""" if not self.__timestamps__: return False self._update_timestamps(date=date) return self.save(query=query) def _update_timestamps(self, date=None): """Sets the updated at date to the current time or a specified date Args: date (datetime.datetime, optional): a date. If none is specified then it will use the current date Defaults to None. """ self.updated_at = date or self._current_timestamp() def _current_timestamp(self): return datetime.now() def __getattr__(self, attribute): """Magic method that is called when an attribute does not exist on the model. Args: attribute (string): the name of the attribute being accessed or called. Returns: mixed: Could be anything that a method can return. """ new_name_accessor = "get_" + attribute + "_attribute" if (new_name_accessor) in self.__class__.__dict__: return self.__class__.__dict__.get(new_name_accessor)(self) if ( "__dirty_attributes__" in self.__dict__ and attribute in self.__dict__["__dirty_attributes__"] ): return self.get_dirty_value(attribute) if ( "__attributes__" in self.__dict__ and attribute in self.__dict__["__attributes__"] ): if attribute in self.get_dates(): return ( self.get_new_date(self.get_value(attribute)) if self.get_value(attribute) else None ) return self.get_value(attribute) if attribute in self.__passthrough__: def method(*args, **kwargs): return getattr(self.get_builder(), attribute)(*args, **kwargs) return method if attribute in self.__dict__.get("_relationships", {}): return self.__dict__["_relationships"][attribute] if attribute not in self.__dict__: name = self.__class__.__name__ raise AttributeError(f"class model '{name}' has no attribute {attribute}") return None def only(self, attributes: list) -> dict: if isinstance(attributes, str): attributes = [attributes] results: dict[str, Any] = {} for attribute in attributes: if " as " in attribute: attribute, alias = attribute.split(" as ") alias = alias.strip() attribute = attribute.strip() else: alias = attribute.strip() attribute = attribute.strip() results[alias] = self.get_raw_attribute(attribute) return results def __setattr__(self, attribute, value): if hasattr(self, "set_" + attribute + "_attribute"): method = getattr(self, "set_" + attribute + "_attribute") value = method(value) if attribute in self.__casts__: value = self._set_cast_attribute(attribute, value) if attribute in self.get_dates(): value = self.get_new_datetime_string(value) try: if not attribute.startswith("_"): self.__dict__["__dirty_attributes__"].update({attribute: value}) else: self.__dict__[attribute] = value except KeyError: pass def get_raw_attribute(self, attribute): """Gets an attribute without having to call the models magic methods. Gets around infinite recursion loops. Args: attribute (string): The attribute to fetch Returns: mixed: Any value an attribute can be. """ return self.__attributes__.get(attribute) def is_dirty(self): return bool(self.__dirty_attributes__) def get_original(self, key): return self.__original_attributes__.get(key) def get_dirty(self, key): return self.__dirty_attributes__.get(key) def get_dirty_keys(self): return list(self.get_dirty_attributes().keys()) def save(self, query=False): builder = self.get_builder() if "builder" in self.__dirty_attributes__: self.__dirty_attributes__.pop("builder") self.observe_events(self, "saving") if not query: if self.is_loaded(): result = builder.update( self.__dirty_attributes__, ignore_mass_assignment=True ) else: result = self.create( self.__dirty_attributes__, query=query, id_key=self.get_primary_key(), ignore_mass_assignment=True, ) self.observe_events(self, "saved") self.fill(result.__attributes__) self.__dirty_attributes__ = {} return result if self.is_loaded(): result = builder.update( self.__dirty_attributes__, dry=query, ignore_mass_assignment=True ) else: result = self.create(self.__dirty_attributes__, query=query) return result def get_value(self, attribute): value = self.__attributes__[attribute] if attribute in self.__casts__: return self._cast_attribute(attribute, value) return value def get_dirty_value(self, attribute): value = self.__dirty_attributes__[attribute] if attribute in self.__casts__: return self._cast_attribute(attribute, value) return value def all_attributes(self): attributes = self.__attributes__ attributes.update(self.get_dirty_attributes()) for key, value in attributes.items(): if key in self.__casts__: attributes.update({key: self._cast_attribute(key, value)}) return attributes def delete_attribute(self, key): if key in self.__attributes__: del self.__attributes__[key] return True return False def get_dirty_attributes(self): if "builder" in self.__dirty_attributes__: self.__dirty_attributes__.pop("builder") return self.__dirty_attributes__ or {} def get_cast_map(self): cast_map = self.__internal_cast_map__ cast_map.update(self.__cast_map__) return cast_map def _cast_attribute(self, attribute, value): cast_method = self.__casts__[attribute] cast_map = self.get_cast_map() if value is None: return None if isinstance(cast_method, str): return cast_map[cast_method]().get(value) return cast_method(value) def _set_cast_attribute(self, attribute, value): cast_method = self.__casts__[attribute] cast_map = self.get_cast_map() if isinstance(cast_method, str): return cast_map[cast_method]().set(value) return cast_method(value) @classmethod def load(cls, *loads): cls.boot() cls._loads += loads return cls.builder def __getitem__(self, attribute): return getattr(self, attribute) def get_dates(self): """ Get the attributes that should be converted to dates. :rtype: list """ defaults = [self.date_created_at, self.date_updated_at] return self.__dates__ + defaults def get_new_date(self, _datetime=None): """ Get the attributes that should be converted to dates. :rtype: list """ import pendulum if not _datetime: return pendulum.now(tz=self.__timezone__) elif isinstance(_datetime, str): return pendulum.parse(_datetime, tz=self.__timezone__) elif isinstance(_datetime, datetime): return pendulum.instance(_datetime, tz=self.__timezone__) elif isinstance(_datetime, datetimedate): return pendulum.datetime( _datetime.year, _datetime.month, _datetime.day, tz=self.__timezone__ ) elif isinstance(_datetime, datetimetime): return pendulum.parse( f"{_datetime.hour}:{_datetime.minute}:{_datetime.second}", tz=self.__timezone__, ) return pendulum.instance(_datetime, tz=self.__timezone__) def get_new_datetime_string(self, _datetime=None): """ Given an optional datetime value, constructs and returns a new datetime string. If no datetime is specified, returns the current time. :rtype: list """ return self.get_new_date(_datetime).to_datetime_string() def get_new_serialized_date(self, _datetime): """ Get the attributes that should be converted to dates. :rtype: list """ return self.get_new_date(_datetime).isoformat() def set_appends(self, appends): """ Get the attributes that should be converted to dates. :rtype: list """ self.__appends__ += appends return self def save_many(self, relation, relating_records): if isinstance(relating_records, Model): raise ValueError( "Saving many records requires an iterable like a collection or a list of models and not a Model object. To attach a model, use the 'attach' method." ) for related_record in relating_records: self.attach(relation, related_record) def detach_many(self, relation, relating_records): if isinstance(relating_records, Model): raise ValueError( "Detaching many records requires an iterable like a collection or a list of models and not a Model object. To detach a model, use the 'detach' method." ) related = getattr(self.__class__, relation) for related_record in relating_records: if not related_record.is_created(): related_record.create(related_record.all_attributes()) else: related_record.save() related.detach(self, related_record) def related(self, relation): related = getattr(self.__class__, relation) return related.relate(self) def get_related(self, relation): related = getattr(self.__class__, relation) return related def attach(self, relation, related_record): related = getattr(self.__class__, relation) return related.attach(self, related_record) def detach(self, relation, related_record): related = getattr(self.__class__, relation) if not related_record.is_created(): related_record = related_record.create(related_record.all_attributes()) else: related_record.save() return related.detach(self, related_record) def save_quietly(self): """This method calls the save method on a model without firing the saved & saving observer events. Saved/Saving are toggled back on once save_quietly has been ran. Instead of calling: User().save(...) you can use this: User.save_quietly(...) """ self.without_events() saved = self.save() self.with_events() return saved def delete_quietly(self): """This method calls the delete method on a model without firing the delete & deleting observer events. Instead of calling: User().delete(...) you can use this: User.delete_quietly(...) Returns: self """ delete = ( self.without_events() .where(self.get_primary_key(), self.get_primary_key_value()) .delete() ) self.with_events() return delete def attach_related(self, relation, related_record): return self.attach(relation, related_record) @classmethod def filter_fillable(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Filters provided dictionary to only include fields specified in the model's __fillable__ property Passed dictionary is not mutated. """ if cls.__fillable__ != ["*"]: dictionary = {x: dictionary[x] for x in cls.__fillable__ if x in dictionary} return dictionary @classmethod def filter_mass_assignment(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Filters the provided dictionary in preparation for a mass-assignment operation Wrapper around filter_fillable() & filter_guarded(). Passed dictionary is not mutated. """ return cls.filter_guarded(cls.filter_fillable(dictionary)) @classmethod def filter_guarded(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Filters provided dictionary to exclude fields specified in the model's __guarded__ property Passed dictionary is not mutated. """ if cls.__guarded__ == ["*"]: # If all fields are guarded, all data should be filtered return {} return {f: dictionary[f] for f in dictionary if f not in cls.__guarded__} ================================================ FILE: src/masoniteorm/models/Model.pyi ================================================ from typing import Any, Dict from typing_extensions import Self from ..query.QueryBuilder import QueryBuilder class Model: def add_select(alias: str, callable: Any): """Specifies a select subquery.""" pass def aggregate(aggregate: str, column: str, alias: str): """Helper function to aggregate. Arguments: aggregate {string} -- The name of the aggregation. column {string} -- The name of the column to aggregate. """ def all(selects: list = [], query: bool = False): """Returns all records from the table. Returns: dictionary -- Returns a dictionary of results. """ pass def get(selects: list = []): """Runs the select query built from the query builder. Returns: self """ pass def avg(column: str): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ pass def between(column: str, low: str | int, high: str | int): """Specifies a where between expression. Arguments: column {string} -- The name of the column. low {string} -- The value on the low end. high {string} -- The value on the high end. Returns: self """ pass def bulk_create(creates: dict, query: bool = False): pass def cast_value(attribute: str, value: Any): """ Given an attribute name and a value, casts the value using the model's registered caster. If no registered caster exists, returns the unmodified value. """ pass def cast_values(dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Runs provided dictionary through all model casters and returns the result. Does not mutate the passed dictionary. """ pass def chunk(chunk_amount: str | int): pass def count(column: str = None): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ pass def create( dictionary: Dict[str, Any] = None, query: bool = False, cast: bool = False, **kwargs ): """Creates new records based off of a dictionary as well as data set on the model such as fillable values. Args: dictionary (dict, optional): [description]. Defaults to {}. query (bool, optional): [description]. Defaults to False. cast (bool, optional): [description]. Whether or not to cast passed values. Returns: self: A hydrated version of a model """ pass def decrement(column: str, value: int = 1): """Decrements a column's value. Arguments: column {string} -- The name of the column. Keyword Arguments: value {int} -- The value to decrement by. (default: {1}) Returns: self """ def delete(column: str = None, value: str = None, query: bool = False): """Specify the column and value to delete or deletes everything based on a previously used where expression. Keyword Arguments: column {string} -- The name of the column (default: {None}) value {string|int} -- The value of the column (default: {None}) Returns: self """ pass def distinct(boolean: bool = True): """Species that the select query should be a SELECT DISTINCT query.""" pass def doesnt_exist() -> bool: """Determines if any rows exist for the current query. Returns: Bool - True or False """ pass def doesnt_have() -> bool: """Determine if any related rows exist for the current query. Returns: Bool - True or False """ pass def exists() -> bool: """Determine if rows exist for the current query. Returns: Bool - True or False """ pass def filter_fillable(dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Filters provided dictionary to only include fields specified in the model's __fillable__ property Passed dictionary is not mutated. """ pass def filter_mass_assignment(dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Filters the provided dictionary in preparation for a mass-assignment operation Wrapper around filter_fillable() & filter_guarded(). Passed dictionary is not mutated. """ pass def filter_guarded(dictionary: Dict[str, Any]) -> Dict[str, Any]: """ Filters provided dictionary to exclude fields specified in the model's __guarded__ property Passed dictionary is not mutated. """ pass def find_or_404(record_id: str | int): """Finds a row by the primary key ID (Requires a model) or raise an 404 exception. Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model|HTTP404 """ pass def find(record_id: str | list) -> Self: """Finds a row by the primary key ID (Requires a model) or raise an 404 exception. Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model|Collection """ pass def find_or_fail(record_id: str | int): """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model|ModelNotFound """ pass def first_or_fail(query: bool = False): """Returns the first row from database. If no result found a ModelNotFound exception. Returns: dictionary|ModelNotFound """ def first(fields: list = None, query: bool = False): """Gets the first record. Returns: dictionary -- Returns a dictionary of results. """ pass def first_where(column: str, *args): """Gets the first record with the given key / value pair""" pass def first_or_create(wheres: dict, creates: dict = None): """Get the first record matching the attributes or create it. Returns: Model """ pass def force_update(updates: dict, dry: bool = False): pass def from_(table: str): """Alias for the table method Arguments: table {string} -- The name of the table Returns: self """ pass def from_raw(table: str): """Alias for the table method Arguments: table {string} -- The name of the table Returns: self """ pass def last(column: str = None, query: bool = False): """Gets the last record, ordered by column in descendant order or primary key if no column is given. Returns: dictionary -- Returns a dictionary of results. """ pass def group_by_raw(query: str, bindings: list = []): """Specifies a column to group by. Arguments: query {string} -- A raw query Returns: self """ pass def group_by(column: str): """Specifies a column to group by. Arguments: column {string} -- The name of the column to group by. Returns: self """ pass def has(*relationships: str): pass def having_raw(string: str): """Specifies raw SQL that should be injected into the having expression. Arguments: string {string} -- The raw query string. Returns: self """ pass def increment(column: str, value: int = 1): """Increments a column's value. Arguments: column {string} -- The name of the column. Keyword Arguments: value {int} -- The value to increment by. (default: {1}) Returns: self """ pass def in_random_order(): """Puts Query results in random order""" pass def join_on(relationship: str, callback: callable = None, clause: str = ["inner"]): pass def join( self, table: str, column1: str = None, equality: str = None, column2: str = None, clause: str = "inner", ): """Specifies a join expression. Arguments: table {string} -- The name of the table or an instance of JoinClause. column1 {string} -- The name of the foreign table. equality {string} -- The equality to join on. column2 {string} -- The name of the local column. Keyword Arguments: clause {string} -- The action clause. (default: {"inner"}) Returns: self """ pass def joins(*relationships: list[str], clause: str = "inner"): pass def left_join( table: str, column1: str = None, equality: str = None, column2: str = None ): """A helper method to add a left join expression. Arguments: table {string} -- The name of the table to join on. column1 {string} -- The name of the foreign table. equality {string} -- The equality to join on. column2 {string} -- The name of the local column. Returns: self """ pass def limit(amount: int): """Specifies a limit expression. Arguments: amount {int} -- The number of rows to limit. Returns: self """ pass def lock_for_update(): pass def make_lock(lock: bool): pass def max(column: str): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ pass def min(column: str): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ pass def new_from_builder(from_builder: QueryBuilder = None): """Creates a new QueryBuilder class. Returns: QueryBuilder -- The ORM QueryBuilder class. """ pass def new(): """Creates a new QueryBuilder class. Returns: QueryBuilder -- The ORM QueryBuilder class. """ pass def not_between(column: str, low: str | int, high: str | int): """Specifies a where not between expression. Arguments: column {string} -- The name of the column. low {string} -- The value on the low end. high {string} -- The value on the high end. Returns: self """ pass def offset(amount: int): """Specifies an offset expression. Arguments: amount {int} -- The number of rows to limit. Returns: self """ pass def on(connection: str): pass def or_where(column: str | int, *args) -> QueryBuilder: """Specifies an or where query expression. Arguments: column {[type]} -- [description] value {[type]} -- [description] Returns: [type] -- [description] """ pass def or_where_null(column: str): """Specifies a where expression where the column is NULL. Arguments: column {string} -- The name of the column. Returns: self """ pass def order_by_raw(query: str, bindings: list = []): """Specifies a column to order by. Arguments: column {string} -- The name of the column. Keyword Arguments: direction {string} -- Specify either ASC or DESC order. (default: {"ASC"}) Returns: self """ pass def order_by(column: str, direction: str = "ASC|DESC"): """Specifies a column to order by. Arguments: column {string} -- The name of the column. Keyword Arguments: direction {string} -- Specify either ASC or DESC order. (default: {"ASC"}) Returns: self """ pass def paginate(per_page: int, page: int = 1): pass def right_join( table: str, column1: str = None, equality: str = None, column2: str = None ): """A helper method to add a right join expression. Arguments: table {string} -- The name of the table to join on. column1 {string} -- The name of the foreign table. equality {string} -- The equality to join on. column2 {string} -- The name of the local column. Returns: self """ pass def select_raw(query: str): """Specifies raw SQL that should be injected into the select expression. Returns: self """ pass def select(*args: str): """Specifies columns that should be selected Returns: self """ pass def set_global_scope( self, name: str = "", callable: callable = None, action: str = ["select", "update", "create", "delete"], ): """Sets the global scopes that should be used before creating the SQL. Arguments: cls {masoniteorm.Model} -- An ORM model class. name {string} -- The name of the global scope. Returns: self """ pass def shared_lock(): pass def simple_paginate(per_page: int, page: int = 1): pass def skip(*args, **kwargs): """Alias for limit method.""" pass def statement(query: str, bindings: list = ()): pass def sum(column: str): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ pass def table_raw(query: str): """Sets a query as the table Arguments: query {string} -- The query to use for the table Returns: self """ pass def take(*args, **kwargs): """Alias for limit method""" pass def to_qmark() -> str: """Compiles the QueryBuilder class into a Qmark SQL statement. Returns: self """ pass def to_sql() -> str: """Compiles the QueryBuilder class into a SQL statement. Returns: self """ pass def truncate(foreign_keys: bool = False): pass def update( updates: dict, dry: bool = False, force: bool = False, cast: bool = False ): """Specifies columns and values to be updated. Arguments: updates {dictionary} -- A dictionary of columns and values to update. dry {bool, optional} -- Whether a query should actually run force {bool, optional} -- Force the update even if there are no changes cast {bool, optional} -- Run all values through model's casters Returns: self """ pass def when(conditional: bool, callback: callable): pass def where_between(*args, **kwargs): """Alias for between""" pass def where_column(column1: str, column2: str): """Specifies where two columns equal eachother. Arguments: column1 {string} -- The name of the column. column2 {string} -- The name of the column. Returns: self """ pass def take(*args: Any, **kwargs: Any): """Alias for limit method""" pass def where_column(column1: str, column2: str): """Specifies where two columns equal eachother. Arguments: column1 {string} -- The name of the column. column2 {string} -- The name of the column. Returns: self """ pass def where_date(column: str, date: Any): """Specifies a where DATE expression Arguments: column {string} -- The name of the column. Returns: self """ pass def or_where_date(column: str, date: Any): """Specifies a where DATE expression Arguments: column {string} -- The name of the column. date {string|datetime|pendulum} -- The name of the column. Returns: self """ pass def where_exists(value: Any): """Specifies a where exists expression. Arguments: value {string|int|QueryBuilder} -- A value to check for the existence of a query expression. Returns: self """ pass def where_from_builder(builder: QueryBuilder): """Specifies a where expression. Arguments: column {string} -- The name of the column to search Keyword Arguments: args {List} -- The operator and the value of the column to search. (default: {None}) Returns: self """ pass def where_has(relationship: str, callback: Any): pass def where_in(column: str, wheres: list = []): """Specifies where a column contains a list of a values. Arguments: column {string} -- The name of the column. Keyword Arguments: wheres {list} -- A list of values (default: {[]}) Returns: self """ pass def where_like(column: str, value: str): """Specifies a where LIKE expression. Arguments: column {string} -- The name of the column to search value {string} -- The value of the column to match Returns: self """ pass def where_not_between(*args: Any, **kwargs: Any): """Alias for not_between""" pass def where_not_in(column: str, wheres: list = []): """Specifies where a column does not contain a list of a values. Arguments: column {string} -- The name of the column. Keyword Arguments: wheres {list} -- A list of values (default: {[]}) Returns: self """ pass def where_not_like(column: str, value: str): """Specifies a where expression. Arguments: column {string} -- The name of the column to search value {string} -- The value of the column to match Returns: self """ pass def where_not_null(column: str): """Specifies a where expression where the column is not NULL. Arguments: column {string} -- The name of the column. Returns: self """ pass def where_null(column: str): """Specifies a where expression where the column is NULL. Arguments: column {string} -- The name of the column. Returns: self """ pass def where_raw(query: str, bindings: list = []): """Specifies raw SQL that should be injected into the where expression. Arguments: query {string} -- The raw query string. Keyword Arguments: bindings {tuple} -- query bindings that should be added to the connection. (default: {()}) Returns: self """ pass def without_global_scopes(): pass def where(column: str, *args: Any): """Specifies a where expression. Arguments: column {string} -- The name of the column to search Keyword Arguments: args {List} -- The operator and the value of the column to search. (default: {None}) Returns: self """ pass def with_(*eagers: str): pass def with_count(relationship: str, callback: Any = None): pass ================================================ FILE: src/masoniteorm/models/Pivot.py ================================================ from .Model import Model class Pivot(Model): __primary_key__ = "id" ================================================ FILE: src/masoniteorm/models/__init__.py ================================================ from .Model import Model ================================================ FILE: src/masoniteorm/observers/ObservesEvents.py ================================================ class ObservesEvents: def observe_events(self, model, event): if model.__has_events__ == True: for observer in model.__observers__.get(model.__class__, []): try: getattr(observer, event)(model) except AttributeError: pass @classmethod def observe(cls, observer): if cls in cls.__observers__: cls.__observers__[cls].append(observer) else: cls.__observers__.update({cls: [observer]}) @classmethod def without_events(cls): """Sets __has_events__ attribute on model to false.""" cls.__has_events__ = False return cls @classmethod def with_events(cls): """Sets __has_events__ attribute on model to True.""" cls.__has_events__ = True return cls ================================================ FILE: src/masoniteorm/observers/__init__.py ================================================ from .ObservesEvents import ObservesEvents ================================================ FILE: src/masoniteorm/pagination/BasePaginator.py ================================================ import json class BasePaginator: def __iter__(self): for result in self.result: yield result def to_json(self): return json.dumps(self.serialize()) ================================================ FILE: src/masoniteorm/pagination/LengthAwarePaginator.py ================================================ import math from .BasePaginator import BasePaginator class LengthAwarePaginator(BasePaginator): def __init__(self, result, per_page, current_page, total, url=None): self.result = result self.current_page = current_page self.per_page = per_page self.count = len(self.result) self.last_page = int(math.ceil(total / per_page)) self.next_page = (int(self.current_page) + 1) if self.has_more_pages() else None self.previous_page = (int(self.current_page) - 1) or None self.total = total self.url = url def serialize(self, *args, **kwargs): return { "data": self.result.serialize(*args, **kwargs), "meta": { "total": self.total, "next_page": self.next_page, "count": self.count, "previous_page": self.previous_page, "last_page": self.last_page, "current_page": self.current_page, }, } def has_more_pages(self): return self.current_page < self.last_page ================================================ FILE: src/masoniteorm/pagination/SimplePaginator.py ================================================ from .BasePaginator import BasePaginator class SimplePaginator(BasePaginator): def __init__(self, result, per_page, current_page, url=None): self.result = result self.current_page = current_page self.per_page = per_page self.count = len(self.result) self.next_page = (int(self.current_page) + 1) if self.has_more_pages() else None self.previous_page = (int(self.current_page) - 1) or None self.url = url def serialize(self, *args, **kwargs): return { "data": self.result.serialize(*args, **kwargs), "meta": { "next_page": self.next_page, "count": self.count, "previous_page": self.previous_page, "current_page": self.current_page, }, } def has_more_pages(self): return len(self.result) > self.per_page ================================================ FILE: src/masoniteorm/pagination/__init__.py ================================================ from .LengthAwarePaginator import LengthAwarePaginator from .SimplePaginator import SimplePaginator ================================================ FILE: src/masoniteorm/providers/ORMProvider.py ================================================ from masonite.providers import Provider from masoniteorm.commands import ( MigrateCommand, MigrateRollbackCommand, MigrateRefreshCommand, MigrateResetCommand, MakeModelCommand, MakeObserverCommand, MigrateStatusCommand, MakeMigrationCommand, MakeSeedCommand, SeedRunCommand, ) class ORMProvider(Provider): """Masonite ORM database provider to be used inside Masonite based projects.""" def __init__(self, application): self.application = application def register(self): self.application.make("commands").add( MakeMigrationCommand(), MakeSeedCommand(), MakeObserverCommand(), MigrateCommand(), MigrateResetCommand(), MakeModelCommand(), MigrateStatusCommand(), MigrateRefreshCommand(), MigrateRollbackCommand(), SeedRunCommand(), ), def boot(self): pass ================================================ FILE: src/masoniteorm/providers/__init__.py ================================================ from .ORMProvider import ORMProvider ================================================ FILE: src/masoniteorm/query/EagerRelation.py ================================================ class EagerRelations: def __init__(self, relation=None): self.eagers = [] self.nested_eagers = {} self.callback_eagers = {} self.is_nested = False self.relation = relation def register(self, *relations, callback=None): for relation in relations: if isinstance(relation, str) and "." not in relation: self.eagers += [relation] elif isinstance(relation, str) and "." in relation: self.is_nested = True relation_key = relation.split(".")[0] if relation_key not in self.nested_eagers: self.nested_eagers = {relation_key: relation.split(".")[1:]} else: self.nested_eagers[relation_key] += relation.split(".")[1:] elif isinstance(relation, (tuple, list)): for eagers in relations: for eager in eagers: self.register(eager) elif isinstance(relation, dict): self.callback_eagers.update(relation) return self def get_eagers(self): eagers = [] if self.eagers: eagers.append(self.eagers) if self.nested_eagers: eagers.append(self.nested_eagers) if self.callback_eagers: eagers.append(self.callback_eagers) return eagers ================================================ FILE: src/masoniteorm/query/QueryBuilder.py ================================================ import inspect from copy import deepcopy from datetime import datetime from typing import Any, Callable, Dict, List, Optional from ..collection.Collection import Collection from ..config import load_config from ..exceptions import ( HTTP404, ConnectionNotRegistered, InvalidArgument, ModelNotFound, MultipleRecordsFound, ) from ..expressions.expressions import ( AggregateExpression, BetweenExpression, FromTable, GroupByExpression, HavingExpression, JoinClause, OrderByExpression, QueryExpression, SelectExpression, SubGroupExpression, SubSelectExpression, UpdateQueryExpression, ) from ..observers import ObservesEvents from ..pagination import LengthAwarePaginator, SimplePaginator from ..schema import Schema from ..scopes import BaseScope from .EagerRelation import EagerRelations class QueryBuilder(ObservesEvents): """A builder class to manage the building and creation of query expressions.""" def __init__( self, grammar=None, connection="default", connection_class=None, table=None, connection_details=None, connection_driver="default", model=None, scopes=None, schema=None, dry=False, config_path=None, ): """QueryBuilder initializer Arguments: grammar {masoniteorm.grammar.Grammar} -- A grammar class. Keyword Arguments: connection {masoniteorm.connection.Connection} -- A connection class (default: {None}) table {str} -- the name of the table (default: {""}) """ self.config_path = config_path self.grammar = grammar self.table(table) self.dry = dry self._creates_related = {} self.connection = connection self.connection_class = connection_class self._connection = None self._connection_details = connection_details or {} self._connection_driver = connection_driver self._scopes = scopes or {} self.lock = False self._schema = schema self._eager_relation = EagerRelations() if model: self._global_scopes = model._global_scopes if model.__with__: self.with_(model.__with__) else: self._global_scopes = {} self.builder = self self._columns = () self._creates = {} self._sql = "" self._bindings = () self._updates = () self._wheres = () self._order_by = () self._group_by = () self._joins = () self._having = () self._macros = {} self._aggregates = () self._limit = False self._offset = False self._distinct = False self._model = model self.set_action("select") if not self._connection_details: DB = load_config(config_path=self.config_path).DB self._connection_details = DB.get_connection_details() self.on(connection) if grammar: self.grammar = grammar if connection_class: self.connection_class = connection_class def _set_creates_related(self, fields: dict): self._creates_related = fields return self def set_schema(self, schema): self._schema = schema return self def shared_lock(self): return self.make_lock("share") def lock_for_update(self): return self.make_lock("update") def make_lock(self, lock): self.lock = lock return self def reset(self): """Resets the query builder instance so you can make multiple calls with the same builder instance""" self.set_action("select") self._updates = () self._wheres = () self._order_by = () self._group_by = () self._joins = () self._having = () return self def get_connection_information(self): return { "host": self._connection_details.get(self.connection, {}).get("host"), "database": self._connection_details.get(self.connection, {}).get( "database" ), "user": self._connection_details.get(self.connection, {}).get("user"), "port": self._connection_details.get(self.connection, {}).get("port"), "password": self._connection_details.get(self.connection, {}).get( "password" ), "prefix": self._connection_details.get(self.connection, {}).get("prefix"), "options": self._connection_details.get(self.connection, {}).get( "options", {} ), "full_details": self._connection_details.get(self.connection, {}), } def table(self, table, raw=False): """Sets a table on the query builder Arguments: table {string} -- The name of the table Returns: self """ if table: self._table = FromTable(table, raw=raw) else: self._table = table return self def from_(self, table): """Alias for the table method Arguments: table {string} -- The name of the table Returns: self """ return self.table(table) def from_raw(self, table): """Alias for the table method Arguments: table {string} -- The name of the table Returns: self """ return self.table(table, raw=True) def table_raw(self, query): """Sets a query on the query builder Arguments: query {string} -- The query to use for the table Returns: self """ return self.from_raw(query) def get_table_name(self): """Sets a table on the query builder Arguments: table {string} -- The name of the table Returns: self """ return self._table.name def get_connection(self): """Sets a table on the query builder Arguments: table {string} -- The name of the table Returns: self """ return self.connection_class def begin(self): """Sets a table on the query builder Arguments: table {string} -- The name of the table Returns: self """ return self.new_connection().begin() def begin_transaction(self, *args, **kwargs): return self.begin(*args, **kwargs) def get_schema_builder(self): return Schema(connection=self.connection_class, grammar=self.grammar) def commit(self): """Sets a table on the query builder Arguments: table {string} -- The name of the table Returns: self """ return self._connection.commit() def rollback(self): """Sets a table on the query builder Arguments: table {string} -- The name of the table Returns: self """ self._connection.rollback() return self def get_relation(self, key): """Sets a table on the query builder Arguments: table {string} -- The name of the table Returns: self """ return getattr(self.owner, key) def set_scope(self, name, callable): """Sets a scope based on a class and maps it to a name. Arguments: cls {masoniteorm.Model} -- An ORM model class. name {string} -- The name of the scope to use. Returns: self """ # setattr(self, name, callable) self._scopes.update({name: callable}) return self def set_global_scope(self, name="", callable=None, action="select"): """Sets the global scopes that should be used before creating the SQL. Arguments: cls {masoniteorm.Model} -- An ORM model class. name {string} -- The name of the global scope. Returns: self """ if isinstance(name, BaseScope): name.on_boot(self) return self if action not in self._global_scopes: self._global_scopes[action] = {} self._global_scopes[action].update({name: callable}) return self def without_global_scopes(self): self._global_scopes = {} return self def remove_global_scope(self, scope, action=None): """Sets the global scopes that should be used before creating the SQL. Arguments: cls {masoniteorm.Model} -- An ORM model class. name {string} -- The name of the global scope. Returns: self """ if isinstance(scope, BaseScope): scope.on_remove(self) return self del self._global_scopes.get(action, {})[scope] return self def __getattr__(self, attribute): """Magic method for fetching query scopes. This method is only used when a method or attribute does not already exist. Arguments: attribute {string} -- The attribute to fetch. Raises: AttributeError: Raised when there is no attribute or scope on the builder class. Returns: self """ if attribute == "__setstate__": raise AttributeError( "'QueryBuilder' object has no attribute '{}'".format(attribute) ) if attribute in self._scopes: def method(*args, **kwargs): return self._scopes[attribute](self._model, self, *args, **kwargs) return method if attribute in self._macros: def method(*args, **kwargs): return self._macros[attribute](self._model, self, *args, **kwargs) return method raise AttributeError( "'QueryBuilder' object has no attribute '{}'".format(attribute) ) def on(self, connection): DB = load_config(self.config_path).DB if connection == "default": self.connection = self._connection_details.get("default") else: self.connection = connection if self.connection not in self._connection_details: raise ConnectionNotRegistered( f"Could not find the '{self.connection}' connection details" ) self._connection_driver = self._connection_details.get(self.connection).get( "driver" ) self.connection_class = DB.connection_factory.make(self._connection_driver) self.grammar = self.connection_class.get_default_query_grammar() return self def select(self, *args): """Specifies columns that should be selected Returns: self """ for arg in args: if isinstance(arg, list): for column in arg: self._columns += (SelectExpression(column),) else: for column in arg.split(","): self._columns += (SelectExpression(column),) return self def distinct(self, boolean=True): """Specifies that all columns should be distinct Returns: self """ self._distinct = boolean return self def add_select(self, alias, callable): """Specifies columns that should be selected Returns: self """ builder = callable(self.new()) self._columns += (SubGroupExpression(builder, alias=alias),) return self def statement(self, query, bindings=None): if bindings is None: bindings = [] result = self.new_connection().query(query, bindings) return self.prepare_result(result) def select_raw(self, query): """Specifies raw SQL that should be injected into the select expression. Returns: self """ self._columns += (SelectExpression(query, raw=True),) return self def get_processor(self): return self.connection_class.get_default_post_processor()() def bulk_create( self, creates: List[Dict[str, Any]], query: bool = False, cast: bool = False ): self.set_action("bulk_create") model = None if self._model: model = self._model self._creates = [] for unsorted_create in creates: if model: unsorted_create = model.filter_mass_assignment(unsorted_create) if cast: unsorted_create = model.cast_values(unsorted_create) # sort the dicts by key so the values inserted align with the correct column self._creates.append(dict(sorted(unsorted_create.items()))) if query: return self if model: model = model.hydrate(self._creates) if not self.dry: connection = self.new_connection() query_result = connection.query(self.to_qmark(), self._bindings, results=1) processed_results = query_result or self._creates else: processed_results = self._creates if model: return model return processed_results def create( self, creates: Optional[Dict[str, Any]] = None, query: bool = False, id_key: str = "id", cast: bool = False, ignore_mass_assignment: bool = False, **kwargs, ): """Specifies a dictionary that should be used to create new values. Arguments: creates {dict} -- A dictionary of columns and values. Returns: self """ self.set_action("insert") model = None self._creates = creates if creates else kwargs if self._model: model = self._model # Update values with related record's self._creates.update(self._creates_related) # Filter __fillable/__guarded__ fields if not ignore_mass_assignment: self._creates = model.filter_mass_assignment(self._creates) # Cast values if necessary if cast: self._creates = model.cast_values(self._creates) if query: return self if model: model = model.hydrate(self._creates) self.observe_events(model, "creating") # if attributes were modified during model observer then we need to update the creates here self._creates.update(model.get_dirty_attributes()) if not self.dry: connection = self.new_connection() query_result = connection.query(self.to_qmark(), self._bindings, results=1) if model: id_key = model.get_primary_key() processed_results = self.get_processor().process_insert_get_id( self, query_result or self._creates, id_key ) else: processed_results = self._creates if model: model = model.fill(processed_results) self.observe_events(model, "created") return model return processed_results def hydrate(self, result, relations=None): return self._model.hydrate(result, relations) def delete(self, column=None, value=None, query=False): """Specify the column and value to delete or deletes everything based on a previously used where expression. Keyword Arguments: column {string} -- The name of the column (default: {None}) value {string|int} -- The value of the column (default: {None}) Returns: self """ model = None self.set_action("delete") if self._model: model = self._model if column and value: if isinstance(value, (list, tuple)): self.where_in(column, value) else: self.where(column, value) if query: return self if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) self.observe_events(model, "deleting") result = self.new_connection().query(self.to_qmark(), self._bindings) if model: self.observe_events(model, "deleted") return result def where(self, column, *args): """Specifies a where expression. Arguments: column {string} -- The name of the column to search Keyword Arguments: args {List} -- The operator and the value of the column to search. (default: {None}) Returns: self """ operator, value = self._extract_operator_value(*args) if inspect.isfunction(column): builder = column(self.new()) self._wheres += ( (QueryExpression(None, operator, SubGroupExpression(builder))), ) elif isinstance(column, dict): for key, value in column.items(): self._wheres += ((QueryExpression(key, "=", value, "value")),) elif isinstance(value, QueryBuilder): self._wheres += ( (QueryExpression(column, operator, SubSelectExpression(value))), ) else: self._wheres += ((QueryExpression(column, operator, value, "value")),) return self def where_from_builder(self, builder): """Specifies a where expression. Arguments: column {string} -- The name of the column to search Keyword Arguments: args {List} -- The operator and the value of the column to search. (default: {None}) Returns: self """ self._wheres += ((QueryExpression(None, "=", SubGroupExpression(builder))),) return self def where_like(self, column, value): """Specifies a where expression. Arguments: column {string} -- The name of the column to search Keyword Arguments: args {List} -- The operator and the value of the column to search. (default: {None}) Returns: self """ return self.where(column, "like", value) def where_not_like(self, column, value): """Specifies a where expression. Arguments: column {string} -- The name of the column to search Keyword Arguments: args {List} -- The operator and the value of the column to search. (default: {None}) Returns: self """ return self.where(column, "not like", value) def where_raw(self, query: str, bindings=()): """Specifies raw SQL that should be injected into the where expression. Arguments: query {string} -- The raw query string. Keyword Arguments: bindings {tuple} -- query bindings that should be added to the connection. (default: {()}) Returns: self """ self._wheres += ( (QueryExpression(query, "=", None, "value", raw=True, bindings=bindings)), ) return self def or_where(self, column, *args): """Specifies an or where query expression. Arguments: column {[type]} -- [description] value {[type]} -- [description] Returns: [type] -- [description] """ operator, value = self._extract_operator_value(*args) if inspect.isfunction(column): builder = column(self.new()) self._wheres += ( ( QueryExpression( None, operator, SubGroupExpression(builder), keyword="or" ) ), ) elif isinstance(value, QueryBuilder): self._wheres += ( (QueryExpression(column, operator, SubSelectExpression(value))), ) else: self._wheres += ( (QueryExpression(column, operator, value, "value", keyword="or")), ) return self def where_exists(self, value: "str|int|QueryBuilder"): """Specifies a where exists expression. Arguments: value {string|int|QueryBuilder} -- A value to check for the existence of a query expression. Returns: self """ if inspect.isfunction(value): self._wheres += ( ( QueryExpression( None, "EXISTS", SubSelectExpression(value(self.new())) ) ), ) elif isinstance(value, QueryBuilder): self._wheres += ( (QueryExpression(None, "EXISTS", SubSelectExpression(value))), ) else: self._wheres += ((QueryExpression(None, "EXISTS", value, "value")),) return self def or_where_exists(self, value: "str|int|QueryBuilder"): """Specifies a where exists expression. Arguments: value {string|int|QueryBuilder} -- A value to check for the existence of a query expression. Returns: self """ if inspect.isfunction(value): self._wheres += ( ( QueryExpression( None, "EXISTS", SubSelectExpression(value(self.new())), keyword="or", ) ), ) elif isinstance(value, QueryBuilder): self._wheres += ( ( QueryExpression( None, "EXISTS", SubSelectExpression(value), keyword="or" ) ), ) else: self._wheres += ( (QueryExpression(None, "EXISTS", value, "value", keyword="or")), ) return self def where_not_exists(self, value: "str|int|QueryBuilder"): """Specifies a where exists expression. Arguments: value {string|int|QueryBuilder} -- A value to check for the existence of a query expression. Returns: self """ if inspect.isfunction(value): self._wheres += ( ( QueryExpression( None, "NOT EXISTS", SubSelectExpression(value(self.new())) ) ), ) elif isinstance(value, QueryBuilder): self._wheres += ( (QueryExpression(None, "NOT EXISTS", SubSelectExpression(value))), ) else: self._wheres += ((QueryExpression(None, "NOT EXISTS", value, "value")),) return self def or_where_not_exists(self, value: "str|int|QueryBuilder"): """Specifies a where exists expression. Arguments: value {string|int|QueryBuilder} -- A value to check for the existence of a query expression. Returns: self """ if inspect.isfunction(value): self._wheres += ( ( QueryExpression( None, "NOT EXISTS", SubSelectExpression(value(self.new())), keyword="or", ) ), ) elif isinstance(value, QueryBuilder): self._wheres += ( ( QueryExpression( None, "NOT EXISTS", SubSelectExpression(value), keyword="or" ) ), ) else: self._wheres += ( (QueryExpression(None, "NOT EXISTS", value, "value", keyword="or")), ) return self def having(self, column, equality="", value=""): """Specifying a having expression. Arguments: column {string} -- The name of the column. Keyword Arguments: equality {string} -- An equality operator (default: {"="}) value {string} -- The value of the having expression (default: {""}) Returns: self """ self._having += ((HavingExpression(column, equality, value)),) return self def having_raw(self, string): """Specifies raw SQL that should be injected into the having expression. Arguments: string {string} -- The raw query string. Returns: self """ self._having += ((HavingExpression(string, raw=True)),) return self def where_null(self, column): """Specifies a where expression where the column is NULL. Arguments: column {string} -- The name of the column. Returns: self """ self._wheres += ((QueryExpression(column, "=", None, "NULL")),) return self def or_where_null(self, column): """Specifies a where expression where the column is NULL. Arguments: column {string} -- The name of the column. Returns: self """ self._wheres += ((QueryExpression(column, "=", None, "NULL", keyword="or")),) return self def chunk(self, chunk_amount): chunk_connection = self.new_connection() for result in chunk_connection.select_many(self.to_sql(), (), chunk_amount): yield self.prepare_result(result) def where_not_null(self, column: str): """Specifies a where expression where the column is not NULL. Arguments: column {string} -- The name of the column. Returns: self """ self._wheres += ((QueryExpression(column, "=", True, "NOT NULL")),) return self def _get_date_string(self, date): if isinstance(date, str): return date elif hasattr(date, "to_date_string"): return date.to_date_string() elif hasattr(date, "strftime"): return date.strftime("%m-%d-%Y") def where_date(self, column: str, date: "str|datetime"): """Specifies a where DATE expression Arguments: column {string} -- The name of the column. Returns: self """ self._wheres += ( (QueryExpression(column, "=", self._get_date_string(date), "DATE")), ) return self def or_where_date(self, column: str, date: "str|datetime"): """Specifies a where DATE expression Arguments: column {string} -- The name of the column. date {string|datetime|pendulum} -- The name of the column. Returns: self """ self._wheres += ( ( QueryExpression( column, "=", self._get_date_string(date), "DATE", keyword="or" ) ), ) return self def between(self, column: str, low: int, high: int): """Specifies a where between expression. Arguments: column {string} -- The name of the column. low {string} -- The value on the low end. high {string} -- The value on the high end. Returns: self """ self._wheres += (BetweenExpression(column, low, high),) return self def where_between(self, *args, **kwargs): return self.between(*args, **kwargs) def where_not_between(self, *args, **kwargs): return self.not_between(*args, **kwargs) def not_between(self, column: str, low: str, high: str): """Specifies a where not between expression. Arguments: column {string} -- The name of the column. low {string} -- The value on the low end. high {string} -- The value on the high end. Returns: self """ self._wheres += (BetweenExpression(column, low, high, equality="NOT BETWEEN"),) return self def where_in(self, column, wheres=None): """Specifies where a column contains a list of a values. Arguments: column {string} -- The name of the column. Keyword Arguments: wheres {list} -- A list of values (default: {[]}) Returns: self """ wheres = wheres or [] if not wheres: self._wheres += ((QueryExpression(0, "=", 1, "value_equals")),) elif isinstance(wheres, QueryBuilder): self._wheres += ( (QueryExpression(column, "IN", SubSelectExpression(wheres))), ) elif callable(wheres): self._wheres += ( ( QueryExpression( column, "IN", SubSelectExpression(wheres(self.new())) ) ), ) else: self._wheres += ((QueryExpression(column, "IN", list(wheres))),) return self def get_relation(self, relationship, builder=None): if not builder: builder = self if not builder._model: raise AttributeError( "You must specify a model in order to use relationship methods" ) return getattr(builder._model, relationship) def has(self, *relationships): if not self._model: raise AttributeError( "You must specify a model in order to use 'has' relationship methods" ) for relationship in relationships: if "." in relationship: last_builder = self._model.builder for split_relationship in relationship.split("."): related = last_builder.get_relation(split_relationship) last_builder = related.query_has(last_builder) else: related = getattr(self._model, relationship) related.query_has(self) return self def or_has(self, *relationships): if not self._model: raise AttributeError( "You must specify a model in order to use 'has' relationship methods" ) for relationship in relationships: if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: last_builder = related.query_has( last_builder, method="where_exists" ) continue last_builder = related.query_has( last_builder, method="or_where_exists" ) else: related = getattr(self._model, relationship) related.query_has(self, method="or_where_exists") return self def doesnt_have(self, *relationships): if not self._model: raise AttributeError( "You must specify a model in order to use the 'doesnt_have' relationship methods" ) for relationship in relationships: if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: last_builder = related.query_has( last_builder, method="where_exists" ) continue last_builder = related.query_has( last_builder, method="where_not_exists" ) else: related = getattr(self._model, relationship) related.query_has(self, method="where_not_exists") return self def or_doesnt_have(self, *relationships): if not self._model: raise AttributeError( "You must specify a model in order to use the 'doesnt_have' relationship methods" ) for relationship in relationships: if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: last_builder = related.query_has( last_builder, method="where_exists" ) continue last_builder = related.query_has( last_builder, method="or_where_not_exists" ) else: related = getattr(self._model, relationship) related.query_has(self, method="or_where_not_exists") return self def where_has(self, relationship, callback): if not self._model: raise AttributeError( "You must specify a model in order to use 'has' relationship methods" ) if "." in relationship: last_builder = self._model.builder splits = relationship.split(".") split_count = len(splits) for index, split_relationship in enumerate(splits): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: last_builder = related.query_where_exists( last_builder, callback, method="where_exists" ) continue last_builder = related.query_has(last_builder, method="where_exists") else: related = getattr(self._model, relationship) related.query_where_exists(self, callback, method="where_exists") return self def or_where_has(self, relationship, callback): if not self._model: raise AttributeError( "You must specify a model in order to use 'has' relationship methods" ) if "." in relationship: last_builder = self._model.builder splits = relationship.split(".") split_count = len(splits) for index, split_relationship in enumerate(splits): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: last_builder = related.query_where_exists( last_builder, callback, method="where_exists" ) continue last_builder = related.query_has(last_builder, method="or_where_exists") else: related = getattr(self._model, relationship) related.query_where_exists(self, callback, method="or_where_exists") return self def where_doesnt_have(self, relationship, callback): if not self._model: raise AttributeError( "You must specify a model in order to use the 'doesnt_have' relationship methods" ) if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: last_builder = getattr( last_builder._model, split_relationship ).query_where_exists(self, callback, method="where_exists") continue last_builder = related.query_has( last_builder, method="where_not_exists" ) else: related = getattr(self._model, relationship) related.query_where_exists(self, callback, method="where_not_exists") return self def or_where_doesnt_have(self, relationship, callback): if not self._model: raise AttributeError( "You must specify a model in order to use the 'doesnt_have' relationship methods" ) if "." in relationship: last_builder = self._model.builder split_count = len(relationship.split(".")) for index, split_relationship in enumerate(relationship.split(".")): related = last_builder.get_relation(split_relationship) if index + 1 == split_count: last_builder = getattr( last_builder._model, split_relationship ).query_where_exists(self, callback, method="where_exists") continue last_builder = related.query_has( last_builder, method="or_where_not_exists" ) else: related = getattr(self._model, relationship) related.query_where_exists(self, callback, method="or_where_not_exists") return self def with_count(self, relationship, callback=None): self.select(*self._model.get_selects()) return getattr(self._model, relationship).get_with_count_query( self, callback=callback ) def where_not_in(self, column, wheres=None): """Specifies where a column does not contain a list of a values. Arguments: column {string} -- The name of the column. Keyword Arguments: wheres {list} -- A list of values (default: {[]}) Returns: self """ wheres = wheres or [] if isinstance(wheres, QueryBuilder): self._wheres += ( (QueryExpression(column, "NOT IN", SubSelectExpression(wheres))), ) else: self._wheres += ((QueryExpression(column, "NOT IN", list(wheres))),) return self def join( self, table: str, column1=None, equality=None, column2=None, clause="inner" ): """Specifies a join expression. Arguments: table {string} -- The name of the table or an instance of JoinClause. column1 {string} -- The name of the foreign table. equality {string} -- The equality to join on. column2 {string} -- The name of the local column. Keyword Arguments: clause {string} -- The action clause. (default: {"inner"}) Returns: self """ if inspect.isfunction(column1): self._joins += (column1(JoinClause(table, clause=clause)),) elif isinstance(table, str): self._joins += ( JoinClause(table, clause=clause).on(column1, equality, column2), ) else: self._joins += (table,) return self def left_join(self, table, column1=None, equality=None, column2=None): """A helper method to add a left join expression. Arguments: table {string} -- The name of the table to join on. column1 {string} -- The name of the foreign table. equality {string} -- The equality to join on. column2 {string} -- The name of the local column. Returns: self """ return self.join( table=table, column1=column1, equality=equality, column2=column2, clause="left", ) def right_join(self, table, column1=None, equality=None, column2=None): """A helper method to add a right join expression. Arguments: table {string} -- The name of the table to join on. column1 {string} -- The name of the foreign table. equality {string} -- The equality to join on. column2 {string} -- The name of the local column. Returns: self """ return self.join( table=table, column1=column1, equality=equality, column2=column2, clause="right", ) def joins(self, *relationships, clause="inner"): for relationship in relationships: getattr(self._model, relationship).joins(self, clause=clause) return self def join_on(self, relationship, callback=None, clause="inner"): relation = getattr(self._model, relationship) relation.joins(self, clause=clause) if callback: new_from_builder = self.new_from_builder() new_from_builder.table(relation.get_builder().get_table_name()) self.where_from_builder(callback(new_from_builder)) return self def where_column(self, column1, column2): """Specifies where two columns equal eachother. Arguments: column1 {string} -- The name of the column. column2 {string} -- The name of the column. Returns: self """ self._wheres += ((QueryExpression(column1, "=", column2, "column")),) return self def take(self, *args, **kwargs): """Alias for limit method""" return self.limit(*args, **kwargs) def limit(self, amount): """Specifies a limit expression. Arguments: amount {int} -- The number of rows to limit. Returns: self """ self._limit = amount return self def offset(self, amount): """Specifies an offset expression. Arguments: amount {int} -- The number of rows to limit. Returns: self """ self._offset = amount return self def skip(self, *args, **kwargs): """Alias for limit method""" return self.offset(*args, **kwargs) def update( self, updates: Dict[str, Any], dry: bool = False, force: bool = False, cast: bool = False, ignore_mass_assignment: bool = False, ): """Specifies columns and values to be updated. Arguments: updates {dictionary} -- A dictionary of columns and values to update. dry {bool, optional}: Do everything except execute the query against the DB force {bool, optional}: Force an update statement to be executed even if nothing was changed cast {bool, optional}: Run all values through model's casters ignore_mass_assignment {bool, optional}: Whether the update should ignore mass assignment on the model Returns: self """ model = None additional = {} if self._model: model = self._model # Filter __fillable/__guarded__ fields if not ignore_mass_assignment: updates = model.filter_mass_assignment(updates) if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) additional.update({model.get_primary_key(): model.get_primary_key_value()}) self.observe_events(model, "updating") if model: if not model.__force_update__ and not force: # Filter updates to only those with changes updates = { attr: value for attr, value in updates.items() if ( value is None or model.__original_attributes__.get(attr, None) != value ) } # Do not execute query if no changes if not updates: return self if dry or self.dry else model # Cast date fields date_fields = model.get_dates() for key, value in updates.items(): if key in date_fields: if value: updates[key] = model.get_new_datetime_string(value) else: updates[key] = value # Cast value if necessary if cast: if value: updates[key] = model.cast_value(value) else: updates[key] = value elif not updates: # Do not perform query if there are no updates return self self._updates = (UpdateQueryExpression(updates),) self.set_action("update") if dry or self.dry: return self additional.update(updates) self.new_connection().query(self.to_qmark(), self._bindings) if model: model.fill(updates) self.observe_events(model, "updated") model.fill_original(updates) return model return additional def force_update(self, updates: dict, dry=False): return self.update(updates, dry=dry, force=True) def set_updates(self, updates: dict, dry=False): """Specifies columns and values to be updated. Arguments: updates {dictionary} -- A dictionary of columns and values to update. Keyword Arguments: dry {bool} -- Whether the query should be executed. (default: {False}) Returns: self """ self._updates += (UpdateQueryExpression(updates),) return self def increment(self, column, value=1, dry=False): """Increments a column's value. Arguments: column {string} -- The name of the column. Keyword Arguments: value {int} -- The value to increment by. (default: {1}) Returns: self """ model = None id_key = "id" id_value = None additional = {} if self._model: model = self._model id_value = self._model.get_primary_key_value() if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) additional.update({model.get_primary_key(): model.get_primary_key_value()}) self.observe_events(model, "updating") self._updates += ( UpdateQueryExpression(column, value, update_type="increment"), ) if dry or self.dry: return self.get_grammar().compile("update").to_sql() self.set_action("update") results = self.new_connection().query(self.to_qmark(), self._bindings) processed_results = self.get_processor().get_column_value( self, column, results, id_key, id_value ) return processed_results def decrement(self, column, value=1, dry=False): """Decrements a column's value. Arguments: column {string} -- The name of the column. Keyword Arguments: value {int} -- The value to decrement by. (default: {1}) Returns: self """ model = None id_key = "id" id_value = None additional = {} if self._model: model = self._model id_value = self._model.get_primary_key_value() if model and model.is_loaded(): self.where(model.get_primary_key(), model.get_primary_key_value()) additional.update({model.get_primary_key(): model.get_primary_key_value()}) self.observe_events(model, "updating") self._updates += ( UpdateQueryExpression(column, value, update_type="decrement"), ) if dry or self.dry: return self.get_grammar().compile("update").to_sql() self.set_action("update") result = self.new_connection().query(self.to_qmark(), self._bindings) processed_results = self.get_processor().get_column_value( self, column, result, id_key, id_value ) return processed_results def sum(self, column): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ self.aggregate("SUM", "{column}".format(column=column)) return self def count(self, column=None, dry=False): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ alias = "m_count_reserved" if (column == "*" or column is None) else column if column == "*": self.aggregate("COUNT", f"{column} as {alias}") elif column is None: self.aggregate("COUNT", f"* as {alias}") else: self.aggregate("COUNT", f"{column}") if dry or self.dry: return self if not column: result = self.new_connection().query( self.to_qmark(), self._bindings, results=1 ) if isinstance(result, dict): return result.get(alias, 0) prepared_result = list(result.values()) if not prepared_result: return 0 return prepared_result[0] else: return self def max(self, column): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ self.aggregate("MAX", "{column}".format(column=column)) return self def order_by(self, column, direction="ASC"): """Specifies a column to order by. Arguments: column {string} -- The name of the column. Keyword Arguments: direction {string} -- Specify either ASC or DESC order. (default: {"ASC"}) Returns: self """ for col in column.split(","): self._order_by += (OrderByExpression(col, direction=direction),) return self def order_by_raw(self, query, bindings=None): """Specifies a column to order by. Arguments: column {string} -- The name of the column. Keyword Arguments: direction {string} -- Specify either ASC or DESC order. (default: {"ASC"}) Returns: self """ if bindings is None: bindings = [] self._order_by += (OrderByExpression(query, raw=True, bindings=bindings),) return self def group_by(self, column): """Specifies a column to group by. Arguments: column {string} -- The name of the column to group by. Returns: self """ for col in column.split(","): self._group_by += (GroupByExpression(column=col),) return self def group_by_raw(self, query, bindings=None): """Specifies a column to group by. Arguments: query {string} -- A raw query Returns: self """ if bindings is None: bindings = [] self._group_by += ( GroupByExpression(column=query, raw=True, bindings=bindings), ) return self def aggregate(self, aggregate, column, alias=None): """Helper function to aggregate. Arguments: aggregate {string} -- The name of the aggregation. column {string} -- The name of the column to aggregate. """ self._aggregates += ( AggregateExpression(aggregate=aggregate, column=column, alias=alias), ) def first(self, fields=None, query=False): """Gets the first record. Returns: dictionary -- Returns a dictionary of results. """ if not fields: fields = [] self.select(fields).limit(1) if query: return self result = self.new_connection().query(self.to_qmark(), self._bindings, results=1) return self.prepare_result(result) def first_or_create(self, wheres, creates: dict = None): """Get the first record matching the attributes or create it. Returns: Model """ if creates is None: creates = {} record = self.where(wheres).first() total = {} if record: if hasattr(record, "serialize"): total.update(record.serialize()) else: total.update(record) total.update(creates) total.update(wheres) total.update(self._creates_related) if not record: return self.create(total, id_key=self.get_primary_key()) return record def sole(self, query=False): """Gets the only record matching a given criteria.""" result = self.take(2).get() if result.is_empty(): raise ModelNotFound() if result.count() > 1: raise MultipleRecordsFound() return result.first() def sole_value(self, column: str, query=False): return self.sole()[column] def first_where(self, column, *args): """Gets the first record with the given key / value pair""" if not args: return self.where_not_null(column).first() return self.where(column, *args).first() def last(self, column=None, query=False): """Gets the last record, ordered by column in descendant order or primary key if no column is given. Returns: dictionary -- Returns a dictionary of results. """ _column = column if column else self._model.get_primary_key() self.limit(1).order_by(_column, direction="DESC") if query: return self result = self.new_connection().query( self.to_qmark(), self._bindings, results=1, ) return self.prepare_result(result) def _get_eager_load_result(self, related, collection): return related.eager_load_from_collection(collection) def find(self, record_id, column=None, query=False): """Finds a row by the primary key ID. Requires a model Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model|None """ if not column: if not self._model: raise InvalidArgument("A colum to search is required") column = self._model.get_primary_key() if isinstance(record_id, (list, tuple)): self.where_in(column, record_id) else: self.where(column, record_id) if query: return self return self.first() def find_or(self, record_id: int, callback: Callable, args=None, column=None): """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. Arguments: record_id {int} -- The ID of the primary key to fetch. callback {Callable} -- The function to call if no record is found. Returns: Model|Callable """ if not callable(callback): raise InvalidArgument("A callback must be callable.") result = self.find(record_id=record_id, column=column) if not result: if not args: return callback() else: return callback(*args) return result def find_or_fail(self, record_id, column=None): """Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception. Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model|ModelNotFound """ result = self.find(record_id=record_id, column=column) if not result: raise ModelNotFound() return result def find_or_404(self, record_id, column=None): """Finds a row by the primary key ID (Requires a model) or raise an 404 exception. Arguments: record_id {int} -- The ID of the primary key to fetch. Returns: Model|HTTP404 """ try: return self.find_or_fail(record_id=record_id, column=column) except ModelNotFound: raise HTTP404() def first_or_fail(self, query=False): """Returns the first row from database. If no result found a ModelNotFound exception. Returns: dictionary|ModelNotFound """ if query: return self.first(query=True) result = self.first() if not result: raise ModelNotFound() return result def get_primary_key(self): return self._model.get_primary_key() def prepare_result(self, result, collection=False): if self._model and result: # eager load here hydrated_model = self._model.hydrate(result) if ( self._eager_relation.eagers or self._eager_relation.nested_eagers or self._eager_relation.callback_eagers ) and hydrated_model: for eager_load in self._eager_relation.get_eagers(): if isinstance(eager_load, dict): # Nested for relation, eagers in eager_load.items(): callback = None if inspect.isclass(self._model): related = getattr(self._model, relation) elif callable(eagers): related = getattr(self._model, relation) callback = eagers else: related = self._model.get_related(relation) result_set = related.get_related( self, hydrated_model, eagers=eagers, callback=callback ) self._register_relationships_to_model( related, result_set, hydrated_model, relation_key=relation, ) else: # Not Nested for eager in eager_load: if inspect.isclass(self._model): related = getattr(self._model, eager) else: related = self._model.get_related(eager) result_set = related.get_related(self, hydrated_model) self._register_relationships_to_model( related, result_set, hydrated_model, relation_key=eager ) if collection: return hydrated_model if result else Collection([]) else: return hydrated_model if result else None if collection: return Collection(result) if result else Collection([]) else: return result or None def _register_relationships_to_model( self, related, related_result, hydrated_model, relation_key ): """Takes a related result and a hydrated model and registers them to eachother using the relation key. Args: related_result (Model|Collection): Will be the related result based on the type of relationship. hydrated_model (Model|Collection): If a collection we will need to loop through the collection of models and register each one individually. Else we can just load the related_result into the hydrated_models relation_key (string): A key to bind the relationship with. Defaults to None. Returns: self """ if related_result and isinstance(hydrated_model, Collection): map_related = self._map_related(related_result, related) for model in hydrated_model: if isinstance(related_result, Collection): related.register_related(relation_key, model, map_related) else: model.add_relation({relation_key: map_related or None}) else: hydrated_model.add_relation({relation_key: related_result or None}) return self def _map_related(self, related_result, related): return related.map_related(related_result) def all(self, selects=[], query=False): """Returns all records from the table. Returns: dictionary -- Returns a dictionary of results. """ self.select(*selects) if query: return self result = self.new_connection().query(self.to_qmark(), self._bindings) or [] return self.prepare_result(result, collection=True) def get(self, selects=[]): """Runs the select query built from the query builder. Returns: self """ self.select(*selects) result = self.new_connection().query(self.to_qmark(), self._bindings) return self.prepare_result(result, collection=True) def new_connection(self): if self._connection: return self._connection self._connection = ( self.connection_class( **self.get_connection_information(), name=self.connection ) .set_schema(self._schema) .make_connection() ) return self._connection def get_connection(self): return self._connection def without_eager(self): self._should_eager = False return self def with_(self, *eagers): self._eager_relation.register(eagers) return self def paginate(self, per_page, page=1): if page == 1: offset = 0 else: offset = (int(page) * per_page) - per_page new_from_builder = self.new_from_builder() new_from_builder._order_by = () new_from_builder._columns = () result = self.limit(per_page).offset(offset).get() total = new_from_builder.count() paginator = LengthAwarePaginator(result, per_page, page, total) return paginator def simple_paginate(self, per_page, page=1): if page == 1: offset = 0 else: offset = (int(page) * per_page) - per_page result = self.limit(per_page).offset(offset).get() paginator = SimplePaginator(result, per_page, page) return paginator def set_action(self, action): """Sets the action that the query builder should take when the query is built. Arguments: action {string} -- The action that the query builder should take. Returns: self """ self._action = action return self def get_grammar(self): """Initializes and returns the grammar class. Returns: masoniteorm.grammar.Grammar -- An ORM grammar class. """ # Either _creates when creating, otherwise use columns columns = self._creates or self._columns if not columns and not self._aggregates and self._model: self.select(*self._model.get_selects()) columns = self._columns return self.grammar( columns=columns, table=self._table, wheres=self._wheres, limit=self._limit, offset=self._offset, updates=self._updates, aggregates=self._aggregates, order_by=self._order_by, group_by=self._group_by, distinct=self._distinct, lock=self.lock, joins=self._joins, having=self._having, ) def to_sql(self): """Compiles the QueryBuilder class into a SQL statement. Returns: self """ self.run_scopes() grammar = self.get_grammar() sql = grammar.compile(self._action, qmark=False).to_sql() return sql def explain(self): """Explains the Query execution plan. Returns: Collection """ sql = self.to_sql() explanation = self.statement(f"EXPLAIN {sql}") return explanation def run_scopes(self): for name, scope in self._global_scopes.get(self._action, {}).items(): scope(self) return self def to_qmark(self): """Compiles the QueryBuilder class into a Qmark SQL statement. Returns: self """ self.run_scopes() grammar = self.get_grammar() sql = grammar.compile(self._action, qmark=True).to_sql() self._bindings = grammar._bindings self.reset() return sql def new(self): """Creates a new QueryBuilder class. Returns: QueryBuilder -- The ORM QueryBuilder class. """ builder = QueryBuilder( grammar=self.grammar, connection_class=self.connection_class, connection=self.connection, connection_driver=self._connection_driver, model=self._model, ) if self._table: builder.table(self._table.name) return builder def avg(self, column): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ self.aggregate("AVG", "{column}".format(column=column)) return self def min(self, column): """Aggregates a columns values. Arguments: column {string} -- The name of the column to aggregate. Returns: self """ self.aggregate("MIN", "{column}".format(column=column)) return self def _extract_operator_value(self, *args): operators = [ "=", ">", ">=", "<", "<=", "!=", "<>", "like", "not like", "regexp", "not regexp", ] operator = operators[0] value = None if (len(args)) >= 2: operator = args[0] value = args[1] elif len(args) == 1: value = args[0] if operator not in operators: raise ValueError( "Invalid comparison operator. The operator can be %s" % ", ".join(operators) ) return operator, value def __call__(self): """Magic method to standardize what happens when the query builder object is called. Returns: self """ return self def macro(self, name, callable): self._macros.update({name: callable}) return self def when(self, conditional, callback): if conditional: callback(self) return self def truncate(self, foreign_keys=False, dry=False): sql = self.get_grammar().truncate_table(self.get_table_name(), foreign_keys) if dry or self.dry: return sql return self.new_connection().query(sql, ()) def exists(self): """Determine if rows exist for the current query. Returns: Bool - True or False """ if self.first(): return True else: return False def doesnt_exist(self): """Determine if no rows exist for the current query. Returns: Bool - True or False """ if self.exists(): return False else: return True def in_random_order(self): """Puts Query results in random order""" return self.order_by_raw(self.grammar().compile_random()) def new_from_builder(self, from_builder=None): """Creates a new QueryBuilder class. Returns: QueryBuilder -- The ORM QueryBuilder class. """ if from_builder is None: from_builder = self builder = QueryBuilder( grammar=self.grammar, connection_class=self.connection_class, connection=self.connection, connection_driver=self._connection_driver, ) if self._table: builder.table(self._table.name) builder._columns = deepcopy(from_builder._columns) builder._creates = deepcopy(from_builder._creates) builder._sql = "" builder._bindings = deepcopy(from_builder._bindings) builder._updates = deepcopy(from_builder._updates) builder._wheres = deepcopy(from_builder._wheres) builder._order_by = deepcopy(from_builder._order_by) builder._group_by = deepcopy(from_builder._group_by) builder._joins = deepcopy(from_builder._joins) builder._having = deepcopy(from_builder._having) builder._macros = deepcopy(from_builder._macros) builder._aggregates = deepcopy(from_builder._aggregates) builder._global_scopes = deepcopy(from_builder._global_scopes) return builder def get_table_columns(self): return self.get_schema().get_columns(self._table.name) def get_schema(self): return Schema( connection=self.connection, connection_details=self._connection_details ) def latest(self, *fields): """Gets the latest record. Returns: querybuilder """ if not fields: fields = ("created_at",) return self.order_by(column=",".join(fields), direction="DESC") def oldest(self, *fields): """Gets the oldest record. Returns: querybuilder """ if not fields: fields = ("created_at",) return self.order_by(column=",".join(fields), direction="ASC") def value(self, column: str): return self.get().first()[column] ================================================ FILE: src/masoniteorm/query/__init__.py ================================================ from .QueryBuilder import QueryBuilder ================================================ FILE: src/masoniteorm/query/grammars/BaseGrammar.py ================================================ import re from ...expressions.expressions import ( JoinClause, OnClause, SelectExpression, SubGroupExpression, SubSelectExpression, ) class BaseGrammar: """The keys in this dictionary is how the ORM will reference these aggregates The values on the right are the matching functions for the grammar Returns: [type] -- [description] """ table = "users" def __init__( self, columns=(), table="users", database=None, wheres=(), limit=False, offset=False, updates=None, aggregates=(), order_by=(), distinct=False, group_by=(), joins=(), lock=False, having=(), connection_details=None, ): self._columns = columns self.table = table self.database = database self._wheres = wheres self._limit = limit self._offset = offset self._updates = updates or {} self._aggregates = aggregates self._order_by = order_by self._group_by = group_by self._distinct = distinct self._joins = joins self._having = having self.lock = lock self._connection_details = connection_details or {} self._column = None self._bindings = [] self._sql = "" self._sql_qmark = "" self._action = "select" self.queries = [] def compile(self, action, qmark=False): self._action = action return getattr(self, "_compile_" + action)(qmark=qmark) def _compile_select(self, qmark=False): """Compile a select query statement. Keyword Arguments: qmark {bool} -- [description] (default: {False}) Returns: [type] -- [description] """ if not self.table: self._sql = ( self.select_no_table() .format( columns=self.process_columns(separator=", ", qmark=qmark), table=self.process_table(self.table), joins=self.process_joins(qmark=qmark), wheres=self.process_wheres(qmark=qmark), limit=self.process_limit(), offset=self.process_offset(), aggregates=self.process_aggregates(), order_by=self.process_order_by(), group_by=self.process_group_by(), having=self.process_having(), lock=self.process_locks(), ) .strip() ) else: self._sql = ( self.select_format() .format( columns=self.process_columns(separator=", ", qmark=qmark), keyword="DISTINCT" if self._distinct else "", table=self.process_table(self.table), joins=self.process_joins(qmark=qmark), wheres=self.process_wheres(qmark=qmark), limit=self.process_limit(), offset=self.process_offset(), aggregates=self.process_aggregates(), order_by=self.process_order_by(), group_by=self.process_group_by(), having=self.process_having(), lock=self.process_locks(), ) .strip() ) return self def _compile_update(self, qmark=False): """Compiles an update query statement. Keyword Arguments: qmark {bool} -- Whether the query should use qmark. (default: {False}) Returns: self """ self._sql = self.update_format().format( key_equals=self._compile_key_value_equals(qmark=qmark), table=self.process_table(self.table), wheres=self.process_wheres(qmark=qmark), ) return self def _compile_insert(self, qmark=False): """Compiles an insert expression. Returns: self """ self._sql = self.insert_format().format( key_equals=self._compile_key_value_equals(qmark=qmark), table=self.process_table(self.table), columns=self.process_columns(separator=", ", action="insert", qmark=qmark), values=self.process_values(separator=", ", qmark=qmark), ) return self def _compile_bulk_create(self, qmark=False): """Compiles an insert expression. Returns: self """ all_values = [list(x.values()) for x in self._columns] self._sql = self.bulk_insert_format().format( key_equals=self._compile_key_value_equals(qmark=qmark), table=self.process_table(self.table), columns=self.columnize_bulk_columns(list(self._columns[0].keys())), values=self.columnize_bulk_values(all_values, qmark=qmark), ) return self def columnize_bulk_columns(self, columns=[]): return ", ".join( self.column_string().format(column=x, separator="") for x in columns ).rstrip(",") def columnize_bulk_values(self, columns=[], qmark=False): sql = "" for x in columns: inner = "" if isinstance(x, list): for y in x: if qmark: self.add_binding(y) inner += ( "'?', " if qmark else self.value_string().format(value=y, separator=", ") ) inner = inner.rstrip(", ") sql += self.process_value_string().format(value=inner, separator=", ") else: if qmark: self.add_binding(x) sql += ( "'?', " if qmark else self.process_value_string().format( value="?" if qmark else x, separator=", " ) ) return sql.rstrip(", ") def process_value_string(self): return "({value}){separator}" def _compile_delete(self, qmark=False): """Compiles a delete expression. Returns: self """ self._sql = self.delete_format().format( key_equals=self._compile_key_value_equals(qmark=qmark), table=self.process_table(self.table), wheres=self.process_wheres(qmark=qmark), ) return self # TODO: Columnize? def _get_multiple_columns(self, columns): """Compiles a string or a list of strings into the grammars column syntax. Arguments: columns {string|list} -- A column or list of columns Returns: self """ if isinstance(columns, list): column_string = "" for col in columns: column_string += self.process_column(col) + ", " return column_string.rstrip(", ") return self.process_column(columns) def process_joins(self, qmark=False): """Compiles a join expression. Returns: self """ sql = "" for join in self._joins: if isinstance(join, JoinClause): on_string = "" for clause_idx, clause in enumerate(join.get_on_clauses()): keyword = clause.operator.upper() if clause_idx else "ON" if isinstance(clause, OnClause): on_string += f"{keyword} {self._table_column_string(clause.column1)} {clause.equality} {self._table_column_string(clause.column2)} " else: if clause.value_type == "NULL": sql_string = f"{self.where_null_string()} " on_string += sql_string.format( keyword=keyword, column=self.process_column(clause.column), ) elif clause.value_type == "NOT NULL": sql_string = f"{self.where_not_null_string()} " on_string += sql_string.format( keyword=keyword, column=self.process_column(clause.column), ) else: if qmark: value = "'?'" self.add_binding(clause.value) else: value = self._compile_value(clause.value) on_string += f"{keyword} {self._table_column_string(clause.column)} {clause.equality} {value} " sql += self.join_string().format( foreign_table=self.process_table(join.table), alias=f" AS {self.process_table(join.alias)}" if join.alias else "", on=on_string, keyword=self.join_keywords[join.clause], ) sql += " " return sql # TODO: Clean def _compile_key_value_equals(self, qmark=False): """Compiles key value pairs. Keyword Arguments: qmark {bool} -- Whether the query should use qmark. (default: {False}) Returns: self """ sql = "" for update in self._updates: if update.update_type == "increment": sql_string = self.increment_string() elif update.update_type == "decrement": sql_string = self.decrement_string() else: sql_string = self.key_value_string() column = update.column value = update.value if isinstance(column, dict): for key, value in column.items(): if hasattr(value, "expression"): sql += self.column_value_string().format( column=self._table_column_string(key), value=value.expression, separator=", ", ) else: sql += sql_string.format( column=self._table_column_string(key), value=value if not qmark else "?", separator=", ", ) if qmark: self._bindings += (value,) else: sql += sql_string.format( column=self._table_column_string(column), value=value if not qmark else "?", separator=", ", ) if qmark: self._bindings += (value,) sql = sql.rstrip(", ") return sql def process_aggregates(self): """Compiles aggregates to be used in a query expression. Returns: self """ sql = "" for aggregates in self._aggregates: aggregate = aggregates.aggregate column = aggregates.column aggregate_function = self.aggregate_options.get(aggregate, "") if not aggregates.alias and column == "*": aggregate_string = self.aggregate_string_without_alias() else: aggregate_string = self.aggregate_string_with_alias() sql += ( aggregate_string.format( aggregate_function=aggregate_function, column="*" if column == "*" else self._table_column_string(column), alias=self.process_alias(aggregates.alias or column), ) + ", " ) return sql def process_order_by(self): """Compiles an order by for a query expression. Returns: self """ sql = "" if self._order_by: order_crit = "" for order_bys in self._order_by: if order_bys.raw: order_crit += order_bys.column if not isinstance(order_bys.bindings, (list, tuple)): raise ValueError( f"Bindings must be tuple or list. Received {type(order_bys.bindings)}" ) if order_bys.bindings: self.add_binding(*order_bys.bindings) continue if len(order_crit): order_crit += ", " column = order_bys.column direction = order_bys.direction if "." in column: column_string = self._table_column_string(column) else: column_string = self.column_string().format( column=column, separator="" ) order_crit += self.order_by_format().format( column=column_string, direction=direction.upper() ) sql += self.order_by_string().format(order_columns=order_crit) return sql def process_group_by(self): """Compiles a group by for a query expression. Returns: self """ sql = "" columns = [] for group_by in self._group_by: if group_by.raw: if group_by.bindings: self.add_binding(*group_by.bindings) sql += "GROUP BY " + group_by.column return sql else: columns.append(self._table_column_string(group_by.column)) if columns: sql += " GROUP BY {column}".format(column=", ".join(columns)) return sql def process_alias(self, column): """Compiles an alias for a column. Arguments: column {string} -- The name of the column. Returns: self """ return column def process_table(self, table): """Compiles a given table name. Arguments: table {string} -- The table name to compile. Returns: self """ if not table: return "" if isinstance(table, str): return ".".join( self.table_string().format( table=t, database=self._connection_details.get("database", ""), prefix=self._connection_details.get("prefix", ""), ) for t in table.split(".") ) if table.raw: return table.name return ".".join( self.table_string().format( table=t, database=self._connection_details.get("database", ""), prefix=self._connection_details.get("prefix", ""), ) for t in table.name.split(".") ) def process_limit(self): """Compiles the limit expression. Returns: self """ if not self._limit: return "" return self.limit_string(offset=self._offset).format(limit=self._limit) def process_offset(self): """Compiles the offset expression. Returns: self """ if not self._offset: return "" return self.offset_string().format(offset=self._offset, limit=self._limit or 1) def process_locks(self): return self.locks.get(self.lock, "") def process_having(self, qmark=False): """Compiles having expression. Keyword Arguments: qmark {bool} -- Whether or not to use Qmark (default: {False}) Returns: self """ sql = "" for having in self._having: column = having.column equality = having.equality value = having.value raw = having.raw if not equality and not value: sql_string = self.having_string() else: sql_string = self.having_equality_string() sql += sql_string.format( column=self._table_column_string(column) if raw is False else column, equality=equality, value=self._compile_value(value), ) return sql def process_wheres(self, qmark=False, strip_first_where=False): """Compiles the where expression. Keyword Arguments: qmark {bool} -- Whether or not to use Qmark. (default: {False}) strip_first_where {bool} -- Whether or not to strip out the first where keyword. This is useful when using subselects (default: {False}) Returns: self """ sql = "" loop_count = 0 for where in self._wheres: column = where.column equality = where.equality value = where.value value_type = where.value_type """Need to get a specific keyword here. This keyword either needs to be something like WHERE, AND, OR. Depending on the loop depends on the placement of the AND """ if loop_count == 0: if strip_first_where: keyword = "" else: keyword = " " + self.first_where_string() elif hasattr(where, "keyword") and where.keyword == "or": keyword = " " + self.or_where_string() else: keyword = " " + self.additional_where_string() if where.raw: """If we have a raw query we just want to use the query supplied and don't need to compile anything. """ sql += self.raw_query_string().format( keyword=keyword, query=where.column ) if not isinstance(where.bindings, (list, tuple)): raise ValueError( f"Bindings must be tuple or list. Received {type(where.bindings)}" ) if where.bindings: self.add_binding(*where.bindings) loop_count += 1 continue """The column is an easy compile """ column = self._table_column_string(column) """Need to find which type of where string it is. If it is a WHERE NULL, WHERE EXISTS, WHERE `col` = 'val' etc """ equality = equality.upper() if equality == "BETWEEN": low = where.low high = where.high if qmark: self.add_binding(low) self.add_binding(high) low = "?" high = "?" sql_string = self.between_string().format( low=self._compile_value(low), high=self._compile_value(high), column=self._table_column_string(where.column), keyword=keyword, ) elif equality == "NOT BETWEEN": sql_string = self.not_between_string().format( low=self._compile_value(where.low), high=self._compile_value(where.high), column=self._table_column_string(where.column), keyword=keyword, ) elif value_type == "value_equals": sql_string = self.value_equal_string().format( value1=where.column, value2=where.value, keyword=keyword ) elif value_type == "NULL": sql_string = self.where_null_string() elif value_type == "DATE": sql_string = self.where_date_string() elif value_type == "NOT NULL": sql_string = self.where_not_null_string() elif equality == "EXISTS": sql_string = self.where_exists_string() elif equality == "NOT EXISTS": sql_string = self.where_not_exists_string() elif equality == "LIKE": sql_string = self.where_like_string() elif equality == "REGEXP": sql_string = self.where_regexp_string() elif equality == "NOT REGEXP": sql_string = self.where_not_regexp_string() elif equality == "NOT LIKE": sql_string = self.where_not_like_string() else: sql_string = self.where_string() """If the value should actually be a sub query then we need to wrap it in a query here """ if isinstance(value, SubGroupExpression): grammar = value.builder.get_grammar() query_value = ( self.subquery_string() .format( query=grammar.process_wheres( qmark=qmark, strip_first_where=True ) ) .replace("( ", "(") ) if grammar._bindings: self.add_binding(*grammar._bindings) sql_string = self.where_group_string() elif isinstance(value, SubSelectExpression): if qmark: query_from_builder = value.builder.to_qmark() if value.builder._bindings: self.add_binding(*value.builder._bindings) else: query_from_builder = value.builder.to_sql() query_value = self.subquery_string().format(query=query_from_builder) elif isinstance(value, list): query_value = "(" for val in value: if qmark: query_value += "'?', " self.add_binding(val) else: query_value += self.value_string().format( value=val, separator="," ) query_value = query_value.rstrip(",").rstrip(", ") + ")" elif value is True and value_type != "NOT NULL": sql_string = self.get_true_column_string() query_value = 1 elif value is False and value_type != "NOT NULL": sql_string = self.get_false_column_string() query_value = 0 elif qmark and value_type != "column": query_value = "'?'" if ( value is not True and value_type != "value_equals" and value_type != "NULL" and value_type != "BETWEEN" ): self.add_binding(value) elif value_type == "value": if qmark: query_value = "'?'" else: query_value = self.value_string().format(value=value, separator="") self.add_binding(value) elif value_type == "column": query_value = self._table_column_string(column=value, separator="") elif value_type == "DATE": query_value = self.value_string().format(value=value, separator="") elif value_type == "having": query_value = self._table_column_string(column=value, separator="") else: query_value = "" sql += sql_string.format( keyword=keyword, column=column, equality=equality, value=query_value ) loop_count += 1 return sql def get_true_column_string(self): return "{keyword} {column} = '1'" def get_false_column_string(self): return "{keyword} {column} = '0'" def add_binding(self, *bindings): """Adds one or more bindings to the bindings tuple. Arguments: binding {string} -- A value to bind. """ self._bindings += bindings def column_exists(self, column): """Check if a column exists Arguments: column {string} -- The name of the column to check for existence. Returns: self """ self._column = column self._sql = self.process_exists() return self def table_exists(self): """Checks if a table exists. Returns: self """ self._sql = self.table_exists_string().format( table=self.process_table(self.table), database=self.database, clean_table=self.table, ) return self def wrap_table(self, table_name): return self.table_string().format(table=table_name) def process_exists(self): """Specifies the column exists expression. Returns: self """ return self.column_exists_string().format( table=self.process_table(self.table), clean_table=self.table, value=self._compile_value(self._column), ) def to_sql(self): """Cleans up the SQL string and returns the SQL Returns: string """ return re.sub(" +", " ", self._sql.strip()) def to_qmark(self): """Cleans up the SQL string and returns the SQL Returns: string """ return re.sub(" +", " ", self._sql.strip()) # TODO: Inspect this can't just be used by another method. seems duplicative def process_columns(self, separator="", action="select", qmark=False): """Specifies the columns in a selection expression. Keyword Arguments: separator {str} -- The separator used between columns (default: {""}) Returns: self """ sql = "" for column in self._columns: alias = None if isinstance(column, SelectExpression): alias = column.alias if column.raw: sql += column.column + ", " continue column = column.column if isinstance(column, SubGroupExpression): if qmark: builder_sql = column.builder.to_qmark() if column.builder._bindings: self.add_binding(*column.builder._bindings) else: builder_sql = column.builder.to_sql() sql += f"({builder_sql}) AS {column.alias}, " continue sql += self._table_column_string(column, alias=alias, separator=separator) if self._aggregates: sql += self.process_aggregates() if sql == "": return "*" return sql.rstrip(",").rstrip(", ") # TODO: Duplicative? def process_values(self, separator="", qmark=False): """Compiles column values for insert expressions. Keyword Arguments: separator {str} -- The separator used between columns (default: {""}) Returns: self """ sql = "" if self._columns == "*": return self._columns elif isinstance(self._columns, list): for c in self._columns: for column, value in dict(c).items(): if qmark: self.add_binding(value) sql += f"'?'{separator}".strip() else: sql += self._compile_value(value, separator=separator) else: for column, value in dict(self._columns).items(): if qmark: self.add_binding(value) sql += f"'?'{separator}".strip() else: sql += self._compile_value(value, separator=separator) if not qmark: return sql[:-2] return sql.rstrip(separator.strip()) def process_column(self, column, separator=""): """Compiles a column into the column syntax. Arguments: column {string} -- The name of the column. Keyword Arguments: separator {string} -- The separator used between columns (default: {""}) Returns: self """ table = None if column and "." in column: table, column = column.split(".") return self.column_string().format( column=column, separator=separator, table=table or self.table ) def _table_column_string(self, column, alias=None, separator=""): """Compiles a column into the column syntax. Arguments: column {string} -- The name of the column. Keyword Arguments: separator {string} -- The separator used between columns (default: {""}) Returns: self """ table = None if column and "." in column: table, column = column.split(".") if column == "*": return self.column_strings.get("select_all").format( column=column, separator=separator, table=self.process_table(table or self.table), ) if alias: alias_string = self.subquery_alias_string().format(alias=alias) return self.column_strings.get(self._action).format( column=column, separator=separator, alias=" " + alias_string if alias else "", table=self.process_table(table or self.table), ) def _compile_value(self, value, separator=""): """Compiles a value using the value syntax. Arguments: value {string} -- The value to compile. Keyword Arguments: separator {string} -- The separator used between columns (default: {""}) Returns: self """ return self.value_string().format(value=value, separator=separator) def drop_table(self, table): """Specifies a drop table expression. Arguments: table {string} -- The table to drop. Returns: self """ self._sql = self.drop_table_string().format(table=self.process_column(table)) return self def drop_table_if_exists(self, table): """Specifies a drop table if exists expression. Arguments: table {string} -- The name of the table to drop. Returns: self """ self._sql = self.drop_table_if_exists_string().format( table=self.process_column(table) ) return self def rename_table(self, current_table_name, new_table_name): """Specifies a rename table expression. Arguments: current_table_name {string} -- The name of the table currently. new_table_name {string} -- The name you want to use now for the table. Returns: self """ self._sql = self.rename_table_string().format( current_table_name=self.process_column(current_table_name), new_table_name=self.process_column(new_table_name), ) return self def truncate_table(self, table, foreign_keys=False): """Specifies a truncate table expression. Arguments; table {string} -- The name of the table to truncate. Returns: self """ raise NotImplementedError( f"'{self.__class__.__name__}' does not support truncating" ) def where_regexp_string(self): return "{keyword} {column} REGEXP {value}" def where_not_regexp_string(self): return "{keyword} {column} NOT REGEXP {value}" ================================================ FILE: src/masoniteorm/query/grammars/MSSQLGrammar.py ================================================ from .BaseGrammar import BaseGrammar class MSSQLGrammar(BaseGrammar): """Microsoft SQL Server grammar class.""" aggregate_options = { "SUM": "SUM", "MAX": "MAX", "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", } join_keywords = { "inner": "INNER JOIN", "join": "INNER JOIN", "outer": "OUTER JOIN", "left": "LEFT JOIN", "right": "RIGHT JOIN", "left_inner": "LEFT INNER JOIN", "right_inner": "RIGHT INNER JOIN", } column_strings = { "select": "{table}.[{column}]{alias}{separator}", "select_all": "{table}.*{separator}", "insert": "{table}.[{column}]{separator}", "update": "{table}.[{column}]{separator}", "delete": "{table}.[{column}]{separator}", } locks = {"share": "WITH(ROWLOCK)", "update": "WITH(ROWLOCK)"} def select_no_table(self): return "SELECT {columns}" def select_format(self): return "SELECT {keyword} {limit} {columns} FROM {table} {lock} {joins} {wheres} {group_by} {having} {order_by} {offset}" def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" def insert_format(self): return "INSERT INTO {table} ({columns}) VALUES ({values})" def bulk_insert_format(self): return "INSERT INTO {table} ({columns}) VALUES {values}" def delete_format(self): return "DELETE FROM {table} {wheres}" def create_column_string(self): return "{column} {data_type}{length}{nullable}, " def create_start(self): return "CREATE TABLE {table} " def having_string(self): return "HAVING {column}" def where_exists_string(self): return "{keyword} EXISTS {value}" def where_not_exists_string(self): return "{keyword} NOT EXISTS {value}" def where_like_string(self): return "{keyword} {column} LIKE {value}" def where_not_like_string(self): return "{keyword} {column} NOT LIKE {value}" def where_date_string(self): return "{keyword} DATE({column}) {equality} {value}" def where_regexp_string(self): return self.where_like_string() def where_not_regexp_string(self): return self.where_not_like_string() def having_equality_string(self): return "HAVING {column} {equality} {value}" def aggregate_string_without_alias(self): return "{aggregate_function}({column})" def create_column_length(self): return "({length})" def limit_string(self, offset=False): if offset: return "" return "TOP {limit}" def first_where_string(self): return "WHERE" def additional_where_string(self): return "AND" def join_string(self): return "{keyword} {foreign_table}{alias} {on}" def aggregate_string(self): return "{aggregate_function}({column}) AS {alias}" def subquery_string(self): return "({query})" def subquery_alias_string(self): return "AS {alias}" def where_group_string(self): return "{keyword} {value}" def or_where_string(self): return "OR" def raw_query_string(self): return "{keyword} {query}" def where_in_string(self): return "WHERE IN ({values})" def value_equal_string(self): return "{keyword} {value1} = {value2}" def where_null_string(self): return " {keyword} {column} IS NULL" def between_string(self): return "{keyword} {column} BETWEEN {low} AND {high}" def not_between_string(self): return "{keyword} {column} NOT BETWEEN {low} AND {high}" def where_not_null_string(self): return " {keyword} {column} IS NOT NULL" def where_string(self): return " {keyword} {column} {equality} {value}" def offset_string(self): return "OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY" def increment_string(self): return "{column} = {column} + '{value}'{separator}" def decrement_string(self): return "{column} = {column} - '{value}'{separator}" def aggregate_string_with_alias(self): return "{aggregate_function}({column}) AS {alias}" def key_value_string(self): return "{column} = '{value}'{separator}" def column_value_string(self): return "{column} = {value}{separator}" def table_string(self): return "[{table}]" def order_by_format(self): return "{column} {direction}" def order_by_string(self): return "ORDER BY {order_columns}" def column_string(self): return "[{column}]{separator}" def table_column_string(self): return "[{table}].[{column}]{separator}" def table_update_column_string(self): return "[{table}].[{column}]{separator}" def table_insert_column_string(self): return "[{table}].[{column}]{separator}" def value_string(self): return "'{value}'{separator}" def wrap_table(self, table_name): return self.table_string().format(table=table_name) def truncate_table(self, table, foreign_keys=False): return f"TRUNCATE TABLE {self.wrap_table(table)}" def compile_random(self, seed): return "NEWID()" ================================================ FILE: src/masoniteorm/query/grammars/MySQLGrammar.py ================================================ from .BaseGrammar import BaseGrammar class MySQLGrammar(BaseGrammar): """MySQL grammar class.""" aggregate_options = { "SUM": "SUM", "MAX": "MAX", "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", } join_keywords = { "inner": "INNER JOIN", "join": "INNER JOIN", "outer": "OUTER JOIN", "left": "LEFT JOIN", "right": "RIGHT JOIN", "left_inner": "LEFT INNER JOIN", "right_inner": "RIGHT INNER JOIN", } """Column strings are formats for how columns and key values should be formatted on specific queries. These can be different depending on the type of query. For example for Postgres, You can specify columns as "users"."name": SELECT "users"."name" from "users" But on updates we can only specify the column name and cannot have the table prefixed: UPDATE "users" SET "name" = "value" This dictionary allows you to modify the format depending on the type of query we are generating. For most databases these will be the same but this allows you to modify formats depending on the database. """ column_strings = { "select": "{table}.`{column}`{alias}{separator}", "select_all": "{table}.*{separator}", "insert": "{table}.`{column}`{separator}", "update": "{table}.`{column}`{separator}", "delete": "{table}.`{column}`{separator}", } locks = {"share": "LOCK IN SHARE MODE", "update": "FOR UPDATE"} def select_format(self): return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}" def select_no_table(self): return "SELECT {columns} {lock}" def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" def insert_format(self): return "INSERT INTO {table} ({columns}) VALUES ({values})" def bulk_insert_format(self): return "INSERT INTO {table} ({columns}) VALUES {values}" def delete_format(self): return "DELETE FROM {table} {wheres}" def aggregate_string_with_alias(self): return "{aggregate_function}({column}) AS {alias}" def aggregate_string_without_alias(self): return "{aggregate_function}({column})" def subquery_string(self): return "({query})" def raw_query_string(self): return "{keyword} {query}" def where_group_string(self): return "{keyword} {value}" def between_string(self): return "{keyword} {column} BETWEEN {low} AND {high}" def not_between_string(self): return "{keyword} {column} NOT BETWEEN {low} AND {high}" def where_exists_string(self): return "{keyword} EXISTS {value}" def where_date_string(self): return "{keyword} DATE({column}) {equality} {value}" def where_not_exists_string(self): return "{keyword} NOT EXISTS {value}" def where_like_string(self): return "{keyword} {column} LIKE {value}" def where_not_like_string(self): return "{keyword} {column} NOT LIKE {value}" def get_true_column_string(self): return "{keyword} {column} = '1'" def get_false_column_string(self): return "{keyword} {column} = '0'" def process_table(self, table): """Compiles a given table name. Arguments: table {string} -- The table name to compile. Returns: self """ if not table: return "" if isinstance(table, str): return ".".join( self.table_string().format(table=t) for t in table.split(".") ) if table.raw: return table.name return ".".join( self.table_string().format(table=t) for t in table.name.split(".") ) def subquery_alias_string(self): return "AS {alias}" def key_value_string(self): return "{column} = '{value}'{separator}" def column_value_string(self): return "{column} = {value}{separator}" def increment_string(self): return "{column} = {column} + '{value}'{separator}" def decrement_string(self): return "{column} = {column} - '{value}'{separator}" def create_column_string(self): return "{column} {data_type}{length}{nullable}{default_value}, " def column_exists_string(self): return "SHOW COLUMNS FROM {table} LIKE {value}" def table_exists_string(self): return "SELECT * from information_schema.tables where table_name='{clean_table}' AND table_schema = '{database}'" def create_column_length(self, column_type): return "({length})" def table_string(self): return "`{table}`" def order_by_format(self): return "{column} {direction}" def order_by_string(self): return "ORDER BY {order_columns}" def column_string(self): return "`{column}`{separator}" def value_string(self): return "'{value}'{separator}" def join_string(self): return "{keyword} {foreign_table}{alias} {on}" def limit_string(self, offset=False): return "LIMIT {limit}" def offset_string(self): return "OFFSET {offset}" def first_where_string(self): return "WHERE" def additional_where_string(self): return "AND" def or_where_string(self): return "OR" def where_in_string(self): return "WHERE IN ({values})" def value_equal_string(self): return "{keyword} {value1} = {value2}" def where_string(self): return " {keyword} {column} {equality} {value}" def having_string(self): return "HAVING {column}" def having_equality_string(self): return "HAVING {column} {equality} {value}" def where_null_string(self): return " {keyword} {column} IS NULL" def where_not_null_string(self): return " {keyword} {column} IS NOT NULL" def enable_foreign_key_constraints(self): return "SET FOREIGN_KEY_CHECKS=1" def disable_foreign_key_constraints(self): return "SET FOREIGN_KEY_CHECKS=0" def truncate_table(self, table, foreign_keys=False): """Specifies a truncate table expression. Arguments; table {string} -- The name of the table to truncate. Returns: self """ if not foreign_keys: return f"TRUNCATE TABLE {self.wrap_table(table)}" return [ self.disable_foreign_key_constraints(), f"TRUNCATE TABLE {self.wrap_table(table)}", self.enable_foreign_key_constraints(), ] def compile_random(self): return "RAND()" ================================================ FILE: src/masoniteorm/query/grammars/PostgresGrammar.py ================================================ import re from .BaseGrammar import BaseGrammar class PostgresGrammar(BaseGrammar): """Postgres grammar class.""" aggregate_options = { "SUM": "SUM", "MAX": "MAX", "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", } join_keywords = { "inner": "INNER JOIN", "join": "INNER JOIN", "outer": "OUTER JOIN", "left": "LEFT JOIN", "right": "RIGHT JOIN", "left_inner": "LEFT INNER JOIN", "right_inner": "RIGHT INNER JOIN", } column_strings = { "select": '{table}."{column}"{alias}{separator}', "select_all": "{table}.*{separator}", "insert": '"{column}"{separator}', "update": '"{column}"{separator}', "delete": '{table}."{column}"{separator}', } locks = {"share": "FOR SHARE", "update": "FOR UPDATE"} def select_no_table(self): return "SELECT {columns} {lock}" def select_format(self): return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}" def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" def insert_format(self): return "INSERT INTO {table} ({columns}) VALUES ({values}) RETURNING *" def bulk_insert_format(self): return "INSERT INTO {table} ({columns}) VALUES {values} RETURNING *" def delete_format(self): return "DELETE FROM {table} {wheres}" def aggregate_string_with_alias(self): return "{aggregate_function}({column}) AS {alias}" def aggregate_string_without_alias(self): return "{aggregate_function}({column})" def get_true_column_string(self): return "{keyword} {column} IS True" def get_false_column_string(self): return "{keyword} {column} IS False" def subquery_string(self): return "({query})" def raw_query_string(self): return "{keyword} {query}" def where_group_string(self): return "{keyword} {value}" def between_string(self): return "{keyword} {column} BETWEEN {low} AND {high}" def not_between_string(self): return "{keyword} {column} NOT BETWEEN {low} AND {high}" def where_exists_string(self): return "{keyword} EXISTS {value}" def where_not_exists_string(self): return "{keyword} NOT EXISTS {value}" def where_like_string(self): return "{keyword} {column} ILIKE {value}" def where_not_like_string(self): return "{keyword} {column} NOT ILIKE {value}" def subquery_alias_string(self): return "AS {alias}" def key_value_string(self): return "{column} = '{value}'{separator}" def column_value_string(self): return "{column} = {value}{separator}" def increment_string(self): return "{column} = {column} + '{value}'{separator}" def decrement_string(self): return "{column} = {column} - '{value}'{separator}" def create_column_string(self): return "{column} {data_type}{length}{nullable}, " def column_exists_string(self): return "SELECT column_name FROM information_schema.columns WHERE table_name='{clean_table}' and column_name={value}" def table_exists_string(self): return ( "SELECT * from information_schema.tables where table_name='{clean_table}'" ) def create_column_length(self, column_type): if column_type in self.types_without_lengths: return "" return "({length})" def to_sql(self): """Cleans up the SQL string and returns the SQL Returns: string """ if self.queries and (not self._columns and not self._creates): sql = "" for query in self.queries: query += "; " sql += re.sub(" +", " ", query) return sql.rstrip(" ") else: sql = re.sub(" +", " ", self._sql.strip().replace(",)", ")")) for query in self.queries: sql += "; " sql += re.sub(" +", " ", query.strip()) return sql def table_string(self): return '"{table}"' def order_by_format(self): return "{column} {direction}" def order_by_string(self): return "ORDER BY {order_columns}" def column_string(self): return '"{column}"{separator}' def value_string(self): return "'{value}'{separator}" def join_string(self): return "{keyword} {foreign_table}{alias} {on}" def limit_string(self, offset=False): return "LIMIT {limit}" def offset_string(self): return "OFFSET {offset}" def first_where_string(self): return "WHERE" def additional_where_string(self): return "AND" def or_where_string(self): return "OR" def where_in_string(self): return "WHERE IN ({values})" def where_date_string(self): return "{keyword} DATE({column}) {equality} {value}" def value_equal_string(self): return "{keyword} {value1} = {value2}" def where_string(self): return " {keyword} {column} {equality} {value}" def having_string(self): return "HAVING {column}" def having_equality_string(self): return "HAVING {column} {equality} {value}" def where_null_string(self): return " {keyword} {column} IS NULL" def where_not_null_string(self): return " {keyword} {column} IS NOT NULL" def truncate_table(self, table, foreign_keys=False): """Specifies a truncate table expression. Arguments; table {string} -- The name of the table to truncate. Returns: string """ return f"TRUNCATE TABLE {self.wrap_table(table)}" def compile_random(self): return "random()" ================================================ FILE: src/masoniteorm/query/grammars/SQLiteGrammar.py ================================================ import re from .BaseGrammar import BaseGrammar class SQLiteGrammar(BaseGrammar): """SQLite grammar class.""" aggregate_options = { "SUM": "SUM", "MAX": "MAX", "MIN": "MIN", "AVG": "AVG", "COUNT": "COUNT", } join_keywords = { "inner": "INNER JOIN", "join": "INNER JOIN", "outer": "OUTER JOIN", "left": "LEFT JOIN", "right": "LEFT JOIN", "left_inner": "LEFT INNER JOIN", "right_inner": "LEFT INNER JOIN", } column_strings = { "select": '{table}."{column}"{alias}{separator}', "select_all": "{table}.*{separator}", "insert": '"{column}"{separator}', "update": '"{column}"{separator}', "delete": '"{column}"{separator}', } locks = {"share": "", "update": ""} def select_format(self): return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}" def select_no_table(self): return "SELECT {columns} {lock}" def update_format(self): return "UPDATE {table} SET {key_equals} {wheres}" def insert_format(self): return "INSERT INTO {table} ({columns}) VALUES ({values})" def bulk_insert_format(self): return "INSERT INTO {table} ({columns}) VALUES {values}" def delete_format(self): return "DELETE FROM {table} {wheres}" def aggregate_string_with_alias(self): return "{aggregate_function}({column}) AS {alias}" def aggregate_string_without_alias(self): return "{aggregate_function}({column})" def subquery_string(self): return "({query})" def default_string(self): return " DEFAULT {default} " def raw_query_string(self): return "{keyword} {query}" def where_group_string(self): return "{keyword} {value}" def between_string(self): return "{keyword} {column} BETWEEN {low} AND {high}" def not_between_string(self): return "{keyword} {column} NOT BETWEEN {low} AND {high}" def where_exists_string(self): return "{keyword} EXISTS {value}" def where_not_exists_string(self): return "{keyword} NOT EXISTS {value}" def where_like_string(self): return "{keyword} {column} LIKE {value}" def where_not_like_string(self): return "{keyword} {column} NOT LIKE {value}" def subquery_alias_string(self): return "AS {alias}" def key_value_string(self): return "{column} = '{value}'{separator}" def column_value_string(self): return "{column} = {value}{separator}" def increment_string(self): return "{column} = {column} + '{value}'{separator}" def decrement_string(self): return "{column} = {column} - '{value}'{separator}" def column_exists_string(self): return "SELECT column_name FROM information_schema.columns WHERE table_name='{clean_table}' and column_name={value}" def table_exists_string(self): return ( "SELECT name FROM sqlite_master WHERE type='table' AND name='{clean_table}'" ) def to_sql(self): """Cleans up the SQL string and returns the SQL Returns: string """ if self.queries and (not self._columns and not self._creates): sql = "" for query in self.queries: query += "; " sql += re.sub(" +", " ", query) return sql.rstrip(" ") else: sql = re.sub(" +", " ", self._sql.strip().replace(",)", ")")) for query in self.queries: sql += "; " sql += re.sub(" +", " ", query.strip()) return sql def table_string(self): return '"{table}"' def order_by_format(self): return "{column} {direction}" def order_by_string(self): return "ORDER BY {order_columns}" def column_string(self): return '"{column}"{separator}' def value_string(self): return "'{value}'{separator}" def join_string(self): return "{keyword} {foreign_table}{alias} {on}" def limit_string(self, offset=False): if offset: return "" return "LIMIT {limit}" def offset_string(self): return "LIMIT {limit} OFFSET {offset}" def first_where_string(self): return "WHERE" def additional_where_string(self): return "AND" def or_where_string(self): return "OR" def where_in_string(self): return "WHERE IN ({values})" def where_string(self): return " {keyword} {column} {equality} {value}" def having_string(self): return "HAVING {column}" def having_equality_string(self): return "HAVING {column} {equality} {value}" def where_null_string(self): return " {keyword} {column} IS NULL" def where_date_string(self): return "{keyword} DATE({column}) {equality} {value}" def value_equal_string(self): return "{keyword} {value1} = {value2}" def where_not_null_string(self): return " {keyword} {column} IS NOT NULL" def enable_foreign_key_constraints(self): return "PRAGMA foreign_keys = ON" def disable_foreign_key_constraints(self): return "PRAGMA foreign_keys = OFF" def get_true_column_string(self): return "{keyword} {column} = '1'" def get_false_column_string(self): return "{keyword} {column} = '0'" def truncate_table(self, table, foreign_keys=False): # SQLite do not have TRUNCATE TABLE command but we can # use SQLite DELETE command to delete complete data from an existing table if not foreign_keys: return f"DELETE FROM {self.wrap_table(table)}" return [ self.disable_foreign_key_constraints(), f"DELETE FROM {self.wrap_table(table)}", self.enable_foreign_key_constraints(), ] def compile_random(self): return "random()" def process_offset(self): """Compiles the offset expression. Returns: self """ if not self._limit: self._limit = int(-1) return super().process_offset() ================================================ FILE: src/masoniteorm/query/grammars/__init__.py ================================================ from .MSSQLGrammar import MSSQLGrammar from .MySQLGrammar import MySQLGrammar from .PostgresGrammar import PostgresGrammar from .SQLiteGrammar import SQLiteGrammar ================================================ FILE: src/masoniteorm/query/processors/MSSQLPostProcessor.py ================================================ class MSSQLPostProcessor: """Post processor classes are responsable for modifying the result after a query. Post Processors are called after the connection calls the database in the Query Builder but before the result is returned in that builder method. We can use this oppurtunity to get things like the inserted ID. For the Postgres Post Processor we have a RETURNING * string in the insert so the result will already have the full inserted record in the results. Therefore, we can just return the results """ def process_insert_get_id(self, builder, results, id_key): """Process the results from the query to the database. Args: builder (masoniteorm.builder.QueryBuilder): The query builder class results (dict): The result from an insert query or the creates from the query builder. This is usually a dictionary. id_key (string): The key to set the primary key to. This is usually the primary key of the table. Returns: dictionary: Should return the modified dictionary. """ last_id = builder.new_connection().query( "SELECT @@Identity as [id]", results=1 ) id = last_id["id"] if str(id).isdigit(): id = int(id) else: id = str(id) results.update({id_key: id}) return results def get_column_value(self, builder, column, results, id_key, id_value): """Gets the specific column value from a table. Typically done after an update to refetch the new value of a field. builder (masoniteorm.builder.QueryBuilder): The query builder class column (string): The column to refetch the value for. results (dict): The result from an update query from the query builder. This is usually a dictionary. id_key (string): The key to fetch the primary key for. This is usually the primary key of the table. id_value (string): The value of the primary key to fetch """ new_builder = builder.select(column) if id_key and id_value: new_builder.where(id_key, id_value) return new_builder.first()[column] return {} ================================================ FILE: src/masoniteorm/query/processors/MySQLPostProcessor.py ================================================ class MySQLPostProcessor: """Post processor classes are responsable for modifying the result after a query. Post Processors are called after the connection calls the database in the Query Builder but before the result is returned in that builder method. We can use this oppurtunity to get things like the inserted ID. For the SQLite Post Processor we have an attribute on the connection class we can use to fetch the ID. """ def process_insert_get_id(self, builder, results, id_key): """Process the results from the query to the database. Args: builder (masoniteorm.builder.QueryBuilder): The query builder class results (dict): The result from an insert query or the creates from the query builder. This is usually a dictionary. id_key (string): The key to set the primary key to. This is usually the primary key of the table. Returns: dictionary: Should return the modified dictionary. """ if id_key not in results: results.update({id_key: builder._connection.get_cursor().lastrowid}) return results def get_column_value(self, builder, column, results, id_key, id_value): """Gets the specific column value from a table. Typically done after an update to refetch the new value of a field. builder (masoniteorm.builder.QueryBuilder): The query builder class column (string): The column to refetch the value for. results (dict): The result from an update query from the query builder. This is usually a dictionary. id_key (string): The key to fetch the primary key for. This is usually the primary key of the table. id_value (string): The value of the primary key to fetch """ new_builder = builder.select(column) if id_key and id_value: new_builder.where(id_key, id_value) return new_builder.first()[column] return {} ================================================ FILE: src/masoniteorm/query/processors/PostgresPostProcessor.py ================================================ class PostgresPostProcessor: """Post processor classes are responsable for modifying the result after a query. Post Processors are called after the connection calls the database in the Query Builder but before the result is returned in that builder method. We can use this oppurtunity to get things like the inserted ID. For the Postgres Post Processor we have a RETURNING * string in the insert so the result will already have the full inserted record in the results. Therefore, we can just return the results """ def process_insert_get_id(self, builder, results, id_key): """Process the results from the query to the database. Args: builder (masoniteorm.builder.QueryBuilder): The query builder class results (dict): The result from an insert query or the creates from the query builder. This is usually a dictionary. id_key (string): The key to set the primary key to. This is usually the primary key of the table. Returns: dictionary: Should return the modified dictionary. """ return results def get_column_value(self, builder, column, results, id_key, id_value): """Gets the specific column value from a table. Typically done after an update to refetch the new value of a field. builder (masoniteorm.builder.QueryBuilder): The query builder class column (string): The column to refetch the value for. results (dict): The result from an update query from the query builder. This is usually a dictionary. id_key (string): The key to fetch the primary key for. This is usually the primary key of the table. id_value (string): The value of the primary key to fetch """ if column in results: return results[column] new_builder = builder.select(column) if id_key and id_value: new_builder.where(id_key, id_value) return new_builder.first()[column] return {} ================================================ FILE: src/masoniteorm/query/processors/SQLitePostProcessor.py ================================================ class SQLitePostProcessor: """Post processor classes are responsable for modifying the result after a query. Post Processors are called after the connection calls the database in the Query Builder but before the result is returned in that builder method. We can use this oppurtunity to get things like the inserted ID. For the SQLite Post Processor we have an attribute on the connection class we can use to fetch the ID. """ def process_insert_get_id(self, builder, results, id_key="id"): """Process the results from the query to the database. Args: builder (masoniteorm.builder.QueryBuilder): The query builder class results (dict): The result from an insert query or the creates from the query builder. This is usually a dictionary. id_key (string): The key to set the primary key to. This is usually the primary key of the table. Returns: dictionary: Should return the modified dictionary. """ if id_key not in results: results.update({id_key: builder.get_connection().get_cursor().lastrowid}) return results def get_column_value(self, builder, column, results, id_key, id_value): """Gets the specific column value from a table. Typically done after an update to refetch the new value of a field. builder (masoniteorm.builder.QueryBuilder): The query builder class column (string): The column to refetch the value for. results (dict): The result from an update query from the query builder. This is usually a dictionary. id_key (string): The key to fetch the primary key for. This is usually the primary key of the table. id_value (string): The value of the primary key to fetch """ new_builder = builder.select(column) if id_key and id_value: new_builder.where(id_key, id_value) return new_builder.first()[column] return {} ================================================ FILE: src/masoniteorm/query/processors/__init__.py ================================================ from .MSSQLPostProcessor import MSSQLPostProcessor from .MySQLPostProcessor import MySQLPostProcessor from .PostgresPostProcessor import PostgresPostProcessor from .SQLitePostProcessor import SQLitePostProcessor ================================================ FILE: src/masoniteorm/relationships/BaseRelationship.py ================================================ class BaseRelationship: def __init__(self, fn, local_key=None, foreign_key=None): if isinstance(fn, str): self.fn = None self.local_key = fn self.foreign_key = local_key else: self.fn = fn self.local_key = local_key self.foreign_key = foreign_key def __set_name__(self, cls, name): """This method is called right after the decorator is registered. At this point we finally have access to the model cls Arguments: name {object} -- The model class. """ pass def __call__(self, fn=None, *args, **kwargs): """This method is called when the decorator contains arguments. When you do something like this: @belongs_to('id', 'user_id'). In this case, the {fn} argument will be the callable. """ if callable(fn): self.fn = fn return self def get_builder(self): return self._related_builder def __get__(self, instance, owner): """ This method is called when the decorated method is accessed. Arguments: instance {object|None} -- The instance we called. If we didn't call the attribute and only accessed it then this will be None. owner {object} -- The current model that the property was accessed on. Returns: object -- Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ relationship = self.fn(instance)() self.set_keys(instance, attribute) self._related_builder = relationship.builder if not instance.is_loaded(): return self if attribute in instance._relationships: return instance._relationships[attribute] return self.apply_query(self._related_builder, instance) def __getattr__(self, attribute): relationship = self.fn(self)() return getattr(relationship.builder, attribute) def apply_query(self, foreign, owner): """Return a dictionary to hydrate the model with Arguments: foreign {oject} -- The relationship object owner {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'apply_query' method" ) def query_where_exists(self, builder, callback, method="where_exists"): """Adds a criteria clause to the query filter for existing related records""" klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'query_where_exists' method" ) def joins(self, builder, clause=None): """Helper method for adding join clauses to a relationship""" other_table = self.get_builder().get_table_name() local_table = builder.get_table_name() return builder.join( other_table, f"{local_table}.{self.local_key}", "=", f"{other_table}.{self.foreign_key}", clause=clause, ) def get_with_count_query(self, builder, callback): """Adds a clause to the query to get the record count of the relationship""" klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'get_with_count_query' method" ) def attach(self, current_model, related_record): """Link a related model to the current model""" klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'attach' method" ) def get_related(self, query, relation, eagers=None, callback=None): klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'get_related' method" ) def relate(self, related_record): klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'relate' method" ) def detach(self, current_model, related_record): """Unlink a related model from the current model""" klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'detach' method" ) def attach_related(self, current_model, related_record): """Unlink a related model from the current model""" klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'attach_related' method" ) def detach_related(self, current_model, related_record): """Unlink a related model from the current model""" klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'detach_related' method" ) def query_has(self, current_query_builder, method="where_exists"): """Adds a clause to the query to chek if a rwlarion exists""" klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'query_has' method" ) def map_related(self, related_result): klass = self.__class__.__name__ raise NotImplementedError( f"{klass} relationship does not implement the 'related_result' method" ) ================================================ FILE: src/masoniteorm/relationships/BelongsTo.py ================================================ from ..collection import Collection from .BaseRelationship import BaseRelationship class BelongsTo(BaseRelationship): """Belongs To Relationship Class.""" def __init__(self, fn, local_key=None, foreign_key=None): if isinstance(fn, str): self.fn = None self.local_key = fn or "id" self.foreign_key = local_key else: self.fn = fn self.local_key = local_key or "id" self.foreign_key = foreign_key def set_keys(self, owner, attribute): self.local_key = self.local_key or f"{attribute}_id" self.foreign_key = self.foreign_key or "id" return self def apply_query(self, foreign, owner): """Apply the query and return a dictionary to be hydrated Arguments: foreign {oject} -- The relationship object owner {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ return foreign.where( self.foreign_key, owner.__attributes__[self.local_key] ).first() def query_has(self, current_query_builder, method="where_exists"): related_builder = self.get_builder() getattr(current_query_builder, method)( related_builder.where_column( f"{related_builder.get_table_name()}.{self.foreign_key}", f"{current_query_builder.get_table_name()}.{self.local_key}", ) ) return related_builder def query_where_exists(self, builder, callback, method="where_exists"): query = self.get_builder() getattr(builder, method)( callback( query.where_column( f"{query.get_table_name()}.{self.foreign_key}", f"{builder.get_table_name()}.{self.local_key}", ) ) ) return query def get_related(self, query, relation, eagers=(), callback=None): """Gets the relation needed between the relation and the related builder. If the relation is a collection then will need to pluck out all the keys from the collection and fetch from the related builder. If relation is just a Model then we can just call the model based on the value of the related builders primary key. Args: relation (Model|Collection): Returns: Model|Collection """ builder = self.get_builder().with_(eagers) if callback: callback(builder) if isinstance(relation, Collection): return builder.where_in( f"{builder.get_table_name()}.{self.foreign_key}", Collection(relation._get_value(self.local_key)).unique(), ).get() else: return builder.where( f"{builder.get_table_name()}.{self.foreign_key}", getattr(relation, self.local_key), ).first() def register_related(self, key, model, collection): related = collection.get(getattr(model, self.local_key), None) model.add_relation({key: related[0] if related else None}) def map_related(self, related_result): return related_result.group_by(self.foreign_key) def attach(self, current_model, related_record): foreign_key_value = getattr(related_record, self.foreign_key) if not current_model.is_created(): current_model.fill({self.local_key: foreign_key_value}) return current_model.create(current_model.all_attributes(), cast=True) return current_model.update({self.local_key: foreign_key_value}) def detach(self, current_model, related_record): return current_model.update({self.local_key: None}) def relate(self, related_record): return ( self.get_builder() .where(self.foreign_key, related_record.__attributes__[self.local_key]) ._set_creates_related( {self.foreign_key: related_record.__attributes__[self.local_key]} ) ) ================================================ FILE: src/masoniteorm/relationships/BelongsToMany.py ================================================ import pendulum from inflection import singularize from ..collection import Collection from ..models.Pivot import Pivot from .BaseRelationship import BaseRelationship class BelongsToMany(BaseRelationship): """Has Many Relationship Class.""" def __init__( self, fn=None, local_foreign_key=None, other_foreign_key=None, local_owner_key=None, other_owner_key=None, table=None, with_timestamps=False, pivot_id="id", attribute="pivot", with_fields=[], ): if isinstance(fn, str): self.fn = None self.local_key = fn self.foreign_key = local_foreign_key self.local_owner_key = other_foreign_key or "id" self.other_owner_key = local_owner_key or "id" else: self.fn = fn self.local_key = local_foreign_key self.foreign_key = other_foreign_key self.local_owner_key = local_owner_key or "id" self.other_owner_key = other_owner_key or "id" self._table = table self.with_timestamps = with_timestamps self._as = attribute self.pivot_id = pivot_id self.with_fields = with_fields def set_keys(self, owner, attribute): self.local_key = self.local_key or "id" self.foreign_key = self.foreign_key or f"{attribute}_id" return self def apply_query(self, query, owner): """Apply the query and return a dictionary to be hydrated. Used during accessing a relationship on a model Arguments: query {oject} -- The relationship object owner {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ if not self._table: pivot_tables = [ singularize(owner.builder.get_table_name()), singularize(query.get_table_name()), ] pivot_tables.sort() pivot_table_1, pivot_table_2 = pivot_tables self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" table1 = owner.get_table_name() table2 = query.get_table_name() result = query.select( f"{query.get_table_name()}.*", f"{self._table}.{self.local_key} as {self._table}_id", f"{self._table}.{self.foreign_key} as m_reserved2", ).table(f"{table1}") if self.pivot_id: result.select(f"{self._table}.{self.pivot_id} as m_reserved3") if self.with_timestamps: result.select( f"{self._table}.updated_at as m_reserved4", f"{self._table}.created_at as m_reserved5", ) result.join( f"{self._table}", f"{self._table}.{self.local_key}", "=", f"{table1}.{self.local_owner_key}", ) result.join( f"{table2}", f"{self._table}.{self.foreign_key}", "=", f"{table2}.{self.other_owner_key}", ) if hasattr(owner, self.local_owner_key): result.where( f"{table1}.{self.local_owner_key}", getattr(owner, self.local_owner_key) ) if self.with_fields: for field in self.with_fields: result.select(f"{self._table}.{field}") result = result.get() for model in result: pivot_data = { self.local_key: getattr(model, f"{self._table}_id"), self.foreign_key: getattr(model, "m_reserved2"), } if self.with_timestamps: pivot_data = { "created_at": getattr(model, "m_reserved5"), "updated_at": getattr(model, "m_reserved4"), } model.delete_attribute("m_reserved4") model.delete_attribute("m_reserved5") model.delete_attribute("m_reserved2") if self.pivot_id: pivot_data.update({self.pivot_id: getattr(model, "m_reserved3")}) model.delete_attribute("m_reserved3") if self.with_fields: for field in self.with_fields: pivot_data.update({field: getattr(model, field)}) model.delete_attribute(field) model.__original_attributes__.update( { self._as: ( Pivot.on(query.connection) .table(self._table) .hydrate(pivot_data) .activate_timestamps(self.with_timestamps) ) } ) return result def table(self, table): self._table = table return self def make_builder(self, eagers=None): builder = self.get_builder().with_(eagers) return builder def make_query(self, query, relation, eagers=None, callback=None): """Used during eager loading a relationship Args: query ([type]): [description] relation ([type]): [description] eagers (list, optional): List of eager loaded relationships. Defaults to None. Returns: [type]: [description] """ eagers = eagers or [] builder = self.get_builder().with_(eagers) if not self._table: pivot_tables = [ singularize(builder.get_table_name()), singularize(query.get_table_name()), ] pivot_tables.sort() pivot_table_1, pivot_table_2 = pivot_tables self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" table2 = builder.get_table_name() table1 = query.get_table_name() result = ( builder.select( f"{table2}.*", f"{self._table}.{self.local_key} as {self._table}_id", f"{self._table}.{self.foreign_key} as m_reserved2", ) .run_scopes() .table(f"{table1}") ) if self.with_fields: for field in self.with_fields: result.select(f"{self._table}.{field}") result.join( f"{self._table}", f"{self._table}.{self.local_key}", "=", f"{table1}.{self.local_owner_key}", ) result.join( f"{table2}", f"{self._table}.{self.foreign_key}", "=", f"{table2}.{self.other_owner_key}", ) if self.with_timestamps: result.select( f"{self._table}.updated_at as m_reserved4", f"{self._table}.created_at as m_reserved5", ) if self.pivot_id: result.select(f"{self._table}.{self.pivot_id} as m_reserved3") result.without_global_scopes() if callback: callback(result) if isinstance(relation, Collection): return result.where_in( self.local_owner_key, Collection(relation._get_value(self.local_owner_key)).unique(), ).get() else: return result.where( self.local_owner_key, getattr(relation, self.local_owner_key) ).get() def get_related(self, query, relation, eagers=None, callback=None): final_result = self.make_query( query, relation, eagers=eagers, callback=callback ) builder = self.make_builder(eagers) for model in final_result: pivot_data = { self.local_key: getattr(model, f"{self._table}_id"), self.foreign_key: getattr(model, "m_reserved2"), } model.delete_attribute("m_reserved2") if self.with_timestamps: pivot_data.update( { "updated_at": getattr(model, "m_reserved4"), "created_at": getattr(model, "m_reserved5"), } ) if self.pivot_id: pivot_data.update({self.pivot_id: getattr(model, "m_reserved3")}) model.delete_attribute("m_reserved3") if self.with_fields: for field in self.with_fields: pivot_data.update({field: getattr(model, field)}) model.delete_attribute(field) model.__original_attributes__.update( { self._as: ( Pivot.on(builder.connection) .table(self._table) .hydrate(pivot_data) .activate_timestamps(self.with_timestamps) ) } ) return final_result def relate(self, related_record): owner = related_record.get_builder() query = self.get_builder() if not self._table: pivot_tables = [ singularize(owner.builder.get_table_name()), singularize(query.get_table_name()), ] pivot_tables.sort() pivot_table_1, pivot_table_2 = pivot_tables self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" table1 = owner.get_table_name() table2 = query.get_table_name() result = query.select( f"{query.get_table_name()}.*", f"{self._table}.{self.local_key} as {self._table}_id", f"{self._table}.{self.foreign_key} as m_reserved2", ).table(f"{table1}") if self.pivot_id: result.select(f"{self._table}.{self.pivot_id} as m_reserved3") if self.with_timestamps: result.select( f"{self._table}.updated_at as m_reserved4", f"{self._table}.created_at as m_reserved5", ) result.join( f"{self._table}", f"{self._table}.{self.local_key}", "=", f"{table1}.{self.local_owner_key}", ) result.join( f"{table2}", f"{self._table}.{self.foreign_key}", "=", f"{table2}.{self.other_owner_key}", ) if hasattr(owner, self.local_owner_key): result.where( f"{table1}.{self.local_owner_key}", getattr(owner, self.local_owner_key) ) if self.with_fields: for field in self.with_fields: result.select(f"{self._table}.{field}") return result def register_related(self, key, model, collection): model.add_relation( { key: collection.where( f"{self._table}_id", getattr(model, self.local_owner_key) ) } ) def joins(self, builder, clause=None): if not self._table: pivot_tables = [ singularize(self.get_builder().get_table_name()), singularize(builder.get_table_name()), ] pivot_tables.sort() pivot_table_1, pivot_table_2 = pivot_tables self._table = "_".join(pivot_tables) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" elif self.local_key is None or self.foreign_key is None: pivot_table_1, pivot_table_2 = self._table.split("_", 1) self.foreign_key = self.foreign_key or f"{pivot_table_1}_id" self.local_key = self.local_key or f"{pivot_table_2}_id" query = self.get_builder() table1 = query.get_table_name() table2 = builder.get_table_name() result = builder if not builder._columns: result = result.select( f"{table2}.*", f"{self._table}.{self.local_key} as {self._table}_id", f"{self._table}.{self.foreign_key} as m_reserved2", ) if self.pivot_id: result.select(f"{self._table}.{self.pivot_id} as m_reserved3") if self.with_timestamps: result.select( f"{self._table}.updated_at as m_reserved4", f"{self._table}.created_at as m_reserved5", ) if self.with_fields: for field in self.with_fields: result.select(f"{self._table}.{field}") # Join pivot table with an inner join result.join( f"{self._table}", f"{self._table}.{self.local_key}", "=", f"{table2}.{self.local_owner_key}", clause="inner", ) result.join( f"{table1}", f"{self._table}.{self.local_owner_key}", "=", f"{table1}.{self.other_owner_key}", clause=clause, ) if self.with_fields: for field in self.with_fields: result.select(f"{self._table}.{field}") return result def query_where_exists(self, builder, callback, method="where_exists"): query = self.get_builder() pivot_table = self._table or self.get_pivot_table_name(query, builder) table = self.get_builder().get_table_name() getattr(builder, method)( query.new() .table(table) .join( f"{pivot_table}", f"{table}.{self.other_owner_key}", "=", f"{pivot_table}.{self.foreign_key}", ) .where_column( f"{pivot_table}.{self.local_key}", f"{builder.get_table_name()}.{self.local_owner_key}", ) .where_in( self.other_owner_key, callback(query.select(self.other_owner_key)) ) ) def query_has(self, builder, method="where_exists"): query = self.get_builder() pivot_table = self._table or self.get_pivot_table_name(query, builder) table = self.get_builder().get_table_name() return getattr(builder, method)( query.new() .table(table) .join( f"{pivot_table}", f"{table}.{self.other_owner_key}", "=", f"{pivot_table}.{self.foreign_key}", ) .where_column( f"{pivot_table}.{self.local_key}", f"{builder.get_table_name()}.{self.local_owner_key}", ) ) def get_pivot_table_name(self, query, builder): pivot_tables = [ singularize(query.get_table_name()), singularize(builder.get_table_name()), ] pivot_tables.sort() return "_".join(pivot_tables) def get_with_count_query(self, builder, callback): query = self.get_builder() self._table = self._table or self.get_pivot_table_name(query, builder) if not builder._columns: builder = builder.select("*") return_query = builder.add_select( f"{query.get_table_name()}_count", lambda q: ( ( q.count("*") .where_column( f"{builder.get_table_name()}.{self.local_owner_key}", f"{self._table}.{self.local_key}", ) .table(self._table) .when( callback, lambda q: ( q.where_in( self.foreign_key, callback(query.select(self.other_owner_key)), ) ), ) ) ), ) return return_query def attach(self, current_model, related_record): data = { self.local_key: getattr(current_model, self.local_owner_key), self.foreign_key: getattr(related_record, self.other_owner_key), } self._table = self._table or self.get_pivot_table_name( current_model, related_record ) if self.with_timestamps: data.update( { "created_at": pendulum.now().to_datetime_string(), "updated_at": pendulum.now().to_datetime_string(), } ) return ( Pivot.on(current_model.get_builder().connection) .table(self._table) .without_global_scopes() .create(data) ) def detach(self, current_model, related_record): data = { self.local_key: getattr(current_model, self.local_owner_key), self.foreign_key: getattr(related_record, self.other_owner_key), } self._table = self._table or self.get_pivot_table_name( current_model, related_record ) return ( Pivot.on(current_model.get_builder().connection) .table(self._table) .without_global_scopes() .where(data) .delete() ) def attach_related(self, current_model, related_record): data = { self.local_key: getattr(current_model, self.local_owner_key), self.foreign_key: getattr(related_record, self.other_owner_key), } self._table = self._table or self.get_pivot_table_name( current_model, related_record ) if self.with_timestamps: data.update( { "created_at": pendulum.now().to_datetime_string(), "updated_at": pendulum.now().to_datetime_string(), } ) return ( Pivot.table(self._table) .on(current_model.get_builder().connection) .without_global_scopes() .create(data) ) def detach_related(self, current_model, related_record): data = { self.local_key: getattr(current_model, self.local_owner_key), self.foreign_key: getattr(related_record, self.other_owner_key), } self._table = self._table or self.get_pivot_table_name( current_model, related_record ) if self.with_timestamps: data.update( { "created_at": pendulum.now().to_datetime_string(), "updated_at": pendulum.now().to_datetime_string(), } ) return ( Pivot.on(current_model.get_builder().connection) .table(self._table) .without_global_scopes() .where(data) .delete() ) ================================================ FILE: src/masoniteorm/relationships/HasMany.py ================================================ from ..collection import Collection from .BaseRelationship import BaseRelationship class HasMany(BaseRelationship): """Has Many Relationship Class.""" def apply_query(self, foreign, owner): """Apply the query and return a dictionary to be hydrated Arguments: foreign {oject} -- The relationship object owner {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ result = foreign.where( self.foreign_key, owner.__attributes__[self.local_key] ).get() return result def set_keys(self, owner, attribute): self.local_key = self.local_key or "id" self.foreign_key = self.foreign_key or f"{attribute}_id" return self def register_related(self, key, model, collection): model.add_relation( {key: collection.get(getattr(model, self.local_key)) or Collection()} ) def map_related(self, related_result): return related_result.group_by(self.foreign_key) def attach(self, current_model, related_record): local_key_value = getattr(current_model, self.local_key) if not related_record.is_created(): related_record.fill({self.foreign_key: local_key_value}) return related_record.create(related_record.all_attributes(), cast=True) return related_record.update({self.foreign_key: local_key_value}) def get_related(self, query, relation, eagers=None, callback=None): eagers = eagers or [] builder = self.get_builder().with_(eagers) if callback: callback(builder) if isinstance(relation, Collection): return builder.where_in( f"{builder.get_table_name()}.{self.foreign_key}", Collection(relation._get_value(self.local_key)).unique(), ).get() return builder.where( f"{builder.get_table_name()}.{self.foreign_key}", getattr(relation, self.local_key), ).get() ================================================ FILE: src/masoniteorm/relationships/HasManyThrough.py ================================================ from ..collection import Collection from .BaseRelationship import BaseRelationship class HasManyThrough(BaseRelationship): """HasManyThrough Relationship Class.""" def __init__( self, fn=None, local_foreign_key=None, other_foreign_key=None, local_owner_key=None, other_owner_key=None, ): if isinstance(fn, str): self.fn = None self.local_key = fn self.foreign_key = local_foreign_key self.local_owner_key = other_foreign_key or "id" self.other_owner_key = local_owner_key or "id" else: self.fn = fn self.local_key = local_foreign_key self.foreign_key = other_foreign_key self.local_owner_key = local_owner_key or "id" self.other_owner_key = other_owner_key or "id" def set_keys(self, distant_builder, intermediary_builder, attribute): self.local_key = self.local_key or "id" self.foreign_key = self.foreign_key or f"{attribute}_id" self.local_owner_key = self.local_owner_key or "id" self.other_owner_key = self.other_owner_key or "id" return self def __get__(self, instance, owner): """This method is called when the decorated method is accessed. Arguments: instance {object|None} -- The instance we called. If we didn't call the attribute and only accessed it then this will be None. owner {object} -- The current model that the property was accessed on. Returns: object -- Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ self.attribute = attribute relationship1 = self.fn(self)[0]() relationship2 = self.fn(self)[1]() self.distant_builder = relationship1.builder self.intermediary_builder = relationship2.builder self.set_keys(self.distant_builder, self.intermediary_builder, attribute) if not instance.is_loaded(): return self if attribute in instance._relationships: return instance._relationships[attribute] return self.apply_related_query( self.distant_builder, self.intermediary_builder, instance ) def apply_related_query(self, distant_builder, intermediary_builder, owner): """ Apply the query to return a Collection of data for the distant models to be hydrated with. Method is used when accessing a relationship on a model if its not already eager loaded Arguments distant_builder (QueryBuilder): QueryBuilder attached to the distant table intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table owner (Any): the model this relationship is starting from Returns Collection: Collection of dicts which will be used for hydrating models. """ distant_table = distant_builder.get_table_name() intermediate_table = intermediary_builder.get_table_name() return ( self.distant_builder.select( f"{distant_table}.*, {intermediate_table}.{self.local_key}" ) .join( f"{intermediate_table}", f"{intermediate_table}.{self.foreign_key}", "=", f"{distant_table}.{self.other_owner_key}", ) .where( f"{intermediate_table}.{self.local_key}", getattr(owner, self.local_owner_key), ) .get() ) def relate(self, related_model): return self.distant_builder.join( f"{self.intermediary_builder.get_table_name()}", f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}", "=", f"{self.distant_builder.get_table_name()}.{self.other_owner_key}", ).where( f"{self.intermediary_builder.get_table_name()}.{self.local_key}", getattr(related_model, self.local_owner_key), ) def get_builder(self): return self.distant_builder def make_builder(self, eagers=None): builder = self.get_builder().with_(eagers) return builder def register_related(self, key, model, collection): """ Attach the related model to source models attribute Arguments key (str): The attribute name model (Any): The model instance collection (Collection): The data for the related models Returns None """ related = collection.get(getattr(model, self.local_owner_key), None) if related and not isinstance(related, Collection): related = Collection(related) model.add_relation({key: related if related else None}) def get_related(self, current_builder, relation, eagers=None, callback=None): """ Get a Collection to hydrate the models for the distant table with Used when eager loading the model attribute Arguments current_builder (QueryBuilder): The source models QueryBuilder object relation (HasManyThrough): this relationship object eagers (Any): callback (Any): Returns Collection the collection of dicts to hydrate the distant models with """ distant_table = self.distant_builder.get_table_name() intermediate_table = self.intermediary_builder.get_table_name() if callback: callback(current_builder) ( self.distant_builder.select( f"{distant_table}.*, {intermediate_table}.{self.local_key}" ).join( f"{intermediate_table}", f"{intermediate_table}.{self.foreign_key}", "=", f"{distant_table}.{self.other_owner_key}", ) ) if isinstance(relation, Collection): return self.distant_builder.where_in( f"{intermediate_table}.{self.local_key}", Collection(relation._get_value(self.local_owner_key)).unique(), ).get() else: return self.distant_builder.where( f"{intermediate_table}.{self.local_key}", getattr(relation, self.local_owner_key), ).get() def query_has(self, current_builder, method="where_exists"): distant_table = self.distant_builder.get_table_name() intermediate_table = self.intermediary_builder.get_table_name() getattr(current_builder, method)( self.distant_builder.join( f"{intermediate_table}", f"{intermediate_table}.{self.foreign_key}", "=", f"{distant_table}.{self.other_owner_key}", ).where_column( f"{intermediate_table}.{self.local_key}", f"{current_builder.get_table_name()}.{self.local_owner_key}", ) ) return self.distant_builder def query_where_exists(self, current_builder, callback, method="where_exists"): distant_table = self.distant_builder.get_table_name() intermediate_table = self.intermediary_builder.get_table_name() getattr(current_builder, method)( self.distant_builder.join( f"{intermediate_table}", f"{intermediate_table}.{self.foreign_key}", "=", f"{distant_table}.{self.other_owner_key}", ) .where_column( f"{intermediate_table}.{self.local_key}", f"{current_builder.get_table_name()}.{self.local_owner_key}", ) .when(callback, lambda q: (callback(q))) ) def get_with_count_query(self, current_builder, callback): distant_table = self.distant_builder.get_table_name() intermediate_table = self.intermediary_builder.get_table_name() if not current_builder._columns: current_builder.select("*") return_query = current_builder.add_select( f"{self.attribute}_count", lambda q: ( ( q.count("*") .join( f"{intermediate_table}", f"{intermediate_table}.{self.foreign_key}", "=", f"{distant_table}.{self.other_owner_key}", ) .where_column( f"{intermediate_table}.{self.local_key}", f"{current_builder.get_table_name()}.{self.local_owner_key}", ) .table(distant_table) .when( callback, lambda q: ( q.where_in( self.foreign_key, callback( self.distant_builder.select(self.other_owner_key) ), ) ), ) ) ), ) return return_query def map_related(self, related_result): return related_result.group_by(self.local_key) ================================================ FILE: src/masoniteorm/relationships/HasOne.py ================================================ from ..collection import Collection from .BaseRelationship import BaseRelationship class HasOne(BaseRelationship): """Belongs To Relationship Class.""" def __init__(self, fn, foreign_key=None, local_key=None): if isinstance(fn, str): self.foreign_key = fn self.local_key = foreign_key or "id" else: self.fn = fn self.local_key = local_key or "id" self.foreign_key = foreign_key def set_keys(self, owner, attribute): self.local_key = self.local_key or "id" self.foreign_key = self.foreign_key or f"{attribute}_id" return self def apply_query(self, foreign, owner): """Apply the query and return a dictionary to be hydrated Arguments: foreign {oject} -- The relationship object owner {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ return foreign.where( self.foreign_key, owner.__attributes__[self.local_key] ).first() def get_related(self, query, relation, eagers=(), callback=None): """Gets the relation needed between the relation and the related builder. If the relation is a collection then will need to pluck out all the keys from the collection and fetch from the related builder. If relation is just a Model then we can just call the model based on the value of the related builders primary key. Args: relation (Model|Collection): Returns: Model|Collection """ builder = self.get_builder().with_(eagers) if callback: callback(builder) if isinstance(relation, Collection): return builder.where_in( f"{builder.get_table_name()}.{self.foreign_key}", Collection(relation._get_value(self.local_key)).unique(), ).get() else: return builder.where( f"{builder.get_table_name()}.{self.foreign_key}", getattr(relation, self.local_key), ).first() def query_has(self, current_query_builder, method="where_exists"): related_builder = self.get_builder() getattr(current_query_builder, method)( related_builder.where_column( f"{related_builder.get_table_name()}.{self.foreign_key}", f"{current_query_builder.get_table_name()}.{self.local_key}", ) ) return related_builder def query_where_exists(self, builder, callback, method="where_exists"): query = self.get_builder() getattr(builder, method)( callback( query.where_column( f"{query.get_table_name()}.{self.foreign_key}", f"{builder.get_table_name()}.{self.local_key}", ) ) ) return query def register_related(self, key, model, collection): related = collection.where( self.foreign_key, getattr(model, self.local_key) ).first() model.add_relation({key: related or None}) def map_related(self, related_result): return related_result def attach(self, current_model, related_record): local_key_value = getattr(current_model, self.local_key) if not related_record.is_created(): related_record.fill({self.foreign_key: local_key_value}) return related_record.create(related_record.all_attributes(), cast=True) return related_record.update({self.foreign_key: local_key_value}) def detach(self, current_model, related_record): return related_record.update({self.foreign_key: None}) ================================================ FILE: src/masoniteorm/relationships/HasOneThrough.py ================================================ from ..collection import Collection from .BaseRelationship import BaseRelationship class HasOneThrough(BaseRelationship): """HasOneThrough Relationship Class.""" def __init__( self, fn=None, local_foreign_key=None, other_foreign_key=None, local_owner_key=None, other_owner_key=None, ): if isinstance(fn, str): self.fn = None self.local_key = fn self.foreign_key = local_foreign_key self.local_owner_key = other_foreign_key or "id" self.other_owner_key = local_owner_key or "id" else: self.fn = fn self.local_key = local_foreign_key self.foreign_key = other_foreign_key self.local_owner_key = local_owner_key or "id" self.other_owner_key = other_owner_key or "id" def __getattr__(self, attribute): relationship = self.fn(self)[1]() return getattr(relationship.builder, attribute) def set_keys(self, distant_builder, intermediary_builder, attribute): self.local_key = self.local_key or "id" self.foreign_key = self.foreign_key or f"{attribute}_id" self.local_owner_key = self.local_owner_key or "id" self.other_owner_key = self.other_owner_key or "id" return self def __get__(self, instance, owner): """ This method is called when the decorated method is accessed. Arguments instance (object|None): The instance we called. If we didn't call the attribute and only accessed it then this will be None. owner (object): The current model that the property was accessed on. Returns QueryBuilder|Model: Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ self.attribute = attribute relationship1 = self.fn(self)[0]() relationship2 = self.fn(self)[1]() self.distant_builder = relationship1.builder self.intermediary_builder = relationship2.builder self.set_keys(self.distant_builder, self.intermediary_builder, attribute) if instance.is_loaded(): if attribute in instance._relationships: return instance._relationships[attribute] return self.apply_relation_query( self.distant_builder, self.intermediary_builder, instance ) else: return self def apply_relation_query(self, distant_builder, intermediary_builder, owner): """ Apply the query and return a dict of data for the distant model to be hydrated with. Method is used when accessing a relationship on a model if its not already eager loaded Arguments distant_builder (QueryBuilder): QueryBuilder attached to the distant table intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table owner (Any): the model this relationship is starting from Returns dict: A dictionary of data which will be hydrated. """ dist_table = distant_builder.get_table_name() int_table = intermediary_builder.get_table_name() return ( distant_builder.select( f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}" ) .join( f"{int_table}", f"{int_table}.{self.foreign_key}", "=", f"{dist_table}.{self.other_owner_key}", ) .where( f"{int_table}.{self.local_owner_key}", getattr(owner, self.local_key), ) .first() ) def relate(self, related_model): dist_table = self.distant_builder.get_table_name() int_table = self.intermediary_builder.get_table_name() return self.distant_builder.join( f"{int_table}", f"{int_table}.{self.foreign_key}", "=", f"{dist_table}.{self.other_owner_key}", ).where_column( f"{int_table}.{self.local_owner_key}", getattr(related_model, self.local_key), ) def get_builder(self): return self.distant_builder def make_builder(self, eagers=None): builder = self.get_builder().with_(eagers) return builder def register_related(self, key, model, collection): """ Attach the related model to source models attribute Arguments key (str): The attribute name model (Any): The model instance collection (Collection): The data for the related models Returns None """ related = collection.get(getattr(model, self.local_key), None) model.add_relation({key: related[0] if related else None}) def get_related(self, current_builder, relation, eagers=None, callback=None): """ Get the data to hydrate the model for the distant table with Used when eager loading the model attribute Arguments query (QueryBuilder): The source models QueryBuilder object relation (HasOneThrough): this relationship object eagers (Any): callback (Any): Returns dict: the dict to hydrate the distant model with """ dist_table = self.distant_builder.get_table_name() int_table = self.intermediary_builder.get_table_name() if callback: callback(current_builder) ( self.distant_builder.select( f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}" ).join( f"{int_table}", f"{int_table}.{self.foreign_key}", "=", f"{dist_table}.{self.other_owner_key}", ) ) if isinstance(relation, Collection): return self.distant_builder.where_in( f"{int_table}.{self.local_owner_key}", Collection(relation._get_value(self.local_key)).unique(), ).get() else: return self.distant_builder.where( f"{int_table}.{self.local_owner_key}", getattr(relation, self.local_key), ).first() def query_has(self, current_builder, method="where_exists"): dist_table = self.distant_builder.get_table_name() int_table = self.intermediary_builder.get_table_name() getattr(current_builder, method)( self.distant_builder.join( f"{int_table}", f"{int_table}.{self.foreign_key}", "=", f"{dist_table}.{self.other_owner_key}", ).where_column( f"{int_table}.{self.local_owner_key}", f"{current_builder.get_table_name()}.{self.local_key}", ) ) return self.distant_builder def query_where_exists(self, current_builder, callback, method="where_exists"): dist_table = self.distant_builder.get_table_name() int_table = self.intermediary_builder.get_table_name() getattr(current_builder, method)( self.distant_builder.join( f"{int_table}", f"{int_table}.{self.foreign_key}", "=", f"{dist_table}.{self.other_owner_key}", ) .where_column( f"{int_table}.{self.local_owner_key}", f"{current_builder.get_table_name()}.{self.local_key}", ) .when(callback, lambda q: (callback(q))) ) def get_with_count_query(self, current_builder, callback): dist_table = self.distant_builder.get_table_name() int_table = self.intermediary_builder.get_table_name() if not current_builder._columns: current_builder.select("*") return_query = current_builder.add_select( f"{self.attribute}_count", lambda q: ( ( q.count("*") .join( f"{int_table}", f"{int_table}.{self.foreign_key}", "=", f"{dist_table}.{self.other_owner_key}", ) .where_column( f"{int_table}.{self.local_owner_key}", f"{current_builder.get_table_name()}.{self.local_key}", ) .table(dist_table) .when( callback, lambda q: ( q.where_in( self.foreign_key, callback( self.distant_builder.select(self.other_owner_key) ), ) ), ) ) ), ) return return_query def map_related(self, related_result): return related_result.group_by(self.local_key) ================================================ FILE: src/masoniteorm/relationships/MorphMany.py ================================================ from ..collection import Collection from ..config import load_config from .BaseRelationship import BaseRelationship class MorphMany(BaseRelationship): def __init__(self, fn, morph_key="record_type", morph_id="record_id"): if isinstance(fn, str): self.fn = None self.morph_key = fn self.morph_id = morph_key else: self.fn = fn self.morph_id = morph_id self.morph_key = morph_key def get_builder(self): return self._related_builder def set_keys(self, owner, attribute): self.morph_id = self.morph_id or "record_id" self.morph_key = self.morph_key or "record_type" return self def __get__(self, instance, owner): """This method is called when the decorated method is accessed. Arguments: instance {object|None} -- The instance we called. If we didn't call the attribute and only accessed it then this will be None. owner {object} -- The current model that the property was accessed on. Returns: object -- Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ self._related_builder = instance.builder self.polymorphic_builder = self.fn(self)() self.set_keys(owner, self.fn) if not instance.is_loaded(): return self if attribute in instance._relationships: return instance._relationships[attribute] return self.apply_query(self._related_builder, instance) def __getattr__(self, attribute): relationship = self.fn(self)() return getattr(relationship.builder, attribute) def apply_query(self, builder, instance): """Apply the query and return a dictionary to be hydrated Arguments: builder {oject} -- The relationship object instance {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ polymorphic_key = self.get_record_key_lookup(builder._model) polymorphic_builder = self.polymorphic_builder return ( polymorphic_builder.where(self.morph_key, polymorphic_key) .where(self.morph_id, instance.get_primary_key_value()) .get() ) def get_related(self, query, relation, eagers=None, callback=None): """Gets the relation needed between the relation and the related builder. If the relation is a collection then will need to pluck out all the keys from the collection and fetch from the related builder. If relation is just a Model then we can just call the model based on the value of the related builders primary key. Args: relation (Model|Collection): Returns: Model|Collection """ if isinstance(relation, Collection): record_type = self.get_record_key_lookup(relation.first()) if callback: return callback( self.polymorphic_builder.where( f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}", record_type, ).where_in( self.morph_id, relation.pluck( relation.first().get_primary_key(), keep_nulls=False ).unique(), ) ).get() return ( self.polymorphic_builder.where( f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}", record_type, ) .where_in( self.morph_id, relation.pluck( relation.first().get_primary_key(), keep_nulls=False ).unique(), ) .get() ) else: record_type = self.get_record_key_lookup(relation) if callback: return callback( self.polymorphic_builder.where(self.morph_key, record_type).where( self.morph_id, relation.get_primary_key_value() ) ).get() return ( self.polymorphic_builder.where(self.morph_key, record_type) .where(self.morph_id, relation.get_primary_key_value()) .get() ) def register_related(self, key, model, collection): record_type = self.get_record_key_lookup(model) related = collection.where(self.morph_key, record_type).where( self.morph_id, model.get_primary_key_value() ) model.add_relation({key: related}) def morph_map(self): return load_config().DB._morph_map def get_record_key_lookup(self, relation): record_type = None for record_type_loop, model in self.morph_map().items(): if model == relation.__class__: record_type = record_type_loop break if not record_type: raise ValueError( f"Could not find the record type key for the {relation} class" ) return record_type ================================================ FILE: src/masoniteorm/relationships/MorphOne.py ================================================ from ..collection import Collection from ..config import load_config from .BaseRelationship import BaseRelationship class MorphOne(BaseRelationship): def __init__(self, fn, morph_key="record_type", morph_id="record_id"): if isinstance(fn, str): self.fn = None self.morph_key = fn self.morph_id = morph_key else: self.fn = fn self.morph_id = morph_id self.morph_key = morph_key def get_builder(self): return self._related_builder def set_keys(self, owner, attribute): self.morph_id = self.morph_id or "record_id" self.morph_key = self.morph_key or "record_type" return self def __get__(self, instance, owner): """This method is called when the decorated method is accessed. Arguments: instance {object|None} -- The instance we called. If we didn't call the attribute and only accessed it then this will be None. owner {object} -- The current model that the property was accessed on. Returns: object -- Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ self._related_builder = instance.builder self.polymorphic_builder = self.fn(self)() self.set_keys(owner, self.fn) if not instance.is_loaded(): return self if attribute in instance._relationships: return instance._relationships[attribute] return self.apply_query(self._related_builder, instance) def __getattr__(self, attribute): relationship = self.fn(self)() return getattr(relationship.builder, attribute) def apply_query(self, builder, instance): """Apply the query and return a dictionary to be hydrated Arguments: builder {oject} -- The relationship object instance {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ polymorphic_key = self.get_record_key_lookup(builder._model) polymorphic_builder = self.polymorphic_builder return ( polymorphic_builder.where(self.morph_key, polymorphic_key) .where(self.morph_id, instance.get_primary_key_value()) .first() ) def get_related(self, query, relation, eagers=None, callback=None): """Gets the relation needed between the relation and the related builder. If the relation is a collection then will need to pluck out all the keys from the collection and fetch from the related builder. If relation is just a Model then we can just call the model based on the value of the related builders primary key. Args: relation (Model|Collection): Returns: Model|Collection """ if isinstance(relation, Collection): record_type = self.get_record_key_lookup(relation.first()) if callback: return callback( self.polymorphic_builder.where( f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}", record_type, ).where_in( self.morph_id, relation.pluck( relation.first().get_primary_key(), keep_nulls=False ).unique(), ) ).get() return ( self.polymorphic_builder.where( f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}", record_type, ) .where_in( self.morph_id, relation.pluck( relation.first().get_primary_key(), keep_nulls=False ).unique(), ) .get() ) else: record_type = self.get_record_key_lookup(relation) if callback: return callback( self.polymorphic_builder.where(self.morph_key, record_type).where( self.morph_id, relation.get_primary_key_value() ) ).first() return ( self.polymorphic_builder.where(self.morph_key, record_type) .where(self.morph_id, relation.get_primary_key_value()) .first() ) def register_related(self, key, model, collection): record_type = self.get_record_key_lookup(model) related = ( collection.where(self.morph_key, record_type) .where(self.morph_id, model.get_primary_key_value()) .first() ) model.add_relation({key: related}) def morph_map(self): return load_config().DB._morph_map def get_record_key_lookup(self, relation): record_type = None for record_type_loop, model in self.morph_map().items(): if model == relation.__class__: record_type = record_type_loop break if not record_type: raise ValueError( f"Could not find the record type key for the {relation} class" ) return record_type ================================================ FILE: src/masoniteorm/relationships/MorphTo.py ================================================ from ..collection import Collection from ..config import load_config from .BaseRelationship import BaseRelationship class MorphTo(BaseRelationship): def __init__(self, fn, morph_key="record_type", morph_id="record_id"): if isinstance(fn, str): self.fn = None self.morph_key = fn self.morph_id = morph_key else: self.fn = fn self.morph_id = morph_id self.morph_key = morph_key def get_builder(self): return self._related_builder def set_keys(self, owner, attribute): self.morph_id = self.morph_id or "record_id" self.morph_key = self.morph_key or "record_type" return self def __get__(self, instance, owner): """This method is called when the decorated method is accessed. Arguments: instance {object|None} -- The instance we called. If we didn't call the attribute and only accessed it then this will be None. owner {object} -- The current model that the property was accessed on. Returns: object -- Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ self._related_builder = instance.builder self.set_keys(owner, self.fn) if not instance.is_loaded(): return self if attribute in instance._relationships: return instance._relationships[attribute] return self.apply_query(self._related_builder, instance) def __getattr__(self, attribute): relationship = self.fn(self)() return getattr(relationship._related_builder, attribute) def apply_query(self, builder, instance): """Apply the query and return a dictionary to be hydrated Arguments: builder {oject} -- The relationship object instance {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ model = self.morph_map().get(instance.__attributes__[self.morph_key]) record = instance.__attributes__[self.morph_id] return model.where(model.get_primary_key(), record).first() def get_related(self, query, relation, eagers=None, callback=None): """Gets the relation needed between the relation and the related builder. If the relation is a collection then will need to pluck out all the keys from the collection and fetch from the related builder. If relation is just a Model then we can just call the model based on the value of the related builders primary key. Args: relation (Model|Collection): Returns: Model|Collection """ if isinstance(relation, Collection): relations = Collection() for group, items in relation.group_by(self.morph_key).items(): morphed_model = self.morph_map().get(group) relations.merge( morphed_model.where_in( f"{morphed_model.get_table_name()}.{morphed_model.get_primary_key()}", Collection(items) .pluck(self.morph_id, keep_nulls=False) .unique(), ).get() ) return relations else: model = self.morph_map().get(getattr(relation, self.morph_key)) if model: return model.find(getattr(relation, self.morph_id)) def register_related(self, key, model, collection): morphed_model = self.morph_map().get(getattr(model, self.morph_key)) related = collection.where( morphed_model.get_primary_key(), getattr(model, self.morph_id) ).first() model.add_relation({key: related}) def morph_map(self): return load_config().DB._morph_map def map_related(self, related_result): return related_result ================================================ FILE: src/masoniteorm/relationships/MorphToMany.py ================================================ from ..collection import Collection from ..config import load_config from .BaseRelationship import BaseRelationship class MorphToMany(BaseRelationship): def __init__(self, fn, morph_key="record_type", morph_id="record_id"): if isinstance(fn, str): self.fn = None self.morph_key = fn self.morph_id = morph_key else: self.fn = fn self.morph_id = morph_id self.morph_key = morph_key def get_builder(self): return self._related_builder def set_keys(self, owner, attribute): self.morph_id = self.morph_id or "record_id" self.morph_key = self.morph_key or "record_type" return self def __get__(self, instance, owner): """This method is called when the decorated method is accessed. Arguments: instance {object|None} -- The instance we called. If we didn't call the attribute and only accessed it then this will be None. owner {object} -- The current model that the property was accessed on. Returns: object -- Either returns a builder or a hydrated model. """ attribute = self.fn.__name__ self._related_builder = instance.builder self.set_keys(owner, self.fn) if not instance.is_loaded(): return self if attribute in instance._relationships: return instance._relationships[attribute] return self.apply_query(self._related_builder, instance) def __getattr__(self, attribute): relationship = self.fn(self)() return getattr(relationship.builder, attribute) def apply_query(self, builder, instance): """Apply the query and return a dictionary to be hydrated Arguments: builder {oject} -- The relationship object instance {object} -- The current model oject. Returns: dict -- A dictionary of data which will be hydrated. """ model = self.morph_map().get(instance.__attributes__[self.morph_key]) record = instance.__attributes__[self.morph_id] return model.where(model.get_primary_key(), record).first() def get_related(self, query, relation, eagers=None, callback=None): """Gets the relation needed between the relation and the related builder. If the relation is a collection then will need to pluck out all the keys from the collection and fetch from the related builder. If relation is just a Model then we can just call the model based on the value of the related builders primary key. Args: relation (Model|Collection): Returns: Model|Collection """ if isinstance(relation, Collection): relations = Collection() for group, items in relation.group_by(self.morph_key).items(): morphed_model = self.morph_map().get(group) relations.merge( morphed_model.where_in( f"{morphed_model.get_table_name()}.{morphed_model.get_primary_key()}", Collection(items) .pluck(self.morph_id, keep_nulls=False) .unique(), ).get() ) return relations else: model = self.morph_map().get(getattr(relation, self.morph_key)) if model: return model.find([getattr(relation, self.morph_id)]) def register_related(self, key, model, collection): morphed_model = self.morph_map().get(getattr(model, self.morph_key)) related = collection.where( morphed_model.get_primary_key(), getattr(model, self.morph_id) ) model.add_relation({key: related}) def morph_map(self): return load_config().DB._morph_map ================================================ FILE: src/masoniteorm/relationships/__init__.py ================================================ from .BelongsTo import BelongsTo as belongs_to from .BelongsToMany import BelongsToMany as belongs_to_many from .HasMany import HasMany as has_many from .HasManyThrough import HasManyThrough as has_many_through from .HasOne import HasOne as has_one from .HasOneThrough import HasOneThrough as has_one_through from .MorphMany import MorphMany as morph_many from .MorphOne import MorphOne as morph_one from .MorphTo import MorphTo as morph_to from .MorphToMany import MorphToMany as morph_to_many ================================================ FILE: src/masoniteorm/schema/Blueprint.py ================================================ class Blueprint: """Used for building schemas for creating, modifying or altering schema.""" def __init__( self, grammar, table="", connection=None, platform=None, schema=None, action=None, default_string_length=None, dry=False, ): self.grammar = grammar self.table = table self._last_column = None self._default_string_length = default_string_length self.platform = platform self.schema = schema self._dry = dry self._action = action self.connection = connection if not platform: self.platform = self.connection.get_default_platform() def string(self, column, length=255, nullable=False): """Sets a column to be the string representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {255}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "string", length=length, nullable=nullable ) return self def tiny_integer(self, column, length=1, nullable=False): """Sets a column to be the tiny_integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {1}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "tiny_integer", length=length, nullable=nullable ) return self def small_integer(self, column, length=5, nullable=False): """Sets a column to be the small_integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {5}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "small_integer", length=length, nullable=nullable ) return self def medium_integer(self, column, length=7, nullable=False): """Sets a column to be the medium_integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {7}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "medium_integer", length=length, nullable=nullable ) return self def integer(self, column, length=11, nullable=False): """Sets a column to be the integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {11}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "integer", length=length, nullable=nullable ) return self def big_integer(self, column, length=32, nullable=False): """Sets a column to be the big_integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {32}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "big_integer", length=length, nullable=nullable ) return self def unsigned_big_integer(self, column, length=32, nullable=False): """Sets a column to be the unsigned big_integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {32}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ return self.big_integer(column, length=length, nullable=nullable).unsigned() def _compile_create(self): return self.grammar(creates=self._columns, table=self.table)._compile_create() def _compile_alter(self): return self.grammar(creates=self._columns, table=self.table)._compile_create() def increments(self, column, nullable=False): """Sets a column to be the auto incrementing primary key representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "increments", nullable=nullable ) self.primary(column) return self def tiny_increments(self, column, nullable=False): """Sets a column to be the auto tiny incrementing primary key representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "tiny_increments", nullable=nullable ) self.primary(column) return self def id(self, column="id"): """Sets a column to be the auto-incrementing big integer (8-byte) primary key representation for the table. Arguments: column {string} -- The column name. Defaults to "id". Returns: self """ return self.big_increments(column) def uuid(self, column, nullable=False, length=36): """Sets a column to be the UUID4 representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "uuid", nullable=nullable, length=length ) return self def big_increments(self, column, nullable=False): """Sets a column to be the the big integer increments representation for the table Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "big_increments", nullable=nullable ) self.primary(column) return self def binary(self, column, nullable=False): """Sets a column to be the binary representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "binary", nullable=nullable) return self def boolean(self, column, nullable=False): """Sets a column to be the boolean representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "boolean", nullable=nullable) return self def default(self, value, raw=False): self._last_column.default = value self._last_column.default_is_raw = raw return self def default_raw(self, value): self.default(value, True) return self def char(self, column, length=1, nullable=False): """Sets a column to be the char representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length for the column (default: {1}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "char", length=length, nullable=nullable ) return self def date(self, column, nullable=False): """Sets a column to be the date representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "date", nullable=nullable) return self def time(self, column, nullable=False): """Sets a column to be the time representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "time", nullable=nullable) return self def datetime(self, column, nullable=False, now=False): """Sets a column to be the datetime representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) now {bool} -- Whether the default for the column should be the current time. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "datetime", nullable=nullable) if now: self._last_column.use_current() return self def timestamp(self, column, nullable=False, now=False): """Sets a column to be the timestamp representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) now {bool} -- Whether the default for the column should be the current time. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "timestamp", nullable=nullable ) if now: self._last_column.use_current() return self def timestamps(self): """Creates `created_at` and `updated_at` timestamp columns. Returns: self """ self.datetime("created_at", nullable=True, now=True) self.datetime("updated_at", nullable=True, now=True) return self def decimal(self, column, length=17, precision=6, nullable=False): """Sets a column to be the decimal representation for the table. Arguments: column {string} -- The name of the column. Keyword Arguments: length {int} -- The total length of the decimal number (default: {17}) precision {int} -- The number of places that should be used for floating numbers. (default: {6}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "decimal", length="{length}, {precision}".format(length=length, precision=precision), nullable=nullable, ) return self def float(self, column, length=19, precision=4, nullable=False): """Sets a column to be the float representation for the table. Arguments: column {string} -- The name of the column. Keyword Arguments: length {int} -- The total length of the float number (default: {17}) precision {int} -- The number of places that should be used for floating numbers. (default: {6}) nullable {bool} -- Whether the column is nullable (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "float", length="{length}, {precision}".format(length=length, precision=precision), nullable=nullable, ) return self def double(self, column, nullable=False): """Sets a column to be the the double representation for the table Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "double", nullable=nullable) return self def enum(self, column, options=None, nullable=False): """Sets a column to be the enum representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: options {list} -- A list of available options for the enum. (default: {False}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ options = options or [] new_options = "" for option in options: new_options += "'{}',".format(option) new_options = new_options.rstrip(",") self._last_column = self.table.add_column( column, "enum", length="255", values=options, nullable=nullable ) return self def text(self, column, length=None, nullable=False): """Sets a column to be the text representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column if any. (default: {False}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "text", length=length, nullable=nullable ) return self def tiny_text(self, column, length=None, nullable=False): """Sets a column to be the text representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column if any. (default: {False}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "tiny_text", length=length, nullable=nullable ) return self def unsigned_decimal(self, column, length=17, precision=6, nullable=False): """Sets a column to be the text representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column if any. (default: {False}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "decimal", length="{length}, {precision}".format(length=length, precision=precision), nullable=nullable, ).unsigned() return self return self def long_text(self, column, length=None, nullable=False): """Sets a column to be the long_text representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column if any. (default: {False}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "long_text", length=length, nullable=nullable ) return self def json(self, column, nullable=False): """Sets a column to be the json representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "json", nullable=nullable) return self def jsonb(self, column, nullable=False): """Sets a column to be the jsonb representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "jsonb", nullable=nullable) return self def inet(self, column, length=255, nullable=False): """Sets a column to be the inet representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "inet", length=255, nullable=nullable ) return self def cidr(self, column, length=255, nullable=False): """Sets a column to be the cidr representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "cidr", length=255, nullable=nullable ) return self def macaddr(self, column, length=255, nullable=False): """Sets a column to be the macaddr representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "macaddr", length=255, nullable=nullable ) return self def point(self, column, nullable=False): """Sets a column to be the point representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "point", nullable=nullable) return self def geometry(self, column, nullable=False): """Sets a column to be the geometry representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column(column, "geometry", nullable=nullable) return self def year(self, column, length=4, default=None, nullable=False): """Sets a column to be the year representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "year", length=length, nullable=nullable, default=default ) return self def unsigned(self, column=None, length=None, nullable=False): """Sets a column to be the unsigned integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: length {int} -- The length of the column. (default: {False}) nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ if not column: self._last_column.unsigned() return self self._last_column = self.table.add_column( column, "unsigned", length=length, nullable=nullable ).unsigned() return self def unsigned_integer(self, column, nullable=False): """Sets a column to be the unsigned integer representation for the table. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ self._last_column = self.table.add_column( column, "integer", nullable=nullable ).unsigned() return self def morphs(self, column, nullable=False, indexes=True): """Sets a column to be used in a polymorphic relationship. Arguments: column {string} -- The column name. Keyword Arguments: nullable {bool} -- Whether the column is nullable. (default: {False}) Returns: self """ _columns = [] _columns.append( self.table.add_column( "{}_id".format(column), "integer", nullable=nullable ).unsigned() ) _columns.append( self.table.add_column( "{}_type".format(column), "string", nullable=nullable, length=self._default_string_length, ) ) if indexes: for column in _columns: self.index(column.name) self._last_column = _columns return self def to_sql(self): """Compiles the blueprint class into a sql statement. Returns: string -- The SQL statement generated. """ if self._action == "create": return self.platform().compile_create_sql(self.table) elif self._action == "create_table_if_not_exists": return self.platform().compile_create_sql(self.table, if_not_exists=True) else: if not self._dry: # get current table schema table = self.platform().get_current_schema( self.connection, self.table.name, schema=self.schema ) self.table.from_table = table return self.platform().compile_alter_sql(self.table) def __enter__(self): return self def __exit__(self, exc_type, exc_value, exc_traceback): if self._dry: return return self.connection.query(self.to_sql(), ()) def nullable(self): """Sets the last columns created as nullable Returns: self """ _columns = [] if not isinstance(self._last_column, list): _columns = [self._last_column] for column in _columns: column.nullable() return self def soft_deletes(self, name="deleted_at"): return self.datetime(name, nullable=True).nullable() def unique(self, column=None, name=None): """Sets the last column to be unique if no column name is passed. If a column name is passed this method will create a new unique column. Keyword Arguments: column {string} -- The name of the column. (default: {None}) Returns: self """ if not column: column = self._last_column.name if not isinstance(column, list): column = [column] self.table.add_constraint( name or f"{self.table.name}_{'_'.join(column)}_unique", "unique", columns=column, ) return self def index(self, column=None, name=None): """Creates a constraint based on the index constraint representation of the table. Arguments: column {string} -- The name of the column to create the index on. Returns: self """ if not column: column = self._last_column.name if not isinstance(column, list): column = [column] self.table.add_index( column, name or f"{self.table.name}_{'_'.join(column)}_index", "index" ) return self def fulltext(self, column=None, name=None): """Creates a constraint based on the full text constraint representation of the table. Arguments: column {string} -- The name of the column to create the index on. Returns: self """ if not column: column = self._last_column.name if not isinstance(column, list): column = [column] self.table.add_constraint( name or f"{'_'.join(column)}_fulltext", "fulltext", column ) return self def primary(self, column=None, name=None): """Creates a constraint based on the primary key constraint representation of the table. Sets the constraint on the last column if no column name is passed. Arguments: column {string} -- The name of the column to create the index on. (default: {None}) Returns: self """ if column is None: column = self._last_column.name if not isinstance(column, list): column = [column] self.table.add_constraint( name or f"{self.table.name}_{'_'.join(column)}_primary", "primary_key", columns=column, ) return self def add_foreign(self, columns, name=None): """Creates the foreign spliting the foreign name, reference column, and reference table. Arguments: columns {string} -- The name of the from_column . to_column . table """ if len(columns.split(".")) != 3: raise Exception( "Wrong add_foreign argument, the struncture is from_column.to_column.table" ) from_column, to_column, table = columns.split(".") return self.foreign(from_column, name=name).references(to_column).on(table) def foreign(self, column, name=None): """Starts the creation of a foreign key constraint Arguments: column {string} -- The name of the column to create the index on. Returns: self """ self._last_foreign = self.table.add_foreign_key( column, name=name or f"{self.table.name}_{column}_foreign" ) return self def foreign_id(self, column): """Sets a column to be a unsigned big integer (8-byte) representation for a foreign ID. Arguments: column {string} -- The name of the column to reference. Returns: self """ return self.unsigned_big_integer(column).foreign(column) def foreign_uuid(self, column): """Sets a column to be a UUID representation for a foreign UUID. Arguments: column {string} -- The name of the column to reference. Returns: self """ return self.uuid(column).foreign(column) def foreign_id_for(self, model, column=None): """Sets a column to be a unsigned big integer (8-byte) representation for a foreign ID. Arguments: model {Model} -- The model to reference. Returns: self """ clm = column if column else model.get_foreign_key() return ( self.foreign_id(clm) if model.get_primary_key_type() == "int" else self.foreign_uuid(column) ) def references(self, column): """Sets the other column on the foreign table that the local column will use to reference. Arguments: column {string} -- The name of the column to create the index on. Returns: self """ self._last_foreign.references(column) return self def on(self, table): """Sets the foreign table that the local column will use to reference on. Arguments: table {string} -- The foreign table name. Returns: self """ self._last_foreign.on(table) return self def on_delete(self, action): """Sets the last foreign key to a specific on delete action. Arguments: action {string} -- The specific action to do on delete. Returns: self """ self._last_foreign.on_delete(action) return self def on_update(self, action): """Sets the last foreign key to a specific on update action. Arguments: action {string} -- The specific action to do on update. Returns: self """ self._last_foreign.on_update(action) return self def comment(self, comment): self._last_column.add_comment(comment) return self def table_comment(self, comment): self.table.add_comment(comment) return self def rename(self, old_column, new_column, data_type, length=None): """Rename a column from the old value to a new value. Arguments: old_column {string} -- The name of the original old column name. new_column {string} -- The name of the new column name. Returns: self """ self.table.rename_column(old_column, new_column, data_type, length=length) return self def after(self, old_column): """Sets the column that this new column should be created after. This is useful for setting the location of the new column in the table schema. Arguments: old_column {string} -- The column that this new column should be created after Returns: self """ self._last_column.after(old_column) return self def drop_column(self, *columns): """Sets columns that should be dropped Returns: self """ for column in columns: self.table.drop_column(column) return self def drop_index(self, index): """Specifies indexes that should be dropped. Arguments: indexes {list|string} -- Either a list of indexes or a specific index. Returns: self """ if isinstance(index, list): for column in index: self.table.remove_index(f"{self.table.name}_{column}_index") return self self.table.remove_index(index) return self def change(self): self.table.change_column(self._last_column) return self def drop_unique(self, index): """Drops a unique index. Arguments: indexes {list|string} -- Either a list of indexes or a specific index. Returns: self """ if isinstance(index, list): for column in index: self.table.remove_unique_index(f"{self.table.name}_{column}_unique") return self self.table.remove_unique_index(index) def drop_primary(self, index): """Drops a unique index. Arguments: indexes {list|string} -- Either a list of indexes or a specific index. Returns: self """ if isinstance(index, list): for column in index: self.table.drop_primary(f"{self.table.name}_{column}_primary") return self self.table.drop_primary(index) def drop_foreign(self, index): """Drops foreign key indexes. Arguments: indexes {list|string} -- Either a list of indexes or a specific index. Returns: self """ if isinstance(index, list): for column in index: self.table.drop_foreign(f"{self.table.name}_{column}_foreign") return self self.table.drop_foreign(index) return self ================================================ FILE: src/masoniteorm/schema/Column.py ================================================ class Column: """Used for creating or modifying columns.""" def __init__( self, name, column_type, length=None, values=None, nullable=False, default=None, signed=None, default_is_raw=False, column_python_type=str, ): self.column_type = column_type self.column_python_type = column_python_type self.name = name self.length = length self.values = values or [] self.is_null = nullable self._after = None self.old_column = "" self.default = default self._signed = signed self.default_is_raw = default_is_raw self.primary = False self.comment = None def nullable(self): """Sets this column to be nullable Returns: self """ self.is_null = True return self def signed(self): """Sets this column to be nullable Returns: self """ self._signed = "signed" return self def unsigned(self): """Sets this column to be nullable Returns: self """ self._signed = "unsigned" return self def not_nullable(self): """Sets this column to be not nullable Returns: self """ self.is_null = False return self def set_as_primary(self): self.primary = True def rename(self, column): """Renames this column to a new name Arguments: column {string} -- The old column name Returns: self """ self.old_column = column return self def after(self, after): """Sets the column that this new column should be created after. This is useful for setting the location of the new column in the table schema. Arguments: after {string} -- The column that this new column should be created after Returns: self """ self._after = after return self def get_after_column(self): """Sets the column that this new column should be created after. This is useful for setting the location of the new column in the table schema. Arguments: after {string} -- The column that this new column should be created after Returns: self """ return self._after def default(self, value, raw=False): """Sets a default value for this column Arguments: value {string} -- A default value. raw {bool} -- should the value be quoted Returns: self """ self.default = value self.default_is_raw = raw return self def change(self): """Sets the schema to create a modify sql statement. Returns: self """ self._action = "modify" return self def use_current(self): """Sets the column to use a current timestamp. Used for timestamp columns. Returns: self """ self.default = "current" return self def add_comment(self, comment): self.comment = comment return self ================================================ FILE: src/masoniteorm/schema/ColumnDiff.py ================================================ ================================================ FILE: src/masoniteorm/schema/Constraint.py ================================================ class Constraint: def __init__(self, name, constraint_type, columns=None): self.name = name self.constraint_type = constraint_type self.columns = columns or [] ================================================ FILE: src/masoniteorm/schema/ForeignKeyConstraint.py ================================================ class ForeignKeyConstraint: def __init__(self, column, foreign_table, foreign_column, name=None): self.column = column self.foreign_table = foreign_table self.foreign_column = foreign_column self.delete_action = None self.update_action = None self.constraint_name = name def references(self, foreign_column): self.foreign_column = foreign_column return self def on(self, foreign_table): self.foreign_table = foreign_table return self def on_delete(self, action): self.delete_action = action return self def on_update(self, action): self.update_action = action return self def name(self, name): self.constraint_name = name return self ================================================ FILE: src/masoniteorm/schema/Index.py ================================================ class Index: def __init__(self, column, name, index_type): self.column = column self.name = name self.index_type = index_type ================================================ FILE: src/masoniteorm/schema/Schema.py ================================================ from .Blueprint import Blueprint from .Table import Table from .TableDiff import TableDiff from ..exceptions import ConnectionNotRegistered from ..config import load_config class Schema: _default_string_length = "255" _type_hints_map = { "string": str, "char": str, "big_increments": int, "integer": int, "big_integer": int, "tiny_integer": int, "small_integer": int, "medium_integer": int, "integer_unsigned": int, "big_integer_unsigned": int, "tiny_integer_unsigned": int, "small_integer_unsigned": int, "medium_integer_unsigned": int, "increments": int, "uuid": str, "binary": bytes, "boolean": bool, "decimal": float, "double": float, "enum": str, "text": str, "float": float, "geometry": str, # ? "json": dict, "jsonb": bytes, "inet": str, "cidr": str, "macaddr": str, "long_text": str, "point": str, # ? "time": str, # or pendulum.DateTime "timestamp": str, # or pendulum.DateTime "date": str, # or pendulum.DateTime "year": str, "datetime": str, # or pendulum.DateTime "tiny_increments": int, "unsigned": int, "unsigned_integer": int, } def __init__( self, dry=False, connection="default", connection_class=None, platform=None, grammar=None, connection_details=None, schema=None, config_path=None, ): self._dry = dry self.connection = connection self.connection_class = connection_class self._connection = None self.grammar = grammar self.platform = platform self.connection_details = connection_details or {} self._blueprint = None self._sql = None self.schema = schema self.config_path = config_path if not self.connection_class: self.on(self.connection) if not self.platform: self.platform = self.connection_class.get_default_platform() def on(self, connection_key): """Change the connection from the default connection Arguments: connection {string} -- A connection string like 'mysql' or 'mssql'. It will be made with the connection factory. Returns: cls """ DB = load_config(config_path=self.config_path).DB if connection_key == "default": self.connection = self.connection_details.get("default") connection_detail = self._connection_driver = self.connection_details.get( self.connection ) if connection_detail: self._connection_driver = connection_detail.get("driver") else: raise ConnectionNotRegistered( f"Could not find the '{connection_key}' connection details" ) self.connection_class = DB.connection_factory.make(self._connection_driver) return self def dry(self): """Whether the query should be executed. (default: {False}) Returns: self """ self._dry = True return self def create(self, table): """Sets the table and returns the blueprint. This should be used as a context manager. Arguments: table {string} -- The name of a table like 'users' Returns: masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object. """ self._table = table self._blueprint = Blueprint( self.grammar, connection=self.new_connection(), table=Table(table), action="create", platform=self.platform, schema=self.schema, default_string_length=self._default_string_length, dry=self._dry, ) return self._blueprint def create_table_if_not_exists(self, table): self._table = table self._blueprint = Blueprint( self.grammar, connection=self.new_connection(), table=Table(table), action="create_table_if_not_exists", platform=self.platform, schema=self.schema, default_string_length=self._default_string_length, dry=self._dry, ) return self._blueprint def table(self, table): """Sets the table and returns the blueprint. This should be used as a context manager. Arguments: table {string} -- The name of a table like 'users' Returns: masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object. """ self._table = table self._blueprint = Blueprint( self.grammar, connection=self.new_connection(), table=TableDiff(table), action="alter", platform=self.platform, schema=self.schema, default_string_length=self._default_string_length, dry=self._dry, ) return self._blueprint def get_connection_information(self): return { "host": self.connection_details.get(self.connection, {}).get("host"), "database": self.connection_details.get(self.connection, {}).get( "database" ), "user": self.connection_details.get(self.connection, {}).get("user"), "port": self.connection_details.get(self.connection, {}).get("port"), "password": self.connection_details.get(self.connection, {}).get( "password" ), "prefix": self.connection_details.get(self.connection, {}).get("prefix"), "options": self.connection_details.get(self.connection, {}).get( "options", {} ), "full_details": self.connection_details.get(self.connection), } def new_connection(self): if self._dry: return self._connection = ( self.connection_class(**self.get_connection_information()) .set_schema(self.schema) .make_connection() ) return self._connection def has_column(self, table, column, query_only=False): """Checks if the a table has a specific column Arguments: table {string} -- The name of a table like 'users' Returns: masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object. """ sql = self.platform().compile_column_exists(table, column) if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) def get_columns(self, table, dict=True): table = self.platform().get_current_schema( self.new_connection(), table, schema=self.get_schema() ) result = {} if dict: for column in table.get_added_columns().items(): result.update({column[0]: column[1]}) return result else: return table.get_added_columns().items() @classmethod def set_default_string_length(cls, length): cls._default_string_length = length return cls def drop_table(self, table, query_only=False): sql = self.platform().compile_drop_table(table) if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) def drop(self, *args, **kwargs): return self.drop_table(*args, **kwargs) def drop_table_if_exists(self, table, exists=False, query_only=False): sql = self.platform().compile_drop_table_if_exists(table) if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) def rename(self, table, new_name): sql = self.platform().compile_rename_table(table, new_name) if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) def truncate(self, table, foreign_keys=False): sql = self.platform().compile_truncate(table, foreign_keys=foreign_keys) if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) def get_schema(self): """Gets the schema set on the migration class""" return self.schema or self.get_connection_information().get("full_details").get( "schema" ) def get_all_tables(self): """Gets all tables in the database""" sql = self.platform().compile_get_all_tables( database=self.get_connection_information().get("database"), schema=self.get_schema(), ) if self._dry: self._sql = sql return sql result = self.new_connection().query(sql, ()) return list(map(lambda t: list(t.values())[0], result)) if result else [] def has_table(self, table, query_only=False): """Checks if the a database has a specific table Arguments: table {string} -- The name of a table like 'users' Returns: masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object. """ sql = self.platform().compile_table_exists( table, database=self.get_connection_information().get("database"), schema=self.get_schema(), ) if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) def enable_foreign_key_constraints(self): sql = self.platform().enable_foreign_key_constraints() if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) def disable_foreign_key_constraints(self): sql = self.platform().disable_foreign_key_constraints() if self._dry: self._sql = sql return sql return bool(self.new_connection().query(sql, ())) ================================================ FILE: src/masoniteorm/schema/Table.py ================================================ from .Column import Column from .Constraint import Constraint from .Index import Index from .ForeignKeyConstraint import ForeignKeyConstraint class Table: def __init__(self, table): self.name = table self.added_columns = {} self.added_constraints = {} self.added_indexes = {} self.added_foreign_keys = {} self.renamed_columns = {} self.drop_indexes = {} self.foreign_keys = {} self.primary_key = None self.comment = None def add_column( self, name=None, column_type=None, length=None, values=None, nullable=False, default=None, signed=None, default_is_raw=False, primary=False, column_python_type=str, ): column = Column( name, column_type, length=length, nullable=nullable, values=values or [], default=default, signed=signed, default_is_raw=default_is_raw, column_python_type=column_python_type, ) if primary: column.set_as_primary() self.added_columns.update({name: column}) return column def add_constraint(self, name, constraint_type, columns=None): self.added_constraints.update( {name: Constraint(name, constraint_type, columns=columns or [])} ) def add_foreign_key(self, column, table=None, foreign_column=None, name=None): foreign_key = ForeignKeyConstraint( column, table, foreign_column, name=name or f"{self.name}_{column}_foreign" ) self.added_foreign_keys.update({column: foreign_key}) return foreign_key def get_added_foreign_keys(self): return self.added_foreign_keys def get_constraint(self, name): return self.added_constraints[name] def get_added_constraints(self): return self.added_constraints def get_added_columns(self): return self.added_columns def get_renamed_columns(self): return self.added_columns def set_primary_key(self, columns): self.primary_key = columns return self def add_index(self, column, name, index_type): self.added_indexes.update({name: Index(column, name, index_type)}) def get_index(self, name): return self.added_indexes[name] def add_comment(self, comment): self.comment = comment return self ================================================ FILE: src/masoniteorm/schema/TableDiff.py ================================================ from .Column import Column from .Table import Table class TableDiff(Table): def __init__(self, name): self.name = name self.from_table = None self.new_name = None self.removed_indexes = [] self.removed_unique_indexes = [] self.added_indexes = {} self.added_columns = {} self.changed_columns = {} self.dropped_columns = [] self.dropped_foreign_keys = [] self.dropped_primary_keys = [] self.renamed_columns = {} self.removed_constraints = {} self.added_constraints = {} self.added_foreign_keys = {} self.comment = None def remove_constraint(self, name): self.removed_constraints.update({name: self.from_table.get_constraint(name)}) def get_removed_constraints(self): return self.removed_constraints def get_renamed_columns(self): return self.renamed_columns def rename_column( self, original_name, new_name, column_type=None, length=None, nullable=False, default=None, ): self.renamed_columns.update( { original_name: Column( new_name, column_type, length=length, nullable=nullable, default=default, ) } ) def remove_index(self, name): self.removed_indexes.append(name) def remove_unique_index(self, name): self.removed_unique_indexes.append(name) def drop_column(self, name): self.dropped_columns.append(name) def get_dropped_columns(self): return self.dropped_columns def get_dropped_foreign_keys(self): return self.dropped_foreign_keys def drop_foreign(self, name): self.dropped_foreign_keys.append(name) return self def drop_primary(self, name): self.dropped_primary_keys.append(name) return self def change_column(self, added_column): self.added_columns.pop(added_column.name) self.changed_columns.update({added_column.name: added_column}) def add_comment(self, comment): self.comment = comment return self ================================================ FILE: src/masoniteorm/schema/__init__.py ================================================ from .Schema import Schema from .Table import Table from .Column import Column ================================================ FILE: src/masoniteorm/schema/platforms/MSSQLPlatform.py ================================================ from .Platform import Platform from ..Table import Table class MSSQLPlatform(Platform): types_without_lengths = [ "integer", "big_integer", "tiny_integer", "small_integer", "medium_integer", ] type_map = { "string": "VARCHAR", "char": "CHAR", "big_increments": "BIGINT IDENTITY", "integer": "INT", "big_integer": "BIGINT", "tiny_integer": "TINYINT", "small_integer": "SMALLINT", "medium_integer": "MEDIUMINT", "integer_unsigned": "INT", "big_integer_unsigned": "BIGINT", "tiny_integer_unsigned": "TINYINT", "small_integer_unsigned": "SMALLINT", "medium_integer_unsigned": "MEDIUMINT", "increments": "INT IDENTITY", "uuid": "CHAR", "binary": "LONGBLOB", "boolean": "BOOLEAN", "decimal": "DECIMAL", "double": "DOUBLE", "enum": "VARCHAR", "text": "TEXT", "tiny_text": "TINYTEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", "jsonb": "LONGBLOB", "inet": "VARCHAR", "cidr": "VARCHAR", "macaddr": "VARCHAR", "long_text": "LONGTEXT", "point": "POINT", "time": "TIME", "timestamp": "DATETIME", "date": "DATE", "year": "YEAR", "datetime": "DATETIME", "tiny_increments": "TINYINT IDENTITY", "unsigned": "INT", "unsigned_integer": "INT", } premapped_nulls = {True: "NULL", False: "NOT NULL"} premapped_defaults = { "current": " DEFAULT CURRENT_TIMESTAMP", "now": " DEFAULT NOW()", "null": " DEFAULT NULL", } def compile_create_sql(self, table, if_not_exists=False): sql = [] table_create_format = ( self.create_if_not_exists_format() if if_not_exists else self.create_format() ) sql.append( table_create_format.format( table=self.wrap_table(table.name), columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( ", " + ", ".join( self.constraintize(table.get_added_constraints(), table) ) if table.get_added_constraints() else "" ), foreign_keys=( ", " + ", ".join( self.foreign_key_constraintize( table.name, table.added_foreign_keys ) ) if table.added_foreign_keys else "" ), ) ) if table.added_indexes: for name, index in table.added_indexes.items(): sql.append( "CREATE INDEX {name} ON {table}({column})".format( name=index.name, table=self.wrap_table(table.name), column=",".join(index.column), ) ) return sql def compile_alter_sql(self, table): sql = [] if table.added_columns: sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns="ADD " + ", ".join(self.columnize(table.added_columns)).strip(), ) ) if table.changed_columns: sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns="ALTER COLUMN " + ", ".join(self.columnize(table.changed_columns)).strip(), ) ) if table.renamed_columns: for name, column in table.get_renamed_columns().items(): sql.append( self.rename_column_string(table.name, name, column.name).strip() ) if table.dropped_columns: dropped_sql = [] for name in table.get_dropped_columns(): dropped_sql.append(self.drop_column_string().format(name=name).strip()) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns="DROP COLUMN " + ", ".join(dropped_sql), ) ) if table.added_foreign_keys: for ( column, foreign_key_constraint, ) in table.get_added_foreign_keys().items(): cascade = "" if foreign_key_constraint.delete_action: cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key_constraint.delete_action.lower())}" if foreign_key_constraint.update_action: cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key_constraint.update_action.lower())}" sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD " + self.get_foreign_key_constraint_string().format( constraint_name=foreign_key_constraint.constraint_name, column=self.wrap_column(column), table=self.wrap_table(table.name), foreign_table=self.wrap_table( foreign_key_constraint.foreign_table ), foreign_column=self.wrap_column( foreign_key_constraint.foreign_column ), cascade=cascade, ) ) if table.dropped_foreign_keys: constraints = table.dropped_foreign_keys for constraint in constraints: sql.append( f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}" ) if table.added_indexes: for name, index in table.added_indexes.items(): sql.append( "CREATE INDEX {name} ON {table}({column})".format( name=index.name, table=self.wrap_table(table.name), column=",".join(index.column), ) ) if ( table.removed_indexes or table.removed_unique_indexes or table.dropped_primary_keys ): constraints = table.removed_indexes constraints += table.removed_unique_indexes constraints += table.dropped_primary_keys for constraint in constraints: sql.append( f"DROP INDEX {self.wrap_table(table.name)}.{self.wrap_table(constraint)}" ) if table.added_constraints: for name, constraint in table.added_constraints.items(): if constraint.constraint_type == "unique": sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} UNIQUE({','.join(constraint.columns)})" ) elif constraint.constraint_type == "fulltext": pass elif constraint.constraint_type == "primary_key": sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})" ) return sql def add_column_string(self): return "{name} {data_type}{length}" def drop_column_string(self): return "{name}" def rename_column_string(self, table, old, new): return f"EXEC sp_rename '{table}.{old}', '{new}', 'COLUMN'" def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" if column.default == "": default = " DEFAULT ''" elif column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" constraint = "" column_constraint = "" if column.primary: constraint = " PRIMARY KEY" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f" CHECK([{column.name}] IN ({values}))" sql.append( self.columnize_string() .format( name=column.name, data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, length=length, constraint=constraint, nullable=self.premapped_nulls.get(column.is_null) or "", default=default, ) .strip() ) return sql def columnize_string(self): return "[{name}] {data_type}{length} {nullable}{default}{column_constraint}{constraint}" def constraintize(self, constraints, table): sql = [] for name, constraint in constraints.items(): sql.append( getattr( self, f"get_{constraint.constraint_type}_constraint_string" )().format( columns=", ".join(constraint.columns), name_columns="_".join(constraint.columns), constraint_name=constraint.name, table=table.name, ) ) return sql def get_table_string(self): return "[{table}]" def get_column_string(self): return "[{column}]" def create_format(self): return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys})" def create_if_not_exists_format(self): return ( "CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys})" ) def alter_format(self): return "ALTER TABLE {table} {columns}" def get_foreign_key_constraint_string(self): return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}" def get_primary_key_constraint_string(self): return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})" def get_unique_constraint_string(self): return "CONSTRAINT {constraint_name} UNIQUE ({columns})" def compile_table_exists(self, table, database=None, schema=None): return f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table}'" def compile_truncate(self, table, foreign_keys=False): if not foreign_keys: return f"TRUNCATE TABLE {self.wrap_table(table)}" return [ f"ALTER TABLE {self.wrap_table(table)} NOCHECK CONSTRAINT ALL", f"TRUNCATE TABLE {self.wrap_table(table)}", f"ALTER TABLE {self.wrap_table(table)} WITH CHECK CHECK CONSTRAINT ALL", ] def compile_rename_table(self, current_name, new_name): return f"EXEC sp_rename {self.wrap_table(current_name)}, {self.wrap_table(new_name)}" def compile_drop_table_if_exists(self, table): return f"DROP TABLE IF EXISTS {self.wrap_table(table)}" def compile_drop_table(self, table): return f"DROP TABLE {self.wrap_table(table)}" def compile_column_exists(self, table, column): return f"SELECT 1 FROM sys.columns WHERE Name = N'{column}' AND Object_ID = Object_ID(N'{table}')" def compile_get_all_tables(self, database, schema=None): return f"SELECT name FROM {database}.sys.tables" def get_current_schema(self, connection, table_name, schema=None): return Table(table_name) def enable_foreign_key_constraints(self): """MSSQL does not allow a global way to enable foreign key constraints""" return "" def disable_foreign_key_constraints(self): """MSSQL does not allow a global way to disable foreign key constraints""" return "" ================================================ FILE: src/masoniteorm/schema/platforms/MySQLPlatform.py ================================================ from ...schema import Schema from .Platform import Platform from ..Table import Table import re class MySQLPlatform(Platform): types_without_lengths = ["enum"] type_map = { "string": "VARCHAR", "char": "CHAR", "integer": "INT", "big_integer": "BIGINT", "tiny_integer": "TINYINT", "small_integer": "SMALLINT", "medium_integer": "MEDIUMINT", "integer_unsigned": "INT UNSIGNED", "big_integer_unsigned": "BIGINT UNSIGNED", "tiny_integer_unsigned": "TINYINT UNSIGNED", "small_integer_unsigned": "SMALLINT UNSIGNED", "medium_integer_unsigned": "MEDIUMINT UNSIGNED", "big_increments": "BIGINT UNSIGNED AUTO_INCREMENT", "increments": "INT UNSIGNED AUTO_INCREMENT", "uuid": "CHAR", "binary": "LONGBLOB", "boolean": "BOOLEAN", "decimal": "DECIMAL", "double": "DOUBLE", "enum": "ENUM", "text": "TEXT", "tiny_text": "TINYTEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", "jsonb": "LONGBLOB", "inet": "VARCHAR", "cidr": "VARCHAR", "macaddr": "VARCHAR", "long_text": "LONGTEXT", "point": "POINT", "time": "TIME", "timestamp": "TIMESTAMP", "date": "DATE", "year": "YEAR", "datetime": "DATETIME", "tiny_increments": "TINYINT AUTO_INCREMENT", "unsigned": "INT UNSIGNED", } premapped_nulls = {True: "NULL", False: "NOT NULL"} premapped_defaults = { "current": " DEFAULT CURRENT_TIMESTAMP", "now": " DEFAULT NOW()", "null": " DEFAULT NULL", } signed = {"unsigned": "UNSIGNED", "signed": "SIGNED"} def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" if column.default == "": default = " DEFAULT ''" elif column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" constraint = "" column_constraint = "" if column.primary: constraint = "PRIMARY KEY" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f"({values})" sql.append( self.columnize_string() .format( name=self.get_column_string().format(column=column.name), data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, length=length, constraint=constraint, nullable=self.premapped_nulls.get(column.is_null) or "", default=default, signed=( " " + self.signed.get(column._signed) if column._signed else "" ), comment=( "COMMENT '" + column.comment + "'" if column.comment else "" ), ) .strip() ) return sql def compile_create_sql(self, table, if_not_exists=False): sql = [] table_create_format = ( self.create_if_not_exists_format() if if_not_exists else self.create_format() ) sql.append( table_create_format.format( table=self.get_table_string().format(table=table.name), columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( ", " + ", ".join( self.constraintize(table.get_added_constraints(), table) ) if table.get_added_constraints() else "" ), foreign_keys=( ", " + ", ".join( self.foreign_key_constraintize( table.name, table.added_foreign_keys ) ) if table.added_foreign_keys else "" ), comment=f" COMMENT '{table.comment}'" if table.comment else "", ) ) if table.added_indexes: for name, index in table.added_indexes.items(): sql.append( "CREATE INDEX {name} ON {table}({column})".format( name=index.name, table=self.wrap_table(table.name), column=",".join(index.column), ) ) return sql def compile_alter_sql(self, table): sql = [] if table.added_columns: add_columns = [] for name, column in table.get_added_columns().items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" default = "" if column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)): default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" column_constraint = "" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f"({values})" add_columns.append( self.add_column_string() .format( name=self.get_column_string().format(column=column.name), data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, length=length, constraint="PRIMARY KEY" if column.primary else "", nullable="NULL" if column.is_null else "NOT NULL", default=default, signed=( " " + self.signed.get(column._signed) if column._signed else "" ), after=( (" AFTER " + self.wrap_column(column._after)) if column._after else "" ), comment=( " COMMENT '" + column.comment + "'" if column.comment else "" ), ) .strip() ) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(add_columns).strip(), comment=f" COMMENT '{table.comment}'" if table.comment else "", ) ) if table.renamed_columns: renamed_sql = [] for name, column in table.get_renamed_columns().items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" renamed_sql.append( self.rename_column_string() .format( to=self.columnize({column.name: column})[0], old=self.get_column_string().format(column=name), ) .strip() ) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(renamed_sql).strip(), ) ) if table.changed_columns: sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join( f"MODIFY {x}" for x in self.columnize(table.changed_columns) ), ) ) if table.dropped_columns: dropped_sql = [] for name in table.get_dropped_columns(): dropped_sql.append( self.drop_column_string() .format(name=self.get_column_string().format(column=name)) .strip() ) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(dropped_sql) ) ) if table.added_foreign_keys: for ( column, foreign_key_constraint, ) in table.get_added_foreign_keys().items(): cascade = "" if foreign_key_constraint.delete_action: cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key_constraint.delete_action.lower())}" if foreign_key_constraint.update_action: cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key_constraint.update_action.lower())}" sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD " + self.get_foreign_key_constraint_string().format( column=column, constraint_name=foreign_key_constraint.constraint_name, table=table.name, foreign_table=foreign_key_constraint.foreign_table, foreign_column=foreign_key_constraint.foreign_column, cascade=cascade, ) ) if table.dropped_foreign_keys: constraints = table.dropped_foreign_keys for constraint in constraints: sql.append( f"ALTER TABLE {self.wrap_table(table.name)} DROP FOREIGN KEY {constraint}" ) if table.added_indexes: for name, index in table.added_indexes.items(): sql.append( "CREATE INDEX {name} ON {table}({column})".format( name=index.name, table=self.wrap_table(table.name), column=",".join(index.column), ) ) if table.added_constraints: for name, constraint in table.added_constraints.items(): if constraint.constraint_type == "unique": sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT UNIQUE INDEX {constraint.name}({','.join(constraint.columns)})" ) elif constraint.constraint_type == "fulltext": sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD FULLTEXT {constraint.name}({','.join(constraint.columns)})" ) elif constraint.constraint_type == "primary_key": sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})" ) if ( table.removed_indexes or table.removed_unique_indexes or table.dropped_primary_keys ): constraints = table.removed_indexes constraints += table.removed_unique_indexes constraints += table.dropped_primary_keys for constraint in constraints: sql.append( f"ALTER TABLE {self.wrap_table(table.name)} DROP INDEX {constraint}" ) if table.comment: sql.append( f"ALTER TABLE {self.wrap_table(table.name)} COMMENT '{table.comment}'" ) return sql def add_column_string(self): return "ADD {name} {data_type}{length}{column_constraint}{signed} {nullable}{default}{after}{comment}" def drop_column_string(self): return "DROP COLUMN {name}" def change_column_string(self): return "MODIFY {name}{data_type}{length}{column_constraint} {nullable}{default} {constraint}" def rename_column_string(self): return "CHANGE {old} {to}" def columnize_string(self): return "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}{comment}" def constraintize(self, constraints, table): sql = [] for name, constraint in constraints.items(): sql.append( getattr( self, f"get_{constraint.constraint_type}_constraint_string" )().format( columns=", ".join(constraint.columns), name_columns="_".join(constraint.columns), table=table.name, constraint_name=constraint.name, ) ) return sql def get_table_string(self): return "`{table}`" def get_column_string(self): return "`{column}`" def create_format(self): return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys}){comment}" def create_if_not_exists_format(self): return "CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys}){comment}" def alter_format(self): return "ALTER TABLE {table} {columns}" def get_foreign_key_constraint_string(self): return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}" def get_primary_key_constraint_string(self): return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})" def get_unique_constraint_string(self): return "CONSTRAINT {constraint_name} UNIQUE ({columns})" def compile_table_exists(self, table, database=None, schema=None): return f"SELECT * from information_schema.tables where table_name='{table}' AND table_schema = '{database}'" def compile_truncate(self, table, foreign_keys=False): if not foreign_keys: return f"TRUNCATE {self.wrap_table(table)}" return [ self.disable_foreign_key_constraints(), f"TRUNCATE {self.wrap_table(table)}", self.enable_foreign_key_constraints(), ] def compile_rename_table(self, current_name, new_name): return f"ALTER TABLE {self.wrap_table(current_name)} RENAME TO {self.wrap_table(new_name)}" def compile_drop_table_if_exists(self, table): return f"DROP TABLE IF EXISTS {self.wrap_table(table)}" def compile_drop_table(self, table): return f"DROP TABLE {self.wrap_table(table)}" def compile_column_exists(self, table, column): return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'" def compile_get_all_tables(self, database, schema=None): return f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{database}'" def get_current_schema(self, connection, table_name, schema=None): table = Table(table_name) sql = f"DESCRIBE {table_name}" result = connection.query(sql, ()) reversed_type_map = {v: k for k, v in self.type_map.items()} for column in result: column_type = self.get_column_type( reversed_type_map, column["Type"].upper() ) length = self.get_column_length(column["Type"]) default = column.get("Default") table.add_column( column["Field"], column_type, column_python_type=Schema._type_hints_map.get(column_type, str), default=default, length=length, ) return table def get_column_type(self, reversed_type_map, column_type): if "(" in column_type: parenthesis_index = column_type.find("(") column_type = column_type[:parenthesis_index] length = self.get_column_length(column_type) if column_type == "CHAR": if length == "1": return "char" elif length == "36": return "uuid" else: return "char" elif column_type == "VARCHAR": if length == "4": return "year" else: return "string" return reversed_type_map.get(column_type) def get_column_length(self, raw_column_type): regex = re.compile(r"^\w+\((\d+)\)") match = regex.match(raw_column_type) if match: return match.groups()[0] if "(" in raw_column_type: parenthesis_index = raw_column_type.find("(") return raw_column_type[parenthesis_index + 1 : -1] else: return None def enable_foreign_key_constraints(self): return "SET FOREIGN_KEY_CHECKS=1" def disable_foreign_key_constraints(self): return "SET FOREIGN_KEY_CHECKS=0" ================================================ FILE: src/masoniteorm/schema/platforms/Platform.py ================================================ class Platform: foreign_key_actions = { "cascade": "CASCADE", "set null": "SET NULL", "cascade": "CASCADE", "restrict": "RESTRICT", "no action": "NO ACTION", "default": "SET DEFAULT", } signed = {"signed": "SIGNED", "unsigned": "UNSIGNED"} def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" if column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" sql.append( self.columnize_string() .format( name=column.name, data_type=self.type_map.get(column.column_type, ""), length=length, constraint="PRIMARY KEY" if column.primary else "", nullable=self.premapped_nulls.get(column.is_null) or "", default=default, ) .strip() ) return sql def columnize_string(self): raise NotImplementedError def create_column_length(self, column_type): if column_type in self.types_without_lengths: return "" return "({length})" def foreign_key_constraintize(self, table, foreign_keys): sql = [] for name, foreign_key in foreign_keys.items(): cascade = "" if foreign_key.delete_action: cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key.delete_action.lower())}" if foreign_key.update_action: cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key.update_action.lower())}" sql.append( self.get_foreign_key_constraint_string().format( column=self.wrap_column(foreign_key.column), constraint_name=foreign_key.constraint_name, table=self.wrap_table(table), foreign_table=self.wrap_table(foreign_key.foreign_table), foreign_column=self.wrap_column(foreign_key.foreign_column), cascade=cascade, ) ) return sql def constraintize(self, constraints): sql = [] for name, constraint in constraints.items(): sql.append( getattr( self, f"get_{constraint.constraint_type}_constraint_string" )().format(columns=", ".join(constraint.columns)) ) return sql def wrap_table(self, table_name): return self.get_table_string().format(table=table_name) def wrap_column(self, column_name): return self.get_column_string().format(column=column_name) ================================================ FILE: src/masoniteorm/schema/platforms/PostgresPlatform.py ================================================ from ...schema import Schema from .Platform import Platform from ..Table import Table class PostgresPlatform(Platform): types_without_lengths = [ "integer", "big_integer", "tiny_integer", "small_integer", "medium_integer", "inet", "cidr", "macaddr", "uuid", ] type_map = { "string": "VARCHAR", "char": "CHAR", "integer": "INTEGER", "big_integer": "BIGINT", "tiny_integer": "TINYINT", "big_increments": "BIGSERIAL UNIQUE", "small_integer": "SMALLINT", "medium_integer": "MEDIUMINT", # Postgres database does not implement unsigned types # So the below types are the same as the normal ones "integer_unsigned": "INTEGER", "big_integer_unsigned": "BIGINT", "tiny_integer_unsigned": "TINYINT", "small_integer_unsigned": "SMALLINT", "medium_integer_unsigned": "MEDIUMINT", "increments": "SERIAL UNIQUE", "uuid": "UUID", "binary": "BYTEA", "boolean": "BOOLEAN", "decimal": "DECIMAL", "double": "DOUBLE PRECISION", "enum": "VARCHAR", "text": "TEXT", "tiny_text": "TEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", "jsonb": "JSONB", "inet": "INET", "cidr": "CIDR", "macaddr": "MACADDR", "long_text": "TEXT", "point": "POINT", "time": "TIME", "timestamp": "TIMESTAMP", "date": "DATE", "year": "YEAR", "datetime": "TIMESTAMPTZ", "tiny_increments": "TINYINT AUTO_INCREMENT", "unsigned": "INT", } table_info_map = { "CHARACTER VARYING": "string", "TIMESTAMP WITH TIME ZONE": "datetime", "TIMESTAMP WITHOUT TIME ZONE": "datetime", } premapped_defaults = { "current": " DEFAULT CURRENT_TIMESTAMP", "now": " DEFAULT NOW()", "null": " DEFAULT NULL", } premapped_nulls = {True: "NULL", False: "NOT NULL"} def compile_create_sql(self, table, if_not_exists=False): sql = [] table_create_format = ( self.create_if_not_exists_format() if if_not_exists else self.create_format() ) sql.append( table_create_format.format( table=self.wrap_table(table.name), columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( ", " + ", ".join( self.constraintize(table.get_added_constraints(), table) ) if table.get_added_constraints() else "" ), foreign_keys=( ", " + ", ".join( self.foreign_key_constraintize( table.name, table.added_foreign_keys ) ) if table.added_foreign_keys else "" ), ) ) if table.added_indexes: for name, index in table.added_indexes.items(): sql.append( "CREATE INDEX {name} ON {table}({column})".format( name=index.name, table=self.wrap_table(table.name), column=",".join(index.column), ) ) for name, column in table.get_added_columns().items(): if column.comment: sql.append( f"""COMMENT ON COLUMN "{table.name}"."{name}" is '{column.comment}'""" ) if table.comment: sql.append(f"""COMMENT ON TABLE "{table.name}" is '{table.comment}'""") return sql def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" if column.default == "": default = " DEFAULT ''" elif column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" constraint = "" column_constraint = "" if column.primary: constraint = "PRIMARY KEY" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f" CHECK({column.name} IN ({values}))" sql.append( self.columnize_string() .format( name=self.wrap_column(column.name), data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, length=length, constraint=constraint, nullable=self.premapped_nulls.get(column.is_null) or "", default=default, ) .strip() ) return sql def compile_alter_sql(self, table): sql = [] if table.added_columns: add_columns = [] for name, column in table.get_added_columns().items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" default = "" if column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)): default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" column_constraint = "" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f" CHECK({column.name} IN ({values}))" add_columns.append( self.add_column_string() .format( name=self.wrap_column(column.name), data_type=self.type_map.get(column.column_type, ""), length=length, constraint="PRIMARY KEY" if column.primary else "", column_constraint=column_constraint, nullable="NULL" if column.is_null else "NOT NULL", default=default, after=( (" AFTER " + self.wrap_column(column._after)) if column._after else "" ), ) .strip() ) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(add_columns).strip(), ) ) if table.renamed_columns: renamed_sql = [] for name, column in table.get_renamed_columns().items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" renamed_sql.append( self.rename_column_string() .format( to=self.wrap_column(column.name), old=self.wrap_column(name) ) .strip() ) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(renamed_sql).strip(), ) ) if table.dropped_columns: dropped_sql = [] for name in table.get_dropped_columns(): dropped_sql.append( self.drop_column_string() .format(name=self.wrap_column(name)) .strip() ) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(dropped_sql) ) ) if table.changed_columns: changed_sql = [] for name, column in table.changed_columns.items(): column_constraint = "" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f" CHECK({column.name} IN ({values}))" changed_sql.append( self.modify_column_string() .format( name=self.wrap_column(name), data_type=self.type_map.get(column.column_type), column_constraint=column_constraint, constraint="PRIMARY KEY" if column.primary else "", length=( "(" + str(column.length) + ")" if column.column_type not in self.types_without_lengths else "" ), ) .strip() ) if column.is_null: changed_sql.append( f"ALTER COLUMN {self.wrap_column(name)} DROP NOT NULL" ) else: changed_sql.append( f"ALTER COLUMN {self.wrap_column(name)} SET NOT NULL" ) if column.default is not None: changed_sql.append( f"ALTER COLUMN {self.wrap_column(name)} SET DEFAULT {column.default}" ) sql.append( self.alter_format().format( table=self.wrap_table(table.name), columns=", ".join(changed_sql) ) ) if table.added_foreign_keys: for ( column, foreign_key_constraint, ) in table.get_added_foreign_keys().items(): cascade = "" if foreign_key_constraint.delete_action: cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key_constraint.delete_action.lower())}" if foreign_key_constraint.update_action: cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key_constraint.update_action.lower())}" sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD " + self.get_foreign_key_constraint_string().format( column=self.wrap_column(column), constraint_name=foreign_key_constraint.constraint_name, table=self.wrap_table(table.name), foreign_table=self.wrap_table( foreign_key_constraint.foreign_table ), foreign_column=self.wrap_column( foreign_key_constraint.foreign_column ), cascade=cascade, ) ) if table.removed_indexes: constraints = table.removed_indexes for constraint in constraints: sql.append(f"DROP INDEX {constraint}") if ( table.dropped_foreign_keys or table.removed_unique_indexes or table.dropped_primary_keys ): constraints = table.dropped_foreign_keys constraints += table.removed_unique_indexes constraints += table.dropped_primary_keys for constraint in constraints: sql.append( f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}" ) if table.added_indexes: for name, index in table.added_indexes.items(): sql.append( "CREATE INDEX {name} ON {table}({column})".format( name=index.name, table=self.wrap_table(table.name), column=",".join(index.column), ) ) if table.added_constraints: for name, constraint in table.added_constraints.items(): if constraint.constraint_type == "unique": sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} UNIQUE({','.join(constraint.columns)})" ) elif constraint.constraint_type == "primary_key": sql.append( f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})" ) for name, column in table.get_added_columns().items(): if column.comment: sql.append( f"""COMMENT ON COLUMN {self.wrap_table(table.name)}.{self.wrap_column(name)} is '{column.comment}'""" ) if table.comment: sql.append( f"""COMMENT ON TABLE {self.wrap_table(table.name)} is '{table.comment}'""" ) return sql def alter_format(self): return "ALTER TABLE {table} {columns}" def alter_format_add_foreign_key(self): return "ALTER TABLE {table} {columns}" def add_column_string(self): return "ADD COLUMN {name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}" def drop_column_string(self): return "DROP COLUMN {name}" def modify_column_string(self): return "ALTER COLUMN {name} TYPE {data_type}{length}{column_constraint} {constraint}" def rename_column_string(self): return "RENAME COLUMN {old} TO {to}" def columnize_string(self): return "{name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}" def constraintize(self, constraints, table): sql = [] for name, constraint in constraints.items(): sql.append( getattr( self, f"get_{constraint.constraint_type}_constraint_string" )().format( columns=", ".join(constraint.columns), name_columns="_".join(constraint.columns), constraint_name=constraint.name, table=table.name, ) ) return sql def create_format(self): return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys})" def create_if_not_exists_format(self): return ( "CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys})" ) def get_foreign_key_constraint_string(self): return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}" def get_primary_key_constraint_string(self): return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})" def get_unique_constraint_string(self): return "CONSTRAINT {constraint_name} UNIQUE ({columns})" def get_table_string(self): return '"{table}"' def get_column_string(self): return '"{column}"' def table_information_string(self): return "SELECT * FROM information_schema.columns WHERE table_schema = '{schema}' AND table_name = '{table}'" def compile_table_exists(self, table, database=None, schema=None): return f"SELECT * from information_schema.tables where table_name='{table}' AND table_schema = '{schema or 'public'}'" def compile_truncate(self, table, foreign_keys=False): if not foreign_keys: return f"TRUNCATE {self.wrap_table(table)}" return [ f"ALTER TABLE {self.wrap_table(table)} DISABLE TRIGGER ALL", f"TRUNCATE {self.wrap_table(table)}", f"ALTER TABLE {self.wrap_table(table)} ENABLE TRIGGER ALL", ] def compile_rename_table(self, current_name, new_name): return f"ALTER TABLE {self.wrap_table(current_name)} RENAME TO {self.wrap_table(new_name)}" def compile_drop_table_if_exists(self, table): return f"DROP TABLE IF EXISTS {self.wrap_table(table)}" def compile_drop_table(self, table): return f"DROP TABLE {self.wrap_table(table)}" def compile_column_exists(self, table, column): return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'" def compile_get_all_tables(self, database=None, schema=None): return f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{database}'" def get_current_schema(self, connection, table_name, schema=None): sql = self.table_information_string().format( table=table_name, schema=schema or "public" ) reversed_type_map = {v: k for k, v in self.type_map.items()} reversed_type_map.update(self.table_info_map) table = Table(table_name) result = connection.query(sql, ()) for column in result: column_type = reversed_type_map.get(column["data_type"].upper()) # find length if column.get("character_maximum_length", None): length = column.get("character_maximum_length") elif column.get("numeric_precision", None): length = column.get("numeric_precision") elif column.get("datetime_precision", None): length = column.get("datetime_precision") else: length = None # find default default = column.get("dflt_value", "") or column.get("column_default", "") if default and default.startswith("nextval"): table.set_primary_key(column["column_name"]) default = None table.add_column( column["column_name"], column_type, default=default, column_python_type=Schema._type_hints_map.get(column_type, str), length=length, ) return table def enable_foreign_key_constraints(self): """Postgres does not allow a global way to enable foreign key constraints""" return "" def disable_foreign_key_constraints(self): """Postgres does not allow a global way to disable foreign key constraints""" return "" ================================================ FILE: src/masoniteorm/schema/platforms/SQLitePlatform.py ================================================ from ...schema import Schema from ..Table import Table from .Platform import Platform class SQLitePlatform(Platform): types_without_lengths = [ "integer", "big_integer", "tiny_integer", "small_integer", "medium_integer", ] types_without_signs = ["decimal"] type_map = { "string": "VARCHAR", "char": "CHAR", "integer": "INTEGER", "big_integer": "BIGINT", "tiny_integer": "TINYINT", "big_increments": "BIGINT", "small_integer": "SMALLINT", "medium_integer": "MEDIUMINT", "integer_unsigned": "INT UNSIGNED", "big_integer_unsigned": "BIGINT UNSIGNED", "tiny_integer_unsigned": "TINYINT UNSIGNED", "small_integer_unsigned": "SMALLINT UNSIGNED", "medium_integer_unsigned": "MEDIUMINT UNSIGNED", "increments": "INTEGER", "uuid": "CHAR", "binary": "LONGBLOB", "boolean": "BOOLEAN", "decimal": "DECIMAL", "double": "DOUBLE", "enum": "VARCHAR", "text": "TEXT", "tiny_text": "TEXT", "float": "FLOAT", "geometry": "GEOMETRY", "json": "JSON", "jsonb": "LONGBLOB", "inet": "VARCHAR", "cidr": "VARCHAR", "macaddr": "VARCHAR", "long_text": "LONGTEXT", "point": "POINT", "time": "TIME", "timestamp": "TIMESTAMP", "date": "DATE", "year": "VARCHAR", "datetime": "DATETIME", "tiny_increments": "TINYINT AUTO_INCREMENT", "unsigned": "INT UNSIGNED", } premapped_defaults = { "current": " DEFAULT CURRENT_TIMESTAMP", "now": " DEFAULT NOW()", "null": " DEFAULT NULL", } premapped_nulls = {True: "NULL", False: "NOT NULL"} def compile_create_sql(self, table, if_not_exists=False): sql = [] table_create_format = ( self.create_if_not_exists_format() if if_not_exists else self.create_format() ) sql.append( table_create_format.format( table=self.get_table_string().format(table=table.name).strip(), columns=", ".join(self.columnize(table.get_added_columns())).strip(), constraints=( ", " + ", ".join(self.constraintize(table.get_added_constraints())) if table.get_added_constraints() else "" ), foreign_keys=( ", " + ", ".join( self.foreign_key_constraintize( table.name, table.added_foreign_keys ) ) if table.added_foreign_keys else "" ), ) ) if table.added_indexes: for name, index in table.added_indexes.items(): sql.append( f"CREATE INDEX {index.name} ON {self.wrap_table(table.name)}({','.join(index.column)})" ) return sql def columnize(self, columns): sql = [] for name, column in columns.items(): if column.length: length = self.create_column_length(column.column_type).format( length=column.length ) else: length = "" if column.default == "": default = " DEFAULT ''" elif column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)) and not column.default_is_raw: default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" constraint = "" column_constraint = "" if column.primary: constraint = "PRIMARY KEY" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f" CHECK({column.name} IN ({values}))" sql.append( self.columnize_string() .format( name=self.wrap_column(column.name), data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, length=length, signed=( " " + self.signed.get(column._signed) if column.column_type not in self.types_without_signs and column._signed else "" ), constraint=constraint, nullable=self.premapped_nulls.get(column.is_null) or "", default=default, ) .strip() ) return sql def compile_alter_sql(self, diff): sql = [] if diff.removed_indexes or diff.removed_unique_indexes: indexes = diff.removed_indexes indexes += diff.removed_unique_indexes for name in indexes: sql.append("DROP INDEX {name}".format(name=name)) if diff.added_columns: for name, column in diff.added_columns.items(): default = "" if column.default in (0,): default = f" DEFAULT {column.default}" elif column.default in self.premapped_defaults.keys(): default = self.premapped_defaults.get(column.default) elif column.default: if isinstance(column.default, (str,)): default = f" DEFAULT '{column.default}'" else: default = f" DEFAULT {column.default}" else: default = "" constraint = "" column_constraint = "" if column.name in diff.added_foreign_keys: foreign_key = diff.added_foreign_keys[column.name] constraint = f" REFERENCES {self.wrap_table(foreign_key.foreign_table)}({self.wrap_column(foreign_key.foreign_column)})" if column.column_type == "enum": values = ", ".join(f"'{x}'" for x in column.values) column_constraint = f" CHECK('{column.name}' IN({values}))" sql.append( self.add_column_string() .format( table=self.wrap_table(diff.name), name=self.wrap_column(column.name), data_type=self.type_map.get(column.column_type, ""), column_constraint=column_constraint, nullable="NULL" if column.is_null else "NOT NULL", default=default, signed=( " " + self.signed.get(column._signed) if column.column_type not in self.types_without_signs and column._signed else "" ), constraint=constraint, ) .strip() ) if ( diff.renamed_columns or diff.dropped_columns or diff.changed_columns or diff.added_foreign_keys ): original_columns = diff.from_table.added_columns # pop off the dropped columns. No need for them here for column in diff.dropped_columns: original_columns.pop(column) sql.append( "CREATE TEMPORARY TABLE __temp__{table} AS SELECT {original_column_names} FROM {table}".format( table=diff.name, original_column_names=", ".join( diff.from_table.added_columns.keys() ), ) ) sql.append("DROP TABLE {table}".format(table=self.wrap_table(diff.name))) columns = diff.from_table.added_columns columns.update(diff.renamed_columns) columns.update(diff.changed_columns) columns.update(diff.added_columns) sql.append( self.create_format().format( table=self.get_table_string().format(table=diff.name).strip(), columns=", ".join(self.columnize(columns)).strip(), constraints=( ", " + ", ".join(self.constraintize(diff.get_added_constraints())) if diff.get_added_constraints() else "" ), foreign_keys=( ", " + ", ".join( self.foreign_key_constraintize( diff.name, diff.added_foreign_keys ) ) if diff.added_foreign_keys else "" ), ) ) for column in diff.added_columns: columns.pop(column) sql.append( "INSERT INTO {quoted_table} ({new_columns}) SELECT {original_column_names} FROM __temp__{table}".format( quoted_table=self.wrap_table(diff.name), table=diff.name, new_columns=", ".join(self.columnize_names(columns)), original_column_names=", ".join( diff.from_table.added_columns.keys() ), ) ) sql.append("DROP TABLE __temp__{table}".format(table=diff.name)) if diff.new_name: sql.append( "ALTER TABLE {old_name} RENAME TO {new_name}".format( old_name=self.wrap_table(diff.name), new_name=self.wrap_table(diff.new_name), ) ) if diff.added_indexes: for name, index in diff.added_indexes.items(): sql.append( f"CREATE INDEX {index.name} ON {self.wrap_table(diff.name)}({','.join(index.column)})" ) if diff.added_constraints: for name, constraint in diff.added_constraints.items(): if constraint.constraint_type == "unique": sql.append( f"CREATE UNIQUE INDEX {constraint.name} ON {self.wrap_table(diff.name)}({','.join(constraint.columns if isinstance(constraint.columns, list) else [constraint.columns])})" ) elif constraint.constraint_type == "primary_key": sql.append( f"ALTER TABLE {self.wrap_table(diff.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})" ) return sql def create_format(self): return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys})" def create_if_not_exists_format(self): return ( "CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys})" ) def get_table_string(self): return '"{table}"' def get_column_string(self): return '"{column}"' def add_column_string(self): return "ALTER TABLE {table} ADD COLUMN {name} {data_type}{column_constraint}{signed} {nullable}{default}{constraint}" def create_column_length(self, column_type): if column_type in self.types_without_lengths: return "" return "({length})" def columnize_string(self): return "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}" def get_unique_constraint_string(self): return "UNIQUE({columns})" def get_foreign_key_constraint_string(self): return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}" def get_primary_key_constraint_string(self): return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})" def constraintize(self, constraints): sql = [] for name, constraint in constraints.items(): sql.append( getattr( self, f"get_{constraint.constraint_type}_constraint_string" )().format( columns=", ".join(constraint.columns), constraint_name=constraint.name, ) ) return sql def foreign_key_constraintize(self, table, foreign_keys): sql = [] for name, foreign_key in foreign_keys.items(): cascade = "" if foreign_key.delete_action: cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key.delete_action.lower())}" if foreign_key.update_action: cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key.update_action.lower())}" sql.append( self.get_foreign_key_constraint_string().format( column=self.wrap_column(foreign_key.column), constraint_name=foreign_key.constraint_name, table=self.wrap_table(table), foreign_table=self.wrap_table(foreign_key.foreign_table), foreign_column=self.wrap_column(foreign_key.foreign_column), cascade=cascade, ) ) return sql def columnize_names(self, columns): names = [] for name, column in columns.items(): names.append(self.wrap_column(column.name)) return names def get_current_schema(self, connection, table_name, schema=None): sql = f"PRAGMA table_info({table_name})" reversed_type_map = {v: k for k, v in self.type_map.items()} table = Table(table_name) result = connection.query(sql, ()) for column in result: column_type = self.get_column_type( reversed_type_map, column["type"].upper() ) length = self.get_column_length(column["type"]) # find default default = column.get("dflt_value") if default: default = default.replace("'", "") table.add_column( column["name"], column_type, column_python_type=Schema._type_hints_map.get(column_type, str), default=default, length=length, nullable=int(column.get("notnull")) == 0, ) if column.get("pk") == 1: table.set_primary_key(column["name"]) return table def get_column_length(self, column_type): if "(" in column_type: parenthesis_index = column_type.find("(") return column_type[parenthesis_index + 1 : -1] else: return None def get_column_type(self, reversed_type_map, column_type): if "(" in column_type: parenthesis_index = column_type.find("(") db_type = column_type[:parenthesis_index] length = self.get_column_length(column_type) if db_type == "CHAR": if length == "1": return "char" elif length == "36": return "uuid" else: return "char" elif db_type == "VARCHAR": if length == "4": return "year" else: return "string" else: return reversed_type_map.get(column_type) def compile_table_exists(self, table, database=None, schema=None): return f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table}'" def compile_column_exists(self, table, column): return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'" def compile_get_all_tables(self, database, schema=None): return "SELECT name FROM sqlite_master WHERE type='table'" def compile_truncate(self, table, foreign_keys=False): if not foreign_keys: return f"DELETE FROM {self.wrap_table(table)}" return [ self.disable_foreign_key_constraints(), f"DELETE FROM {self.wrap_table(table)}", self.enable_foreign_key_constraints(), ] def compile_rename_table(self, current_table, new_name): return f"ALTER TABLE {self.wrap_table(current_table)} RENAME TO {self.wrap_table(new_name)}" def compile_drop_table_if_exists(self, current_table): return f"DROP TABLE IF EXISTS {self.wrap_table(current_table)}" def compile_drop_table(self, current_table): return f"DROP TABLE {self.wrap_table(current_table)}" def enable_foreign_key_constraints(self): return "PRAGMA foreign_keys = ON" def disable_foreign_key_constraints(self): return "PRAGMA foreign_keys = OFF" ================================================ FILE: src/masoniteorm/schema/platforms/__init__.py ================================================ from .SQLitePlatform import SQLitePlatform from .MySQLPlatform import MySQLPlatform from .MSSQLPlatform import MSSQLPlatform from .PostgresPlatform import PostgresPlatform ================================================ FILE: src/masoniteorm/scopes/BaseScope.py ================================================ class BaseScope: def on_boot(self, builder): raise NotImplementedError() def on_remove(self, builder): raise NotImplementedError() ================================================ FILE: src/masoniteorm/scopes/SoftDeleteScope.py ================================================ from .BaseScope import BaseScope class SoftDeleteScope(BaseScope): """Global scope class to add soft deleting to models.""" def __init__(self, deleted_at_column="deleted_at"): self.deleted_at_column = deleted_at_column def on_boot(self, builder): builder.set_global_scope("_where_null", self._where_null, action="select") builder.set_global_scope( "_query_set_null_on_delete", self._query_set_null_on_delete, action="delete" ) builder.macro("with_trashed", self._with_trashed) builder.macro("only_trashed", self._only_trashed) builder.macro("force_delete", self._force_delete) builder.macro("restore", self._restore) def on_remove(self, builder): builder.remove_global_scope("_where_null", action="select") builder.remove_global_scope("_query_set_null_on_delete", action="delete") def _where_null(self, builder): return builder.where_null( f"{builder.get_table_name()}.{self.deleted_at_column}" ) def _with_trashed(self, model, builder): builder.remove_global_scope("_where_null", action="select") return builder def _only_trashed(self, model, builder): builder.remove_global_scope("_where_null", action="select") return builder.where_not_null(self.deleted_at_column) def _force_delete(self, model, builder, query=False): if query: return builder.remove_global_scope(self).set_action("delete") return builder.remove_global_scope(self).delete() def _restore(self, model, builder): return builder.remove_global_scope(self).update({self.deleted_at_column: None}) def _query_set_null_on_delete(self, builder): return builder.set_action("update").set_updates( {self.deleted_at_column: builder._model.get_new_datetime_string()} ) ================================================ FILE: src/masoniteorm/scopes/SoftDeletesMixin.py ================================================ from .SoftDeleteScope import SoftDeleteScope class SoftDeletesMixin: """Global scope class to add soft deleting to models.""" __deleted_at__ = "deleted_at" def boot_SoftDeletesMixin(self, builder): builder.set_global_scope(SoftDeleteScope(self.__deleted_at__)) def get_deleted_at_column(self): return self.__deleted_at__ ================================================ FILE: src/masoniteorm/scopes/TimeStampsMixin.py ================================================ from .TimeStampsScope import TimeStampsScope class TimeStampsMixin: """Global scope class to add soft deleting to models.""" def boot_TimeStampsMixin(self, builder): builder.set_global_scope(TimeStampsScope()) def activate_timestamps(self, boolean=True): self.__timestamps__ = boolean return self ================================================ FILE: src/masoniteorm/scopes/TimeStampsScope.py ================================================ from ..expressions.expressions import UpdateQueryExpression from .BaseScope import BaseScope class TimeStampsScope(BaseScope): """Global scope class to add soft deleting to models.""" def on_boot(self, builder): builder.set_global_scope( "_timestamps", self.set_timestamp_create, action="insert" ) builder.set_global_scope( "_timestamp_update", self.set_timestamp_update, action="update" ) def on_remove(self, builder): pass def set_timestamp(owner_cls, query): owner_cls.updated_at = "now" def set_timestamp_create(self, builder): if not builder._model.__timestamps__: return builder builder._creates.update( { builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string(), builder._model.date_created_at: builder._model.get_new_date().to_datetime_string(), } ) def set_timestamp_update(self, builder): if not builder._model.__timestamps__: return builder for update in builder._updates: if builder._model.date_updated_at in update.column: return builder._updates += ( UpdateQueryExpression( { builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string() } ), ) ================================================ FILE: src/masoniteorm/scopes/UUIDPrimaryKeyMixin.py ================================================ from .UUIDPrimaryKeyScope import UUIDPrimaryKeyScope class UUIDPrimaryKeyMixin: """Global scope class to add UUID as primary key to models.""" def boot_UUIDPrimaryKeyMixin(self, builder): builder.set_global_scope(UUIDPrimaryKeyScope()) ================================================ FILE: src/masoniteorm/scopes/UUIDPrimaryKeyScope.py ================================================ import uuid from .BaseScope import BaseScope class UUIDPrimaryKeyScope(BaseScope): """Global scope class to use UUID4 as primary key.""" def on_boot(self, builder): builder.set_global_scope( "_UUID_primary_key", self.set_uuid_create, action="insert" ) builder.set_global_scope( "_UUID_primary_key", self.set_bulk_uuid_create, action="bulk_create" ) def on_remove(self, builder): pass def generate_uuid(self, builder, uuid_version, bytes=False): # UUID 3 and 5 requires parameters uuid_func = getattr(uuid, f"uuid{uuid_version}") args = [] if uuid_version in [3, 5]: args = [builder._model.__uuid_namespace__, builder._model.__uuid_name__] return uuid_func(*args).bytes if bytes else str(uuid_func(*args)) def build_uuid_pk(self, builder): uuid_version = getattr(builder._model, "__uuid_version__", 4) uuid_bytes = getattr(builder._model, "__uuid_bytes__", False) return { builder._model.__primary_key__: self.generate_uuid( builder, uuid_version, uuid_bytes ) } def set_uuid_create(self, builder): # if there is already a primary key, no need to set a new one if builder._model.__primary_key__ not in builder._creates: builder._creates.update(self.build_uuid_pk(builder)) def set_bulk_uuid_create(self, builder): for idx, create_atts in enumerate(builder._creates): if builder._model.__primary_key__ not in create_atts: builder._creates[idx].update(self.build_uuid_pk(builder)) ================================================ FILE: src/masoniteorm/scopes/__init__.py ================================================ from .scope import scope from .BaseScope import BaseScope from .SoftDeletesMixin import SoftDeletesMixin from .SoftDeleteScope import SoftDeleteScope from .TimeStampsMixin import TimeStampsMixin from .TimeStampsScope import TimeStampsScope from .UUIDPrimaryKeyScope import UUIDPrimaryKeyScope from .UUIDPrimaryKeyMixin import UUIDPrimaryKeyMixin ================================================ FILE: src/masoniteorm/scopes/scope.py ================================================ class scope: def __init__(self, callback, *params, **kwargs): self.fn = callback def __set_name__(self, cls, name): if cls not in cls._scopes: cls._scopes[cls] = {name: self.fn} else: cls._scopes[cls].update({name: self.fn}) self.cls = cls def __call__(self, *args, **kwargs): instantiated = self.cls() builder = instantiated.get_builder() return self.fn(instantiated, builder, *args, **kwargs) ================================================ FILE: src/masoniteorm/seeds/Seeder.py ================================================ import pydoc class Seeder: def __init__(self, dry=False, seed_path="databases/seeds", connection=None): self.ran_seeds = [] self.dry = dry self.seed_path = seed_path self.connection = connection self.seed_module = seed_path.replace("/", ".").replace("\\", ".") def call(self, *seeder_classes): for seeder_class in seeder_classes: self.ran_seeds.append(seeder_class) if not self.dry: seeder_class(connection=self.connection).run() def run_database_seed(self): database_seeder = pydoc.locate( f"{self.seed_module}.database_seeder.DatabaseSeeder" ) self.ran_seeds.append(database_seeder) if not self.dry: database_seeder(connection=self.connection).run() def run_specific_seed(self, seed): file_name = f"{self.seed_module}.{seed}" database_seeder = pydoc.locate(file_name) if not database_seeder: raise ValueError(f"Could not find the {file_name} seeder file") self.ran_seeds.append(database_seeder) if not self.dry: database_seeder(connection=self.connection).run() else: print(f"Running {database_seeder}") ================================================ FILE: src/masoniteorm/seeds/__init__.py ================================================ from .Seeder import Seeder ================================================ FILE: src/masoniteorm/stubs/create-migration.html ================================================ from masoniteorm.migrations import Migration class CreateUsersTable(Migration): def up(self): """Run the migrations.""" with self.schema.create('users') as table: pass def down(self): """Revert the migrations.""" self.schema.drop('users') ================================================ FILE: src/masoniteorm/stubs/table-migration.html ================================================ ================================================ FILE: src/masoniteorm/testing/BaseTestCaseSelectGrammar.py ================================================ import inspect from ..query import QueryBuilder from ..expressions import JoinClause from ..models import Model class MockConnection: connection_details = {} def make_connection(self): return self class BaseTestCaseSelectGrammar: def setUp(self): self.builder = QueryBuilder( self.grammar, table="users", connection_class=MockConnection, model=Model(), dry=True, ) def test_can_compile_select(self): to_sql = self.builder.to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_order_by_and_first(self): to_sql = self.builder.order_by("id", "asc").first(query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_columns(self): to_sql = self.builder.select("username", "password").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_where(self): to_sql = self.builder.select("username", "password").where("id", 1).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_or_where(self): to_sql = self.builder.where("name", 2).or_where("name", 3).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_grouped_where(self): to_sql = self.builder.where( lambda query: query.where("age", 2).where("name", "Joe") ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_several_where(self): to_sql = ( self.builder.select("username", "password") .where("id", 1) .where("username", "joe") .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_several_where_and_limit(self): to_sql = ( self.builder.select("username", "password") .where("id", 1) .where("username", "joe") .limit(10) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_sum(self): to_sql = self.builder.sum("age").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_max(self): to_sql = self.builder.max("age").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_max_and_columns(self): to_sql = self.builder.select("username").max("age").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_max_and_columns_different_order(self): to_sql = self.builder.max("age").select("username").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_order_by(self): to_sql = self.builder.select("username").order_by("age", "desc").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_multiple_order_by(self): to_sql = ( self.builder.select("username") .order_by("age", "desc") .order_by("name") .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_with_group_by(self): to_sql = self.builder.select("username").group_by("age").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_where_in(self): to_sql = self.builder.select("username").where_in("age", [1, 2, 3]).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_where_in_empty(self): to_sql = self.builder.where_in("age", []).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_where_not_in(self): to_sql = self.builder.select("username").where_not_in("age", [1, 2, 3]).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_where_null(self): to_sql = self.builder.select("username").where_null("age").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_where_not_null(self): to_sql = self.builder.select("username").where_not_null("age").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_count(self): to_sql = self.builder.count("*").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_count_column(self): to_sql = self.builder.count("money").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_where_column(self): to_sql = self.builder.where_column("name", "email").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_sub_select(self): to_sql = self.builder.where_in( "name", self.builder.new().select("age") ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_complex_sub_select(self): to_sql = self.builder.where_in( "name", ( self.builder.new() .select("age") .where_in("email", self.builder.new().select("email")) ), ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_sub_select_where(self): to_sql = self.builder.where_in( "age", self.builder.new().select("age").where("age", 2).where("name", "Joe") ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_sub_select_from_lambda(self): to_sql = ( self.builder.new() .where_in( "age", lambda q: (q.select("age").where("age", 2).where("name", "Joe")) ) .to_sql() ) sql = getattr(self, "can_compile_sub_select_where")() self.assertEqual(to_sql, sql) def test_can_compile_sub_select_value(self): to_sql = self.builder.where("name", self.builder.new().sum("age")).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_exists(self): to_sql = ( self.builder.select("age") .where_exists(self.builder.new().select("username").where("age", 12)) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_not_exists(self): to_sql = ( self.builder.select("age") .where_not_exists(self.builder.new().select("username").where("age", 12)) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_having(self): to_sql = self.builder.sum("age").group_by("age").having("age").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_having_with_expression(self): to_sql = self.builder.sum("age").group_by("age").having("age", 10).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_having_with_greater_than_expression(self): to_sql = self.builder.sum("age").group_by("age").having("age", ">", 10).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_join(self): to_sql = self.builder.join( "contacts", "users.id", "=", "contacts.user_id" ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_join_clause(self): clause = ( JoinClause("report_groups as rg") .on("bgt.fund", "=", "rg.fund") .on("bgt.dept", "=", "rg.dept") .on("bgt.acct", "=", "rg.acct") .on("bgt.sub", "=", "rg.sub") ) to_sql = self.builder.join(clause).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_value(self): clause = ( JoinClause("report_groups as rg") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) to_sql = self.builder.join(clause).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_null(self): clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") .or_on_null("bgt.dept") .on_value("rg.abc", 10) ) to_sql = self.builder.join(clause).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_not_null(self): clause = ( JoinClause("report_groups as rg") .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") .on_value("rg.abc", 10) ) to_sql = self.builder.join(clause).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_join_clause_with_lambda(self): to_sql = self.builder.join( "report_groups as rg", lambda clause: (clause.on("bgt.fund", "=", "rg.fund").on_null("bgt")), ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_left_join_clause_with_lambda(self): to_sql = self.builder.left_join( "report_groups as rg", lambda clause: (clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt")), ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_right_join_clause_with_lambda(self): to_sql = self.builder.right_join( "report_groups as rg", lambda clause: (clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt")), ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_left_join(self): to_sql = self.builder.left_join( "contacts", "users.id", "=", "contacts.user_id" ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_multiple_join(self): to_sql = ( self.builder.join("contacts", "users.id", "=", "contacts.user_id") .join("posts", "comments.post_id", "=", "posts.id") .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_limit_and_offset(self): to_sql = self.builder.limit(10).offset(10).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_between(self): to_sql = self.builder.between("age", 18, 21).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_not_between(self): to_sql = self.builder.not_between("age", 18, 21).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_user_where_raw_and_where(self): to_sql = ( self.builder.where_raw("age = '18'").where("name", "=", "James").to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_where_raw_and_where_with_multiple_bindings(self): query = self.builder.where_raw( "`age` = '?' AND `is_admin` = '?'", [18, True] ).where("email", "test@example.com") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(query.to_qmark(), sql) self.assertEqual(query._bindings, [18, True, "test@example.com"]) def test_can_compile_first_or_fail(self): to_sql = ( self.builder.where("is_admin", "=", True).first_or_fail(query=True).to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_where_like(self): to_sql = self.builder.where("age", "like", "%name%").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_where_regexp(self): to_sql = self.builder.where("age", "regexp", "Joe").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_where_exists_with_lambda(self): to_sql = self.builder.where_exists(lambda q: q.where("age", 1)).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() print(to_sql) self.assertEqual(to_sql, sql) def test_where_not_exists_with_lambda(self): to_sql = self.builder.where_not_exists(lambda q: q.where("age", 1)).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() print(to_sql) self.assertEqual(to_sql, sql) def test_where_not_regexp(self): to_sql = self.builder.where("age", "not regexp", "Joe").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_where_not_like(self): to_sql = self.builder.where("age", "not like", "%name%").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_shared_lock(self): to_sql = self.builder.where("votes", ">=", 100).shared_lock().to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_update_lock(self): to_sql = self.builder.where("votes", ">=", 100).lock_for_update().to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_where_date(self): to_sql = self.builder.where_date("created_at", "2022-06-01").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_or_where_null(self): to_sql = self.builder.where_null("column1").or_where_null("column2").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_select_distinct(self): to_sql = self.builder.select("group").distinct().to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) ================================================ FILE: src/masoniteorm/testing/__init__.py ================================================ from .BaseTestCaseSelectGrammar import BaseTestCaseSelectGrammar ================================================ FILE: tests/User.py ================================================ """ User Model """ from src.masoniteorm import Model class User(Model): """User Model""" __fillable__ = ["name", "email", "password"] __connection__ = "t" __auth__ = "email" @property def meta(self): return 1 ================================================ FILE: tests/collection/test_collection.py ================================================ import unittest from src.masoniteorm.collection import Collection from src.masoniteorm.factories import Factory as factory from src.masoniteorm.models import Model from tests.User import User class TestCollection(unittest.TestCase): def test_take(self): collection = Collection([1, 2, 3, 4]) self.assertEqual(collection.take(2), [1, 2]) def test_first(self): collection = Collection([1, 2, 3, 4]) self.assertEqual(collection.first(), 1) self.assertEqual(collection.last(), 4) self.assertEqual(collection.first(lambda x: x < 3), 1) def test_last(self): collection = Collection([1, 2, 3, 4]) self.assertEqual(collection.last(), 4) self.assertEqual(collection.last(lambda x: x < 3), 2) def test_pluck(self): collection = Collection([{"id": 1, "name": "Joe"}, {"id": 2, "name": "Bob"}]) self.assertEqual(collection.pluck("id"), [1, 2]) self.assertEqual(collection.pluck("id").serialize(), [1, 2]) self.assertEqual(collection.pluck("name", "id"), {1: "Joe", 2: "Bob"}) def test_pluck_with_models(self): factory.register(Model, lambda faker: {"id": 1, "batch": 1}) collection = factory(Model, 5).make() self.assertEqual(collection.pluck("batch"), [1, 1, 1, 1, 1]) def test_where(self): collection = Collection( [ {"id": 1, "name": "Joe"}, {"id": 2, "name": "Joe"}, {"id": 3, "name": "Bob"}, ] ) self.assertEqual(len(collection.where("name", "Joe")), 2) self.assertEqual(len(collection.where("id", "!=", 1)), 2) self.assertEqual(len(collection.where("id", ">", 1)), 2) self.assertEqual(len(collection.where("id", ">=", 1)), 3) self.assertEqual(len(collection.where("id", "<=", 1)), 1) self.assertEqual(len(collection.where("id", "<", 3)), 2) def test_where_in(self): collection = Collection( [ {"id": 1, "name": "Joe"}, {"id": 2, "name": "Joe"}, {"id": 3, "name": "Bob"}, ] ) self.assertEqual(len(collection.where_in("id", [1, 2])), 2) self.assertEqual(len(collection.where_in("id", [3])), 1) self.assertEqual(len(collection.where_in("id", [4])), 0) self.assertEqual(len(collection.where_in("id", ["1", "2"])), 2) self.assertEqual(len(collection.where_in("id", ["3"])), 1) self.assertEqual(len(collection.where_in("id", ["4"])), 0) self.assertEqual(len(collection.where_in("name", ["Joe"])), 2) def test_where_not_in(self): collection = Collection( [ {"id": 1, "name": "Joe"}, {"id": 2, "name": "Joe"}, {"id": 3, "name": "Bob"}, ] ) self.assertEqual(len(collection.where_not_in("id", [1, 2])), 1) self.assertEqual(len(collection.where_not_in("id", [3])), 2) self.assertEqual(len(collection.where_not_in("id", [4])), 3) self.assertEqual(len(collection.where_not_in("id", ["1", "2"])), 1) self.assertEqual(len(collection.where_not_in("id", ["3"])), 2) self.assertEqual(len(collection.where_not_in("id", ["4"])), 3) self.assertEqual(len(collection.where_not_in("name", ["Joe"])), 1) def test_where_in_bool(self): nested_collection = Collection( [ {"id": 1, "is_active": True}, {"id": 2, "is_active": True}, {"id": 3, "is_active": True}, {"id": 4}, ] ) self.assertEqual(len(nested_collection.where_in("is_active", [False])), 0) self.assertEqual(len(nested_collection.where_in("is_active", [True])), 3) self.assertEqual(len(nested_collection.where_in("is_active", [True, False])), 3) obj_collection = Collection( [ type("", (), {"is_active": True, "is_disabled": False}), type("", (), {"is_active": False, "is_disabled": True}), type("", (), {"is_active": True, "is_disabled": True}), ] ) self.assertEqual(len(obj_collection.where_in("is_active", [False])), 1) self.assertEqual(len(obj_collection.where_in("is_active", [True])), 2) self.assertEqual(len(obj_collection.where_in("is_active", [True, False])), 3) self.assertEqual(len(obj_collection.where_in("nonexistent_key", [False])), 0) self.assertEqual(len(obj_collection.where_in("nonexistent_key", [True])), 0) def test_where_in_bytes(self): byte_strs = [bytes("should find this", "utf-8"), bytes("and this", "utf-8")] collection = Collection( [ {"id": 1, "name": "Joe", "bytes_val": byte_strs[0]}, {"id": 2, "name": "Joe", "bytes_val": byte_strs[1]}, { "id": 3, "name": "Bob", "bytes_val": bytes("should not find", "utf-8"), }, {"id": 4, "name": "Bob"}, ] ) self.assertEqual(len(collection.where_in("bytes_val", byte_strs)), 2) self.assertEqual(len(collection.where_in("bytes_val", [byte_strs[0]])), 1) def test_pop(self): collection = Collection([1, 2, 3]) self.assertEqual(collection.pop(), 3) self.assertEqual(collection.all(), [1, 2]) def test_is_empty(self): collection = Collection([]) self.assertEqual(collection.is_empty(), True) collection = Collection([1, 2, 3]) self.assertEqual(collection.is_empty(), False) def test_sum(self): collection = Collection([1, 1, 2, 4]) self.assertEqual(collection.sum(), 8) collection = Collection( [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) self.assertEqual(collection.sum("age"), 10) self.assertEqual(collection.sum(), 0) collection = Collection( [ {"name": "chair", "colours": ["green", "black"]}, {"name": "desk", "colours": ["red", "yellow"]}, {"name": "bookcase", "colours": ["white"]}, ] ) self.assertEqual(collection.sum(lambda x: len(x["colours"])), 5) self.assertEqual(collection.sum(lambda x: len(x)), 6) def test_avg(self): collection = Collection([1, 1, 2, 4]) self.assertEqual(collection.avg(), 2) collection = Collection( [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) self.assertEqual(collection.avg("age"), 2.5) self.assertEqual(collection.avg(), 0) collection = Collection( [ {"name": "chair", "colours": ["green", "black"]}, {"name": "desk", "colours": ["red", "yellow"]}, {"name": "bookcase", "colours": ["white"]}, ] ) self.assertEqual(collection.avg(lambda x: len(x["colours"])), 5 / 3) self.assertEqual(collection.avg(lambda x: len(x)), 2) def test_max(self): collection = Collection([1, 1, 2, 4]) self.assertEqual(collection.max(), 4) collection = Collection( [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) self.assertEqual(collection.max("age"), 4) self.assertEqual(collection.max(), 0) collection = Collection([{"batch": 1}, {"batch": 1}]) self.assertEqual(collection.max("batch"), 1) def test_min(self): collection = Collection([1, 1, 2, 4]) self.assertEqual(collection.min(), 1) collection = Collection( [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) self.assertEqual(collection.min("age"), 1) self.assertEqual(collection.min(), 0) collection = Collection([{"batch": 1}, {"batch": 1}]) self.assertEqual(collection.min("batch"), 1) def test_count(self): collection = Collection([1, 1, 2, 4]) self.assertEqual(collection.count(), 4) collection = Collection( [{"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}] ) self.assertEqual(collection.count(), 2) def test_chunk(self): collection = Collection([1, 1, 2, 4]) chunked = collection.chunk(2) self.assertEqual(chunked, Collection([Collection([1, 1]), Collection([2, 4])])) collection = Collection( [ {"name": "chair", "colours": ["green", "black"]}, {"name": "desk", "colours": ["red", "yellow"]}, {"name": "bookcase", "colours": ["white"]}, ] ) chunked = collection.chunk(2) self.assertEqual( chunked, Collection( [ Collection( [ {"name": "chair", "colours": ["green", "black"]}, {"name": "desk", "colours": ["red", "yellow"]}, ] ), Collection([{"name": "bookcase", "colours": ["white"]}]), ] ), ) def test_collapse(self): collection = Collection([[1, 1], [2, 4]]) collapsed = collection.collapse() self.assertEqual(collapsed, Collection([1, 1, 2, 4])) def test_get(self): collection = Collection([[1, 1], [2, 4]]) self.assertEqual(collection.get(0), [1, 1]) self.assertIsNone(collection.get(2)) self.assertEqual(collection.get(2, 0), 0) def test_merge(self): collection = Collection([[1, 1], [2, 4]]) collection.merge([[2, 1]]) self.assertEqual(collection.all(), [[1, 1], [2, 4], [2, 1]]) def test_reduce(self): callback = lambda x, y: x + y collection = Collection([1, 1, 2, 4]) sum = collection.sum() reduce = collection.reduce(callback) self.assertEqual(sum, reduce) reduce = collection.reduce(callback, 10) self.assertEqual(10 + sum, reduce) def test_forget(self): collection = Collection([1, 2, 3, 4]) collection.forget(0) self.assertEqual(collection.all(), [2, 3, 4]) collection.forget(1, 2) self.assertEqual(collection.all(), [2]) collection.forget(0) self.assertTrue(collection.is_empty()) def test_prepend(self): collection = Collection([1, 2, 3, 4]) collection.prepend(0) self.assertEqual(collection.get(0), 0) self.assertEqual(collection.all(), [0, 1, 2, 3, 4]) def test_pull(self): collection = Collection([1, 2, 3, 4]) value = collection.pull(0) self.assertEqual(value, 1) self.assertEqual(collection.all(), [2, 3, 4]) def test_push(self): collection = Collection([1, 2, 3, 4]) collection.push(5) self.assertEqual(collection.get(4), 5) self.assertEqual(collection.all(), [1, 2, 3, 4, 5]) def test_put(self): collection = Collection([1, 2, 3, 4]) collection.put(2, 5) self.assertEqual(collection.get(2), 5) self.assertEqual(collection.all(), [1, 2, 5, 4]) def test_reject(self): collection = Collection([1, 2, 3, 4]) collection.reject(lambda x: x if x > 2 else None) self.assertEqual(collection.all(), [3, 4]) collection = Collection( [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) collection.reject(lambda x: x if x["age"] > 2 else None) self.assertEqual( Collection( [{"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}] ), collection.all(), ) collection.reject(lambda x: x["age"] if x["age"] > 2 else None) self.assertEqual(collection.all(), [3, 4]) def test_for_page(self): collection = Collection([1, 2, 3, 4]) chunked = collection.for_page(0, 3) self.assertEqual(chunked.all(), [1, 2, 3]) def test_unique(self): collection = Collection([1, 1, 2, 3, 4]) unique = collection.unique() self.assertEqual(unique.all(), [1, 2, 3, 4]) collection = Collection( [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) unique = collection.unique("age") self.assertEqual( unique.all(), [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ], ) self.assertEqual(collection.pluck("name").unique().all(), ["Corentin All"]) def test_transform(self): collection = Collection([1, 1, 2, 3, 4]) collection.transform(lambda x: x * 2) self.assertEqual(collection.all(), [2, 2, 4, 6, 8]) def test_shift(self): collection = Collection([1, 2, 3, 4]) value = collection.shift() self.assertEqual(value, 1) self.assertEqual(collection.all(), [2, 3, 4]) collection = Collection( [ {"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) value = collection.shift() self.assertEqual(value, {"name": "Corentin All", "age": 1}) self.assertEqual( collection.all(), [ {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ], ) def test_sort(self): collection = Collection([4, 1, 2, 3]) collection.sort() self.assertEqual(collection.all(), [1, 2, 3, 4]) def test_reverse(self): collection = Collection([4, 1, 2, 3]) collection.reverse() self.assertEqual(collection.all(), [3, 2, 1, 4]) collection = Collection( [ {"name": "Corentin All", "age": 2}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}, ] ) collection.reverse() self.assertEqual( collection.all(), [ {"name": "Corentin All", "age": 4}, {"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 2}, ], ) def test_zip(self): collection = Collection(["Chair", "Desk"]) zipped = collection.zip([100, 200]) self.assertEqual(zipped.all(), [["Chair", 100], ["Desk", 200]]) def test_diff(self): collection = Collection(["Chair", "Desk"]) diff = collection.diff([100, 200]) self.assertEqual(diff.all(), ["Chair", "Desk"]) def test_each(self): collection = Collection([1, 2, 3, 4]) collection.each(lambda x: x + 2) self.assertEqual(collection.all(), [x + 2 for x in range(1, 5)]) def test_every(self): collection = Collection([1, 2, 3, 4]) self.assertFalse(collection.every(lambda x: x > 2)) collection = Collection([1, 2, 3, 4]) self.assertTrue(collection.every(lambda x: x >= 1)) def test_filter(self): collection = Collection([1, 2, 3, 4]) filtered = collection.filter(lambda x: x > 2) self.assertEqual(filtered.all(), Collection([3, 4])) def test_implode(self): collection = Collection([1, 2, 3, 4]) result = collection.implode("-") self.assertEqual(result, "1-2-3-4") collection = Collection( [{"name": "Corentin"}, {"name": "Joe"}, {"name": "Marlysson"}] ) result = collection.implode(key="name") self.assertEqual(result, "Corentin,Joe,Marlysson") def test_map_into(self): collection = Collection(["USD", "EUR", "GBP"]) class Currency: def __init__(self, code): self.code = code def __eq__(self, other): return self.code == other.code currencies = collection.map_into(Currency) self.assertEqual( currencies.all(), [Currency("USD"), Currency("EUR"), Currency("GBP")] ) def test_map(self): collection = Collection([1, 2, 3, 4]) multiplied = collection.map(lambda x: x * 2) self.assertEqual(multiplied.all(), [2, 4, 6, 8]) def callback(x): x["age"] = x["age"] + 2 return x collection = Collection( [ {"name": "Corentin", "age": 10}, {"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}, ] ) result = collection.map(callback) self.assertEqual( result.all(), [ {"name": "Corentin", "age": 12}, {"name": "Joe", "age": 22}, {"name": "Marlysson", "age": 17}, ], ) def test_serialize(self): class Currency: def __init__(self, code): self.code = code def __eq__(self, other): return self.code == other.code def to_dict(self): return {"code": self.code} collection = Collection( [ Collection([{"name": "Corentin", "age": 12}]), {"name": "Joe", "age": 22}, {"name": "Marlysson", "age": 17}, Currency("USD"), ] ) serialized_data = collection.serialize() self.assertEqual( serialized_data, [ [{"name": "Corentin", "age": 12}], {"name": "Joe", "age": 22}, {"name": "Marlysson", "age": 17}, {"code": "USD"}, ], ) def test_json(self): collection = Collection( [ {"name": "Corentin", "age": 10}, {"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}, ] ) json_data = collection.to_json() self.assertEqual( json_data, '[{"name": "Corentin", "age": 10}, ' '{"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}]', ) def test_contains(self): collection = Collection([1, 2, 3, 4]) self.assertTrue(collection.contains(3)) self.assertFalse(collection.contains(5)) collection = Collection( [ {"name": "Corentin", "age": 10}, {"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}, ] ) self.assertTrue(collection.contains(lambda x: x["age"] == 10)) self.assertFalse(collection.contains("age")) self.assertTrue(collection.contains("age", 10)) self.assertFalse(collection.contains("age", 11)) def test_all(self): collection = Collection([1, 2, 3, 4]) self.assertEqual(collection.all(), [1, 2, 3, 4]) collection = Collection( [ {"name": "Corentin", "age": 10}, {"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}, ] ) self.assertEqual( collection.all(), [ {"name": "Corentin", "age": 10}, {"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}, ], ) def test_flatten(self): collection = Collection([1, 2, [3, 4, 5, {"foo": "bar"}]]) flattened = collection.flatten() self.assertEqual(flattened.all(), [1, 2, 3, 4, 5, "bar"]) def test_group_by(self): collection = Collection( [ {"name": "Corentin", "age": 10}, {"name": "Joe", "age": 10}, {"name": "Marlysson", "age": 20}, ] ) grouped = collection.group_by("age") self.assertIsInstance(grouped, Collection) self.assertEqual( grouped, { 10: [{"name": "Corentin", "age": 10}, {"name": "Joe", "age": 10}], 20: [{"name": "Marlysson", "age": 20}], }, ) def test_serialize_with_model_appends(self): User.__appends__ = ["meta"] users = User.all().serialize() self.assertTrue(users[0].get("meta")) def test_serialize_with_on_the_fly_appends(self): users = User.all().set_appends(["meta"]).serialize() self.assertTrue(users[0].get("meta")) def test_random(self): collection = Collection([1, 2, 3, 4]) item = collection.random() self.assertIn(item, collection) collection = Collection([]) item = collection.random() self.assertIsNone(item) collection = Collection([3]) item = collection.random() self.assertEqual(item, 3) def test_random_with_count(self): collection = Collection([1, 2, 3, 4]) items = collection.random(2) self.assertEqual(items.count(), 2) self.assertIsInstance(items, Collection) with self.assertRaises(ValueError): items = collection.random(6) items = collection.random(1) self.assertEqual(items.count(), 1) self.assertIsInstance(items, Collection) def test_make_comparison(self): collection = Collection([]) self.assertTrue(collection._make_comparison(1, 1, "==")) self.assertTrue(collection._make_comparison(1, "1", "==")) def test_eq(self): collection = Collection([1, 2, 3, 4]) other = Collection([1, 2, 3, 4]) self.assertTrue(collection == other) different = Collection([1, 2, 3]) self.assertFalse(collection == different) ================================================ FILE: tests/commands/test_shell.py ================================================ import unittest from cleo import CommandTester from src.masoniteorm.commands import ShellCommand class TestShellCommand(unittest.TestCase): def setUp(self): self.command = ShellCommand() self.command_tester = CommandTester(self.command) def test_for_mysql(self): config = { "host": "localhost", "database": "orm", "user": "root", "port": "1234", "password": "secret", "prefix": "", "options": {"charset": "utf8mb4"}, "full_details": {"driver": "mysql"}, } command, _ = self.command.get_command(config) assert ( command == "mysql orm --host localhost --port 1234 --user root --password secret --default-character-set utf8mb4" ) def test_for_postgres(self): config = { "host": "localhost", "database": "orm", "user": "root", "port": "1234", "password": "secretpostgres", "prefix": "", "options": {"charset": "utf8mb4"}, "full_details": {"driver": "postgres"}, } command, env = self.command.get_command(config) assert command == "psql orm --host localhost --port 1234 --username root" assert env.get("PGPASSWORD", "secretpostgres") def test_for_sqlite(self): config = { "database": "orm.sqlite3", "prefix": "", "full_details": {"driver": "sqlite"}, } command, _ = self.command.get_command(config) assert command == "sqlite3 orm.sqlite3" def test_for_mssql(self): config = { "host": "db.masonite.com", "database": "orm", "user": "root", "port": "1234", "password": "secretpostgres", "prefix": "", "options": {"charset": "utf8mb4"}, "full_details": {"driver": "mssql"}, } command, _ = self.command.get_command(config) assert ( command == "sqlcmd -d orm -U root -P secretpostgres -S tcp:db.masonite.com,1234" ) def test_running_command_with_sqlite(self): self.command_tester.execute("-c dev") assert "sqlite3" not in self.command_tester.io.fetch_output() self.command_tester.execute("-c dev -s") assert "sqlite3 orm.sqlite3" in self.command_tester.io.fetch_output() def test_hiding_sensitive_options(self): config = { "host": "localhost", "database": "orm", "user": "root", "port": "", "password": "secret", "full_details": {"driver": "mysql"}, } command, _ = self.command.get_command(config) cleaned_command = self.command.hide_sensitive_options(config, command) assert ( cleaned_command == "mysql orm --host localhost --user root --password ***" ) ================================================ FILE: tests/config/test_db_url.py ================================================ import os import pytest import unittest from src.masoniteorm.config import db_url, load_config from src.masoniteorm.exceptions import InvalidUrlConfiguration from src.masoniteorm.connections import ConnectionResolver class TestDbUrlHelper(unittest.TestCase): def setUp(self): self.original_db_url = os.getenv("DATABASE_URL") # def tearDown(self): # os.environ["DATABASE_URL"] = self.original_db_url def test_parse_env_by_default(self): os.environ["DATABASE_URL"] = "mysql://root:@localhost:3306/orm" config = db_url() assert config.get("driver") == "mysql" # def test_raise_error_if_no_url(self): # # no DATABASE_URL is defined yet # with pytest.raises(InvalidUrlConfiguration): # db_url() def test_parse_sqlite(self): # check in memory use config = db_url("sqlite://") assert config.get("driver", "sqlite") assert config.get("database", ":memory:") assert not config.get("user") config = db_url("sqlite://:memory:") assert config.get("driver", "sqlite") assert config.get("database", ":memory:") assert not config.get("user") config = db_url("sqlite://db.sqlite3") assert config.get("driver", "sqlite") assert config.get("database", "db.sqlite3") assert not config.get("user") def test_parse_mysql(self): config = db_url("mysql://root:@localhost:3306/orm") assert config == { "driver": "mysql", "database": "orm", "prefix": "", "options": {}, "log_queries": False, "user": "root", "password": "", "host": "localhost", "port": 3306, } def test_parse_postgres(self): config = db_url( "postgres://utpcrbiihfqqys:de0a0d847094a66e32274262aa5b5f0ad78e5e34197875fc6089a2d9185d0032@ec2-54-225-242-183.compute-1.amazonaws.com:5432/da455n1ef8kout" ) assert config == { "driver": "postgres", "database": "da455n1ef8kout", "prefix": "", "options": {}, "log_queries": False, "user": "utpcrbiihfqqys", "password": "de0a0d847094a66e32274262aa5b5f0ad78e5e34197875fc6089a2d9185d0032", "host": "ec2-54-225-242-183.compute-1.amazonaws.com", "port": 5432, } def test_parse_mssql(self): config = db_url("mssql://john:secret@127.0.0.1:1433/mssql_db") assert config == { "driver": "mssql", "database": "mssql_db", "prefix": "", "options": {}, "log_queries": False, "user": "john", "password": "secret", "host": "127.0.0.1", "port": "1433", } def test_parse_with_params(self): config = db_url( "mysql://root:@localhost:3306/orm", log_queries=True, prefix="a", options={"key": "value"}, ) assert config == { "driver": "mysql", "database": "orm", "prefix": "a", "options": {"key": "value"}, "log_queries": True, "user": "root", "password": "", "host": "localhost", "port": 3306, } def test_using_it_with_connection_resolver(self): TEST_DATABASES = { "default": "test", "test": { **db_url("mysql://root:@localhost:3306/orm"), "prefix": "", "log_queries": True, }, } resolver = ConnectionResolver().set_connection_details(TEST_DATABASES) config = resolver.get_connection_details().get("test") assert config.get("database") == "orm" assert config.get("user") == "root" assert config.get("password") == "" assert config.get("port") == 3306 assert config.get("host") == "localhost" assert config.get("log_queries") # reset connection resolver to default for other tests to continue working from tests.integrations.config.database import DATABASES ConnectionResolver().set_connection_details(DATABASES) ================================================ FILE: tests/connections/test_base_connections.py ================================================ import unittest from src.masoniteorm.connections import ConnectionResolver from tests.integrations.config.database import DB class TestDefaultBehaviorConnections(unittest.TestCase): def test_should_return_connection_with_enabled_logs(self): connection = DB.begin_transaction("dev") should_log_queries = connection.full_details.get("log_queries") DB.commit("dev") self.assertTrue(should_log_queries) def test_should_disable_log_queries_in_connection(self): connection = DB.begin_transaction("dev") connection.disable_query_log() should_log_queries = connection.full_details.get("log_queries") self.assertFalse(should_log_queries) connection.enable_query_log() should_log_queries = connection.full_details.get("log_queries") DB.commit("dev") self.assertTrue(should_log_queries) ================================================ FILE: tests/eagers/test_eager.py ================================================ import os import unittest from src.masoniteorm.query.EagerRelation import EagerRelations class TestEagerRelation(unittest.TestCase): def test_can_register_string_eager_load(self): self.assertEqual( EagerRelations().register("profile").get_eagers(), [["profile"]] ) self.assertEqual(EagerRelations().register("profile").is_nested, False) self.assertEqual( EagerRelations().register("profile.user").get_eagers(), [{"profile": ["user"]}], ) self.assertEqual( EagerRelations().register("profile.user", "profile.logo").get_eagers(), [{"profile": ["user", "logo"]}], ) self.assertEqual( EagerRelations() .register("profile.user", "profile.logo", "profile.bio") .get_eagers(), [{"profile": ["user", "logo", "bio"]}], ) self.assertEqual( EagerRelations().register("user", "logo", "bio").get_eagers(), [["user", "logo", "bio"]], ) def test_can_register_tuple_eager_load(self): self.assertEqual( EagerRelations().register(("profile",)).get_eagers(), [["profile"]] ) self.assertEqual( EagerRelations().register(("profile", "user")).get_eagers(), [["profile", "user"]], ) self.assertEqual( EagerRelations().register(("profile.name", "profile.user")).get_eagers(), [{"profile": ["name", "user"]}], ) def test_can_register_list_eager_load(self): self.assertEqual( EagerRelations().register(["profile"]).get_eagers(), [["profile"]] ) self.assertEqual( EagerRelations().register(["profile", "user"]).get_eagers(), [["profile", "user"]], ) self.assertEqual( EagerRelations().register(["profile.name", "profile.user"]).get_eagers(), [{"profile": ["name", "user"]}], ) self.assertEqual( EagerRelations().register(["profile.name"]).get_eagers(), [{"profile": ["name"]}], ) self.assertEqual( EagerRelations().register(["profile.name", "logo"]).get_eagers(), [["logo"], {"profile": ["name"]}], ) self.assertEqual( EagerRelations() .register(["profile.name", "logo", "profile.user"]) .get_eagers(), [["logo"], {"profile": ["name", "user"]}], ) ================================================ FILE: tests/factories/test_factories.py ================================================ import os import unittest from src.masoniteorm import Factory as factory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder class User(Model): pass class AfterCreatedModel(Model): __dry__ = True pass class TestFactories(unittest.TestCase): def setUp(self): factory.register(User, self.user_factory) factory.register(User, self.named_user_factory, name="admin") factory.register(AfterCreatedModel, self.named_user_factory) factory.after_creating(AfterCreatedModel, self.after_creating) def user_factory(self, faker): return {"id": 1, "name": faker.name()} def named_user_factory(self, faker): return {"id": 1, "name": faker.name(), "admin": 1} def after_creating(self, model, faker): model.after_created = True def test_can_make_single(self): user = factory(User).make({"id": 1, "name": "Joe"}) self.assertEqual(user.name, "Joe") self.assertIsInstance(user, User) def test_can_make_several(self): users = factory(User).make([{"id": 1, "name": "Joe"}, {"id": 2, "name": "Bob"}]) self.assertEqual(users.count(), 2) def test_can_make_any_number(self): users = factory(User, 50).make() self.assertEqual(users.count(), 50) def test_can_make_named_factory(self): user = factory(User).make(name="admin") self.assertEqual(user.admin, 1) def test_after_creates(self): user = factory(AfterCreatedModel).create() self.assertTrue(user.name) self.assertEqual(user.after_created, True) users = factory(AfterCreatedModel).create({"name": "billy"}) self.assertEqual(users.name, "billy") self.assertEqual(users.after_created, True) users = factory(AfterCreatedModel).make() self.assertEqual(users.after_created, True) users = factory(AfterCreatedModel, 2).make() for user in users: self.assertEqual(user.after_created, True) ================================================ FILE: tests/integrations/config/__init__.py ================================================ ================================================ FILE: tests/integrations/config/database.py ================================================ """ Database Settings """ import os import logging from dotenv import load_dotenv from src.masoniteorm.connections import ConnectionResolver from src.masoniteorm.config import db_url """ |-------------------------------------------------------------------------- | Load Environment Variables |-------------------------------------------------------------------------- | | Loads in the environment variables when this page is imported. | """ load_dotenv(".env") """ The connections here don't determine the database but determine the "connection". They can be named whatever you want. """ DATABASES = { "default": "mysql", "mysql": { "driver": "mysql", "host": os.getenv("MYSQL_DATABASE_HOST"), "user": os.getenv("MYSQL_DATABASE_USER"), "password": os.getenv("MYSQL_DATABASE_PASSWORD"), "database": os.getenv("MYSQL_DATABASE_DATABASE"), "port": os.getenv("MYSQL_DATABASE_PORT"), "prefix": "", "options": {"charset": "utf8mb4"}, "log_queries": True, "propagate": False, "connection_pooling_enabled": True, "connection_pooling_max_size": 10, "connection_pooling_min_size": None, }, "t": {"driver": "sqlite", "database": "orm.sqlite3", "log_queries": True, "foreign_keys": True}, "devprod": { "driver": "mysql", "host": os.getenv("MYSQL_DATABASE_HOST"), "user": os.getenv("MYSQL_DATABASE_USER"), "password": os.getenv("MYSQL_DATABASE_PASSWORD"), "database": "DEVPROD", "port": os.getenv("MYSQL_DATABASE_PORT"), "prefix": "", "options": {"charset": "utf8mb4"}, "log_queries": True, "propagate": False, }, "many": { "driver": "mysql", "host": "localhost", "user": "root", "password": "", "database": "replicate", "port": os.getenv("MYSQL_DATABASE_PORT"), "options": {"charset": "utf8mb4"}, "log_queries": True, "propagate": False, }, "postgres": { "driver": "postgres", "host": os.getenv("POSTGRES_DATABASE_HOST"), "user": os.getenv("POSTGRES_DATABASE_USER"), "password": os.getenv("POSTGRES_DATABASE_PASSWORD"), "database": os.getenv("POSTGRES_DATABASE_DATABASE"), "port": os.getenv("POSTGRES_DATABASE_PORT"), "connection_pooling_enabled": True, "connection_pooling_max_size": 10, "connection_pooling_min_size": 2, "prefix": "", "log_queries": True, "propagate": False, }, # Example with db_url() # "postgres": db_url( # "postgres://user:@localhost:5432/postgres", log_queries=True # ), "dev": { "driver": "sqlite", "database": "orm.sqlite3", "prefix": "", "log_queries": True, }, # Example with db_url() # "dev": {**db_url("sqlite://orm.sqlite3"), "prefix": "", "log_queries": True}, "mssql": { "driver": "mssql", "host": os.getenv("MSSQL_DATABASE_HOST"), "user": os.getenv("MSSQL_DATABASE_USER"), "password": os.getenv("MSSQL_DATABASE_PASSWORD"), "database": os.getenv("MSSQL_DATABASE_DATABASE"), "port": os.getenv("MSSQL_DATABASE_PORT"), "prefix": "", "log_queries": True, "options": { "trusted_connection": "Yes", "integrated_security": "sspi", "instance": "SQLExpress", "authentication": "ActiveDirectoryPassword", "driver": "ODBC Driver 17 for SQL Server", "connection_timeout": 15, "connection_pooling": False, "connection_pooling_size": 100, }, }, } DB = ConnectionResolver().set_connection_details(DATABASES) logger = logging.getLogger("masoniteorm.connection.queries") logger.setLevel(logging.DEBUG) formatter = logging.Formatter("It executed the query %(query)s") stream_handler = logging.StreamHandler() file_handler = logging.FileHandler("queries.log") logger.addHandler(stream_handler) logger.addHandler(file_handler) # DB = QueryBuilder(connection_details=DATABASES) # DATABASES = { # 'default': os.environ.get('DB_DRIVER'), # 'sqlite': { # 'driver': 'sqlite', # 'database': os.environ.get('DB_DATABASE') # }, # 'postgres': { # 'driver': 'postgres', # 'host': env('DB_HOST'), # 'database': env('DB_DATABASE'), # 'port': env('DB_PORT'), # 'user': env('DB_USERNAME'), # 'password': env('DB_PASSWORD'), # 'log_queries': env('DB_LOG'), # }, # } # DB = DatabaseManager(DATABASES) # Model.set_connection_resolver(DB) ================================================ FILE: tests/models/test_models.py ================================================ import datetime import json import unittest import pendulum from src.masoniteorm.models import Model class ModelTest(Model): __dates__ = ["due_date"] __casts__ = { "is_vip": "bool", "payload": "json", "x": "int", "f": "float", "d": "decimal", } class FillableModelTest(Model): __fillable__ = ["due_date", "is_vip"] class InvalidFillableGuardedModelTest(Model): __fillable__ = ["due_date"] __guarded__ = ["is_vip", "payload"] class InvalidFillableGuardedChildModelTest(ModelTest): __fillable__ = ["due_date"] __guarded__ = ["is_vip", "payload"] class ModelTestForced(Model): __table__ = "users" __force_update__ = True class BaseModel(Model): __dry__ = True def get_selects(self): return [f"{self.get_table_name()}.*"] class ModelWithBaseModel(BaseModel): __table__ = "users" class TestModels(unittest.TestCase): def test_model_can_access_str_dates_as_pendulum(self): model = ModelTest.hydrate({"user": "joe", "due_date": "2020-11-28 11:42:07"}) self.assertTrue(model.user) self.assertTrue(model.due_date) self.assertIsInstance(model.due_date, pendulum.now().__class__) def test_model_can_access_str_dates_as_pendulum_from_correct_datetimes(self): model = ModelTest() self.assertEqual( model.get_new_date(datetime.datetime(2021, 1, 1, 7, 10)).hour, 7 ) self.assertEqual(model.get_new_date(datetime.date(2021, 1, 1)).hour, 0) self.assertEqual(model.get_new_date(datetime.time(1, 1, 1)).hour, 1) self.assertEqual(model.get_new_date("2020-11-28 11:42:07").hour, 11) def test_model_can_access_str_dates_on_relationships(self): model = ModelTest.hydrate({"user": "joe", "due_date": "2020-11-28 11:42:07"}) model.add_relation( { "profile": ModelTest.hydrate( {"name": "bob", "due_date": "2020-11-28 11:42:07"} ) } ) self.assertEqual(model.profile.name, "bob") self.assertTrue(model.profile.due_date.is_past()) def test_model_original_and_dirty_attributes(self): model = ModelTest.hydrate({"username": "joe", "admin": True}) self.assertEqual(model.username, "joe") self.assertEqual( model.__original_attributes__, {"username": "joe", "admin": True} ) model.username = "bob" self.assertEqual(model.username, "bob") self.assertEqual(model.get_original("username"), "joe") self.assertEqual(model.get_dirty("username"), "bob") self.assertEqual(model.__dirty_attributes__["username"], "bob") self.assertEqual(model.get_dirty_keys(), ["username"]) self.assertTrue(model.is_dirty() is True) self.assertEqual( model.__original_attributes__, {"username": "joe", "admin": True} ) def test_model_creates_when_new(self): model = ModelTest.hydrate({"id": 1, "username": "joe", "admin": True}) model.name = "Bill" sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("UPDATE")) model = ModelTest() model.name = "Bill" sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("INSERT")) def test_model_can_cast_attributes(self): model = ModelTest.hydrate( { "is_vip": 1, "payload": '["item1", "item2"]', "x": True, "f": "10.5", "d": 3.14, } ) self.assertEqual(type(model.payload), list) self.assertEqual(type(model.x), int) self.assertEqual(type(model.f), float) self.assertEqual(type(model.is_vip), bool) self.assertEqual(type(model.serialize()["is_vip"]), bool) def test_model_can_cast_dict_attributes(self): """test cast with dict object to json field""" dictcasttest = {} dictcasttest["key"] = "value" model = ModelTest.hydrate( {"is_vip": 1, "payload": dictcasttest, "x": True, "f": "10.5"} ) self.assertEqual(type(model.payload), dict) self.assertEqual(type(model.x), int) self.assertEqual(type(model.f), float) self.assertEqual(type(model.is_vip), bool) self.assertEqual(type(model.serialize()["is_vip"]), bool) def test_valid_json_cast(self): model = ModelTest.hydrate( {"payload": {"this": "dict", "is": "usable", "as": "json"}} ) self.assertEqual(type(model.payload), dict) model = ModelTest.hydrate( {"payload": {"this": "dict", "is": "invalid", "as": "json"}} ) self.assertEqual(type(model.payload), dict) model = ModelTest.hydrate( {"payload": '{"this": "dict", "is": "usable", "as": "json"}'} ) self.assertEqual(type(model.payload), dict) model = ModelTest.hydrate({"payload": '{"valid": "json", "int": 1}'}) self.assertEqual(type(model.payload), dict) model = ModelTest.hydrate({"payload": "{'this': 'should', 'throw': 'error'}"}) self.assertEqual(model.payload, None) with self.assertRaises(ValueError): model.payload = "{'this': 'should', 'throw': 'error'}" model.save() def test_model_update_without_changes(self): model = ModelTest.hydrate( {"id": 1, "username": "joe", "name": "Joe", "admin": True} ) model.username = "joe" model.name = "Bill" sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("UPDATE")) self.assertNotIn("username", sql) def test_force_update_on_model_class(self): model = ModelTestForced.hydrate( {"id": 1, "username": "joe", "name": "Joe", "admin": True} ) model.username = "joe" model.name = "Bill" sql = model.save(query=True).to_sql() self.assertTrue(sql.startswith("UPDATE")) self.assertIn("username", sql) self.assertIn("name", sql) def test_only_method(self): model = ModelTestForced.hydrate( {"id": 1, "username": "joe", "name": "Joe", "admin": True} ) self.assertEqual({"username": "joe"}, model.only("username")) self.assertEqual({"username": "joe"}, model.only(["username"])) def test_model_update_without_changes_at_all(self): model = ModelTest.hydrate( {"id": 1, "username": "joe", "name": "Joe", "admin": True} ) model.username = "joe" model.name = "Joe" sql = model.save(query=True).to_sql() self.assertFalse(sql.startswith("UPDATE")) def test_model_using_or_where(self): model = ModelTest() sql = model.where("name", "=", "joe").or_where("is_vip", True).to_sql() self.assertEqual( sql, """SELECT * FROM `model_tests` WHERE `model_tests`.`name` = 'joe' OR `model_tests`.`is_vip` = '1'""", ) def test_model_using_or_where_and_chaining_wheres(self): model = ModelTest() sql = ( model.where("name", "=", "joe") .or_where( lambda query: query.where("username", "Joseph").or_where( "age", ">=", 18 ) ) .to_sql() ) self.assertTrue( sql, """SELECT * FROM `model_tests` WHERE `model_tests`.`name` = 'joe' OR (`model_tests`.`username` = 'Joseph' OR `model_tests`.`age` >= '18'))""", ) def test_both_fillable_and_guarded_attributes_raise(self): # Both fillable and guarded props are populated on this class with self.assertRaises(AttributeError): InvalidFillableGuardedModelTest() # Child that inherits from an intermediary class also fails with self.assertRaises(AttributeError): InvalidFillableGuardedChildModelTest() # Still shouldn't be allowed to define even if empty InvalidFillableGuardedModelTest.__fillable__ = [] with self.assertRaises(AttributeError): InvalidFillableGuardedModelTest() # Or wildcard InvalidFillableGuardedModelTest.__fillable__ = ["*"] with self.assertRaises(AttributeError): InvalidFillableGuardedModelTest() # Empty guarded attr still raises InvalidFillableGuardedModelTest.__guarded__ = [] with self.assertRaises(AttributeError): InvalidFillableGuardedModelTest() # Removing one of the props allows us to instantiate delattr(InvalidFillableGuardedModelTest, "__guarded__") InvalidFillableGuardedModelTest() def test_model_can_provide_default_select(self): sql = ModelWithBaseModel.to_sql() self.assertEqual( sql, """SELECT `users`.* FROM `users`""", ) def test_model_can_override_to_default_select(self): sql = ModelWithBaseModel.select(["products.name", "products.id", "store.name"]).to_sql() self.assertEqual( sql, """SELECT `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""", ) def test_model_can_use_aggregate_funcs_with_default_selects(self): sql = ModelWithBaseModel.count().to_sql() self.assertEqual( sql, """SELECT COUNT(*) AS m_count_reserved FROM `users`""", ) sql = ModelWithBaseModel.max("id").to_sql() self.assertEqual( sql, """SELECT MAX(`users`.`id`) AS id FROM `users`""", ) sql = ModelWithBaseModel.min("id").to_sql() self.assertEqual( sql, """SELECT MIN(`users`.`id`) AS id FROM `users`""", ) ================================================ FILE: tests/mssql/builder/test_mssql_query_builder.py ================================================ import inspect import unittest from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar from tests.utils import MockConnectionFactory class MockConnection: connection_details = {} def make_connection(self): return self @classmethod def get_default_query_grammar(cls): return class ModelTest(Model): __timestamps__ = False class TestMSSQLQueryBuilder(unittest.TestCase): maxDiff = None def get_builder(self, table="users", dry=True): connection = MockConnectionFactory().make("mssql") return QueryBuilder( # self.grammar, connection_class=connection, connection="mssql", table=table, model=ModelTest(), dry=dry, ) def test_sum(self): builder = self.get_builder() builder.sum("age") self.assertEqual( builder.to_sql(), "SELECT SUM([users].[age]) AS age FROM [users]" ) def test_where_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] LIKE '%name%'" ) def test_where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] NOT LIKE '%name%'", ) def test_max(self): builder = self.get_builder() builder.max("age") self.assertEqual( builder.to_sql(), "SELECT MAX([users].[age]) AS age FROM [users]" ) def test_min(self): builder = self.get_builder() builder.min("age") self.assertEqual( builder.to_sql(), "SELECT MIN([users].[age]) AS age FROM [users]" ) def test_avg(self): builder = self.get_builder() builder.avg("age") self.assertEqual( builder.to_sql(), "SELECT AVG([users].[age]) AS age FROM [users]" ) def test_all(self): builder = self.get_builder() builder.all() self.assertEqual(builder.to_sql(), "SELECT * FROM [users]") def test_get(self): builder = self.get_builder() builder.get() self.assertEqual(builder.to_sql(), "SELECT * FROM [users]") def test_first(self): builder = self.get_builder().first(query=True) self.assertEqual(builder.to_sql(), "SELECT TOP 1 * FROM [users]") def test_select(self): builder = self.get_builder() builder.select("name", "email") self.assertEqual( builder.to_sql(), "SELECT [users].[name], [users].[email] FROM [users]" ) def test_add_select_no_table(self): builder = self.get_builder(table=None) builder.add_select( "other_test", lambda q: q.max("updated_at").table("different_table") ).add_select("some_alias", lambda q: q.max("updated_at").table("another_table")) self.assertEqual( builder.to_sql(), ( "SELECT " "(SELECT MAX([different_table].[updated_at]) AS updated_at FROM [different_table]) AS other_test, " "(SELECT MAX([another_table].[updated_at]) AS updated_at FROM [another_table]) AS some_alias" ), ) def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") self.assertEqual( builder.to_sql(), "SELECT count(email) as email_count FROM [users]" ) def test_create(self): builder = self.get_builder().without_global_scopes() builder.create( {"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True ) self.assertEqual( builder.to_sql(), "INSERT INTO [users] ([users].[name], [users].[email]) VALUES ('Corentin All', 'corentin@yopmail.com')", ) def test_delete(self): builder = self.get_builder() builder.delete("name", "Joe", query=True) self.assertEqual( builder.to_sql(), "DELETE FROM [users] WHERE [users].[name] = 'Joe'" ) def test_where(self): builder = self.get_builder() builder.where("name", "Joe") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[name] = 'Joe'" ) def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") self.assertEqual(builder.to_sql(), "SELECT * FROM [users] WHERE EXISTS 'name'") def test_limit(self): builder = self.get_builder() builder.limit(5) self.assertEqual(builder.to_sql(), "SELECT TOP 5 * FROM [users]") def test_offset(self): builder = self.get_builder() builder.offset(5) self.assertEqual( builder.to_sql(), "SELECT * FROM [users] OFFSET 5 ROWS FETCH NEXT 1 ROWS ONLY", ) def test_join(self): builder = self.get_builder() builder.join("profiles", "users.id", "=", "profiles.user_id") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] INNER JOIN [profiles] ON [users].[id] = [profiles].[user_id]", ) def test_left_join(self): builder = self.get_builder() builder.left_join("profiles", "users.id", "=", "profiles.user_id") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] LEFT JOIN [profiles] ON [users].[id] = [profiles].[user_id]", ) def test_right_join(self): builder = self.get_builder() builder.right_join("profiles", "users.id", "=", "profiles.user_id") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] RIGHT JOIN [profiles] ON [users].[id] = [profiles].[user_id]", ) def test_update(self): builder = self.get_builder().update( {"name": "Joe", "email": "joe@yopmail.com"}, dry=True ) self.assertEqual( builder.to_sql(), "UPDATE [users] SET [users].[name] = 'Joe', [users].[email] = 'joe@yopmail.com'", ) # def test_increment(self): # builder = self.get_builder() # builder.increment("age", 1) # self.assertEqual( # builder.to_sql(), "UPDATE [users] SET [users].[age] = [users].[age] + '1'" # ) # def test_decrement(self): # builder = self.get_builder() # builder.decrement("age", 1) # self.assertEqual( # builder.to_sql(), "UPDATE [users] SET [users].[age] = [users].[age] - '1'" # ) def test_count(self): builder = self.get_builder() builder.count("id") self.assertEqual( builder.to_sql(), "SELECT COUNT([users].[id]) AS id FROM [users]" ) def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC") def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC" ) def test_where_column(self): builder = self.get_builder() builder.where_column("name", "username") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[name] = [users].[username]", ) def test_where_not_in(self): builder = self.get_builder() builder.where_not_in("id", [1, 2, 3]) self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[id] NOT IN ('1','2','3')", ) def test_between(self): builder = self.get_builder() builder.between("id", 2, 5) self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[id] BETWEEN '2' AND '5'", ) def test_not_between(self): builder = self.get_builder() builder.not_between("id", 2, 5) self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[id] NOT BETWEEN '2' AND '5'", ) def test_where_in(self): builder = self.get_builder() builder.where_in("id", [1, 2, 3]) self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[id] IN ('1','2','3')", ) def test_where_null(self): builder = self.get_builder() builder.where_null("name") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[name] IS NULL" ) def test_where_not_null(self): builder = self.get_builder() builder.where_not_null("name") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[name] IS NOT NULL" ) def test_having(self): builder = self.get_builder(table="payments") builder.select("user_id").avg("salary").group_by("user_id").having( "salary", ">=", "1000" ) self.assertEqual( builder.to_sql(), "SELECT [payments].[user_id], AVG([payments].[salary]) AS salary FROM [payments] GROUP BY [payments].[user_id] HAVING [payments].[salary] >= '1000'", ) def test_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id") self.assertEqual( builder.to_sql(), "SELECT [payments].[user_id], MIN([payments].[salary]) AS salary FROM [payments] GROUP BY [payments].[user_id]", ) def test_builder_alone(self): self.assertTrue( QueryBuilder( connection_class=MockConnection, connection="mssql", connection_details={ "default": "mssql", "mssql": { "driver": "mssql", "host": "localhost", "user": "root", "password": "root", "database": "orm", "port": "5432", "prefix": "", "grammar": "mssql", }, }, ).table("users") ) def test_where_lt(self): builder = self.get_builder() builder.where("age", "<", "20") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] < '20'" ) def test_where_lte(self): builder = self.get_builder() builder.where("age", "<=", "20") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] <= '20'" ) def test_where_gt(self): builder = self.get_builder() builder.where("age", ">", "20") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] > '20'" ) def test_where_gte(self): builder = self.get_builder() builder.where("age", ">=", "20") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] >= '20'" ) def test_where_ne(self): builder = self.get_builder() builder.where("age", "!=", "20") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] != '20'" ) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) self.assertEqual( builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] = '20' OR [users].[age] < '20'", ) def test_can_call_with_schema(self): builder = self.get_builder() sql = ( builder.table("information_schema.columns") .select("table_name") .where("table_name", "users") .to_sql() ) self.assertEqual( sql, """SELECT [information_schema].[columns].[table_name] FROM [information_schema].[columns] WHERE [information_schema].[columns].[table_name] = 'users'""", ) def test_truncate(self): builder = self.get_builder(dry=True) sql = builder.truncate() self.assertEqual(sql, "TRUNCATE TABLE [users]") def test_truncate_without_foreign_keys(self): builder = self.get_builder(dry=True) sql = builder.truncate(foreign_keys=True) self.assertEqual(sql, "TRUNCATE TABLE [users]") def test_latest(self): builder = self.get_builder() builder.latest("email") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC" ) def test_latest_multiple(self): builder = self.get_builder() builder.latest("email", "created_at") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC, [created_at] DESC", ) def test_oldest(self): builder = self.get_builder() builder.oldest("email") self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC") def test_oldest_multiple(self): builder = self.get_builder() builder.oldest("email", "created_at") self.assertEqual( builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC, [created_at] ASC", ) ================================================ FILE: tests/mssql/builder/test_mssql_query_builder_relationships.py ================================================ import inspect import unittest from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MSSQLGrammar from src.masoniteorm.relationships import belongs_to from tests.utils import MockConnectionFactory from dotenv import load_dotenv load_dotenv(".env") class Logo(Model): __connection__ = "mssql" class Article(Model): __connection__ = "mssql" @belongs_to("id", "article_id") def logo(self): return Logo class Profile(Model): __connection__ = "mssql" class User(Model): __connection__ = "mssql" @belongs_to("id", "user_id") def articles(self): return Article @belongs_to("id", "user_id") def profile(self): return Profile @belongs_to("id", "parent_dynamic_id") def parent_dynamic(self): return self.__class__ @belongs_to("id", "parent_specified_id") def parent_specified(self): return User class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users"): connection = MockConnectionFactory().make("mssql") return QueryBuilder( grammar=MSSQLGrammar, connection_class=connection, connection="mssql", table=table, model=User(), ) def test_has(self): builder = self.get_builder() sql = builder.has("articles").to_sql() self.assertEqual( sql, """SELECT * FROM [users] WHERE EXISTS (""" """SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id]""" """)""", ) def test_has_reference_to_self(self): builder = self.get_builder() sql = builder.has("parent_dynamic").to_sql() self.assertEqual( sql, """SELECT * FROM [users] WHERE EXISTS (""" """SELECT * FROM [users] WHERE [users].[parent_dynamic_id] = [users].[id]""" """)""", ) def test_has_reference_to_self_using_class(self): builder = self.get_builder() sql = builder.has("parent_specified").to_sql() self.assertEqual( sql, """SELECT * FROM [users] WHERE EXISTS (""" """SELECT * FROM [users] WHERE [users].[parent_specified_id] = [users].[id]""" """)""", ) def test_where_has_query(self): builder = self.get_builder() sql = builder.where_has("articles", lambda q: q.where("active", 1)).to_sql() self.assertEqual( sql, """SELECT * FROM [users] WHERE EXISTS (""" """SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id] AND [articles].[active] = '1'""" """)""", ) def test_relationship_multiple_has(self): to_sql = User.has("articles", "profile").to_sql() self.assertEqual( to_sql, """SELECT * FROM [users] WHERE EXISTS (""" """SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id]""" """) AND EXISTS (""" """SELECT * FROM [profiles] WHERE [profiles].[user_id] = [users].[id]""" """)""", ) def test_relationship_multiple_has_calls(self): to_sql = User.has("articles").has("profile").to_sql() self.assertEqual( to_sql, """SELECT * FROM [users] WHERE EXISTS (""" """SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id]""" """) AND EXISTS (""" """SELECT * FROM [profiles] WHERE [profiles].[user_id] = [users].[id]""" """)""", ) def test_nested_has(self): to_sql = User.has("articles.logo").to_sql() self.assertEqual( to_sql, """SELECT * FROM [users] WHERE EXISTS (SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id] AND EXISTS (SELECT * FROM [logos] WHERE [logos].[article_id] = [articles].[id]))""", ) ================================================ FILE: tests/mssql/grammar/test_mssql_delete_grammar.py ================================================ import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MSSQLGrammar class TestMySQLDeleteGrammar(unittest.TestCase): def setUp(self): self.builder = QueryBuilder(MSSQLGrammar, table="users") def test_can_compile_delete(self): to_sql = self.builder.delete("id", 1, query=True).to_sql() sql = "DELETE FROM [users] WHERE [users].[id] = '1'" self.assertEqual(to_sql, sql) def test_can_compile_delete_with_where(self): to_sql = ( self.builder.where("age", 20) .where("profile", 1) .set_action("delete") .delete(query=True) .to_sql() ) sql = ( "DELETE FROM [users] WHERE [users].[age] = '20' AND [users].[profile] = '1'" ) self.assertEqual(to_sql, sql) ================================================ FILE: tests/mssql/grammar/test_mssql_insert_grammar.py ================================================ import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MSSQLGrammar class TestMySQLInsertGrammar(unittest.TestCase): def setUp(self): self.builder = QueryBuilder(MSSQLGrammar, table="users") def test_can_compile_insert(self): to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() sql = "INSERT INTO [users] ([users].[name]) VALUES ('Joe')" self.assertEqual(to_sql, sql) def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( # These keys are intentionally out of order to show column to value alignment works [ {"name": "Joe", "age": 5}, {"age": 35, "name": "Bill"}, {"name": "John", "age": 10}, ], query=True, ).to_sql() sql = "INSERT INTO [users] ([age], [name]) VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')" self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): to_sql = self.builder.bulk_create( [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True ).to_qmark() sql = "INSERT INTO [users] ([name]) VALUES ('?'), ('?'), ('?')" self.assertEqual(to_sql, sql) ================================================ FILE: tests/mssql/grammar/test_mssql_qmark.py ================================================ import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MSSQLGrammar class TestMSSQLQmark(unittest.TestCase): def setUp(self): self.builder = QueryBuilder(MSSQLGrammar, table="users") def test_can_compile_select(self): mark = self.builder.select("username").where("name", "Joe") sql = "SELECT [users].[username] FROM [users] WHERE [users].[name] = '?'" self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, ["Joe"]) def test_can_compile_update(self): mark = self.builder.update({"name": "Bob"}, dry=True).where("name", "Joe") sql = "UPDATE [users] SET [users].[name] = '?' WHERE [users].[name] = '?'" self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, ["Bob", "Joe"]) def test_can_compile_insert(self): mark = self.builder.create({"name": "Bob"}, query=True) sql = "INSERT INTO [users] ([users].[name]) VALUES ('?')" self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, ["Bob"]) def test_can_compile_where_in(self): mark = self.builder.where_in("id", [1, 2, 3]) qmark_sql = "SELECT * FROM [users] WHERE [users].[id] IN ('?', '?', '?')" sql = "SELECT * FROM [users] WHERE [users].[id] IN ('1','2','3')" self.assertEqual(mark.to_qmark(), qmark_sql) self.assertEqual(mark._bindings, [1, 2, 3]) self.assertEqual(self.builder.where_in("id", [1, 2, 3]).to_sql(), sql) self.builder.reset() # Assert that when passed string values it generates synonymous sql self.assertEqual(self.builder.where_in("id", ["1", "2", "3"]).to_sql(), sql) ================================================ FILE: tests/mssql/grammar/test_mssql_select_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query.grammars import MSSQLGrammar from src.masoniteorm.testing import BaseTestCaseSelectGrammar class TestMSSQLGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): grammar = MSSQLGrammar def can_compile_select(self): """ self.builder.to_sql() """ return "SELECT * FROM [users]" def can_compile_with_columns(self): """ self.builder.select('username', 'password').to_sql() """ return "SELECT [users].[username], [users].[password] FROM [users]" def can_compile_with_where(self): """ self.builder.select('username', 'password').where('id', 1).to_sql() """ return "SELECT [users].[username], [users].[password] FROM [users] WHERE [users].[id] = '1'" def can_compile_with_several_where(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql() """ return "SELECT [users].[username], [users].[password] FROM [users] WHERE [users].[id] = '1' AND [users].[username] = 'joe'" def can_compile_with_several_where_and_limit(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql() """ return "SELECT TOP 10 [users].[username], [users].[password] FROM [users] WHERE [users].[id] = '1' AND [users].[username] = 'joe'" def can_compile_with_sum(self): """ self.builder.sum('age').to_sql() """ return "SELECT SUM([users].[age]) AS age FROM [users]" def can_compile_order_by_and_first(self): """ self.builder.order_by('id', 'asc').first() """ return """SELECT TOP 1 * FROM [users] ORDER BY [id] ASC""" def can_compile_with_max(self): """ self.builder.max('age').to_sql() """ return "SELECT MAX([users].[age]) AS age FROM [users]" def can_compile_with_max_and_columns(self): """ self.builder.select('username').max('age').to_sql() """ return "SELECT [users].[username], MAX([users].[age]) AS age FROM [users]" def can_compile_with_max_and_columns_different_order(self): """ self.builder.max('age').select('username').to_sql() """ return "SELECT [users].[username], MAX([users].[age]) AS age FROM [users]" def can_compile_with_order_by(self): """ self.builder.select('username').order_by('age', 'desc').to_sql() """ return "SELECT [users].[username] FROM [users] ORDER BY [age] DESC" def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ return "SELECT [users].[username] FROM [users] ORDER BY [age] DESC, [name] ASC" def can_compile_with_group_by(self): """ self.builder.select('username').group_by('age').to_sql() """ return "SELECT [users].[username] FROM [users] GROUP BY [users].[age]" def can_compile_where_in(self): """ self.builder.select('username').where_in('age', [1,2,3]).to_sql() """ return "SELECT [users].[username] FROM [users] WHERE [users].[age] IN ('1','2','3')" def can_compile_where_in_empty(self): """ self.builder.where_in('age', []).to_sql() """ return """SELECT * FROM [users] WHERE 0 = 1""" def can_compile_where_null(self): """ self.builder.select('username').where_null('age').to_sql() """ return "SELECT [users].[username] FROM [users] WHERE [users].[age] IS NULL" def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ return "SELECT [users].[username] FROM [users] WHERE [users].[age] IS NOT NULL" def can_compile_where_raw(self): """ self.builder.where_raw("`age` = '18'").to_sql() """ return "SELECT * FROM [users] WHERE [users].[age] = '18'" def test_can_compile_where_raw_and_where_with_multiple_bindings(self): query = self.builder.where_raw( "[age] = '?' AND [is_admin] = '?'", [18, True] ).where("email", "test@example.com") self.assertEqual( query.to_qmark(), "SELECT * FROM [users] WHERE [age] = '?' AND [is_admin] = '?' AND [users].[email] = '?'", ) self.assertEqual(query._bindings, [18, True, "test@example.com"]) def can_compile_select_raw(self): """ self.builder.select_raw("COUNT(*)").to_sql() """ return "SELECT COUNT(*) FROM [users]" def can_compile_limit_and_offset(self): """ self.builder.limit(10).offset(10).to_sql() """ return "SELECT * FROM [users] OFFSET 10 ROWS FETCH NEXT 10 ROWS ONLY" def can_compile_select_raw_with_select(self): """ self.builder.select('id').select_raw("COUNT(*)").to_sql() """ return "SELECT [users].[id], COUNT(*) FROM [users]" def can_compile_having_raw(self): """ self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql() """ return "SELECT COUNT(*) as counts FROM [users] HAVING counts > 18" def can_compile_count(self): """ self.builder.count().to_sql() """ return "SELECT COUNT(*) AS m_count_reserved FROM [users]" def can_compile_count_column(self): """ self.builder.count().to_sql() """ return "SELECT COUNT([users].[money]) AS money FROM [users]" def can_compile_where_column(self): """ self.builder.where_column('name', 'email').to_sql() """ return "SELECT * FROM [users] WHERE [users].[name] = [users].[email]" def can_compile_or_where(self): """ self.builder.where('name', 2).or_where('name', 3).to_sql() """ return ( "SELECT * FROM [users] WHERE [users].[name] = '2' OR [users].[name] = '3'" ) def can_grouped_where(self): """ self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql() """ return "SELECT * FROM [users] WHERE ([users].[age] = '2' AND [users].[name] = 'Joe')" def can_compile_sub_select(self): """ self.builder.where_in('name', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age') ).to_sql() """ return "SELECT * FROM [users] WHERE [users].[name] IN (SELECT [users].[age] FROM [users])" def can_compile_sub_select_where(self): """ self.builder.where_in('age', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe') ).to_sql() """ return "SELECT * FROM [users] WHERE [users].[age] IN (SELECT [users].[age] FROM [users] WHERE [users].[age] = '2' AND [users].[name] = 'Joe')" def can_compile_sub_select_value(self): """ self.builder.where('name', self.builder.new().sum('age') ).to_sql() """ return "SELECT * FROM [users] WHERE [users].[name] = (SELECT SUM([users].[age]) AS age FROM [users])" def can_compile_complex_sub_select(self): """ self.builder.where_in('name', (QueryBuilder(GrammarFactory.make(self.grammar), table='users') .select('age').where_in('email', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email') )) ).to_sql() """ return "SELECT * FROM [users] WHERE [users].[name] IN (SELECT [users].[age] FROM [users] WHERE [users].[email] IN (SELECT [users].[email] FROM [users]))" def can_compile_exists(self): """ self.builder.select('age').where_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return "SELECT [users].[age] FROM [users] WHERE EXISTS (SELECT [users].[username] FROM [users] WHERE [users].[age] = '12')" def can_compile_not_exists(self): """ self.builder.select('age').where_not_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return "SELECT [users].[age] FROM [users] WHERE NOT EXISTS (SELECT [users].[username] FROM [users] WHERE [users].[age] = '12')" def can_compile_having(self): """ builder.sum('age').group_by('age').having('age').to_sql() """ return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age]" def can_compile_having_order(self): """ builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() """ return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] ORDER [users].[age] DESC" def can_compile_between(self): """ builder.between('age', 18, 21).to_sql() """ return "SELECT * FROM [users] WHERE [users].[age] BETWEEN '18' AND '21'" def can_compile_not_between(self): """ builder.not_between('age', 18, 21).to_sql() """ return "SELECT * FROM [users] WHERE [users].[age] NOT BETWEEN '18' AND '21'" def can_compile_where_not_in(self): """ self.builder.select('username').where_not_in('age', [1,2,3]).to_sql() """ return "SELECT [users].[username] FROM [users] WHERE [users].[age] NOT IN ('1','2','3')" def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() """ return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] = '10'" def can_compile_having_with_greater_than_expression(self): """ builder.sum('age').group_by('age').having('age', '>', 10).to_sql() """ return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] > '10'" def can_compile_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return "SELECT * FROM [users] INNER JOIN [contacts] ON [users].[id] = [contacts].[user_id]" def can_compile_left_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return "SELECT * FROM [users] LEFT JOIN [contacts] ON [users].[id] = [contacts].[user_id]" def can_compile_multiple_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return "SELECT * FROM [users] INNER JOIN [contacts] ON [users].[id] = [contacts].[user_id] INNER JOIN [posts] ON [comments].[post_id] = [posts].[id]" def test_can_compile_where_raw(self): to_sql = self.builder.where_raw("[age] = '18'").to_sql() self.assertEqual(to_sql, "SELECT * FROM [users] WHERE [age] = '18'") def test_can_compile_having_raw(self): to_sql = ( self.builder.select_raw("COUNT(*) as counts") .having_raw("counts > 10") .to_sql() ) self.assertEqual( to_sql, "SELECT COUNT(*) as counts FROM [users] HAVING counts > 10" ) def test_can_compile_having_raw_order(self): to_sql = ( self.builder.select_raw("COUNT(*) as counts") .having_raw("counts > 10") .order_by_raw("counts DESC") .to_sql() ) self.assertEqual( to_sql, "SELECT COUNT(*) as counts FROM [users] HAVING counts > 10 ORDER BY counts DESC", ) def test_can_compile_select_raw(self): to_sql = self.builder.select_raw("COUNT(*)").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_select_raw_with_select(self): to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def can_compile_first_or_fail(self): """ builder = self.get_builder() builder.where("is_admin", "=", True).first_or_fail() """ return """SELECT TOP 1 * FROM [users] WHERE [users].[is_admin] = '1'""" def where_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%") """ return """SELECT * FROM [users] WHERE [users].[age] LIKE '%name%'""" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%") """ return """SELECT * FROM [users] WHERE [users].[age] NOT LIKE '%name%'""" def where_regexp(self): """ builder = self.get_builder() builder.where("age", "regexp", "Joe") """ return """SELECT * FROM [users] WHERE [users].[age] LIKE 'Joe'""" def where_not_regexp(self): """ builder = self.get_builder() builder.where("age", "not regexp", "Joe") """ return """SELECT * FROM [users] WHERE [users].[age] NOT LIKE 'Joe'""" def can_compile_join_clause(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on("bgt.fund", "=", "rg.fund") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] AND [bgt].[dept] = [rg].[dept] AND [bgt].[acct] = [rg].[acct] AND [bgt].[sub] = [rg].[sub]" def can_compile_join_clause_with_value(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [bgt].[active] = '1' OR [bgt].[acct] = '1234'" def can_compile_join_clause_with_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") .or_on_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [acct] IS NULL OR [dept] IS NULL AND [rg].[abc] = '10'" def can_compile_join_clause_with_not_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [acct] IS NOT NULL OR [dept] IS NOT NULL AND [rg].[abc] = '10'" def can_compile_join_clause_with_lambda(self): """ builder = self.get_builder() builder.join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .on_null("bgt") ), ).to_sql() """ return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] AND [bgt] IS NULL" def can_compile_left_join_clause_with_lambda(self): """ builder = self.get_builder() builder.left_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return "SELECT * FROM [users] LEFT JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] OR [bgt] IS NULL" def can_compile_right_join_clause_with_lambda(self): """ builder = self.get_builder() builder.right_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return "SELECT * FROM [users] RIGHT JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] OR [bgt] IS NULL" def shared_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return "SELECT * FROM [users] WITH(ROWLOCK) WHERE [users].[votes] >= '100'" def update_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return "SELECT * FROM [users] WITH(ROWLOCK) WHERE [users].[votes] >= '100'" def can_user_where_raw_and_where(self): """ builder.where_raw("`age` = '18'").where("name", "=", "James").to_sql() """ return "SELECT * FROM [users] WHERE age = '18' AND [users].[name] = 'James'" def where_exists_with_lambda(self): return """SELECT * FROM [users] WHERE EXISTS (SELECT * FROM [users] WHERE [users].[age] = '1')""" def where_not_exists_with_lambda(self): return """SELECT * FROM [users] WHERE NOT EXISTS (SELECT * FROM [users] WHERE [users].[age] = '1')""" def where_date(self): return ( """SELECT * FROM [users] WHERE DATE([users].[created_at]) = '2022-06-01'""" ) def or_where_null(self): return """SELECT * FROM [users] WHERE [users].[column1] IS NULL OR [users].[column2] IS NULL""" def select_distinct(self): return """SELECT DISTINCT [users].[group] FROM [users]""" ================================================ FILE: tests/mssql/grammar/test_mssql_update_grammar.py ================================================ import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MSSQLGrammar from src.masoniteorm.expressions import Raw class TestMSSQLUpdateGrammar(unittest.TestCase): def setUp(self): self.builder = QueryBuilder(MSSQLGrammar, table="users") def test_can_compile_update(self): to_sql = ( self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() ) sql = "UPDATE [users] SET [users].[name] = 'Joe' WHERE [users].[name] = 'bob'" self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): to_sql = ( self.builder.where("name", "bob") .where("age", 20) .update({"name": "Joe"}, dry=True) .to_sql() ) sql = "UPDATE [users] SET [users].[name] = 'Joe' WHERE [users].[name] = 'bob' AND [users].[age] = '20'" self.assertEqual(to_sql, sql) def test_raw_expression(self): to_sql = self.builder.update({"name": Raw("[username]")}, dry=True).to_sql() sql = "UPDATE [users] SET [users].[name] = [username]" self.assertEqual(to_sql, sql) ================================================ FILE: tests/mssql/schema/test_mssql_schema_builder.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MSSQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MSSQLPlatform class TestMSSQLSchemaBuilder(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection_class=MSSQLConnection, connection="mssql", connection_details=DATABASES, platform=MSSQLPlatform, dry=True, ).on("mssql") def test_can_add_columns(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), ["CREATE TABLE [users] ([name] VARCHAR(255) NOT NULL, [age] INT NOT NULL)"], ) def test_can_add_tiny_text(self): with self.schema.create("users") as blueprint: blueprint.tiny_text("description") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ["CREATE TABLE [users] ([description] TINYTEXT NOT NULL)"], ) def test_can_add_unsigned_decimal(self): with self.schema.create("users") as blueprint: blueprint.unsigned_decimal("amount", 19, 4) self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ["CREATE TABLE [users] ([amount] DECIMAL(19, 4) NOT NULL)"], ) def test_can_add_columns_with_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") blueprint.unique("name") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] ([name] VARCHAR(255) NOT NULL, [age] INT NOT NULL, CONSTRAINT users_name_unique UNIQUE (name))" ], ) def test_can_have_float_type(self): with self.schema.create("users") as blueprint: blueprint.float("amount") self.assertEqual( blueprint.to_sql(), ["""CREATE TABLE [users] (""" """[amount] FLOAT(19, 4) NOT NULL)"""], ) def test_can_have_unsigned_columns(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id").unsigned() blueprint.big_integer("big_profile_id").unsigned() blueprint.tiny_integer("tiny_profile_id").unsigned() blueprint.small_integer("small_profile_id").unsigned() blueprint.medium_integer("medium_profile_id").unsigned() self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] (" "[profile_id] INT NOT NULL, " "[big_profile_id] BIGINT NOT NULL, " "[tiny_profile_id] TINYINT NOT NULL, " "[small_profile_id] SMALLINT NOT NULL, " "[medium_profile_id] MEDIUMINT NOT NULL)" ], ) def test_can_add_columns_with_foreign_key_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.foreign("profile_id").references("id").on("profiles") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] " "([name] VARCHAR(255) NOT NULL, " "[age] INT NOT NULL, " "[profile_id] INT NOT NULL, " "CONSTRAINT users_name_unique UNIQUE (name), " "CONSTRAINT users_profile_id_foreign FOREIGN KEY ([profile_id]) REFERENCES [profiles]([id]))" ], ) def test_can_add_columns_with_add_foreign_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.add_foreign("profile_id.id.profiles") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] " "([name] VARCHAR(255) NOT NULL, " "[age] INT NOT NULL, " "[profile_id] INT NOT NULL, " "CONSTRAINT users_name_unique UNIQUE (name), " "CONSTRAINT users_profile_id_foreign FOREIGN KEY ([profile_id]) REFERENCES [profiles]([id]))" ], ) def test_can_advanced_table_creation(self): with self.schema.create("users") as blueprint: blueprint.increments("id") blueprint.string("name") blueprint.string("email").unique() blueprint.string("password") blueprint.integer("admin").default(0) blueprint.string("remember_token").nullable() blueprint.timestamp("verified_at").nullable() blueprint.timestamp("registered_at").default_raw("CURRENT_TIMESTAMP") blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 10) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] ([id] INT IDENTITY NOT NULL, [name] VARCHAR(255) NOT NULL, [email] VARCHAR(255) NOT NULL, " "[password] VARCHAR(255) NOT NULL, [admin] INT NOT NULL DEFAULT 0, [remember_token] VARCHAR(255) NULL, " "[verified_at] DATETIME NULL, [registered_at] DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, [created_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, " "[updated_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_email_unique UNIQUE (email))" ], ) def test_can_advanced_table_creation2(self): with self.schema.create("users") as blueprint: blueprint.increments("id") blueprint.enum("gender", ["male", "female"]) blueprint.string("name") blueprint.string("duration") blueprint.string("url") blueprint.inet("last_address").nullable() blueprint.cidr("route_origin").nullable() blueprint.macaddr("mac_address").nullable() blueprint.datetime("published_at") blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.integer("author_id").unsigned().nullable() blueprint.foreign("author_id").references("id").on("users").on_delete( "CASCADE" ) blueprint.text("description") blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 15) self.assertEqual( blueprint.to_sql(), ( [ "CREATE TABLE [users] ([id] INT IDENTITY NOT NULL, [gender] VARCHAR(255) NOT NULL CHECK([gender] IN ('male', 'female')), [name] VARCHAR(255) NOT NULL, [duration] VARCHAR(255) NOT NULL, " "[url] VARCHAR(255) NOT NULL, [last_address] VARCHAR(255) NULL, [route_origin] VARCHAR(255) NULL, [mac_address] VARCHAR(255) NULL, [published_at] DATETIME NOT NULL, [thumbnail] VARCHAR(255) NULL, [premium] INT NOT NULL, " "[author_id] INT NULL, [description] TEXT NOT NULL, [created_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, " "[updated_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, " "CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY ([author_id]) REFERENCES [users]([id]) ON DELETE CASCADE)" ] ), ) def test_can_add_columns_with_foreign_key_constraint_name(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id") blueprint.foreign("profile_id", name="profile_foreign").references("id").on( "profiles" ) self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] (" "[profile_id] INT NOT NULL, " "CONSTRAINT profile_foreign FOREIGN KEY ([profile_id]) REFERENCES [profiles]([id]))" ], ) def test_can_have_composite_keys(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.primary(["name", "age"]) self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] " "([name] VARCHAR(255) NOT NULL, " "[age] INT NOT NULL, " "[profile_id] INT NOT NULL, " "CONSTRAINT users_name_unique UNIQUE (name), " "CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))" ], ) def test_can_have_column_primary_key(self): with self.schema.create("users") as blueprint: blueprint.string("name").primary() blueprint.integer("age") blueprint.integer("profile_id") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] " "([name] VARCHAR(255) NOT NULL, " "[age] INT NOT NULL, " "[profile_id] INT NOT NULL, " "CONSTRAINT users_name_primary PRIMARY KEY (name))" ], ) def test_has_table(self): schema_sql = self.schema.has_table("users") sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'users'" self.assertEqual(schema_sql, sql) def test_can_truncate(self): sql = self.schema.truncate("users") self.assertEqual(sql, "TRUNCATE TABLE [users]") def test_can_rename_table(self): sql = self.schema.rename("users", "clients") self.assertEqual(sql, "EXEC sp_rename [users], [clients]") def test_can_drop_table_if_exists(self): sql = self.schema.drop_table_if_exists("users", "clients") self.assertEqual(sql, "DROP TABLE IF EXISTS [users]") def test_can_drop_table(self): sql = self.schema.drop_table("users", "clients") self.assertEqual(sql, "DROP TABLE [users]") def test_has_column(self): sql = self.schema.has_column("users", "name") self.assertEqual( sql, "SELECT 1 FROM sys.columns WHERE Name = N'name' AND Object_ID = Object_ID(N'users')", ) def test_can_enable_foreign_keys(self): sql = self.schema.enable_foreign_key_constraints() self.assertEqual(sql, "") def test_can_disable_foreign_keys(self): sql = self.schema.disable_foreign_key_constraints() self.assertEqual(sql, "") def test_can_truncate_without_foreign_keys(self): sql = self.schema.truncate("users", foreign_keys=True) self.assertEqual( sql, [ "ALTER TABLE [users] NOCHECK CONSTRAINT ALL", "TRUNCATE TABLE [users]", "ALTER TABLE [users] WITH CHECK CHECK CONSTRAINT ALL", ], ) def test_can_add_enum(self): with self.schema.create("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE [users] ([status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive')))" ], ) def test_can_change_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active").change() self.assertEqual(len(blueprint.table.changed_columns), 1) self.assertEqual( blueprint.to_sql(), [ "ALTER TABLE [users] ALTER COLUMN [status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive'))" ], ) ================================================ FILE: tests/mssql/schema/test_mssql_schema_builder_alter.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MSSQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MSSQLPlatform from src.masoniteorm.schema.Table import Table class TestMySQLSchemaBuilderAlter(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection_class=MSSQLConnection, connection="mssql", connection_details=DATABASES, platform=MSSQLPlatform, dry=True, ) def test_can_add_columns(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) sql = [ "ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL, [age] INT NOT NULL" ] self.assertEqual(blueprint.to_sql(), sql) def test_can_adds_column_with_default(self): with self.schema.table("users") as blueprint: blueprint.string("name").default(0) self.assertEqual(len(blueprint.table.added_columns), 1) sql = ["ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL DEFAULT 0"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_rename(self): with self.schema.table("users") as blueprint: blueprint.rename("post", "comment", "integer") table = Table("users") table.add_column("post", "integer") blueprint.table.from_table = table sql = ["EXEC sp_rename 'users.post', 'comment', 'COLUMN'"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_and_rename(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.rename("post", "comment", "integer") table = Table("users") table.add_column("post", "integer") blueprint.table.from_table = table sql = [ "ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL", "EXEC sp_rename 'users.post', 'comment', 'COLUMN'", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop1(self): with self.schema.table("users") as blueprint: blueprint.drop_column("post") sql = ["ALTER TABLE [users] DROP COLUMN post"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( "cascade" ) sql = [ "ALTER TABLE [users] ADD [playlist_id] INT NULL", "ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY ([playlist_id]) REFERENCES [playlists]([id]) ON DELETE CASCADE", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_column_and_add_foreign(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() blueprint.add_foreign("playlist_id.id.playlists").on_delete("cascade") sql = [ "ALTER TABLE [users] ADD [playlist_id] INT NULL", "ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY ([playlist_id]) REFERENCES [playlists]([id]) ON DELETE CASCADE", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign("users_playlist_id_foreign") sql = ["ALTER TABLE [users] DROP CONSTRAINT users_playlist_id_foreign"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_foreign_key_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign(["playlist_id"]) sql = ["ALTER TABLE [users] DROP CONSTRAINT users_playlist_id_foreign"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_unique_constraint(self): with self.schema.table("users") as blueprint: blueprint.drop_unique("users_playlist_id_unique") sql = ["DROP INDEX [users].[users_playlist_id_unique]"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") sql = [ "ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)" ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_index(self): with self.schema.table("users") as blueprint: blueprint.index("playlist_id") sql = ["CREATE INDEX users_playlist_id_index ON [users](playlist_id)"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_index(self): with self.schema.table("users") as blueprint: blueprint.drop_index("users_playlist_id_index") sql = ["DROP INDEX [users].[users_playlist_id_index]"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_index_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_index(["playlist_id"]) sql = ["DROP INDEX [users].[users_playlist_id_index]"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_unique_constraint_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_unique(["playlist_id"]) sql = ["DROP INDEX [users].[users_playlist_id_unique]"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_primary(self): with self.schema.table("users") as blueprint: blueprint.drop_primary(["id"]) sql = ["DROP INDEX [users].[users_id_primary]"] self.assertEqual(blueprint.to_sql(), sql) def test_has_table(self): schema_sql = self.schema.has_table("users") sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'users'" self.assertEqual(schema_sql, sql) def test_drop_table(self): schema_sql = self.schema.has_table("users") sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'users'" self.assertEqual(schema_sql, sql) def test_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").change() blueprint.string("name") blueprint.string("external_type").default("external") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = [ "ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL, [external_type] VARCHAR(255) NOT NULL DEFAULT 'external'", "ALTER TABLE [users] ALTER COLUMN [age] INT NOT NULL", ] self.assertEqual(blueprint.to_sql(), sql) def test_drop_add_and_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").default(0).change() blueprint.string("name") blueprint.drop_column("email") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") table.add_column("email", "string") blueprint.table.from_table = table sql = [ "ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL", "ALTER TABLE [users] ALTER COLUMN [age] INT NOT NULL DEFAULT 0", "ALTER TABLE [users] DROP COLUMN email", ] self.assertEqual(blueprint.to_sql(), sql) def test_can_create_indexes(self): with self.schema.table("users") as blueprint: blueprint.index("name") blueprint.index(["name", "email"]) blueprint.unique("name") blueprint.unique(["name", "email"]) blueprint.fulltext("description") self.assertEqual(len(blueprint.table.added_columns), 0) print(blueprint.to_sql()) self.assertEqual( blueprint.to_sql(), [ "CREATE INDEX users_name_index ON [users](name)", "CREATE INDEX users_name_email_index ON [users](name,email)", "ALTER TABLE [users] ADD CONSTRAINT users_name_unique UNIQUE(name)", "ALTER TABLE [users] ADD CONSTRAINT users_name_email_unique UNIQUE(name,email)", ], ) def test_timestamp_alter_add_nullable_column(self): with self.schema.table("users") as blueprint: blueprint.timestamp("due_date").nullable() self.assertEqual(len(blueprint.table.added_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = ["ALTER TABLE [users] ADD [due_date] DATETIME NULL"] self.assertEqual(blueprint.to_sql(), sql) def test_can_add_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ "ALTER TABLE [users] ADD [status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive'))" ] self.assertEqual(blueprint.to_sql(), sql) ================================================ FILE: tests/mysql/builder/test_mysql_builder_transaction.py ================================================ import inspect import os import unittest from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.relationships import belongs_to from tests.utils import MockConnectionFactory from tests.integrations.config.database import DB if os.getenv("RUN_MYSQL_DATABASE") == "True": class User(Model): __connection__ = "mysql" __timestamps__ = False class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users"): connection = ConnectionFactory().make("mysql") return QueryBuilder( grammar=MySQLGrammar, connection=connection, table=table ).on("mysql") def test_transaction(self): builder = self.get_builder() builder.begin() builder.create({"name": "phillip2", "email": "phillip2"}) # builder.commit() user = builder.where("name", "phillip2").first() self.assertEqual(user["name"], "phillip2") builder.rollback() user = builder.where("name", "phillip2").first() self.assertEqual(user, None) def test_transaction_default_globally(self): connection = DB.begin_transaction() self.assertEqual(connection, self.get_builder().new_connection()) DB.commit() DB.begin_transaction() DB.rollback() ================================================ FILE: tests/mysql/builder/test_query_builder.py ================================================ import datetime import inspect import unittest from src.masoniteorm.exceptions import InvalidArgument from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.relationships import has_many from src.masoniteorm.scopes import SoftDeleteScope from tests.integrations.config.database import DATABASES from tests.utils import MockConnectionFactory class Articles(Model): pass class User(Model): __timestamps__ = False @has_many("id", "user_id") def articles(self): return Articles class BaseTestQueryBuilder: maxDiff = None def get_builder(self, table="users", dry=True): connection = MockConnectionFactory().make("default") return QueryBuilder( grammar=self.grammar, connection_class=connection, connection="mysql", table=table, model=User(), dry=dry, connection_details=DATABASES, ) def test_sum(self): builder = self.get_builder() builder.sum("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_sum_chained(self): builder = self.get_builder() builder.sum("age").max("salary") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_with_(self): builder = self.get_builder() builder.with_("articles").sum("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_max(self): builder = self.get_builder() builder.max("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_min(self): builder = self.get_builder() builder.min("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_avg(self): builder = self.get_builder() builder.avg("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_all(self): builder = self.get_builder() builder.all() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_get(self): builder = self.get_builder() builder.get() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_first(self): builder = self.get_builder().first(query=True) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_find_with_model(self): builder = self.get_builder() builder.find(1000, query=True) sql = '''SELECT * FROM `users` WHERE `users`.`id` = '1000\'''' self.assertEqual(builder.to_sql(), sql) def test_find_with_model_and_list(self): builder = self.get_builder() builder.find([1000, 2000, 3000], query=True) sql = '''SELECT * FROM `users` WHERE `users`.`id` IN ('1000','2000','3000')''' self.assertEqual(builder.to_sql(), sql) def test_find_with_model_custom_column(self): builder = self.get_builder() builder.find(10, column="age", query=True) sql = '''SELECT * FROM `users` WHERE `users`.`age` = '10\'''' self.assertEqual(builder.to_sql(), sql) def test_find_with_builder(self): builder = self.get_builder() builder._model = None builder.find(10, column="age", query=True) sql = '''SELECT * FROM `users` WHERE `users`.`age` = '10\'''' self.assertEqual(builder.to_sql(), sql) def test_find_with_builder_and_list(self): builder = self.get_builder() builder._model = None builder.find([10, 20, 30], column="age", query=True) sql = '''SELECT * FROM `users` WHERE `users`.`age` IN ('10','20','30')''' self.assertEqual(builder.to_sql(), sql) def test_find_with_builder_without_column(self): builder = self.get_builder() builder._model = None with self.assertRaises(InvalidArgument): builder.find(10, query=True) def test_select(self): builder = self.get_builder() builder.select("name", "email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select_with_table(self): builder = self.get_builder() builder.select("users.*") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select_with_table_raw(self): builder = self.get_builder() builder.select("users.*").from_raw("orders, customers") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select_with_alias(self): builder = self.get_builder() builder.select("users.username as name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_add_select(self): builder = self.get_builder() sql = ( builder.select("name") .add_select("phone_count", lambda q: q.count("*").table("phones")) .add_select("salary", lambda q: q.count("*").table("salary")) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_add_select_no_table(self): builder = self.get_builder(table=None) sql = ( builder.add_select( "other_test", lambda q: q.max("updated_at").table("different_table") ) .add_select( "some_alias", lambda q: q.max("updated_at").table("another_table") ) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_create(self): builder = self.get_builder().without_global_scopes() builder.create( {"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_delete(self): builder = self.get_builder() builder.delete("name", "Joe", query=True) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where(self): builder = self.get_builder() builder.where("name", "Joe") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_limit(self): builder = self.get_builder() builder.limit(5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_offset(self): builder = self.get_builder() builder.offset(5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_join(self): builder = self.get_builder() builder.join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_left_join(self): builder = self.get_builder() builder.left_join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_right_join(self): builder = self.get_builder() builder.right_join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_update(self): builder = self.get_builder().update( {"name": "Joe", "email": "joe@yopmail.com"}, dry=True ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) # def test_increment(self): # builder = self.get_builder() # builder.increment("age", 1) # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(builder.to_sql(), sql) # def test_decrement(self): # builder = self.get_builder() # builder.decrement("age", 1) # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(builder.to_sql(), sql) def test_count(self): builder = self.get_builder() builder.count("id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_column(self): builder = self.get_builder() builder.where_column("name", "username") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_in(self): builder = self.get_builder() builder.where_not_in("id", [1, 2, 3]) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_between(self): builder = self.get_builder() builder.between("id", 2, 5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_not_between(self): builder = self.get_builder() builder.not_between("id", 2, 5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_in(self): builder = self.get_builder() builder.where_in("id", [1, 2, 3]) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_null(self): builder = self.get_builder() builder.where_null("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_null(self): builder = self.get_builder() builder.where_not_null("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_having(self): builder = self.get_builder(table="payments") builder.select("user_id").avg("salary").group_by("user_id").having( "salary", ">=", "1000" ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_builder_alone(self): self.assertTrue( QueryBuilder( dry=True, connection_details={ "default": "mysql", "mysql": { "driver": "mysql", "host": "localhost", "username": "root", "password": "", "database": "orm", "port": "3306", "prefix": "", "grammar": "mysql", "options": {"charset": "utf8mb4"}, }, }, ).table("users") ) def test_where_lt(self): builder = self.get_builder() builder.where("age", "<", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_lte(self): builder = self.get_builder() builder.where("age", "<=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_gt(self): builder = self.get_builder() builder.where("age", ">", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_gte(self): builder = self.get_builder() builder.where("age", ">=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_ne(self): builder = self.get_builder() builder.where("age", "!=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_like_as_operator(self): builder = self.get_builder() builder.where("age", "like", "%name%") sql = getattr(self, "where_like")() self.assertEqual(builder.to_sql(), sql) def test_where_like(self): builder = self.get_builder() builder.where_like("age", "%name%") sql = getattr(self, "where_like")() self.assertEqual(builder.to_sql(), sql) def test_where_not_like_as_operator(self): builder = self.get_builder() builder.where("age", "not like", "%name%") sql = getattr(self, "where_not_like")() self.assertEqual(builder.to_sql(), sql) def test_where_not_like(self): builder = self.get_builder() builder.where_not_like("age", "%name%") sql = getattr(self, "where_not_like")() self.assertEqual(builder.to_sql(), sql) def test_can_call_with_multi_tables(self): builder = self.get_builder() sql = ( builder.table("information_schema.columns") .select("table_name") .where("table_name", "users") .to_sql() ) self.assertEqual( sql, """SELECT `information_schema`.`columns`.`table_name` FROM `information_schema`.`columns` WHERE `information_schema`.`columns`.`table_name` = 'users'""", ) def test_truncate(self): builder = self.get_builder(dry=True) sql = builder.truncate() sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) def test_truncate_without_foreign_keys(self): builder = self.get_builder(dry=True) sql = builder.truncate(foreign_keys=True) sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) def test_shared_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).shared_lock().to_sql() sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) def test_update_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).lock_for_update().to_sql() sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) class MySQLQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase): grammar = MySQLGrammar def sum(self): """ builder = self.get_builder() builder.sum('age') """ return "SELECT SUM(`users`.`age`) AS age FROM `users`" def sum_chained(self): """ builder = self.get_builder() builder.sum('age') """ return "SELECT SUM(`users`.`age`) AS age, MAX(`users`.`salary`) AS salary FROM `users`" def with_(self): """ builder = self.get_builder() builder.with_('articles').sum('age') """ return "SELECT SUM(`users`.`age`) AS age FROM `users`" def max(self): """ builder = self.get_builder() builder.max('age') """ return "SELECT MAX(`users`.`age`) AS age FROM `users`" def min(self): """ builder = self.get_builder() builder.min('age') """ return "SELECT MIN(`users`.`age`) AS age FROM `users`" def avg(self): """ builder = self.get_builder() builder.avg('age') """ return "SELECT AVG(`users`.`age`) AS age FROM `users`" def first(self): """ builder = self.get_builder() builder.first() """ return "SELECT * FROM `users` LIMIT 1" def all(self): """ builder = self.get_builder() builder.all() """ return "SELECT * FROM `users`" def get(self): """ builder = self.get_builder() builder.get() """ return "SELECT * FROM `users`" def select(self): """ builder = self.get_builder() builder.select('name', 'email') """ return "SELECT `users`.`name`, `users`.`email` FROM `users`" def select_with_table(self): """ builder = self.get_builder() builder.select('users.*') """ return "SELECT `users`.* FROM `users`" def select_with_table_raw(self): """ builder = self.get_builder() builder.select('users.*') """ return "SELECT `users`.* FROM orders, customers" def select_with_alias(self): """ builder = self.get_builder() builder.select('users.name as name') """ return "SELECT `users`.`username` AS name FROM `users`" def select_raw(self): """ builder = self.get_builder() builder.select_raw('count(email) as email_count') """ return "SELECT count(email) as email_count FROM `users`" def add_select(self): """ builder = self.get_builder() builder.select('name', 'email') """ return "SELECT `users`.`name`, (SELECT COUNT(*) AS m_count_reserved FROM `phones`) AS phone_count, (SELECT COUNT(*) AS m_count_reserved FROM `salary`) AS salary FROM `users`" def add_select_no_table(self): """ builder = self.get_builder() builder.select('name', 'email') """ return ( "SELECT " "(SELECT MAX(`different_table`.`updated_at`) AS updated_at FROM `different_table`) AS other_test, " "(SELECT MAX(`another_table`.`updated_at`) AS updated_at FROM `another_table`) AS some_alias" ) def create(self): """ builder = get_builder() builder.create({"name": "Corentin All", 'email': 'corentin@yopmail.com'}) """ return "INSERT INTO `users` (`users`.`name`, `users`.`email`) VALUES ('Corentin All', 'corentin@yopmail.com')" def delete(self): """ builder = get_builder() builder.delete("name', 'Joe') """ return "DELETE FROM `users` WHERE `users`.`name` = 'Joe'" def where(self): """ builder = get_builder() builder.where('name', 'Joe') """ return "SELECT * FROM `users` WHERE `users`.`name` = 'Joe'" def where_exists(self): """ builder = get_builder() builder.where_exists('name') """ return "SELECT * FROM `users` WHERE EXISTS 'name'" def limit(self): """ builder = get_builder() builder.limit(5) """ return "SELECT * FROM `users` LIMIT 5" def offset(self): """ builder = get_builder() builder.offset(5) """ return "SELECT * FROM `users` OFFSET 5" def join(self): """ builder.join("profiles", "users.id", "=", "profiles.user_id") """ return "SELECT * FROM `users` INNER JOIN `profiles` ON `users`.`id` = `profiles`.`user_id`" def left_join(self): """ builder.left_join("profiles", "users.id", "=", "profiles.user_id") """ return "SELECT * FROM `users` LEFT JOIN `profiles` ON `users`.`id` = `profiles`.`user_id`" def right_join(self): """ builder.right_join("profiles", "users.id", "=", "profiles.user_id") """ return "SELECT * FROM `users` RIGHT JOIN `profiles` ON `users`.`id` = `profiles`.`user_id`" def update(self): """ builder.update({"name": "Joe", "email": "joe@yopmail.com"}) """ return "UPDATE `users` SET `users`.`name` = 'Joe', `users`.`email` = 'joe@yopmail.com'" def increment(self): """ builder.increment('age', 1) """ return "UPDATE `users` SET `users`.`age` = `users`.`age` + '1'" def decrement(self): """ builder.decrement('age', 1) """ return "UPDATE `users` SET `users`.`age` = `users`.`age` - '1'" def count(self): """ builder.count(id) """ return "SELECT COUNT(`users`.`id`) AS id FROM `users`" def order_by_asc(self): """ builder.order_by('email', 'asc') """ return "SELECT * FROM `users` ORDER BY `email` ASC" def order_by_desc(self): """ builder.order_by('email', 'des') """ return "SELECT * FROM `users` ORDER BY `email` DESC" def where_column(self): """ builder.where_column('name', 'username') """ return "SELECT * FROM `users` WHERE `users`.`name` = `users`.`username`" def where_null(self): """ builder.where_null('name') """ return "SELECT * FROM `users` WHERE `users`.`name` IS NULL" def where_not_null(self): """ builder.where_null('name') """ return "SELECT * FROM `users` WHERE `users`.`name` IS NOT NULL" def where_not_in(self): """ builder.where_not_in('id', [1, 2, 3]) """ return "SELECT * FROM `users` WHERE `users`.`id` NOT IN ('1','2','3')" def where_in(self): """ builder.where_in('id', [1, 2, 3]) """ return "SELECT * FROM `users` WHERE `users`.`id` IN ('1','2','3')" def between(self): """ builder.between('id', 2, 5) """ return "SELECT * FROM `users` WHERE `users`.`id` BETWEEN '2' AND '5'" def not_between(self): """ builder.not_between('id', 2, 5) """ return "SELECT * FROM `users` WHERE `users`.`id` NOT BETWEEN '2' AND '5'" def having(self): """ builder.select('user_id').avg('salary').group_by('user_id').having('salary', '>=', '1000') """ return "SELECT `payments`.`user_id`, AVG(`payments`.`salary`) AS salary FROM `payments` GROUP BY `payments`.`user_id` HAVING `payments`.`salary` >= '1000'" def group_by(self): """ builder.select('user_id').min('salary').group_by('user_id') """ return "SELECT `payments`.`user_id`, MIN(`payments`.`salary`) AS salary FROM `payments` GROUP BY `payments`.`user_id`" def where_lt(self): """ builder = self.get_builder() builder.where('age', '<', '20') """ return "SELECT * FROM `users` WHERE `users`.`age` < '20'" def where_lte(self): """ builder = self.get_builder() builder.where('age', '<=', '20') """ return "SELECT * FROM `users` WHERE `users`.`age` <= '20'" def where_gt(self): """ builder = self.get_builder() builder.where('age', '>', '20') """ return "SELECT * FROM `users` WHERE `users`.`age` > '20'" def where_gte(self): """ builder = self.get_builder() builder.where('age', '>=', '20') """ return "SELECT * FROM `users` WHERE `users`.`age` >= '20'" def where_ne(self): """ builder = self.get_builder() builder.where('age', '!=', '20') """ return "SELECT * FROM `users` WHERE `users`.`age` != '20'" def or_where(self): """ builder = self.get_builder() builder.where('age', '20').or_where('age','<', 20) """ return ( "SELECT * FROM `users` WHERE `users`.`age` = '20' OR `users`.`age` < '20'" ) def where_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%") """ return "SELECT * FROM `users` WHERE `users`.`age` LIKE '%name%'" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%") """ return "SELECT * FROM `users` WHERE `users`.`age` NOT LIKE '%name%'" def truncate(self): """ builder = self.get_builder() builder.truncate() """ return """TRUNCATE TABLE `users`""" def truncate_without_foreign_keys(self): """ builder = self.get_builder() builder.truncate() """ return [ "SET FOREIGN_KEY_CHECKS=0", "TRUNCATE TABLE `users`", "SET FOREIGN_KEY_CHECKS=1", ] def shared_lock(self): """ builder = self.get_builder() builder.truncate() """ return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' LOCK IN SHARE MODE" def update_lock(self): """ builder = self.get_builder() builder.truncate() """ return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE" def test_latest(self): builder = self.get_builder() builder.latest("email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_oldest(self): builder = self.get_builder() builder.oldest("email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def latest(self): """ builder.order_by('email', 'des') """ return "SELECT * FROM `users` ORDER BY `email` DESC" def oldest(self): """ builder.order_by('email', 'asc') """ return "SELECT * FROM `users` ORDER BY `email` ASC" ================================================ FILE: tests/mysql/builder/test_query_builder_scopes.py ================================================ import inspect import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.relationships import has_many from src.masoniteorm.scopes import SoftDeleteScope from tests.utils import MockConnectionFactory class BaseTestQueryBuilderScopes(unittest.TestCase): grammar = "mysql" def get_builder(self, table="users"): connection = MockConnectionFactory().make("default") return QueryBuilder( grammar=MySQLGrammar, connection_class=connection, connection="mysql", table=table, connection_details=DATABASES, ) def test_scopes(self): builder = self.get_builder().set_scope( "gender", lambda model, q: q.where("gender", "w") ) self.assertEqual( builder.gender().where("id", 1).to_sql(), "SELECT * FROM `users` WHERE `users`.`gender` = 'w' AND `users`.`id` = '1'", ) def test_global_scopes(self): builder = self.get_builder().set_global_scope( "where_not_null", lambda q: q.where_not_null("deleted_at"), action="select" ) self.assertEqual( builder.where("id", 1).to_sql(), "SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NOT NULL", ) def test_global_scope_from_class(self): builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual( builder.where("id", 1).to_sql(), "SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NULL", ) def test_global_scope_remove_from_class(self): builder = ( self.get_builder() .set_global_scope(SoftDeleteScope()) .remove_global_scope(SoftDeleteScope()) ) self.assertEqual( builder.where("id", 1).to_sql(), "SELECT * FROM `users` WHERE `users`.`id` = '1'", ) def test_global_scope_adds_method(self): builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual(builder.with_trashed().to_sql(), "SELECT * FROM `users`") ================================================ FILE: tests/mysql/builder/test_transactions.py ================================================ import inspect import os import unittest from src.masoniteorm.connections.ConnectionFactory import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.relationships import has_many from tests.utils import MockConnectionFactory class Articles(Model): pass class User(Model): @has_many("id", "user_id") def articles(self): return Articles if os.getenv("RUN_MYSQL_DATABASE", False) == "True": class TestTransactions(unittest.TestCase): pass # def get_builder(self, table="users"): # connection = ConnectionFactory().make("default") # return QueryBuilder(MySQLGrammar, connection, table=table, model=User()) # def test_can_start_transaction(self, table="users"): # builder = self.get_builder() # builder.begin() # builder.create({"name": "mike", "email": "mike@email.com"}) # builder.rollback() # self.assertFalse(builder.where("email", "mike@email.com").first()) ================================================ FILE: tests/mysql/connections/test_mysql_connection_selects.py ================================================ import os import unittest from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar class MockUser(Model): __table__ = "users" if os.getenv("RUN_MYSQL_DATABASE", False) == "True": class TestMySQLSelectConnection(unittest.TestCase): def setUp(self): self.builder = QueryBuilder(MySQLGrammar, table="users") def test_can_compile_select(self): to_sql = MockUser.where("id", 1).to_sql() sql = "SELECT * FROM `users` WHERE `users`.`id` = '1'" self.assertEqual(to_sql, sql) def test_can_get_first_record(self): user = MockUser.where("id", 1).first() self.assertEqual(user.id, 1) def test_can_find_first_record(self): user = MockUser.find(1) self.assertEqual(user.id, 1) def test_can_get_all_records(self): users = MockUser.all() self.assertGreater(len(users), 1) def test_can_get_5_records(self): users = MockUser.limit(5).get() self.assertEqual(len(users), 5) def test_can_get_1_record_with_get(self): users = MockUser.where("id", 1).limit(5).get() self.assertEqual(len(users), 1) users = MockUser.limit(5).where("id", 1).get() self.assertEqual(len(users), 1) ================================================ FILE: tests/mysql/grammar/test_mysql_delete_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar class BaseDeleteGrammarTest: def setUp(self): self.builder = QueryBuilder(MySQLGrammar, table="users") def test_can_compile_delete(self): to_sql = self.builder.delete("id", 1, query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_delete_in(self): to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_delete_with_where(self): to_sql = ( self.builder.where("age", 20) .where("profile", 1) .set_action("delete") .delete(query=True) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestMySQLDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase): grammar = "mysql" def can_compile_delete(self): """ ( self.builder .delete('id', 1) .to_sql() ) """ return "DELETE FROM `users` WHERE `users`.`id` = '1'" def can_compile_delete_in(self): """ ( self.builder .delete('id', 1) .to_sql() ) """ return "DELETE FROM `users` WHERE `users`.`id` IN ('1','2','3')" def can_compile_delete_with_where(self): """ ( self.builder .where('age', 20) .where('profile', 1) .set_action('delete') .delete() .to_sql() ) """ return ( "DELETE FROM `users` WHERE `users`.`age` = '20' AND `users`.`profile` = '1'" ) ================================================ FILE: tests/mysql/grammar/test_mysql_insert_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar class BaseInsertGrammarTest: def setUp(self): self.builder = QueryBuilder(MySQLGrammar, table="users") def test_can_compile_insert(self): to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_insert_with_keywords(self): to_sql = self.builder.create(name="Joe", query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( # These keys are intentionally out of order to show column to value alignment works [ {"name": "Joe", "age": 5}, {"age": 35, "name": "Bill"}, {"name": "John", "age": 10}, ], query=True, ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): to_sql = self.builder.bulk_create( [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True ).to_qmark() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_multiple(self): to_sql = self.builder.bulk_create( [ {"name": "Joe", "active": "1"}, {"name": "Bill", "active": "1"}, {"name": "John", "active": "1"}, ], query=True, ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestMySQLUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase): grammar = "mysql" def can_compile_insert(self): """ self.builder.create({ 'name': 'Joe' }).to_sql() """ return "INSERT INTO `users` (`users`.`name`) VALUES ('Joe')" def can_compile_insert_with_keywords(self): """ self.builder.create(name="Joe").to_sql() """ return "INSERT INTO `users` (`users`.`name`) VALUES ('Joe')" def can_compile_bulk_create(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO `users` (`age`, `name`) VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')""" def can_compile_bulk_create_multiple(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO `users` (`active`, `name`) VALUES ('1', 'Joe'), ('1', 'Bill'), ('1', 'John')""" def can_compile_bulk_create_qmark(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO `users` (`name`) VALUES ('?'), ('?'), ('?')""" ================================================ FILE: tests/mysql/grammar/test_mysql_qmark.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar class BaseQMarkTest: def setUp(self): self.builder = QueryBuilder(grammar=MySQLGrammar, table="users") def test_can_compile_select(self): mark = self.builder.select("username").where("name", "Joe") sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_delete(self): mark = self.builder.where("name", "Joe").delete(query=True) sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_update(self): mark = self.builder.update({"name": "Bob"}, dry=True).where("name", "Joe") sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_in(self): mark = self.builder.where_in("id", [1, 2, 3]) sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_not_null(self): mark = self.builder.where_not_null("id") sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, []) def test_can_compile_where_with_falsy_values(self): mark = self.builder.where("name", 0) sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_with_true_value(self): mark = self.builder.where("is_admin", True) sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_where_with_false_value(self): mark = self.builder.where("is_admin", False) sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) def test_can_compile_sub_group_bindings(self): mark = self.builder.where( lambda query: ( query.where("challenger", 1) .or_where("proposer", 1) .or_where("referee", 1) ) ) sql, bindings = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(mark.to_qmark(), sql) self.assertEqual(mark._bindings, bindings) class TestMySQLQmark(BaseQMarkTest, unittest.TestCase): def can_compile_select(self): """ self.builder.select('username').where('name', 'Joe') """ return ( "SELECT `users`.`username` FROM `users` WHERE `users`.`name` = '?'", ["Joe"], ) def can_compile_delete(self): """ self.builder.where('name', 'Joe').delete() """ return "DELETE FROM `users` WHERE `users`.`name` = '?'", ["Joe"] def can_compile_update(self): """ self.builder.update({ 'name': 'Bob' }).where('name', 'Joe') """ return ( "UPDATE `users` SET `users`.`name` = '?' WHERE `users`.`name` = '?'", ["Bob", "Joe"], ) def can_compile_where_in(self): """ self.builder.where_in('id', [1,2,3]).to_qmark() """ return ( "SELECT * FROM `users` WHERE `users`.`id` IN ('?', '?', '?')", [1, 2, 3], ) def can_compile_where_not_null(self): """ self.builder.where_not_null("id").to_qmark() """ return ("SELECT * FROM `users` WHERE `users`.`id` IS NOT NULL", ()) def can_compile_where_with_falsy_values(self): """ self.builder.where_not_null("id").to_qmark() """ return ("SELECT * FROM `users` WHERE `users`.`name` = '?'", [0]) def can_compile_where_with_true_value(self): """ self.builder.where("is_admin", True).to_qmark() """ return ("SELECT * FROM `users` WHERE `users`.`is_admin` = '1'", []) def can_compile_where_with_false_value(self): """ self.builder.where("is_admin", True).to_qmark() """ return ("SELECT * FROM `users` WHERE `users`.`is_admin` = '0'", []) def can_compile_sub_group_bindings(self): """ self.builder.where("is_admin", True).to_qmark() """ return ( "SELECT * FROM `users` WHERE (`users`.`challenger` = '?' OR `users`.`proposer` = '?' OR `users`.`referee` = '?')", [1, 1, 1], ) ================================================ FILE: tests/mysql/grammar/test_mysql_select_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.testing import BaseTestCaseSelectGrammar class TestMySQLGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): grammar = MySQLGrammar def can_compile_select(self): """ self.builder.to_sql() """ return "SELECT * FROM `users`" def can_compile_with_columns(self): """ self.builder.select('username', 'password').to_sql() """ return "SELECT `users`.`username`, `users`.`password` FROM `users`" def can_compile_order_by_and_first(self): """ self.builder.order_by('id', 'asc').first() """ return """SELECT * FROM `users` ORDER BY `id` ASC LIMIT 1""" def can_compile_with_where(self): """ self.builder.select('username', 'password').where('id', 1).to_sql() """ return "SELECT `users`.`username`, `users`.`password` FROM `users` WHERE `users`.`id` = '1'" def can_compile_with_several_where(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql() """ return "SELECT `users`.`username`, `users`.`password` FROM `users` WHERE `users`.`id` = '1' AND `users`.`username` = 'joe'" def can_compile_with_several_where_and_limit(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql() """ return "SELECT `users`.`username`, `users`.`password` FROM `users` WHERE `users`.`id` = '1' AND `users`.`username` = 'joe' LIMIT 10" def can_compile_with_sum(self): """ self.builder.sum('age').to_sql() """ return "SELECT SUM(`users`.`age`) AS age FROM `users`" def can_compile_with_max(self): """ self.builder.max('age').to_sql() """ return "SELECT MAX(`users`.`age`) AS age FROM `users`" def can_compile_with_max_and_columns(self): """ self.builder.select('username').max('age').to_sql() """ return "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`" def can_compile_with_max_and_columns_different_order(self): """ self.builder.max('age').select('username').to_sql() """ return "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`" def can_compile_with_order_by(self): """ self.builder.select('username').order_by('age', 'desc').to_sql() """ return "SELECT `users`.`username` FROM `users` ORDER BY `age` DESC" def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ return "SELECT `users`.`username` FROM `users` ORDER BY `age` DESC, `name` ASC" def can_compile_with_group_by(self): """ self.builder.select('username').group_by('age').to_sql() """ return "SELECT `users`.`username` FROM `users` GROUP BY `users`.`age`" def can_compile_where_in(self): """ self.builder.select('username').where_in('age', [1,2,3]).to_sql() """ return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` IN ('1','2','3')" def can_compile_where_in_empty(self): """ self.builder.where_in('age', []).to_sql() """ return """SELECT * FROM `users` WHERE 0 = 1""" def can_compile_where_not_in(self): """ self.builder.select('username').where_not_in('age', [1,2,3]).to_sql() """ return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` NOT IN ('1','2','3')" def can_compile_where_null(self): """ self.builder.select('username').where_null('age').to_sql() """ return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` IS NULL" def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` IS NOT NULL" def can_compile_where_raw(self): """ self.builder.where_raw("`age` = '18'").to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` = '18'" def can_compile_where_raw_and_where_with_multiple_bindings(self): """ self.builder.where_raw("`age` = '?' AND `is_admin` = '?'", [18, True]).where("email", "test@example.com") """ return "SELECT * FROM `users` WHERE `age` = '?' AND `is_admin` = '?' AND `users`.`email` = '?'" def can_compile_having_raw(self): """ self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql() """ return "SELECT COUNT(*) as counts FROM `users` HAVING counts > 18" def can_compile_select_raw(self): """ self.builder.select_raw("COUNT(*)").to_sql() """ return "SELECT COUNT(*) FROM `users`" def can_compile_limit_and_offset(self): """ self.builder.limit(10).offset(10).to_sql() """ return "SELECT * FROM `users` LIMIT 10 OFFSET 10" def can_compile_select_raw_with_select(self): """ self.builder.select('id').select_raw("COUNT(*)").to_sql() """ return "SELECT `users`.`id`, COUNT(*) FROM `users`" def can_compile_count(self): """ self.builder.count().to_sql() """ return "SELECT COUNT(*) AS m_count_reserved FROM `users`" def can_compile_count_column(self): """ self.builder.count().to_sql() """ return "SELECT COUNT(`users`.`money`) AS money FROM `users`" def can_compile_where_column(self): """ self.builder.where_column('name', 'email').to_sql() """ return "SELECT * FROM `users` WHERE `users`.`name` = `users`.`email`" def can_compile_or_where(self): """ self.builder.where('name', 2).or_where('name', 3).to_sql() """ return ( "SELECT * FROM `users` WHERE `users`.`name` = '2' OR `users`.`name` = '3'" ) def can_grouped_where(self): """ self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql() """ return "SELECT * FROM `users` WHERE (`users`.`age` = '2' AND `users`.`name` = 'Joe')" def can_compile_sub_select(self): """ self.builder.where_in('name', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age') ).to_sql() """ return "SELECT * FROM `users` WHERE `users`.`name` IN (SELECT `users`.`age` FROM `users`)" def can_compile_sub_select_where(self): """ self.builder.where_in('age', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe') ).to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` IN (SELECT `users`.`age` FROM `users` WHERE `users`.`age` = '2' AND `users`.`name` = 'Joe')" def can_compile_sub_select_value(self): """ self.builder.where('name', self.builder.new().sum('age') ).to_sql() """ return "SELECT * FROM `users` WHERE `users`.`name` = (SELECT SUM(`users`.`age`) AS age FROM `users`)" def can_compile_complex_sub_select(self): """ self.builder.where_in('name', (QueryBuilder(GrammarFactory.make(self.grammar), table='users') .select('age').where_in('email', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email') )) ).to_sql() """ return "SELECT * FROM `users` WHERE `users`.`name` IN (SELECT `users`.`age` FROM `users` WHERE `users`.`email` IN (SELECT `users`.`email` FROM `users`))" def can_compile_exists(self): """ self.builder.select('age').where_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return "SELECT `users`.`age` FROM `users` WHERE EXISTS (SELECT `users`.`username` FROM `users` WHERE `users`.`age` = '12')" def can_compile_not_exists(self): """ self.builder.select('age').where_not_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return "SELECT `users`.`age` FROM `users` WHERE NOT EXISTS (SELECT `users`.`username` FROM `users` WHERE `users`.`age` = '12')" def can_compile_having(self): """ builder.sum('age').group_by('age').having('age').to_sql() """ return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age`" def can_compile_having_order(self): """ builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() """ return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` ORDER `users`.`age` DESC" def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() """ return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` = '10'" def can_compile_having_with_greater_than_expression(self): """ builder.sum('age').group_by('age').having('age', '>', 10).to_sql() """ return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` > '10'" def can_compile_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return "SELECT * FROM `users` INNER JOIN `contacts` ON `users`.`id` = `contacts`.`user_id`" def can_compile_left_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return "SELECT * FROM `users` LEFT JOIN `contacts` ON `users`.`id` = `contacts`.`user_id`" def can_compile_multiple_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return "SELECT * FROM `users` INNER JOIN `contacts` ON `users`.`id` = `contacts`.`user_id` INNER JOIN `posts` ON `comments`.`post_id` = `posts`.`id`" def can_compile_between(self): """ builder.between('age', 18, 21).to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` BETWEEN '18' AND '21'" def can_compile_not_between(self): """ builder.not_between('age', 18, 21).to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` NOT BETWEEN '18' AND '21'" def test_can_compile_where_raw(self): to_sql = self.builder.where_raw("`age` = '18'").to_sql() self.assertEqual(to_sql, "SELECT * FROM `users` WHERE `age` = '18'") def test_can_compile_having_raw(self): to_sql = ( self.builder.select_raw("COUNT(*) as counts") .having_raw("counts > 10") .to_sql() ) self.assertEqual( to_sql, "SELECT COUNT(*) as counts FROM `users` HAVING counts > 10" ) def test_can_compile_having_raw_order(self): to_sql = ( self.builder.select_raw("COUNT(*) as counts") .having_raw("counts > 10") .order_by_raw("counts DESC") .to_sql() ) self.assertEqual( to_sql, "SELECT COUNT(*) as counts FROM `users` HAVING counts > 10 ORDER BY counts DESC", ) def test_can_compile_select_raw(self): to_sql = self.builder.select_raw("COUNT(*)").to_sql() self.assertEqual(to_sql, "SELECT COUNT(*) FROM `users`") def test_can_compile_select_raw_with_select(self): to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql() self.assertEqual(to_sql, "SELECT `users`.`id`, COUNT(*) FROM `users`") def can_compile_first_or_fail(self): """ builder = self.get_builder() builder.where("is_admin", "=", True).first_or_fail() """ return """SELECT * FROM `users` WHERE `users`.`is_admin` = '1' LIMIT 1""" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` NOT LIKE '%name%'" def where_regexp(self): """ builder = self.get_builder() builder.where("age", "regexp", "Joe").to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` REGEXP 'Joe'" def where_not_regexp(self): """ builder = self.get_builder() builder.where("age", "not regexp", "Joe").to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` NOT REGEXP 'Joe'" def where_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%").to_sql() """ return "SELECT * FROM `users` WHERE `users`.`age` LIKE '%name%'" def can_compile_join_clause(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on("bgt.fund", "=", "rg.fund") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` AND `bgt`.`dept` = `rg`.`dept` AND `bgt`.`acct` = `rg`.`acct` AND `bgt`.`sub` = `rg`.`sub`" def can_compile_join_clause_with_value(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `bgt`.`active` = '1' OR `bgt`.`acct` = '1234'" def can_compile_join_clause_with_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") .or_on_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `acct` IS NULL OR `dept` IS NULL AND `rg`.`abc` = '10'" def can_compile_join_clause_with_not_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `acct` IS NOT NULL OR `dept` IS NOT NULL AND `rg`.`abc` = '10'" def can_compile_join_clause_with_lambda(self): """ builder = self.get_builder() builder.join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .on_null("bgt") ), ).to_sql() """ return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` AND `bgt` IS NULL" def can_compile_left_join_clause_with_lambda(self): """ builder = self.get_builder() builder.left_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return "SELECT * FROM `users` LEFT JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` OR `bgt` IS NULL" def can_compile_right_join_clause_with_lambda(self): """ builder = self.get_builder() builder.right_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return "SELECT * FROM `users` RIGHT JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` OR `bgt` IS NULL" def shared_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' LOCK IN SHARE MODE" def update_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE" def can_user_where_raw_and_where(self): """ builder.where_raw("`age` = '18'").where("name", "=", "James").to_sql() """ return "SELECT * FROM `users` WHERE age = '18' AND `users`.`name` = 'James'" def where_exists_with_lambda(self): return """SELECT * FROM `users` WHERE EXISTS (SELECT * FROM `users` WHERE `users`.`age` = '1')""" def where_not_exists_with_lambda(self): return """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `users` WHERE `users`.`age` = '1')""" def where_date(self): return ( """SELECT * FROM `users` WHERE DATE(`users`.`created_at`) = '2022-06-01'""" ) def or_where_null(self): return """SELECT * FROM `users` WHERE `users`.`column1` IS NULL OR `users`.`column2` IS NULL""" def select_distinct(self): return """SELECT DISTINCT `users`.`group` FROM `users`""" ================================================ FILE: tests/mysql/grammar/test_mysql_update_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.expressions import Raw class BaseTestCaseUpdateGrammar: def setUp(self): self.builder = QueryBuilder(self.grammar, table="users") def test_can_compile_update(self): to_sql = ( self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_multiple_update(self): to_sql = self.builder.update( {"name": "Joe", "email": "user@email.com"}, dry=True ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): to_sql = ( self.builder.where("name", "bob") .where("age", 20) .update({"name": "Joe"}, dry=True) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) # def test_can_compile_increment(self): # to_sql = self.builder.increment("age") # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(to_sql, sql) # def test_can_compile_decrement(self): # to_sql = self.builder.decrement("age", 20).to_sql() # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(to_sql, sql) def test_raw_expression(self): to_sql = self.builder.update({"name": Raw("`username`")}, dry=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestMySQLUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase): grammar = MySQLGrammar def can_compile_update(self): """ builder.where('name', 'bob').update({ 'name': 'Joe' }).to_sql() """ return "UPDATE `users` SET `users`.`name` = 'Joe' WHERE `users`.`name` = 'bob'" def raw_expression(self): """ builder.where('name', 'bob').update({ 'name': 'Joe' }).to_sql() """ return "UPDATE `users` SET `users`.`name` = `username`" def can_compile_multiple_update(self): """ self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql() """ return "UPDATE `users` SET `users`.`name` = 'Joe', `users`.`email` = 'user@email.com'" def can_compile_update_with_multiple_where(self): """ builder.where('name', 'bob').where('age', 20).update({ 'name': 'Joe' }).to_sql() """ return "UPDATE `users` SET `users`.`name` = 'Joe' WHERE `users`.`name` = 'bob' AND `users`.`age` = '20'" def can_compile_increment(self): """ builder.increment('age').to_sql() """ return "UPDATE `users` SET `users`.`age` = `users`.`age` + '1'" def can_compile_decrement(self): """ builder.decrement('age', 20).to_sql() """ return "UPDATE `users` SET `users`.`age` = `users`.`age` - '20'" ================================================ FILE: tests/mysql/model/test_accessors_and_mutators.py ================================================ import datetime import json import os import unittest import pendulum from src.masoniteorm.collection import Collection from src.masoniteorm.models import Model from src.masoniteorm.query.grammars import MSSQLGrammar from tests.User import User class User(Model): __casts__ = {"is_admin": "bool"} def get_name_attribute(self): return f"Hello, {self.get_raw_attribute('name')}" def set_name_attribute(self, attribute): return str(attribute).upper() class SetUser(Model): __casts__ = {"is_admin": "bool"} def set_name_attribute(self, attribute): return str(attribute).upper() class TestAccessor(unittest.TestCase): def test_can_get_accessor(self): user = User.hydrate( {"name": "joe", "email": "joe@masoniteproject.com", "is_admin": 1} ) self.assertEqual(user.email, "joe@masoniteproject.com") self.assertEqual(user.name, "Hello, joe") self.assertTrue(user.is_admin is True, f"{user.is_admin} is not True") def test_mutator(self): user = SetUser.hydrate({"email": "joe@masoniteproject.com", "is_admin": 1}) user.name = "joe" self.assertEqual(user.name, "JOE") ================================================ FILE: tests/mysql/model/test_model.py ================================================ import datetime import json import os import unittest import pendulum from src.masoniteorm.collection import Collection from src.masoniteorm.exceptions import ModelNotFound from src.masoniteorm.models import Model from tests.User import User class ProfileFillable(Model): __fillable__ = ["name"] __table__ = "profiles" __timestamps__ = None class ProfileFillTimeStamped(Model): __fillable__ = ["*"] __table__ = "profiles" class ProfileFillAsterisk(Model): __fillable__ = ["*"] __table__ = "profiles" __timestamps__ = None class ProfileGuarded(Model): __guarded__ = ["email"] __table__ = "profiles" __timestamps__ = None class ProfileGuardedAsterisk(Model): __guarded__ = ["*"] __table__ = "profiles" __timestamps__ = None class ProfileSerialize(Model): __fillable__ = ["*"] __table__ = "profiles" __hidden__ = ["password"] class ProfileSerializeWithVisible(Model): __fillable__ = ["*"] __table__ = "profiles" __visible__ = ["name", "email"] class ProfileSerializeWithVisibleAndHidden(Model): __fillable__ = ["*"] __table__ = "profiles" __visible__ = ["name", "email"] __hidden__ = ["password"] class Profile(Model): pass class Company(Model): pass class User(Model): @property def meta(self): return {"is_subscribed": True} class ProductNames(Model): pass class TestModel(unittest.TestCase): def test_create_can_use_fillable(self): sql = ProfileFillable.create( {"name": "Joe", "email": "user@example.com"}, query=True ).to_sql() self.assertEqual( sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" ) def test_create_can_use_fillable_asterisk(self): sql = ProfileFillAsterisk.create( {"name": "Joe", "email": "user@example.com"}, query=True ).to_sql() self.assertEqual( sql, "INSERT INTO `profiles` (`profiles`.`name`, `profiles`.`email`) VALUES ('Joe', 'user@example.com')", ) def test_create_can_use_guarded(self): sql = ProfileGuarded.create( {"name": "Joe", "email": "user@example.com"}, query=True ).to_sql() self.assertEqual( sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')" ) def test_create_can_use_guarded_asterisk(self): sql = ProfileGuardedAsterisk.create( {"name": "Joe", "email": "user@example.com"}, query=True ).to_sql() # An asterisk guarded attribute excludes all fields from mass-assignment. # This would raise a DB error if there are any required fields. self.assertEqual(sql, "INSERT INTO `profiles` (*) VALUES ()") def test_bulk_create_can_use_fillable(self): query_builder = ProfileFillable.bulk_create( [ {"name": "Joe", "email": "user@example.com"}, {"name": "Joe II", "email": "userII@example.com"}, ], query=True, ) self.assertEqual( query_builder.to_sql(), "INSERT INTO `profiles` (`name`) VALUES ('Joe'), ('Joe II')", ) def test_bulk_create_can_use_fillable_asterisk(self): query_builder = ProfileFillAsterisk.bulk_create( [ {"name": "Joe", "email": "user@example.com"}, {"name": "Joe II", "email": "userII@example.com"}, ], query=True, ) self.assertEqual( query_builder.to_sql(), "INSERT INTO `profiles` (`email`, `name`) VALUES ('user@example.com', 'Joe'), ('userII@example.com', 'Joe II')", ) def test_bulk_create_can_use_guarded(self): query_builder = ProfileGuarded.bulk_create( [ {"name": "Joe", "email": "user@example.com"}, {"name": "Joe II", "email": "userII@example.com"}, ], query=True, ) self.assertEqual( query_builder.to_sql(), "INSERT INTO `profiles` (`name`) VALUES ('Joe'), ('Joe II')", ) def test_bulk_create_can_use_guarded_asterisk(self): query_builder = ProfileGuardedAsterisk.bulk_create( [ {"name": "Joe", "email": "user@example.com"}, {"name": "Joe II", "email": "userII@example.com"}, ], query=True, ) # An asterisk guarded attribute excludes all fields from mass-assignment. # This would obviously raise an invalid SQL syntax error. # TODO: Raise a clearer error? self.assertEqual( query_builder.to_sql(), "INSERT INTO `profiles` () VALUES (), ()" ) def test_update_can_use_fillable(self): query_builder = ProfileFillable().update( {"name": "Joe", "email": "user@example.com"}, dry=True ) self.assertEqual( query_builder.to_sql(), "UPDATE `profiles` SET `profiles`.`name` = 'Joe'" ) def test_update_can_use_fillable_asterisk(self): query_builder = ProfileFillAsterisk().update( {"name": "Joe", "email": "user@example.com"}, dry=True ) self.assertEqual( query_builder.to_sql(), "UPDATE `profiles` SET `profiles`.`name` = 'Joe', `profiles`.`email` = 'user@example.com'", ) def test_update_can_use_guarded(self): query_builder = ProfileGuarded().update( {"name": "Joe", "email": "user@example.com"}, dry=True ) self.assertEqual( query_builder.to_sql(), "UPDATE `profiles` SET `profiles`.`name` = 'Joe'" ) def test_update_can_use_guarded_asterisk(self): profile = ProfileGuardedAsterisk() initial_sql = profile.get_builder().to_sql() query_builder = profile.update( {"name": "Joe", "email": "user@example.com"}, dry=True ) # An asterisk guarded attribute excludes all fields from mass-assignment. # The query builder's sql should not have been altered in any way. self.assertEqual(query_builder.to_sql(), initial_sql) def test_table_name(self): table_name = Profile.get_table_name() self.assertEqual(table_name, "profiles") table_name = Company.get_table_name() self.assertEqual(table_name, "companies") table_name = ProductNames.get_table_name() self.assertEqual(table_name, "product_names") def test_serialize(self): profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) self.assertEqual(profile.serialize(), {"name": "Joe", "id": 1}) def test_json(self): profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) self.assertEqual(profile.to_json(), '{"name": "Joe", "id": 1}') def test_serialize_with_hidden(self): profile = ProfileSerialize.hydrate( {"name": "Joe", "id": 1, "password": "secret"} ) self.assertTrue(profile.serialize().get("name")) self.assertTrue(profile.serialize().get("id")) self.assertFalse(profile.serialize().get("password")) def test_serialize_with_visible(self): profile = ProfileSerializeWithVisible.hydrate( {"name": "Joe", "id": 1, "password": "secret", "email": "joe@masonite.com"} ) self.assertTrue( {"name": "Joe", "email": "joe@masonite.com"}, profile.serialize() ) def test_serialize_with_visible_and_hidden_raise_error(self): profile = ProfileSerializeWithVisibleAndHidden.hydrate( {"name": "Joe", "id": 1, "password": "secret", "email": "joe@masonite.com"} ) with self.assertRaises(AttributeError): profile.serialize() def test_serialize_with_on_the_fly_appends(self): user = User.hydrate({"name": "Joe", "id": 1}) user.set_appends(["meta"]) serialized = user.serialize() self.assertEqual(serialized["id"], 1) self.assertEqual(serialized["name"], "Joe") self.assertEqual(serialized["meta"]["is_subscribed"], True) def test_serialize_with_model_appends(self): User.__appends__ = ["meta"] user = User.hydrate({"name": "Joe", "id": 1}) serialized = user.serialize() self.assertEqual(serialized["id"], 1) self.assertEqual(serialized["name"], "Joe") self.assertEqual(serialized["meta"]["is_subscribed"], True) def test_serialize_with_date(self): user = User.hydrate({"name": "Joe", "created_at": pendulum.now()}) self.assertTrue(json.dumps(user.serialize())) def test_set_as_date(self): user = User.hydrate( { "name": "Joe", "created_at": pendulum.now().add(days=10).to_datetime_string(), } ) self.assertTrue(user.created_at) self.assertTrue(user.created_at.is_future()) def test_access_as_date(self): user = User.hydrate( { "name": "Joe", "created_at": datetime.datetime.now() + datetime.timedelta(days=1), } ) self.assertTrue(user.created_at) self.assertTrue(user.created_at.is_future()) def test_hydrate_with_none(self): profile = ProfileFillAsterisk.hydrate(None) self.assertEqual(profile, None) def test_serialize_with_dirty_attribute(self): profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1}) profile.age = 18 self.assertEqual(profile.serialize(), {"age": 18, "name": "Joe", "id": 1}) def test_attribute_check_with_hasattr(self): self.assertFalse(hasattr(Profile(), "__password__")) if os.getenv("RUN_MYSQL_DATABASE", "false").lower() == "true": class MysqlTestModel(unittest.TestCase): # TODO: these tests aren't getting run in CI... is that intentional? def test_can_find_first(self): profile = User.find(1) def test_can_touch(self): profile = ProfileFillTimeStamped.hydrate({"name": "Joe", "id": 1}) sql = profile.touch("now", query=True).to_sql() self.assertEqual( sql, "UPDATE `profiles` SET `profiles`.`updated_at` = 'now' WHERE `profiles`.`id` = '1'", ) def test_find_or_fail_raise_an_exception_if_not_exists(self): with self.assertRaises(ModelNotFound): User.find(100) def test_returns_correct_data_type(self): self.assertIsInstance(User.all(), Collection) # self.assertIsInstance(User.first(), User) # self.assertIsInstance(User.first(), User) ================================================ FILE: tests/mysql/relationships/test_belongs_to_many.py ================================================ import unittest # from src.masoniteorm import query from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( has_one, belongs_to_many, has_one_through, has_many, ) from dotenv import load_dotenv load_dotenv(".env") class User(Model): @has_one def profile(self): return Profile class Profile(Model): pass class Permission(Model): @belongs_to_many("permission_id", "role_id", "id", "id") def role(self): return Role class PermissionSelect(Model): __table__ = "permissions" __selects__ = ["permission_id"] @belongs_to_many("permission_id", "role_id", "id", "id") def role(self): return Role class Role(Model): @belongs_to_many("role_id", "permission_id", "id", "id") def permissions(self): return Permission class MySQLRelationships(unittest.TestCase): maxDiff = None def test_belongs_to_many(self): sql = Permission.where_has( "role", lambda query: (query.where("slug", "users")) ).to_sql() self.assertEqual( sql, """SELECT * FROM `permissions` WHERE EXISTS (SELECT * FROM `roles` INNER JOIN `permission_role` ON `roles`.`id` = `permission_role`.`role_id` WHERE `permission_role`.`permission_id` = `permissions`.`id` AND `roles`.`id` IN (SELECT `roles`.`id` FROM `roles` WHERE `roles`.`slug` = 'users'))""", ) def test_belongs_to_many_has(self): sql = Role.has("permissions").to_sql() self.assertEqual( sql, """SELECT * FROM `roles` WHERE EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id`)""", ) def test_belongs_to_many_or_has(self): sql = Role.where("name", "role_name").or_has("permissions").to_sql() self.assertEqual( sql, """SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id`)""", ) def test_belongs_to_many_or_where_has(self): sql = ( Role.where("name", "role_name") .or_where_has("permissions", lambda q: q.where("permission_id", 1)) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`permission_id` = '1'))""", ) def test_belongs_to_many_or_doesnt_have(self): sql = Role.where("name", "role_name").or_doesnt_have("permissions").to_sql() self.assertEqual( sql, """SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR NOT EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id`)""", ) def test_where_doesnt_have(self): sql = ( Role.where("name", "role_name") .where_doesnt_have( "permissions", lambda q: q.where("name", "Creates Users") ) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' AND NOT EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`name` = 'Creates Users'))""", ) def test_or_where_doesnt_have(self): sql = ( Role.where("name", "role_name") .or_where_doesnt_have( "permissions", lambda q: q.where("name", "Creates Users") ) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR NOT EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`name` = 'Creates Users'))""", ) def test_belongs_to_many_where_has(self): sql = Role.where_has( "permissions", lambda q: q.where("name", "Creates Users") ).to_sql() self.assertEqual( sql, """SELECT * FROM `roles` WHERE EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`name` = 'Creates Users'))""", ) def test_belongs_to_many_relate_method(self): permission = Permission.hydrate({"id": 1, "name": "Create Users"}) sql = permission.related("role").to_sql() self.assertEqual( sql, """SELECT `roles`.*, `permission_role`.`permission_id` AS permission_role_id, `permission_role`.`role_id` AS m_reserved2, `permission_role`.`id` AS m_reserved3 FROM `permissions` INNER JOIN `permission_role` ON `permission_role`.`permission_id` = `permissions`.`id` INNER JOIN `roles` ON `permission_role`.`role_id` = `roles`.`id`""", ) def test_belongs_to_many_relate_method_reversed(self): role = Role.hydrate({"id": 1, "name": "Create Users"}) sql = role.related("permissions").to_sql() self.assertEqual( sql, """SELECT `permissions`.*, `permission_role`.`role_id` AS permission_role_id, `permission_role`.`permission_id` AS m_reserved2, `permission_role`.`id` AS m_reserved3 FROM `roles` INNER JOIN `permission_role` ON `permission_role`.`role_id` = `roles`.`id` INNER JOIN `permissions` ON `permission_role`.`permission_id` = `permissions`.`id`""", ) def test_belongs_to_many_joins(self): sql = Role.joins("permissions").to_sql() self.assertEqual( sql, """SELECT `roles`.*, `permission_role`.`role_id` AS permission_role_id, `permission_role`.`permission_id` AS m_reserved2, `permission_role`.`id` AS m_reserved3 FROM `roles` INNER JOIN `permission_role` ON `permission_role`.`role_id` = `roles`.`id` INNER JOIN `permissions` ON `permission_role`.`id` = `permissions`.`id`""", ) def test_with_count(self): sql = Permission.with_count("role").to_sql() self.assertEqual( sql, """SELECT `permissions`.*, (SELECT COUNT(*) AS m_count_reserved FROM `permission_role` WHERE `permissions`.`id` = `permission_role`.`permission_id`) AS roles_count FROM `permissions`""", ) def test_with_count_with_selects(self): sql = PermissionSelect.with_count("role").to_sql() self.assertEqual( sql, """SELECT `permissions`.`permission_id`, (SELECT COUNT(*) AS m_count_reserved FROM `permission_role` WHERE `permissions`.`id` = `permission_role`.`permission_id`) AS roles_count FROM `permissions`""", ) ================================================ FILE: tests/mysql/relationships/test_has_many_through.py ================================================ import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( has_many_through, ) from dotenv import load_dotenv load_dotenv(".env") class InboundShipment(Model): @has_many_through("port_id", "country_id", "from_port_id", "country_id") def from_country(self): return Country, Port class Country(Model): pass class Port(Model): pass class MySQLRelationships(unittest.TestCase): maxDiff = None def test_has_query(self): sql = InboundShipment.has("from_country").to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_has(self): sql = InboundShipment.where("name", "Joe").or_has("from_country").to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_where_has_query(self): sql = InboundShipment.where_has( "from_country", lambda query: query.where("name", "USA") ).to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_or_where_has(self): sql = ( InboundShipment.where("name", "Joe") .or_where_has("from_country", lambda query: query.where("name", "USA")) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_doesnt_have(self): sql = InboundShipment.doesnt_have("from_country").to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_where_doesnt_have(self): sql = ( InboundShipment.where("name", "Joe") .or_where_doesnt_have( "from_country", lambda query: query.where("name", "USA") ) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) ================================================ FILE: tests/mysql/relationships/test_has_one_through.py ================================================ import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( has_one_through, ) from dotenv import load_dotenv load_dotenv(".env") class InboundShipment(Model): @has_one_through(None, "from_port_id", "country_id", "port_id", "country_id") def from_country(self): return Country, Port class Country(Model): pass class Port(Model): pass class MySQLHasOneThroughRelationship(unittest.TestCase): maxDiff = None def test_has_query(self): sql = InboundShipment.has("from_country").to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_has(self): sql = InboundShipment.where("name", "Joe").or_has("from_country").to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_where_has_query(self): sql = InboundShipment.where_has( "from_country", lambda query: query.where("name", "USA") ).to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_or_where_has(self): sql = ( InboundShipment.where("name", "Joe") .or_where_has("from_country", lambda query: query.where("name", "USA")) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_doesnt_have(self): sql = InboundShipment.doesnt_have("from_country").to_sql() self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""", ) def test_or_where_doesnt_have(self): sql = ( InboundShipment.where("name", "Joe") .or_where_doesnt_have( "from_country", lambda query: query.where("name", "USA") ) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""", ) def test_has_one_through_with_count(self): sql = InboundShipment.with_count("from_country").to_sql() self.assertEqual( sql, """SELECT `inbound_shipments`.*, (SELECT COUNT(*) AS m_count_reserved FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`) AS from_country_count FROM `inbound_shipments`""", ) ================================================ FILE: tests/mysql/relationships/test_relationships.py ================================================ import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import ( has_one, belongs_to_many, has_one_through, has_many, ) from dotenv import load_dotenv load_dotenv(".env") class User(Model): @has_one def profile(self): return Profile class Profile(Model): @has_one def identification(self): return Identification class Identification(Model): pass class MySQLRelationships(unittest.TestCase): maxDiff = None def test_has(self): sql = User.has("profile").to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`)""", ) def test_has_nested(self): sql = User.has("profile.identification").to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id`))""", ) def test_or_has(self): sql = User.where("name", "Joe").or_has("profile").to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`)""", ) def test_or_has_nested(self): sql = User.where("name", "Joe").or_has("profile.identification").to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id`))""", ) def test_relationship_where_has(self): sql = ( User.where("name", "Joe") .where_has("profile", lambda q: q.where("profile_id", 1)) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `users` WHERE `users`.`name` = 'Joe' AND EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""", ) def test_relationship_where_has_nested(self): sql = ( User.where("name", "Joe") .where_has( "profile.identification", lambda q: q.where("identification_id", 1) ) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `users` WHERE `users`.`name` = 'Joe' AND EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id` AND `identifications`.`identification_id` = '1'))""", ) def test_relationship_or_where_has(self): sql = ( User.where("name", "Joe") .or_where_has("profile", lambda q: q.where("profile_id", 1)) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""", ) def test_relationship_or_where_has_nested(self): sql = ( User.where("name", "Joe") .or_where_has( "profile.identification", lambda q: q.where("identification_id", 1) ) .to_sql() ) self.assertEqual( sql, """SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id` AND `identifications`.`identification_id` = '1'))""", ) def test_relationship_doesnt_have(self): sql = User.doesnt_have("profile").to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`)""", ) def test_relationship_doesnt_have_nested(self): sql = User.doesnt_have("profile.identification").to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id`))""", ) def test_relationship_where_doesnt_have(self): sql = User.where_doesnt_have( "profile", lambda q: q.where("profile_id", 1) ).to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""", ) def test_relationship_where_doesnt_have_nested(self): sql = User.where_doesnt_have( "profile.identification", lambda q: q.where("identification_id", 1) ).to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`) AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `users`.`id` AND `identifications`.`identification_id` = '1')""", ) def test_relationship_or_where_doesnt_have(self): sql = User.or_where_doesnt_have( "profile", lambda q: q.where("profile_id", 1) ).to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""", ) def test_relationship_or_where_doesnt_have_nested(self): sql = User.or_where_doesnt_have( "profile.identification", lambda q: q.where("identification_id", 1) ).to_sql() self.assertEqual( sql, """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`) AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `users`.`id` AND `identifications`.`identification_id` = '1')""", ) def test_joins(self): sql = User.joins("profile").to_sql() self.assertEqual( sql, """SELECT * FROM `users` INNER JOIN `profiles` ON `users`.`id` = `profiles`.`profile_id`""", ) def test_join_on(self): sql = User.join_on("profile", lambda q: (q.where("active", 1))).to_sql() self.assertEqual( sql, """SELECT * FROM `users` INNER JOIN `profiles` ON `users`.`id` = `profiles`.`profile_id` WHERE (`profiles`.`active` = '1')""", ) ================================================ FILE: tests/mysql/schema/test_mysql_schema_builder.py ================================================ import os import unittest from src.masoniteorm import Model from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MySQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MySQLPlatform from tests.integrations.config.database import DATABASES class Discussion(Model): pass class TestMySQLSchemaBuilder(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection_class=MySQLConnection, connection="mysql", connection_details=DATABASES, platform=MySQLPlatform, dry=True, ).on("mysql") def test_can_add_columns1(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)" ], ) def test_can_add_tiny_text(self): with self.schema.create("users") as blueprint: blueprint.tiny_text("description") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ["CREATE TABLE `users` (`description` TINYTEXT NOT NULL)"], ) def test_can_add_unsigned_decimal(self): with self.schema.create("users") as blueprint: blueprint.unsigned_decimal("amount", 19, 4) self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ["CREATE TABLE `users` (`amount` DECIMAL(19, 4) UNSIGNED NOT NULL)"], ) def test_can_create_table_if_not_exists(self): with self.schema.create_table_if_not_exists("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE IF NOT EXISTS `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)" ], ) def test_can_add_columns_with_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") blueprint.unique("name"), blueprint.unique("name", name="table_unique"), self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL, CONSTRAINT users_name_unique UNIQUE (name), CONSTRAINT table_unique UNIQUE (name))" ], ) def test_add_column_comment(self): with self.schema.create("users") as blueprint: blueprint.string("name").comment("A users username") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL COMMENT 'A users username')" ], ) def test_can_add_table_comment(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.table_comment("A users table") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL) COMMENT 'A users table'" ], ) def test_can_add_columns_with_foreign_key_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.foreign("profile_id").references("id").on("profiles") blueprint.foreign_id("post_id").references("id").on("posts") blueprint.foreign_id_for(Discussion).references("id").on("discussions") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, " "`age` INT(11) NOT NULL, " "`profile_id` INT(11) NOT NULL, " "`post_id` BIGINT UNSIGNED NOT NULL, " "CONSTRAINT users_name_unique UNIQUE (name), " "CONSTRAINT users_profile_id_foreign FOREIGN KEY (`profile_id`) REFERENCES `profiles`(`id`), " "CONSTRAINT users_profile_id_foreign FOREIGN KEY (`post_id`) REFERENCES `posts`(`id`)), " "CONSTRAINT users_discussions_id_foreign FOREIGN KEY (`discussion_id`) REFERENCES `posts`(`id`))" ], ) def test_can_add_columns_with_foreign_key_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.add_foreign("profile_id.id.profiles") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, " "`age` INT(11) NOT NULL, " "`profile_id` INT(11) NOT NULL, " "CONSTRAINT users_name_unique UNIQUE (name), " "CONSTRAINT users_profile_id_foreign FOREIGN KEY (`profile_id`) REFERENCES `profiles`(`id`))" ], ) def test_can_advanced_table_creation(self): with self.schema.create("users") as blueprint: blueprint.increments("id") blueprint.id("id2") blueprint.string("name") blueprint.tiny_integer("active") blueprint.string("email").unique() blueprint.enum("gender", ["male", "female"]) blueprint.string("password") blueprint.decimal("money") blueprint.integer("admin").default(0) blueprint.string("option").default("ADMIN") blueprint.string("remember_token").nullable() blueprint.timestamp("verified_at").nullable() blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 14) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`id` INT UNSIGNED AUTO_INCREMENT NOT NULL, " "`id2` BIGINT UNSIGNED AUTO_INCREMENT NOT NULL, " "`name` VARCHAR(255) NOT NULL, `active` TINYINT(1) NOT NULL, `email` VARCHAR(255) NOT NULL, `gender` ENUM('male', 'female') NOT NULL, " "`password` VARCHAR(255) NOT NULL, `money` DECIMAL(17, 6) NOT NULL, " "`admin` INT(11) NOT NULL DEFAULT 0, `option` VARCHAR(255) NOT NULL DEFAULT 'ADMIN', `remember_token` VARCHAR(255) NULL, `verified_at` TIMESTAMP NULL, " "`created_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_id2_primary PRIMARY KEY (id2), CONSTRAINT users_email_unique UNIQUE (email))" ], ) def test_can_add_primary_constraint_without_column_name(self): with self.schema.create("users") as blueprint: blueprint.integer("user_id").primary() blueprint.string("name") blueprint.string("email") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual(len(blueprint.table.added_constraints), 1) self.assertTrue( blueprint.to_sql()[0].startswith( "CREATE TABLE `users` (`user_id` INT(11) NOT NULL" ) ) def test_can_advanced_table_creation2(self): with self.schema.create("users") as blueprint: blueprint.big_increments("id") blueprint.string("name") blueprint.string("duration") blueprint.string("url") blueprint.inet("last_address").nullable() blueprint.cidr("route_origin").nullable() blueprint.macaddr("mac_address").nullable() blueprint.datetime("published_at") blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.integer("author_id").unsigned().nullable() blueprint.foreign("author_id").references("id").on("users").on_delete( "CASCADE" ) blueprint.text("description") blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 14) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`id` BIGINT UNSIGNED AUTO_INCREMENT NOT NULL, `name` VARCHAR(255) NOT NULL, " "`duration` VARCHAR(255) NOT NULL, `url` VARCHAR(255) NOT NULL, `last_address` VARCHAR(255) NULL, `route_origin` VARCHAR(255) NULL, `mac_address` VARCHAR(255) NULL, " "`published_at` DATETIME NOT NULL, `thumbnail` VARCHAR(255) NULL, " "`premium` INT(11) NOT NULL, `author_id` INT(11) UNSIGNED NULL, `description` TEXT NOT NULL, `created_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, " "`updated_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY (`author_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)" ], ) def test_can_add_columns_with_foreign_key_constraint_name(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id") blueprint.foreign("profile_id", name="profile_foreign").references("id").on( "profiles" ) self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (" "`profile_id` INT(11) NOT NULL, " "CONSTRAINT profile_foreign FOREIGN KEY (`profile_id`) REFERENCES `profiles`(`id`))" ], ) def test_can_have_composite_keys(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.primary(["name", "age"]) self.assertEqual(len(blueprint.table.added_columns), 3) print(blueprint.to_sql()) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` " "(`name` VARCHAR(255) NOT NULL, " "`age` INT(11) NOT NULL, " "`profile_id` INT(11) NOT NULL, " "CONSTRAINT users_name_unique UNIQUE (name), " "CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))" ], ) def test_can_have_column_primary_key(self): with self.schema.create("users") as blueprint: blueprint.string("name").primary() blueprint.integer("age") blueprint.integer("profile_id") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` " "(`name` VARCHAR(255) NOT NULL, " "`age` INT(11) NOT NULL, " "`profile_id` INT(11) NOT NULL, " "CONSTRAINT users_name_primary PRIMARY KEY (name))" ], ) def test_can_have_unsigned_columns(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id").unsigned() blueprint.big_integer("big_profile_id").unsigned() blueprint.tiny_integer("tiny_profile_id").unsigned() blueprint.small_integer("small_profile_id").unsigned() blueprint.medium_integer("medium_profile_id").unsigned() blueprint.unsigned_integer("unsigned_profile_id") blueprint.unsigned_big_integer("unsigned_big_profile_id") self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (" "`profile_id` INT(11) UNSIGNED NOT NULL, " "`big_profile_id` BIGINT(32) UNSIGNED NOT NULL, " "`tiny_profile_id` TINYINT(1) UNSIGNED NOT NULL, " "`small_profile_id` SMALLINT(5) UNSIGNED NOT NULL, " "`medium_profile_id` MEDIUMINT(7) UNSIGNED NOT NULL, " "`unsigned_profile_id` INT UNSIGNED NOT NULL, " "`unsigned_big_profile_id` BIGINT(32) UNSIGNED NOT NULL)" ], ) def test_can_have_default_blank_string(self): with self.schema.create("users") as blueprint: blueprint.string("profile_id").default("") self.assertEqual( blueprint.to_sql(), ["CREATE TABLE `users` (" "`profile_id` VARCHAR(255) NOT NULL DEFAULT '')"], ) def test_can_have_float_type(self): with self.schema.create("users") as blueprint: blueprint.float("amount") self.assertEqual( blueprint.to_sql(), ["CREATE TABLE `users` (" "`amount` FLOAT(19, 4) NOT NULL)"], ) def test_has_table(self): schema_sql = self.schema.has_table("users") sql = f"SELECT * from information_schema.tables where table_name='users' AND table_schema = '{os.getenv('MYSQL_DATABASE_DATABASE')}'" self.assertEqual(schema_sql, sql) def test_can_truncate(self): sql = self.schema.truncate("users") self.assertEqual(sql, "TRUNCATE `users`") def test_can_rename_table(self): sql = self.schema.rename("users", "clients") self.assertEqual(sql, "ALTER TABLE `users` RENAME TO `clients`") def test_can_drop_table_if_exists(self): sql = self.schema.drop_table_if_exists("users", "clients") self.assertEqual(sql, "DROP TABLE IF EXISTS `users`") def test_can_drop_table(self): sql = self.schema.drop_table("users", "clients") self.assertEqual(sql, "DROP TABLE `users`") def test_has_column(self): sql = self.schema.has_column("users", "name") self.assertEqual( sql, "SELECT column_name FROM information_schema.columns WHERE table_name='users' and column_name='name'", ) def test_can_enable_foreign_keys(self): sql = self.schema.enable_foreign_key_constraints() self.assertEqual(sql, "SET FOREIGN_KEY_CHECKS=1") def test_can_disable_foreign_keys(self): sql = self.schema.disable_foreign_key_constraints() self.assertEqual(sql, "SET FOREIGN_KEY_CHECKS=0") def test_can_truncate_without_foreign_keys(self): sql = self.schema.truncate("users", foreign_keys=True) self.assertEqual( sql, [ "SET FOREIGN_KEY_CHECKS=0", "TRUNCATE `users`", "SET FOREIGN_KEY_CHECKS=1", ], ) def test_can_add_enum(self): with self.schema.create("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ "CREATE TABLE `users` (`status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active')" ], ) ================================================ FILE: tests/mysql/schema/test_mysql_schema_builder_alter.py ================================================ import unittest import os from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import MySQLConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import MySQLPlatform from src.masoniteorm.schema.Table import Table class TestMySQLSchemaBuilderAlter(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection_class=MySQLConnection, connection="mysql", connection_details=DATABASES, platform=MySQLPlatform, dry=True, ).on("mysql") def test_can_add_columns(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) sql = [ "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL, ADD `age` INT(11) NOT NULL" ] self.assertEqual(blueprint.to_sql(), sql) def test_can_add_column_comments(self): with self.schema.table("users") as blueprint: blueprint.string("name").comment("A users username") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL COMMENT 'A users username'" ] def test_can_add_table_comment(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.table_comment("A users username") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL COMMENT 'A users username'" ] def test_can_add_table_comment_with_no_columns(self): with self.schema.table("users") as blueprint: blueprint.table_comment("A users username") self.assertEqual(len(blueprint.table.added_columns), 0) sql = ["ALTER TABLE `users` COMMENT 'A users username'"] self.assertEqual(blueprint.to_sql(), sql) def test_can_add_column_after(self): with self.schema.table("users") as blueprint: blueprint.string("name").after("age") self.assertEqual(len(blueprint.table.added_columns), 1) sql = ["ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `age`"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_rename(self): with self.schema.table("users") as blueprint: blueprint.rename("post", "comment", "integer") table = Table("users") table.add_column("post", "integer") blueprint.table.from_table = table sql = ["ALTER TABLE `users` CHANGE `post` `comment` INT NOT NULL"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_and_rename(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.rename("post", "comment", "string") table = Table("users") table.add_column("post", "string") blueprint.table.from_table = table sql = [ "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL", "ALTER TABLE `users` CHANGE `post` `comment` VARCHAR NOT NULL", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_and_rename_to_string(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.rename("post", "comment", "string", length=200) table = Table("users") table.add_column("post", "integer") blueprint.table.from_table = table sql = [ "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL", "ALTER TABLE `users` CHANGE `post` `comment` VARCHAR(200) NOT NULL", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop1(self): with self.schema.table("users") as blueprint: blueprint.drop_column("post") sql = ["ALTER TABLE `users` DROP COLUMN `post`"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( "cascade" ) sql = [ "ALTER TABLE `users` ADD `playlist_id` INT UNSIGNED NULL", "ALTER TABLE `users` ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY (playlist_id) REFERENCES playlists(id) ON DELETE CASCADE", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign("users_playlist_id_foreign") sql = ["ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_foreign_key_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign(["playlist_id"]) sql = ["ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_unique_constraint(self): with self.schema.table("users") as blueprint: blueprint.drop_unique("users_playlist_id_unique") sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_unique"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_index(self): with self.schema.table("users") as blueprint: blueprint.index("playlist_id") sql = ["CREATE INDEX users_playlist_id_index ON `users`(playlist_id)"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_index(self): with self.schema.table("users") as blueprint: blueprint.drop_index("users_playlist_id_index") sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_index"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") sql = [ "ALTER TABLE `users` ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)" ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_index_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_index(["playlist_id"]) sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_index"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_unique_constraint_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_unique(["playlist_id"]) sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_unique"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_primary(self): with self.schema.table("users") as blueprint: blueprint.drop_primary("users_id_primary") sql = ["ALTER TABLE `users` DROP INDEX users_id_primary"] self.assertEqual(blueprint.to_sql(), sql) def test_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").change() blueprint.string("external_type").default("external") blueprint.integer("gender").nullable().change() blueprint.string("name") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual(len(blueprint.table.changed_columns), 2) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = [ "ALTER TABLE `users` ADD `external_type` VARCHAR(255) NOT NULL DEFAULT 'external', ADD `name` VARCHAR(255) NOT NULL", "ALTER TABLE `users` MODIFY `age` INT(11) NOT NULL, MODIFY `gender` INT(11) NULL", ] self.assertEqual(blueprint.to_sql(), sql) def test_timestamp_alter_add_nullable_column(self): with self.schema.table("users") as blueprint: blueprint.timestamp("due_date").nullable() self.assertEqual(len(blueprint.table.added_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = ["ALTER TABLE `users` ADD `due_date` TIMESTAMP NULL"] self.assertEqual(blueprint.to_sql(), sql) def test_drop_add_and_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").default(0).change() blueprint.string("name") blueprint.drop_column("email") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") table.add_column("email", "string") blueprint.table.from_table = table sql = [ "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL", "ALTER TABLE `users` MODIFY `age` INT(11) NOT NULL DEFAULT 0", "ALTER TABLE `users` DROP COLUMN `email`", ] self.assertEqual(blueprint.to_sql(), sql) def test_can_create_indexes(self): with self.schema.table("users") as blueprint: blueprint.index("name") blueprint.index(["name", "email"]) blueprint.unique("name") blueprint.unique("name", name="table_unique") blueprint.unique(["name", "email"]) blueprint.fulltext("description") self.assertEqual(len(blueprint.table.added_columns), 0) print(blueprint.to_sql()) self.assertEqual( blueprint.to_sql(), [ "CREATE INDEX users_name_index ON `users`(name)", "CREATE INDEX users_name_email_index ON `users`(name,email)", "ALTER TABLE `users` ADD CONSTRAINT UNIQUE INDEX users_name_unique(name)", "ALTER TABLE `users` ADD CONSTRAINT UNIQUE INDEX table_unique(name)", "ALTER TABLE `users` ADD CONSTRAINT UNIQUE INDEX users_name_email_unique(name,email)", "ALTER TABLE `users` ADD FULLTEXT description_fulltext(description)", ], ) def test_can_add_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ "ALTER TABLE `users` ADD `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'" ] self.assertEqual(blueprint.to_sql(), sql) def test_can_change_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active").change() self.assertEqual(len(blueprint.table.changed_columns), 1) sql = [ "ALTER TABLE `users` MODIFY `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'" ] self.assertEqual(blueprint.to_sql(), sql) ================================================ FILE: tests/mysql/scopes/test_can_use_global_scopes.py ================================================ import inspect import unittest from src.masoniteorm.models import Model from src.masoniteorm.scopes import ( SoftDeleteScope, SoftDeletesMixin, TimeStampsMixin, scope, ) class UserSoft(Model, SoftDeletesMixin): __dry__ = True class User(Model): __dry__ = True class TestMySQLGlobalScopes(unittest.TestCase): def test_can_use_global_scopes_on_select(self): sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`name` = 'joe' AND `user_softs`.`deleted_at` IS NULL" self.assertEqual(sql, UserSoft.where("name", "joe").to_sql()) # def test_can_use_global_scopes_on_delete(self): # sql = "UPDATE `users` SET `users`.`deleted_at` = 'now' WHERE `users`.`name` = 'joe'" # self.assertEqual( # sql, # User.apply_scope(SoftDeletes) # .where("name", "joe") # .delete(query=True) # .to_sql(), # ) def test_can_use_global_scopes_on_time(self): sql = "INSERT INTO `users` (`users`.`name`, `users`.`updated_at`, `users`.`created_at`) VALUES ('Joe'" self.assertTrue(User.create({"name": "Joe"}, query=True).to_sql().startswith(sql)) # def test_can_use_global_scopes_on_inherit(self): # sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NULL" # self.assertEqual(sql, UserSoft.all(query=True)) ================================================ FILE: tests/mysql/scopes/test_can_use_scopes.py ================================================ import inspect import unittest from src.masoniteorm.models import Model from src.masoniteorm.scopes import SoftDeletesMixin, scope from tests.User import User class User(Model): __dry__ = True @scope def active(self, query, status): return query.where("active", status) @scope def gender(self, query, status): return query.where("gender", status) class UserSoft(Model, SoftDeletesMixin): __dry__ = True class TestMySQLScopes(unittest.TestCase): def test_can_get_sql(self): sql = "SELECT * FROM `users` WHERE `users`.`name` = 'joe'" self.assertEqual(sql, User.where("name", "joe").to_sql()) def test_active_scope(self): sql = "SELECT * FROM `users` WHERE `users`.`name` = 'joe' AND `users`.`active` = '1'" self.assertEqual(sql, User.where("name", "joe").active(1).to_sql()) def test_active_scope_with_params(self): sql = "SELECT * FROM `users` WHERE `users`.`active` = '2' AND `users`.`name` = 'joe'" self.assertEqual(sql, User.active(2).where("name", "joe").to_sql()) def test_can_chain_scopes(self): sql = "SELECT * FROM `users` WHERE `users`.`active` = '2' AND `users`.`gender` = 'W' AND `users`.`name` = 'joe'" self.assertEqual(sql, User.active(2).gender("W").where("name", "joe").to_sql()) ================================================ FILE: tests/mysql/scopes/test_soft_delete.py ================================================ import unittest import pendulum from tests.integrations.config.database import DATABASES from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import MySQLGrammar from src.masoniteorm.scopes import SoftDeleteScope from tests.utils import MockConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.scopes import SoftDeletesMixin class UserSoft(Model, SoftDeletesMixin): __dry__ = True __table__ = "users" class UserSoftArchived(Model, SoftDeletesMixin): __dry__ = True __deleted_at__ = "archived_at" __table__ = "users" class TestSoftDeleteScope(unittest.TestCase): def get_builder(self, table="users"): connection = MockConnectionFactory().make("default") return QueryBuilder( grammar=MySQLGrammar, connection_class=connection, connection="mysql", table=table, connection_details=DATABASES, dry=True, ) def test_with_trashed(self): sql = "SELECT * FROM `users`" builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual(sql, builder.with_trashed().to_sql()) def test_force_delete(self): sql = "DELETE FROM `users`" builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual(sql, builder.force_delete(query=True).to_sql()) def test_restore(self): sql = "UPDATE `users` SET `users`.`deleted_at` = 'None'" builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual(sql, builder.restore().to_sql()) def test_force_delete_with_wheres(self): sql = "DELETE FROM `users` WHERE `users`.`active` = '1'" builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual( sql, UserSoft.where("active", 1).force_delete(query=True).to_sql() ) def test_that_trashed_users_are_not_returned_by_default(self): sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL" builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual(sql, builder.to_sql()) def test_only_trashed(self): sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL" builder = self.get_builder().set_global_scope(SoftDeleteScope()) self.assertEqual(sql, builder.only_trashed().to_sql()) def test_only_trashed_on_model(self): sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL" self.assertEqual(sql, UserSoft.only_trashed().to_sql()) def test_can_change_column(self): sql = "SELECT * FROM `users` WHERE `users`.`archived_at` IS NOT NULL" self.assertEqual(sql, UserSoftArchived.only_trashed().to_sql()) def test_find_with_global_scope(self): find_sql = UserSoft.find("1", query=True).to_sql() raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NULL""" self.assertEqual(find_sql, raw_sql) def test_find_with_trashed_scope(self): find_sql = UserSoft.with_trashed().find("1", query=True).to_sql() raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1'""" self.assertEqual(find_sql, raw_sql) def test_find_with_only_trashed_scope(self): find_sql = UserSoft.only_trashed().find("1", query=True).to_sql() raw_sql = """SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL AND `users`.`id` = '1'""" self.assertEqual(find_sql, raw_sql) ================================================ FILE: tests/postgres/builder/test_postgres_query_builder.py ================================================ import inspect import unittest from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar from tests.utils import MockConnectionFactory class MockConnection: connection_details = {} def make_connection(self): return self @classmethod def get_default_query_grammar(cls): return class ModelTest(Model): __timestamps__ = False class BaseTestQueryBuilder: def get_builder(self, table="users", dry=True): connection = MockConnectionFactory().make("postgres") return QueryBuilder( self.grammar, connection_class=connection, connection="postgres", table=table, model=ModelTest(), dry=dry, ) def test_sum(self): builder = self.get_builder() builder.sum("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_max(self): builder = self.get_builder() builder.max("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_min(self): builder = self.get_builder() builder.min("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_avg(self): builder = self.get_builder() builder.avg("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_all(self): builder = self.get_builder() builder.all() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_get(self): builder = self.get_builder() builder.get() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_first(self): builder = self.get_builder().first(query=True) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select(self): builder = self.get_builder() builder.select("name", "email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_add_select_no_table(self): builder = self.get_builder(table=None) sql = ( builder.add_select( "other_test", lambda q: q.max("updated_at").table("different_table") ) .add_select( "some_alias", lambda q: q.max("updated_at").table("another_table") ) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_create(self): builder = self.get_builder().without_global_scopes() builder.create( {"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_delete(self): builder = self.get_builder() builder.delete("name", "Joe", query=True) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where(self): builder = self.get_builder() builder.where("name", "Joe") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_limit(self): builder = self.get_builder() builder.limit(5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_offset(self): builder = self.get_builder() builder.offset(5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_join(self): builder = self.get_builder() builder.join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_left_join(self): builder = self.get_builder() builder.left_join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_right_join(self): builder = self.get_builder() builder.right_join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_update(self): builder = self.get_builder().update( {"name": "Joe", "email": "joe@yopmail.com"}, dry=True ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) # def test_increment(self): # builder = self.get_builder() # builder.increment("age", 1) # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(builder.to_sql(), sql) # def test_decrement(self): # builder = self.get_builder() # builder.decrement("age", 1) # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(builder.to_sql(), sql) def test_count(self): builder = self.get_builder() builder.count("id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_column(self): builder = self.get_builder() builder.where_column("name", "username") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_in(self): builder = self.get_builder() builder.where_not_in("id", [1, 2, 3]) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_between(self): builder = self.get_builder() builder.between("id", 2, 5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_not_between(self): builder = self.get_builder() builder.not_between("id", 2, 5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_in(self): builder = self.get_builder() builder.where_in("id", [1, 2, 3]) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_null(self): builder = self.get_builder() builder.where_null("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_null(self): builder = self.get_builder() builder.where_not_null("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_having(self): builder = self.get_builder(table="payments") builder.select("user_id").avg("salary").group_by("user_id").having( "salary", ">=", "1000" ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_builder_alone(self): self.assertTrue( QueryBuilder( connection_class=MockConnection, connection="postgres", connection_details={ "default": "postgres", "postgres": { "driver": "postgres", "host": "localhost", "user": "postgres", "password": "postgres", "database": "orm", "port": "5432", "prefix": "", "grammar": "postgres", }, }, ).table("users") ) def test_where_lt(self): builder = self.get_builder() builder.where("age", "<", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_lte(self): builder = self.get_builder() builder.where("age", "<=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_gt(self): builder = self.get_builder() builder.where("age", ">", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_gte(self): builder = self.get_builder() builder.where("age", ">=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_ne(self): builder = self.get_builder() builder.where("age", "!=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_can_call_with_schema(self): builder = self.get_builder() sql = ( builder.table("information_schema.columns") .select("table_name") .where("table_name", "users") .to_sql() ) self.assertEqual( sql, """SELECT "information_schema"."columns"."table_name" FROM "information_schema"."columns" WHERE "information_schema"."columns"."table_name" = 'users'""", ) def test_truncate(self): builder = self.get_builder(dry=True) sql = builder.truncate() sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) def test_truncate_without_foreign_keys(self): builder = self.get_builder(dry=True) sql = builder.truncate(foreign_keys=True) sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) def test_shared_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).shared_lock().to_sql() sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) def test_update_lock(self): builder = self.get_builder(dry=True) sql = builder.where("votes", ">=", 100).lock_for_update().to_sql() sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) class PostgresQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase): grammar = PostgresGrammar def sum(self): """ builder = self.get_builder() builder.sum('age') """ return """SELECT SUM("users"."age") AS age FROM "users\"""" def max(self): """ builder = self.get_builder() builder.max('age') """ return """SELECT MAX("users"."age") AS age FROM "users\"""" def min(self): """ builder = self.get_builder() builder.min('age') """ return """SELECT MIN("users"."age") AS age FROM "users\"""" def avg(self): """ builder = self.get_builder() builder.avg('age') """ return """SELECT AVG("users"."age") AS age FROM "users\"""" def first(self): """ builder = self.get_builder() builder.first() """ return """SELECT * FROM "users" LIMIT 1""" def all(self): """ builder = self.get_builder() builder.all() """ return """SELECT * FROM "users\"""" def get(self): """ builder = self.get_builder() builder.get() """ return """SELECT * FROM "users\"""" def select(self): """ builder = self.get_builder() builder.select('name', 'email') """ return """SELECT "users"."name", "users"."email" FROM "users\"""" def add_select_no_table(self): """ builder = self.get_builder() builder.select('name', 'email') """ return ( "SELECT " '(SELECT MAX("different_table"."updated_at") AS updated_at FROM "different_table") AS other_test, ' '(SELECT MAX("another_table"."updated_at") AS updated_at FROM "another_table") AS some_alias' ) def select_raw(self): """ builder = self.get_builder() builder.select_raw('count(email) as email_count') """ return """SELECT count(email) as email_count FROM "users\"""" def create(self): """ builder = get_builder() builder.create({"name": "Corentin All", 'email': 'corentin@yopmail.com'}) """ return """INSERT INTO "users" ("name", "email") VALUES ('Corentin All', 'corentin@yopmail.com') RETURNING *""" def delete(self): """ builder = get_builder() builder.delete("name', 'Joe') """ return """DELETE FROM "users" WHERE "users"."name" = 'Joe'""" def where(self): """ builder = get_builder() builder.where('name', 'Joe') """ return """SELECT * FROM "users" WHERE "users"."name" = 'Joe'""" def where_exists(self): """ builder = get_builder() builder.where_exists('name') """ return """SELECT * FROM "users" WHERE EXISTS 'name'""" def limit(self): """ builder = get_builder() builder.limit(5) """ return """SELECT * FROM "users" LIMIT 5""" def offset(self): """ builder = get_builder() builder.offset(5) """ return """SELECT * FROM "users" OFFSET 5""" def join(self): """ builder.join("profiles", "users.id", "=", "profiles.user_id") """ return """SELECT * FROM "users" INNER JOIN "profiles" ON "users"."id" = "profiles"."user_id\"""" def left_join(self): """ builder.left_join("profiles", "users.id", "=", "profiles.user_id") """ return """SELECT * FROM "users" LEFT JOIN "profiles" ON "users"."id" = "profiles"."user_id\"""" def right_join(self): """ builder.right_join("profiles", "users.id", "=", "profiles.user_id") """ return """SELECT * FROM "users" RIGHT JOIN "profiles" ON "users"."id" = "profiles"."user_id\"""" def update(self): """ builder.update({"name": "Joe", "email": "joe@yopmail.com"}) """ return """UPDATE "users" SET "name" = 'Joe', "email" = 'joe@yopmail.com'""" def increment(self): """ builder.increment('age', 1) """ return """UPDATE "users" SET "age" = "age" + '1'""" def decrement(self): """ builder.decrement('age', 1) """ return """UPDATE "users" SET "age" = "age" - '1'""" def count(self): """ builder.count(id) """ return """SELECT COUNT("users"."id") AS id FROM "users\"""" def order_by_asc(self): """ builder.order_by('email', 'asc') """ return """SELECT * FROM "users" ORDER BY "email" ASC""" def order_by_desc(self): """ builder.order_by('email', 'des') """ return """SELECT * FROM "users" ORDER BY "email" DESC""" def where_column(self): """ builder.where_column('name', 'username') """ return """SELECT * FROM "users" WHERE "users"."name" = "users"."username\"""" def where_null(self): """ builder.where_null('name') """ return """SELECT * FROM "users" WHERE "users"."name" IS NULL""" def where_not_null(self): """ builder.where_null('name') """ return """SELECT * FROM "users" WHERE "users"."name" IS NOT NULL""" def where_not_in(self): """ builder.where_not_in('id', [1, 2, 3]) """ return """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')""" def where_in(self): """ builder.where_in('id', [1, 2, 3]) """ return """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')""" def between(self): """ builder.between('id', 2, 5) """ return """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'""" def not_between(self): """ builder.not_between('id', 2, 5) """ return """SELECT * FROM "users" WHERE "users"."id" NOT BETWEEN '2' AND '5'""" def having(self): """ builder.select('user_id').avg('salary').group_by('user_id').having('salary', '>=', '1000') """ return """SELECT "payments"."user_id", AVG("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id" HAVING "payments"."salary" >= '1000'""" def group_by(self): """ builder.select('user_id').min('salary').group_by('user_id') """ return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id\"""" def where_lt(self): """ builder = self.get_builder() builder.where('age', '<', '20') """ return """SELECT * FROM "users" WHERE "users"."age" < '20'""" def where_lte(self): """ builder = self.get_builder() builder.where('age', '<=', '20') """ return """SELECT * FROM "users" WHERE "users"."age" <= '20'""" def where_gt(self): """ builder = self.get_builder() builder.where('age', '>', '20') """ return """SELECT * FROM "users" WHERE "users"."age" > '20'""" def where_gte(self): """ builder = self.get_builder() builder.where('age', '>=', '20') """ return """SELECT * FROM "users" WHERE "users"."age" >= '20'""" def where_ne(self): """ builder = self.get_builder() builder.where('age', '!=', '20') """ return """SELECT * FROM "users" WHERE "users"."age" != '20'""" def or_where(self): """ builder = self.get_builder() builder.where('age', '20').or_where('age','<', 20) """ return """SELECT * FROM "users" WHERE "users"."age" = '20' OR "users"."age" < '20'""" def where_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%") """ return """SELECT * FROM "users" WHERE "users"."age" ILIKE '%name%'""" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%") """ return """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'""" def truncate(self): """ builder = self.get_builder() builder.truncate() """ return """TRUNCATE TABLE "users\"""" def truncate_without_foreign_keys(self): """ builder = self.get_builder() builder.truncate() """ return """TRUNCATE TABLE "users\"""" def update_lock(self): """ builder = self.get_builder() builder.truncate() """ return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR UPDATE""" def shared_lock(self): """ builder = self.get_builder() builder.truncate() """ return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR SHARE""" def test_latest(self): builder = self.get_builder() builder.latest("email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_oldest(self): builder = self.get_builder() builder.oldest("email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def oldest(self): """ builder.order_by('email', 'asc') """ return """SELECT * FROM "users" ORDER BY "email" ASC""" def latest(self): """ builder.order_by('email', 'des') """ return """SELECT * FROM "users" ORDER BY "email" DESC""" ================================================ FILE: tests/postgres/builder/test_postgres_transaction.py ================================================ import inspect import os import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar from src.masoniteorm.relationships import belongs_to from tests.utils import MockConnectionFactory if os.getenv("RUN_POSTGRES_DATABASE") == "True": class User(Model): __connection__ = "postgres" __timestamps__ = False class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users"): connection = ConnectionFactory().make("postgres") return QueryBuilder( grammar=PostgresGrammar, connection=connection, table=table, connection_details=DATABASES, ).on("postgres") def test_transaction(self): builder = self.get_builder() builder.begin() builder.create({"name": "phillip2", "email": "phillip2"}) # builder.commit() user = builder.where("name", "phillip2").first() self.assertEqual(user["name"], "phillip2") builder.rollback() user = builder.where("name", "phillip2").first() self.assertEqual(user, None) ================================================ FILE: tests/postgres/grammar/test_delete_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar class BaseDeleteGrammarTest: def setUp(self): self.builder = QueryBuilder(PostgresGrammar, table="users") def test_can_compile_delete(self): to_sql = self.builder.delete("id", 1, query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_delete_in(self): to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_delete_with_where(self): to_sql = ( self.builder.where("age", 20) .where("profile", 1) .set_action("delete") .delete(query=True) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestPostgresDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase): grammar = "postgres" def can_compile_delete(self): """ ( self.builder .delete('id', 1) .to_sql() ) """ return """DELETE FROM "users" WHERE "users"."id" = '1'""" def can_compile_delete_in(self): """ ( self.builder .delete('id', 1) .to_sql() ) """ return """DELETE FROM "users" WHERE "users"."id" IN ('1','2','3')""" def can_compile_delete_with_where(self): """ ( self.builder .where('age', 20) .where('profile', 1) .set_action('delete') .delete() .to_sql() ) """ return """DELETE FROM "users" WHERE "users"."age" = '20' AND "users"."profile" = '1'""" ================================================ FILE: tests/postgres/grammar/test_insert_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar class BaseInsertGrammarTest: def setUp(self): self.builder = QueryBuilder(PostgresGrammar, table="users") def test_can_compile_insert(self): to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_insert_with_keywords(self): to_sql = self.builder.create(name="Joe", query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( # These keys are intentionally out of order to show column to value alignment works [ {"name": "Joe", "age": 5}, {"age": 35, "name": "Bill"}, {"name": "John", "age": 10}, ], query=True, ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): to_sql = self.builder.bulk_create( [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True ).to_qmark() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestPostgresUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase): grammar = "postgres" def can_compile_insert(self): """ self.builder.create({ 'name': 'Joe' }).to_sql() """ return """INSERT INTO "users" ("name") VALUES ('Joe') RETURNING *""" def can_compile_insert_with_keywords(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO "users" ("name") VALUES ('Joe') RETURNING *""" def can_compile_bulk_create(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO "users" ("age", "name") VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John') RETURNING *""" def can_compile_bulk_create_qmark(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO "users" ("name") VALUES ('?'), ('?'), ('?') RETURNING *""" ================================================ FILE: tests/postgres/grammar/test_select_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query.grammars import PostgresGrammar from src.masoniteorm.testing import BaseTestCaseSelectGrammar class TestPostgresGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): grammar = PostgresGrammar def can_compile_select(self): """ self.builder.to_sql() """ return """SELECT * FROM "users\"""" def can_compile_with_columns(self): """ self.builder.select('username', 'password').to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users\"""" def can_compile_with_where(self): """ self.builder.select('username', 'password').where('id', 1).to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1'""" def can_compile_with_several_where(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe'""" def can_compile_with_several_where_and_limit(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe' LIMIT 10""" def can_compile_with_sum(self): """ self.builder.sum('age').to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users\"""" def can_compile_with_max(self): """ self.builder.max('age').to_sql() """ return """SELECT MAX("users"."age") AS age FROM "users\"""" def can_compile_with_max_and_columns(self): """ self.builder.select('username').max('age').to_sql() """ return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\"""" def can_compile_with_max_and_columns_different_order(self): """ self.builder.max('age').select('username').to_sql() """ return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\"""" def can_compile_with_order_by(self): """ self.builder.select('username').order_by('age', 'desc').to_sql() """ return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC""" def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ return ( """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" ) def can_compile_with_group_by(self): """ self.builder.select('username').group_by('age').to_sql() """ return """SELECT "users"."username" FROM "users" GROUP BY "users"."age\"""" def can_compile_where_in(self): """ self.builder.select('username').where_in('age', [1,2,3]).to_sql() """ return """SELECT "users"."username" FROM "users" WHERE "users"."age" IN ('1','2','3')""" def can_compile_where_in_empty(self): """ self.builder.where_in('age', []).to_sql() """ return """SELECT * FROM "users" WHERE 0 = 1""" def can_compile_where_not_in(self): """ self.builder.select('username').where_not_in('age', [1,2,3]).to_sql() """ return """SELECT "users"."username" FROM "users" WHERE "users"."age" NOT IN ('1','2','3')""" def can_compile_where_null(self): """ self.builder.select('username').where_null('age').to_sql() """ return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NULL""" def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ return ( """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" ) def can_compile_where_raw(self): """ self.builder.where_raw(""age" = '18'").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" = '18'""" def can_compile_having_raw(self): """ self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql() """ return """SELECT COUNT(*) as counts FROM "users" HAVING counts > 18""" def can_compile_select_raw(self): """ self.builder.select_raw("COUNT(*)").to_sql() """ return """SELECT COUNT(*) FROM "users\"""" def can_compile_limit_and_offset(self): """ self.builder.limit(10).offset(10).to_sql() """ return """SELECT * FROM "users" LIMIT 10 OFFSET 10""" def can_compile_select_raw_with_select(self): """ self.builder.select('id').select_raw("COUNT(*)").to_sql() """ return """SELECT "users"."id", COUNT(*) FROM "users\"""" def can_compile_count(self): """ self.builder.count().to_sql() """ return """SELECT COUNT(*) AS m_count_reserved FROM "users\"""" def can_compile_count_column(self): """ self.builder.count().to_sql() """ return """SELECT COUNT("users"."money") AS money FROM "users\"""" def can_compile_where_column(self): """ self.builder.where_column('name', 'email').to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" = "users"."email\"""" def can_compile_or_where(self): """ self.builder.where('name', 2).or_where('name', 3).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" = '2' OR "users"."name" = '3'""" def can_grouped_where(self): """ self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql() """ return """SELECT * FROM "users" WHERE ("users"."age" = '2' AND "users"."name" = 'Joe')""" def can_compile_sub_select(self): """ self.builder.where_in('name', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age') ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users")""" def can_compile_sub_select_from_lambda(self): """ self.builder.where_in('name', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age') ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users")""" def can_compile_sub_select_where(self): """ self.builder.where_in('age', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe') ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" IN (SELECT "users"."age" FROM "users" WHERE "users"."age" = '2' AND "users"."name" = 'Joe')""" def can_compile_sub_select_value(self): """ self.builder.where('name', self.builder.new().sum('age') ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")""" def can_compile_complex_sub_select(self): """ self.builder.where_in('name', (QueryBuilder(GrammarFactory.make(self.grammar), table='users') .select('age').where_in('email', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email') )) ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users" WHERE "users"."email" IN (SELECT "users"."email" FROM "users"))""" def can_compile_exists(self): """ self.builder.select('age').where_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return """SELECT "users"."age" FROM "users" WHERE EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')""" def can_compile_not_exists(self): """ self.builder.select('age').where_not_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return """SELECT "users"."age" FROM "users" WHERE NOT EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')""" def can_compile_having(self): """ builder.sum('age').group_by('age').having('age').to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" def can_compile_having_order(self): """ builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" ORDER "users"."age" DESC""" def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" = '10'""" def can_compile_order_by_and_first(self): """ self.builder.order_by('id', 'asc').first() """ return """SELECT * FROM "users" ORDER BY "id" ASC LIMIT 1""" def can_compile_having_with_greater_than_expression(self): """ builder.sum('age').group_by('age').having('age', '>', 10).to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" > '10'""" def can_compile_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id\"""" def can_compile_left_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return """SELECT * FROM "users" LEFT JOIN "contacts" ON "users"."id" = "contacts"."user_id\"""" def can_compile_multiple_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id" INNER JOIN "posts" ON "comments"."post_id" = "posts"."id\"""" def can_compile_between(self): """ builder.between('age', 18, 21).to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" BETWEEN '18' AND '21'""" def can_compile_not_between(self): """ builder.not_between('age', 18, 21).to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" NOT BETWEEN '18' AND '21'""" def test_can_compile_where_raw(self): to_sql = self.builder.where_raw(""" "age" = '18'""").to_sql() self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""") def test_can_compile_having_raw(self): to_sql = ( self.builder.select_raw("COUNT(*) as counts") .having_raw("counts > 10") .to_sql() ) self.assertEqual( to_sql, """SELECT COUNT(*) as counts FROM "users" HAVING counts > 10""" ) def test_can_compile_having_raw_order(self): to_sql = ( self.builder.select_raw("COUNT(*) as counts") .having_raw("counts > 10") .order_by_raw("counts DESC") .to_sql() ) self.assertEqual( to_sql, """SELECT COUNT(*) as counts FROM "users" HAVING counts > 10 ORDER BY counts DESC""", ) def test_can_compile_where_raw_and_where_with_multiple_bindings(self): query = self.builder.where_raw( """ "age" = '?' AND "is_admin" = '?'""", [18, True] ).where("email", "test@example.com") self.assertEqual( query.to_qmark(), """SELECT * FROM "users" WHERE "age" = '?' AND "is_admin" = '?' AND "users"."email" = '?'""", ) self.assertEqual(query._bindings, [18, True, "test@example.com"]) def test_can_compile_select_raw(self): to_sql = self.builder.select_raw("COUNT(*)").to_sql() self.assertEqual(to_sql, """SELECT COUNT(*) FROM "users\"""") def test_can_compile_select_raw_with_select(self): to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql() self.assertEqual(to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""") def can_compile_first_or_fail(self): """ builder = self.get_builder() builder.where("is_admin", "=", True).first_or_fail() """ return """SELECT * FROM "users" WHERE "users"."is_admin" IS True LIMIT 1""" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'""" def where_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" ILIKE '%name%'""" def where_regexp(self): """ builder = self.get_builder() builder.where("age", "regexp", "Joe").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" REGEXP 'Joe'""" def where_not_regexp(self): """ builder = self.get_builder() builder.where("age", "regexp", "Joe").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" NOT REGEXP 'Joe'""" def can_compile_join_clause(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on("bgt.fund", "=", "rg.fund") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt"."dept" = "rg"."dept" AND "bgt"."acct" = "rg"."acct" AND "bgt"."sub" = "rg"."sub\"""" def can_compile_join_clause_with_value(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."active" = '1' OR "bgt"."acct" = '1234'""" def can_compile_join_clause_with_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") .or_on_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NULL AND "rg"."abc" = '10'""" def can_compile_join_clause_with_not_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NOT NULL OR "dept" IS NOT NULL AND "rg"."abc" = '10'""" def can_compile_join_clause_with_lambda(self): """ builder = self.get_builder() builder.join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .on_null("bgt") ), ).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt" IS NULL""" def can_compile_left_join_clause_with_lambda(self): """ builder = self.get_builder() builder.left_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return """SELECT * FROM "users" LEFT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL""" def can_compile_right_join_clause_with_lambda(self): """ builder = self.get_builder() builder.right_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return """SELECT * FROM "users" RIGHT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL""" def shared_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR SHARE""" def update_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR UPDATE""" def can_user_where_raw_and_where(self): """ builder.where_raw("`age` = '18'").where("name", "=", "James").to_sql() """ return """SELECT * FROM "users" WHERE age = '18' AND "users"."name" = 'James'""" def where_exists_with_lambda(self): return """SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_not_exists_with_lambda(self): return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_date(self): return ( """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" ) def or_where_null(self): return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL""" def select_distinct(self): return """SELECT DISTINCT "users"."group" FROM "users\"""" ================================================ FILE: tests/postgres/grammar/test_update_grammar.py ================================================ import inspect import unittest from src.masoniteorm.connections import PostgresConnection from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import PostgresGrammar from src.masoniteorm.expressions import Raw class BaseTestCaseUpdateGrammar: def setUp(self): self.builder = QueryBuilder( PostgresGrammar, connection_class=PostgresConnection, table="users" ) def test_can_compile_update(self): to_sql = ( self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_multiple_update(self): to_sql = self.builder.update( {"name": "Joe", "email": "user@email.com"}, dry=True ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): to_sql = ( self.builder.where("name", "bob") .where("age", 20) .update({"name": "Joe"}, dry=True) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) # def test_can_compile_increment(self): # to_sql = self.builder.increment("age").to_sql() # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(to_sql, sql) # def test_can_compile_decrement(self): # to_sql = self.builder.decrement("age", 20).to_sql() # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(to_sql, sql) def test_raw_expression(self): to_sql = self.builder.update({"name": Raw('"username"')}, dry=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_update_null(self): to_sql = self.builder.update({"name": None}, dry=True).to_sql() print(to_sql) sql = """UPDATE "users" SET "name" = \'None\'""" self.assertEqual(to_sql, sql) class TestPostgresUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase): grammar = "postgres" def can_compile_update(self): """ builder.where('name', 'bob').update({ 'name': 'Joe' }).to_sql() """ return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob'""" def raw_expression(self): """ builder.where('name', 'bob').update({ 'name': 'Joe' }).to_sql() """ return """UPDATE "users" SET "name" = "username\"""" def can_compile_multiple_update(self): """ self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql() """ return """UPDATE "users" SET "name" = 'Joe', "email" = 'user@email.com'""" def can_compile_update_with_multiple_where(self): """ builder.where('name', 'bob').where('age', 20).update({ 'name': 'Joe' }).to_sql() """ return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob' AND "age" = '20'""" def can_compile_increment(self): """ builder.increment('age').to_sql() """ return """UPDATE "users" SET "age" = "age" + '1'""" def can_compile_decrement(self): """ builder.decrement('age', 20).to_sql() """ return """UPDATE "users" SET "age" = "age" - '20'""" ================================================ FILE: tests/postgres/relationships/test_postgres_relationships.py ================================================ import os import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import belongs_to, has_many if os.getenv("RUN_POSTGRES_DATABASE", False) == "True": class Profile(Model): __table__ = "profiles" __connection__ = "postgres" class Articles(Model): __table__ = "articles" __connection__ = "postgres" @belongs_to("id", "article_id") def logo(self): return Logo class Logo(Model): __table__ = "logos" __connection__ = "postgres" class User(Model): __connection__ = "postgres" _eager_loads = () __casts__ = {"is_admin": "bool"} @belongs_to("id", "user_id") def profile(self): return Profile @has_many("id", "user_id") def articles(self): return Articles def get_is_admin(self): return "You are an admin" class TestRelationships(unittest.TestCase): maxDiff = None def test_relationship_can_be_callable(self): self.assertEqual( User.profile().where("name", "Joe").to_sql(), """SELECT * FROM "profiles" WHERE "profiles"."name" = 'Joe'""", ) def test_can_access_relationship(self): for user in User.where("id", 1).get(): self.assertIsInstance(user.profile, Profile) def test_can_access_has_many_relationship(self): user = User.hydrate(User.where("id", 1).first()) self.assertEqual(len(user.articles), 4) def test_can_access_relationship_multiple_times(self): user = User.hydrate(User.where("id", 1).first()) self.assertEqual(len(user.articles), 4) self.assertEqual(len(user.articles), 4) def test_loading(self): users = User.with_("articles").get() for user in users: user def test_casting(self): users = User.with_("articles").where("is_admin", True).get() for user in users: user def test_setting(self): users = User.with_("articles").where("is_admin", True).get() for user in users: user.name = "Joe" user.is_admin = 1 user.save() def test_relationship_has(self): to_sql = User.has("articles").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\"""" """)""", ) def test_relationship_has_off_builder(self): to_sql = User.where("active", 1).has("articles").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE "users"."active" = '1' AND EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\"""" """)""", ) def test_relationship_multiple_has(self): to_sql = User.has("articles", "profile").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\"""" """) AND EXISTS (""" """SELECT * FROM "profiles" WHERE "profiles"."user_id" = "users"."id\"""" """)""", ) count = User.has("articles", "profile").get().count() self.assertEqual(count, 2) def test_nested_has(self): to_sql = User.has("articles.logo").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND EXISTS (SELECT * FROM "logos" WHERE "logos"."article_id" = "articles"."id"))""", ) count = User.has("articles.logo").get().count() self.assertEqual(count, 2) def test_relationship_where_has(self): to_sql = User.where_has("articles", lambda q: q.where("status", 1)).to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND "articles"."status" = '1'""" """)""", ) ================================================ FILE: tests/postgres/schema/test_postgres_schema_builder.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import PostgresConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import PostgresPlatform class TestPostgresSchemaBuilder(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection_class=PostgresConnection, connection="postgres", connection_details=DATABASES, platform=PostgresPlatform, dry=True, ) def test_can_add_columns(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' ], ) def test_can_add_tiny_text(self): with self.schema.create("users") as blueprint: blueprint.tiny_text("description") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)'] ) def test_can_add_unsigned_decimal(self): with self.schema.create("users") as blueprint: blueprint.unsigned_decimal("amount", 19, 4) self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ['CREATE TABLE "users" ("amount" DECIMAL(19, 4) NOT NULL)'], ) def test_can_create_table_if_not_exists(self): with self.schema.create_table_if_not_exists("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' ], ) def test_can_add_column_comment(self): with self.schema.create("users") as blueprint: blueprint.string("name").comment("A users username") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL)', """COMMENT ON COLUMN "users"."name" is 'A users username'""", ], ) def test_can_add_table_comment(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.table_comment("A users table") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL)', """COMMENT ON TABLE "users" is 'A users table'""", ], ) def test_can_truncate(self): sql = self.schema.truncate("users") self.assertEqual(sql, 'TRUNCATE "users"') def test_can_rename_table(self): sql = self.schema.rename("users", "clients") self.assertEqual(sql, 'ALTER TABLE "users" RENAME TO "clients"') def test_can_drop_table_if_exists(self): sql = self.schema.drop_table_if_exists("users", "clients") self.assertEqual(sql, 'DROP TABLE IF EXISTS "users"') def test_can_drop_table(self): sql = self.schema.drop_table("users", "clients") self.assertEqual(sql, 'DROP TABLE "users"') def test_has_column(self): sql = self.schema.has_column("users", "name") self.assertEqual( sql, "SELECT column_name FROM information_schema.columns WHERE table_name='users' and column_name='name'", ) def test_can_add_columns_with_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") blueprint.unique("name") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, CONSTRAINT users_name_unique UNIQUE (name))' ], ) def test_can_add_columns_with_long_text(self): with self.schema.create("users") as blueprint: blueprint.long_text("description") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)'] ) def test_can_have_unsigned_columns(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id").unsigned() blueprint.big_integer("big_profile_id").unsigned() blueprint.tiny_integer("tiny_profile_id").unsigned() blueprint.small_integer("small_profile_id").unsigned() blueprint.medium_integer("medium_profile_id").unsigned() self.assertEqual( blueprint.to_sql(), [ """CREATE TABLE "users" (""" """"profile_id" INTEGER NOT NULL, """ """"big_profile_id" BIGINT NOT NULL, """ """"tiny_profile_id" TINYINT NOT NULL, """ """"small_profile_id" SMALLINT NOT NULL, """ """"medium_profile_id" MEDIUMINT NOT NULL)""" ], ) def test_can_add_columns_with_foreign_key_constaint(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.foreign("profile_id").references("id").on("profiles") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, ' '"profile_id" INTEGER NOT NULL, CONSTRAINT users_name_unique UNIQUE (name), ' 'CONSTRAINT users_profile_id_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))' ], ) def test_can_advanced_table_creation(self): with self.schema.create("users") as blueprint: blueprint.increments("id") blueprint.string("name") blueprint.string("email").unique() blueprint.string("password") blueprint.integer("admin").default(0) blueprint.string("remember_token").nullable() blueprint.timestamp("verified_at").nullable() blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 9) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("id" SERIAL UNIQUE NOT NULL, "name" VARCHAR(255) NOT NULL, ' '"email" VARCHAR(255) NOT NULL, "password" VARCHAR(255) NOT NULL, "admin" INTEGER NOT NULL DEFAULT 0, ' '"remember_token" VARCHAR(255) NULL, "verified_at" TIMESTAMP NULL, ' '"created_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, "updated_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, ' "CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_email_unique UNIQUE (email))" ], ) def test_can_advanced_table_creation2(self): with self.schema.create("users") as blueprint: blueprint.big_increments("id") blueprint.string("name") blueprint.enum("gender", ["male", "female"]) blueprint.string("duration") blueprint.decimal("money") blueprint.string("url") blueprint.string("option").default("ADMIN") blueprint.jsonb("payload") blueprint.inet("last_address").nullable() blueprint.cidr("route_origin").nullable() blueprint.macaddr("mac_address").nullable() blueprint.datetime("published_at") blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.double("amount").default(0.0) blueprint.integer("author_id").unsigned().nullable() blueprint.foreign("author_id").references("id").on("authors").on_delete( "CASCADE" ) blueprint.text("description") blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 19) self.assertEqual( blueprint.to_sql(), ( [ """CREATE TABLE "users" ("id" BIGSERIAL UNIQUE NOT NULL, "name" VARCHAR(255) NOT NULL, "gender" VARCHAR(255) CHECK(gender IN ('male', 'female')) NOT NULL, """ """"duration" VARCHAR(255) NOT NULL, "money" DECIMAL(17, 6) NOT NULL, "url" VARCHAR(255) NOT NULL, "option" VARCHAR(255) NOT NULL DEFAULT 'ADMIN', "payload" JSONB NOT NULL, "last_address" INET NULL, """ '"route_origin" CIDR NULL, "mac_address" MACADDR NULL, "published_at" TIMESTAMPTZ NOT NULL, "thumbnail" VARCHAR(255) NULL, "premium" INTEGER NOT NULL, "amount" DOUBLE PRECISION NOT NULL DEFAULT 0.0, ' '"author_id" INTEGER NULL, "description" TEXT NOT NULL, "created_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, ' '"updated_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, ' 'CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY ("author_id") REFERENCES "authors"("id") ON DELETE CASCADE)' ] ), ) def test_can_add_uuid_column(self): # might not be the right place for this test + other column types # are not tested => just for testing the PR now with self.schema.create("users") as table: table.uuid("id").default_raw("uuid_generate_v4()") table.primary("id") table.string("name") table.uuid("public_id").nullable() table.uuid("other_id").default_raw("uuid_generate_v4()") self.assertEqual(len(table.table.added_columns), 4) self.assertEqual( table.to_sql(), [ 'CREATE TABLE "users" ("id" UUID NOT NULL DEFAULT uuid_generate_v4(), "name" VARCHAR(255) NOT NULL, "public_id" UUID NULL, "other_id" UUID NOT NULL DEFAULT uuid_generate_v4(), CONSTRAINT users_id_primary PRIMARY KEY (id))' ], ) def test_can_add_columns_with_foreign_key_constraint_name(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id") blueprint.foreign("profile_id", name="profile_foreign").references("id").on( "profiles" ) self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" (' '"profile_id" INTEGER NOT NULL, ' 'CONSTRAINT profile_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))' ], ) def test_can_have_composite_keys(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.primary(["name", "age"]) self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ' '("name" VARCHAR(255) NOT NULL, ' '"age" INTEGER NOT NULL, ' '"profile_id" INTEGER NOT NULL, ' "CONSTRAINT users_name_unique UNIQUE (name), " "CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))" ], ) def test_can_have_column_primary_key(self): with self.schema.create("users") as blueprint: blueprint.string("name").primary() blueprint.integer("age") blueprint.integer("profile_id") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ' '("name" VARCHAR(255) NOT NULL, ' '"age" INTEGER NOT NULL, ' '"profile_id" INTEGER NOT NULL, ' "CONSTRAINT users_name_primary PRIMARY KEY (name))" ], ) def test_can_add_other_integer_types_column(self): with self.schema.create("integer_types") as table: table.tiny_integer("tiny") table.small_integer("small") table.medium_integer("medium") table.big_integer("big") self.assertEqual(len(table.table.added_columns), 4) self.assertEqual( table.to_sql(), [ 'CREATE TABLE "integer_types" ("tiny" TINYINT NOT NULL, "small" SMALLINT NOT NULL, "medium" MEDIUMINT NOT NULL, "big" BIGINT NOT NULL)' ], ) def test_can_add_binary_column(self): with self.schema.create("binary_storing") as table: table.binary("filecontent") self.assertEqual(len(table.table.added_columns), 1) self.assertEqual( table.to_sql(), ['CREATE TABLE "binary_storing" ("filecontent" BYTEA NOT NULL)'], ) def test_can_have_float_type(self): with self.schema.create("users") as blueprint: blueprint.float("amount") self.assertEqual( blueprint.to_sql(), ["""CREATE TABLE "users" (""" """\"amount" FLOAT(19, 4) NOT NULL)"""], ) def test_can_enable_foreign_keys(self): sql = self.schema.enable_foreign_key_constraints() self.assertEqual(sql, "") def test_can_disable_foreign_keys(self): sql = self.schema.disable_foreign_key_constraints() self.assertEqual(sql, "") def test_can_truncate_without_foreign_keys(self): sql = self.schema.truncate("users", foreign_keys=True) self.assertEqual( sql, [ 'ALTER TABLE "users" DISABLE TRIGGER ALL', 'TRUNCATE "users"', 'ALTER TABLE "users" ENABLE TRIGGER ALL', ], ) def test_can_add_enum(self): with self.schema.create("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL ' 'DEFAULT \'active\')' ], ) ================================================ FILE: tests/postgres/schema/test_postgres_schema_builder_alter.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import PostgresConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import PostgresPlatform from src.masoniteorm.schema.Table import Table class TestPostgresSchemaBuilderAlter(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection_class=PostgresConnection, connection="postgres", connection_details=DATABASES, platform=PostgresPlatform, dry=True, ).on("postgres") def test_can_add_columns(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL, ADD COLUMN "age" INTEGER NOT NULL' ] self.assertEqual(blueprint.to_sql(), sql) def test_can_add_column_comments(self): with self.schema.table("users") as blueprint: blueprint.string("name").comment("A users username") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL', """COMMENT ON COLUMN "users"."name" is 'A users username'""", ] self.assertEqual(blueprint.to_sql(), sql) def test_can_add_table_comment(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.table_comment("A users table") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL', """COMMENT ON TABLE "users" is 'A users table'""", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_rename(self): with self.schema.table("users") as blueprint: blueprint.rename("post", "comment", "integer") table = Table("users") table.add_column("post", "integer") blueprint.table.from_table = table sql = ['ALTER TABLE "users" RENAME COLUMN "post" TO "comment"'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_and_rename(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.rename("post", "comment", "integer") table = Table("users") table.add_column("post", "integer") blueprint.table.from_table = table sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL', 'ALTER TABLE "users" RENAME COLUMN "post" TO "comment"', ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop(self): with self.schema.table("users") as blueprint: blueprint.drop_column("post") sql = ['ALTER TABLE "users" DROP COLUMN "post"'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( "cascade" ) sql = [ 'ALTER TABLE "users" ADD COLUMN "playlist_id" INTEGER NULL', 'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY ("playlist_id") REFERENCES "playlists"("id") ON DELETE CASCADE', ] self.assertEqual(blueprint.to_sql(), sql) def test_can_create_indexes(self): with self.schema.table("users") as blueprint: blueprint.index("name") blueprint.index(["name", "email"]) blueprint.unique("name") blueprint.unique(["name", "email"]) blueprint.fulltext("description") self.assertEqual(len(blueprint.table.added_columns), 0) self.assertEqual( blueprint.to_sql(), [ 'CREATE INDEX users_name_index ON "users"(name)', 'CREATE INDEX users_name_email_index ON "users"(name,email)', 'ALTER TABLE "users" ADD CONSTRAINT users_name_unique UNIQUE(name)', 'ALTER TABLE "users" ADD CONSTRAINT users_name_email_unique UNIQUE(name,email)', ], ) def test_alter_drop_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign("users_playlist_id_foreign") sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_foreign'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_foreign_key_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_foreign(["playlist_id"]) sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_foreign'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_unique_constraint(self): with self.schema.table("users") as blueprint: blueprint.drop_unique("users_playlist_id_unique") sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_unique'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_index(self): with self.schema.table("users") as blueprint: blueprint.index("playlist_id") sql = ['CREATE INDEX users_playlist_id_index ON "users"(playlist_id)'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") sql = [ 'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)' ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_index(self): with self.schema.table("users") as blueprint: blueprint.drop_index("users_playlist_id_index") sql = ["DROP INDEX users_playlist_id_index"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_index_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_index(["playlist_id"]) sql = ["DROP INDEX users_playlist_id_index"] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_primary(self): with self.schema.table("users") as blueprint: blueprint.drop_primary("users_id_primary") sql = ['ALTER TABLE "users" DROP CONSTRAINT users_id_primary'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_unique_constraint_shortcut(self): with self.schema.table("users") as blueprint: blueprint.drop_unique(["playlist_id"]) sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_unique'] self.assertEqual(blueprint.to_sql(), sql) def test_has_table(self): schema_sql = self.schema.has_table("users") sql = "SELECT * from information_schema.tables where table_name='users' AND table_schema = 'public'" self.assertEqual(schema_sql, sql) def test_drop_table(self): schema_sql = self.schema.has_table("users") sql = "SELECT * from information_schema.tables where table_name='users' AND table_schema = 'public'" self.assertEqual(schema_sql, sql) def test_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").change() blueprint.string("name") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL', 'ALTER TABLE "users" ALTER COLUMN "age" TYPE INTEGER, ALTER COLUMN "age" SET NOT NULL', ] self.assertEqual(blueprint.to_sql(), sql) def test_change_string(self): with self.schema.table("users") as blueprint: blueprint.string("name", 93).change() self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = [ 'ALTER TABLE "users" ALTER COLUMN "name" TYPE VARCHAR(93), ALTER COLUMN "name" SET NOT NULL' ] self.assertEqual(blueprint.to_sql(), sql) def test_drop_add_and_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").default(0).nullable().change() blueprint.string("name") blueprint.string("external_type").default("external") blueprint.drop_column("email") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") table.add_column("email", "string") blueprint.table.from_table = table sql = [ """ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL, ADD COLUMN "external_type" VARCHAR(255) NOT NULL DEFAULT 'external'""", 'ALTER TABLE "users" DROP COLUMN "email"', 'ALTER TABLE "users" ALTER COLUMN "age" TYPE INTEGER, ALTER COLUMN "age" DROP NOT NULL, ALTER COLUMN "age" SET DEFAULT 0', ] self.assertEqual(blueprint.to_sql(), sql) def test_timestamp_alter_add_nullable_column(self): with self.schema.table("users") as blueprint: blueprint.timestamp("due_date").nullable() self.assertEqual(len(blueprint.table.added_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = ['ALTER TABLE "users" ADD COLUMN "due_date" TIMESTAMP NULL'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_on_table_schema_table(self): schema = Schema( connection_class=PostgresConnection, connection="postgres", connection_details=DATABASES, ).on("postgres") with schema.table("table_schema") as blueprint: blueprint.drop_column("name") with schema.table("table_schema") as blueprint: blueprint.string("name") def test_can_add_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ 'ALTER TABLE "users" ADD COLUMN "status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\'', ] self.assertEqual(blueprint.to_sql(), sql) def test_can_change_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active").change() self.assertEqual(len(blueprint.table.changed_columns), 1) sql = [ 'ALTER TABLE "users" ALTER COLUMN "status" TYPE VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')), ALTER COLUMN "status" SET NOT NULL, ALTER COLUMN "status" SET DEFAULT active', ] self.assertEqual(blueprint.to_sql(), sql) ================================================ FILE: tests/scopes/test_default_global_scopes.py ================================================ """Test the default global scopes available in ORM.""" import unittest import uuid import pendulum from src.masoniteorm.models import Model from src.masoniteorm.scopes import ( SoftDeletesMixin, TimeStampsScope, TimeStampsMixin, UUIDPrimaryKeyScope, UUIDPrimaryKeyMixin, ) class MockBuilder: def __init__(self, model): self._model = model() self._creates = {} self._updates = {} class UserWithUUID(Model, UUIDPrimaryKeyMixin): __dry__ = True class UserWithTimeStamps(Model, TimeStampsMixin): __dry__ = True class UserWithCustomTimeStamps(Model, TimeStampsMixin): __dry__ = True date_updated_at = "updated_ts" date_created_at = "created_ts" class UserSoft(Model, SoftDeletesMixin): __dry__ = True class TestUUIDPrimaryKeyScope(unittest.TestCase): def setUp(self): self.builder = MockBuilder(UserWithUUID) self.scope = UUIDPrimaryKeyScope() # reset User attributes before each test UserWithUUID.__primary_key__ = "id" flags = { "__uuid_version__", "__uuid_namespace__", "__uuid_name__", "__uuid_bytes__", } for flag in flags: if hasattr(UserWithUUID, flag): delattr(UserWithUUID, flag) def test_default_to_uuid4(self): self.scope.set_uuid_create(self.builder) uuid_value = uuid.UUID(self.builder._creates["id"]) self.assertEqual(4, uuid_value.version) def test_can_set_uuid_version(self): # required for uuid 3 and 5 UserWithUUID.__uuid_namespace__ = uuid.NAMESPACE_DNS UserWithUUID.__uuid_name__ = "domain.com" for version in [1, 3, 4, 5]: UserWithUUID.__uuid_version__ = version self.scope.set_uuid_create(self.builder) uuid_value = uuid.UUID(self.builder._creates["id"]) self.assertEqual(version, uuid_value.version) del self.builder._creates["id"] def test_default_to_uuid_str(self): # Generates UUIDs as strings by default self.scope.set_uuid_create(self.builder) self.assertIsInstance(self.builder._creates["id"], str) def test_can_set_uuid_bytes(self): # Generates UUIDs in bytes instead of strings UserWithUUID.__uuid_bytes__ = True self.scope.set_uuid_create(self.builder) self.assertIsInstance(self.builder._creates["id"], bytes) def test_works_with_custom_pk_column(self): UserWithUUID.__primary_key__ = "ref" self.scope.set_uuid_create(self.builder) self.assertIn("ref", self.builder._creates) class TestSoftDeletesScope(unittest.TestCase): def test_soft_deletes_changes_delete_to_update(self): UserSoft.__timestamps__ = False user = UserSoft.hydrate({"id": 1}) sql = user.delete(query=True).to_sql() self.assertTrue(sql.startswith("UPDATE")) class TestTimeStampsScope(unittest.TestCase): def setUp(self): self.builder = MockBuilder(UserWithTimeStamps) self.scope = TimeStampsScope() try: del UserWithTimeStamps.__timestamps__ except: pass def test_updated_and_created_dates_are_set_when_create(self): self.scope.set_timestamp_create(self.builder) self.assertIn("created_at", self.builder._creates) self.assertIn("updated_at", self.builder._creates) created_at = pendulum.parse(self.builder._creates["created_at"]) updated_at = pendulum.parse(self.builder._creates["updated_at"]) self.assertIsInstance(created_at, pendulum.DateTime) self.assertIsInstance(updated_at, pendulum.DateTime) def test_timestamps_can_be_disabled(self): UserWithTimeStamps.__timestamps__ = False self.scope.set_timestamp_create(self.builder) self.assertNotIn("created_at", self.builder._creates) self.assertNotIn("updated_at", self.builder._creates) def test_uses_custom_timestamp_columns_on_create(self): self.builder = MockBuilder(UserWithCustomTimeStamps) self.scope.set_timestamp_create(self.builder) created_column = UserWithCustomTimeStamps.date_created_at updated_column = UserWithCustomTimeStamps.date_updated_at self.assertNotIn("created_at", self.builder._creates) self.assertNotIn("updated_at", self.builder._creates) self.assertIn(created_column, self.builder._creates) self.assertIn(updated_column, self.builder._creates) self.assertIsInstance( pendulum.parse(self.builder._creates[created_column]), pendulum.DateTime ) self.assertIsInstance( pendulum.parse(self.builder._creates[updated_column]), pendulum.DateTime ) def test_uses_custom_updated_column_on_update(self): user = UserWithCustomTimeStamps.hydrate({"id": 1}) sql = user.update({"id": 2}).to_sql() self.assertTrue(UserWithCustomTimeStamps.date_updated_at in sql) ================================================ FILE: tests/seeds/test_seeds.py ================================================ import unittest from databases.seeds.user_table_seeder import UserTableSeeder from src.masoniteorm.seeds import Seeder class TestSeeds(unittest.TestCase): def test_can_run_seeds(self): seeder = Seeder(dry=True) seeder.call(UserTableSeeder) self.assertEqual(seeder.ran_seeds, [UserTableSeeder]) ================================================ FILE: tests/sqlite/builder/test_sqlite_builder_insert.py ================================================ import inspect import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar class User(Model): __connection__ = "dev" __timestamps__ = False pass class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users"): connection = ConnectionFactory().make("sqlite") return QueryBuilder( grammar=SQLiteGrammar, connection_class=connection, connection="dev", table=table, connection_details=DATABASES, ).on("dev") def test_insert(self): builder = self.get_builder() result = builder.create( {"name": "Joe", "email": "joe@masoniteproject.com", "password": "secret"} ) self.assertIsInstance(result["id"], int) ================================================ FILE: tests/sqlite/builder/test_sqlite_builder_pagination.py ================================================ import inspect import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar class User(Model): __connection__ = "dev" class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users", model=User()): connection = ConnectionFactory().make("sqlite") return QueryBuilder( grammar=SQLiteGrammar, connection_class=connection, connection="dev", table=table, model=model, connection_details=DATABASES, ).on("dev") def test_pagination(self): builder = self.get_builder() paginator = builder.table("users").paginate(1) self.assertTrue(paginator.count) self.assertTrue(paginator.serialize()["data"]) self.assertTrue(paginator.serialize()["meta"]) self.assertTrue(paginator.result) self.assertTrue(paginator.current_page) self.assertTrue(paginator.per_page) self.assertTrue(paginator.count) self.assertTrue(paginator.last_page) self.assertTrue(paginator.next_page) self.assertEqual(paginator.previous_page, None) self.assertTrue(paginator.total) for user in paginator: self.assertIsInstance(user, User) paginator = builder.table("users").simple_paginate(10, 1) self.assertIsInstance(paginator.to_json(), str) self.assertTrue(paginator.count) self.assertTrue(paginator.serialize()["data"]) self.assertTrue(paginator.serialize()["meta"]) self.assertTrue(paginator.result) self.assertTrue(paginator.current_page) self.assertTrue(paginator.per_page) self.assertTrue(paginator.count) self.assertEqual(paginator.next_page, None) self.assertEqual(paginator.previous_page, None) for user in paginator: self.assertIsInstance(user, User) self.assertIsInstance(paginator.to_json(), str) ================================================ FILE: tests/sqlite/builder/test_sqlite_query_builder.py ================================================ import inspect import unittest from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from tests.utils import MockConnectionFactory from src.masoniteorm.exceptions import ModelNotFound, HTTP404 class UserMock(Model): __connection__ = "dev" __table__ = "users" class BaseTestQueryBuilder: maxDiff = None def get_builder(self, table="users"): connection = MockConnectionFactory().make("sqlite") return QueryBuilder( self.grammar, connection_class=connection, connection="mysql", table=table, dry=True, ) def test_sum(self): builder = self.get_builder() builder.sum("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_sum_aggregate(self): builder = self.get_builder() builder.aggregate("SUM", "age") sql = getattr(self, "sum")() self.assertEqual(builder.to_sql(), sql) def test_sum_aggregate_with_alias(self): builder = self.get_builder() builder.aggregate("SUM", "age", alias="number") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_sum_aggregate_with_alias_in_column_name(self): builder = self.get_builder() builder.sum("age as number") sql = getattr(self, "sum_aggregate_with_alias")() self.assertEqual(builder.to_sql(), sql) def test_where_like(self): builder = self.get_builder() builder.where("age", "like", "%name%") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_like(self): builder = self.get_builder() builder.where("age", "not like", "%name%") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_max(self): builder = self.get_builder() builder.max("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_min(self): builder = self.get_builder() builder.min("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_avg(self): builder = self.get_builder() builder.avg("age") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_all(self): builder = self.get_builder() builder.all() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_get(self): builder = self.get_builder() builder.get() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_first(self): builder = self.get_builder().first(query=True) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_last(self): UserMock.order_by("id", "DESC").first().id == UserMock.last("id").id def test_last_with_default_primary_key(self): UserMock.order_by("id", "DESC").first().id == UserMock.last().id def test_first_or_fail_exception(self): with self.assertRaises(ModelNotFound): user = self.get_builder().where("name", "=", "Marlysson").first_or_fail() def test_find_or_fail_exception(self): with self.assertRaises(ModelNotFound): user = UserMock.find_or_fail(1000) def test_find_or_404_exception(self): with self.assertRaises(HTTP404): user = UserMock.find_or_404(10000) def test_select(self): builder = self.get_builder() builder.select("name", "email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select_multiple(self): builder = self.get_builder() builder.select("name, email") sql = getattr(self, "select")() self.assertEqual(builder.to_sql(), sql) def test_add_select(self): builder = self.get_builder() sql = ( builder.select("name") .add_select("phone_count", lambda q: q.count("*").table("phones")) .add_select("salary", lambda q: q.count("*").table("salary")) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_add_select_no_table(self): builder = self.get_builder(table=None) sql = ( builder.add_select( "other_test", lambda q: q.max("updated_at").table("different_table") ) .add_select( "some_alias", lambda q: q.max("updated_at").table("another_table") ) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_add_select_with_raw(self): builder = self.get_builder(table=None) sql = ( builder.select_raw("max(updated_at) as test") .from_("some_table") .add_select( "other_test", lambda query: ( query.max("updated_at") .from_("different_table") .where("some_id", "=", "3") ), ) ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_select_raw(self): builder = self.get_builder() builder.select_raw("count(email) as email_count") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_create(self): builder = self.get_builder() builder.create( {"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_delete(self): builder = self.get_builder() builder.delete("name", "Joe", query=True) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where(self): builder = self.get_builder() builder.where("name", "Joe") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_dictionary(self): builder = self.get_builder() builder.where({"name": "Joe"}) sql = getattr(self, "where")() self.assertEqual(builder.to_sql(), sql) def test_where_exists(self): builder = self.get_builder() builder.where_exists("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_limit(self): builder = self.get_builder() builder.limit(5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_offset(self): builder = self.get_builder() builder.offset(5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_offset_with_limit(self): builder = self.get_builder() builder.limit(2).offset(5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_join(self): builder = self.get_builder() builder.join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_left_join(self): builder = self.get_builder() builder.left_join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_right_join(self): builder = self.get_builder() builder.right_join("profiles", "users.id", "=", "profiles.user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_update(self): builder = self.get_builder().update( {"name": "Joe", "email": "joe@yopmail.com"}, dry=True ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_increment(self): builder = self.get_builder() builder_sql = builder.increment("age", 1) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder_sql, sql) def test_decrement(self): builder = self.get_builder() builder_sql = builder.decrement("age", 1) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder_sql, sql) def test_count(self): builder = self.get_builder() builder.count("id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_asc(self): builder = self.get_builder() builder.order_by("email", "asc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_multiple(self): builder = self.get_builder() builder.order_by("email, name, active") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_reference_direction(self): builder = self.get_builder() builder.order_by("email, name desc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_raw(self): builder = self.get_builder() builder.order_by_raw("col asc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_order_by_desc(self): builder = self.get_builder() builder.order_by("email", "desc") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_column(self): builder = self.get_builder() builder.where_column("name", "username") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_in(self): builder = self.get_builder() builder.where_not_in("id", [1, 2, 3]) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_between(self): builder = self.get_builder() builder.between("id", 2, 5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_between_persisted(self): builder = QueryBuilder().table("users").on("dev") users = builder.between("age", 1, 2).count() self.assertEqual(users, 2) def test_not_between(self): builder = self.get_builder() builder.not_between("id", 2, 5) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_not_between_persisted(self): builder = QueryBuilder().table("users").on("dev") users = builder.where_not_null("id").not_between("age", 1, 2).count() self.assertEqual(users, 0) def test_where_in(self): builder = self.get_builder() builder.where_in("id", [1, 2, 3]) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_null(self): builder = self.get_builder() builder.where_null("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_not_null(self): builder = self.get_builder() builder.where_not_null("name") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_having(self): builder = self.get_builder(table="payments") builder.select("user_id").avg("salary").group_by("user_id").having( "salary", ">=", "1000" ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_group_by_raw(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by_raw("count(*)") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_group_by_multiple(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id").group_by("salary") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_group_by_multiple_in_same_group_by(self): builder = self.get_builder(table="payments") builder.select("user_id").min("salary").group_by("user_id, salary") sql = getattr(self, "group_by_multiple")() self.assertEqual(builder.to_sql(), sql) def test_builder_alone(self): self.assertTrue( QueryBuilder( connection_details={ "default": "sqlite", "sqlite": { "driver": "sqlite", "database": "orm.sqlite3", "prefix": "", }, } ).table("users") ) def test_where_lt(self): builder = self.get_builder() builder.where("age", "<", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_lte(self): builder = self.get_builder() builder.where("age", "<=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_gt(self): builder = self.get_builder() builder.where("age", ">", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_gte(self): builder = self.get_builder() builder.where("age", ">=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_where_ne(self): builder = self.get_builder() builder.where("age", "!=", "20") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_or_where(self): builder = self.get_builder() builder.where("age", "20").or_where("age", "<", 20) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_can_call_with_schema(self): builder = self.get_builder() sql = ( builder.table("information_schema.columns") .select("table_name") .where("table_name", "users") .to_sql() ) self.assertEqual( sql, """SELECT "information_schema"."columns"."table_name" FROM "information_schema"."columns" WHERE "information_schema"."columns"."table_name" = 'users'""", ) def test_can_call_with_raw(self): builder = self.get_builder() sql = builder.on("dev").statement("select * from users") self.assertTrue(sql) def test_truncate(self): builder = self.get_builder() sql = builder.truncate(dry=True) sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) def test_truncate_without_foreign_keys(self): builder = self.get_builder() sql = builder.truncate(foreign_keys=True) sql_ref = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(sql, sql_ref) class SQLiteQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase): grammar = SQLiteGrammar def sum(self): """ builder = self.get_builder() builder.sum('age') """ return """SELECT SUM("users"."age") AS age FROM "users\"""" def sum_aggregate_with_alias(self): """ builder = self.get_builder() builder.sum('age') """ return """SELECT SUM("users"."age") AS number FROM "users\"""" def max(self): """ builder = self.get_builder() builder.max('age') """ return """SELECT MAX("users"."age") AS age FROM "users\"""" def min(self): """ builder = self.get_builder() builder.min('age') """ return """SELECT MIN("users"."age") AS age FROM "users\"""" def avg(self): """ builder = self.get_builder() builder.avg('age') """ return """SELECT AVG("users"."age") AS age FROM "users\"""" def first(self): """ builder = self.get_builder() builder.first() """ return """SELECT * FROM "users" LIMIT 1""" def all(self): """ builder = self.get_builder() builder.all() """ return """SELECT * FROM "users\"""" def get(self): """ builder = self.get_builder() builder.get() """ return """SELECT * FROM "users\"""" def select(self): """ builder = self.get_builder() builder.select('name', 'email') """ return """SELECT "users"."name", "users"."email" FROM "users\"""" def select_multiple(self): """ builder = self.get_builder() builder.select('name', 'email') """ return """SELECT "users"."name", "users"."email" FROM "users\"""" def add_select(self): """ builder = self.get_builder() builder.select('name', 'email') """ return """SELECT "users"."name", (SELECT COUNT(*) AS m_count_reserved FROM "phones") AS phone_count, (SELECT COUNT(*) AS m_count_reserved FROM "salary") AS salary FROM "users\"""" def add_select_no_table(self): """ builder = self.get_builder() builder.select('name', 'email') """ return ( "SELECT " '(SELECT MAX("different_table"."updated_at") AS updated_at FROM "different_table") AS other_test, ' '(SELECT MAX("another_table"."updated_at") AS updated_at FROM "another_table") AS some_alias' ) def add_select_with_raw(self): """ builder .select_raw("max(updated_at) as test").from_("some_table") .add_select( "other_test", lambda query: ( query.max("updated_at") .from_("different_table") .where( "some_id", "=", "3" ) ), ) """ return ( "SELECT max(updated_at) as test, " '(SELECT MAX("different_table"."updated_at") AS updated_at ' 'FROM "different_table" ' 'WHERE "different_table"."some_id" = \'3\') AS other_test ' 'FROM "some_table"' ) def select_raw(self): """ builder = self.get_builder() builder.select_raw('count(email) as email_count') """ return """SELECT count(email) as email_count FROM "users\"""" def create(self): """ builder = get_builder() builder.create({"name": "Corentin All", 'email': 'corentin@yopmail.com'}) """ return """INSERT INTO "users" ("name", "email") VALUES ('Corentin All', 'corentin@yopmail.com')""" def delete(self): """ builder = get_builder() builder.delete("name', 'Joe') """ return """DELETE FROM "users" WHERE "name" = 'Joe'""" def where(self): """ builder = get_builder() builder.where('name', 'Joe') """ return """SELECT * FROM "users" WHERE "users"."name" = 'Joe'""" def where_exists(self): """ builder = get_builder() builder.where_exists('name') """ return """SELECT * FROM "users" WHERE EXISTS 'name'""" def limit(self): """ builder = get_builder() builder.limit(5) """ return """SELECT * FROM "users" LIMIT 5""" def offset(self): """ builder = get_builder() builder.offset(5) """ return """SELECT * FROM "users" LIMIT -1 OFFSET 5""" def offset_with_limit(self): """ builder = get_builder() builder.limit(2).offset(5) """ return """SELECT * FROM "users" LIMIT 2 OFFSET 5""" def join(self): """ builder.join("profiles", "users.id", "=", "profiles.user_id") """ return """SELECT * FROM "users" INNER JOIN "profiles" ON "users"."id" = "profiles"."user_id\"""" def left_join(self): """ builder.left_join("profiles", "users.id", "=", "profiles.user_id") """ return """SELECT * FROM "users" LEFT JOIN "profiles" ON "users"."id" = "profiles"."user_id\"""" def right_join(self): """ builder.right_join("profiles", "users.id", "=", "profiles.user_id") """ return """SELECT * FROM "users" LEFT JOIN "profiles" ON "users"."id" = "profiles"."user_id\"""" def update(self): """ builder.update({"name": "Joe", "email": "joe@yopmail.com"}) """ return """UPDATE "users" SET "name" = 'Joe', "email" = 'joe@yopmail.com'""" def increment(self): """ builder.increment('age', 1) """ return """UPDATE "users" SET "age" = "age" + '1'""" def decrement(self): """ builder.decrement('age', 1) """ return """UPDATE "users" SET "age" = "age" - '1'""" def count(self): """ builder.count(id) """ return """SELECT COUNT("users"."id") AS id FROM "users\"""" def order_by_asc(self): """ builder.order_by('email', 'asc') """ return """SELECT * FROM "users" ORDER BY "email" ASC""" def order_by_multiple(self): """ builder.order_by('email', 'asc') """ return ( """SELECT * FROM "users" ORDER BY "email" ASC, "name" ASC, "active" ASC""" ) def order_by_raw(self): """ builder.order_by('email', 'asc') """ return """SELECT * FROM "users" ORDER BY col asc""" def order_by_reference_direction(self): """ builder.order_by('email', 'asc') """ return """SELECT * FROM "users" ORDER BY "email" ASC, "name" DESC""" def order_by_desc(self): """ builder.order_by('email', 'des') """ return """SELECT * FROM "users" ORDER BY "email" DESC""" def where_column(self): """ builder.where_column('name', 'username') """ return """SELECT * FROM "users" WHERE "users"."name" = "users"."username\"""" def where_null(self): """ builder.where_null('name') """ return """SELECT * FROM "users" WHERE "users"."name" IS NULL""" def where_not_null(self): """ builder.where_null('name') """ return """SELECT * FROM "users" WHERE "users"."name" IS NOT NULL""" def where_not_in(self): """ builder.where_not_in('id', [1, 2, 3]) """ return """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')""" def where_in(self): """ builder.where_in('id', [1, 2, 3]) """ return """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')""" def between(self): """ builder.between('id', 2, 5) """ return """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'""" def not_between(self): """ builder.not_between('id', 2, 5) """ return """SELECT * FROM "users" WHERE "users"."id" NOT BETWEEN '2' AND '5'""" def having(self): """ builder.select('user_id').avg('salary').group_by('user_id').having('salary', '>=', '1000') """ return """SELECT "payments"."user_id", AVG("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id" HAVING "payments"."salary" >= '1000'""" def group_by(self): """ builder.select('user_id').min('salary').group_by('user_id') """ return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id\"""" def group_by_multiple(self): """ builder.select('user_id').min('salary').group_by('user_id') """ return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id", "payments"."salary\"""" def group_by_raw(self): """ builder.select('user_id').min('salary').group_by('user_id') """ return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY count(*)""" def where_lt(self): """ builder = self.get_builder() builder.where('age', '<', '20') """ return """SELECT * FROM "users" WHERE "users"."age" < '20'""" def where_lte(self): """ builder = self.get_builder() builder.where('age', '<=', '20') """ return """SELECT * FROM "users" WHERE "users"."age" <= '20'""" def where_gt(self): """ builder = self.get_builder() builder.where('age', '>', '20') """ return """SELECT * FROM "users" WHERE "users"."age" > '20'""" def where_gte(self): """ builder = self.get_builder() builder.where('age', '>=', '20') """ return """SELECT * FROM "users" WHERE "users"."age" >= '20'""" def where_ne(self): """ builder = self.get_builder() builder.where('age', '!=', '20') """ return """SELECT * FROM "users" WHERE "users"."age" != '20'""" def or_where(self): """ builder = self.get_builder() builder.where('age', '20').or_where('age','<', 20) """ return """SELECT * FROM "users" WHERE "users"."age" = '20' OR "users"."age" < '20'""" def where_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%") """ return """SELECT * FROM "users" WHERE "users"."age" LIKE '%name%'""" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%") """ return """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'""" def test_when(self): builder = self.get_builder() sql = builder.when(19 > 18, lambda q: q.where("age_restricted", 1)).to_sql() return self.assertEqual( sql, """SELECT * FROM "users" WHERE "users"."age_restricted" = '1'""" ) builder = self.get_builder() sql = builder.when(17 > 18, lambda q: q.where("age_restricted", 1)).to_sql() return self.assertEqual(sql, """SELECT * FROM "users\"""") def truncate(self): """ builder = self.get_builder() builder.truncate() """ return """DELETE FROM "users\"""" def truncate_without_foreign_keys(self): """ builder = self.get_builder() builder.truncate(foreign_keys=True) """ return [ "PRAGMA foreign_keys = OFF", 'DELETE FROM "users"', "PRAGMA foreign_keys = ON", ] def test_latest(self): builder = self.get_builder() builder.latest("email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def test_oldest(self): builder = self.get_builder() builder.oldest("email") sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(builder.to_sql(), sql) def oldest(self): """ builder.order_by('email', 'asc') """ return """SELECT * FROM "users" ORDER BY "email" ASC""" def latest(self): """ builder.order_by('email', 'des') """ return """SELECT * FROM "users" ORDER BY "email" DESC""" ================================================ FILE: tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py ================================================ import inspect import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.relationships import belongs_to, has_many class Logo(Model): __connection__ = "dev" class Article(Model): __connection__ = "dev" @belongs_to("id", "article_id") def logo(self): return Logo @belongs_to("user_id", "id") def user(self): return User class Profile(Model): __connection__ = "dev" class User(Model): __connection__ = "dev" __with__ = ["articles.logo"] @has_many("id", "user_id") def articles(self): return Article @belongs_to("id", "user_id") def profile(self): return Profile class EagerUser(Model): __connection__ = "dev" __with__ = ("profile",) __table__ = "users" @belongs_to("id", "user_id") def profile(self): return Profile class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users", model=User()): connection = ConnectionFactory().make("sqlite") return QueryBuilder( grammar=SQLiteGrammar, connection="dev", table=table, model=model, connection_details=DATABASES, ).on("dev") def test_with(self): builder = self.get_builder() result = builder.with_("profile").get() for model in result: if model.profile: self.assertEqual(model.profile.title, "title") def test_with_from_model(self): builder = EagerUser result = builder.get() for model in result: if model.profile: self.assertEqual(model.profile.title, "title") def test_with_first(self): builder = self.get_builder() result = builder.with_("profile").where("id", 1).first() self.assertEqual(result.profile.title, "title") def test_with_where_no_relation(self): builder = self.get_builder() result = builder.with_("profile").where("id", 5).first() result.serialize() def test_with_multiple_per_same_relation(self): builder = self.get_builder() result = User.with_("articles", "articles.logo").where("id", 1).first() self.assertTrue(result.serialize()["articles"]) self.assertTrue(result.serialize()["articles"][0]["logo"]) ================================================ FILE: tests/sqlite/builder/test_sqlite_query_builder_relationships.py ================================================ import unittest from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.relationships import belongs_to from tests.utils import MockConnectionFactory from dotenv import load_dotenv load_dotenv(".env") class Logo(Model): __connection__ = "dev" class Article(Model): __connection__ = "dev" @belongs_to("id", "article_id") def logo(self): return Logo class Profile(Model): __connection__ = "dev" class User(Model): __connection__ = "dev" @belongs_to("id", "user_id") def articles(self): return Article @belongs_to("id", "user_id") def profile(self): return Profile class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users"): connection = MockConnectionFactory().make("sqlite") return QueryBuilder( grammar=SQLiteGrammar, connection_class=connection, table=table, model=User() ) def test_has(self): builder = self.get_builder() sql = builder.has("articles").to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\"""" """)""", ) def test_doesnt_have(self): builder = self.get_builder() sql = builder.doesnt_have("articles").to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE NOT EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\"""" """)""", ) def test_where_doesnt_have(self): builder = self.get_builder() sql = builder.where_doesnt_have( "articles", lambda q: q.where("title", "Eggs and Ham") ).to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE NOT EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND "articles"."title" = 'Eggs and Ham'""" """)""", ) def test_where_has_query(self): builder = self.get_builder() sql = builder.where_has("articles", lambda q: q.where("active", 1)).to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND "articles"."active" = '1'""" """)""", ) def test_relationship_multiple_has(self): to_sql = User.has("articles", "profile").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\"""" """) AND EXISTS (""" """SELECT * FROM "profiles" WHERE "profiles"."user_id" = "users"."id\"""" """)""", ) def test_relationship_multiple_has_calls(self): to_sql = User.has("articles").has("profile").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE EXISTS (""" """SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\"""" """) AND EXISTS (""" """SELECT * FROM "profiles" WHERE "profiles"."user_id" = "users"."id\"""" """)""", ) def test_nested_has(self): to_sql = User.has("articles.logo").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND EXISTS (SELECT * FROM "logos" WHERE "logos"."article_id" = "articles"."id"))""", ) def test_joins(self): to_sql = self.get_builder().joins("articles").to_sql() self.assertEqual( to_sql, """SELECT * FROM "users" INNER JOIN "articles" ON "users"."id" = "articles"."user_id\"""", ) ================================================ FILE: tests/sqlite/builder/test_sqlite_transaction.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from tests.integrations.config.database import DB from src.masoniteorm.collection import Collection class User(Model): __connection__ = "dev" __timestamps__ = False class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def get_builder(self, table="users"): connection = ConnectionFactory().make("sqlite") return QueryBuilder( grammar=SQLiteGrammar, connection="dev", table=table, model=User(), connection_details=DATABASES, ).on("dev") def test_transaction(self): builder = self.get_builder() builder.begin() builder.create({"name": "phillip3", "email": "phillip3"}) user = builder.where("name", "phillip3").first() self.assertEqual(user["name"], "phillip3") builder.rollback() user = builder.where("name", "phillip3").first() self.assertEqual(user, None) def test_transaction_globally(self): connection = DB.begin_transaction("dev") self.assertEqual(connection, self.get_builder().new_connection()) DB.commit("dev") DB.begin_transaction("dev") DB.rollback("dev") def test_chunking(self): for users in self.get_builder().chunk(10): self.assertIsInstance(users, Collection) ================================================ FILE: tests/sqlite/grammar/test_sqlite_delete_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar class BaseDeleteGrammarTest: def setUp(self): self.builder = QueryBuilder(SQLiteGrammar, table="users") def test_can_compile_delete(self): to_sql = self.builder.delete("id", 1, query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_delete_in(self): to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_delete_with_where(self): to_sql = ( self.builder.where("age", 20) .where("profile", 1) .delete(query=True) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestSqliteDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase): grammar = "sqlite" def can_compile_delete(self): """ ( self.builder .delete('id', 1) .to_sql() ) """ return """DELETE FROM "users" WHERE "id" = '1'""" def can_compile_delete_in(self): """ ( self.builder .delete('id', 1) .to_sql() ) """ return """DELETE FROM "users" WHERE "id" IN ('1','2','3')""" def can_compile_delete_with_where(self): """ ( self.builder .where('age', 20) .where('profile', 1) .delete() .to_sql() ) """ return """DELETE FROM "users" WHERE "age" = '20' AND "profile" = '1'""" ================================================ FILE: tests/sqlite/grammar/test_sqlite_insert_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar class BaseInsertGrammarTest: def setUp(self): self.builder = QueryBuilder(SQLiteGrammar, table="users") def test_can_compile_insert(self): to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_insert_with_keywords(self): to_sql = self.builder.create(name="Joe", query=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create(self): to_sql = self.builder.bulk_create( # These keys are intentionally out of order to show column to value alignment works [ {"name": "Joe", "age": 5}, {"age": 35, "name": "Bill"}, {"name": "John", "age": 10}, ], query=True, ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_qmark(self): to_sql = self.builder.bulk_create( [{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True ).to_qmark() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_bulk_create_multiple(self): to_sql = self.builder.bulk_create( [ {"name": "Joe", "active": "1"}, {"name": "Bill", "active": "1"}, {"name": "John", "active": "1"}, ], query=True, ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestSqliteUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase): grammar = "sqlite" def can_compile_insert(self): """ self.builder.create({ 'name': 'Joe' }).to_sql() """ return """INSERT INTO "users" ("name") VALUES ('Joe')""" def can_compile_insert_with_keywords(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO "users" ("name") VALUES ('Joe')""" def can_compile_bulk_create(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO "users" ("age", "name") VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')""" def can_compile_bulk_create_multiple(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO "users" ("active", "name") VALUES ('1', 'Joe'), ('1', 'Bill'), ('1', 'John')""" def can_compile_bulk_create_qmark(self): """ self.builder.create(name="Joe").to_sql() """ return """INSERT INTO "users" ("name") VALUES ('?'), ('?'), ('?')""" ================================================ FILE: tests/sqlite/grammar/test_sqlite_select_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.testing import BaseTestCaseSelectGrammar class TestSQLiteGrammar(BaseTestCaseSelectGrammar, unittest.TestCase): grammar = SQLiteGrammar maxDiff = None def can_compile_select(self): """ self.builder.to_sql() """ return """SELECT * FROM "users\"""" def can_compile_order_by_and_first(self): """ self.builder.order_by('id', 'asc').first() """ return """SELECT * FROM "users" ORDER BY "id" ASC LIMIT 1""" def can_compile_with_columns(self): """ self.builder.select('username', 'password').to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users\"""" def can_compile_with_where(self): """ self.builder.select('username', 'password').where('id', 1).to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1'""" def can_compile_with_several_where(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe'""" def can_compile_with_several_where_and_limit(self): """ self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql() """ return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe' LIMIT 10""" def can_compile_with_sum(self): """ self.builder.sum('age').to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users\"""" def can_compile_with_max(self): """ self.builder.max('age').to_sql() """ return """SELECT MAX("users"."age") AS age FROM "users\"""" def can_compile_with_max_and_columns(self): """ self.builder.select('username').max('age').to_sql() """ return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\"""" def can_compile_with_max_and_columns_different_order(self): """ self.builder.max('age').select('username').to_sql() """ return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\"""" def can_compile_with_order_by(self): """ self.builder.select('username').order_by('age', 'desc').to_sql() """ return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC""" def can_compile_with_multiple_order_by(self): """ self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql() """ return ( """SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC""" ) def can_compile_with_group_by(self): """ self.builder.select('username').group_by('age').to_sql() """ return """SELECT "users"."username" FROM "users" GROUP BY "users"."age\"""" def can_compile_where_in(self): """ self.builder.select('username').where_in('age', [1,2,3]).to_sql() """ return """SELECT "users"."username" FROM "users" WHERE "users"."age" IN ('1','2','3')""" def can_compile_where_in_empty(self): """ self.builder.where_in('age', []).to_sql() """ return """SELECT * FROM "users" WHERE 0 = 1""" def can_compile_where_not_in(self): """ self.builder.select('username').where_not_in('age', [1,2,3]).to_sql() """ return """SELECT "users"."username" FROM "users" WHERE "users"."age" NOT IN ('1','2','3')""" def can_compile_where_null(self): """ self.builder.select('username').where_null('age').to_sql() """ return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NULL""" def can_compile_where_not_null(self): """ self.builder.select('username').where_not_null('age').to_sql() """ return ( """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL""" ) def can_compile_where_raw(self): """ self.builder.where_raw(""age" = '18'").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" = '18'""" def can_compile_select_raw(self): """ self.builder.select_raw("COUNT(*)").to_sql() """ return """SELECT COUNT(*) FROM "users\"""" def can_compile_limit_and_offset(self): """ self.builder.limit(10).offset(10).to_sql() """ return """SELECT * FROM "users" LIMIT 10 OFFSET 10""" def can_compile_select_raw_with_select(self): """ self.builder.select('id').select_raw("COUNT(*)").to_sql() """ return """SELECT "users"."id", COUNT(*) FROM "users\"""" def can_compile_count(self): """ self.builder.count().to_sql() """ return """SELECT COUNT(*) AS m_count_reserved FROM "users\"""" def can_compile_count_column(self): """ self.builder.count().to_sql() """ return """SELECT COUNT("users"."money") AS money FROM "users\"""" def can_compile_where_column(self): """ self.builder.where_column('name', 'email').to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" = "users"."email\"""" def can_compile_or_where(self): """ self.builder.where('name', 2).or_where('name', 3).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" = '2' OR "users"."name" = '3'""" def can_grouped_where(self): """ self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql() """ return """SELECT * FROM "users" WHERE ("users"."age" = '2' AND "users"."name" = 'Joe')""" def can_compile_sub_select(self): """ self.builder.where_in('name', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age') ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users")""" def can_compile_sub_select_where(self): """ self.builder.where_in('age', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe') ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" IN (SELECT "users"."age" FROM "users" WHERE "users"."age" = '2' AND "users"."name" = 'Joe')""" def can_compile_sub_select_value(self): """ self.builder.where('name', self.builder.new().sum('age') ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")""" def can_compile_complex_sub_select(self): """ self.builder.where_in('name', (QueryBuilder(GrammarFactory.make(self.grammar), table='users') .select('age').where_in('email', QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email') )) ).to_sql() """ return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users" WHERE "users"."email" IN (SELECT "users"."email" FROM "users"))""" def can_compile_exists(self): """ self.builder.select('age').where_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return """SELECT "users"."age" FROM "users" WHERE EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')""" def can_compile_not_exists(self): """ self.builder.select('age').where_not_exists( self.builder.new().select('username').where('age', 12) ).to_sql() """ return """SELECT "users"."age" FROM "users" WHERE NOT EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')""" def can_compile_having(self): """ builder.sum('age').group_by('age').having('age').to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\"""" def can_compile_having_order(self): """ builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\" ORDER "users"."age" DESC""" def can_compile_having_raw(self): """ builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql() """ return """SELECT COUNT(*) as counts FROM "users" HAVING counts > 18""" def can_compile_having_with_expression(self): """ builder.sum('age').group_by('age').having('age', 10).to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" = '10'""" def can_compile_having_with_greater_than_expression(self): """ builder.sum('age').group_by('age').having('age', '>', 10).to_sql() """ return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" > '10'""" def can_compile_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id\"""" def can_compile_left_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return """SELECT * FROM "users" LEFT JOIN "contacts" ON "users"."id" = "contacts"."user_id\"""" def can_compile_multiple_join(self): """ builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql() """ return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id" INNER JOIN "posts" ON "comments"."post_id" = "posts"."id\"""" def can_compile_between(self): """ builder.between('age', 18, 21).to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" BETWEEN '18' AND '21'""" def can_compile_not_between(self): """ builder.not_between('age', 18, 21).to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" NOT BETWEEN '18' AND '21'""" def test_can_compile_where_raw(self): to_sql = self.builder.where_raw(""" "age" = '18'""").to_sql() self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""") def test_can_compile_where_raw_and_where_with_multiple_bindings(self): query = self.builder.where_raw( """ "age" = '?' AND "is_admin" = '?' """, [18, True] ).where("email", "test@example.com") self.assertEqual( query.to_qmark(), """SELECT * FROM "users" WHERE "age" = '?' AND "is_admin" = '?' AND "users"."email" = '?'""", ) self.assertEqual(query._bindings, [18, True, "test@example.com"]) def test_can_compile_select_raw(self): to_sql = self.builder.select_raw("COUNT(*)").to_sql() self.assertEqual(to_sql, """SELECT COUNT(*) FROM "users\"""") def test_can_compile_select_raw_with_select(self): to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql() self.assertEqual(to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""") def can_compile_first_or_fail(self): """ builder = self.get_builder() builder.where("is_admin", "=", True).first_or_fail() """ return """SELECT * FROM "users" WHERE "users"."is_admin" = '1' LIMIT 1""" def where_not_like(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'""" def where_like(self): """ builder = self.get_builder() builder.where("age", "like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" LIKE '%name%'""" def where_regexp(self): """ builder = self.get_builder() builder.where("age", "regexp", "Joe").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" REGEXP 'Joe'""" def where_not_regexp(self): """ builder = self.get_builder() builder.where("age", "regexp", "Joe").to_sql() """ return """SELECT * FROM "users" WHERE "users"."age" NOT REGEXP 'Joe'""" def can_compile_join_clause(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on("bgt.fund", "=", "rg.fund") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt"."dept" = "rg"."dept" AND "bgt"."acct" = "rg"."acct" AND "bgt"."sub" = "rg"."sub\"""" def can_compile_join_clause_with_value(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_value("bgt.active", "=", "1") .or_on_value("bgt.acct", "=", "1234") ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."active" = '1' OR "bgt"."acct" = '1234'""" def can_compile_join_clause_with_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_null("bgt.acct") .or_on_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NULL AND "rg"."abc" = '10'""" def can_compile_join_clause_with_not_null(self): """ builder = self.get_builder() clause = ( JoinClause("report_groups as rg") .on_not_null("bgt.acct") .or_on_not_null("bgt.dept") .on_value("rg.abc", 10) ) builder.join(clause).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NOT NULL OR "dept" IS NOT NULL AND "rg"."abc" = '10'""" def can_compile_join_clause_with_lambda(self): """ builder = self.get_builder() builder.join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .on_null("bgt") ), ).to_sql() """ return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt" IS NULL""" def can_compile_left_join_clause_with_lambda(self): """ builder = self.get_builder() builder.left_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return """SELECT * FROM "users" LEFT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL""" def can_compile_right_join_clause_with_lambda(self): """ builder = self.get_builder() builder.right_join( "report_groups as rg", lambda clause: ( clause.on("bgt.fund", "=", "rg.fund") .or_on_null("bgt") ), ).to_sql() """ return """SELECT * FROM "users" LEFT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL""" def update_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."votes" >= '100'""" def shared_lock(self): """ builder = self.get_builder() builder.where("age", "not like", "%name%").to_sql() """ return """SELECT * FROM "users" WHERE "users"."votes" >= '100'""" def can_user_where_raw_and_where(self): """ builder.where_raw("age = '18'").where("name", "=", "James").to_sql() """ return """SELECT * FROM "users" WHERE age = '18' AND "users"."name" = 'James'""" def where_exists_with_lambda(self): return """SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_not_exists_with_lambda(self): return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')""" def where_date(self): return ( """SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'""" ) def or_where_null(self): return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL""" def select_distinct(self): return """SELECT DISTINCT "users"."group" FROM "users\"""" ================================================ FILE: tests/sqlite/grammar/test_sqlite_update_grammar.py ================================================ import inspect import unittest from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.expressions import Raw class BaseTestCaseUpdateGrammar: def setUp(self): self.builder = QueryBuilder(SQLiteGrammar, table="users") def test_can_compile_update(self): to_sql = ( self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_multiple_update(self): to_sql = self.builder.update( {"name": "Joe", "email": "user@email.com"}, dry=True ).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) def test_can_compile_update_with_multiple_where(self): to_sql = ( self.builder.where("name", "bob") .where("age", 20) .update({"name": "Joe"}, dry=True) .to_sql() ) sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) # def test_can_compile_increment(self): # to_sql = self.builder.increment("age") # print(to_sql) # self.assertTrue(to_sql.isnumeric()) # def test_can_compile_decrement(self): # to_sql = self.builder.decrement("age", 20).to_sql() # sql = getattr( # self, inspect.currentframe().f_code.co_name.replace("test_", "") # )() # self.assertEqual(to_sql, sql) def test_raw_expression(self): to_sql = self.builder.update({"name": Raw('"username"')}, dry=True).to_sql() sql = getattr( self, inspect.currentframe().f_code.co_name.replace("test_", "") )() self.assertEqual(to_sql, sql) class TestSqliteUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase): grammar = "sqlite" def can_compile_update(self): """ builder.where('name', 'bob').update({ 'name': 'Joe' }).to_sql() """ return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob'""" def raw_expression(self): """ builder.where('name', 'bob').update({ 'name': 'Joe' }).to_sql() """ return 'UPDATE "users" SET "name" = "username"' def can_compile_multiple_update(self): """ self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql() """ return """UPDATE "users" SET "name" = 'Joe', "email" = 'user@email.com'""" def can_compile_update_with_multiple_where(self): """ builder.where('name', 'bob').where('age', 20).update({ 'name': 'Joe' }).to_sql() """ return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob' AND "age" = '20'""" def can_compile_increment(self): """ builder.increment('age').to_sql() """ return """UPDATE "users" SET "age" = "age" + '1'""" def can_compile_decrement(self): """ builder.decrement('age', 20).to_sql() """ return """UPDATE "users" SET "age" = "age" - '20'""" ================================================ FILE: tests/sqlite/models/test_attach_detach.py ================================================ import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import belongs_to, has_one from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform from tests.integrations.config.database import DATABASES class Bottle(Model): __table__ = "bottles" __connection__ = "dev" __timestamps__ = False __fillable__ = ["label"] @has_one(None, "bottle_id", "id") def lid(self): return BottleLid class BottleLid(Model): __table__ = "bottle_lids" __connection__ = "dev" __timestamps__ = False __fillable__ = ["colour", "bottle_id"] @belongs_to(None, "bottle_id", "id") def bottle(self): return Bottle class TestAttachDetach(unittest.TestCase): def setUp(self): self.schema = Schema( connection="dev", connection_details=DATABASES, platform=SQLitePlatform, ).on("dev") with self.schema.create_table_if_not_exists("bottles") as table: table.integer("id").primary() table.string("label") with self.schema.create_table_if_not_exists("bottle_lids") as table: table.integer("id").primary() table.string("colour") table.integer("bottle_id", nullable=True) # HasOne / BelongsTo relationship def tearDown(self): BottleLid.delete() Bottle.delete() def test_has_one_attach_detach(self): bottle = Bottle.create( { "label": "cola", } ) # test unsaved red_lid = BottleLid().fill( { "colour": "Red", } ) current_lid = bottle.attach("lid", red_lid) self.assertIsNotNone(bottle.lid) self.assertIsInstance(current_lid, BottleLid) self.assertTrue(current_lid.is_created()) self.assertEqual(bottle.id, current_lid.bottle_id) bottle.detach("lid", current_lid) test_lid = BottleLid.find(current_lid.id) self.assertIsNone(test_lid.bottle_id) self.assertIsNone(bottle.lid) # test usning a pre-saved record green_lid = BottleLid.create( { "colour": "Green", } ) current_lid = bottle.attach("lid", green_lid) self.assertIsNotNone(bottle.lid) self.assertIsInstance(current_lid, BottleLid) self.assertEqual(bottle.id, current_lid.bottle_id) bottle.detach("lid", current_lid) test_lid = BottleLid.find(current_lid.id) self.assertIsNone(test_lid.bottle_id) self.assertIsNone(bottle.lid) def test_belongs_to_attach_detach(self): bottle = Bottle.create( { "label": "milk", } ) # test unsaved red_lid = BottleLid().fill( { "colour": "Red", } ) current_lid = red_lid.attach("bottle", bottle) self.assertIsNotNone(bottle.lid) self.assertIsInstance(current_lid, BottleLid) self.assertTrue(current_lid.is_created()) self.assertEqual(bottle.id, current_lid.bottle_id) current_lid.detach("bottle", bottle) test_lid = BottleLid.find(current_lid.id) self.assertIsNone(test_lid.bottle_id) self.assertIsNone(bottle.lid) # test usning a pre-saved record green_lid = BottleLid.create( { "colour": "Green", } ) current_lid = green_lid.attach("bottle", bottle) self.assertIsNotNone(bottle.lid) self.assertIsInstance(current_lid, BottleLid) self.assertEqual(bottle.id, current_lid.bottle_id) current_lid.detach("bottle", bottle) test_lid = BottleLid.find(current_lid.id) self.assertIsNone(test_lid.bottle_id) self.assertIsNone(bottle.lid) ================================================ FILE: tests/sqlite/models/test_observers.py ================================================ import inspect import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.relationships import belongs_to from tests.utils import MockConnectionFactory from tests.integrations.config.database import DB class TestM: pass class UserObserver: def created(self, user): TestM.observed_created = 1 def creating(self, user): TestM.observed_creating = 1 def saving(self, user): TestM.observed_saving = 1 def saved(self, user): TestM.observed_saved = 1 def updating(self, user): TestM.observed_updating = 1 def updated(self, user): TestM.observed_updated = 1 def booted(self, user): TestM.observed_booting = 1 def booting(self, user): TestM.observed_booted = 1 def hydrating(self, user): TestM.observed_hydrating = 1 def hydrated(self, user): TestM.observed_hydrated = 1 def deleting(self, user): TestM.observed_deleting = 1 def deleted(self, user): TestM.observed_deleted = 1 class Observer(Model): __connection__ = "dev" __timestamps__ = False __observers__ = {} Observer.observe(UserObserver()) class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def test_created_is_observed(self): # DB.begin_transaction("dev") user = Observer.create({"name": "joe"}) self.assertEqual(TestM.observed_creating, 1) self.assertEqual(TestM.observed_created, 1) # DB.rollback("dev") def test_saving_is_observed(self): # DB.begin_transaction("dev") user = Observer.hydrate({"id": 1, "name": "joe"}) user.name = "bill" user.save() self.assertEqual(TestM.observed_saving, 1) self.assertEqual(TestM.observed_saved, 1) # DB.rollback("dev") def test_updating_is_observed(self): # DB.begin_transaction("dev") user = Observer.hydrate({"id": 1, "name": "joe"}) re = user.update({"name": "bill"}) self.assertEqual(TestM.observed_updated, 1) self.assertEqual(TestM.observed_updating, 1) # DB.rollback("dev") def test_booting_is_observed(self): # DB.begin_transaction("dev") user = Observer.hydrate({"id": 1, "name": "joe"}) re = user.update({"name": "bill"}) self.assertEqual(TestM.observed_booting, 1) self.assertEqual(TestM.observed_booted, 1) # DB.rollback("dev") def test_deleting_is_observed(self): DB.begin_transaction("dev") user = Observer.hydrate({"id": 10, "name": "joe"}) re = user.delete() self.assertEqual(TestM.observed_deleting, 1) self.assertEqual(TestM.observed_deleted, 1) DB.rollback("dev") def test_hydrating_is_observed(self): DB.begin_transaction("dev") user = Observer.hydrate({"id": 10, "name": "joe"}) self.assertEqual(TestM.observed_hydrating, 1) self.assertEqual(TestM.observed_hydrated, 1) DB.rollback("dev") ================================================ FILE: tests/sqlite/models/test_sqlite_model.py ================================================ import inspect import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import ConnectionFactory from src.masoniteorm.models import Model from src.masoniteorm.query import QueryBuilder from src.masoniteorm.query.grammars import SQLiteGrammar from src.masoniteorm.relationships import belongs_to, belongs_to_many from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform from tests.utils import MockConnectionFactory class User(Model): __connection__ = "dev" __timestamps__ = False __dry__ = True class UserForced(Model): __connection__ = "dev" __table__ = "users" __timestamps__ = False __dry__ = True __force_update__ = True class Select(Model): __connection__ = "dev" __selects__ = ["username", "rememember_token as token"] __dry__ = True class SelectPass(Model): __connection__ = "dev" __dry__ = True class UserHydrateHidden(Model): __connection__ = "dev" __table__ = "users_hidden" __hidden__ = ["token", "password"] class Group(Model): __connection__ = "dev" __table__ = "groups" __fillable = ["name"] __with__ = ["team"] @belongs_to_many("group_id", "user_id", "id", "id", table="group_user") def team(self): return UserHydrateHidden class BaseTestQueryRelationships(unittest.TestCase): maxDiff = None def test_update_specific_record(self): user = User.first() sql = user.update({"name": "joe"}).to_sql() self.assertEqual( sql, """UPDATE "users" SET "name" = 'joe' WHERE "id" = '{}'""".format(user.id), ) def test_update_all_records(self): sql = User.update({"name": "joe"}).to_sql() self.assertEqual(sql, """UPDATE "users" SET "name" = 'joe'""") def test_can_find_list(self): sql = User.find(1, query=True).to_sql() self.assertEqual(sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""") sql = User.find([1, 2, 3], query=True).to_sql() self.assertEqual( sql, """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')""" ) def test_find_or_if_record_not_found(self): # Insane record number so record cannot be found record_id = 1_000_000_000_000_000 result = User.find_or(record_id, lambda: "Record not found.") self.assertEqual(result, "Record not found.") def test_find_or_if_record_found(self): record_id = 1 result_id = User.find_or(record_id, lambda: "Record not found.").id self.assertEqual(result_id, record_id) def test_can_set_and_retreive_attribute(self): user = User.hydrate({"id": 1, "name": "joe", "customer_id": 1}) user.customer_id = "CUST1" self.assertEqual(user.customer_id, "CUST1") def test_model_can_use_selects(self): self.assertEqual( Select.to_sql(), 'SELECT "selects"."username", "selects"."rememember_token" AS token FROM "selects"', ) def test_model_can_use_selects_from_methods(self): self.assertEqual( SelectPass.all(["username"], query=True).to_sql(), 'SELECT "select_passes"."username" FROM "select_passes"', ) def test_update_only_changed_attributes(self): user = User.first() sql = user.update({"name": user.name, "username": "new"}).to_sql() # unchanged name attribute is not updated self.assertEqual( sql, """UPDATE "users" SET "username" = 'new' WHERE "id" = '{}'""".format( user.id ), ) def test_can_force_update_on_method(self): user = User.first() sql = user.update({"name": user.name, "username": "new"}, force=True).to_sql() self.assertEqual( sql, """UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format( user.id ), ) def test_can_force_update_on_model(self): user = UserForced.first() sql = user.update({"name": user.name, "username": "new"}).to_sql() self.assertEqual( sql, """UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format( user.id ), ) def test_force_update(self): user = User.first() sql = user.force_update({"name": user.name, "username": "new"}).to_sql() self.assertEqual( sql, """UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format( user.id ), ) def test_update_is_not_done_when_no_changes(self): user = User.first() sql = user.update({"name": user.name}).to_sql() self.assertNotIn("UPDATE", sql) def test_should_collect_correct_amount_data_using_between(self): class ModelUser(Model): __connection__ = "dev" __table__ = "users" count = User.between("age", 1, 2).get().count() self.assertEqual(count, 2) def test_should_collect_correct_amount_data_using_not_between(self): class ModelUser(Model): __connection__ = "dev" __table__ = "users" count = User.where_not_null("id").not_between("age", 1, 2).get().count() self.assertEqual(count, 0) def test_get_columns(self): columns = User.get_columns() self.assertEqual( columns, [ "id", "name", "email", "password", "remember_token", "created_at", "is_admin", "age", "boo", "tool1", "tool2", "active", "updated_at", "profile_id", "name5", "name6", "age6", "age7", "age8", "age10", ], ) def test_should_return_relation_applying_hidden_attributes(self): schema = Schema( connection_details=DATABASES, connection="dev", platform=SQLitePlatform ).on("dev") tables = ["users_hidden", "group_user", "groups"] for table in tables: schema.drop_table_if_exists(table) with schema.create("users_hidden") as blueprint: blueprint.increments("id") blueprint.string("name") blueprint.integer("token") blueprint.string("password") blueprint.timestamps() with schema.create("groups") as blueprint: blueprint.increments("id") blueprint.string("name") blueprint.timestamps() with schema.create("group_user") as blueprint: blueprint.increments("id") blueprint.unsigned_integer("group_id") blueprint.unsigned_integer("user_id") blueprint.foreign("group_id").references("id").on("groups") blueprint.foreign("user_id").references("id").on("users_hidden") blueprint.timestamps() UserHydrateHidden.create( name="Name", password="pass_value", token="token_value" ) Group.create(name="Group") user = UserHydrateHidden.first() group = Group.first() group.attach_related("team", user) serialized = Group.first().serialize() self.assertIn("team", serialized) self.assertTrue("team", serialized) relation_serialized = serialized.get("team") self.assertNotIn("password", relation_serialized) self.assertNotIn("token", relation_serialized) for table in tables: schema.truncate(table) ================================================ FILE: tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py ================================================ import unittest from src.masoniteorm.collection import Collection from src.masoniteorm.models import Model from src.masoniteorm.relationships import has_many_through from tests.integrations.config.database import DATABASES from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform class Enrolment(Model): __table__ = "enrolment" __connection__ = "dev" __fillable__ = ["active_student_id", "in_course_id"] class Student(Model): __table__ = "student" __connection__ = "dev" __fillable__ = ["student_id", "name"] class Course(Model): __table__ = "course" __connection__ = "dev" __fillable__ = ["course_id", "name"] @has_many_through( None, "in_course_id", "active_student_id", "course_id", "student_id" ) def students(self): return [Student, Enrolment] class TestHasManyThroughRelationship(unittest.TestCase): def setUp(self): self.schema = Schema( connection="dev", connection_details=DATABASES, platform=SQLitePlatform, ).on("dev") with self.schema.create_table_if_not_exists("student") as table: table.integer("student_id").primary() table.string("name") with self.schema.create_table_if_not_exists("course") as table: table.integer("course_id").primary() table.string("name") with self.schema.create_table_if_not_exists("enrolment") as table: table.integer("enrolment_id").primary() table.integer("active_student_id") table.integer("in_course_id") if not Course.count(): Course.builder.new().bulk_create( [ {"course_id": 10, "name": "Math 101"}, {"course_id": 20, "name": "History 101"}, {"course_id": 30, "name": "Math 302"}, {"course_id": 40, "name": "Biology 302"}, ] ) if not Student.count(): Student.builder.new().bulk_create( [ {"student_id": 100, "name": "Bob"}, {"student_id": 200, "name": "Alice"}, {"student_id": 300, "name": "Steve"}, {"student_id": 400, "name": "Megan"}, ] ) if not Enrolment.count(): Enrolment.builder.new().bulk_create( [ {"active_student_id": 100, "in_course_id": 30}, {"active_student_id": 200, "in_course_id": 10}, {"active_student_id": 100, "in_course_id": 10}, {"active_student_id": 400, "in_course_id": 20}, ] ) def test_has_many_through_can_eager_load(self): courses = Course.where("name", "Math 101").with_("students").get() students = courses.first().students self.assertIsInstance(students, Collection) self.assertEqual(students.count(), 2) student1 = students.shift() self.assertIsInstance(student1, Student) self.assertEqual(student1.name, "Alice") student2 = students.shift() self.assertIsInstance(student2, Student) self.assertEqual(student2.name, "Bob") # check .first() and .get() produce the same result single = ( Course.where("name", "History 101") .with_("students") .first() ) self.assertIsInstance(single.students, Collection) single_get = ( Course.where("name", "History 101").with_("students").get() ) print(single.students) print(single_get.first().students) self.assertEqual(single.students.count(), 1) self.assertEqual(single_get.first().students.count(), 1) single_name = single.students.first().name single_get_name = single_get.first().students.first().name self.assertEqual(single_name, single_get_name) def test_has_many_through_eager_load_can_be_empty(self): courses = ( Course.where("name", "Biology 302") .with_("students") .get() ) self.assertIsNone(courses.first().students) def test_has_many_through_can_get_related(self): course = Course.where("name", "Math 101").first() self.assertIsInstance(course.students, Collection) self.assertIsInstance(course.students.first(), Student) self.assertEqual(course.students.count(), 2) def test_has_many_through_has_query(self): courses = Course.where_has( "students", lambda query: query.where("name", "Bob") ) self.assertEqual(courses.count(), 2) ================================================ FILE: tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py ================================================ import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import has_one_through from tests.integrations.config.database import DATABASES from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform class Port(Model): __table__ = "ports" __connection__ = "dev" __fillable__ = ["port_id", "name", "port_country_id"] class Country(Model): __table__ = "countries" __connection__ = "dev" __fillable__ = ["country_id", "name"] class IncomingShipment(Model): __table__ = "incoming_shipments" __connection__ = "dev" __fillable__ = ["shipment_id", "name", "from_port_id"] @has_one_through(None, "from_port_id", "port_country_id", "port_id", "country_id") def from_country(self): return [Country, Port] class TestHasOneThroughRelationship(unittest.TestCase): def setUp(self): self.schema = Schema( connection="dev", connection_details=DATABASES, platform=SQLitePlatform, ).on("dev") with self.schema.create_table_if_not_exists("incoming_shipments") as table: table.integer("shipment_id").primary() table.string("name") table.integer("from_port_id") with self.schema.create_table_if_not_exists("ports") as table: table.integer("port_id").primary() table.string("name") table.integer("port_country_id") with self.schema.create_table_if_not_exists("countries") as table: table.integer("country_id").primary() table.string("name") if not Country.count(): Country.builder.new().bulk_create( [ {"country_id": 10, "name": "Australia"}, {"country_id": 20, "name": "USA"}, {"country_id": 30, "name": "Canada"}, {"country_id": 40, "name": "United Kingdom"}, ] ) if not Port.count(): Port.builder.new().bulk_create( [ {"port_id": 100, "name": "Melbourne", "port_country_id": 10}, {"port_id": 200, "name": "Darwin", "port_country_id": 10}, {"port_id": 300, "name": "South Louisiana", "port_country_id": 20}, {"port_id": 400, "name": "Houston", "port_country_id": 20}, {"port_id": 500, "name": "Montreal", "port_country_id": 30}, {"port_id": 600, "name": "Vancouver", "port_country_id": 30}, {"port_id": 700, "name": "Southampton", "port_country_id": 40}, {"port_id": 800, "name": "London Gateway", "port_country_id": 40}, ] ) if not IncomingShipment.count(): IncomingShipment.builder.new().bulk_create( [ {"name": "Bread", "from_port_id": 300}, {"name": "Milk", "from_port_id": 100}, {"name": "Tractor Parts", "from_port_id": 100}, {"name": "Fridges", "from_port_id": 700}, {"name": "Wheat", "from_port_id": 600}, {"name": "Kettles", "from_port_id": 400}, {"name": "Bread", "from_port_id": 700}, ] ) def test_has_one_through_can_eager_load(self): shipments = IncomingShipment.where("name", "Bread").with_("from_country").get() self.assertEqual(shipments.count(), 2) shipment1 = shipments.shift() self.assertIsInstance(shipment1.from_country, Country) self.assertEqual(shipment1.from_country.country_id, 20) shipment2 = shipments.shift() self.assertIsInstance(shipment2.from_country, Country) self.assertEqual(shipment2.from_country.country_id, 40) # check .first() and .get() produce the same result single = ( IncomingShipment.where("name", "Tractor Parts") .with_("from_country") .first() ) single_get = ( IncomingShipment.where("name", "Tractor Parts").with_("from_country").get() ) self.assertEqual(single.from_country.country_id, 10) self.assertEqual(single_get.count(), 1) self.assertEqual( single.from_country.country_id, single_get.first().from_country.country_id ) def test_has_one_through_eager_load_can_be_empty(self): shipments = ( IncomingShipment.where("name", "Bread") .where_has("from_country", lambda query: query.where("name", "Ueaguay")) .with_( "from_country", ) .get() ) self.assertEqual(shipments.count(), 0) def test_has_one_through_can_get_related(self): shipment = IncomingShipment.where("name", "Milk").first() self.assertIsInstance(shipment.from_country, Country) self.assertEqual(shipment.from_country.country_id, 10) def test_has_one_through_has_query(self): shipments = IncomingShipment.where_has( "from_country", lambda query: query.where("name", "USA") ) self.assertEqual(shipments.count(), 2) ================================================ FILE: tests/sqlite/relationships/test_sqlite_polymorphic.py ================================================ import os import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import belongs_to, has_many, morph_to from tests.integrations.config.database import DB class Profile(Model): __table__ = "profiles" __connection__ = "dev" class Articles(Model): __table__ = "articles" __connection__ = "dev" @belongs_to("id", "article_id") def logo(self): return Logo class Logo(Model): __table__ = "logos" __connection__ = "dev" class Like(Model): __connection__ = "dev" @morph_to("record_type", "record_id") def record(self): return self class User(Model): __connection__ = "dev" _eager_loads = () DB.morph_map({"user": User, "article": Articles}) class TestRelationships(unittest.TestCase): maxDiff = None def test_can_get_polymorphic_relation(self): likes = Like.get() for like in likes: self.assertIsInstance(like.record, (Articles, User)) def test_can_get_eager_load_polymorphic_relation(self): likes = Like.with_("record").get() for like in likes: self.assertIsInstance(like.record, (Articles, User)) ================================================ FILE: tests/sqlite/relationships/test_sqlite_relationships.py ================================================ import unittest from src.masoniteorm.models import Model from src.masoniteorm.relationships import belongs_to, has_many, has_one, belongs_to_many from tests.integrations.config.database import DB class Profile(Model): __table__ = "profiles" __connection__ = "dev" class Articles(Model): __table__ = "articles" __connection__ = "dev" __timestamps__ = None __dates__ = ["published_date"] @belongs_to("id", "article_id") def logo(self): return Logo class Logo(Model): __table__ = "logos" __connection__ = "dev" __timestamps__ = None __dates__ = ["published_date"] class User(Model): __connection__ = "dev" _eager_loads = () __casts__ = {"is_admin": "bool"} @belongs_to("id", "user_id") def profile(self): return Profile @has_many("id", "user_id") def articles(self): return Articles def get_is_admin(self): return "You are an admin" class Store(Model): __connection__ = "dev" @belongs_to_many("store_id", "product_id", "id", "id", with_timestamps=True) def products(self): return Product @belongs_to_many("store_id", "product_id", "id", "id", table="product_table") def products_table(self): return Product @belongs_to_many def store_products(self): return Product class Product(Model): __connection__ = "dev" class UserHasOne(Model): __table__ = "users" __connection__ = "dev" @has_one("user_id", "user_id") def profile(self): return Profile class TestRelationships(unittest.TestCase): maxDiff = None def test_relationship_can_be_callable(self): self.assertEqual( User.profile().where("name", "Joe").to_sql(), """SELECT * FROM "profiles" WHERE "profiles"."name" = 'Joe'""", ) def test_can_access_relationship(self): for user in User.where("id", 1).get(): self.assertIsInstance(user.profile, Profile) def test_can_access_has_many_relationship(self): user = User.hydrate(User.where("id", 1).first()) self.assertEqual(len(user.articles), 1) def test_can_access_relationship_multiple_times(self): user = User.hydrate(User.where("id", 1).first()) self.assertEqual(len(user.articles), 1) self.assertEqual(len(user.articles), 1) def test_can_access_relationship_date(self): user = User.with_("articles").where("id", 1).first() for article in user.articles: print(article.logo.published_date.is_past()) def test_loading(self): users = User.with_("articles").get() for user in users: user def test_relationship_has_one_sql(self): self.assertEqual(UserHasOne.profile().to_sql(), 'SELECT * FROM "profiles"') def test_loading_with_nested_with(self): users = User.with_("articles", "articles.logo").get() for user in users: for article in user.articles: if article.logo: print("aa", article.logo.url) def test_casting(self): users = User.with_("articles").where("is_admin", True).get() for user in users: user def test_setting(self): users = User.with_("articles").where("is_admin", True).get() for user in users: user.name = "Joe" user.is_admin = 1 user.save() def test_related(self): user = User.first() related_query = user.related("profile").where("active", 1).to_sql() self.assertEqual( related_query, """SELECT * FROM "profiles" WHERE "profiles"."user_id" = '1' AND "profiles"."active" = '1'""", ) def test_associate_records(self): DB.begin_transaction("dev") user = User.first() articles = [Articles.hydrate({"id": 1, "title": "associate records"})] user.save_many("articles", articles) DB.rollback("dev") def test_belongs_to_many(self): store = Store.hydrate({"id": 2, "name": "Walmart"}) self.assertEqual(store.products.count(), 3) self.assertEqual(store.products.serialize()[0]["id"], 4) self.assertEqual(store.products.serialize()[0]["name"], "Handgun") self.assertEqual( store.products.serialize()[0]["updated_at"], "2020-01-01T00:00:00+00:00" ) self.assertEqual( store.products.serialize()[0]["created_at"], "2020-01-01T00:00:00+00:00" ) def test_belongs_to_eager_many(self): store = Store.hydrate({"id": 2, "name": "Walmart"}) store = Store.with_("products").first() self.assertEqual(store.products.count(), 3) ================================================ FILE: tests/sqlite/schema/test_sqlite_schema_builder.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import SQLiteConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform class TestSQLiteSchemaBuilder(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection="dev", connection_details=DATABASES, platform=SQLitePlatform, dry=True, ).on("dev") def test_can_add_columns(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' ], ) def test_can_add_tiny_text(self): with self.schema.create("users") as blueprint: blueprint.tiny_text("description") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)'] ) def test_can_add_unsigned_decimal(self): with self.schema.create("users") as blueprint: blueprint.unsigned_decimal("amount", 19, 4) self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), ['CREATE TABLE "users" ("amount" DECIMAL(19, 4) NOT NULL)'], ) def test_can_create_table_if_not_exists(self): with self.schema.create_table_if_not_exists("users") as blueprint: blueprint.string("name") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)' ], ) def test_can_add_columns_with_constraint(self): with self.schema.create("users") as blueprint: blueprint.string("name") blueprint.integer("age") blueprint.unique("name") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, UNIQUE(name))' ], ) def test_can_have_float_type(self): with self.schema.create("users") as blueprint: blueprint.float("amount") self.assertEqual( blueprint.to_sql(), ["""CREATE TABLE "users" (""" """\"amount" FLOAT(19, 4) NOT NULL)"""], ) def test_can_add_columns_with_foreign_key_constraint(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.foreign("profile_id").references("id").on("profiles") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ' '("name" VARCHAR(255) NOT NULL, ' '"age" INTEGER NOT NULL, ' '"profile_id" INTEGER NOT NULL, ' "UNIQUE(name), " 'CONSTRAINT users_profile_id_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))' ], ) def test_can_add_columns_with_foreign_key_constraint_name(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.foreign("profile_id", name="profile_foreign").references("id").on( "profiles" ) self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ' '("name" VARCHAR(255) NOT NULL, ' '"age" INTEGER NOT NULL, ' '"profile_id" INTEGER NOT NULL, ' "UNIQUE(name), " 'CONSTRAINT profile_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))' ], ) def test_can_use_morphs_for_polymorphism_relationships(self): with self.schema.create("likes") as blueprint: blueprint.morphs("record") self.assertEqual(len(blueprint.table.added_columns), 2) sql = [ 'CREATE TABLE "likes" ("record_id" INTEGER UNSIGNED NOT NULL, "record_type" VARCHAR(255) NOT NULL)', 'CREATE INDEX likes_record_id_index ON "likes"(record_id)', 'CREATE INDEX likes_record_type_index ON "likes"(record_type)', ] self.assertEqual(blueprint.to_sql(), sql) def test_can_advanced_table_creation(self): with self.schema.create("users") as blueprint: blueprint.increments("id") blueprint.string("name") blueprint.enum("gender", ["male", "female"]) blueprint.string("email").unique() blueprint.string("password") blueprint.string("option").default("ADMIN") blueprint.integer("admin").default(0) blueprint.string("remember_token").nullable() blueprint.timestamp("verified_at").nullable() blueprint.unique(["email", "name"]) blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 11) self.assertEqual( blueprint.to_sql(), [ """CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR(255) NOT NULL, "gender" VARCHAR(255) CHECK(gender IN ('male', 'female')) NOT NULL, "email" VARCHAR(255) NOT NULL, """ """"password" VARCHAR(255) NOT NULL, "option" VARCHAR(255) NOT NULL DEFAULT 'ADMIN', "admin" INTEGER NOT NULL DEFAULT 0, "remember_token" VARCHAR(255) NULL, """ '"verified_at" TIMESTAMP NULL, "created_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, ' '"updated_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), ' "UNIQUE(email), UNIQUE(email, name))" ], ) def test_can_create_indexes(self): with self.schema.table("users") as blueprint: blueprint.index("name") blueprint.index("active", "active_idx") blueprint.index(["name", "email"]) blueprint.unique("name") blueprint.unique(["name", "email"]) blueprint.fulltext("description") self.assertEqual(len(blueprint.table.added_columns), 0) self.assertEqual( blueprint.to_sql(), [ 'CREATE INDEX users_name_index ON "users"(name)', 'CREATE INDEX active_idx ON "users"(active)', 'CREATE INDEX users_name_email_index ON "users"(name,email)', 'CREATE UNIQUE INDEX users_name_unique ON "users"(name)', 'CREATE UNIQUE INDEX users_name_email_unique ON "users"(name,email)', ], ) def test_can_create_indexes_on_previous_column(self): with self.schema.table("users") as blueprint: blueprint.string("email").index() blueprint.string("active").index(name="email_idx") self.assertEqual(len(blueprint.table.added_columns), 2) self.assertEqual( blueprint.to_sql(), [ 'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL', 'ALTER TABLE "users" ADD COLUMN "active" VARCHAR NOT NULL', 'CREATE INDEX users_email_index ON "users"(email)', 'CREATE INDEX email_idx ON "users"(active)', ], ) def test_can_have_composite_keys(self): with self.schema.create("users") as blueprint: blueprint.string("name").unique() blueprint.integer("age") blueprint.integer("profile_id") blueprint.primary(["name", "age"]) self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ' '("name" VARCHAR(255) NOT NULL, ' '"age" INTEGER NOT NULL, ' '"profile_id" INTEGER NOT NULL, ' "UNIQUE(name), " "CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))" ], ) def test_can_have_column_primary_key(self): with self.schema.create("users") as blueprint: blueprint.string("name").primary() blueprint.integer("age") blueprint.integer("profile_id") self.assertEqual(len(blueprint.table.added_columns), 3) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ' '("name" VARCHAR(255) NOT NULL, ' '"age" INTEGER NOT NULL, ' '"profile_id" INTEGER NOT NULL, ' "CONSTRAINT users_name_primary PRIMARY KEY (name))" ], ) def test_can_advanced_table_creation2(self): with self.schema.create("users") as blueprint: blueprint.big_increments("id") blueprint.string("name") blueprint.string("duration") blueprint.string("url") blueprint.json("payload") blueprint.year("birth") blueprint.inet("last_address").nullable() blueprint.cidr("route_origin").nullable() blueprint.macaddr("mac_address").nullable() blueprint.datetime("published_at") blueprint.time("wakeup_at") blueprint.string("thumbnail").nullable() blueprint.integer("premium") blueprint.integer("author_id").unsigned().nullable() blueprint.foreign("author_id").references("id").on("users").on_delete( "set null" ) blueprint.text("description") blueprint.timestamps() self.assertEqual(len(blueprint.table.added_columns), 17) self.assertEqual( blueprint.to_sql(), ( [ 'CREATE TABLE "users" ("id" BIGINT NOT NULL, "name" VARCHAR(255) NOT NULL, "duration" VARCHAR(255) NOT NULL, ' '"url" VARCHAR(255) NOT NULL, "payload" JSON NOT NULL, "birth" VARCHAR(4) NOT NULL, "last_address" VARCHAR(255) NULL, "route_origin" VARCHAR(255) NULL, "mac_address" VARCHAR(255) NULL, ' '"published_at" DATETIME NOT NULL, "wakeup_at" TIME NOT NULL, "thumbnail" VARCHAR(255) NULL, "premium" INTEGER NOT NULL, "author_id" INTEGER UNSIGNED NULL, "description" TEXT NOT NULL, ' '"created_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, "updated_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, ' 'CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY ("author_id") REFERENCES "users"("id") ON DELETE SET NULL)' ] ), ) def test_has_table(self): schema_sql = self.schema.has_table("users") sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='users'" self.assertEqual(schema_sql, sql) def test_can_truncate(self): sql = self.schema.truncate("users") self.assertEqual(sql, 'DELETE FROM "users"') def test_can_rename_table(self): sql = self.schema.rename("users", "clients") self.assertEqual(sql, 'ALTER TABLE "users" RENAME TO "clients"') def test_can_drop_table_if_exists(self): sql = self.schema.drop_table_if_exists("users", "clients") self.assertEqual(sql, 'DROP TABLE IF EXISTS "users"') def test_can_drop_table(self): sql = self.schema.drop_table("users", "clients") self.assertEqual(sql, 'DROP TABLE "users"') def test_has_column(self): sql = self.schema.has_column("users", "name") self.assertEqual( sql, "SELECT column_name FROM information_schema.columns WHERE table_name='users' and column_name='name'", ) def test_can_have_unsigned_columns(self): with self.schema.create("users") as blueprint: blueprint.integer("profile_id").unsigned() blueprint.big_integer("big_profile_id").unsigned() blueprint.tiny_integer("tiny_profile_id").unsigned() blueprint.small_integer("small_profile_id").unsigned() blueprint.medium_integer("medium_profile_id").unsigned() self.assertEqual( blueprint.to_sql(), [ """CREATE TABLE "users" (""" """"profile_id" INTEGER UNSIGNED NOT NULL, """ """"big_profile_id" BIGINT UNSIGNED NOT NULL, """ """"tiny_profile_id" TINYINT UNSIGNED NOT NULL, """ """"small_profile_id" SMALLINT UNSIGNED NOT NULL, """ """"medium_profile_id" MEDIUMINT UNSIGNED NOT NULL)""" ], ) def test_can_enable_foreign_keys(self): sql = self.schema.enable_foreign_key_constraints() self.assertEqual(sql, "PRAGMA foreign_keys = ON") def test_can_disable_foreign_keys(self): sql = self.schema.disable_foreign_key_constraints() self.assertEqual(sql, "PRAGMA foreign_keys = OFF") def test_can_truncate_without_foreign_keys(self): sql = self.schema.truncate("users", foreign_keys=True) self.assertEqual( sql, [ "PRAGMA foreign_keys = OFF", 'DELETE FROM "users"', "PRAGMA foreign_keys = ON", ], ) def test_can_add_enum(self): with self.schema.create("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual( blueprint.to_sql(), [ 'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\')' ], ) ================================================ FILE: tests/sqlite/schema/test_sqlite_schema_builder_alter.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import SQLiteConnection from src.masoniteorm.schema import Schema from src.masoniteorm.schema.platforms import SQLitePlatform from src.masoniteorm.schema.Table import Table class TestSQLiteSchemaBuilderAlter(unittest.TestCase): maxDiff = None def setUp(self): self.schema = Schema( connection="dev", connection_details=DATABASES, platform=SQLitePlatform, dry=True, ).on("dev") def test_can_add_columns(self): with self.schema.table("users") as blueprint: blueprint.string("name") blueprint.string("external_type").default("external") blueprint.integer("age") self.assertEqual(len(blueprint.table.added_columns), 3) sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR NOT NULL', """ALTER TABLE "users" ADD COLUMN "external_type" VARCHAR NOT NULL DEFAULT 'external'""", 'ALTER TABLE "users" ADD COLUMN "age" INTEGER NOT NULL', ] self.assertEqual(blueprint.to_sql(), sql) def test_can_add_constraints(self): with self.schema.table("users") as blueprint: blueprint.unique("name", name="table_unique") self.assertEqual(len(blueprint.table.added_columns), 0) sql = ['CREATE UNIQUE INDEX table_unique ON "users"(name)'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_rename(self): with self.schema.table("users") as blueprint: blueprint.rename("post", "comment", "integer") table = Table("users") table.add_column("post", "integer") blueprint.table.from_table = table sql = [ "CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("comment" INTEGER NOT NULL)', 'INSERT INTO "users" ("comment") SELECT post FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop(self): with self.schema.table("users") as blueprint: blueprint.drop_column("post") table = Table("users") table.add_column("post", "string") table.add_column("name", "string") table.add_column("email", "string") blueprint.table.from_table = table sql = [ "CREATE TEMPORARY TABLE __temp__users AS SELECT name, email FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("name" VARCHAR NOT NULL, "email" VARCHAR NOT NULL)', 'INSERT INTO "users" ("name", "email") SELECT name, email FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(blueprint.to_sql(), sql) def test_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").change() blueprint.string("name") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR NOT NULL', "CREATE TEMPORARY TABLE __temp__users AS SELECT age FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("age" INTEGER NOT NULL, "name" VARCHAR(255) NOT NULL)', 'INSERT INTO "users" ("age") SELECT age FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(blueprint.to_sql(), sql) def test_drop_add_and_change(self): with self.schema.table("users") as blueprint: blueprint.integer("age").change() blueprint.string("name") blueprint.drop_column("email") self.assertEqual(len(blueprint.table.added_columns), 1) self.assertEqual(len(blueprint.table.changed_columns), 1) table = Table("users") table.add_column("age", "string") table.add_column("email", "string") blueprint.table.from_table = table sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR', "CREATE TEMPORARY TABLE __temp__users AS SELECT age FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("age" INTEGER NOT NULL, "name" VARCHAR(255) NOT NULL)', 'INSERT INTO "users" ("age") SELECT age FROM __temp__users', "DROP TABLE __temp__users", ] def test_timestamp_alter_add_nullable_column(self): with self.schema.table("users") as blueprint: blueprint.timestamp("due_date").nullable() self.assertEqual(len(blueprint.table.added_columns), 1) table = Table("users") table.add_column("age", "string") blueprint.table.from_table = table sql = ['ALTER TABLE "users" ADD COLUMN "due_date" TIMESTAMP NULL'] self.assertEqual(blueprint.to_sql(), sql) def test_alter_drop_on_table_schema_table(self): schema = Schema(connection="dev", connection_details=DATABASES).on("dev") with schema.table("table_schema") as blueprint: blueprint.drop_column("name") with schema.table("table_schema") as blueprint: blueprint.string("name").nullable() def test_alter_add_primary(self): with self.schema.table("users") as blueprint: blueprint.primary("playlist_id") sql = [ 'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)' ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_column_and_foreign_key(self): with self.schema.table("users") as blueprint: blueprint.unsigned_integer("playlist_id").nullable() blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( "cascade" ).on_update("SET NULL") table = Table("users") table.add_column("age", "string") table.add_column("email", "string") blueprint.table.from_table = table sql = [ 'ALTER TABLE "users" ADD COLUMN "playlist_id" INTEGER UNSIGNED NULL REFERENCES "playlists"("id")', "CREATE TEMPORARY TABLE __temp__users AS SELECT age, email FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("age" VARCHAR NOT NULL, "email" VARCHAR NOT NULL, "playlist_id" INTEGER UNSIGNED NULL, ' 'CONSTRAINT users_playlist_id_foreign FOREIGN KEY ("playlist_id") REFERENCES "playlists"("id") ON DELETE CASCADE ON UPDATE SET NULL)', 'INSERT INTO "users" ("age", "email") SELECT age, email FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(blueprint.to_sql(), sql) def test_alter_add_foreign_key_only(self): with self.schema.table("users") as blueprint: blueprint.foreign("playlist_id").references("id").on("playlists").on_delete( "cascade" ).on_update("set null") table = Table("users") table.add_column("age", "string") table.add_column("email", "string") blueprint.table.from_table = table sql = [ "CREATE TEMPORARY TABLE __temp__users AS SELECT age, email FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("age" VARCHAR NOT NULL, "email" VARCHAR NOT NULL, ' 'CONSTRAINT users_playlist_id_foreign FOREIGN KEY ("playlist_id") REFERENCES "playlists"("id") ON DELETE CASCADE ON UPDATE SET NULL)', 'INSERT INTO "users" ("age", "email") SELECT age, email FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(blueprint.to_sql(), sql) def test_can_add_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active") self.assertEqual(len(blueprint.table.added_columns), 1) sql = [ 'ALTER TABLE "users" ADD COLUMN "status" VARCHAR CHECK(\'status\' IN(\'active\', \'inactive\')) NOT NULL DEFAULT \'active\'' ] self.assertEqual(blueprint.to_sql(), sql) def test_can_change_column_enum(self): with self.schema.table("users") as blueprint: blueprint.enum("status", ["active", "inactive"]).default("active").change() blueprint.table.from_table = Table("users") self.assertEqual(len(blueprint.table.changed_columns), 1) sql = [ 'CREATE TEMPORARY TABLE __temp__users AS SELECT FROM users', 'DROP TABLE "users"', 'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\')', 'INSERT INTO "users" ("status") SELECT status FROM __temp__users', 'DROP TABLE __temp__users' ] self.assertEqual(blueprint.to_sql(), sql) ================================================ FILE: tests/sqlite/schema/test_table.py ================================================ import unittest from tests.integrations.config.database import DATABASES from src.masoniteorm.connections import SQLiteConnection from src.masoniteorm.schema import Column, Table from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform class TestTable(unittest.TestCase): maxDiff = None def setUp(self): self.platform = SQLitePlatform() def test_add_columns(self): table = Table("users") table.add_column("name", "string") self.assertIsInstance(table.added_columns["name"], Column) def test_primary_key(self): table = Table("users") table.add_column("id", "integer") table.set_primary_key("id") self.assertEqual(table.primary_key, "id") def test_create_sql(self): table = Table("users") table.add_column("id", "integer") table.add_column("name", "string") sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR NOT NULL)' self.assertEqual([sql], self.platform.compile_create_sql(table)) def test_create_sql_with_primary_key(self): table = Table("users") table.add_column("id", "integer") table.add_column("name", "string") table.set_primary_key("id") sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR NOT NULL)' self.assertEqual([sql], self.platform.compile_create_sql(table)) def test_create_sql_with_unique_constraint(self): table = Table("users") table.add_column("id", "integer") table.add_column("name", "string") table.add_constraint("name", "unique", ["name"]) table.set_primary_key("id") sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR NOT NULL, UNIQUE(name))' self.platform.constraintize(table.get_added_constraints()) self.assertEqual(self.platform.compile_create_sql(table), [sql]) def test_create_sql_with_multiple_unique_constraints(self): table = Table("users") table.add_column("id", "integer") table.add_column("email", "string") table.add_column("name", "string") table.add_constraint("name", "unique", ["name"]) table.add_constraint("email", "unique", ["email"]) table.set_primary_key("id") sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "email" VARCHAR NOT NULL, "name" VARCHAR NOT NULL, UNIQUE(name), UNIQUE(email))' self.platform.constraintize(table.get_added_constraints()) self.assertEqual(self.platform.compile_create_sql(table), [sql]) def test_create_sql_with_multiple_unique_constraint(self): table = Table("users") table.add_column("id", "integer") table.add_column("email", "string") table.add_column("name", "string") table.add_constraint("name", "unique", ["name", "email"]) table.set_primary_key("id") sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "email" VARCHAR NOT NULL, "name" VARCHAR NOT NULL, UNIQUE(name, email))' self.platform.constraintize(table.get_added_constraints()) self.assertEqual(self.platform.compile_create_sql(table), [sql]) def test_create_sql_with_foreign_key_constraint(self): table = Table("users") table.add_column("id", "integer") table.add_column("profile_id", "integer") table.add_column("comment_id", "integer") table.add_foreign_key("profile_id", "profiles", "id") table.add_foreign_key("comment_id", "comments", "id") table.set_primary_key("id") sql = ( 'CREATE TABLE "users" (' '"id" INTEGER NOT NULL, "profile_id" INTEGER NOT NULL, "comment_id" INTEGER NOT NULL, ' 'CONSTRAINT users_profile_id_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"), ' 'CONSTRAINT users_comment_id_foreign FOREIGN KEY ("comment_id") REFERENCES "comments"("id"))' ) self.platform.constraintize(table.get_added_constraints()) self.assertEqual(self.platform.compile_create_sql(table), [sql]) def test_can_build_table_from_connection_call(self): sql_details = DATABASES["dev"] table = self.platform.get_current_schema( SQLiteConnection( database=sql_details["database"], name="dev" ).make_connection(), "table_schema", ) self.assertEqual(len(table.added_columns), 4) ================================================ FILE: tests/sqlite/schema/test_table_diff.py ================================================ import unittest from src.masoniteorm.schema import Column, Table from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform from src.masoniteorm.schema.TableDiff import TableDiff class TestTableDiff(unittest.TestCase): def setUp(self): self.platform = SQLitePlatform() def test_rename_table(self): table = Table("users") table.add_column("name", "string") diff = TableDiff("users") diff.from_table = table diff.new_name = "clients" sql = ['ALTER TABLE "users" RENAME TO "clients"'] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_drop_index(self): table = Table("users") table.add_index("name", "name_index", "unique") diff = TableDiff("users") diff.from_table = table diff.remove_index("name") sql = ["DROP INDEX name"] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_drop_index_and_rename_table(self): table = Table("users") table.add_index("name", "name_unique", "unique") diff = TableDiff("users") diff.from_table = table diff.new_name = "clients" diff.remove_index("name_unique") sql = ["DROP INDEX name_unique", 'ALTER TABLE "users" RENAME TO "clients"'] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_alter_add_column(self): table = Table("users") diff = TableDiff("users") diff.from_table = table diff.add_column("name", "string") diff.add_column("email", "string") sql = [ 'ALTER TABLE "users" ADD COLUMN "name" VARCHAR NOT NULL', 'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL', ] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_alter_rename(self): table = Table("users") table.add_column("post", "integer") diff = TableDiff("users") diff.from_table = table diff.rename_column("post", "comment", "integer") sql = [ "CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("comment" INTEGER NOT NULL)', 'INSERT INTO "users" ("comment") SELECT post FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_alter_advanced_rename_columns(self): table = Table("users") table.add_column("post", "integer") table.add_column("user", "integer") table.add_column("email", "integer") diff = TableDiff("users") diff.from_table = table diff.rename_column("post", "comment", "integer") sql = [ "CREATE TEMPORARY TABLE __temp__users AS SELECT post, user, email FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("comment" INTEGER NOT NULL, "user" INTEGER NOT NULL, "email" INTEGER NOT NULL)', 'INSERT INTO "users" ("comment", "user", "email") SELECT post, user, email FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_alter_rename_column_and_rename_table(self): table = Table("users") table.add_column("post", "integer") diff = TableDiff("users") diff.from_table = table diff.new_name = "clients" diff.rename_column("post", "comment", "integer") sql = [ "CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("comment" INTEGER NOT NULL)', 'INSERT INTO "users" ("comment") SELECT post FROM __temp__users', "DROP TABLE __temp__users", 'ALTER TABLE "users" RENAME TO "clients"', ] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_alter_rename_column_and_rename_table_and_drop_index(self): table = Table("users") table.add_column("post", "integer") table.add_index("name", "name_unique", "unique") diff = TableDiff("users") diff.from_table = table diff.new_name = "clients" diff.rename_column("post", "comment", "integer") diff.remove_index("name") sql = [ "DROP INDEX name", "CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("comment" INTEGER NOT NULL)', 'INSERT INTO "users" ("comment") SELECT post FROM __temp__users', "DROP TABLE __temp__users", 'ALTER TABLE "users" RENAME TO "clients"', ] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) def test_alter_can_drop_column(self): table = Table("users") table.add_column("post", "integer") table.add_column("name", "string") table.add_column("email", "string") diff = TableDiff("users") diff.from_table = table diff.drop_column("post") sql = [ "CREATE TEMPORARY TABLE __temp__users AS SELECT name, email FROM users", 'DROP TABLE "users"', 'CREATE TABLE "users" ("name" VARCHAR NOT NULL, "email" VARCHAR NOT NULL)', 'INSERT INTO "users" ("name", "email") SELECT name, email FROM __temp__users', "DROP TABLE __temp__users", ] self.assertEqual(sql, self.platform.compile_alter_sql(diff)) ================================================ FILE: tests/utils.py ================================================ from unittest import mock from src.masoniteorm.connections.ConnectionFactory import ConnectionFactory from src.masoniteorm.connections.MySQLConnection import MySQLConnection from src.masoniteorm.connections.SQLiteConnection import SQLiteConnection from src.masoniteorm.schema.platforms import MySQLPlatform class MockMySQLConnection(MySQLConnection): def make_connection(self): self._connection = mock.MagicMock() self._cursor = mock.MagicMock() return self @classmethod def get_default_platform(cls): return MySQLPlatform class MockMSSQLConnection(MySQLConnection): def make_connection(self): self._connection = mock.MagicMock() self._cursor = mock.MagicMock() return self @classmethod def get_default_platform(cls): return MySQLPlatform class MockPostgresConnection(MySQLConnection): def make_connection(self): self._connection = mock.MagicMock() return self class MockSQLiteConnection(SQLiteConnection): def make_connection(self): self._connection = mock.MagicMock() return self def query(self, *args, **kwargs): return {} class MockConnectionFactory(ConnectionFactory): _connections = { "mysql": MockMySQLConnection, "mssql": MockMSSQLConnection, "postgres": MockPostgresConnection, "sqlite": MockSQLiteConnection, "oracle": "", }