Repository: MasoniteFramework/orm
Branch: 2.0
Commit: 0d31e53f1881
Files: 224
Total size: 1020.6 KB
Directory structure:
gitextract_q9lmgekj/
├── .deepsource.toml
├── .env-example
├── .envrc
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ └── feature_request.md
│ └── workflows/
│ ├── pythonapp.yml
│ └── pythonpublish.yml
├── .gitignore
├── .pre-commit-config.yaml
├── .pypirc
├── .tool-versions
├── CONTRIBUTING.md
├── LICENSE
├── MANIFEST.in
├── README.md
├── TODO.md
├── app/
│ └── observers/
│ └── UserObserver.py
├── cc.py
├── conda/
│ ├── conda_build_config.yaml
│ └── meta.yaml
├── config/
│ └── test-database.py
├── databases/
│ ├── migrations/
│ │ ├── 2018_01_09_043202_create_users_table.py
│ │ ├── 2020_04_17_000000_create_friends_table.py
│ │ ├── 2020_04_17_00000_create_articles_table.py
│ │ ├── 2020_10_20_152904_create_table_schema_migration.py
│ │ └── __init__.py
│ └── seeds/
│ ├── database_seeder.py
│ └── user_table_seeder.py
├── makefile
├── orm
├── pyproject.toml
├── pytest.ini
├── requirements.dev
├── requirements.txt
├── setup.py
├── src/
│ └── masoniteorm/
│ ├── .gitignore
│ ├── __init__.py
│ ├── collection/
│ │ ├── Collection.py
│ │ └── __init__.py
│ ├── commands/
│ │ ├── CanOverrideConfig.py
│ │ ├── CanOverrideOptionsDefault.py
│ │ ├── Command.py
│ │ ├── Entry.py
│ │ ├── MakeMigrationCommand.py
│ │ ├── MakeModelCommand.py
│ │ ├── MakeModelDocstringCommand.py
│ │ ├── MakeObserverCommand.py
│ │ ├── MakeSeedCommand.py
│ │ ├── MigrateCommand.py
│ │ ├── MigrateFreshCommand.py
│ │ ├── MigrateRefreshCommand.py
│ │ ├── MigrateResetCommand.py
│ │ ├── MigrateRollbackCommand.py
│ │ ├── MigrateStatusCommand.py
│ │ ├── SeedRunCommand.py
│ │ ├── ShellCommand.py
│ │ ├── __init__.py
│ │ └── stubs/
│ │ ├── create_migration.stub
│ │ ├── create_seed.stub
│ │ ├── model.stub
│ │ ├── observer.stub
│ │ └── table_migration.stub
│ ├── config.py
│ ├── connections/
│ │ ├── .gitignore
│ │ ├── BaseConnection.py
│ │ ├── ConnectionFactory.py
│ │ ├── ConnectionResolver.py
│ │ ├── MSSQLConnection.py
│ │ ├── MySQLConnection.py
│ │ ├── PostgresConnection.py
│ │ ├── SQLiteConnection.py
│ │ └── __init__.py
│ ├── exceptions.py
│ ├── expressions/
│ │ ├── __init__.py
│ │ └── expressions.py
│ ├── factories/
│ │ ├── Factory.py
│ │ └── __init__.py
│ ├── helpers/
│ │ ├── __init__.py
│ │ └── misc.py
│ ├── migrations/
│ │ ├── Migration.py
│ │ └── __init__.py
│ ├── models/
│ │ ├── MigrationModel.py
│ │ ├── Model.py
│ │ ├── Model.pyi
│ │ ├── Pivot.py
│ │ └── __init__.py
│ ├── observers/
│ │ ├── ObservesEvents.py
│ │ └── __init__.py
│ ├── pagination/
│ │ ├── BasePaginator.py
│ │ ├── LengthAwarePaginator.py
│ │ ├── SimplePaginator.py
│ │ └── __init__.py
│ ├── providers/
│ │ ├── ORMProvider.py
│ │ └── __init__.py
│ ├── query/
│ │ ├── EagerRelation.py
│ │ ├── QueryBuilder.py
│ │ ├── __init__.py
│ │ ├── grammars/
│ │ │ ├── BaseGrammar.py
│ │ │ ├── MSSQLGrammar.py
│ │ │ ├── MySQLGrammar.py
│ │ │ ├── PostgresGrammar.py
│ │ │ ├── SQLiteGrammar.py
│ │ │ └── __init__.py
│ │ └── processors/
│ │ ├── MSSQLPostProcessor.py
│ │ ├── MySQLPostProcessor.py
│ │ ├── PostgresPostProcessor.py
│ │ ├── SQLitePostProcessor.py
│ │ └── __init__.py
│ ├── relationships/
│ │ ├── BaseRelationship.py
│ │ ├── BelongsTo.py
│ │ ├── BelongsToMany.py
│ │ ├── HasMany.py
│ │ ├── HasManyThrough.py
│ │ ├── HasOne.py
│ │ ├── HasOneThrough.py
│ │ ├── MorphMany.py
│ │ ├── MorphOne.py
│ │ ├── MorphTo.py
│ │ ├── MorphToMany.py
│ │ └── __init__.py
│ ├── schema/
│ │ ├── Blueprint.py
│ │ ├── Column.py
│ │ ├── ColumnDiff.py
│ │ ├── Constraint.py
│ │ ├── ForeignKeyConstraint.py
│ │ ├── Index.py
│ │ ├── Schema.py
│ │ ├── Table.py
│ │ ├── TableDiff.py
│ │ ├── __init__.py
│ │ └── platforms/
│ │ ├── MSSQLPlatform.py
│ │ ├── MySQLPlatform.py
│ │ ├── Platform.py
│ │ ├── PostgresPlatform.py
│ │ ├── SQLitePlatform.py
│ │ └── __init__.py
│ ├── scopes/
│ │ ├── BaseScope.py
│ │ ├── SoftDeleteScope.py
│ │ ├── SoftDeletesMixin.py
│ │ ├── TimeStampsMixin.py
│ │ ├── TimeStampsScope.py
│ │ ├── UUIDPrimaryKeyMixin.py
│ │ ├── UUIDPrimaryKeyScope.py
│ │ ├── __init__.py
│ │ └── scope.py
│ ├── seeds/
│ │ ├── Seeder.py
│ │ └── __init__.py
│ ├── stubs/
│ │ ├── create-migration.html
│ │ └── table-migration.html
│ └── testing/
│ ├── BaseTestCaseSelectGrammar.py
│ └── __init__.py
└── tests/
├── User.py
├── collection/
│ └── test_collection.py
├── commands/
│ └── test_shell.py
├── config/
│ └── test_db_url.py
├── connections/
│ └── test_base_connections.py
├── eagers/
│ └── test_eager.py
├── factories/
│ └── test_factories.py
├── integrations/
│ └── config/
│ ├── __init__.py
│ └── database.py
├── models/
│ └── test_models.py
├── mssql/
│ ├── builder/
│ │ ├── test_mssql_query_builder.py
│ │ └── test_mssql_query_builder_relationships.py
│ ├── grammar/
│ │ ├── test_mssql_delete_grammar.py
│ │ ├── test_mssql_insert_grammar.py
│ │ ├── test_mssql_qmark.py
│ │ ├── test_mssql_select_grammar.py
│ │ └── test_mssql_update_grammar.py
│ └── schema/
│ ├── test_mssql_schema_builder.py
│ └── test_mssql_schema_builder_alter.py
├── mysql/
│ ├── builder/
│ │ ├── test_mysql_builder_transaction.py
│ │ ├── test_query_builder.py
│ │ ├── test_query_builder_scopes.py
│ │ └── test_transactions.py
│ ├── connections/
│ │ └── test_mysql_connection_selects.py
│ ├── grammar/
│ │ ├── test_mysql_delete_grammar.py
│ │ ├── test_mysql_insert_grammar.py
│ │ ├── test_mysql_qmark.py
│ │ ├── test_mysql_select_grammar.py
│ │ └── test_mysql_update_grammar.py
│ ├── model/
│ │ ├── test_accessors_and_mutators.py
│ │ └── test_model.py
│ ├── relationships/
│ │ ├── test_belongs_to_many.py
│ │ ├── test_has_many_through.py
│ │ ├── test_has_one_through.py
│ │ └── test_relationships.py
│ ├── schema/
│ │ ├── test_mysql_schema_builder.py
│ │ └── test_mysql_schema_builder_alter.py
│ └── scopes/
│ ├── test_can_use_global_scopes.py
│ ├── test_can_use_scopes.py
│ └── test_soft_delete.py
├── postgres/
│ ├── builder/
│ │ ├── test_postgres_query_builder.py
│ │ └── test_postgres_transaction.py
│ ├── grammar/
│ │ ├── test_delete_grammar.py
│ │ ├── test_insert_grammar.py
│ │ ├── test_select_grammar.py
│ │ └── test_update_grammar.py
│ ├── relationships/
│ │ └── test_postgres_relationships.py
│ └── schema/
│ ├── test_postgres_schema_builder.py
│ └── test_postgres_schema_builder_alter.py
├── scopes/
│ └── test_default_global_scopes.py
├── seeds/
│ └── test_seeds.py
├── sqlite/
│ ├── builder/
│ │ ├── test_sqlite_builder_insert.py
│ │ ├── test_sqlite_builder_pagination.py
│ │ ├── test_sqlite_query_builder.py
│ │ ├── test_sqlite_query_builder_eager_loading.py
│ │ ├── test_sqlite_query_builder_relationships.py
│ │ └── test_sqlite_transaction.py
│ ├── grammar/
│ │ ├── test_sqlite_delete_grammar.py
│ │ ├── test_sqlite_insert_grammar.py
│ │ ├── test_sqlite_select_grammar.py
│ │ └── test_sqlite_update_grammar.py
│ ├── models/
│ │ ├── test_attach_detach.py
│ │ ├── test_observers.py
│ │ └── test_sqlite_model.py
│ ├── relationships/
│ │ ├── test_sqlite_has_many_through_relationship.py
│ │ ├── test_sqlite_has_one_through_relationship.py
│ │ ├── test_sqlite_polymorphic.py
│ │ └── test_sqlite_relationships.py
│ └── schema/
│ ├── test_sqlite_schema_builder.py
│ ├── test_sqlite_schema_builder_alter.py
│ ├── test_table.py
│ └── test_table_diff.py
└── utils.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .deepsource.toml
================================================
# generated by deepsource.io
version = 1
test_patterns = [
'tests/**/*.py'
]
exclude_patterns = [
'databases/migrations/*'
]
[[analyzers]]
name = "python"
enabled = true
runtime_version = "3.x.x"
================================================
FILE: .env-example
================================================
RUN_MYSQL_DATABASE=False
MYSQL_DATABASE_HOST=
MYSQL_DATABASE_USER=
MYSQL_DATABASE_PASSWORD=
MYSQL_DATABASE_DATABASE=
MYSQL_DATABASE_PORT=
POSTGRES_DATABASE_HOST=
POSTGRES_DATABASE_USER=
POSTGRES_DATABASE_PASSWORD=
POSTGRES_DATABASE_DATABASE=
POSTGRES_DATABASE_PORT=
DATABASE_URL=
================================================
FILE: .envrc
================================================
use asdf
layout python
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: A bug would be defined as an issue / problem in the original requirement. If the feature works but could be enhanced please use the feature request option.
title: ''
labels: 'bug'
assignees: ''
---
**Describe the bug**
A clear and concise description of what the bug is.
**To Reproduce**
Steps to reproduce the behavior:
1. Go to '...'
2. Click on '....'
3. Scroll down to '....'
4. See error
**Expected behavior**
What do you believe should be happening?
**Screenshots or code snippets**
Screenshots help a lot. If applicable, add screenshots to help explain your problem.
**Desktop (please complete the following information):**
- OS: [e.g. Mac OSX, Windows]
- Version [e.g. Big Sur, 10]
**What database are you using?**
- Type: [e.g. Postgres, MySQL, SQLite]
- Version [e.g. 8, 9.1, 10.5]
- Masonite ORM [e.g. v1.0.26, v1.0.27]
**Additional context**
Any other steps you are doing or any other related information that will help us debug the problem please put here.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request or enhancement
about: Suggest an idea or improvement for this project.
title: ''
labels: enhancement, feature request
assignees: ''
---
**Describe the feature as you'd like to see it**
A clear and concise description of what you want to happen.
**What do we currently have to do now?**
Give some examples or code snippets on the current way of doing things.
**Additional context**
Add any other context or screenshots about the feature request here.
- [ ] Is this a breaking change?
================================================
FILE: .github/workflows/pythonapp.yml
================================================
name: Test Application
on: [push, pull_request]
jobs:
build:
runs-on: ubuntu-20.04
services:
postgres:
image: postgres:10.8
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
ports:
# will assign a random free host port
- 5432/tcp
# needed because the postgres container does not provide a healthcheck
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
mysql:
image: mysql:5.7
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: orm
ports:
- 3306
options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3
strategy:
matrix:
python-version: ["3.6", "3.7", "3.8", "3.9"]
name: Python ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
make init-ci
- name: Test with pytest
env:
POSTGRES_DATABASE_HOST: localhost
POSTGRES_DATABASE_DATABASE: postgres
POSTGRES_DATABASE_USER: postgres
POSTGRES_DATABASE_PASSWORD: postgres
POSTGRES_DATABASE_PORT: ${{ job.services.postgres.ports[5432] }}
MYSQL_DATABASE_HOST: localhost
MYSQL_DATABASE_DATABASE: orm
MYSQL_DATABASE_USER: root
MYSQL_DATABASE_PORT: ${{ job.services.mysql.ports[3306] }}
DB_CONFIG_PATH: tests/integrations/config/database.py
run: |
python orm migrate --connection postgres
python orm migrate --connection mysql
make test
lint:
runs-on: ubuntu-20.04
name: Lint
steps:
- uses: actions/checkout@v1
- name: Set up Python 3.6
uses: actions/setup-python@v4
with:
python-version: 3.6
- name: Install Flake8
run: |
pip install flake8-pyproject
- name: Lint
run: make lint
================================================
FILE: .github/workflows/pythonpublish.yml
================================================
name: Upload Python Package
on:
release:
types: [created]
jobs:
build:
runs-on: ubuntu-20.04
services:
postgres:
image: postgres:10.8
env:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: postgres
ports:
# will assign a random free host port
- 5432/tcp
# needed because the postgres container does not provide a healthcheck
options: --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5
mysql:
image: mysql:5.7
env:
MYSQL_ALLOW_EMPTY_PASSWORD: yes
MYSQL_DATABASE: orm
ports:
- 3306
options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3
strategy:
matrix:
python-version: ["3.6"]
name: Python ${{ matrix.python-version }}
steps:
- uses: actions/checkout@v1
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
make init-ci
- name: Test with Pytest and Publish to PYPI
env:
POSTGRES_DATABASE_HOST: localhost
POSTGRES_DATABASE_DATABASE: postgres
POSTGRES_DATABASE_USER: postgres
POSTGRES_DATABASE_PASSWORD: postgres
POSTGRES_DATABASE_PORT: ${{ job.services.postgres.ports[5432] }}
MYSQL_DATABASE_HOST: localhost
MYSQL_DATABASE_DATABASE: orm
MYSQL_DATABASE_USER: root
MYSQL_DATABASE_PORT: ${{ job.services.mysql.ports[3306] }}
TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }}
TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }}
DB_CONFIG_PATH: tests/integrations/config/database.py
run: |
python orm migrate --connection postgres
python orm migrate --connection mysql
make publish
- name: Discord notification
env:
DISCORD_WEBHOOK: ${{ secrets.DISCORD_WEBHOOK }}
uses: Ilshidur/action-discord@master
with:
args: "{{ EVENT_PAYLOAD.repository.full_name }} {{ EVENT_PAYLOAD.release.tag_name }} has been released. Checkout the full release notes here: {{ EVENT_PAYLOAD.release.html_url }}"
publish:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: publish-to-conda
uses: fcakyon/conda-publish-action@v1.3
with:
subdir: "conda"
anacondatoken: ${{ secrets.ANACONDA_TOKEN }}
platforms: "win osx linux"
================================================
FILE: .gitignore
================================================
venv
.direnv
.python-version
.vscode
.pytest_*
**/*__pycache__*
**/*.DS_Store*
masonite_validation*
dist
.env
*.db
*.sqlite3
.idea
**/*.egg-info
htmlcov/*
coverage.xml
.coverage
*.log
build
/orm.sqlite3
/.bootstrapped-pip
/.ignore-pre-commit
================================================
FILE: .pre-commit-config.yaml
================================================
repos:
- repo: https://github.com/psf/black
rev: 25.1.0
hooks:
- id: black
exclude: |
(?x)(
^build|
^conda
)
- repo: https://github.com/pycqa/isort
rev: 6.0.1
hooks:
- id: isort
exclude: |
(?x)(
^build|
^conda
)
- repo: https://github.com/pycqa/flake8
rev: 7.1.2
hooks:
- id: flake8
additional_dependencies: [flake8-pyproject]
exclude: |
(?x)(
^build|
^conda
)
================================================
FILE: .pypirc
================================================
[distutils]
index-servers =
pypi
pypitest
[pypi]
username=username
password=password
================================================
FILE: .tool-versions
================================================
python 3.8.10
================================================
FILE: CONTRIBUTING.md
================================================
# Contributing Guide
This guide is intended to explain how to contribute to this project.
## Preface
Note that you do not need to write code in order to contribute to the project. You can contribute your voice, your ideas, past experiences or just join general discussions we are having in GitHub or the Slack channel. Whether its 1 hour per day or 1 minute per week. All input and ideas are important for the success of the project. That one sentence could lead to more discussion and ideas.
If you have any questions at all then be sure to join the [Slack Channel](https://slack.masoniteproject.com).
If you are interested in the project then it would be a great idea to read the "White Paper". This is a document about how the project works and how the classes all work together. The White Paper can be [Found Here](https://orm.masoniteproject.com/white-page)
## Issues
Everything really should start with opening an issue or finding an issue. If you feel you have an idea for how the project can be improved, no matter how small, you should open an issue so we can have an open discussion with the maintainers of the project.
We can discuss in that issue the solution to the problem or feature you have. If we do not feel it fits within the project then we will close the issue. Feel free to open a new issue if new information comes up.
If there is already an issue open that you want to contribute ideas to, have information to add to the discussion, or want to contribute to the issue by writing code to complete the issue then please comment on the issue saying you would like to contribute to it.
## Labels
To improve the quality of issues, please add any related labels to the issue you think are most relevant. You may add as many as you think make sense. There are tag descriptions on the labels section of the repo so please read those descriptions to choose which labels best work for the issue.
**Please do not use any of the difficulty labels (easy, medium or hard). A maintainer will label the issue with the difficulty level after reviewing the issue**
## Difficulty Levels
Before contributing, it is assumed you have basic Python or programming skills and you are able to understand the issues enough to have a discussion about it without much information direction. All issues are marked with a difficulty level to determine how much effort will be involved in closing the issue. There are several difficulty level issues:
**good first issue** - Issues marked with this label are great issues to take if you have never contributed to open source before. These issues typically have a step by step solution in the issues are are intended for first time contributors to expand the pool of maintainers.
**easy** - Issues marked as easy are great issues to take if you have never contributed to this project before. Take this opportunity to take a simple issue to understand how some of the code works together and a simple test.
**medium** - Issues marked with this should not be worked on by someone who has not contributed to the project before. These issues assume you have basic knowledge of the codebase and can work on the issue with little direction. Discussions should be had on these issues on the best way to solve and close them.
**hard** - These issues should really not be worked on unless you are a maintainer of the Masonite organization. These issues are very involved and assume advanced knowledge of the codebase. You may contribute your voice to the issue but it is not advised you work on these issues unless you are a maintainer or have contributed to the past and have completed a medium difficulty task
## Pull Request Flow
If you choose to contribute to an issue via code contribution then please follow the steps below:
* First you will need to fork the repository. You can do this directly in GitHub by clicking the fork icon in the top right corner of the repository.
* You should then checkout the repository to your computer
* Make the code change and push up your changes to a local branch.
* **The branch should should follow a common naming convention. If the issue is #123 then your branch should be called `feature/123`. This helps me identify which issue the branch is supposed to fix.**
* You should then open a pull request to the repository.
* **All tests are required to be written before merging a pull request.** If you do not know how to write tests you can open the pull request without tests and we can discuss the best way to test the code you wrote. A maintainer or contributor could also step in and write tests for you
Once the pull request is open, the code will be reviewed and we will discuss how this particular solution to the problem solves the original issue. If there are code improvements or corrections to be made then they will be discussed with maintainers of the project.
## Running Tests
You should run all tests locally and make sure they pass before writing any code. This way you can be sure if your code is not breaking any tests that may be failing for other reasons.
You should set up a virtual environment and run tests via pytest:
```
$ python -m venv venv
$ source venv/bin/activate
$ python -m pytest
```
This should run all tests successfully. The code was written in a way where you do not need a database to test the code so all tests should run fine.
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2020 Joseph Mancuso
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: MANIFEST.in
================================================
# include src/package/some/directory/*
include src/masoniteorm/commands/stubs/*
================================================
FILE: README.md
================================================
Masonite ORM
## Installation & Usage
All documentation can be found here [https://orm.masoniteproject.com](https://orm.masoniteproject.com).
Hop on [Masonite Discord Community](https://discord.gg/TwKeFahmPZ) to ask any questions you need!
## Contributing
If you would like to contribute please read the [Contributing Documentation](CONTRIBUTING.md) here.
## License
Masonite ORM is open-sourced software licensed under the [MIT License](LICENSE).
================================================
FILE: TODO.md
================================================
- [x] fix scopes - need to find a new way to perform scopes
- [x] scopes need to be set on the model and then passed off to the query builder
- [x] global scopes
- on select need to call a scope
- on delete need to call a scope
- need to be able to remove global scopes
- need to be able to able to call something like with_trashed()
- this needs to remove global scopes only from the soft deletes class
================================================
FILE: app/observers/UserObserver.py
================================================
"""User Observer"""
from masoniteorm.models import Model
class UserObserver:
def created(self, clients):
"""Handle the Clients "created" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def creating(self, clients):
"""Handle the Clients "creating" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def saving(self, clients):
"""Handle the Clients "saving" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def saved(self, clients):
"""Handle the Clients "saved" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def updating(self, clients):
"""Handle the Clients "updating" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def updated(self, clients):
"""Handle the Clients "updated" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def booted(self, clients):
"""Handle the Clients "booted" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def booting(self, clients):
"""Handle the Clients "booting" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def hydrating(self, clients):
"""Handle the Clients "hydrating" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def hydrated(self, clients):
"""Handle the Clients "hydrated" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def deleting(self, clients):
"""Handle the Clients "deleting" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
def deleted(self, clients):
"""Handle the Clients "deleted" event.
Args:
clients (masoniteorm.models.Model): Clients model.
"""
pass
================================================
FILE: cc.py
================================================
"""Sandbox experimental file used to quickly feature test features of the package
"""
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.connections import MySQLConnection, PostgresConnection
from src.masoniteorm.query.grammars import MySQLGrammar, PostgresGrammar
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import has_many
import inspect
# builder = QueryBuilder(connection=PostgresConnection, grammar=PostgresGrammar).table("users").on("postgres")
# print(builder.where("id", 1).or_where(lambda q: q.where('id', 2).or_where('id', 3)).get())
class User(Model):
__connection__ = "t"
__table__ = "users"
__dates__ = ["verified_at"]
@has_many("id", "user_id")
def articles(self):
return Article
class Company(Model):
__connection__ = "sqlite"
# user = User.create({"name": "phill", "email": "phill"})
# print(inspect.isclass(User))
user = User.first()
# user.update({"verified_at": None, "updated_at": None})
print(user.serialize())
# print(user.serialize())
# print(User.first())
================================================
FILE: conda/conda_build_config.yaml
================================================
python:
- 3.6
- 3.7
- 3.8
- 3.9
================================================
FILE: conda/meta.yaml
================================================
{% set data = load_setup_py_data() %}
package:
name: masonite-orm
version: {{ data['version'] }}
source:
path: ..
build:
number: 0
script: python setup.py install --single-version-externally-managed --record=record.txt
requirements:
build:
- python
run:
- python
test:
run:
- python -m pytest
about:
home: {{ data['url'] }}
license: {{ data['license'] }}
summary: {{ data['description'] }}
================================================
FILE: config/test-database.py
================================================
from src.masoniteorm.connections import ConnectionResolver
DATABASES = {
"default": "mysql",
"mysql": {
"host": "127.0.0.1",
"driver": "mysql",
"database": "masonite",
"user": "root",
"password": "",
"port": 3306,
"log_queries": False,
"options": {
#
}
},
"postgres": {
"host": "127.0.0.1",
"driver": "postgres",
"database": "masonite",
"user": "root",
"password": "",
"port": 5432,
"log_queries": False,
"options": {
#
}
},
"sqlite": {
"driver": "sqlite",
"database": "masonite.sqlite3",
}
}
DB = ConnectionResolver().set_connection_details(DATABASES)
================================================
FILE: databases/migrations/2018_01_09_043202_create_users_table.py
================================================
from src.masoniteorm.migrations import Migration
from tests.User import User
class CreateUsersTable(Migration):
def up(self):
"""Run the migrations."""
with self.schema.create('users') as table:
table.increments('id')
table.string('name')
table.string('email').unique()
table.string('password')
table.string('second_password').nullable()
table.string('remember_token').nullable()
table.timestamp('verified_at').nullable()
table.timestamps()
if not self.schema._dry:
User.on(self.connection).set_schema(self.schema_name).create({
'name': 'Joe',
'email': 'joe@email.com',
'password': 'secret'
})
def down(self):
"""Revert the migrations."""
self.schema.drop('users')
================================================
FILE: databases/migrations/2020_04_17_000000_create_friends_table.py
================================================
from src.masoniteorm.migrations.Migration import Migration
class CreateFriendsTable(Migration):
def up(self):
"""
Run the migrations.
"""
with self.schema.create('friends') as table:
table.increments('id')
table.string('name')
table.integer('age')
def down(self):
"""
Revert the migrations.
"""
self.schema.drop('friends')
================================================
FILE: databases/migrations/2020_04_17_00000_create_articles_table.py
================================================
from src.masoniteorm.migrations.Migration import Migration
class CreateArticlesTable(Migration):
def up(self):
"""
Run the migrations.
"""
with self.schema.create('fans') as table:
table.increments('id')
table.string('name')
table.integer('age')
def down(self):
"""
Revert the migrations.
"""
self.schema.drop('fans')
================================================
FILE: databases/migrations/2020_10_20_152904_create_table_schema_migration.py
================================================
"""CreateTableSchemaMigration Migration."""
from src.masoniteorm.migrations import Migration
class CreateTableSchemaMigration(Migration):
def up(self):
"""
Run the migrations.
"""
with self.schema.create("table_schema") as table:
table.increments('id')
table.string('name')
table.timestamps()
def down(self):
"""
Revert the migrations.
"""
self.schema.drop("table_schema")
================================================
FILE: databases/migrations/__init__.py
================================================
import os
import sys
sys.path.append(os.getcwd())
================================================
FILE: databases/seeds/database_seeder.py
================================================
"""Base Database Seeder Module."""
from src.masoniteorm.seeds import Seeder
from .user_table_seeder import UserTableSeeder
class DatabaseSeeder(Seeder):
def run(self):
"""Run the database seeds."""
self.call(UserTableSeeder)
================================================
FILE: databases/seeds/user_table_seeder.py
================================================
"""UserTableSeeder Seeder."""
from src.masoniteorm.seeds import Seeder
from src.masoniteorm.factories import Factory as factory
from tests.User import User
factory.register(User, lambda faker: {'email': faker.email()})
class UserTableSeeder(Seeder):
def run(self):
"""Run the database seeds."""
factory(User, 5).create({
'name': 'Joe',
'password': 'joe',
})
================================================
FILE: makefile
================================================
SHELL := /bin/bash
init: .env .bootstrapped-pip .git/hooks/pre-commit
init-ci:
touch .ignore-pre-commit
make init
.bootstrapped-pip: requirements.txt requirements.dev
pip install -r requirements.txt -r requirements.dev
touch .bootstrapped-pip
.git/hooks/pre-commit:
@if ! test -e ".ignore-pre-commit"; then \
pip install pre-commit; \
pre-commit install --install-hooks; \
fi
.env:
cp .env-example .env
# Create MySQL Database
# Create Postgres Database
test: init
python -m pytest tests
ci:
make test
check: format sort lint
lint:
flake8 src/masoniteorm/
format: init
black src/masoniteorm tests/
sort: init
isort src/masoniteorm tests/
coverage:
python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/
python -m coveralls
show:
python -m pytest --cov-report term --cov-report html --cov=src/masoniteorm tests/
cov:
python -m pytest --cov-report term --cov-report xml --cov=src/masoniteorm tests/
publish:
pip install twine
make test
python setup.py sdist
twine upload dist/*
rm -fr build dist .egg masonite.egg-info
rm -rf dist/*
pub:
python setup.py sdist
twine upload dist/*
rm -fr build dist .egg masonite.egg-info
rm -rf dist/*
pypirc:
cp .pypirc ~/.pypirc
================================================
FILE: orm
================================================
"""Craft Command.
This module is really used for backup only if the masonite CLI cannot import this for you.
This can be used by running "python craft". This module is not ran when the CLI can
successfully import commands for you.
"""
from cleo import Application
from src.masoniteorm.commands import (
MigrateCommand,
MigrateRollbackCommand,
MigrateRefreshCommand,
MigrateFreshCommand,
MakeMigrationCommand,
MakeObserverCommand,
MakeModelCommand,
MigrateStatusCommand,
MigrateResetCommand,
MakeSeedCommand,
MakeModelDocstringCommand,
SeedRunCommand,
)
application = Application("ORM Version:", 0.1)
application.add(MigrateCommand())
application.add(MigrateRollbackCommand())
application.add(MigrateRefreshCommand())
application.add(MigrateFreshCommand())
application.add(MakeMigrationCommand())
application.add(MakeModelCommand())
application.add(MakeModelDocstringCommand())
application.add(MakeObserverCommand())
application.add(MigrateResetCommand())
application.add(MigrateStatusCommand())
application.add(MakeSeedCommand())
application.add(SeedRunCommand())
if __name__ == "__main__":
application.run()
================================================
FILE: pyproject.toml
================================================
[tool.black]
target-version = ['py38']
include = '\.pyi?$'
line-length = 79
[tool.isort]
profile = "black"
multi_line_output = 3
include_trailing_comma = true
force_grid_wrap = 0
use_parentheses = true
ensure_newline_before_comments = true
[tool.flake8]
ignore = ['E501', 'E203', 'E128', 'E402', 'E731', 'F821', 'E712', 'W503', 'F811']
#max-line-length = 79
#max-complexity = 18
per-file-ignores = [
'__init__.py:F401',
]
================================================
FILE: pytest.ini
================================================
[pytest]
env =
D:DB_CONFIG_PATH=config/test-database
================================================
FILE: requirements.dev
================================================
flake8-pyproject
black
faker
pytest
pytest-cov
pytest-env
pymysql
isort
================================================
FILE: requirements.txt
================================================
inflection==0.3.1
psycopg2-binary
pyodbc
pendulum>=2.1,<3.1
cleo>=0.8.0,<0.9
python-dotenv==0.14.0
================================================
FILE: setup.py
================================================
from setuptools import setup
with open("README.md", "r") as fh:
long_description = fh.read()
setup(
name="masonite-orm",
# Versions should comply with PEP440. For a discussion on single-sourcing
# the version across setup.py and the project code, see
# https://packaging.python.org/en/latest/single_source_version.html
version="2.24.0",
package_dir={"": "src"},
description="The Official Masonite ORM",
long_description=long_description,
long_description_content_type="text/markdown",
# The project's main homepage.
url="https://github.com/masoniteframework/orm",
# Author details
author="Joe Mancuso",
author_email="joe@masoniteproject.com",
# Choose your license
license="MIT",
# If your package should include things you specify in your MANIFEST.in file
# Use this option if your package needs to include files that are not python files
# like html templates or css files
include_package_data=True,
# List run-time dependencies here. These will be installed by pip when
# your project is installed. For an analysis of "install_requires" vs pip's
# requirements files see:
# https://packaging.python.org/en/latest/requirements.html
install_requires=[
"inflection>=0.3,<0.6",
"pendulum>=2.1,<3.1",
"faker>=4.1.0,<14.0",
"cleo>=0.8.0,<0.9",
],
# See https://pypi.python.org/pypi?%3Aaction=list_classifiers
classifiers=[
# How mature is this project? Common values are
# 3 - Alpha
# 4 - Beta
# 5 - Production/Stable
"Development Status :: 5 - Production/Stable",
# Indicate who your project is intended for
"Intended Audience :: Developers",
"Topic :: Software Development :: Build Tools",
"Environment :: Web Environment",
# Pick your license as you wish (should match "license" above)
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
# Specify the Python versions you support here. In particular, ensure
# that you indicate whether you support Python 2, Python 3 or both.
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"Programming Language :: Python :: 3.8",
"Programming Language :: Python :: 3.9",
"Framework :: Masonite",
"Topic :: Software Development :: Libraries :: Python Modules",
"Framework :: Masonite",
],
# What does your project relate to?
keywords="Masonite, MasoniteFramework, Python, ORM",
# You can just specify the packages manually here if your project is
# simple. Or you can use find_packages().
packages=[
"masoniteorm",
"masoniteorm.collection",
"masoniteorm.commands",
"masoniteorm.connections",
"masoniteorm.expressions",
"masoniteorm.factories",
"masoniteorm.helpers",
"masoniteorm.migrations",
"masoniteorm.models",
"masoniteorm.observers",
"masoniteorm.pagination",
"masoniteorm.providers",
"masoniteorm.query",
"masoniteorm.query.grammars",
"masoniteorm.query.processors",
"masoniteorm.relationships",
"masoniteorm.schema",
"masoniteorm.schema.platforms",
"masoniteorm.scopes",
"masoniteorm.seeds",
"masoniteorm.testing",
],
# List additional groups of dependencies here (e.g. development
# dependencies). You can install these using the following syntax,
# for example:
# $ pip install -e .[dev,test]
# $ pip install your-package[dev,test]
extras_require={
"test": ["coverage", "pytest"],
},
# If there are data files included in your packages that need to be
# installed, specify them here. If using Python 2.6 or less, then these
# have to be included in MANIFEST.in as well.
## package_data={
## 'sample': [],
## },
# Although 'package_data' is the preferred approach, in some case you may
# need to place data files outside of your packages. See:
# http://docs.python.org/3.4/distutils/setupscript.html#installing-additional-files # noqa
# In this case, 'data_file' will be installed into '/my_data'
## data_files=[('my_data', ['data/data_file.txt'])],
# To provide executable scripts, use entry points in preference to the
# "scripts" keyword. Entry points provide cross-platform support and allow
# pip to create the appropriate form of executable for the target platform.
entry_points={
"console_scripts": [
"masonite-orm = masoniteorm.commands.Entry:application.run",
],
},
)
================================================
FILE: src/masoniteorm/.gitignore
================================================
================================================
FILE: src/masoniteorm/__init__.py
================================================
from .models import Model
from .factories.Factory import Factory
================================================
FILE: src/masoniteorm/collection/Collection.py
================================================
import json
import random
import operator
from functools import reduce
class Collection:
"""Wraps various data types to make working with them easier."""
def __init__(self, items=None):
self._items = items or []
self.__appends__ = []
def take(self, number: int):
"""Takes a specific number of results from the items.
Arguments:
number {integer} -- The number of results to take.
Returns:
int
"""
if number < 0:
return self[number:]
return self[:number]
def first(self, callback=None):
"""Takes the first result in the items.
If a callback is given then the first result will be the result after the filter.
Keyword Arguments:
callback {callable} -- Used to filter the results before returning the first item. (default: {None})
Returns:
mixed -- Returns whatever the first item is.
"""
filtered = self
if callback:
filtered = self.filter(callback)
response = None
if filtered:
response = filtered[0]
return response
def items(self):
return self._items.items()
def last(self, callback=None):
"""Takes the last result in the items.
If a callback is given then the last result will be the result after the filter.
Keyword Arguments:
callback {callable} -- Used to filter the results before returning the last item. (default: {None})
Returns:
mixed -- Returns whatever the last item is.
"""
filtered = self
if callback:
filtered = self.filter(callback)
return filtered[-1]
def all(self):
"""Returns all the items.
Returns:
mixed -- Returns all items.
"""
return self._items
def avg(self, key=None):
"""Returns the average of the items.
If a key is given it will return the average of all the values of the key.
Keyword Arguments:
key {string} -- The key to use to find the average of all the values of that key. (default: {None})
Returns:
int -- Returns the average.
"""
result = 0
items = self._get_value(key) or self._items
try:
result = sum(items) / len(items)
except TypeError:
pass
return result
def max(self, key=None):
"""Returns the max of the items.
If a key is given it will return the max of all the values of the key.
Keyword Arguments:
key {string} -- The key to use to find the max of all the values of that key. (default: {None})
Returns:
int -- Returns the max.
"""
result = 0
items = self._get_value(key) or self._items
try:
return max(items)
except (TypeError, ValueError):
pass
return result
def min(self, key=None):
"""Returns the min of the items.
If a key is given it will return the min of all the values of the key.
Keyword Arguments:
key {string} -- The key to use to find the min of all the values of that key. (default: {None})
Returns:
int -- Returns the min.
"""
result = 0
items = self._get_value(key) or self._items
try:
return min(items)
except (TypeError, ValueError):
pass
return result
def chunk(self, size: int):
"""Chunks the items.
Keyword Arguments:
size {integer} -- The number of values in each chunk.
Returns:
int -- Returns the average.
"""
items = []
for i in range(0, self.count(), size):
items.append(self[i : i + size])
return self.__class__(items)
def collapse(self):
items = []
for item in self:
items += self.__get_items(item)
return self.__class__(items)
def contains(self, key, value=None):
if value:
return self.contains(lambda x: self._data_get(x, key) == value)
if self._check_is_callable(key, raise_exception=False):
return self.first(key) is not None
return key in self
def count(self):
return len(self._items)
def diff(self, items):
items = self.__get_items(items)
return self.__class__([x for x in self if x not in items])
def each(self, callback):
self._check_is_callable(callback)
for k, v in enumerate(self):
result = callback(v)
if not result:
break
self[k] = result
def every(self, callback):
self._check_is_callable(callback)
return all([callback(x) for x in self])
def filter(self, callback):
self._check_is_callable(callback)
return self.__class__(list(filter(callback, self)))
def flatten(self):
def _flatten(items):
if isinstance(items, dict):
for v in items.values():
for x in _flatten(v):
yield x
elif isinstance(items, list):
for i in items:
for j in _flatten(i):
yield j
else:
yield items
return self.__class__(list(_flatten(self._items)))
def forget(self, *keys):
keys = reversed(sorted(keys))
for key in keys:
del self[key]
return self
def for_page(self, page, number):
return self.__class__(self[page:number])
def get(self, key, default=None):
try:
return self[key]
except IndexError:
pass
return self._value(default)
def implode(self, glue=",", key=None):
first = self.first()
if not isinstance(first, str) and key:
return glue.join(self.pluck(key))
return glue.join([str(x) for x in self])
def is_empty(self):
return not self
def map(self, callback):
self._check_is_callable(callback)
items = [callback(x) for x in self]
return self.__class__(items)
def map_into(self, cls, method=None, **kwargs):
results = []
for item in self:
if method:
results.append(getattr(cls, method)(item, **kwargs))
else:
results.append(cls(item))
return self.__class__(results)
def merge(self, items):
if isinstance(items, Collection):
items = items._items
elif not isinstance(items, list):
raise ValueError("Unable to merge uncompatible types")
items = self.__get_items(items)
self._items += items
return self
def pluck(self, value, key=None, keep_nulls=True):
if key:
attributes = {}
else:
attributes = []
if isinstance(self._items, dict):
return Collection([self._items.get(value)])
for item in self:
if isinstance(item, dict):
iterable = item.items()
elif hasattr(item, "serialize"):
iterable = item.serialize().items()
else:
iterable = self.all().items()
for k, v in iterable:
if keep_nulls is False and v is None:
continue
if k == value:
if key:
attributes[self._data_get(item, key)] = self._data_get(
item, value
)
else:
attributes.append(v)
return Collection(attributes)
def pop(self):
last = self._items.pop()
return last
def prepend(self, value):
self._items.insert(0, value)
return self
def pull(self, key):
value = self.get(key)
self.forget(key)
return value
def push(self, value):
self._items.append(value)
def put(self, key, value):
self._items[key] = value
return self
def random(self, count=None):
"""Returns a random item of the collection."""
collection_count = self.count()
if collection_count == 0:
return None
elif count and count > collection_count:
raise ValueError("count argument must be inferior to collection length.")
elif count:
self._items = random.sample(self._items, k=count)
return self
else:
return random.choice(self._items)
def reduce(self, callback, initial=0):
return reduce(callback, self, initial)
def reject(self, callback):
self._check_is_callable(callback)
items = self._get_value(callback) or self._items
self._items = items
def reverse(self):
self._items = self[::-1]
def serialize(self, *args, **kwargs):
def _serialize(item):
if self.__appends__:
item.set_appends(self.__appends__)
if hasattr(item, "serialize"):
return item.serialize(*args, **kwargs)
elif hasattr(item, "to_dict"):
return item.to_dict()
return item
return list(map(_serialize, self))
def add_relation(self, result=None):
for model in self._items:
model.add_relation(result or {})
return self
def shift(self):
return self.pull(0)
def sort(self, key=None):
if key:
self._items.sort(key=lambda x: x[key], reverse=False)
return self
self._items = sorted(self)
return self
def sum(self, key=None):
result = 0
items = self._get_value(key) or self._items
try:
result = sum(items)
except TypeError:
pass
return result
def to_json(self, **kwargs):
return json.dumps(self.serialize(), **kwargs)
def group_by(self, key):
from itertools import groupby
self.sort(key)
new_dict = {}
for k, v in groupby(self._items, key=lambda x: x[key]):
new_dict.update({k: list(v)})
return Collection(new_dict)
def transform(self, callback):
self._check_is_callable(callback)
self._items = self._get_value(callback)
def unique(self, key=None):
if not key:
items = list(set(self._items))
return self.__class__(items)
keys = set()
items = []
if isinstance(self.all(), dict):
return self
for item in self:
if isinstance(item, dict):
comparison = item.get(key)
elif isinstance(item, str):
comparison = item
else:
comparison = getattr(item, key)
if comparison not in keys:
items.append(item)
keys.add(comparison)
return self.__class__(items)
def where(self, key, *args):
op = "=="
value = args[0]
if len(args) >= 2:
op = args[0]
value = args[1]
attributes = []
for item in self._items:
if isinstance(item, dict):
comparison = item.get(key)
else:
comparison = getattr(item, key) if hasattr(item, key) else False
if self._make_comparison(comparison, value, op):
attributes.append(item)
return self.__class__(attributes)
def where_in(self, key, args: list) -> "Collection":
# Compatibility patch - allow numeric strings to match integers
# (if all args are numeric strings)
if all(
[isinstance(arg, str) and arg.isnumeric() for arg in args]
):
return self.where_in(key, [int(arg) for arg in args])
attributes = []
for item in self._items:
if isinstance(item, dict):
if key not in item:
continue
comparison = item.get(key)
else:
if not hasattr(item, key):
continue
comparison = getattr(item, key)
if comparison in args:
attributes.append(item)
return self.__class__(attributes)
def where_not_in(self, key, args: list) -> "Collection":
# Compatibility patch - allow numeric strings to match integers
# (if all args are numeric strings)
if all(
[isinstance(arg, str) and arg.isnumeric() for arg in args]
):
return self.where_not_in(key, [int(arg) for arg in args])
attributes = []
for item in self._items:
if isinstance(item, dict):
if key not in item:
continue
comparison = item.get(key)
else:
if not hasattr(item, key):
continue
comparison = getattr(item, key)
if comparison not in args:
attributes.append(item)
return self.__class__(attributes)
def zip(self, items):
items = self.__get_items(items)
if not isinstance(items, list):
raise ValueError("The 'items' parameter must be a list or a Collection")
_items = []
for x, y in zip(self, items):
_items.append([x, y])
return self.__class__(_items)
def set_appends(self, appends):
"""
Set the attributes that should be appended to the Collection.
:rtype: list
"""
self.__appends__ += appends
return self
def _get_value(self, key):
if not key:
return None
items = []
for item in self:
if isinstance(key, str):
if hasattr(item, key) or (key in item):
items.append(getattr(item, key, item[key]))
elif callable(key):
result = key(item)
if result:
items.append(result)
return items
def _data_get(self, item, key, default=None):
try:
if isinstance(item, (list, tuple, dict)):
item = item[key]
elif isinstance(item, object):
item = getattr(item, key)
except (IndexError, AttributeError, KeyError, TypeError):
return self._value(default)
return item
def _value(self, value):
if callable(value):
return value()
return value
def _check_is_callable(self, callback, raise_exception=True):
if not callable(callback):
if not raise_exception:
return False
raise ValueError("The 'callback' should be a function")
return True
def _make_comparison(self, a, b, op):
operators = {
"<": operator.lt,
"<=": operator.le,
"==": operator.eq,
"!=": operator.ne,
">": operator.gt,
">=": operator.ge,
}
return operators[op](str(a), str(b))
def __iter__(self):
for item in self._items:
yield item
def __eq__(self, other):
other = self.__get_items(other)
return other == self._items
def __getitem__(self, item):
if isinstance(item, slice):
return self.__class__(self._items[item])
if isinstance(item, dict):
return self._items.get(item, None)
try:
return self._items[item]
except KeyError:
return None
def __setitem__(self, key, value):
self._items[key] = value
def __delitem__(self, key):
del self._items[key]
def __ne__(self, other):
other = self.__get_items(other)
return other != self._items
def __len__(self):
return len(self._items)
def __le__(self, other):
other = self.__get_items(other)
return self._items <= other
def __lt__(self, other):
other = self.__get_items(other)
return self._items < other
def __ge__(self, other):
other = self.__get_items(other)
return self._items >= other
def __gt__(self, other):
other = self.__get_items(other)
return self._items > other
@classmethod
def __get_items(cls, items):
if isinstance(items, Collection):
items = items.all()
return items
================================================
FILE: src/masoniteorm/collection/__init__.py
================================================
from .Collection import Collection
================================================
FILE: src/masoniteorm/commands/CanOverrideConfig.py
================================================
from cleo import Command
class CanOverrideConfig(Command):
def __init__(self):
super().__init__()
self.add_option()
def add_option(self):
# 8 is the required flag constant in cleo
self._config.add_option(
"config",
"C",
8,
description="The path to the ORM configuration file. If not given DB_CONFIG_PATH env variable will be used and finally 'config.database'.",
)
================================================
FILE: src/masoniteorm/commands/CanOverrideOptionsDefault.py
================================================
from inflection import underscore
class CanOverrideOptionsDefault:
"""Command mixin to allow to override optional argument default values when instantiating the
command.
Example: SomeCommand(app, option1="other/default").
If an argument long name is using - then use _ in keyword argument:
Example: SomeCommand(app, option_1="other/default") for an option named in option-1
"""
def __init__(self, **kwargs):
super().__init__()
self.overriden_default = kwargs
for option_name, option in self.config.options.items():
# Cleo does not authorize _ in option name but - are authorized and unfortunately -
# cannot be used in Python variables. So underscore() is called to make sure that
# an option like 'option-a' will be accessible with 'option_a' in kwargs
default = self.overriden_default.get(underscore(option_name))
if default:
option.set_default(default)
================================================
FILE: src/masoniteorm/commands/Command.py
================================================
from .CanOverrideConfig import CanOverrideConfig
from .CanOverrideOptionsDefault import CanOverrideOptionsDefault
class Command(CanOverrideOptionsDefault, CanOverrideConfig):
pass
================================================
FILE: src/masoniteorm/commands/Entry.py
================================================
"""Craft Command.
This module is really used for backup only if the masonite CLI cannot import this for you.
This can be used by running "python craft". This module is not ran when the CLI can
successfully import commands for you.
"""
from cleo import Application
from . import (
MigrateCommand,
MigrateRollbackCommand,
MigrateRefreshCommand,
MigrateFreshCommand,
MakeMigrationCommand,
MakeModelCommand,
MakeModelDocstringCommand,
MakeObserverCommand,
MigrateStatusCommand,
MigrateResetCommand,
MakeSeedCommand,
SeedRunCommand,
ShellCommand,
)
application = Application("ORM Version:", 0.1)
application.add(MigrateCommand())
application.add(MigrateRollbackCommand())
application.add(MigrateRefreshCommand())
application.add(MigrateFreshCommand())
application.add(MakeMigrationCommand())
application.add(MakeModelCommand())
application.add(MakeModelDocstringCommand())
application.add(MakeObserverCommand())
application.add(MigrateResetCommand())
application.add(MigrateStatusCommand())
application.add(MakeSeedCommand())
application.add(SeedRunCommand())
application.add(ShellCommand())
if __name__ == "__main__":
application.run()
================================================
FILE: src/masoniteorm/commands/MakeMigrationCommand.py
================================================
import datetime
import os
import pathlib
from inflection import camelize, tableize
from .Command import Command
class MakeMigrationCommand(Command):
"""
Creates a new migration file.
migration
{name : The name of the migration}
{--c|create=None : The table to create}
{--t|table=None : The table to alter}
{--d|directory=databases/migrations : The location of the migration directory}
"""
def handle(self):
name = self.argument("name")
now = datetime.datetime.today()
if self.option("create") != "None":
table = self.option("create")
stub_file = "create_migration"
else:
table = self.option("table")
stub_file = "table_migration"
if table == "None":
table = tableize(name.replace("create_", "").replace("_table", ""))
stub_file = "create_migration"
migration_directory = self.option("directory")
with open(
os.path.join(
pathlib.Path(__file__).parent.absolute(), f"stubs/{stub_file}.stub"
)
) as fp:
output = fp.read()
output = output.replace("__MIGRATION_NAME__", camelize(name))
output = output.replace("__TABLE_NAME__", table)
file_name = f"{now.strftime('%Y_%m_%d_%H%M%S')}_{name}.py"
with open(os.path.join(os.getcwd(), migration_directory, file_name), "w") as fp:
fp.write(output)
self.info(
f"Migration file created: {os.path.join(migration_directory, file_name)}"
)
================================================
FILE: src/masoniteorm/commands/MakeModelCommand.py
================================================
import os
import pathlib
from inflection import camelize, tableize, underscore
from .Command import Command
class MakeModelCommand(Command):
"""
Creates a new model file.
model
{name : The name of the model}
{--m|migration : Optionally create a migration file}
{--s|seeder : Optionally create a seeder file}
{--c|create : If the migration file should create a table}
{--t|table : If the migration file should modify an existing table}
{--p|pep : Makes the file into pep 8 standards}
{--d|directory=app : The location of the model directory}
{--D|migrations-directory=databases/migrations : The location of the migration directory}
{--S|seeders-directory=databases/seeds : The location of the seeders directory}
"""
def handle(self):
name = self.argument("name")
model_directory = self.option("directory")
with open(
os.path.join(pathlib.Path(__file__).parent.absolute(), "stubs/model.stub")
) as fp:
output = fp.read()
output = output.replace("__CLASS__", camelize(name))
if self.option("pep"):
file_name = f"{underscore(name)}.py"
else:
file_name = f"{camelize(name)}.py"
full_directory_path = os.path.join(os.getcwd(), model_directory)
if os.path.exists(os.path.join(full_directory_path, file_name)):
self.line(
f'Model "{name}" Already Exists ({full_directory_path}/{file_name})'
)
return
os.makedirs(os.path.dirname(os.path.join(full_directory_path)), exist_ok=True)
with open(os.path.join(os.getcwd(), model_directory, file_name), "w+") as fp:
fp.write(output)
self.info(f"Model created: {os.path.join(model_directory, file_name)}")
if self.option("migration"):
migrations_directory = self.option("migrations-directory")
if self.option("table"):
self.call(
"migration",
f"update_{tableize(name)}_table --table {tableize(name)} --directory {migrations_directory}",
)
else:
self.call(
"migration",
f"create_{tableize(name)}_table --create {tableize(name)} --directory {migrations_directory}",
)
if self.option("seeder"):
directory = self.option("seeders-directory")
self.call("seed", f"{self.argument('name')} --directory {directory}")
================================================
FILE: src/masoniteorm/commands/MakeModelDocstringCommand.py
================================================
from ..config import load_config
from .Command import Command
class MakeModelDocstringCommand(Command):
"""
Generate model docstring and type hints (for auto-completion).
model:docstring
{table : The table you want to generate docstring and type hints}
{--t|type-hints : The table you want to generate docstring and type hints}
{--c|connection=default : The connection you want to use}
"""
def handle(self):
table = self.argument("table")
DB = load_config(self.option("config")).DB
schema = DB.get_schema_builder(self.option("connection"))
if not schema.has_table(table):
return self.line_error(
f"There is no such table {table} for this connection."
)
self.info(f"Model Docstring for table: {table}")
print('"""')
for _, column in schema.get_columns(table).items():
length = f"({column.length})" if column.length else ""
default = f" default: {column.default}"
print(f"{column.name}: {column.column_type}{length}{default}")
print('"""')
if self.option("type-hints"):
self.info(f"Model Type Hints for table: {table}")
for name, column in schema.get_columns(table).items():
print(f" {name}:{column.column_python_type.__name__}")
================================================
FILE: src/masoniteorm/commands/MakeObserverCommand.py
================================================
import os
import pathlib
from inflection import camelize, underscore
from .Command import Command
class MakeObserverCommand(Command):
"""
Creates a new observer file.
observer
{name : The name of the observer}
{--m|model=None : The name of the model}
{--d|directory=app/observers : The location of the observers directory}
"""
def handle(self):
name = self.argument("name")
model = self.option("model")
if model == "None":
model = name
observer_directory = self.option("directory")
with open(
os.path.join(
pathlib.Path(__file__).parent.absolute(), "stubs/observer.stub"
)
) as fp:
output = fp.read()
output = output.replace("__CLASS__", camelize(name))
output = output.replace("__MODEL_VARIABLE__", underscore(model))
output = output.replace("__MODEL__", camelize(model))
file_name = f"{camelize(name)}Observer.py"
full_directory_path = os.path.join(os.getcwd(), observer_directory)
if os.path.exists(os.path.join(full_directory_path, file_name)):
self.line(
f'Observer "{name}" Already Exists ({full_directory_path}/{file_name})'
)
return
os.makedirs(os.path.join(full_directory_path), exist_ok=True)
with open(os.path.join(os.getcwd(), observer_directory, file_name), "w+") as fp:
fp.write(output)
self.info(f"Observer created: {file_name}")
================================================
FILE: src/masoniteorm/commands/MakeSeedCommand.py
================================================
import os
import pathlib
from inflection import camelize, underscore
from .Command import Command
class MakeSeedCommand(Command):
"""
Creates a new seed file.
seed
{name : The name of the seed}
{--d|directory=databases/seeds : The location of the seed directory}
"""
def handle(self):
# get the contents of a stub file
# replace the placeholders of a stub file
# output the content to a file location
name = self.argument("name") + "TableSeeder"
seed_directory = self.option("directory")
file_name = underscore(name)
stub_file = "create_seed"
with open(
os.path.join(
pathlib.Path(__file__).parent.absolute(), f"stubs/{stub_file}.stub"
)
) as fp:
output = fp.read()
output = output.replace("__SEEDER_NAME__", camelize(name))
file_name = f"{underscore(name)}.py"
full_path = pathlib.Path(os.path.join(os.getcwd(), seed_directory, file_name))
path_normalized = pathlib.Path(seed_directory) / pathlib.Path(file_name)
if os.path.exists(full_path):
return self.line(f"{path_normalized} already exists.")
with open(full_path, "w") as fp:
fp.write(output)
self.info(f"Seed file created: {path_normalized}")
================================================
FILE: src/masoniteorm/commands/MigrateCommand.py
================================================
import os
from ..migrations import Migration
from .Command import Command
class MigrateCommand(Command):
"""
Run migrations.
migrate
{--m|migration=all : Migration's name to be migrated}
{--c|connection=default : The connection you want to run migrations on}
{--f|force : Force migrations without prompt in production}
{--s|show : Shows the output of SQL for migrations that would be running}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""
def handle(self):
# prompt user for confirmation in production
if os.getenv("APP_ENV") == "production" and not self.option("force"):
answer = ""
while answer not in ["y", "n"]:
answer = input(
"Do you want to run migrations in PRODUCTION ? (y/n)\n"
).lower()
if answer != "y":
self.info("Migrations cancelled")
exit(0)
migration = Migration(
command_class=self,
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)
migration.create_table_if_not_exists()
if not migration.get_unran_migrations():
self.info("Nothing To Migrate!")
return
migration_name = self.option("migration")
show_output = self.option("show")
migration.migrate(migration=migration_name, output=show_output)
================================================
FILE: src/masoniteorm/commands/MigrateFreshCommand.py
================================================
from ..migrations import Migration
from .Command import Command
class MigrateFreshCommand(Command):
"""
Drops all tables and migrates them again.
migrate:fresh
{--c|connection=default : The connection you want to run migrations on}
{--d|directory=databases/migrations : The location of the migration directory}
{--f|ignore-fk=? : The connection you want to run migrations on}
{--s|seed=? : Seed database after fresh. The seeder to be ran can be provided in argument}
{--schema=? : Sets the schema to be migrated}
{--D|seed-directory=databases/seeds : The location of the seed directory if seed option is used.}
"""
def handle(self):
migration = Migration(
command_class=self,
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)
migration.fresh(ignore_fk=self.option("ignore-fk"))
self.line("")
if self.option("seed") == "null":
self.call(
"seed:run",
f"None --directory {self.option('seed-directory')} --connection {self.option('connection')}",
)
elif self.option("seed"):
self.call(
"seed:run",
f"{self.option('seed')} --directory {self.option('seed-directory')} --connection {self.option('connection')}",
)
================================================
FILE: src/masoniteorm/commands/MigrateRefreshCommand.py
================================================
from ..migrations import Migration
from .Command import Command
class MigrateRefreshCommand(Command):
"""
Rolls back migrations and migrates them again.
migrate:refresh
{--m|migration=all : Migration's name to be refreshed}
{--c|connection=default : The connection you want to run migrations on}
{--d|directory=databases/migrations : The location of the migration directory}
{--s|seed=? : Seed database after refresh. The seeder to be ran can be provided in argument}
{--schema=? : Sets the schema to be migrated}
{--D|seed-directory=databases/seeds : The location of the seed directory if seed option is used.}
"""
def handle(self):
migration = Migration(
command_class=self,
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)
migration.refresh(self.option("migration"))
self.line("")
if self.option("seed") == "null":
self.call(
"seed:run",
f"None --directory {self.option('seed-directory')} --connection {self.option('connection')}",
)
elif self.option("seed"):
self.call(
"seed:run",
f"{self.option('seed')} --directory {self.option('seed-directory')} --connection {self.option('connection')}",
)
================================================
FILE: src/masoniteorm/commands/MigrateResetCommand.py
================================================
from ..migrations import Migration
from .Command import Command
class MigrateResetCommand(Command):
"""
Reset migrations.
migrate:reset
{--m|migration=all : Migration's name to be rollback}
{--c|connection=default : The connection you want to run migrations on}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""
def handle(self):
migration = Migration(
command_class=self,
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)
migration.reset(self.option("migration"))
================================================
FILE: src/masoniteorm/commands/MigrateRollbackCommand.py
================================================
from ..migrations import Migration
from .Command import Command
class MigrateRollbackCommand(Command):
"""
Rolls back the last batch of migrations.
migrate:rollback
{--m|migration=all : Migration's name to be rollback}
{--c|connection=default : The connection you want to run migrations on}
{--s|show : Shows the output of SQL for migrations that would be running}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""
def handle(self):
Migration(
command_class=self,
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
).rollback(migration=self.option("migration"), output=self.option("show"))
================================================
FILE: src/masoniteorm/commands/MigrateStatusCommand.py
================================================
from ..migrations import Migration
from .Command import Command
class MigrateStatusCommand(Command):
"""
Display migrations status.
migrate:status
{--c|connection=default : The connection you want to run migrations on}
{--schema=? : Sets the schema to be migrated}
{--d|directory=databases/migrations : The location of the migration directory}
"""
def handle(self):
migration = Migration(
command_class=self,
connection=self.option("connection"),
migration_directory=self.option("directory"),
config_path=self.option("config"),
schema=self.option("schema"),
)
migration.create_table_if_not_exists()
table = self.table()
table.set_header_row(["Ran?", "Migration", "Batch"])
migrations = []
for migration_data in migration.get_ran_migrations():
migration_file = migration_data["migration_file"]
batch = migration_data["batch"]
migrations.append(
[
"Y",
f"{migration_file}",
f"{batch}",
]
)
for migration_file in migration.get_unran_migrations():
migrations.append(
[
"N",
f"{migration_file}",
"-",
]
)
table.set_rows(migrations)
table.render(self.io)
================================================
FILE: src/masoniteorm/commands/SeedRunCommand.py
================================================
from inflection import camelize, underscore
from ..seeds import Seeder
from .Command import Command
class SeedRunCommand(Command):
"""
Run seeds.
seed:run
{--c|connection=default : The connection you want to run migrations on}
{--dry : If the seed should run in dry mode}
{table=None : Name of the table to seed}
{--d|directory=databases/seeds : The location of the seed directory}
"""
def handle(self):
seeder = Seeder(
dry=self.option("dry"),
seed_path=self.option("directory"),
connection=self.option("connection"),
)
if self.argument("table") == "None":
seeder.run_database_seed()
seeder_seeded = "Database Seeder"
else:
table = self.argument("table")
seeder_file = (
f"{underscore(table)}_table_seeder.{camelize(table)}TableSeeder"
)
seeder.run_specific_seed(seeder_file)
seeder_seeded = f"{camelize(table)}TableSeeder"
self.line(f"{seeder_seeded} seeded!")
================================================
FILE: src/masoniteorm/commands/ShellCommand.py
================================================
import subprocess
import os
import re
import shlex
from collections import OrderedDict
from ..config import load_config
from .Command import Command
class ShellCommand(Command):
"""
Connect to your database interactive terminal.
shell
{--c|connection=default : The connection you want to use to connect to interactive terminal}
{--s|show=? : Display the command which will be called to connect}
"""
shell_programs = {
"mysql": "mysql",
"postgres": "psql",
"sqlite": "sqlite3",
"mssql": "sqlcmd",
}
def handle(self):
resolver = load_config(self.option("config")).DB
connection = self.option("connection")
if connection == "default":
connection = resolver.get_connection_details()["default"]
config = resolver.get_connection_information(connection)
if not config.get("full_details"):
self.line(
f"Connection configuration for '{connection}' not found !"
)
exit(-1)
command, env = self.get_command(config)
if self.option("show"):
cleaned_command = self.hide_sensitive_options(config, command)
self.comment(cleaned_command)
self.line("")
# let shlex split command in a list as it's safer
command_args = shlex.split(command)
try:
subprocess.run(command_args, check=True, env=env)
except FileNotFoundError:
self.line(
f"Cannot find {config.get('full_details').get('driver')} program ! Please ensure you can call this program in your shell first."
)
exit(-1)
except subprocess.CalledProcessError:
self.line("An error happened calling the command.")
exit(-1)
def get_shell_program(self, connection):
"""Get the database shell program to run."""
return self.shell_programs.get(connection)
def get_command(self, config):
"""Get the command to run as a string."""
driver = config.get("full_details").get("driver")
program = self.get_shell_program(driver)
try:
get_driver_args = getattr(self, f"get_{driver}_args")
except AttributeError:
self.line(
f"Connecting with driver '{driver}' is not implemented !"
)
exit(-1)
args, options = get_driver_args(config)
# process positional arguments
args = " ".join(args)
# process optional arguments
options = self.remove_empty_options(options)
options_string = " ".join(
f"{option} {value}" if value else option
for option, value in options.items()
)
# finally build command string
command = program
if args:
command += f" {args}"
if options_string:
command += f" {options_string}"
# prepare environment in which command will be run
# some drivers need to define env variable such as psql for specifying password
try:
driver_env = getattr(self, f"get_{driver}_env")(config)
except AttributeError:
driver_env = {}
command_env = {**os.environ.copy(), **driver_env}
return command, command_env
def get_mysql_args(self, config):
"""Get command positional arguments and options for MySQL driver."""
args = [config.get("database")]
options = OrderedDict(
{
"--host": config.get("host"),
"--port": config.get("port"),
"--user": config.get("user"),
"--password": config.get("password"),
"--default-character-set": config.get("options", {}).get("charset"),
}
)
return args, options
def get_postgres_args(self, config):
"""Get command positional arguments and options for PostgreSQL driver."""
args = [config.get("database")]
options = OrderedDict(
{
"--host": config.get("host"),
"--port": config.get("port"),
"--username": config.get("user"),
}
)
return args, options
def get_postgres_env(self, config):
return {"PGPASSWORD": config.get("password")}
def get_mssql_args(self, config):
"""Get command positional arguments and options for MSSQL driver."""
args = []
# instance of SQL Server: -S [protocol:]server[instance_name][,port]
server = f"tcp:{config.get('host')}"
if config.get("port"):
server += f",{config.get('port')}"
trusted_connection = config.get("options").get("trusted_connection") == "Yes"
options = OrderedDict(
{
"-d": config.get("database"),
"-U": config.get("user"),
"-P": config.get("password"),
"-S": server,
"-E": trusted_connection,
"-t": config.get("options", {}).get("connection_timeout"),
}
)
return args, options
def get_sqlite_args(self, config):
"""Get command positional arguments and options for SQLite driver."""
args = [config.get("database")]
options = OrderedDict()
return args, options
def remove_empty_options(self, options):
"""Remove empty options when value does not evaluate to True.
Also when value is exactly 'True' we don't want True to be passed as option value but
we want the option to be passed.
"""
# remove empty options and remove value when option is True
cleaned_options = OrderedDict()
for key, value in options.items():
if value is True:
cleaned_options[key] = ""
elif value:
cleaned_options[key] = value
return cleaned_options
def get_sensitive_options(self, config):
driver = config.get("full_details").get("driver")
try:
sensitive_options = getattr(self, f"get_{driver}_sensitive_options")()
except AttributeError:
sensitive_options = []
return sensitive_options
def get_mysql_sensitive_options(self):
return ["--password"]
def get_mssql_sensitive_options(self):
return ["-P"]
def hide_sensitive_options(self, config, command):
"""Hide sensitive options in command string before it's displayed in the shell for
security reasons. All drivers can declare what options are considered sensitive, their
values will be then replaced by *** when displayed only."""
cleaned_command = command
sensitive_options = self.get_sensitive_options(config)
for option in sensitive_options:
# if option is used obfuscate its value
if option in command:
match = re.search(rf"{option} (\w+)", command)
if match.groups():
cleaned_command = cleaned_command.replace(match.groups()[0], "***")
return cleaned_command
================================================
FILE: src/masoniteorm/commands/__init__.py
================================================
import os
import sys
sys.path.append(os.getcwd())
from .MigrateCommand import MigrateCommand
from .MigrateRollbackCommand import MigrateRollbackCommand
from .MigrateRefreshCommand import MigrateRefreshCommand
from .MigrateFreshCommand import MigrateFreshCommand
from .MigrateResetCommand import MigrateResetCommand
from .MakeModelCommand import MakeModelCommand
from .MakeModelDocstringCommand import MakeModelDocstringCommand
from .MakeObserverCommand import MakeObserverCommand
from .MigrateStatusCommand import MigrateStatusCommand
from .MakeMigrationCommand import MakeMigrationCommand
from .MakeSeedCommand import MakeSeedCommand
from .SeedRunCommand import SeedRunCommand
from .ShellCommand import ShellCommand
================================================
FILE: src/masoniteorm/commands/stubs/create_migration.stub
================================================
"""__MIGRATION_NAME__ Migration."""
from masoniteorm.migrations import Migration
class __MIGRATION_NAME__(Migration):
def up(self):
"""
Run the migrations.
"""
with self.schema.create("__TABLE_NAME__") as table:
table.increments("id")
table.timestamps()
def down(self):
"""
Revert the migrations.
"""
self.schema.drop("__TABLE_NAME__")
================================================
FILE: src/masoniteorm/commands/stubs/create_seed.stub
================================================
"""__SEEDER_NAME__ Seeder."""
from masoniteorm.seeds import Seeder
class __SEEDER_NAME__(Seeder):
def run(self):
"""Run the database seeds."""
pass
================================================
FILE: src/masoniteorm/commands/stubs/model.stub
================================================
""" __CLASS__ Model """
from masoniteorm.models import Model
class __CLASS__(Model):
"""__CLASS__ Model"""
pass
================================================
FILE: src/masoniteorm/commands/stubs/observer.stub
================================================
"""__CLASS__ Observer"""
from masoniteorm.models import Model
class __CLASS__Observer:
def created(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "created" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def creating(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "creating" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def saving(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "saving" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def saved(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "saved" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def updating(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "updating" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def updated(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "updated" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def booted(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "booted" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def booting(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "booting" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def hydrating(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "hydrating" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def hydrated(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "hydrated" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def deleting(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "deleting" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
def deleted(self, __MODEL_VARIABLE__):
"""Handle the __MODEL__ "deleted" event.
Args:
__MODEL_VARIABLE__ (masoniteorm.models.Model): __MODEL__ model.
"""
pass
================================================
FILE: src/masoniteorm/commands/stubs/table_migration.stub
================================================
"""__MIGRATION_NAME__ Migration."""
from masoniteorm.migrations import Migration
class __MIGRATION_NAME__(Migration):
def up(self):
"""
Run the migrations.
"""
with self.schema.table("__TABLE_NAME__") as table:
pass
def down(self):
"""
Revert the migrations.
"""
with self.schema.table("__TABLE_NAME__") as table:
pass
================================================
FILE: src/masoniteorm/config.py
================================================
import os
import pydoc
import urllib.parse as urlparse
from .exceptions import ConfigurationNotFound
from .exceptions import InvalidUrlConfiguration
def load_config(config_path=None):
"""Load ORM configuration from given configuration path (dotted or not).
If no path is provided:
1. try to load from DB_CONFIG_PATH environment variable
2. else try to load from default config_path: config/database
"""
selected_config_path = (
os.getenv("DB_CONFIG_PATH", None) or config_path or "config/database"
)
os.environ["DB_CONFIG_PATH"] = selected_config_path
# format path as python module if needed
selected_config_path = (
selected_config_path.replace("/", ".").replace("\\", ".").rstrip(".py")
)
config_module = pydoc.locate(selected_config_path)
if config_module is None:
raise ConfigurationNotFound(
f"ORM configuration file has not been found in {selected_config_path}."
)
return config_module
def db_url(database_url=None, prefix="", options={}, log_queries=False):
"""Parse connection configuration from database url format. If no url is provided,
DATABASE_URL environment variable will be used instead.
Reference: Code adapted from https://github.com/jacobian/dj-database-url
"""
url = database_url or os.getenv("DATABASE_URL")
if not url:
raise InvalidUrlConfiguration("Database url is empty !")
# Register database schemes in URLs.
urlparse.uses_netloc.append("postgres")
urlparse.uses_netloc.append("postgresql")
urlparse.uses_netloc.append("pgsql")
urlparse.uses_netloc.append("postgis")
urlparse.uses_netloc.append("mysql")
urlparse.uses_netloc.append("mysql2")
urlparse.uses_netloc.append("mysqlgis")
urlparse.uses_netloc.append("mssql")
urlparse.uses_netloc.append("sqlite")
DRIVERS_MAP = {
"postgres": "postgres",
"postgresql": "postgres",
"pgsql": "postgres",
"postgis": "postgres",
"mysql": "mysql",
"mysql2": "mysql",
"mysqlgis": "mysql",
"mysql-connector": "mysql",
"mssql": "mssql",
"sqlite": "sqlite",
}
# this is a special case, because if we pass this URL into
# urlparse, urlparse will choke trying to interpret "memory"
# as a port number
if url in ["sqlite://:memory:", "sqlite://memory"]:
driver = DRIVERS_MAP["sqlite"]
path = ":memory:"
# otherwise parse the url as normal
else:
url = urlparse.urlparse(url)
# remove query string from path (not parsed for now)
path = url.path[1:]
if "?" in path and not url.query:
path, _ = path.split("?", 2)
# if we are using sqlite and we have no path, then assume we
# want an in-memory database (this is the behaviour of sqlalchemy)
if url.scheme == "sqlite" and path == "":
path = ":memory:"
# handle postgres percent-encoded paths.
hostname = url.hostname or ""
if "%2f" in hostname.lower():
# Switch to url.netloc to avoid lower cased paths
hostname = url.netloc
if "@" in hostname:
hostname = hostname.rsplit("@", 1)[1]
if ":" in hostname:
hostname = hostname.split(":", 1)[0]
hostname = hostname.replace("%2f", "/").replace("%2F", "/")
# lookup specified driver
driver = DRIVERS_MAP[url.scheme]
port = (
str(url.port) if url.port and driver in [DRIVERS_MAP["mssql"]] else url.port
)
# build final configuration
config = {
"driver": driver,
"database": urlparse.unquote(path or ""),
"prefix": prefix,
"options": options,
"log_queries": log_queries,
}
if driver != DRIVERS_MAP["sqlite"]:
config.update(
{
"user": urlparse.unquote(url.username or ""),
"password": urlparse.unquote(url.password or ""),
"host": hostname,
"port": port or "",
}
)
return config
================================================
FILE: src/masoniteorm/connections/.gitignore
================================================
================================================
FILE: src/masoniteorm/connections/BaseConnection.py
================================================
import logging
from timeit import default_timer as timer
from .ConnectionResolver import ConnectionResolver
class BaseConnection:
_connection = None
_cursor = None
_dry = False
def dry(self):
self._dry = True
return self
def set_schema(self, schema):
self.schema = schema
return self
def log(
self, query, bindings, query_time=0, logger="masoniteorm.connections.queries"
):
logger = logging.getLogger("masoniteorm.connection.queries")
logger.propagate = self.full_details.get("propagate", True)
logger.debug(
f"Running query {query}, {bindings}. Executed in {query_time}ms",
extra={"query": query, "bindings": bindings, "query_time": query_time},
)
def statement(self, query, bindings=()):
"""Wrapper around calling the cursor query. Helpful for logging output.
Args:
query (string): The query to execute on the cursor
bindings (tuple, optional): Tuple of query bindings. Defaults to ().
"""
start = timer()
if not self._cursor:
raise AttributeError(
f"Must set the _cursor attribute on the {self.__class__.__name__} class before calling the 'statement' method."
)
self._cursor.execute(query, bindings)
end = "{:.2f}".format(timer() - start)
if self.full_details and self.full_details.get("log_queries", False):
self.log(query, bindings, query_time=end)
def has_global_connection(self):
return self.name in ConnectionResolver().get_global_connections()
def get_global_connection(self):
return ConnectionResolver().get_global_connections()[self.name]
def enable_query_log(self):
self.full_details["log_queries"] = True
def disable_query_log(self):
self.full_details["log_queries"] = False
def format_cursor_results(self, cursor_result):
return cursor_result
def set_cursor(self):
self._cursor = self._connection.cursor()
return self
def select_many(self, query, bindings, amount):
self.set_cursor()
self.statement(query)
if not self.open:
self.make_connection()
result = self.format_cursor_results(self._cursor.fetchmany(amount))
while result:
yield result
result = self.format_cursor_results(self._cursor.fetchmany(amount))
def enable_disable_foreign_keys(self):
foreign_keys = self.full_details.get("foreign_keys")
platform = self.get_default_platform()()
if foreign_keys:
self._connection.execute(platform.enable_foreign_key_constraints())
elif foreign_keys is not None:
self._connection.execute(platform.disable_foreign_key_constraints())
================================================
FILE: src/masoniteorm/connections/ConnectionFactory.py
================================================
from ..config import load_config
class ConnectionFactory:
"""Class for controlling the registration and creation of connection types."""
_connections = {}
def __init__(self, config_path=None):
self.config_path = config_path
@classmethod
def register(cls, key, connection):
"""Registers new connections
Arguments:
key {key} -- The key or driver name you want assigned to this connection
connection {masoniteorm.connections.BaseConnection} -- An instance of a BaseConnection class.
Returns:
cls
"""
cls._connections.update({key: connection})
return cls
def make(self, key):
"""Makes already registered connections
Arguments:
key {string} -- The name of the connection you want to make
Raises:
Exception: Raises exception if there are no driver keys that match
Returns:
masoniteorm.connection.BaseConnection -- Returns an instance of a BaseConnection class.
"""
DB = load_config(config_path=self.config_path).DB
connections = DB.get_connection_details()
if key == "default":
connection_details = connections.get(connections.get("default"))
connection = self._connections.get(connection_details.get("driver"))
else:
connection_details = connections.get(key)
connection = self._connections.get(key)
if connection:
return connection
raise Exception(
"The '{connection}' connection does not exist".format(connection=key)
)
================================================
FILE: src/masoniteorm/connections/ConnectionResolver.py
================================================
from contextlib import contextmanager
class ConnectionResolver:
_connection_details = {}
_connections = {}
_morph_map = {}
def __init__(self, config_path=None):
from ..connections import (
SQLiteConnection,
PostgresConnection,
MySQLConnection,
MSSQLConnection,
)
self.config_path = config_path
from ..connections import ConnectionFactory
self.connection_factory = ConnectionFactory(config_path=config_path)
self.register(SQLiteConnection)
self.register(PostgresConnection)
self.register(MySQLConnection)
self.register(MSSQLConnection)
def morph_map(self, map):
self._morph_map = map
return self
def set_connection_details(self, connection_details):
self.__class__._connection_details = connection_details
return self
def get_connection_details(self):
return self._connection_details
def set_connection_option(self, connection: str, options: dict):
self._connection_details.get(connection).update(options)
return self
def get_global_connections(self):
return self._connections
def remove_global_connection(self, name=None):
self._connections.pop(name)
def register(self, connection):
self.connection_factory.register(connection.name, connection)
def begin_transaction(self, name=None):
if name is None:
name = self.get_connection_details()["default"]
driver = self.get_connection_details()[name].get("driver")
connection = (
self.connection_factory.make(driver)(
**self.get_connection_information(name)
)
.make_connection()
.begin()
)
self.__class__._connections.update({name: connection})
return connection
def commit(self, name=None):
if name is None:
name = self.get_connection_details()["default"]
connection = self.get_global_connections()[name]
self.remove_global_connection(name)
connection.commit()
def rollback(self, name=None):
if name is None:
name = self.get_connection_details()["default"]
connection = self.get_global_connections()[name]
self.remove_global_connection(name)
connection.rollback()
@contextmanager
def transaction(self, name=None):
self.begin_transaction(name)
try:
yield self
except Exception:
self.rollback(name)
raise
try:
self.commit(name)
except Exception:
self.rollback(name)
raise
def get_connection_information(self, name):
details = self.get_connection_details()
return {
"host": details.get(name, {}).get("host"),
"database": details.get(name, {}).get("database"),
"user": details.get(name, {}).get("user"),
"port": details.get(name, {}).get("port"),
"password": details.get(name, {}).get("password"),
"prefix": details.get(name, {}).get("prefix"),
"options": details.get(name, {}).get("options", {}),
"full_details": details.get(name, {}),
}
def get_schema_builder(self, connection="default", schema=None):
from ..schema import Schema
return Schema(
connection=connection,
connection_details=self.get_connection_details(),
schema=schema,
)
def get_query_builder(self, connection="default"):
from ..query import QueryBuilder
return QueryBuilder(
connection=connection, connection_details=self.get_connection_details()
)
def statement(self, query, bindings=(), connection="default"):
return self.get_query_builder().on(connection).statement(query, bindings)
================================================
FILE: src/masoniteorm/connections/MSSQLConnection.py
================================================
from ..exceptions import DriverNotFound
from .BaseConnection import BaseConnection
from ..query.grammars import MSSQLGrammar
from ..schema.platforms import MSSQLPlatform
from ..query.processors import MSSQLPostProcessor
from ..exceptions import QueryException
CONNECTION_POOL = []
class MSSQLConnection(BaseConnection):
"""MSSQL Connection class."""
name = "mssql"
def __init__(
self,
host=None,
database=None,
user=None,
port=None,
password=None,
prefix=None,
options=None,
full_details=None,
name=None,
):
self.host = host
if port:
self.port = int(port)
else:
self.port = port
self.database = database
self.user = user
self.password = password
self.prefix = prefix
self.full_details = full_details or {}
self.options = options or {}
self._cursor = None
self.transaction_level = 0
self.open = 0
if name:
self.name = name
def make_connection(self):
"""This sets the connection on the connection class"""
try:
import pyodbc
except ModuleNotFoundError:
raise DriverNotFound(
"You must have the 'pyodbc' package installed to make a connection to Microsoft SQL Server. Please install it using 'pip install pyodbc'"
)
if self.has_global_connection():
return self.get_global_connection()
driver = self.options.get("driver", "ODBC Driver 17 for SQL Server")
integrated_security = self.options.get("integrated_security")
connection_timeout = str(self.options.get("connection_timeout", "30"))
authentication = self.options.get("authentication")
instance = self.options.get("instance", "")
trusted_connection = self.options.get("trusted_connection")
if instance:
instance = "\\" + instance
self._connection = pyodbc.connect(
f"DRIVER={driver};SERVER={self.host}{instance if instance else ''},{self.port};Connection Timeout={connection_timeout};DATABASE={self.database}{f';Integrated Security={integrated_security}' if integrated_security else ''};UID={self.user};PWD={self.password}{f';Trusted_Connection={trusted_connection}' if trusted_connection else ''}{f';Authentication={authentication}' if authentication else ''}",
autocommit=True,
)
self.enable_disable_foreign_keys()
self.open = 1
return self
def get_database_name(self):
return self.database
@classmethod
def get_default_query_grammar(cls):
return MSSQLGrammar
@classmethod
def get_default_platform(cls):
return MSSQLPlatform
@classmethod
def get_default_post_processor(cls):
return MSSQLPostProcessor
def reconnect(self):
pass
def commit(self):
"""Transaction"""
if self.get_transaction_level() == 1:
self._connection.commit()
self._connection.autocommit = True
self.transaction_level -= 1
def begin(self):
"""MSSQL Transaction"""
self._connection.autocommit = False
self.transaction_level += 1
return self
def rollback(self):
"""Transaction"""
if self.get_transaction_level() == 1:
self._connection.rollback()
self._connection.autocommit = True
self.transaction_level -= 1
def get_transaction_level(self):
"""Transaction"""
return self.transaction_level
def get_cursor(self):
return self._cursor
def query(self, query, bindings=(), results="*"):
"""Make the actual query that will reach the database and come back with a result.
Arguments:
query {string} -- A string query. This could be a qmarked string or a regular query.
bindings {tuple} -- A tuple of bindings
Keyword Arguments:
results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll'
else it will return 'fetchOne' and return a single record. (default: {"*"})
Returns:
dict|None -- Returns a dictionary of results or None
"""
try:
if not self.open:
self.make_connection()
self._cursor = self._connection.cursor()
with self._cursor as cursor:
if isinstance(query, list) and not self._dry:
for q in query:
self.statement(q, ())
return
query = query.replace("'?'", "?")
self.statement(query, bindings)
if results == 1:
if not cursor.description:
return {}
columnNames = [column[0] for column in cursor.description]
result = cursor.fetchone()
return dict(zip(columnNames, result)) if result is not None else {}
else:
if not cursor.description:
return {}
return self.format_cursor_results(cursor.fetchall())
return {}
except Exception as e:
raise QueryException(str(e)) from e
finally:
if self.get_transaction_level() <= 0:
self._connection.close()
def format_cursor_results(self, cursor_result):
columnNames = [column[0] for column in self.get_cursor().description]
results = []
for record in cursor_result:
results.append(dict(zip(columnNames, record)))
return results
================================================
FILE: src/masoniteorm/connections/MySQLConnection.py
================================================
from ..exceptions import DriverNotFound
from .BaseConnection import BaseConnection
from ..query.grammars import MySQLGrammar
from ..schema.platforms import MySQLPlatform
from ..query.processors import MySQLPostProcessor
from ..exceptions import QueryException
CONNECTION_POOL = []
class MySQLConnection(BaseConnection):
"""MYSQL Connection class."""
name = "mysql"
_dry = False
def __init__(
self,
host=None,
database=None,
user=None,
port=None,
password=None,
prefix=None,
options=None,
full_details=None,
name=None,
):
self.host = host
self.port = port
if str(port).isdigit():
self.port = int(self.port)
self.database = database
self.user = user
self.password = password
self.prefix = prefix
self.full_details = full_details or {}
self.connection_pool_size = full_details.get("connection_pooling_max_size", 100)
self.options = options or {}
self._cursor = None
self.open = 0
self.transaction_level = 0
if name:
self.name = name
def make_connection(self):
"""This sets the connection on the connection class"""
if self._dry:
return
if self.has_global_connection():
return self.get_global_connection()
# Check if there is an available connection in the pool
self._connection = self.create_connection()
self.enable_disable_foreign_keys()
return self
def close_connection(self):
if (
self.full_details.get("connection_pooling_enabled")
and len(CONNECTION_POOL) < self.connection_pool_size
):
CONNECTION_POOL.append(self._connection)
self.open = 0
self._connection = None
def create_connection(self, autocommit=True):
try:
import pymysql
except ModuleNotFoundError:
raise DriverNotFound(
"You must have the 'pymysql' package "
"installed to make a connection to MySQL. "
"Please install it using 'pip install pymysql'"
)
import pendulum
import pymysql.converters
pymysql.converters.conversions[pendulum.DateTime] = (
pymysql.converters.escape_datetime
)
# Initialize the connection pool if the option is set
initialize_size = self.full_details.get("connection_pooling_min_size")
if initialize_size and len(CONNECTION_POOL) < initialize_size:
for _ in range(initialize_size - len(CONNECTION_POOL)):
connection = pymysql.connect(
cursorclass=pymysql.cursors.DictCursor,
autocommit=autocommit,
host=self.host,
user=self.user,
password=self.password,
port=self.port,
database=self.database,
**self.options
)
CONNECTION_POOL.append(connection)
if (
self.full_details.get("connection_pooling_enabled")
and CONNECTION_POOL
and len(CONNECTION_POOL) > 0
):
connection = CONNECTION_POOL.pop()
else:
connection = pymysql.connect(
cursorclass=pymysql.cursors.DictCursor,
autocommit=autocommit,
host=self.host,
user=self.user,
password=self.password,
port=self.port,
database=self.database,
**self.options
)
connection.close = self.close_connection
self.open = 1
return connection
def reconnect(self):
self._connection.connect()
return self
@classmethod
def get_default_query_grammar(cls):
return MySQLGrammar
@classmethod
def get_default_platform(cls):
return MySQLPlatform
@classmethod
def get_default_post_processor(cls):
return MySQLPostProcessor
def get_database_name(self):
return self.database
def commit(self):
"""Transaction"""
self._connection.commit()
self.transaction_level -= 1
if self.get_transaction_level() <= 0:
self.open = 0
self._connection.close()
def dry(self):
"""Transaction"""
self._dry = True
return self
def begin(self):
"""Mysql Transaction"""
self._connection.begin()
self.transaction_level += 1
return self
def rollback(self):
"""Transaction"""
self._connection.rollback()
self.transaction_level -= 1
if self.get_transaction_level() <= 0:
self.open = 0
self._connection.close()
def get_transaction_level(self):
"""Transaction"""
return self.transaction_level
def get_cursor(self):
return self._cursor
def query(self, query, bindings=(), results="*"):
"""Make the actual query that
will reach the database and come back with a result.
Arguments:
query {string} -- A string query.
This could be a qmarked string or a regular query.
bindings {tuple} -- A tuple of bindings
Keyword Arguments:
results {str|1} -- If the results is equal to an
asterisks it will call 'fetchAll'
else it will return 'fetchOne' and
return a single record. (default: {"*"})
Returns:
dict|None -- Returns a dictionary of results or None
"""
if self._dry:
return {}
if not self.open:
if self._connection is None:
self._connection = self.create_connection()
self._connection.connect()
self._cursor = self._connection.cursor()
try:
with self._cursor as cursor:
if isinstance(query, list):
for q in query:
q = q.replace("'?'", "%s")
self.statement(q, ())
return
query = query.replace("'?'", "%s")
self.statement(query, bindings)
if results == 1:
return self.format_cursor_results(cursor.fetchone())
else:
return self.format_cursor_results(cursor.fetchall())
except Exception as e:
raise QueryException(str(e)) from e
finally:
self._cursor.close()
if self.get_transaction_level() <= 0:
self.open = 0
self._connection.close()
================================================
FILE: src/masoniteorm/connections/PostgresConnection.py
================================================
from ..exceptions import DriverNotFound
from .BaseConnection import BaseConnection
from ..query.grammars import PostgresGrammar
from ..schema.platforms import PostgresPlatform
from ..query.processors import PostgresPostProcessor
from ..exceptions import QueryException
CONNECTION_POOL = []
class PostgresConnection(BaseConnection):
"""Postgres Connection class."""
name = "postgres"
def __init__(
self,
host=None,
database=None,
user=None,
port=None,
password=None,
prefix=None,
options=None,
full_details=None,
name=None,
):
self.host = host
if port:
self.port = int(port)
else:
self.port = port
self.database = database
self.user = user
self.password = password
self.prefix = prefix
self.full_details = full_details or {}
self.connection_pool_size = full_details.get("connection_pooling_max_size", 100)
self.options = options or {}
self._cursor = None
self.transaction_level = 0
self.open = 0
self.schema = None
if name:
self.name = name
def make_connection(self):
"""This sets the connection on the connection class"""
try:
import psycopg2 # noqa F401
except ModuleNotFoundError:
raise DriverNotFound(
"You must have the 'psycopg2' package installed to make a connection to Postgres. Please install it using 'pip install psycopg2-binary'"
)
if self.has_global_connection():
return self.get_global_connection()
self._connection = self.create_connection()
self._connection.autocommit = True
self.enable_disable_foreign_keys()
self.open = 1
return self
def create_connection(self):
import psycopg2
# Initialize the connection pool if the option is set
initialize_size = self.full_details.get("connection_pooling_min_size")
if (
self.full_details.get("connection_pooling_enabled")
and initialize_size
and len(CONNECTION_POOL) < initialize_size
):
for _ in range(initialize_size - len(CONNECTION_POOL)):
connection = psycopg2.connect(
database=self.database,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
sslmode=self.options.get("sslmode"),
sslcert=self.options.get("sslcert"),
sslkey=self.options.get("sslkey"),
sslrootcert=self.options.get("sslrootcert"),
options=(
f"-c search_path={self.schema or self.full_details.get('schema')}"
if self.schema or self.full_details.get("schema")
else ""
),
)
CONNECTION_POOL.append(connection)
if (
self.full_details.get("connection_pooling_enabled")
and CONNECTION_POOL
and len(CONNECTION_POOL) > 0
):
connection = CONNECTION_POOL.pop()
else:
connection = psycopg2.connect(
database=self.database,
user=self.user,
password=self.password,
host=self.host,
port=self.port,
sslmode=self.options.get("sslmode"),
sslcert=self.options.get("sslcert"),
sslkey=self.options.get("sslkey"),
sslrootcert=self.options.get("sslrootcert"),
options=(
f"-c search_path={self.schema or self.full_details.get('schema')}"
if self.schema or self.full_details.get("schema")
else ""
),
)
return connection
def get_database_name(self):
return self.database
@classmethod
def get_default_query_grammar(cls):
return PostgresGrammar
@classmethod
def get_default_platform(cls):
return PostgresPlatform
@classmethod
def get_default_post_processor(cls):
return PostgresPostProcessor
def reconnect(self):
pass
def close_connection(self):
if (
self.full_details.get("connection_pooling_enabled")
and len(CONNECTION_POOL) < self.connection_pool_size
):
CONNECTION_POOL.append(self._connection)
else:
self._connection.close()
self._connection = None
def commit(self):
"""Transaction"""
if self.get_transaction_level() == 1:
self._connection.commit()
self._connection.autocommit = True
self.transaction_level -= 1
def begin(self):
"""Postgres Transaction"""
self._connection.autocommit = False
self.transaction_level += 1
return self
def rollback(self):
"""Transaction"""
if self.get_transaction_level() == 1:
self._connection.rollback()
self._connection.autocommit = True
self.transaction_level -= 1
def get_transaction_level(self):
"""Transaction"""
return self.transaction_level
def set_cursor(self):
from psycopg2.extras import RealDictCursor
self._cursor = self._connection.cursor(cursor_factory=RealDictCursor)
return self._cursor
def query(self, query, bindings=(), results="*"):
"""Make the actual query that will reach the database and come back with a result.
Arguments:
query {string} -- A string query. This could be a qmarked string or a regular query.
bindings {tuple} -- A tuple of bindings
Keyword Arguments:
results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll'
else it will return 'fetchOne' and return a single record. (default: {"*"})
Returns:
dict|None -- Returns a dictionary of results or None
"""
try:
if not self._connection or self._connection.closed:
self.make_connection()
self.set_cursor()
with self._cursor as cursor:
if isinstance(query, list) and not self._dry:
for q in query:
self.statement(q, ())
return
query = query.replace("'?'", "%s")
self.statement(query, bindings)
if results == 1:
return dict(cursor.fetchone() or {})
else:
if "SELECT" in cursor.statusmessage:
return cursor.fetchall()
return {}
except Exception as e:
raise QueryException(str(e)) from e
finally:
if self.get_transaction_level() <= 0:
self.open = 0
self.close_connection()
# self._connection.close()
================================================
FILE: src/masoniteorm/connections/SQLiteConnection.py
================================================
from ..query.grammars import SQLiteGrammar
from .BaseConnection import BaseConnection
from ..schema.platforms import SQLitePlatform
from ..query.processors import SQLitePostProcessor
from ..exceptions import DriverNotFound, QueryException
import re
def regexp(expr, item):
reg = re.compile(expr)
return reg.search(item) is not None
class SQLiteConnection(BaseConnection):
"""SQLite Connection class."""
name = "sqlite"
_connection = None
def __init__(
self,
host=None,
database=None,
user=None,
port=None,
password=None,
prefix=None,
full_details=None,
options=None,
name=None,
):
self.host = host
if port:
self.port = int(port)
else:
self.port = port
self.database = database
self.user = user
self.password = password
self.prefix = prefix
self.full_details = full_details or {}
self.options = options or {}
self._cursor = None
self.transaction_level = 0
self.open = 0
if name:
self.name = name
def make_connection(self):
"""This sets the connection on the connection class"""
try:
import sqlite3
except ModuleNotFoundError:
raise DriverNotFound(
"You must have the 'sqlite3' package installed to make a connection to SQLite."
)
if self.has_global_connection():
return self.get_global_connection()
self._connection = sqlite3.connect(self.database, isolation_level=None)
self._connection.create_function("REGEXP", 2, regexp)
self._connection.row_factory = sqlite3.Row
self.enable_disable_foreign_keys()
self.open = 1
return self
@classmethod
def get_default_query_grammar(cls):
return SQLiteGrammar
@classmethod
def get_default_platform(cls):
return SQLitePlatform
@classmethod
def get_default_post_processor(cls):
return SQLitePostProcessor
def get_database_name(self):
return self.database
def reconnect(self):
pass
def commit(self):
"""Transaction"""
if self.get_transaction_level() == 1:
self.transaction_level -= 1
self._connection.commit()
self._connection.isolation_level = None
self._connection.close()
self.open = 0
self.transaction_level -= 1
return self
def begin(self):
"""Sqlite Transaction"""
self._connection.isolation_level = "DEFERRED"
self.transaction_level += 1
return self
def rollback(self):
"""Transaction"""
if self.get_transaction_level() == 1:
self.transaction_level -= 1
self._connection.rollback()
self._connection.close()
self.open = 0
self.transaction_level -= 1
return self
def get_cursor(self):
return self._cursor
def get_transaction_level(self):
return self.transaction_level
def query(self, query, bindings=(), results="*"):
"""Make the actual query that will reach the database and come back with a result.
Arguments:
query {string} -- A string query. This could be a qmarked string or a regular query.
bindings {tuple} -- A tuple of bindings
Keyword Arguments:
results {str|1} -- If the results is equal to an asterisks it will call 'fetchAll'
else it will return 'fetchOne' and return a single record. (default: {"*"})
Returns:
dict|None -- Returns a dictionary of results or None
"""
if not self.open:
self.make_connection()
try:
self._cursor = self._connection.cursor()
if isinstance(query, list):
for query in query:
self.statement(query)
else:
query = query.replace("'?'", "?")
self.statement(query, bindings)
if results == 1:
result = [dict(row) for row in self._cursor.fetchall()]
if result:
return result[0]
else:
return [dict(row) for row in self._cursor.fetchall()]
except Exception as e:
raise QueryException(str(e)) from e
finally:
if self.get_transaction_level() <= 0:
self._connection.close()
self.open = 0
def format_cursor_results(self, cursor_result):
return [dict(row) for row in cursor_result]
def select_many(self, query, bindings, amount):
self._cursor = self._connection.cursor()
self.statement(query)
if not self.open:
self.make_connection()
result = self.format_cursor_results(self._cursor.fetchmany(amount))
while result:
yield result
result = self.format_cursor_results(self._cursor.fetchmany(amount))
================================================
FILE: src/masoniteorm/connections/__init__.py
================================================
from .ConnectionResolver import ConnectionResolver
from .ConnectionFactory import ConnectionFactory
from .MySQLConnection import MySQLConnection
from .PostgresConnection import PostgresConnection
from .SQLiteConnection import SQLiteConnection
from .MSSQLConnection import MSSQLConnection
================================================
FILE: src/masoniteorm/exceptions.py
================================================
class DriverNotFound(Exception):
pass
class ModelNotFound(Exception):
pass
class HTTP404(Exception):
pass
class ConnectionNotRegistered(Exception):
pass
class QueryException(Exception):
pass
class MigrationNotFound(Exception):
pass
class ConfigurationNotFound(Exception):
pass
class InvalidUrlConfiguration(Exception):
pass
class MultipleRecordsFound(Exception):
pass
class InvalidArgument(Exception):
pass
================================================
FILE: src/masoniteorm/expressions/__init__.py
================================================
from .expressions import Raw
from .expressions import JoinClause
================================================
FILE: src/masoniteorm/expressions/expressions.py
================================================
from ..helpers.misc import deprecated
class QueryExpression:
"""A helper class to manage query expressions."""
def __init__(
self,
column,
equality,
value,
value_type="value",
keyword=None,
raw=False,
bindings=(),
):
self.column = column
self.equality = equality
self.value = value
self.value_type = value_type
self.keyword = keyword
self.raw = raw
self.bindings = bindings
class HavingExpression:
"""A helper class to manage having expressions."""
def __init__(self, column, equality=None, value=None, raw=False):
self.column = column
self.raw = raw
if equality and not value:
value = equality
equality = "="
self.equality = equality
self.value = value
self.value_type = "having"
class FromTable:
"""A helper class to manage having expressions."""
def __init__(self, name, raw=False):
self.name = name
self.raw = raw
class UpdateQueryExpression:
"""A helper class to manage update expressions."""
def __init__(self, column, value=None, update_type="keyvalue"):
self.column = column
self.value = value
self.update_type = update_type
class BetweenExpression:
"""A helper class to manage where between expressions."""
def __init__(self, column, low, high, equality="BETWEEN"):
self.column = column
self.low = low
self.high = high
self.equality = equality
self.value = None
self.value_type = "BETWEEN"
self.raw = False
class SubSelectExpression:
"""A helper class to manage subselect expressions."""
def __init__(self, builder):
self.builder = builder
class SubGroupExpression:
"""A helper class to manage subgroup expressions."""
def __init__(self, builder, alias="group"):
self.builder = builder
self.alias = alias
class SelectExpression:
"""A helper class to manage select expressions."""
def __init__(self, column, raw=False):
self.column = column.strip()
self.alias = None
self.raw = raw
if raw is False and " as " in self.column:
self.column, self.alias = self.column.split(" as ")
self.column = self.column.strip()
self.alias = self.alias.strip()
class OrderByExpression:
"""A helper class to manage select expressions."""
def __init__(self, column, direction="ASC", raw=False, bindings=()):
self.column = column.strip()
self.raw = raw
self.direction = direction
self.bindings = bindings
if raw is False:
if self.column.endswith(" desc"):
self.column = self.column.split(" desc")[0].strip()
self.direction = "DESC"
if self.column.endswith(" asc"):
self.column = self.column.split(" asc")[0].strip()
self.direction = "ASC"
class GroupByExpression:
"""A helper class to manage select expressions."""
def __init__(self, column=None, raw=False, bindings=()):
self.column = column.strip()
self.raw = raw
self.bindings = bindings
class AggregateExpression:
def __init__(self, aggregate=None, column=None, alias=False):
self.aggregate = aggregate
self.column = column.strip()
self.alias = alias
if " as " in self.column:
self.column, self.alias = self.column.split(" as ")
class Raw:
def __init__(self, expression):
self.expression = expression
class JoinClause:
def __init__(self, table, clause="join"):
self.table = table
self.alias = None
self.clause = clause
self.on_clauses = []
if " as " in self.table:
self.table = table.split(" as ")[0]
self.alias = table.split(" as ")[1]
def on(self, column1, equality, column2):
self.on_clauses.append(OnClause(column1, equality, column2))
return self
def or_on(self, column1, equality, column2):
self.on_clauses.append(OnClause(column1, equality, column2, "or"))
return self
def on_value(self, column, *args):
equality, value = self._extract_operator_value(*args)
self.on_clauses += ((OnValueClause(column, equality, value, "value")),)
return self
def or_on_value(self, column, *args):
equality, value = self._extract_operator_value(*args)
self.on_clauses += (
(OnValueClause(column, equality, value, "value", operator="or")),
)
return self
def on_null(self, column):
"""Specifies an ON expression where the column IS NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self.on_clauses += ((OnValueClause(column, "=", None, "NULL")),)
return self
def on_not_null(self, column: str):
"""Specifies an ON expression where the column IS NOT NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self.on_clauses += ((OnValueClause(column, "=", True, "NOT NULL")),)
return self
def or_on_null(self, column):
"""Specifies an ON expression where the column IS NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self.on_clauses += ((OnValueClause(column, "=", None, "NULL", operator="or")),)
return self
def or_on_not_null(self, column: str):
"""Specifies an ON expression where the column IS NOT NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self.on_clauses += (
(OnValueClause(column, "=", True, "NOT NULL", operator="or")),
)
return self
@deprecated("Using where() in a Join clause has been superceded by on_value()")
def where(self, column, *args):
return self.on_value(column, *args)
def _extract_operator_value(self, *args):
operators = ["=", ">", ">=", "<", "<=", "!=", "<>", "like", "not like"]
operator = operators[0]
value = None
if (len(args)) >= 2:
operator = args[0]
value = args[1]
elif len(args) == 1:
value = args[0]
if operator not in operators:
raise ValueError(
"Invalid comparison operator. The operator can be %s"
% ", ".join(operators)
)
return operator, value
def get_on_clauses(self):
return self.on_clauses
class OnClause:
def __init__(self, column1, equality, column2, operator="and"):
self.column1 = column1
self.column2 = column2
self.equality = equality
self.operator = operator
class OnValueClause:
"""A helper class to manage ON expressions in joins with a value."""
def __init__(
self,
column,
equality,
value,
value_type="value",
keyword=None,
raw=False,
bindings=(),
operator="and",
):
self.column = column
self.equality = equality
self.value = value
self.value_type = value_type
self.keyword = keyword
self.raw = raw
self.bindings = bindings
self.operator = operator
================================================
FILE: src/masoniteorm/factories/Factory.py
================================================
from faker import Faker
import random
class Factory:
_factories = {}
_after_creates = {}
_faker = None
@property
def faker(self):
if not Factory._faker:
Factory._faker = Faker()
random.seed()
Factory._faker.seed_instance(random.randint(1, 10000))
return Factory._faker
def __init__(self, model, number=1):
self.model = model
self.number = number
def make(self, dictionary=None, name="default"):
if dictionary is None:
dictionary = {}
if self.number == 1 and not isinstance(dictionary, list):
called = self._factories[self.model][name](self.faker)
called.update(dictionary)
model = self.model.hydrate(called)
self.run_after_creates(model)
return model
elif isinstance(dictionary, list):
results = []
for index in range(0, len(dictionary)):
called = self._factories[self.model][name](self.faker)
called.update(dictionary)
results.append(called)
models = self.model.hydrate(results)
for model in models:
self.run_after_creates(model)
return models
else:
results = []
for index in range(0, self.number):
called = self._factories[self.model][name](self.faker)
called.update(dictionary)
results.append(called)
models = self.model.hydrate(results)
for model in models:
self.run_after_creates(model)
return models
def create(self, dictionary=None, name="default"):
if dictionary is None:
dictionary = {}
if self.number == 1 and not isinstance(dictionary, list):
called = self._factories[self.model][name](self.faker)
called.update(dictionary)
model = self.model.create(called)
self.run_after_creates(model)
return model
elif isinstance(dictionary, list):
results = []
for index in range(0, len(dictionary)):
called = self._factories[self.model][name](self.faker)
called.update(dictionary)
results.append(called)
models = self.model.create(results)
for model in models:
self.run_after_creates(model)
return models
else:
full_collection = []
for index in range(0, self.number):
called = self._factories[self.model][name](self.faker)
called.update(dictionary)
full_collection.append(called)
model = self.model.create(called)
self.run_after_creates(model)
return self.model.hydrate(full_collection)
@classmethod
def register(cls, model, call, name="default"):
if model not in cls._factories:
cls._factories[model] = {name: call}
else:
cls._factories[model][name] = call
@classmethod
def after_creating(cls, model, call, name="default"):
if model not in cls._after_creates:
cls._after_creates[model] = {name: call}
else:
cls._after_creates[model][name] = call
def run_after_creates(self, model):
if self.model not in self._after_creates:
return model
for name, callback in self._after_creates[self.model].items():
callback(model, self.faker)
================================================
FILE: src/masoniteorm/factories/__init__.py
================================================
from .Factory import Factory
================================================
FILE: src/masoniteorm/helpers/__init__.py
================================================
================================================
FILE: src/masoniteorm/helpers/misc.py
================================================
"""Module for miscellaneous helper methods."""
import warnings
def deprecated(message):
warnings.simplefilter("default", DeprecationWarning)
def deprecated_decorator(func):
def deprecated_func(*args, **kwargs):
warnings.warn(
"{} is a deprecated function. {}".format(func.__name__, message),
category=DeprecationWarning,
stacklevel=2,
)
return func(*args, **kwargs)
return deprecated_func
return deprecated_decorator
================================================
FILE: src/masoniteorm/migrations/Migration.py
================================================
import os
from os import listdir
from os.path import isfile, join
from pydoc import locate
from inflection import camelize
from ..models.MigrationModel import MigrationModel
from ..schema import Schema
from ..config import load_config
from timeit import default_timer as timer
class Migration:
def __init__(
self,
connection="default",
dry=False,
command_class=None,
migration_directory="databases/migrations",
config_path=None,
schema=None,
):
self.connection = connection
self.migration_directory = migration_directory
self.last_migrations_ran = []
self.command_class = command_class
self.schema_name = schema
DB = load_config(config_path).DB
DATABASES = DB.get_connection_details()
self.schema = Schema(
connection=connection,
connection_details=DATABASES,
dry=dry,
schema=self.schema_name,
)
self.migration_model = MigrationModel.on(self.connection)
if self.schema_name:
self.migration_model.set_schema(self.schema_name)
def create_table_if_not_exists(self):
if not self.schema.has_table("migrations"):
with self.schema.create("migrations") as table:
table.increments("migration_id")
table.string("migration")
table.integer("batch")
return True
return False
def get_unran_migrations(self):
directory_path = os.path.join(os.getcwd(), self.migration_directory)
all_migrations = [
f.replace(".py", "")
for f in listdir(directory_path)
if isfile(join(directory_path, f))
and f != "__init__.py"
and not f.startswith(".")
]
all_migrations.sort()
unran_migrations = []
database_migrations = self.migration_model.all()
for migration in all_migrations:
if migration not in database_migrations.pluck("migration"):
unran_migrations.append(migration)
return unran_migrations
def get_rollback_migrations(self):
return (
self.migration_model.where("batch", self.migration_model.all().max("batch"))
.order_by("migration_id", "desc")
.get()
.pluck("migration")
)
def get_all_migrations(self, reverse=False):
if reverse:
return (
self.migration_model.order_by("migration_id", "desc")
.get()
.pluck("migration")
)
return self.migration_model.all().pluck("migration")
def get_last_batch_number(self):
return self.migration_model.select("batch").get().max("batch")
def delete_migration(self, file_path):
return self.migration_model.where("migration", file_path).delete()
def locate(self, file_name):
migration_name = camelize("_".join(file_name.split("_")[4:]).replace(".py", ""))
file_name = file_name.replace(".py", "")
migration_directory = self.migration_directory.replace("/", ".").replace(
"\\", "."
)
return locate(f"{migration_directory}.{file_name}.{migration_name}")
def get_ran_migrations(self):
directory_path = os.path.join(os.getcwd(), self.migration_directory)
all_migrations = [
f.replace(".py", "")
for f in listdir(directory_path)
if isfile(join(directory_path, f))
and f != "__init__.py"
and not f.startswith(".")
]
all_migrations.sort()
ran = []
database_migrations = self.migration_model.all()
for migration in all_migrations:
matched_migration = database_migrations.where(
"migration", migration
).first()
if matched_migration:
ran.append(
{
"migration_file": matched_migration.migration,
"batch": matched_migration.batch,
}
)
return ran
def migrate(self, migration="all", output=False):
default_migrations = self.get_unran_migrations()
migrations = default_migrations if migration == "all" else [migration]
batch = self.get_last_batch_number() + 1
for migration in migrations:
try:
migration_class = self.locate(migration)
except TypeError:
self.command_class.line(f"Not Found: {migration}")
continue
self.last_migrations_ran.append(migration)
if self.command_class:
self.command_class.line(
f"Migrating: {migration}"
)
migration_class = migration_class(
connection=self.connection, schema=self.schema_name
)
if output:
migration_class.schema.dry()
start = timer()
migration_class.up()
duration = "{:.2f}".format(timer() - start)
if output:
if self.command_class:
table = self.command_class.table()
table.set_header_row(["SQL"])
sql = migration_class.schema._blueprint.to_sql()
if isinstance(sql, list):
sql = ",".join(sql)
table.set_rows([[sql]])
table.render(self.command_class.io)
continue
else:
print(migration_class.schema._blueprint.to_sql())
if self.command_class:
self.command_class.line(
f"Migrated: {migration} ({duration}s)"
)
self.migration_model.create(
{"batch": batch, "migration": migration.replace(".py", "")}
)
def rollback(self, migration="all", output=False):
default_migrations = self.get_rollback_migrations()
migrations = default_migrations if migration == "all" else [migration]
for migration in migrations:
if migration.endswith(".py"):
migration = migration.replace(".py", "")
if self.command_class:
self.command_class.line(
f"Rolling back: {migration}"
)
try:
migration_class = self.locate(migration)
except TypeError:
self.command_class.line(f"Not Found: {migration}")
continue
migration_class = migration_class(
connection=self.connection, schema=self.schema_name
)
if output:
migration_class.schema.dry()
start = timer()
migration_class.down()
duration = "{:.2f}".format(timer() - start)
if output:
if self.command_class:
table = self.command_class.table()
table.set_header_row(["SQL"])
if (
hasattr(migration_class.schema, "_blueprint")
and migration_class.schema._blueprint
):
sql = migration_class.schema._blueprint.to_sql()
if isinstance(sql, list):
sql = ",".join(sql)
table.set_rows([[sql]])
elif migration_class.schema._sql:
table.set_rows([[migration_class.schema._sql]])
table.render(self.command_class.io)
continue
else:
print(migration_class.schema._blueprint.to_sql())
self.delete_migration(migration)
if self.command_class:
self.command_class.line(
f"Rolled back: {migration} ({duration}s)"
)
def delete_migrations(self, migrations=None):
return self.migration_model.where_in("migration", migrations or []).delete()
def delete_last_batch(self):
return self.migration_model.where(
"batch", self.get_last_batch_number()
).delete()
def reset(self, migration="all"):
default_migrations = self.get_all_migrations(reverse=True)
migrations = default_migrations if migration == "all" else [migration]
if not len(migrations):
if self.command_class:
self.command_class.line("Nothing to reset")
else:
print("Nothing to reset")
for migration in migrations:
if self.command_class:
self.command_class.line(
f"Rolling back: {migration}"
)
try:
self.locate(migration)(
connection=self.connection, schema=self.schema_name
).down()
except TypeError:
self.command_class.line(f"Not Found: {migration}")
continue
# raise MigrationNotFound(f"Could not find {migration}")
self.delete_migration(migration)
if self.command_class:
self.command_class.line(
f"Rolled back: {migration}"
)
self.delete_migrations([migration])
if self.command_class:
self.command_class.line("")
def refresh(self, migration="all"):
self.reset(migration)
self.migrate(migration)
def drop_all_tables(self, ignore_fk=False):
if self.command_class:
self.command_class.line("Dropping all tables")
if ignore_fk:
self.schema.disable_foreign_key_constraints()
for table in self.schema.get_all_tables():
self.schema.drop(table)
if ignore_fk:
self.schema.enable_foreign_key_constraints()
if self.command_class:
self.command_class.line("All tables dropped")
def fresh(self, ignore_fk=False, migration="all"):
self.drop_all_tables(ignore_fk=ignore_fk)
self.create_table_if_not_exists()
if not self.get_unran_migrations():
if self.command_class:
self.command_class.line("Nothing to migrate")
return
self.migrate(migration)
================================================
FILE: src/masoniteorm/migrations/__init__.py
================================================
from .Migration import Migration
================================================
FILE: src/masoniteorm/models/MigrationModel.py
================================================
from .Model import Model
class MigrationModel(Model):
__table__ = "migrations"
__fillable__ = ["migration", "batch"]
__timestamps__ = None
__primary_key__ = "migration_id"
================================================
FILE: src/masoniteorm/models/Model.py
================================================
import inspect
import json
import logging
from datetime import date as datetimedate
from datetime import datetime
from datetime import time as datetimetime
from decimal import Decimal
from typing import Any, Dict
import pendulum
from inflection import tableize, underscore
from ..collection import Collection
from ..config import load_config
from ..exceptions import ModelNotFound
from ..observers import ObservesEvents
from ..query import QueryBuilder
from ..scopes import TimeStampsMixin
"""This is a magic class that will help using models like User.first() instead of having to instatiate a class like
User().first()
"""
class ModelMeta(type):
def __getattr__(self, attribute, *args, **kwargs):
"""This method is called between a Model and accessing a property. This is a quick and easy
way to instantiate a class before the first method is called. This is to avoid needing
to do this:
User().where(..)
and instead, with this class inherited as a meta class, we can do this:
User.where(...)
This class (potentially magically) instantiates the class even though we really didn't instantiate it.
Args:
attribute (string): The name of the attribute
Returns:
Model|mixed: An instantiated model's attribute
"""
instantiated = self()
return getattr(instantiated, attribute)
class BoolCast:
"""Casts a value to a boolean"""
def get(self, value):
return bool(value)
def set(self, value):
return bool(value)
class JsonCast:
"""Casts a value to JSON"""
def get(self, value):
if isinstance(value, str):
try:
return json.loads(value)
except ValueError:
return None
return value
def set(self, value):
if isinstance(value, str):
# make sure the string is valid JSON
json.loads(value)
return value
return json.dumps(value, default=str)
class IntCast:
"""Casts a value to a int"""
def get(self, value):
return int(value)
def set(self, value):
return int(value)
class FloatCast:
"""Casts a value to a float"""
def get(self, value):
return float(value)
def set(self, value):
return float(value)
class DateCast:
"""Casts a value to a float"""
def get(self, value):
return pendulum.parse(value).to_date_string()
def set(self, value):
return pendulum.parse(value).to_date_string()
class DecimalCast:
"""Casts a value to Decimal for accuracy"""
def get(self, value):
"""
Get the value
"""
if isinstance(value, Decimal):
return str(value)
return Decimal(str(value))
def set(self, value):
"""
Set the value
"""
return Decimal(str(value))
class Model(TimeStampsMixin, ObservesEvents, metaclass=ModelMeta):
"""The ORM Model class
Base Classes:
TimeStampsMixin (TimeStampsMixin): Adds scopes to add timestamps when something is inserted
metaclass (ModelMeta, optional): Helps instantiate a class when it hasn't been instantiated. Defaults to ModelMeta.
"""
__fillable__ = ["*"]
__guarded__ = []
__dry__ = False
__table__ = None
__connection__ = "default"
__resolved_connection__ = None
__selects__ = []
__observers__ = {}
__has_events__ = True
_booted = False
_scopes = {}
__primary_key__ = "id"
__primary_key_type__ = "int"
__casts__ = {}
__dates__ = []
__hidden__ = []
__relationship_hidden__ = {}
__visible__ = []
__timestamps__ = True
__timezone__ = "UTC"
__with__ = ()
__force_update__ = False
date_created_at = "created_at"
date_updated_at = "updated_at"
builder: QueryBuilder
"""Pass through will pass any method calls to the model directly through to the query builder.
Anytime one of these methods are called on the model it will actually be called on the query builder class.
"""
__passthrough__ = set(
(
"add_select",
"aggregate",
"all",
"avg",
"between",
"bulk_create",
"chunk",
"count",
"decrement",
"delete",
"distinct",
"doesnt_exist",
"doesnt_have",
"exists",
"find_or",
"find_or_404",
"find_or_fail",
"first_or_fail",
"first",
"first_where",
"first_or_create",
"force_update",
"from_",
"from_raw",
"get",
"get_table_schema",
"group_by_raw",
"group_by",
"has",
"having",
"having_raw",
"increment",
"in_random_order",
"join_on",
"join",
"joins",
"last",
"left_join",
"limit",
"lock_for_update",
"make_lock",
"max",
"min",
"new_from_builder",
"new",
"not_between",
"offset",
"on",
"or_where",
"or_where_null",
"order_by_raw",
"order_by",
"paginate",
"right_join",
"select_raw",
"select",
"set_global_scope",
"set_schema",
"shared_lock",
"simple_paginate",
"skip",
"statement",
"sum",
"table_raw",
"take",
"to_qmark",
"to_sql",
"truncate",
"update",
"when",
"where_between",
"where_column",
"where_date",
"or_where_doesnt_have",
"or_has",
"or_where_has",
"or_doesnt_have",
"or_where_not_exists",
"or_where_date",
"where_exists",
"where_from_builder",
"where_has",
"where_in",
"where_like",
"where_not_between",
"where_not_in",
"where_not_like",
"where_not_null",
"where_null",
"where_raw",
"without_global_scopes",
"where",
"where_doesnt_have",
"with_",
"with_count",
"latest",
"oldest",
"value",
)
)
__cast_map__ = {}
__internal_cast_map__ = {
"bool": BoolCast,
"json": JsonCast,
"int": IntCast,
"float": FloatCast,
"date": DateCast,
"decimal": DecimalCast,
}
def __init__(self):
self.__attributes__ = {}
self.__original_attributes__ = {}
self.__dirty_attributes__ = {}
if not hasattr(self, "__appends__"):
self.__appends__ = []
self._relationships = {}
self._global_scopes = {}
self.boot()
@classmethod
def get_primary_key(self):
"""Gets the primary key column
Returns:
mixed
"""
return self.__primary_key__
def get_primary_key_type(self):
"""Gets the primary key column type
Returns:
mixed
"""
return self.__primary_key_type__
def get_primary_key_value(self):
"""Gets the primary key value.
Raises:
AttributeError: Raises attribute error if the model does not have an
attribute with the primary key.
Returns:
str|int
"""
try:
return getattr(self, self.get_primary_key())
except AttributeError:
name = self.__class__.__name__
raise AttributeError(
f"class '{name}' has no attribute {self.get_primary_key()}. Did you set the primary key correctly on the model using the __primary_key__ attribute?"
)
def get_foreign_key(self):
"""Gets the foreign key based on this model name.
Args:
relationship (str): The relationship name.
Returns:
str
"""
return underscore(self.__class__.__name__ + "_" + self.get_primary_key())
def query(self):
return self.get_builder()
def get_builder(self):
if hasattr(self, "builder"):
return self.builder
self.builder = QueryBuilder(
connection=self.__connection__,
table=self.get_table_name(),
connection_details=self.get_connection_details(),
model=self,
scopes=self._scopes.get(self.__class__),
dry=self.__dry__,
)
return self.builder
def get_selects(self):
return self.__selects__
@classmethod
def get_columns(cls):
return list(cls.first().__attributes__.keys())
def get_connection_details(self):
DB = load_config().DB
return DB.get_connection_details()
def boot(self):
if not self._booted:
self.observe_events(self, "booting")
for base_class in inspect.getmro(self.__class__):
class_name = base_class.__name__
if class_name.endswith("Mixin"):
getattr(self, "boot_" + class_name)(self.get_builder())
elif (
base_class != Model
and issubclass(base_class, Model)
and "__fillable__" in base_class.__dict__
and "__guarded__" in base_class.__dict__
):
raise AttributeError(
f"{type(self).__name__} must specify either __fillable__ or __guarded__ properties, but not both."
)
self._booted = True
self.observe_events(self, "booted")
self.append_passthrough(list(self.get_builder()._macros.keys()))
def append_passthrough(self, passthrough):
self.__passthrough__.update(passthrough)
return self
@classmethod
def get_table_name(cls):
"""Gets the table name.
Returns:
str
"""
return cls.__table__ or tableize(cls.__name__)
@classmethod
def table(cls, table):
"""Gets the table name.
Returns:
str
"""
cls.__table__ = table
return cls
@classmethod
def find(cls, record_id, query=False):
"""Finds a row by the primary key ID.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model
"""
if isinstance(record_id, (list, tuple)):
builder = cls().where_in(cls.get_primary_key(), record_id)
else:
builder = cls().where(cls.get_primary_key(), record_id)
if query:
return builder
else:
if isinstance(record_id, (list, tuple)):
return builder.get()
return builder.first()
@classmethod
def find_or_fail(cls, record_id, query=False):
"""Finds a row by the primary key ID or raise a ModelNotFound exception.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model
"""
result = cls.find(record_id, query)
if not result:
raise ModelNotFound()
return result
def is_loaded(self):
return bool(self.__attributes__)
def is_created(self):
return self.get_primary_key() in self.__attributes__
def add_relation(self, relations):
self._relationships.update(relations)
return self
@classmethod
def hydrate(cls, result, relations=None):
"""Takes a result and loads it into a model
Args:
result ([type]): [description]
relations (dict, optional): [description]. Defaults to {}.
Returns:
[type]: [description]
"""
relations = relations or {}
if result is None:
return None
if isinstance(result, (list, tuple)):
response = []
for element in result:
response.append(cls.hydrate(element))
return cls.new_collection(response)
elif isinstance(result, dict):
model = cls()
dic = {}
for key, value in result.items():
if key in model.get_dates() and value:
value = model.get_new_date(value)
dic.update({key: value})
logger = logging.getLogger("masoniteorm.models.hydrate")
logger.setLevel(logging.INFO)
logger.propagate = False
logger.info(
f"Hydrating Model {cls.__name__}",
extra={"class_name": cls.__name__, "class_module": cls.__module__},
)
model.observe_events(model, "hydrating")
model.__attributes__.update(dic or {})
model.__original_attributes__.update(dic or {})
model.add_relation(relations)
model.observe_events(model, "hydrated")
return model
elif hasattr(result, "serialize"):
model = cls()
model.__attributes__.update(result.serialize())
model.__original_attributes__.update(result.serialize())
return model
else:
model = cls()
model.observe_events(model, "hydrating")
model.__attributes__.update(dict(result))
model.__original_attributes__.update(dict(result))
model.observe_events(model, "hydrated")
return model
def fill(self, attributes):
self.__attributes__.update(attributes)
return self
def fill_original(self, attributes):
self.__original_attributes__.update(attributes)
return self
@classmethod
def new_collection(cls, data):
"""Takes a result and puts it into a new collection.
This is designed to be able to be overidden by the user.
Args:
data (list|dict): Could be any data type but will be loaded directly into a collection.
Returns:
Collection
"""
return Collection(data)
@classmethod
def create(
cls,
dictionary: Dict[str, Any] = None,
query: bool = False,
cast: bool = False,
**kwargs,
):
"""Creates new records based off of a dictionary as well as data set on the model
such as fillable values.
Args:
dictionary (dict, optional): [description]. Defaults to {}.
query (bool, optional): [description]. Defaults to False.
cast (bool, optional): [description]. Whether or not to cast passed values.
Returns:
self: A hydrated version of a model
"""
if query:
return cls.builder.create(dictionary, query=True, cast=cast, **kwargs)
return cls.builder.create(dictionary, cast=cast, **kwargs)
@classmethod
def cast_value(cls, attribute: str, value: Any):
"""
Given an attribute name and a value, casts the value using the model's registered caster.
If no registered caster exists, returns the unmodified value.
"""
cast_method = cls.__casts__.get(attribute)
cast_map = cls.get_cast_map(cls)
if value is None:
return None
if isinstance(cast_method, str):
return cast_map[cast_method]().set(value)
if cast_method:
return cast_method(value)
return value
@classmethod
def cast_values(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Runs provided dictionary through all model casters and returns the result.
Does not mutate the passed dictionary.
"""
return {x: cls.cast_value(x, dictionary[x]) for x in dictionary}
def fresh(self):
return (
self.get_builder()
.where(self.get_primary_key(), self.get_primary_key_value())
.first()
)
def serialize(self, exclude=None, include=None):
"""Takes the data as a model and converts it into a dictionary.
Returns:
dict
"""
serialized_dictionary = self.__attributes__.copy()
# prevent using both exclude and include at the same time
if exclude is not None and include is not None:
raise AttributeError("Can not define both includes and exclude values.")
if exclude is not None:
self.__hidden__ = exclude
if include is not None:
self.__visible__ = include
# prevent using both hidden and visible at the same time
if self.__visible__ and self.__hidden__:
raise AttributeError(
f"class model '{self.__class__.__name__}' defines both __visible__ and __hidden__."
)
if self.__visible__:
new_serialized_dictionary = {
k: serialized_dictionary[k]
for k in self.__visible__
if k in serialized_dictionary
}
serialized_dictionary = new_serialized_dictionary
else:
for key in self.__hidden__:
if key in serialized_dictionary:
serialized_dictionary.pop(key)
for date_column in self.get_dates():
if (
date_column in serialized_dictionary
and serialized_dictionary[date_column]
):
serialized_dictionary[date_column] = self.get_new_serialized_date(
serialized_dictionary[date_column]
)
serialized_dictionary.update(self.__dirty_attributes__)
# The builder is inside the attributes but should not be serialized
if "builder" in serialized_dictionary:
serialized_dictionary.pop("builder")
# Serialize relationships as well
serialized_dictionary.update(self.relations_to_dict())
for append in self.__appends__:
serialized_dictionary.update({append: getattr(self, append)})
remove_keys = []
for key, value in serialized_dictionary.items():
if key in self.__hidden__:
remove_keys.append(key)
if hasattr(value, "serialize"):
value = value.serialize(self.__relationship_hidden__.get(key, []))
if isinstance(value, datetime):
value = self.get_new_serialized_date(value)
if key in self.__casts__:
value = self._cast_attribute(key, value)
serialized_dictionary.update({key: value})
for key in remove_keys:
serialized_dictionary.pop(key)
return serialized_dictionary
def to_json(self):
"""Converts a model to JSON
Returns:
string
"""
return json.dumps(self.serialize(), default=str)
@classmethod
def first_or_create(cls, wheres, creates: dict = None):
"""Get the first record matching the attributes or create it.
Returns:
Model
"""
if creates is None:
creates = {}
self = cls()
record = self.where(wheres).first()
total = {}
total.update(creates)
total.update(wheres)
if not record:
return self.create(total, id_key=cls.get_primary_key())
return record
@classmethod
def update_or_create(cls, wheres, updates):
self = cls()
record = self.where(wheres).first()
total = {}
total.update(updates)
total.update(wheres)
if not record:
return self.create(total, id_key=cls.get_primary_key()).fresh()
return self.where(wheres).update(total)
def relations_to_dict(self):
"""Converts a models relationships to a dictionary
Returns:
[type]: [description]
"""
new_dic = {}
for key, value in self._relationships.items():
if value == {}:
new_dic.update({key: {}})
else:
if value is None:
new_dic.update({key: {}})
continue
elif isinstance(value, list):
value = Collection(value).serialize()
elif isinstance(value, dict):
pass
else:
value = value.serialize()
new_dic.update({key: value})
return new_dic
def touch(self, date=None, query=True):
"""Updates the current timestamps on the model"""
if not self.__timestamps__:
return False
self._update_timestamps(date=date)
return self.save(query=query)
def _update_timestamps(self, date=None):
"""Sets the updated at date to the current time or a specified date
Args:
date (datetime.datetime, optional): a date. If none is specified then it will use the current date Defaults to None.
"""
self.updated_at = date or self._current_timestamp()
def _current_timestamp(self):
return datetime.now()
def __getattr__(self, attribute):
"""Magic method that is called when an attribute does not exist on the model.
Args:
attribute (string): the name of the attribute being accessed or called.
Returns:
mixed: Could be anything that a method can return.
"""
new_name_accessor = "get_" + attribute + "_attribute"
if (new_name_accessor) in self.__class__.__dict__:
return self.__class__.__dict__.get(new_name_accessor)(self)
if (
"__dirty_attributes__" in self.__dict__
and attribute in self.__dict__["__dirty_attributes__"]
):
return self.get_dirty_value(attribute)
if (
"__attributes__" in self.__dict__
and attribute in self.__dict__["__attributes__"]
):
if attribute in self.get_dates():
return (
self.get_new_date(self.get_value(attribute))
if self.get_value(attribute)
else None
)
return self.get_value(attribute)
if attribute in self.__passthrough__:
def method(*args, **kwargs):
return getattr(self.get_builder(), attribute)(*args, **kwargs)
return method
if attribute in self.__dict__.get("_relationships", {}):
return self.__dict__["_relationships"][attribute]
if attribute not in self.__dict__:
name = self.__class__.__name__
raise AttributeError(f"class model '{name}' has no attribute {attribute}")
return None
def only(self, attributes: list) -> dict:
if isinstance(attributes, str):
attributes = [attributes]
results: dict[str, Any] = {}
for attribute in attributes:
if " as " in attribute:
attribute, alias = attribute.split(" as ")
alias = alias.strip()
attribute = attribute.strip()
else:
alias = attribute.strip()
attribute = attribute.strip()
results[alias] = self.get_raw_attribute(attribute)
return results
def __setattr__(self, attribute, value):
if hasattr(self, "set_" + attribute + "_attribute"):
method = getattr(self, "set_" + attribute + "_attribute")
value = method(value)
if attribute in self.__casts__:
value = self._set_cast_attribute(attribute, value)
if attribute in self.get_dates():
value = self.get_new_datetime_string(value)
try:
if not attribute.startswith("_"):
self.__dict__["__dirty_attributes__"].update({attribute: value})
else:
self.__dict__[attribute] = value
except KeyError:
pass
def get_raw_attribute(self, attribute):
"""Gets an attribute without having to call the models magic methods. Gets around infinite recursion loops.
Args:
attribute (string): The attribute to fetch
Returns:
mixed: Any value an attribute can be.
"""
return self.__attributes__.get(attribute)
def is_dirty(self):
return bool(self.__dirty_attributes__)
def get_original(self, key):
return self.__original_attributes__.get(key)
def get_dirty(self, key):
return self.__dirty_attributes__.get(key)
def get_dirty_keys(self):
return list(self.get_dirty_attributes().keys())
def save(self, query=False):
builder = self.get_builder()
if "builder" in self.__dirty_attributes__:
self.__dirty_attributes__.pop("builder")
self.observe_events(self, "saving")
if not query:
if self.is_loaded():
result = builder.update(
self.__dirty_attributes__, ignore_mass_assignment=True
)
else:
result = self.create(
self.__dirty_attributes__,
query=query,
id_key=self.get_primary_key(),
ignore_mass_assignment=True,
)
self.observe_events(self, "saved")
self.fill(result.__attributes__)
self.__dirty_attributes__ = {}
return result
if self.is_loaded():
result = builder.update(
self.__dirty_attributes__, dry=query, ignore_mass_assignment=True
)
else:
result = self.create(self.__dirty_attributes__, query=query)
return result
def get_value(self, attribute):
value = self.__attributes__[attribute]
if attribute in self.__casts__:
return self._cast_attribute(attribute, value)
return value
def get_dirty_value(self, attribute):
value = self.__dirty_attributes__[attribute]
if attribute in self.__casts__:
return self._cast_attribute(attribute, value)
return value
def all_attributes(self):
attributes = self.__attributes__
attributes.update(self.get_dirty_attributes())
for key, value in attributes.items():
if key in self.__casts__:
attributes.update({key: self._cast_attribute(key, value)})
return attributes
def delete_attribute(self, key):
if key in self.__attributes__:
del self.__attributes__[key]
return True
return False
def get_dirty_attributes(self):
if "builder" in self.__dirty_attributes__:
self.__dirty_attributes__.pop("builder")
return self.__dirty_attributes__ or {}
def get_cast_map(self):
cast_map = self.__internal_cast_map__
cast_map.update(self.__cast_map__)
return cast_map
def _cast_attribute(self, attribute, value):
cast_method = self.__casts__[attribute]
cast_map = self.get_cast_map()
if value is None:
return None
if isinstance(cast_method, str):
return cast_map[cast_method]().get(value)
return cast_method(value)
def _set_cast_attribute(self, attribute, value):
cast_method = self.__casts__[attribute]
cast_map = self.get_cast_map()
if isinstance(cast_method, str):
return cast_map[cast_method]().set(value)
return cast_method(value)
@classmethod
def load(cls, *loads):
cls.boot()
cls._loads += loads
return cls.builder
def __getitem__(self, attribute):
return getattr(self, attribute)
def get_dates(self):
"""
Get the attributes that should be converted to dates.
:rtype: list
"""
defaults = [self.date_created_at, self.date_updated_at]
return self.__dates__ + defaults
def get_new_date(self, _datetime=None):
"""
Get the attributes that should be converted to dates.
:rtype: list
"""
import pendulum
if not _datetime:
return pendulum.now(tz=self.__timezone__)
elif isinstance(_datetime, str):
return pendulum.parse(_datetime, tz=self.__timezone__)
elif isinstance(_datetime, datetime):
return pendulum.instance(_datetime, tz=self.__timezone__)
elif isinstance(_datetime, datetimedate):
return pendulum.datetime(
_datetime.year, _datetime.month, _datetime.day, tz=self.__timezone__
)
elif isinstance(_datetime, datetimetime):
return pendulum.parse(
f"{_datetime.hour}:{_datetime.minute}:{_datetime.second}",
tz=self.__timezone__,
)
return pendulum.instance(_datetime, tz=self.__timezone__)
def get_new_datetime_string(self, _datetime=None):
"""
Given an optional datetime value, constructs and returns a new datetime string.
If no datetime is specified, returns the current time.
:rtype: list
"""
return self.get_new_date(_datetime).to_datetime_string()
def get_new_serialized_date(self, _datetime):
"""
Get the attributes that should be converted to dates.
:rtype: list
"""
return self.get_new_date(_datetime).isoformat()
def set_appends(self, appends):
"""
Get the attributes that should be converted to dates.
:rtype: list
"""
self.__appends__ += appends
return self
def save_many(self, relation, relating_records):
if isinstance(relating_records, Model):
raise ValueError(
"Saving many records requires an iterable like a collection or a list of models and not a Model object. To attach a model, use the 'attach' method."
)
for related_record in relating_records:
self.attach(relation, related_record)
def detach_many(self, relation, relating_records):
if isinstance(relating_records, Model):
raise ValueError(
"Detaching many records requires an iterable like a collection or a list of models and not a Model object. To detach a model, use the 'detach' method."
)
related = getattr(self.__class__, relation)
for related_record in relating_records:
if not related_record.is_created():
related_record.create(related_record.all_attributes())
else:
related_record.save()
related.detach(self, related_record)
def related(self, relation):
related = getattr(self.__class__, relation)
return related.relate(self)
def get_related(self, relation):
related = getattr(self.__class__, relation)
return related
def attach(self, relation, related_record):
related = getattr(self.__class__, relation)
return related.attach(self, related_record)
def detach(self, relation, related_record):
related = getattr(self.__class__, relation)
if not related_record.is_created():
related_record = related_record.create(related_record.all_attributes())
else:
related_record.save()
return related.detach(self, related_record)
def save_quietly(self):
"""This method calls the save method on a model without firing the saved & saving observer events. Saved/Saving
are toggled back on once save_quietly has been ran.
Instead of calling:
User().save(...)
you can use this:
User.save_quietly(...)
"""
self.without_events()
saved = self.save()
self.with_events()
return saved
def delete_quietly(self):
"""This method calls the delete method on a model without firing the delete & deleting observer events.
Instead of calling:
User().delete(...)
you can use this:
User.delete_quietly(...)
Returns:
self
"""
delete = (
self.without_events()
.where(self.get_primary_key(), self.get_primary_key_value())
.delete()
)
self.with_events()
return delete
def attach_related(self, relation, related_record):
return self.attach(relation, related_record)
@classmethod
def filter_fillable(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters provided dictionary to only include fields specified in the model's __fillable__ property
Passed dictionary is not mutated.
"""
if cls.__fillable__ != ["*"]:
dictionary = {x: dictionary[x] for x in cls.__fillable__ if x in dictionary}
return dictionary
@classmethod
def filter_mass_assignment(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters the provided dictionary in preparation for a mass-assignment operation
Wrapper around filter_fillable() & filter_guarded(). Passed dictionary is not mutated.
"""
return cls.filter_guarded(cls.filter_fillable(dictionary))
@classmethod
def filter_guarded(cls, dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters provided dictionary to exclude fields specified in the model's __guarded__ property
Passed dictionary is not mutated.
"""
if cls.__guarded__ == ["*"]:
# If all fields are guarded, all data should be filtered
return {}
return {f: dictionary[f] for f in dictionary if f not in cls.__guarded__}
================================================
FILE: src/masoniteorm/models/Model.pyi
================================================
from typing import Any, Dict
from typing_extensions import Self
from ..query.QueryBuilder import QueryBuilder
class Model:
def add_select(alias: str, callable: Any):
"""Specifies a select subquery."""
pass
def aggregate(aggregate: str, column: str, alias: str):
"""Helper function to aggregate.
Arguments:
aggregate {string} -- The name of the aggregation.
column {string} -- The name of the column to aggregate.
"""
def all(selects: list = [], query: bool = False):
"""Returns all records from the table.
Returns:
dictionary -- Returns a dictionary of results.
"""
pass
def get(selects: list = []):
"""Runs the select query built from the query builder.
Returns:
self
"""
pass
def avg(column: str):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
pass
def between(column: str, low: str | int, high: str | int):
"""Specifies a where between expression.
Arguments:
column {string} -- The name of the column.
low {string} -- The value on the low end.
high {string} -- The value on the high end.
Returns:
self
"""
pass
def bulk_create(creates: dict, query: bool = False):
pass
def cast_value(attribute: str, value: Any):
"""
Given an attribute name and a value, casts the value using the model's registered caster.
If no registered caster exists, returns the unmodified value.
"""
pass
def cast_values(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Runs provided dictionary through all model casters and returns the result.
Does not mutate the passed dictionary.
"""
pass
def chunk(chunk_amount: str | int):
pass
def count(column: str = None):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
pass
def create(
dictionary: Dict[str, Any] = None,
query: bool = False,
cast: bool = False,
**kwargs
):
"""Creates new records based off of a dictionary as well as data set on the model
such as fillable values.
Args:
dictionary (dict, optional): [description]. Defaults to {}.
query (bool, optional): [description]. Defaults to False.
cast (bool, optional): [description]. Whether or not to cast passed values.
Returns:
self: A hydrated version of a model
"""
pass
def decrement(column: str, value: int = 1):
"""Decrements a column's value.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
value {int} -- The value to decrement by. (default: {1})
Returns:
self
"""
def delete(column: str = None, value: str = None, query: bool = False):
"""Specify the column and value to delete
or deletes everything based on a previously used where expression.
Keyword Arguments:
column {string} -- The name of the column (default: {None})
value {string|int} -- The value of the column (default: {None})
Returns:
self
"""
pass
def distinct(boolean: bool = True):
"""Species that the select query should be a SELECT DISTINCT query."""
pass
def doesnt_exist() -> bool:
"""Determines if any rows exist for the current query.
Returns:
Bool - True or False
"""
pass
def doesnt_have() -> bool:
"""Determine if any related rows exist for the current query.
Returns:
Bool - True or False
"""
pass
def exists() -> bool:
"""Determine if rows exist for the current query.
Returns:
Bool - True or False
"""
pass
def filter_fillable(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters provided dictionary to only include fields specified in the model's __fillable__ property
Passed dictionary is not mutated.
"""
pass
def filter_mass_assignment(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters the provided dictionary in preparation for a mass-assignment operation
Wrapper around filter_fillable() & filter_guarded(). Passed dictionary is not mutated.
"""
pass
def filter_guarded(dictionary: Dict[str, Any]) -> Dict[str, Any]:
"""
Filters provided dictionary to exclude fields specified in the model's __guarded__ property
Passed dictionary is not mutated.
"""
pass
def find_or_404(record_id: str | int):
"""Finds a row by the primary key ID (Requires a model) or raise an 404 exception.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model|HTTP404
"""
pass
def find(record_id: str | list) -> Self:
"""Finds a row by the primary key ID (Requires a model) or raise an 404 exception.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model|Collection
"""
pass
def find_or_fail(record_id: str | int):
"""Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model|ModelNotFound
"""
pass
def first_or_fail(query: bool = False):
"""Returns the first row from database. If no result found a ModelNotFound exception.
Returns:
dictionary|ModelNotFound
"""
def first(fields: list = None, query: bool = False):
"""Gets the first record.
Returns:
dictionary -- Returns a dictionary of results.
"""
pass
def first_where(column: str, *args):
"""Gets the first record with the given key / value pair"""
pass
def first_or_create(wheres: dict, creates: dict = None):
"""Get the first record matching the attributes or create it.
Returns:
Model
"""
pass
def force_update(updates: dict, dry: bool = False):
pass
def from_(table: str):
"""Alias for the table method
Arguments:
table {string} -- The name of the table
Returns:
self
"""
pass
def from_raw(table: str):
"""Alias for the table method
Arguments:
table {string} -- The name of the table
Returns:
self
"""
pass
def last(column: str = None, query: bool = False):
"""Gets the last record, ordered by column in descendant order or primary
key if no column is given.
Returns:
dictionary -- Returns a dictionary of results.
"""
pass
def group_by_raw(query: str, bindings: list = []):
"""Specifies a column to group by.
Arguments:
query {string} -- A raw query
Returns:
self
"""
pass
def group_by(column: str):
"""Specifies a column to group by.
Arguments:
column {string} -- The name of the column to group by.
Returns:
self
"""
pass
def has(*relationships: str):
pass
def having_raw(string: str):
"""Specifies raw SQL that should be injected into the having expression.
Arguments:
string {string} -- The raw query string.
Returns:
self
"""
pass
def increment(column: str, value: int = 1):
"""Increments a column's value.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
value {int} -- The value to increment by. (default: {1})
Returns:
self
"""
pass
def in_random_order():
"""Puts Query results in random order"""
pass
def join_on(relationship: str, callback: callable = None, clause: str = ["inner"]):
pass
def join(
self,
table: str,
column1: str = None,
equality: str = None,
column2: str = None,
clause: str = "inner",
):
"""Specifies a join expression.
Arguments:
table {string} -- The name of the table or an instance of JoinClause.
column1 {string} -- The name of the foreign table.
equality {string} -- The equality to join on.
column2 {string} -- The name of the local column.
Keyword Arguments:
clause {string} -- The action clause. (default: {"inner"})
Returns:
self
"""
pass
def joins(*relationships: list[str], clause: str = "inner"):
pass
def left_join(
table: str, column1: str = None, equality: str = None, column2: str = None
):
"""A helper method to add a left join expression.
Arguments:
table {string} -- The name of the table to join on.
column1 {string} -- The name of the foreign table.
equality {string} -- The equality to join on.
column2 {string} -- The name of the local column.
Returns:
self
"""
pass
def limit(amount: int):
"""Specifies a limit expression.
Arguments:
amount {int} -- The number of rows to limit.
Returns:
self
"""
pass
def lock_for_update():
pass
def make_lock(lock: bool):
pass
def max(column: str):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
pass
def min(column: str):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
pass
def new_from_builder(from_builder: QueryBuilder = None):
"""Creates a new QueryBuilder class.
Returns:
QueryBuilder -- The ORM QueryBuilder class.
"""
pass
def new():
"""Creates a new QueryBuilder class.
Returns:
QueryBuilder -- The ORM QueryBuilder class.
"""
pass
def not_between(column: str, low: str | int, high: str | int):
"""Specifies a where not between expression.
Arguments:
column {string} -- The name of the column.
low {string} -- The value on the low end.
high {string} -- The value on the high end.
Returns:
self
"""
pass
def offset(amount: int):
"""Specifies an offset expression.
Arguments:
amount {int} -- The number of rows to limit.
Returns:
self
"""
pass
def on(connection: str):
pass
def or_where(column: str | int, *args) -> QueryBuilder:
"""Specifies an or where query expression.
Arguments:
column {[type]} -- [description]
value {[type]} -- [description]
Returns:
[type] -- [description]
"""
pass
def or_where_null(column: str):
"""Specifies a where expression where the column is NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
pass
def order_by_raw(query: str, bindings: list = []):
"""Specifies a column to order by.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
direction {string} -- Specify either ASC or DESC order. (default: {"ASC"})
Returns:
self
"""
pass
def order_by(column: str, direction: str = "ASC|DESC"):
"""Specifies a column to order by.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
direction {string} -- Specify either ASC or DESC order. (default: {"ASC"})
Returns:
self
"""
pass
def paginate(per_page: int, page: int = 1):
pass
def right_join(
table: str, column1: str = None, equality: str = None, column2: str = None
):
"""A helper method to add a right join expression.
Arguments:
table {string} -- The name of the table to join on.
column1 {string} -- The name of the foreign table.
equality {string} -- The equality to join on.
column2 {string} -- The name of the local column.
Returns:
self
"""
pass
def select_raw(query: str):
"""Specifies raw SQL that should be injected into the select expression.
Returns:
self
"""
pass
def select(*args: str):
"""Specifies columns that should be selected
Returns:
self
"""
pass
def set_global_scope(
self,
name: str = "",
callable: callable = None,
action: str = ["select", "update", "create", "delete"],
):
"""Sets the global scopes that should be used before creating the SQL.
Arguments:
cls {masoniteorm.Model} -- An ORM model class.
name {string} -- The name of the global scope.
Returns:
self
"""
pass
def shared_lock():
pass
def simple_paginate(per_page: int, page: int = 1):
pass
def skip(*args, **kwargs):
"""Alias for limit method."""
pass
def statement(query: str, bindings: list = ()):
pass
def sum(column: str):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
pass
def table_raw(query: str):
"""Sets a query as the table
Arguments:
query {string} -- The query to use for the table
Returns:
self
"""
pass
def take(*args, **kwargs):
"""Alias for limit method"""
pass
def to_qmark() -> str:
"""Compiles the QueryBuilder class into a Qmark SQL statement.
Returns:
self
"""
pass
def to_sql() -> str:
"""Compiles the QueryBuilder class into a SQL statement.
Returns:
self
"""
pass
def truncate(foreign_keys: bool = False):
pass
def update(
updates: dict, dry: bool = False, force: bool = False, cast: bool = False
):
"""Specifies columns and values to be updated.
Arguments:
updates {dictionary} -- A dictionary of columns and values to update.
dry {bool, optional} -- Whether a query should actually run
force {bool, optional} -- Force the update even if there are no changes
cast {bool, optional} -- Run all values through model's casters
Returns:
self
"""
pass
def when(conditional: bool, callback: callable):
pass
def where_between(*args, **kwargs):
"""Alias for between"""
pass
def where_column(column1: str, column2: str):
"""Specifies where two columns equal eachother.
Arguments:
column1 {string} -- The name of the column.
column2 {string} -- The name of the column.
Returns:
self
"""
pass
def take(*args: Any, **kwargs: Any):
"""Alias for limit method"""
pass
def where_column(column1: str, column2: str):
"""Specifies where two columns equal eachother.
Arguments:
column1 {string} -- The name of the column.
column2 {string} -- The name of the column.
Returns:
self
"""
pass
def where_date(column: str, date: Any):
"""Specifies a where DATE expression
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
pass
def or_where_date(column: str, date: Any):
"""Specifies a where DATE expression
Arguments:
column {string} -- The name of the column.
date {string|datetime|pendulum} -- The name of the column.
Returns:
self
"""
pass
def where_exists(value: Any):
"""Specifies a where exists expression.
Arguments:
value {string|int|QueryBuilder} -- A value to check for the existence of a query expression.
Returns:
self
"""
pass
def where_from_builder(builder: QueryBuilder):
"""Specifies a where expression.
Arguments:
column {string} -- The name of the column to search
Keyword Arguments:
args {List} -- The operator and the value of the column to search. (default: {None})
Returns:
self
"""
pass
def where_has(relationship: str, callback: Any):
pass
def where_in(column: str, wheres: list = []):
"""Specifies where a column contains a list of a values.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
wheres {list} -- A list of values (default: {[]})
Returns:
self
"""
pass
def where_like(column: str, value: str):
"""Specifies a where LIKE expression.
Arguments:
column {string} -- The name of the column to search
value {string} -- The value of the column to match
Returns:
self
"""
pass
def where_not_between(*args: Any, **kwargs: Any):
"""Alias for not_between"""
pass
def where_not_in(column: str, wheres: list = []):
"""Specifies where a column does not contain a list of a values.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
wheres {list} -- A list of values (default: {[]})
Returns:
self
"""
pass
def where_not_like(column: str, value: str):
"""Specifies a where expression.
Arguments:
column {string} -- The name of the column to search
value {string} -- The value of the column to match
Returns:
self
"""
pass
def where_not_null(column: str):
"""Specifies a where expression where the column is not NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
pass
def where_null(column: str):
"""Specifies a where expression where the column is NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
pass
def where_raw(query: str, bindings: list = []):
"""Specifies raw SQL that should be injected into the where expression.
Arguments:
query {string} -- The raw query string.
Keyword Arguments:
bindings {tuple} -- query bindings that should be added to the connection. (default: {()})
Returns:
self
"""
pass
def without_global_scopes():
pass
def where(column: str, *args: Any):
"""Specifies a where expression.
Arguments:
column {string} -- The name of the column to search
Keyword Arguments:
args {List} -- The operator and the value of the column to search. (default: {None})
Returns:
self
"""
pass
def with_(*eagers: str):
pass
def with_count(relationship: str, callback: Any = None):
pass
================================================
FILE: src/masoniteorm/models/Pivot.py
================================================
from .Model import Model
class Pivot(Model):
__primary_key__ = "id"
================================================
FILE: src/masoniteorm/models/__init__.py
================================================
from .Model import Model
================================================
FILE: src/masoniteorm/observers/ObservesEvents.py
================================================
class ObservesEvents:
def observe_events(self, model, event):
if model.__has_events__ == True:
for observer in model.__observers__.get(model.__class__, []):
try:
getattr(observer, event)(model)
except AttributeError:
pass
@classmethod
def observe(cls, observer):
if cls in cls.__observers__:
cls.__observers__[cls].append(observer)
else:
cls.__observers__.update({cls: [observer]})
@classmethod
def without_events(cls):
"""Sets __has_events__ attribute on model to false."""
cls.__has_events__ = False
return cls
@classmethod
def with_events(cls):
"""Sets __has_events__ attribute on model to True."""
cls.__has_events__ = True
return cls
================================================
FILE: src/masoniteorm/observers/__init__.py
================================================
from .ObservesEvents import ObservesEvents
================================================
FILE: src/masoniteorm/pagination/BasePaginator.py
================================================
import json
class BasePaginator:
def __iter__(self):
for result in self.result:
yield result
def to_json(self):
return json.dumps(self.serialize())
================================================
FILE: src/masoniteorm/pagination/LengthAwarePaginator.py
================================================
import math
from .BasePaginator import BasePaginator
class LengthAwarePaginator(BasePaginator):
def __init__(self, result, per_page, current_page, total, url=None):
self.result = result
self.current_page = current_page
self.per_page = per_page
self.count = len(self.result)
self.last_page = int(math.ceil(total / per_page))
self.next_page = (int(self.current_page) + 1) if self.has_more_pages() else None
self.previous_page = (int(self.current_page) - 1) or None
self.total = total
self.url = url
def serialize(self, *args, **kwargs):
return {
"data": self.result.serialize(*args, **kwargs),
"meta": {
"total": self.total,
"next_page": self.next_page,
"count": self.count,
"previous_page": self.previous_page,
"last_page": self.last_page,
"current_page": self.current_page,
},
}
def has_more_pages(self):
return self.current_page < self.last_page
================================================
FILE: src/masoniteorm/pagination/SimplePaginator.py
================================================
from .BasePaginator import BasePaginator
class SimplePaginator(BasePaginator):
def __init__(self, result, per_page, current_page, url=None):
self.result = result
self.current_page = current_page
self.per_page = per_page
self.count = len(self.result)
self.next_page = (int(self.current_page) + 1) if self.has_more_pages() else None
self.previous_page = (int(self.current_page) - 1) or None
self.url = url
def serialize(self, *args, **kwargs):
return {
"data": self.result.serialize(*args, **kwargs),
"meta": {
"next_page": self.next_page,
"count": self.count,
"previous_page": self.previous_page,
"current_page": self.current_page,
},
}
def has_more_pages(self):
return len(self.result) > self.per_page
================================================
FILE: src/masoniteorm/pagination/__init__.py
================================================
from .LengthAwarePaginator import LengthAwarePaginator
from .SimplePaginator import SimplePaginator
================================================
FILE: src/masoniteorm/providers/ORMProvider.py
================================================
from masonite.providers import Provider
from masoniteorm.commands import (
MigrateCommand,
MigrateRollbackCommand,
MigrateRefreshCommand,
MigrateResetCommand,
MakeModelCommand,
MakeObserverCommand,
MigrateStatusCommand,
MakeMigrationCommand,
MakeSeedCommand,
SeedRunCommand,
)
class ORMProvider(Provider):
"""Masonite ORM database provider to be used inside
Masonite based projects."""
def __init__(self, application):
self.application = application
def register(self):
self.application.make("commands").add(
MakeMigrationCommand(),
MakeSeedCommand(),
MakeObserverCommand(),
MigrateCommand(),
MigrateResetCommand(),
MakeModelCommand(),
MigrateStatusCommand(),
MigrateRefreshCommand(),
MigrateRollbackCommand(),
SeedRunCommand(),
),
def boot(self):
pass
================================================
FILE: src/masoniteorm/providers/__init__.py
================================================
from .ORMProvider import ORMProvider
================================================
FILE: src/masoniteorm/query/EagerRelation.py
================================================
class EagerRelations:
def __init__(self, relation=None):
self.eagers = []
self.nested_eagers = {}
self.callback_eagers = {}
self.is_nested = False
self.relation = relation
def register(self, *relations, callback=None):
for relation in relations:
if isinstance(relation, str) and "." not in relation:
self.eagers += [relation]
elif isinstance(relation, str) and "." in relation:
self.is_nested = True
relation_key = relation.split(".")[0]
if relation_key not in self.nested_eagers:
self.nested_eagers = {relation_key: relation.split(".")[1:]}
else:
self.nested_eagers[relation_key] += relation.split(".")[1:]
elif isinstance(relation, (tuple, list)):
for eagers in relations:
for eager in eagers:
self.register(eager)
elif isinstance(relation, dict):
self.callback_eagers.update(relation)
return self
def get_eagers(self):
eagers = []
if self.eagers:
eagers.append(self.eagers)
if self.nested_eagers:
eagers.append(self.nested_eagers)
if self.callback_eagers:
eagers.append(self.callback_eagers)
return eagers
================================================
FILE: src/masoniteorm/query/QueryBuilder.py
================================================
import inspect
from copy import deepcopy
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional
from ..collection.Collection import Collection
from ..config import load_config
from ..exceptions import (
HTTP404,
ConnectionNotRegistered,
InvalidArgument,
ModelNotFound,
MultipleRecordsFound,
)
from ..expressions.expressions import (
AggregateExpression,
BetweenExpression,
FromTable,
GroupByExpression,
HavingExpression,
JoinClause,
OrderByExpression,
QueryExpression,
SelectExpression,
SubGroupExpression,
SubSelectExpression,
UpdateQueryExpression,
)
from ..observers import ObservesEvents
from ..pagination import LengthAwarePaginator, SimplePaginator
from ..schema import Schema
from ..scopes import BaseScope
from .EagerRelation import EagerRelations
class QueryBuilder(ObservesEvents):
"""A builder class to manage the building and creation of query expressions."""
def __init__(
self,
grammar=None,
connection="default",
connection_class=None,
table=None,
connection_details=None,
connection_driver="default",
model=None,
scopes=None,
schema=None,
dry=False,
config_path=None,
):
"""QueryBuilder initializer
Arguments:
grammar {masoniteorm.grammar.Grammar} -- A grammar class.
Keyword Arguments:
connection {masoniteorm.connection.Connection} -- A connection class (default: {None})
table {str} -- the name of the table (default: {""})
"""
self.config_path = config_path
self.grammar = grammar
self.table(table)
self.dry = dry
self._creates_related = {}
self.connection = connection
self.connection_class = connection_class
self._connection = None
self._connection_details = connection_details or {}
self._connection_driver = connection_driver
self._scopes = scopes or {}
self.lock = False
self._schema = schema
self._eager_relation = EagerRelations()
if model:
self._global_scopes = model._global_scopes
if model.__with__:
self.with_(model.__with__)
else:
self._global_scopes = {}
self.builder = self
self._columns = ()
self._creates = {}
self._sql = ""
self._bindings = ()
self._updates = ()
self._wheres = ()
self._order_by = ()
self._group_by = ()
self._joins = ()
self._having = ()
self._macros = {}
self._aggregates = ()
self._limit = False
self._offset = False
self._distinct = False
self._model = model
self.set_action("select")
if not self._connection_details:
DB = load_config(config_path=self.config_path).DB
self._connection_details = DB.get_connection_details()
self.on(connection)
if grammar:
self.grammar = grammar
if connection_class:
self.connection_class = connection_class
def _set_creates_related(self, fields: dict):
self._creates_related = fields
return self
def set_schema(self, schema):
self._schema = schema
return self
def shared_lock(self):
return self.make_lock("share")
def lock_for_update(self):
return self.make_lock("update")
def make_lock(self, lock):
self.lock = lock
return self
def reset(self):
"""Resets the query builder instance so you can make multiple calls with the same builder instance"""
self.set_action("select")
self._updates = ()
self._wheres = ()
self._order_by = ()
self._group_by = ()
self._joins = ()
self._having = ()
return self
def get_connection_information(self):
return {
"host": self._connection_details.get(self.connection, {}).get("host"),
"database": self._connection_details.get(self.connection, {}).get(
"database"
),
"user": self._connection_details.get(self.connection, {}).get("user"),
"port": self._connection_details.get(self.connection, {}).get("port"),
"password": self._connection_details.get(self.connection, {}).get(
"password"
),
"prefix": self._connection_details.get(self.connection, {}).get("prefix"),
"options": self._connection_details.get(self.connection, {}).get(
"options", {}
),
"full_details": self._connection_details.get(self.connection, {}),
}
def table(self, table, raw=False):
"""Sets a table on the query builder
Arguments:
table {string} -- The name of the table
Returns:
self
"""
if table:
self._table = FromTable(table, raw=raw)
else:
self._table = table
return self
def from_(self, table):
"""Alias for the table method
Arguments:
table {string} -- The name of the table
Returns:
self
"""
return self.table(table)
def from_raw(self, table):
"""Alias for the table method
Arguments:
table {string} -- The name of the table
Returns:
self
"""
return self.table(table, raw=True)
def table_raw(self, query):
"""Sets a query on the query builder
Arguments:
query {string} -- The query to use for the table
Returns:
self
"""
return self.from_raw(query)
def get_table_name(self):
"""Sets a table on the query builder
Arguments:
table {string} -- The name of the table
Returns:
self
"""
return self._table.name
def get_connection(self):
"""Sets a table on the query builder
Arguments:
table {string} -- The name of the table
Returns:
self
"""
return self.connection_class
def begin(self):
"""Sets a table on the query builder
Arguments:
table {string} -- The name of the table
Returns:
self
"""
return self.new_connection().begin()
def begin_transaction(self, *args, **kwargs):
return self.begin(*args, **kwargs)
def get_schema_builder(self):
return Schema(connection=self.connection_class, grammar=self.grammar)
def commit(self):
"""Sets a table on the query builder
Arguments:
table {string} -- The name of the table
Returns:
self
"""
return self._connection.commit()
def rollback(self):
"""Sets a table on the query builder
Arguments:
table {string} -- The name of the table
Returns:
self
"""
self._connection.rollback()
return self
def get_relation(self, key):
"""Sets a table on the query builder
Arguments:
table {string} -- The name of the table
Returns:
self
"""
return getattr(self.owner, key)
def set_scope(self, name, callable):
"""Sets a scope based on a class and maps it to a name.
Arguments:
cls {masoniteorm.Model} -- An ORM model class.
name {string} -- The name of the scope to use.
Returns:
self
"""
# setattr(self, name, callable)
self._scopes.update({name: callable})
return self
def set_global_scope(self, name="", callable=None, action="select"):
"""Sets the global scopes that should be used before creating the SQL.
Arguments:
cls {masoniteorm.Model} -- An ORM model class.
name {string} -- The name of the global scope.
Returns:
self
"""
if isinstance(name, BaseScope):
name.on_boot(self)
return self
if action not in self._global_scopes:
self._global_scopes[action] = {}
self._global_scopes[action].update({name: callable})
return self
def without_global_scopes(self):
self._global_scopes = {}
return self
def remove_global_scope(self, scope, action=None):
"""Sets the global scopes that should be used before creating the SQL.
Arguments:
cls {masoniteorm.Model} -- An ORM model class.
name {string} -- The name of the global scope.
Returns:
self
"""
if isinstance(scope, BaseScope):
scope.on_remove(self)
return self
del self._global_scopes.get(action, {})[scope]
return self
def __getattr__(self, attribute):
"""Magic method for fetching query scopes.
This method is only used when a method or attribute does not already exist.
Arguments:
attribute {string} -- The attribute to fetch.
Raises:
AttributeError: Raised when there is no attribute or scope on the builder class.
Returns:
self
"""
if attribute == "__setstate__":
raise AttributeError(
"'QueryBuilder' object has no attribute '{}'".format(attribute)
)
if attribute in self._scopes:
def method(*args, **kwargs):
return self._scopes[attribute](self._model, self, *args, **kwargs)
return method
if attribute in self._macros:
def method(*args, **kwargs):
return self._macros[attribute](self._model, self, *args, **kwargs)
return method
raise AttributeError(
"'QueryBuilder' object has no attribute '{}'".format(attribute)
)
def on(self, connection):
DB = load_config(self.config_path).DB
if connection == "default":
self.connection = self._connection_details.get("default")
else:
self.connection = connection
if self.connection not in self._connection_details:
raise ConnectionNotRegistered(
f"Could not find the '{self.connection}' connection details"
)
self._connection_driver = self._connection_details.get(self.connection).get(
"driver"
)
self.connection_class = DB.connection_factory.make(self._connection_driver)
self.grammar = self.connection_class.get_default_query_grammar()
return self
def select(self, *args):
"""Specifies columns that should be selected
Returns:
self
"""
for arg in args:
if isinstance(arg, list):
for column in arg:
self._columns += (SelectExpression(column),)
else:
for column in arg.split(","):
self._columns += (SelectExpression(column),)
return self
def distinct(self, boolean=True):
"""Specifies that all columns should be distinct
Returns:
self
"""
self._distinct = boolean
return self
def add_select(self, alias, callable):
"""Specifies columns that should be selected
Returns:
self
"""
builder = callable(self.new())
self._columns += (SubGroupExpression(builder, alias=alias),)
return self
def statement(self, query, bindings=None):
if bindings is None:
bindings = []
result = self.new_connection().query(query, bindings)
return self.prepare_result(result)
def select_raw(self, query):
"""Specifies raw SQL that should be injected into the select expression.
Returns:
self
"""
self._columns += (SelectExpression(query, raw=True),)
return self
def get_processor(self):
return self.connection_class.get_default_post_processor()()
def bulk_create(
self, creates: List[Dict[str, Any]], query: bool = False, cast: bool = False
):
self.set_action("bulk_create")
model = None
if self._model:
model = self._model
self._creates = []
for unsorted_create in creates:
if model:
unsorted_create = model.filter_mass_assignment(unsorted_create)
if cast:
unsorted_create = model.cast_values(unsorted_create)
# sort the dicts by key so the values inserted align with the correct column
self._creates.append(dict(sorted(unsorted_create.items())))
if query:
return self
if model:
model = model.hydrate(self._creates)
if not self.dry:
connection = self.new_connection()
query_result = connection.query(self.to_qmark(), self._bindings, results=1)
processed_results = query_result or self._creates
else:
processed_results = self._creates
if model:
return model
return processed_results
def create(
self,
creates: Optional[Dict[str, Any]] = None,
query: bool = False,
id_key: str = "id",
cast: bool = False,
ignore_mass_assignment: bool = False,
**kwargs,
):
"""Specifies a dictionary that should be used to create new values.
Arguments:
creates {dict} -- A dictionary of columns and values.
Returns:
self
"""
self.set_action("insert")
model = None
self._creates = creates if creates else kwargs
if self._model:
model = self._model
# Update values with related record's
self._creates.update(self._creates_related)
# Filter __fillable/__guarded__ fields
if not ignore_mass_assignment:
self._creates = model.filter_mass_assignment(self._creates)
# Cast values if necessary
if cast:
self._creates = model.cast_values(self._creates)
if query:
return self
if model:
model = model.hydrate(self._creates)
self.observe_events(model, "creating")
# if attributes were modified during model observer then we need to update the creates here
self._creates.update(model.get_dirty_attributes())
if not self.dry:
connection = self.new_connection()
query_result = connection.query(self.to_qmark(), self._bindings, results=1)
if model:
id_key = model.get_primary_key()
processed_results = self.get_processor().process_insert_get_id(
self, query_result or self._creates, id_key
)
else:
processed_results = self._creates
if model:
model = model.fill(processed_results)
self.observe_events(model, "created")
return model
return processed_results
def hydrate(self, result, relations=None):
return self._model.hydrate(result, relations)
def delete(self, column=None, value=None, query=False):
"""Specify the column and value to delete
or deletes everything based on a previously used where expression.
Keyword Arguments:
column {string} -- The name of the column (default: {None})
value {string|int} -- The value of the column (default: {None})
Returns:
self
"""
model = None
self.set_action("delete")
if self._model:
model = self._model
if column and value:
if isinstance(value, (list, tuple)):
self.where_in(column, value)
else:
self.where(column, value)
if query:
return self
if model and model.is_loaded():
self.where(model.get_primary_key(), model.get_primary_key_value())
self.observe_events(model, "deleting")
result = self.new_connection().query(self.to_qmark(), self._bindings)
if model:
self.observe_events(model, "deleted")
return result
def where(self, column, *args):
"""Specifies a where expression.
Arguments:
column {string} -- The name of the column to search
Keyword Arguments:
args {List} -- The operator and the value of the column to search. (default: {None})
Returns:
self
"""
operator, value = self._extract_operator_value(*args)
if inspect.isfunction(column):
builder = column(self.new())
self._wheres += (
(QueryExpression(None, operator, SubGroupExpression(builder))),
)
elif isinstance(column, dict):
for key, value in column.items():
self._wheres += ((QueryExpression(key, "=", value, "value")),)
elif isinstance(value, QueryBuilder):
self._wheres += (
(QueryExpression(column, operator, SubSelectExpression(value))),
)
else:
self._wheres += ((QueryExpression(column, operator, value, "value")),)
return self
def where_from_builder(self, builder):
"""Specifies a where expression.
Arguments:
column {string} -- The name of the column to search
Keyword Arguments:
args {List} -- The operator and the value of the column to search. (default: {None})
Returns:
self
"""
self._wheres += ((QueryExpression(None, "=", SubGroupExpression(builder))),)
return self
def where_like(self, column, value):
"""Specifies a where expression.
Arguments:
column {string} -- The name of the column to search
Keyword Arguments:
args {List} -- The operator and the value of the column to search. (default: {None})
Returns:
self
"""
return self.where(column, "like", value)
def where_not_like(self, column, value):
"""Specifies a where expression.
Arguments:
column {string} -- The name of the column to search
Keyword Arguments:
args {List} -- The operator and the value of the column to search. (default: {None})
Returns:
self
"""
return self.where(column, "not like", value)
def where_raw(self, query: str, bindings=()):
"""Specifies raw SQL that should be injected into the where expression.
Arguments:
query {string} -- The raw query string.
Keyword Arguments:
bindings {tuple} -- query bindings that should be added to the connection. (default: {()})
Returns:
self
"""
self._wheres += (
(QueryExpression(query, "=", None, "value", raw=True, bindings=bindings)),
)
return self
def or_where(self, column, *args):
"""Specifies an or where query expression.
Arguments:
column {[type]} -- [description]
value {[type]} -- [description]
Returns:
[type] -- [description]
"""
operator, value = self._extract_operator_value(*args)
if inspect.isfunction(column):
builder = column(self.new())
self._wheres += (
(
QueryExpression(
None, operator, SubGroupExpression(builder), keyword="or"
)
),
)
elif isinstance(value, QueryBuilder):
self._wheres += (
(QueryExpression(column, operator, SubSelectExpression(value))),
)
else:
self._wheres += (
(QueryExpression(column, operator, value, "value", keyword="or")),
)
return self
def where_exists(self, value: "str|int|QueryBuilder"):
"""Specifies a where exists expression.
Arguments:
value {string|int|QueryBuilder} -- A value to check for the existence of a query expression.
Returns:
self
"""
if inspect.isfunction(value):
self._wheres += (
(
QueryExpression(
None, "EXISTS", SubSelectExpression(value(self.new()))
)
),
)
elif isinstance(value, QueryBuilder):
self._wheres += (
(QueryExpression(None, "EXISTS", SubSelectExpression(value))),
)
else:
self._wheres += ((QueryExpression(None, "EXISTS", value, "value")),)
return self
def or_where_exists(self, value: "str|int|QueryBuilder"):
"""Specifies a where exists expression.
Arguments:
value {string|int|QueryBuilder} -- A value to check for the existence of a query expression.
Returns:
self
"""
if inspect.isfunction(value):
self._wheres += (
(
QueryExpression(
None,
"EXISTS",
SubSelectExpression(value(self.new())),
keyword="or",
)
),
)
elif isinstance(value, QueryBuilder):
self._wheres += (
(
QueryExpression(
None, "EXISTS", SubSelectExpression(value), keyword="or"
)
),
)
else:
self._wheres += (
(QueryExpression(None, "EXISTS", value, "value", keyword="or")),
)
return self
def where_not_exists(self, value: "str|int|QueryBuilder"):
"""Specifies a where exists expression.
Arguments:
value {string|int|QueryBuilder} -- A value to check for the existence of a query expression.
Returns:
self
"""
if inspect.isfunction(value):
self._wheres += (
(
QueryExpression(
None, "NOT EXISTS", SubSelectExpression(value(self.new()))
)
),
)
elif isinstance(value, QueryBuilder):
self._wheres += (
(QueryExpression(None, "NOT EXISTS", SubSelectExpression(value))),
)
else:
self._wheres += ((QueryExpression(None, "NOT EXISTS", value, "value")),)
return self
def or_where_not_exists(self, value: "str|int|QueryBuilder"):
"""Specifies a where exists expression.
Arguments:
value {string|int|QueryBuilder} -- A value to check for the existence of a query expression.
Returns:
self
"""
if inspect.isfunction(value):
self._wheres += (
(
QueryExpression(
None,
"NOT EXISTS",
SubSelectExpression(value(self.new())),
keyword="or",
)
),
)
elif isinstance(value, QueryBuilder):
self._wheres += (
(
QueryExpression(
None, "NOT EXISTS", SubSelectExpression(value), keyword="or"
)
),
)
else:
self._wheres += (
(QueryExpression(None, "NOT EXISTS", value, "value", keyword="or")),
)
return self
def having(self, column, equality="", value=""):
"""Specifying a having expression.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
equality {string} -- An equality operator (default: {"="})
value {string} -- The value of the having expression (default: {""})
Returns:
self
"""
self._having += ((HavingExpression(column, equality, value)),)
return self
def having_raw(self, string):
"""Specifies raw SQL that should be injected into the having expression.
Arguments:
string {string} -- The raw query string.
Returns:
self
"""
self._having += ((HavingExpression(string, raw=True)),)
return self
def where_null(self, column):
"""Specifies a where expression where the column is NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self._wheres += ((QueryExpression(column, "=", None, "NULL")),)
return self
def or_where_null(self, column):
"""Specifies a where expression where the column is NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self._wheres += ((QueryExpression(column, "=", None, "NULL", keyword="or")),)
return self
def chunk(self, chunk_amount):
chunk_connection = self.new_connection()
for result in chunk_connection.select_many(self.to_sql(), (), chunk_amount):
yield self.prepare_result(result)
def where_not_null(self, column: str):
"""Specifies a where expression where the column is not NULL.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self._wheres += ((QueryExpression(column, "=", True, "NOT NULL")),)
return self
def _get_date_string(self, date):
if isinstance(date, str):
return date
elif hasattr(date, "to_date_string"):
return date.to_date_string()
elif hasattr(date, "strftime"):
return date.strftime("%m-%d-%Y")
def where_date(self, column: str, date: "str|datetime"):
"""Specifies a where DATE expression
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
self._wheres += (
(QueryExpression(column, "=", self._get_date_string(date), "DATE")),
)
return self
def or_where_date(self, column: str, date: "str|datetime"):
"""Specifies a where DATE expression
Arguments:
column {string} -- The name of the column.
date {string|datetime|pendulum} -- The name of the column.
Returns:
self
"""
self._wheres += (
(
QueryExpression(
column, "=", self._get_date_string(date), "DATE", keyword="or"
)
),
)
return self
def between(self, column: str, low: int, high: int):
"""Specifies a where between expression.
Arguments:
column {string} -- The name of the column.
low {string} -- The value on the low end.
high {string} -- The value on the high end.
Returns:
self
"""
self._wheres += (BetweenExpression(column, low, high),)
return self
def where_between(self, *args, **kwargs):
return self.between(*args, **kwargs)
def where_not_between(self, *args, **kwargs):
return self.not_between(*args, **kwargs)
def not_between(self, column: str, low: str, high: str):
"""Specifies a where not between expression.
Arguments:
column {string} -- The name of the column.
low {string} -- The value on the low end.
high {string} -- The value on the high end.
Returns:
self
"""
self._wheres += (BetweenExpression(column, low, high, equality="NOT BETWEEN"),)
return self
def where_in(self, column, wheres=None):
"""Specifies where a column contains a list of a values.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
wheres {list} -- A list of values (default: {[]})
Returns:
self
"""
wheres = wheres or []
if not wheres:
self._wheres += ((QueryExpression(0, "=", 1, "value_equals")),)
elif isinstance(wheres, QueryBuilder):
self._wheres += (
(QueryExpression(column, "IN", SubSelectExpression(wheres))),
)
elif callable(wheres):
self._wheres += (
(
QueryExpression(
column, "IN", SubSelectExpression(wheres(self.new()))
)
),
)
else:
self._wheres += ((QueryExpression(column, "IN", list(wheres))),)
return self
def get_relation(self, relationship, builder=None):
if not builder:
builder = self
if not builder._model:
raise AttributeError(
"You must specify a model in order to use relationship methods"
)
return getattr(builder._model, relationship)
def has(self, *relationships):
if not self._model:
raise AttributeError(
"You must specify a model in order to use 'has' relationship methods"
)
for relationship in relationships:
if "." in relationship:
last_builder = self._model.builder
for split_relationship in relationship.split("."):
related = last_builder.get_relation(split_relationship)
last_builder = related.query_has(last_builder)
else:
related = getattr(self._model, relationship)
related.query_has(self)
return self
def or_has(self, *relationships):
if not self._model:
raise AttributeError(
"You must specify a model in order to use 'has' relationship methods"
)
for relationship in relationships:
if "." in relationship:
last_builder = self._model.builder
split_count = len(relationship.split("."))
for index, split_relationship in enumerate(relationship.split(".")):
related = last_builder.get_relation(split_relationship)
if index + 1 == split_count:
last_builder = related.query_has(
last_builder, method="where_exists"
)
continue
last_builder = related.query_has(
last_builder, method="or_where_exists"
)
else:
related = getattr(self._model, relationship)
related.query_has(self, method="or_where_exists")
return self
def doesnt_have(self, *relationships):
if not self._model:
raise AttributeError(
"You must specify a model in order to use the 'doesnt_have' relationship methods"
)
for relationship in relationships:
if "." in relationship:
last_builder = self._model.builder
split_count = len(relationship.split("."))
for index, split_relationship in enumerate(relationship.split(".")):
related = last_builder.get_relation(split_relationship)
if index + 1 == split_count:
last_builder = related.query_has(
last_builder, method="where_exists"
)
continue
last_builder = related.query_has(
last_builder, method="where_not_exists"
)
else:
related = getattr(self._model, relationship)
related.query_has(self, method="where_not_exists")
return self
def or_doesnt_have(self, *relationships):
if not self._model:
raise AttributeError(
"You must specify a model in order to use the 'doesnt_have' relationship methods"
)
for relationship in relationships:
if "." in relationship:
last_builder = self._model.builder
split_count = len(relationship.split("."))
for index, split_relationship in enumerate(relationship.split(".")):
related = last_builder.get_relation(split_relationship)
if index + 1 == split_count:
last_builder = related.query_has(
last_builder, method="where_exists"
)
continue
last_builder = related.query_has(
last_builder, method="or_where_not_exists"
)
else:
related = getattr(self._model, relationship)
related.query_has(self, method="or_where_not_exists")
return self
def where_has(self, relationship, callback):
if not self._model:
raise AttributeError(
"You must specify a model in order to use 'has' relationship methods"
)
if "." in relationship:
last_builder = self._model.builder
splits = relationship.split(".")
split_count = len(splits)
for index, split_relationship in enumerate(splits):
related = last_builder.get_relation(split_relationship)
if index + 1 == split_count:
last_builder = related.query_where_exists(
last_builder, callback, method="where_exists"
)
continue
last_builder = related.query_has(last_builder, method="where_exists")
else:
related = getattr(self._model, relationship)
related.query_where_exists(self, callback, method="where_exists")
return self
def or_where_has(self, relationship, callback):
if not self._model:
raise AttributeError(
"You must specify a model in order to use 'has' relationship methods"
)
if "." in relationship:
last_builder = self._model.builder
splits = relationship.split(".")
split_count = len(splits)
for index, split_relationship in enumerate(splits):
related = last_builder.get_relation(split_relationship)
if index + 1 == split_count:
last_builder = related.query_where_exists(
last_builder, callback, method="where_exists"
)
continue
last_builder = related.query_has(last_builder, method="or_where_exists")
else:
related = getattr(self._model, relationship)
related.query_where_exists(self, callback, method="or_where_exists")
return self
def where_doesnt_have(self, relationship, callback):
if not self._model:
raise AttributeError(
"You must specify a model in order to use the 'doesnt_have' relationship methods"
)
if "." in relationship:
last_builder = self._model.builder
split_count = len(relationship.split("."))
for index, split_relationship in enumerate(relationship.split(".")):
related = last_builder.get_relation(split_relationship)
if index + 1 == split_count:
last_builder = getattr(
last_builder._model, split_relationship
).query_where_exists(self, callback, method="where_exists")
continue
last_builder = related.query_has(
last_builder, method="where_not_exists"
)
else:
related = getattr(self._model, relationship)
related.query_where_exists(self, callback, method="where_not_exists")
return self
def or_where_doesnt_have(self, relationship, callback):
if not self._model:
raise AttributeError(
"You must specify a model in order to use the 'doesnt_have' relationship methods"
)
if "." in relationship:
last_builder = self._model.builder
split_count = len(relationship.split("."))
for index, split_relationship in enumerate(relationship.split(".")):
related = last_builder.get_relation(split_relationship)
if index + 1 == split_count:
last_builder = getattr(
last_builder._model, split_relationship
).query_where_exists(self, callback, method="where_exists")
continue
last_builder = related.query_has(
last_builder, method="or_where_not_exists"
)
else:
related = getattr(self._model, relationship)
related.query_where_exists(self, callback, method="or_where_not_exists")
return self
def with_count(self, relationship, callback=None):
self.select(*self._model.get_selects())
return getattr(self._model, relationship).get_with_count_query(
self, callback=callback
)
def where_not_in(self, column, wheres=None):
"""Specifies where a column does not contain a list of a values.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
wheres {list} -- A list of values (default: {[]})
Returns:
self
"""
wheres = wheres or []
if isinstance(wheres, QueryBuilder):
self._wheres += (
(QueryExpression(column, "NOT IN", SubSelectExpression(wheres))),
)
else:
self._wheres += ((QueryExpression(column, "NOT IN", list(wheres))),)
return self
def join(
self, table: str, column1=None, equality=None, column2=None, clause="inner"
):
"""Specifies a join expression.
Arguments:
table {string} -- The name of the table or an instance of JoinClause.
column1 {string} -- The name of the foreign table.
equality {string} -- The equality to join on.
column2 {string} -- The name of the local column.
Keyword Arguments:
clause {string} -- The action clause. (default: {"inner"})
Returns:
self
"""
if inspect.isfunction(column1):
self._joins += (column1(JoinClause(table, clause=clause)),)
elif isinstance(table, str):
self._joins += (
JoinClause(table, clause=clause).on(column1, equality, column2),
)
else:
self._joins += (table,)
return self
def left_join(self, table, column1=None, equality=None, column2=None):
"""A helper method to add a left join expression.
Arguments:
table {string} -- The name of the table to join on.
column1 {string} -- The name of the foreign table.
equality {string} -- The equality to join on.
column2 {string} -- The name of the local column.
Returns:
self
"""
return self.join(
table=table,
column1=column1,
equality=equality,
column2=column2,
clause="left",
)
def right_join(self, table, column1=None, equality=None, column2=None):
"""A helper method to add a right join expression.
Arguments:
table {string} -- The name of the table to join on.
column1 {string} -- The name of the foreign table.
equality {string} -- The equality to join on.
column2 {string} -- The name of the local column.
Returns:
self
"""
return self.join(
table=table,
column1=column1,
equality=equality,
column2=column2,
clause="right",
)
def joins(self, *relationships, clause="inner"):
for relationship in relationships:
getattr(self._model, relationship).joins(self, clause=clause)
return self
def join_on(self, relationship, callback=None, clause="inner"):
relation = getattr(self._model, relationship)
relation.joins(self, clause=clause)
if callback:
new_from_builder = self.new_from_builder()
new_from_builder.table(relation.get_builder().get_table_name())
self.where_from_builder(callback(new_from_builder))
return self
def where_column(self, column1, column2):
"""Specifies where two columns equal eachother.
Arguments:
column1 {string} -- The name of the column.
column2 {string} -- The name of the column.
Returns:
self
"""
self._wheres += ((QueryExpression(column1, "=", column2, "column")),)
return self
def take(self, *args, **kwargs):
"""Alias for limit method"""
return self.limit(*args, **kwargs)
def limit(self, amount):
"""Specifies a limit expression.
Arguments:
amount {int} -- The number of rows to limit.
Returns:
self
"""
self._limit = amount
return self
def offset(self, amount):
"""Specifies an offset expression.
Arguments:
amount {int} -- The number of rows to limit.
Returns:
self
"""
self._offset = amount
return self
def skip(self, *args, **kwargs):
"""Alias for limit method"""
return self.offset(*args, **kwargs)
def update(
self,
updates: Dict[str, Any],
dry: bool = False,
force: bool = False,
cast: bool = False,
ignore_mass_assignment: bool = False,
):
"""Specifies columns and values to be updated.
Arguments:
updates {dictionary} -- A dictionary of columns and values to update.
dry {bool, optional}: Do everything except execute the query against the DB
force {bool, optional}: Force an update statement to be executed even if nothing was changed
cast {bool, optional}: Run all values through model's casters
ignore_mass_assignment {bool, optional}: Whether the update should ignore mass assignment on the model
Returns:
self
"""
model = None
additional = {}
if self._model:
model = self._model
# Filter __fillable/__guarded__ fields
if not ignore_mass_assignment:
updates = model.filter_mass_assignment(updates)
if model and model.is_loaded():
self.where(model.get_primary_key(), model.get_primary_key_value())
additional.update({model.get_primary_key(): model.get_primary_key_value()})
self.observe_events(model, "updating")
if model:
if not model.__force_update__ and not force:
# Filter updates to only those with changes
updates = {
attr: value
for attr, value in updates.items()
if (
value is None
or model.__original_attributes__.get(attr, None) != value
)
}
# Do not execute query if no changes
if not updates:
return self if dry or self.dry else model
# Cast date fields
date_fields = model.get_dates()
for key, value in updates.items():
if key in date_fields:
if value:
updates[key] = model.get_new_datetime_string(value)
else:
updates[key] = value
# Cast value if necessary
if cast:
if value:
updates[key] = model.cast_value(value)
else:
updates[key] = value
elif not updates:
# Do not perform query if there are no updates
return self
self._updates = (UpdateQueryExpression(updates),)
self.set_action("update")
if dry or self.dry:
return self
additional.update(updates)
self.new_connection().query(self.to_qmark(), self._bindings)
if model:
model.fill(updates)
self.observe_events(model, "updated")
model.fill_original(updates)
return model
return additional
def force_update(self, updates: dict, dry=False):
return self.update(updates, dry=dry, force=True)
def set_updates(self, updates: dict, dry=False):
"""Specifies columns and values to be updated.
Arguments:
updates {dictionary} -- A dictionary of columns and values to update.
Keyword Arguments:
dry {bool} -- Whether the query should be executed. (default: {False})
Returns:
self
"""
self._updates += (UpdateQueryExpression(updates),)
return self
def increment(self, column, value=1, dry=False):
"""Increments a column's value.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
value {int} -- The value to increment by. (default: {1})
Returns:
self
"""
model = None
id_key = "id"
id_value = None
additional = {}
if self._model:
model = self._model
id_value = self._model.get_primary_key_value()
if model and model.is_loaded():
self.where(model.get_primary_key(), model.get_primary_key_value())
additional.update({model.get_primary_key(): model.get_primary_key_value()})
self.observe_events(model, "updating")
self._updates += (
UpdateQueryExpression(column, value, update_type="increment"),
)
if dry or self.dry:
return self.get_grammar().compile("update").to_sql()
self.set_action("update")
results = self.new_connection().query(self.to_qmark(), self._bindings)
processed_results = self.get_processor().get_column_value(
self, column, results, id_key, id_value
)
return processed_results
def decrement(self, column, value=1, dry=False):
"""Decrements a column's value.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
value {int} -- The value to decrement by. (default: {1})
Returns:
self
"""
model = None
id_key = "id"
id_value = None
additional = {}
if self._model:
model = self._model
id_value = self._model.get_primary_key_value()
if model and model.is_loaded():
self.where(model.get_primary_key(), model.get_primary_key_value())
additional.update({model.get_primary_key(): model.get_primary_key_value()})
self.observe_events(model, "updating")
self._updates += (
UpdateQueryExpression(column, value, update_type="decrement"),
)
if dry or self.dry:
return self.get_grammar().compile("update").to_sql()
self.set_action("update")
result = self.new_connection().query(self.to_qmark(), self._bindings)
processed_results = self.get_processor().get_column_value(
self, column, result, id_key, id_value
)
return processed_results
def sum(self, column):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
self.aggregate("SUM", "{column}".format(column=column))
return self
def count(self, column=None, dry=False):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
alias = "m_count_reserved" if (column == "*" or column is None) else column
if column == "*":
self.aggregate("COUNT", f"{column} as {alias}")
elif column is None:
self.aggregate("COUNT", f"* as {alias}")
else:
self.aggregate("COUNT", f"{column}")
if dry or self.dry:
return self
if not column:
result = self.new_connection().query(
self.to_qmark(), self._bindings, results=1
)
if isinstance(result, dict):
return result.get(alias, 0)
prepared_result = list(result.values())
if not prepared_result:
return 0
return prepared_result[0]
else:
return self
def max(self, column):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
self.aggregate("MAX", "{column}".format(column=column))
return self
def order_by(self, column, direction="ASC"):
"""Specifies a column to order by.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
direction {string} -- Specify either ASC or DESC order. (default: {"ASC"})
Returns:
self
"""
for col in column.split(","):
self._order_by += (OrderByExpression(col, direction=direction),)
return self
def order_by_raw(self, query, bindings=None):
"""Specifies a column to order by.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
direction {string} -- Specify either ASC or DESC order. (default: {"ASC"})
Returns:
self
"""
if bindings is None:
bindings = []
self._order_by += (OrderByExpression(query, raw=True, bindings=bindings),)
return self
def group_by(self, column):
"""Specifies a column to group by.
Arguments:
column {string} -- The name of the column to group by.
Returns:
self
"""
for col in column.split(","):
self._group_by += (GroupByExpression(column=col),)
return self
def group_by_raw(self, query, bindings=None):
"""Specifies a column to group by.
Arguments:
query {string} -- A raw query
Returns:
self
"""
if bindings is None:
bindings = []
self._group_by += (
GroupByExpression(column=query, raw=True, bindings=bindings),
)
return self
def aggregate(self, aggregate, column, alias=None):
"""Helper function to aggregate.
Arguments:
aggregate {string} -- The name of the aggregation.
column {string} -- The name of the column to aggregate.
"""
self._aggregates += (
AggregateExpression(aggregate=aggregate, column=column, alias=alias),
)
def first(self, fields=None, query=False):
"""Gets the first record.
Returns:
dictionary -- Returns a dictionary of results.
"""
if not fields:
fields = []
self.select(fields).limit(1)
if query:
return self
result = self.new_connection().query(self.to_qmark(), self._bindings, results=1)
return self.prepare_result(result)
def first_or_create(self, wheres, creates: dict = None):
"""Get the first record matching the attributes or create it.
Returns:
Model
"""
if creates is None:
creates = {}
record = self.where(wheres).first()
total = {}
if record:
if hasattr(record, "serialize"):
total.update(record.serialize())
else:
total.update(record)
total.update(creates)
total.update(wheres)
total.update(self._creates_related)
if not record:
return self.create(total, id_key=self.get_primary_key())
return record
def sole(self, query=False):
"""Gets the only record matching a given criteria."""
result = self.take(2).get()
if result.is_empty():
raise ModelNotFound()
if result.count() > 1:
raise MultipleRecordsFound()
return result.first()
def sole_value(self, column: str, query=False):
return self.sole()[column]
def first_where(self, column, *args):
"""Gets the first record with the given key / value pair"""
if not args:
return self.where_not_null(column).first()
return self.where(column, *args).first()
def last(self, column=None, query=False):
"""Gets the last record, ordered by column in descendant order or primary
key if no column is given.
Returns:
dictionary -- Returns a dictionary of results.
"""
_column = column if column else self._model.get_primary_key()
self.limit(1).order_by(_column, direction="DESC")
if query:
return self
result = self.new_connection().query(
self.to_qmark(),
self._bindings,
results=1,
)
return self.prepare_result(result)
def _get_eager_load_result(self, related, collection):
return related.eager_load_from_collection(collection)
def find(self, record_id, column=None, query=False):
"""Finds a row by the primary key ID. Requires a model
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model|None
"""
if not column:
if not self._model:
raise InvalidArgument("A colum to search is required")
column = self._model.get_primary_key()
if isinstance(record_id, (list, tuple)):
self.where_in(column, record_id)
else:
self.where(column, record_id)
if query:
return self
return self.first()
def find_or(self, record_id: int, callback: Callable, args=None, column=None):
"""Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
callback {Callable} -- The function to call if no record is found.
Returns:
Model|Callable
"""
if not callable(callback):
raise InvalidArgument("A callback must be callable.")
result = self.find(record_id=record_id, column=column)
if not result:
if not args:
return callback()
else:
return callback(*args)
return result
def find_or_fail(self, record_id, column=None):
"""Finds a row by the primary key ID (Requires a model) or raise a ModelNotFound exception.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model|ModelNotFound
"""
result = self.find(record_id=record_id, column=column)
if not result:
raise ModelNotFound()
return result
def find_or_404(self, record_id, column=None):
"""Finds a row by the primary key ID (Requires a model) or raise an 404 exception.
Arguments:
record_id {int} -- The ID of the primary key to fetch.
Returns:
Model|HTTP404
"""
try:
return self.find_or_fail(record_id=record_id, column=column)
except ModelNotFound:
raise HTTP404()
def first_or_fail(self, query=False):
"""Returns the first row from database. If no result found a ModelNotFound exception.
Returns:
dictionary|ModelNotFound
"""
if query:
return self.first(query=True)
result = self.first()
if not result:
raise ModelNotFound()
return result
def get_primary_key(self):
return self._model.get_primary_key()
def prepare_result(self, result, collection=False):
if self._model and result:
# eager load here
hydrated_model = self._model.hydrate(result)
if (
self._eager_relation.eagers
or self._eager_relation.nested_eagers
or self._eager_relation.callback_eagers
) and hydrated_model:
for eager_load in self._eager_relation.get_eagers():
if isinstance(eager_load, dict):
# Nested
for relation, eagers in eager_load.items():
callback = None
if inspect.isclass(self._model):
related = getattr(self._model, relation)
elif callable(eagers):
related = getattr(self._model, relation)
callback = eagers
else:
related = self._model.get_related(relation)
result_set = related.get_related(
self, hydrated_model, eagers=eagers, callback=callback
)
self._register_relationships_to_model(
related,
result_set,
hydrated_model,
relation_key=relation,
)
else:
# Not Nested
for eager in eager_load:
if inspect.isclass(self._model):
related = getattr(self._model, eager)
else:
related = self._model.get_related(eager)
result_set = related.get_related(self, hydrated_model)
self._register_relationships_to_model(
related, result_set, hydrated_model, relation_key=eager
)
if collection:
return hydrated_model if result else Collection([])
else:
return hydrated_model if result else None
if collection:
return Collection(result) if result else Collection([])
else:
return result or None
def _register_relationships_to_model(
self, related, related_result, hydrated_model, relation_key
):
"""Takes a related result and a hydrated model and registers them to eachother using the relation key.
Args:
related_result (Model|Collection): Will be the related result based on the type of relationship.
hydrated_model (Model|Collection): If a collection we will need to loop through the collection of models
and register each one individually. Else we can just load the
related_result into the hydrated_models
relation_key (string): A key to bind the relationship with. Defaults to None.
Returns:
self
"""
if related_result and isinstance(hydrated_model, Collection):
map_related = self._map_related(related_result, related)
for model in hydrated_model:
if isinstance(related_result, Collection):
related.register_related(relation_key, model, map_related)
else:
model.add_relation({relation_key: map_related or None})
else:
hydrated_model.add_relation({relation_key: related_result or None})
return self
def _map_related(self, related_result, related):
return related.map_related(related_result)
def all(self, selects=[], query=False):
"""Returns all records from the table.
Returns:
dictionary -- Returns a dictionary of results.
"""
self.select(*selects)
if query:
return self
result = self.new_connection().query(self.to_qmark(), self._bindings) or []
return self.prepare_result(result, collection=True)
def get(self, selects=[]):
"""Runs the select query built from the query builder.
Returns:
self
"""
self.select(*selects)
result = self.new_connection().query(self.to_qmark(), self._bindings)
return self.prepare_result(result, collection=True)
def new_connection(self):
if self._connection:
return self._connection
self._connection = (
self.connection_class(
**self.get_connection_information(), name=self.connection
)
.set_schema(self._schema)
.make_connection()
)
return self._connection
def get_connection(self):
return self._connection
def without_eager(self):
self._should_eager = False
return self
def with_(self, *eagers):
self._eager_relation.register(eagers)
return self
def paginate(self, per_page, page=1):
if page == 1:
offset = 0
else:
offset = (int(page) * per_page) - per_page
new_from_builder = self.new_from_builder()
new_from_builder._order_by = ()
new_from_builder._columns = ()
result = self.limit(per_page).offset(offset).get()
total = new_from_builder.count()
paginator = LengthAwarePaginator(result, per_page, page, total)
return paginator
def simple_paginate(self, per_page, page=1):
if page == 1:
offset = 0
else:
offset = (int(page) * per_page) - per_page
result = self.limit(per_page).offset(offset).get()
paginator = SimplePaginator(result, per_page, page)
return paginator
def set_action(self, action):
"""Sets the action that the query builder should take when the query is built.
Arguments:
action {string} -- The action that the query builder should take.
Returns:
self
"""
self._action = action
return self
def get_grammar(self):
"""Initializes and returns the grammar class.
Returns:
masoniteorm.grammar.Grammar -- An ORM grammar class.
"""
# Either _creates when creating, otherwise use columns
columns = self._creates or self._columns
if not columns and not self._aggregates and self._model:
self.select(*self._model.get_selects())
columns = self._columns
return self.grammar(
columns=columns,
table=self._table,
wheres=self._wheres,
limit=self._limit,
offset=self._offset,
updates=self._updates,
aggregates=self._aggregates,
order_by=self._order_by,
group_by=self._group_by,
distinct=self._distinct,
lock=self.lock,
joins=self._joins,
having=self._having,
)
def to_sql(self):
"""Compiles the QueryBuilder class into a SQL statement.
Returns:
self
"""
self.run_scopes()
grammar = self.get_grammar()
sql = grammar.compile(self._action, qmark=False).to_sql()
return sql
def explain(self):
"""Explains the Query execution plan.
Returns:
Collection
"""
sql = self.to_sql()
explanation = self.statement(f"EXPLAIN {sql}")
return explanation
def run_scopes(self):
for name, scope in self._global_scopes.get(self._action, {}).items():
scope(self)
return self
def to_qmark(self):
"""Compiles the QueryBuilder class into a Qmark SQL statement.
Returns:
self
"""
self.run_scopes()
grammar = self.get_grammar()
sql = grammar.compile(self._action, qmark=True).to_sql()
self._bindings = grammar._bindings
self.reset()
return sql
def new(self):
"""Creates a new QueryBuilder class.
Returns:
QueryBuilder -- The ORM QueryBuilder class.
"""
builder = QueryBuilder(
grammar=self.grammar,
connection_class=self.connection_class,
connection=self.connection,
connection_driver=self._connection_driver,
model=self._model,
)
if self._table:
builder.table(self._table.name)
return builder
def avg(self, column):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
self.aggregate("AVG", "{column}".format(column=column))
return self
def min(self, column):
"""Aggregates a columns values.
Arguments:
column {string} -- The name of the column to aggregate.
Returns:
self
"""
self.aggregate("MIN", "{column}".format(column=column))
return self
def _extract_operator_value(self, *args):
operators = [
"=",
">",
">=",
"<",
"<=",
"!=",
"<>",
"like",
"not like",
"regexp",
"not regexp",
]
operator = operators[0]
value = None
if (len(args)) >= 2:
operator = args[0]
value = args[1]
elif len(args) == 1:
value = args[0]
if operator not in operators:
raise ValueError(
"Invalid comparison operator. The operator can be %s"
% ", ".join(operators)
)
return operator, value
def __call__(self):
"""Magic method to standardize what happens when the query builder object is called.
Returns:
self
"""
return self
def macro(self, name, callable):
self._macros.update({name: callable})
return self
def when(self, conditional, callback):
if conditional:
callback(self)
return self
def truncate(self, foreign_keys=False, dry=False):
sql = self.get_grammar().truncate_table(self.get_table_name(), foreign_keys)
if dry or self.dry:
return sql
return self.new_connection().query(sql, ())
def exists(self):
"""Determine if rows exist for the current query.
Returns:
Bool - True or False
"""
if self.first():
return True
else:
return False
def doesnt_exist(self):
"""Determine if no rows exist for the current query.
Returns:
Bool - True or False
"""
if self.exists():
return False
else:
return True
def in_random_order(self):
"""Puts Query results in random order"""
return self.order_by_raw(self.grammar().compile_random())
def new_from_builder(self, from_builder=None):
"""Creates a new QueryBuilder class.
Returns:
QueryBuilder -- The ORM QueryBuilder class.
"""
if from_builder is None:
from_builder = self
builder = QueryBuilder(
grammar=self.grammar,
connection_class=self.connection_class,
connection=self.connection,
connection_driver=self._connection_driver,
)
if self._table:
builder.table(self._table.name)
builder._columns = deepcopy(from_builder._columns)
builder._creates = deepcopy(from_builder._creates)
builder._sql = ""
builder._bindings = deepcopy(from_builder._bindings)
builder._updates = deepcopy(from_builder._updates)
builder._wheres = deepcopy(from_builder._wheres)
builder._order_by = deepcopy(from_builder._order_by)
builder._group_by = deepcopy(from_builder._group_by)
builder._joins = deepcopy(from_builder._joins)
builder._having = deepcopy(from_builder._having)
builder._macros = deepcopy(from_builder._macros)
builder._aggregates = deepcopy(from_builder._aggregates)
builder._global_scopes = deepcopy(from_builder._global_scopes)
return builder
def get_table_columns(self):
return self.get_schema().get_columns(self._table.name)
def get_schema(self):
return Schema(
connection=self.connection, connection_details=self._connection_details
)
def latest(self, *fields):
"""Gets the latest record.
Returns:
querybuilder
"""
if not fields:
fields = ("created_at",)
return self.order_by(column=",".join(fields), direction="DESC")
def oldest(self, *fields):
"""Gets the oldest record.
Returns:
querybuilder
"""
if not fields:
fields = ("created_at",)
return self.order_by(column=",".join(fields), direction="ASC")
def value(self, column: str):
return self.get().first()[column]
================================================
FILE: src/masoniteorm/query/__init__.py
================================================
from .QueryBuilder import QueryBuilder
================================================
FILE: src/masoniteorm/query/grammars/BaseGrammar.py
================================================
import re
from ...expressions.expressions import (
JoinClause,
OnClause,
SelectExpression,
SubGroupExpression,
SubSelectExpression,
)
class BaseGrammar:
"""The keys in this dictionary is how the ORM will reference these aggregates
The values on the right are the matching functions for the grammar
Returns:
[type] -- [description]
"""
table = "users"
def __init__(
self,
columns=(),
table="users",
database=None,
wheres=(),
limit=False,
offset=False,
updates=None,
aggregates=(),
order_by=(),
distinct=False,
group_by=(),
joins=(),
lock=False,
having=(),
connection_details=None,
):
self._columns = columns
self.table = table
self.database = database
self._wheres = wheres
self._limit = limit
self._offset = offset
self._updates = updates or {}
self._aggregates = aggregates
self._order_by = order_by
self._group_by = group_by
self._distinct = distinct
self._joins = joins
self._having = having
self.lock = lock
self._connection_details = connection_details or {}
self._column = None
self._bindings = []
self._sql = ""
self._sql_qmark = ""
self._action = "select"
self.queries = []
def compile(self, action, qmark=False):
self._action = action
return getattr(self, "_compile_" + action)(qmark=qmark)
def _compile_select(self, qmark=False):
"""Compile a select query statement.
Keyword Arguments:
qmark {bool} -- [description] (default: {False})
Returns:
[type] -- [description]
"""
if not self.table:
self._sql = (
self.select_no_table()
.format(
columns=self.process_columns(separator=", ", qmark=qmark),
table=self.process_table(self.table),
joins=self.process_joins(qmark=qmark),
wheres=self.process_wheres(qmark=qmark),
limit=self.process_limit(),
offset=self.process_offset(),
aggregates=self.process_aggregates(),
order_by=self.process_order_by(),
group_by=self.process_group_by(),
having=self.process_having(),
lock=self.process_locks(),
)
.strip()
)
else:
self._sql = (
self.select_format()
.format(
columns=self.process_columns(separator=", ", qmark=qmark),
keyword="DISTINCT" if self._distinct else "",
table=self.process_table(self.table),
joins=self.process_joins(qmark=qmark),
wheres=self.process_wheres(qmark=qmark),
limit=self.process_limit(),
offset=self.process_offset(),
aggregates=self.process_aggregates(),
order_by=self.process_order_by(),
group_by=self.process_group_by(),
having=self.process_having(),
lock=self.process_locks(),
)
.strip()
)
return self
def _compile_update(self, qmark=False):
"""Compiles an update query statement.
Keyword Arguments:
qmark {bool} -- Whether the query should use qmark. (default: {False})
Returns:
self
"""
self._sql = self.update_format().format(
key_equals=self._compile_key_value_equals(qmark=qmark),
table=self.process_table(self.table),
wheres=self.process_wheres(qmark=qmark),
)
return self
def _compile_insert(self, qmark=False):
"""Compiles an insert expression.
Returns:
self
"""
self._sql = self.insert_format().format(
key_equals=self._compile_key_value_equals(qmark=qmark),
table=self.process_table(self.table),
columns=self.process_columns(separator=", ", action="insert", qmark=qmark),
values=self.process_values(separator=", ", qmark=qmark),
)
return self
def _compile_bulk_create(self, qmark=False):
"""Compiles an insert expression.
Returns:
self
"""
all_values = [list(x.values()) for x in self._columns]
self._sql = self.bulk_insert_format().format(
key_equals=self._compile_key_value_equals(qmark=qmark),
table=self.process_table(self.table),
columns=self.columnize_bulk_columns(list(self._columns[0].keys())),
values=self.columnize_bulk_values(all_values, qmark=qmark),
)
return self
def columnize_bulk_columns(self, columns=[]):
return ", ".join(
self.column_string().format(column=x, separator="") for x in columns
).rstrip(",")
def columnize_bulk_values(self, columns=[], qmark=False):
sql = ""
for x in columns:
inner = ""
if isinstance(x, list):
for y in x:
if qmark:
self.add_binding(y)
inner += (
"'?', "
if qmark
else self.value_string().format(value=y, separator=", ")
)
inner = inner.rstrip(", ")
sql += self.process_value_string().format(value=inner, separator=", ")
else:
if qmark:
self.add_binding(x)
sql += (
"'?', "
if qmark
else self.process_value_string().format(
value="?" if qmark else x, separator=", "
)
)
return sql.rstrip(", ")
def process_value_string(self):
return "({value}){separator}"
def _compile_delete(self, qmark=False):
"""Compiles a delete expression.
Returns:
self
"""
self._sql = self.delete_format().format(
key_equals=self._compile_key_value_equals(qmark=qmark),
table=self.process_table(self.table),
wheres=self.process_wheres(qmark=qmark),
)
return self
# TODO: Columnize?
def _get_multiple_columns(self, columns):
"""Compiles a string or a list of strings into the grammars column syntax.
Arguments:
columns {string|list} -- A column or list of columns
Returns:
self
"""
if isinstance(columns, list):
column_string = ""
for col in columns:
column_string += self.process_column(col) + ", "
return column_string.rstrip(", ")
return self.process_column(columns)
def process_joins(self, qmark=False):
"""Compiles a join expression.
Returns:
self
"""
sql = ""
for join in self._joins:
if isinstance(join, JoinClause):
on_string = ""
for clause_idx, clause in enumerate(join.get_on_clauses()):
keyword = clause.operator.upper() if clause_idx else "ON"
if isinstance(clause, OnClause):
on_string += f"{keyword} {self._table_column_string(clause.column1)} {clause.equality} {self._table_column_string(clause.column2)} "
else:
if clause.value_type == "NULL":
sql_string = f"{self.where_null_string()} "
on_string += sql_string.format(
keyword=keyword,
column=self.process_column(clause.column),
)
elif clause.value_type == "NOT NULL":
sql_string = f"{self.where_not_null_string()} "
on_string += sql_string.format(
keyword=keyword,
column=self.process_column(clause.column),
)
else:
if qmark:
value = "'?'"
self.add_binding(clause.value)
else:
value = self._compile_value(clause.value)
on_string += f"{keyword} {self._table_column_string(clause.column)} {clause.equality} {value} "
sql += self.join_string().format(
foreign_table=self.process_table(join.table),
alias=f" AS {self.process_table(join.alias)}" if join.alias else "",
on=on_string,
keyword=self.join_keywords[join.clause],
)
sql += " "
return sql
# TODO: Clean
def _compile_key_value_equals(self, qmark=False):
"""Compiles key value pairs.
Keyword Arguments:
qmark {bool} -- Whether the query should use qmark. (default: {False})
Returns:
self
"""
sql = ""
for update in self._updates:
if update.update_type == "increment":
sql_string = self.increment_string()
elif update.update_type == "decrement":
sql_string = self.decrement_string()
else:
sql_string = self.key_value_string()
column = update.column
value = update.value
if isinstance(column, dict):
for key, value in column.items():
if hasattr(value, "expression"):
sql += self.column_value_string().format(
column=self._table_column_string(key),
value=value.expression,
separator=", ",
)
else:
sql += sql_string.format(
column=self._table_column_string(key),
value=value if not qmark else "?",
separator=", ",
)
if qmark:
self._bindings += (value,)
else:
sql += sql_string.format(
column=self._table_column_string(column),
value=value if not qmark else "?",
separator=", ",
)
if qmark:
self._bindings += (value,)
sql = sql.rstrip(", ")
return sql
def process_aggregates(self):
"""Compiles aggregates to be used in a query expression.
Returns:
self
"""
sql = ""
for aggregates in self._aggregates:
aggregate = aggregates.aggregate
column = aggregates.column
aggregate_function = self.aggregate_options.get(aggregate, "")
if not aggregates.alias and column == "*":
aggregate_string = self.aggregate_string_without_alias()
else:
aggregate_string = self.aggregate_string_with_alias()
sql += (
aggregate_string.format(
aggregate_function=aggregate_function,
column="*" if column == "*" else self._table_column_string(column),
alias=self.process_alias(aggregates.alias or column),
)
+ ", "
)
return sql
def process_order_by(self):
"""Compiles an order by for a query expression.
Returns:
self
"""
sql = ""
if self._order_by:
order_crit = ""
for order_bys in self._order_by:
if order_bys.raw:
order_crit += order_bys.column
if not isinstance(order_bys.bindings, (list, tuple)):
raise ValueError(
f"Bindings must be tuple or list. Received {type(order_bys.bindings)}"
)
if order_bys.bindings:
self.add_binding(*order_bys.bindings)
continue
if len(order_crit):
order_crit += ", "
column = order_bys.column
direction = order_bys.direction
if "." in column:
column_string = self._table_column_string(column)
else:
column_string = self.column_string().format(
column=column, separator=""
)
order_crit += self.order_by_format().format(
column=column_string, direction=direction.upper()
)
sql += self.order_by_string().format(order_columns=order_crit)
return sql
def process_group_by(self):
"""Compiles a group by for a query expression.
Returns:
self
"""
sql = ""
columns = []
for group_by in self._group_by:
if group_by.raw:
if group_by.bindings:
self.add_binding(*group_by.bindings)
sql += "GROUP BY " + group_by.column
return sql
else:
columns.append(self._table_column_string(group_by.column))
if columns:
sql += " GROUP BY {column}".format(column=", ".join(columns))
return sql
def process_alias(self, column):
"""Compiles an alias for a column.
Arguments:
column {string} -- The name of the column.
Returns:
self
"""
return column
def process_table(self, table):
"""Compiles a given table name.
Arguments:
table {string} -- The table name to compile.
Returns:
self
"""
if not table:
return ""
if isinstance(table, str):
return ".".join(
self.table_string().format(
table=t,
database=self._connection_details.get("database", ""),
prefix=self._connection_details.get("prefix", ""),
)
for t in table.split(".")
)
if table.raw:
return table.name
return ".".join(
self.table_string().format(
table=t,
database=self._connection_details.get("database", ""),
prefix=self._connection_details.get("prefix", ""),
)
for t in table.name.split(".")
)
def process_limit(self):
"""Compiles the limit expression.
Returns:
self
"""
if not self._limit:
return ""
return self.limit_string(offset=self._offset).format(limit=self._limit)
def process_offset(self):
"""Compiles the offset expression.
Returns:
self
"""
if not self._offset:
return ""
return self.offset_string().format(offset=self._offset, limit=self._limit or 1)
def process_locks(self):
return self.locks.get(self.lock, "")
def process_having(self, qmark=False):
"""Compiles having expression.
Keyword Arguments:
qmark {bool} -- Whether or not to use Qmark (default: {False})
Returns:
self
"""
sql = ""
for having in self._having:
column = having.column
equality = having.equality
value = having.value
raw = having.raw
if not equality and not value:
sql_string = self.having_string()
else:
sql_string = self.having_equality_string()
sql += sql_string.format(
column=self._table_column_string(column) if raw is False else column,
equality=equality,
value=self._compile_value(value),
)
return sql
def process_wheres(self, qmark=False, strip_first_where=False):
"""Compiles the where expression.
Keyword Arguments:
qmark {bool} -- Whether or not to use Qmark. (default: {False})
strip_first_where {bool} -- Whether or not to strip out the first where keyword.
This is useful when using subselects (default: {False})
Returns:
self
"""
sql = ""
loop_count = 0
for where in self._wheres:
column = where.column
equality = where.equality
value = where.value
value_type = where.value_type
"""Need to get a specific keyword here. This keyword either needs to be
something like WHERE, AND, OR.
Depending on the loop depends on the placement of the AND
"""
if loop_count == 0:
if strip_first_where:
keyword = ""
else:
keyword = " " + self.first_where_string()
elif hasattr(where, "keyword") and where.keyword == "or":
keyword = " " + self.or_where_string()
else:
keyword = " " + self.additional_where_string()
if where.raw:
"""If we have a raw query we just want to use the query supplied
and don't need to compile anything.
"""
sql += self.raw_query_string().format(
keyword=keyword, query=where.column
)
if not isinstance(where.bindings, (list, tuple)):
raise ValueError(
f"Bindings must be tuple or list. Received {type(where.bindings)}"
)
if where.bindings:
self.add_binding(*where.bindings)
loop_count += 1
continue
"""The column is an easy compile
"""
column = self._table_column_string(column)
"""Need to find which type of where string it is.
If it is a WHERE NULL, WHERE EXISTS, WHERE `col` = 'val' etc
"""
equality = equality.upper()
if equality == "BETWEEN":
low = where.low
high = where.high
if qmark:
self.add_binding(low)
self.add_binding(high)
low = "?"
high = "?"
sql_string = self.between_string().format(
low=self._compile_value(low),
high=self._compile_value(high),
column=self._table_column_string(where.column),
keyword=keyword,
)
elif equality == "NOT BETWEEN":
sql_string = self.not_between_string().format(
low=self._compile_value(where.low),
high=self._compile_value(where.high),
column=self._table_column_string(where.column),
keyword=keyword,
)
elif value_type == "value_equals":
sql_string = self.value_equal_string().format(
value1=where.column, value2=where.value, keyword=keyword
)
elif value_type == "NULL":
sql_string = self.where_null_string()
elif value_type == "DATE":
sql_string = self.where_date_string()
elif value_type == "NOT NULL":
sql_string = self.where_not_null_string()
elif equality == "EXISTS":
sql_string = self.where_exists_string()
elif equality == "NOT EXISTS":
sql_string = self.where_not_exists_string()
elif equality == "LIKE":
sql_string = self.where_like_string()
elif equality == "REGEXP":
sql_string = self.where_regexp_string()
elif equality == "NOT REGEXP":
sql_string = self.where_not_regexp_string()
elif equality == "NOT LIKE":
sql_string = self.where_not_like_string()
else:
sql_string = self.where_string()
"""If the value should actually be a sub query then we need to wrap it in a query here
"""
if isinstance(value, SubGroupExpression):
grammar = value.builder.get_grammar()
query_value = (
self.subquery_string()
.format(
query=grammar.process_wheres(
qmark=qmark, strip_first_where=True
)
)
.replace("( ", "(")
)
if grammar._bindings:
self.add_binding(*grammar._bindings)
sql_string = self.where_group_string()
elif isinstance(value, SubSelectExpression):
if qmark:
query_from_builder = value.builder.to_qmark()
if value.builder._bindings:
self.add_binding(*value.builder._bindings)
else:
query_from_builder = value.builder.to_sql()
query_value = self.subquery_string().format(query=query_from_builder)
elif isinstance(value, list):
query_value = "("
for val in value:
if qmark:
query_value += "'?', "
self.add_binding(val)
else:
query_value += self.value_string().format(
value=val, separator=","
)
query_value = query_value.rstrip(",").rstrip(", ") + ")"
elif value is True and value_type != "NOT NULL":
sql_string = self.get_true_column_string()
query_value = 1
elif value is False and value_type != "NOT NULL":
sql_string = self.get_false_column_string()
query_value = 0
elif qmark and value_type != "column":
query_value = "'?'"
if (
value is not True
and value_type != "value_equals"
and value_type != "NULL"
and value_type != "BETWEEN"
):
self.add_binding(value)
elif value_type == "value":
if qmark:
query_value = "'?'"
else:
query_value = self.value_string().format(value=value, separator="")
self.add_binding(value)
elif value_type == "column":
query_value = self._table_column_string(column=value, separator="")
elif value_type == "DATE":
query_value = self.value_string().format(value=value, separator="")
elif value_type == "having":
query_value = self._table_column_string(column=value, separator="")
else:
query_value = ""
sql += sql_string.format(
keyword=keyword, column=column, equality=equality, value=query_value
)
loop_count += 1
return sql
def get_true_column_string(self):
return "{keyword} {column} = '1'"
def get_false_column_string(self):
return "{keyword} {column} = '0'"
def add_binding(self, *bindings):
"""Adds one or more bindings to the bindings tuple.
Arguments:
binding {string} -- A value to bind.
"""
self._bindings += bindings
def column_exists(self, column):
"""Check if a column exists
Arguments:
column {string} -- The name of the column to check for existence.
Returns:
self
"""
self._column = column
self._sql = self.process_exists()
return self
def table_exists(self):
"""Checks if a table exists.
Returns:
self
"""
self._sql = self.table_exists_string().format(
table=self.process_table(self.table),
database=self.database,
clean_table=self.table,
)
return self
def wrap_table(self, table_name):
return self.table_string().format(table=table_name)
def process_exists(self):
"""Specifies the column exists expression.
Returns:
self
"""
return self.column_exists_string().format(
table=self.process_table(self.table),
clean_table=self.table,
value=self._compile_value(self._column),
)
def to_sql(self):
"""Cleans up the SQL string and returns the SQL
Returns:
string
"""
return re.sub(" +", " ", self._sql.strip())
def to_qmark(self):
"""Cleans up the SQL string and returns the SQL
Returns:
string
"""
return re.sub(" +", " ", self._sql.strip())
# TODO: Inspect this can't just be used by another method. seems duplicative
def process_columns(self, separator="", action="select", qmark=False):
"""Specifies the columns in a selection expression.
Keyword Arguments:
separator {str} -- The separator used between columns (default: {""})
Returns:
self
"""
sql = ""
for column in self._columns:
alias = None
if isinstance(column, SelectExpression):
alias = column.alias
if column.raw:
sql += column.column + ", "
continue
column = column.column
if isinstance(column, SubGroupExpression):
if qmark:
builder_sql = column.builder.to_qmark()
if column.builder._bindings:
self.add_binding(*column.builder._bindings)
else:
builder_sql = column.builder.to_sql()
sql += f"({builder_sql}) AS {column.alias}, "
continue
sql += self._table_column_string(column, alias=alias, separator=separator)
if self._aggregates:
sql += self.process_aggregates()
if sql == "":
return "*"
return sql.rstrip(",").rstrip(", ")
# TODO: Duplicative?
def process_values(self, separator="", qmark=False):
"""Compiles column values for insert expressions.
Keyword Arguments:
separator {str} -- The separator used between columns (default: {""})
Returns:
self
"""
sql = ""
if self._columns == "*":
return self._columns
elif isinstance(self._columns, list):
for c in self._columns:
for column, value in dict(c).items():
if qmark:
self.add_binding(value)
sql += f"'?'{separator}".strip()
else:
sql += self._compile_value(value, separator=separator)
else:
for column, value in dict(self._columns).items():
if qmark:
self.add_binding(value)
sql += f"'?'{separator}".strip()
else:
sql += self._compile_value(value, separator=separator)
if not qmark:
return sql[:-2]
return sql.rstrip(separator.strip())
def process_column(self, column, separator=""):
"""Compiles a column into the column syntax.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
separator {string} -- The separator used between columns (default: {""})
Returns:
self
"""
table = None
if column and "." in column:
table, column = column.split(".")
return self.column_string().format(
column=column, separator=separator, table=table or self.table
)
def _table_column_string(self, column, alias=None, separator=""):
"""Compiles a column into the column syntax.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
separator {string} -- The separator used between columns (default: {""})
Returns:
self
"""
table = None
if column and "." in column:
table, column = column.split(".")
if column == "*":
return self.column_strings.get("select_all").format(
column=column,
separator=separator,
table=self.process_table(table or self.table),
)
if alias:
alias_string = self.subquery_alias_string().format(alias=alias)
return self.column_strings.get(self._action).format(
column=column,
separator=separator,
alias=" " + alias_string if alias else "",
table=self.process_table(table or self.table),
)
def _compile_value(self, value, separator=""):
"""Compiles a value using the value syntax.
Arguments:
value {string} -- The value to compile.
Keyword Arguments:
separator {string} -- The separator used between columns (default: {""})
Returns:
self
"""
return self.value_string().format(value=value, separator=separator)
def drop_table(self, table):
"""Specifies a drop table expression.
Arguments:
table {string} -- The table to drop.
Returns:
self
"""
self._sql = self.drop_table_string().format(table=self.process_column(table))
return self
def drop_table_if_exists(self, table):
"""Specifies a drop table if exists expression.
Arguments:
table {string} -- The name of the table to drop.
Returns:
self
"""
self._sql = self.drop_table_if_exists_string().format(
table=self.process_column(table)
)
return self
def rename_table(self, current_table_name, new_table_name):
"""Specifies a rename table expression.
Arguments:
current_table_name {string} -- The name of the table currently.
new_table_name {string} -- The name you want to use now for the table.
Returns:
self
"""
self._sql = self.rename_table_string().format(
current_table_name=self.process_column(current_table_name),
new_table_name=self.process_column(new_table_name),
)
return self
def truncate_table(self, table, foreign_keys=False):
"""Specifies a truncate table expression.
Arguments;
table {string} -- The name of the table to truncate.
Returns:
self
"""
raise NotImplementedError(
f"'{self.__class__.__name__}' does not support truncating"
)
def where_regexp_string(self):
return "{keyword} {column} REGEXP {value}"
def where_not_regexp_string(self):
return "{keyword} {column} NOT REGEXP {value}"
================================================
FILE: src/masoniteorm/query/grammars/MSSQLGrammar.py
================================================
from .BaseGrammar import BaseGrammar
class MSSQLGrammar(BaseGrammar):
"""Microsoft SQL Server grammar class."""
aggregate_options = {
"SUM": "SUM",
"MAX": "MAX",
"MIN": "MIN",
"AVG": "AVG",
"COUNT": "COUNT",
}
join_keywords = {
"inner": "INNER JOIN",
"join": "INNER JOIN",
"outer": "OUTER JOIN",
"left": "LEFT JOIN",
"right": "RIGHT JOIN",
"left_inner": "LEFT INNER JOIN",
"right_inner": "RIGHT INNER JOIN",
}
column_strings = {
"select": "{table}.[{column}]{alias}{separator}",
"select_all": "{table}.*{separator}",
"insert": "{table}.[{column}]{separator}",
"update": "{table}.[{column}]{separator}",
"delete": "{table}.[{column}]{separator}",
}
locks = {"share": "WITH(ROWLOCK)", "update": "WITH(ROWLOCK)"}
def select_no_table(self):
return "SELECT {columns}"
def select_format(self):
return "SELECT {keyword} {limit} {columns} FROM {table} {lock} {joins} {wheres} {group_by} {having} {order_by} {offset}"
def update_format(self):
return "UPDATE {table} SET {key_equals} {wheres}"
def insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES ({values})"
def bulk_insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES {values}"
def delete_format(self):
return "DELETE FROM {table} {wheres}"
def create_column_string(self):
return "{column} {data_type}{length}{nullable}, "
def create_start(self):
return "CREATE TABLE {table} "
def having_string(self):
return "HAVING {column}"
def where_exists_string(self):
return "{keyword} EXISTS {value}"
def where_not_exists_string(self):
return "{keyword} NOT EXISTS {value}"
def where_like_string(self):
return "{keyword} {column} LIKE {value}"
def where_not_like_string(self):
return "{keyword} {column} NOT LIKE {value}"
def where_date_string(self):
return "{keyword} DATE({column}) {equality} {value}"
def where_regexp_string(self):
return self.where_like_string()
def where_not_regexp_string(self):
return self.where_not_like_string()
def having_equality_string(self):
return "HAVING {column} {equality} {value}"
def aggregate_string_without_alias(self):
return "{aggregate_function}({column})"
def create_column_length(self):
return "({length})"
def limit_string(self, offset=False):
if offset:
return ""
return "TOP {limit}"
def first_where_string(self):
return "WHERE"
def additional_where_string(self):
return "AND"
def join_string(self):
return "{keyword} {foreign_table}{alias} {on}"
def aggregate_string(self):
return "{aggregate_function}({column}) AS {alias}"
def subquery_string(self):
return "({query})"
def subquery_alias_string(self):
return "AS {alias}"
def where_group_string(self):
return "{keyword} {value}"
def or_where_string(self):
return "OR"
def raw_query_string(self):
return "{keyword} {query}"
def where_in_string(self):
return "WHERE IN ({values})"
def value_equal_string(self):
return "{keyword} {value1} = {value2}"
def where_null_string(self):
return " {keyword} {column} IS NULL"
def between_string(self):
return "{keyword} {column} BETWEEN {low} AND {high}"
def not_between_string(self):
return "{keyword} {column} NOT BETWEEN {low} AND {high}"
def where_not_null_string(self):
return " {keyword} {column} IS NOT NULL"
def where_string(self):
return " {keyword} {column} {equality} {value}"
def offset_string(self):
return "OFFSET {offset} ROWS FETCH NEXT {limit} ROWS ONLY"
def increment_string(self):
return "{column} = {column} + '{value}'{separator}"
def decrement_string(self):
return "{column} = {column} - '{value}'{separator}"
def aggregate_string_with_alias(self):
return "{aggregate_function}({column}) AS {alias}"
def key_value_string(self):
return "{column} = '{value}'{separator}"
def column_value_string(self):
return "{column} = {value}{separator}"
def table_string(self):
return "[{table}]"
def order_by_format(self):
return "{column} {direction}"
def order_by_string(self):
return "ORDER BY {order_columns}"
def column_string(self):
return "[{column}]{separator}"
def table_column_string(self):
return "[{table}].[{column}]{separator}"
def table_update_column_string(self):
return "[{table}].[{column}]{separator}"
def table_insert_column_string(self):
return "[{table}].[{column}]{separator}"
def value_string(self):
return "'{value}'{separator}"
def wrap_table(self, table_name):
return self.table_string().format(table=table_name)
def truncate_table(self, table, foreign_keys=False):
return f"TRUNCATE TABLE {self.wrap_table(table)}"
def compile_random(self, seed):
return "NEWID()"
================================================
FILE: src/masoniteorm/query/grammars/MySQLGrammar.py
================================================
from .BaseGrammar import BaseGrammar
class MySQLGrammar(BaseGrammar):
"""MySQL grammar class."""
aggregate_options = {
"SUM": "SUM",
"MAX": "MAX",
"MIN": "MIN",
"AVG": "AVG",
"COUNT": "COUNT",
}
join_keywords = {
"inner": "INNER JOIN",
"join": "INNER JOIN",
"outer": "OUTER JOIN",
"left": "LEFT JOIN",
"right": "RIGHT JOIN",
"left_inner": "LEFT INNER JOIN",
"right_inner": "RIGHT INNER JOIN",
}
"""Column strings are formats for how columns and key values should be formatted
on specific queries. These can be different depending on the type of query.
For example for Postgres, You can specify columns as "users"."name":
SELECT "users"."name" from "users"
But on updates we can only specify the column name and cannot have the table prefixed:
UPDATE "users" SET "name" = "value"
This dictionary allows you to modify the format depending on the type
of query we are generating. For most databases these will be the same
but this allows you to modify formats depending on the database.
"""
column_strings = {
"select": "{table}.`{column}`{alias}{separator}",
"select_all": "{table}.*{separator}",
"insert": "{table}.`{column}`{separator}",
"update": "{table}.`{column}`{separator}",
"delete": "{table}.`{column}`{separator}",
}
locks = {"share": "LOCK IN SHARE MODE", "update": "FOR UPDATE"}
def select_format(self):
return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}"
def select_no_table(self):
return "SELECT {columns} {lock}"
def update_format(self):
return "UPDATE {table} SET {key_equals} {wheres}"
def insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES ({values})"
def bulk_insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES {values}"
def delete_format(self):
return "DELETE FROM {table} {wheres}"
def aggregate_string_with_alias(self):
return "{aggregate_function}({column}) AS {alias}"
def aggregate_string_without_alias(self):
return "{aggregate_function}({column})"
def subquery_string(self):
return "({query})"
def raw_query_string(self):
return "{keyword} {query}"
def where_group_string(self):
return "{keyword} {value}"
def between_string(self):
return "{keyword} {column} BETWEEN {low} AND {high}"
def not_between_string(self):
return "{keyword} {column} NOT BETWEEN {low} AND {high}"
def where_exists_string(self):
return "{keyword} EXISTS {value}"
def where_date_string(self):
return "{keyword} DATE({column}) {equality} {value}"
def where_not_exists_string(self):
return "{keyword} NOT EXISTS {value}"
def where_like_string(self):
return "{keyword} {column} LIKE {value}"
def where_not_like_string(self):
return "{keyword} {column} NOT LIKE {value}"
def get_true_column_string(self):
return "{keyword} {column} = '1'"
def get_false_column_string(self):
return "{keyword} {column} = '0'"
def process_table(self, table):
"""Compiles a given table name.
Arguments:
table {string} -- The table name to compile.
Returns:
self
"""
if not table:
return ""
if isinstance(table, str):
return ".".join(
self.table_string().format(table=t) for t in table.split(".")
)
if table.raw:
return table.name
return ".".join(
self.table_string().format(table=t) for t in table.name.split(".")
)
def subquery_alias_string(self):
return "AS {alias}"
def key_value_string(self):
return "{column} = '{value}'{separator}"
def column_value_string(self):
return "{column} = {value}{separator}"
def increment_string(self):
return "{column} = {column} + '{value}'{separator}"
def decrement_string(self):
return "{column} = {column} - '{value}'{separator}"
def create_column_string(self):
return "{column} {data_type}{length}{nullable}{default_value}, "
def column_exists_string(self):
return "SHOW COLUMNS FROM {table} LIKE {value}"
def table_exists_string(self):
return "SELECT * from information_schema.tables where table_name='{clean_table}' AND table_schema = '{database}'"
def create_column_length(self, column_type):
return "({length})"
def table_string(self):
return "`{table}`"
def order_by_format(self):
return "{column} {direction}"
def order_by_string(self):
return "ORDER BY {order_columns}"
def column_string(self):
return "`{column}`{separator}"
def value_string(self):
return "'{value}'{separator}"
def join_string(self):
return "{keyword} {foreign_table}{alias} {on}"
def limit_string(self, offset=False):
return "LIMIT {limit}"
def offset_string(self):
return "OFFSET {offset}"
def first_where_string(self):
return "WHERE"
def additional_where_string(self):
return "AND"
def or_where_string(self):
return "OR"
def where_in_string(self):
return "WHERE IN ({values})"
def value_equal_string(self):
return "{keyword} {value1} = {value2}"
def where_string(self):
return " {keyword} {column} {equality} {value}"
def having_string(self):
return "HAVING {column}"
def having_equality_string(self):
return "HAVING {column} {equality} {value}"
def where_null_string(self):
return " {keyword} {column} IS NULL"
def where_not_null_string(self):
return " {keyword} {column} IS NOT NULL"
def enable_foreign_key_constraints(self):
return "SET FOREIGN_KEY_CHECKS=1"
def disable_foreign_key_constraints(self):
return "SET FOREIGN_KEY_CHECKS=0"
def truncate_table(self, table, foreign_keys=False):
"""Specifies a truncate table expression.
Arguments;
table {string} -- The name of the table to truncate.
Returns:
self
"""
if not foreign_keys:
return f"TRUNCATE TABLE {self.wrap_table(table)}"
return [
self.disable_foreign_key_constraints(),
f"TRUNCATE TABLE {self.wrap_table(table)}",
self.enable_foreign_key_constraints(),
]
def compile_random(self):
return "RAND()"
================================================
FILE: src/masoniteorm/query/grammars/PostgresGrammar.py
================================================
import re
from .BaseGrammar import BaseGrammar
class PostgresGrammar(BaseGrammar):
"""Postgres grammar class."""
aggregate_options = {
"SUM": "SUM",
"MAX": "MAX",
"MIN": "MIN",
"AVG": "AVG",
"COUNT": "COUNT",
}
join_keywords = {
"inner": "INNER JOIN",
"join": "INNER JOIN",
"outer": "OUTER JOIN",
"left": "LEFT JOIN",
"right": "RIGHT JOIN",
"left_inner": "LEFT INNER JOIN",
"right_inner": "RIGHT INNER JOIN",
}
column_strings = {
"select": '{table}."{column}"{alias}{separator}',
"select_all": "{table}.*{separator}",
"insert": '"{column}"{separator}',
"update": '"{column}"{separator}',
"delete": '{table}."{column}"{separator}',
}
locks = {"share": "FOR SHARE", "update": "FOR UPDATE"}
def select_no_table(self):
return "SELECT {columns} {lock}"
def select_format(self):
return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}"
def update_format(self):
return "UPDATE {table} SET {key_equals} {wheres}"
def insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES ({values}) RETURNING *"
def bulk_insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES {values} RETURNING *"
def delete_format(self):
return "DELETE FROM {table} {wheres}"
def aggregate_string_with_alias(self):
return "{aggregate_function}({column}) AS {alias}"
def aggregate_string_without_alias(self):
return "{aggregate_function}({column})"
def get_true_column_string(self):
return "{keyword} {column} IS True"
def get_false_column_string(self):
return "{keyword} {column} IS False"
def subquery_string(self):
return "({query})"
def raw_query_string(self):
return "{keyword} {query}"
def where_group_string(self):
return "{keyword} {value}"
def between_string(self):
return "{keyword} {column} BETWEEN {low} AND {high}"
def not_between_string(self):
return "{keyword} {column} NOT BETWEEN {low} AND {high}"
def where_exists_string(self):
return "{keyword} EXISTS {value}"
def where_not_exists_string(self):
return "{keyword} NOT EXISTS {value}"
def where_like_string(self):
return "{keyword} {column} ILIKE {value}"
def where_not_like_string(self):
return "{keyword} {column} NOT ILIKE {value}"
def subquery_alias_string(self):
return "AS {alias}"
def key_value_string(self):
return "{column} = '{value}'{separator}"
def column_value_string(self):
return "{column} = {value}{separator}"
def increment_string(self):
return "{column} = {column} + '{value}'{separator}"
def decrement_string(self):
return "{column} = {column} - '{value}'{separator}"
def create_column_string(self):
return "{column} {data_type}{length}{nullable}, "
def column_exists_string(self):
return "SELECT column_name FROM information_schema.columns WHERE table_name='{clean_table}' and column_name={value}"
def table_exists_string(self):
return (
"SELECT * from information_schema.tables where table_name='{clean_table}'"
)
def create_column_length(self, column_type):
if column_type in self.types_without_lengths:
return ""
return "({length})"
def to_sql(self):
"""Cleans up the SQL string and returns the SQL
Returns:
string
"""
if self.queries and (not self._columns and not self._creates):
sql = ""
for query in self.queries:
query += "; "
sql += re.sub(" +", " ", query)
return sql.rstrip(" ")
else:
sql = re.sub(" +", " ", self._sql.strip().replace(",)", ")"))
for query in self.queries:
sql += "; "
sql += re.sub(" +", " ", query.strip())
return sql
def table_string(self):
return '"{table}"'
def order_by_format(self):
return "{column} {direction}"
def order_by_string(self):
return "ORDER BY {order_columns}"
def column_string(self):
return '"{column}"{separator}'
def value_string(self):
return "'{value}'{separator}"
def join_string(self):
return "{keyword} {foreign_table}{alias} {on}"
def limit_string(self, offset=False):
return "LIMIT {limit}"
def offset_string(self):
return "OFFSET {offset}"
def first_where_string(self):
return "WHERE"
def additional_where_string(self):
return "AND"
def or_where_string(self):
return "OR"
def where_in_string(self):
return "WHERE IN ({values})"
def where_date_string(self):
return "{keyword} DATE({column}) {equality} {value}"
def value_equal_string(self):
return "{keyword} {value1} = {value2}"
def where_string(self):
return " {keyword} {column} {equality} {value}"
def having_string(self):
return "HAVING {column}"
def having_equality_string(self):
return "HAVING {column} {equality} {value}"
def where_null_string(self):
return " {keyword} {column} IS NULL"
def where_not_null_string(self):
return " {keyword} {column} IS NOT NULL"
def truncate_table(self, table, foreign_keys=False):
"""Specifies a truncate table expression.
Arguments;
table {string} -- The name of the table to truncate.
Returns:
string
"""
return f"TRUNCATE TABLE {self.wrap_table(table)}"
def compile_random(self):
return "random()"
================================================
FILE: src/masoniteorm/query/grammars/SQLiteGrammar.py
================================================
import re
from .BaseGrammar import BaseGrammar
class SQLiteGrammar(BaseGrammar):
"""SQLite grammar class."""
aggregate_options = {
"SUM": "SUM",
"MAX": "MAX",
"MIN": "MIN",
"AVG": "AVG",
"COUNT": "COUNT",
}
join_keywords = {
"inner": "INNER JOIN",
"join": "INNER JOIN",
"outer": "OUTER JOIN",
"left": "LEFT JOIN",
"right": "LEFT JOIN",
"left_inner": "LEFT INNER JOIN",
"right_inner": "LEFT INNER JOIN",
}
column_strings = {
"select": '{table}."{column}"{alias}{separator}',
"select_all": "{table}.*{separator}",
"insert": '"{column}"{separator}',
"update": '"{column}"{separator}',
"delete": '"{column}"{separator}',
}
locks = {"share": "", "update": ""}
def select_format(self):
return "SELECT {keyword} {columns} FROM {table} {joins} {wheres} {group_by} {having} {order_by} {limit} {offset} {lock}"
def select_no_table(self):
return "SELECT {columns} {lock}"
def update_format(self):
return "UPDATE {table} SET {key_equals} {wheres}"
def insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES ({values})"
def bulk_insert_format(self):
return "INSERT INTO {table} ({columns}) VALUES {values}"
def delete_format(self):
return "DELETE FROM {table} {wheres}"
def aggregate_string_with_alias(self):
return "{aggregate_function}({column}) AS {alias}"
def aggregate_string_without_alias(self):
return "{aggregate_function}({column})"
def subquery_string(self):
return "({query})"
def default_string(self):
return " DEFAULT {default} "
def raw_query_string(self):
return "{keyword} {query}"
def where_group_string(self):
return "{keyword} {value}"
def between_string(self):
return "{keyword} {column} BETWEEN {low} AND {high}"
def not_between_string(self):
return "{keyword} {column} NOT BETWEEN {low} AND {high}"
def where_exists_string(self):
return "{keyword} EXISTS {value}"
def where_not_exists_string(self):
return "{keyword} NOT EXISTS {value}"
def where_like_string(self):
return "{keyword} {column} LIKE {value}"
def where_not_like_string(self):
return "{keyword} {column} NOT LIKE {value}"
def subquery_alias_string(self):
return "AS {alias}"
def key_value_string(self):
return "{column} = '{value}'{separator}"
def column_value_string(self):
return "{column} = {value}{separator}"
def increment_string(self):
return "{column} = {column} + '{value}'{separator}"
def decrement_string(self):
return "{column} = {column} - '{value}'{separator}"
def column_exists_string(self):
return "SELECT column_name FROM information_schema.columns WHERE table_name='{clean_table}' and column_name={value}"
def table_exists_string(self):
return (
"SELECT name FROM sqlite_master WHERE type='table' AND name='{clean_table}'"
)
def to_sql(self):
"""Cleans up the SQL string and returns the SQL
Returns:
string
"""
if self.queries and (not self._columns and not self._creates):
sql = ""
for query in self.queries:
query += "; "
sql += re.sub(" +", " ", query)
return sql.rstrip(" ")
else:
sql = re.sub(" +", " ", self._sql.strip().replace(",)", ")"))
for query in self.queries:
sql += "; "
sql += re.sub(" +", " ", query.strip())
return sql
def table_string(self):
return '"{table}"'
def order_by_format(self):
return "{column} {direction}"
def order_by_string(self):
return "ORDER BY {order_columns}"
def column_string(self):
return '"{column}"{separator}'
def value_string(self):
return "'{value}'{separator}"
def join_string(self):
return "{keyword} {foreign_table}{alias} {on}"
def limit_string(self, offset=False):
if offset:
return ""
return "LIMIT {limit}"
def offset_string(self):
return "LIMIT {limit} OFFSET {offset}"
def first_where_string(self):
return "WHERE"
def additional_where_string(self):
return "AND"
def or_where_string(self):
return "OR"
def where_in_string(self):
return "WHERE IN ({values})"
def where_string(self):
return " {keyword} {column} {equality} {value}"
def having_string(self):
return "HAVING {column}"
def having_equality_string(self):
return "HAVING {column} {equality} {value}"
def where_null_string(self):
return " {keyword} {column} IS NULL"
def where_date_string(self):
return "{keyword} DATE({column}) {equality} {value}"
def value_equal_string(self):
return "{keyword} {value1} = {value2}"
def where_not_null_string(self):
return " {keyword} {column} IS NOT NULL"
def enable_foreign_key_constraints(self):
return "PRAGMA foreign_keys = ON"
def disable_foreign_key_constraints(self):
return "PRAGMA foreign_keys = OFF"
def get_true_column_string(self):
return "{keyword} {column} = '1'"
def get_false_column_string(self):
return "{keyword} {column} = '0'"
def truncate_table(self, table, foreign_keys=False):
# SQLite do not have TRUNCATE TABLE command but we can
# use SQLite DELETE command to delete complete data from an existing table
if not foreign_keys:
return f"DELETE FROM {self.wrap_table(table)}"
return [
self.disable_foreign_key_constraints(),
f"DELETE FROM {self.wrap_table(table)}",
self.enable_foreign_key_constraints(),
]
def compile_random(self):
return "random()"
def process_offset(self):
"""Compiles the offset expression.
Returns:
self
"""
if not self._limit:
self._limit = int(-1)
return super().process_offset()
================================================
FILE: src/masoniteorm/query/grammars/__init__.py
================================================
from .MSSQLGrammar import MSSQLGrammar
from .MySQLGrammar import MySQLGrammar
from .PostgresGrammar import PostgresGrammar
from .SQLiteGrammar import SQLiteGrammar
================================================
FILE: src/masoniteorm/query/processors/MSSQLPostProcessor.py
================================================
class MSSQLPostProcessor:
"""Post processor classes are responsable for modifying the result after a query.
Post Processors are called after the connection calls the database in the
Query Builder but before the result is returned in that builder method.
We can use this oppurtunity to get things like the inserted ID.
For the Postgres Post Processor we have a RETURNING * string in the insert so the result
will already have the full inserted record in the results. Therefore, we can just return
the results
"""
def process_insert_get_id(self, builder, results, id_key):
"""Process the results from the query to the database.
Args:
builder (masoniteorm.builder.QueryBuilder): The query builder class
results (dict): The result from an insert query or the creates from the query builder.
This is usually a dictionary.
id_key (string): The key to set the primary key to. This is usually the primary key of the table.
Returns:
dictionary: Should return the modified dictionary.
"""
last_id = builder.new_connection().query(
"SELECT @@Identity as [id]", results=1
)
id = last_id["id"]
if str(id).isdigit():
id = int(id)
else:
id = str(id)
results.update({id_key: id})
return results
def get_column_value(self, builder, column, results, id_key, id_value):
"""Gets the specific column value from a table. Typically done after an update to
refetch the new value of a field.
builder (masoniteorm.builder.QueryBuilder): The query builder class
column (string): The column to refetch the value for.
results (dict): The result from an update query from the query builder.
This is usually a dictionary.
id_key (string): The key to fetch the primary key for. This is usually the primary key of the table.
id_value (string): The value of the primary key to fetch
"""
new_builder = builder.select(column)
if id_key and id_value:
new_builder.where(id_key, id_value)
return new_builder.first()[column]
return {}
================================================
FILE: src/masoniteorm/query/processors/MySQLPostProcessor.py
================================================
class MySQLPostProcessor:
"""Post processor classes are responsable for modifying the result after a query.
Post Processors are called after the connection calls the database in the
Query Builder but before the result is returned in that builder method.
We can use this oppurtunity to get things like the inserted ID.
For the SQLite Post Processor we have an attribute on the connection class we can use to fetch the ID.
"""
def process_insert_get_id(self, builder, results, id_key):
"""Process the results from the query to the database.
Args:
builder (masoniteorm.builder.QueryBuilder): The query builder class
results (dict): The result from an insert query or the creates from the query builder.
This is usually a dictionary.
id_key (string): The key to set the primary key to. This is usually the primary key of the table.
Returns:
dictionary: Should return the modified dictionary.
"""
if id_key not in results:
results.update({id_key: builder._connection.get_cursor().lastrowid})
return results
def get_column_value(self, builder, column, results, id_key, id_value):
"""Gets the specific column value from a table. Typically done after an update to
refetch the new value of a field.
builder (masoniteorm.builder.QueryBuilder): The query builder class
column (string): The column to refetch the value for.
results (dict): The result from an update query from the query builder.
This is usually a dictionary.
id_key (string): The key to fetch the primary key for. This is usually the primary key of the table.
id_value (string): The value of the primary key to fetch
"""
new_builder = builder.select(column)
if id_key and id_value:
new_builder.where(id_key, id_value)
return new_builder.first()[column]
return {}
================================================
FILE: src/masoniteorm/query/processors/PostgresPostProcessor.py
================================================
class PostgresPostProcessor:
"""Post processor classes are responsable for modifying the result after a query.
Post Processors are called after the connection calls the database in the
Query Builder but before the result is returned in that builder method.
We can use this oppurtunity to get things like the inserted ID.
For the Postgres Post Processor we have a RETURNING * string in the insert so the result
will already have the full inserted record in the results. Therefore, we can just return
the results
"""
def process_insert_get_id(self, builder, results, id_key):
"""Process the results from the query to the database.
Args:
builder (masoniteorm.builder.QueryBuilder): The query builder class
results (dict): The result from an insert query or the creates from the query builder.
This is usually a dictionary.
id_key (string): The key to set the primary key to. This is usually the primary key of the table.
Returns:
dictionary: Should return the modified dictionary.
"""
return results
def get_column_value(self, builder, column, results, id_key, id_value):
"""Gets the specific column value from a table. Typically done after an update to
refetch the new value of a field.
builder (masoniteorm.builder.QueryBuilder): The query builder class
column (string): The column to refetch the value for.
results (dict): The result from an update query from the query builder.
This is usually a dictionary.
id_key (string): The key to fetch the primary key for. This is usually the primary key of the table.
id_value (string): The value of the primary key to fetch
"""
if column in results:
return results[column]
new_builder = builder.select(column)
if id_key and id_value:
new_builder.where(id_key, id_value)
return new_builder.first()[column]
return {}
================================================
FILE: src/masoniteorm/query/processors/SQLitePostProcessor.py
================================================
class SQLitePostProcessor:
"""Post processor classes are responsable for modifying the result after a query.
Post Processors are called after the connection calls the database in the
Query Builder but before the result is returned in that builder method.
We can use this oppurtunity to get things like the inserted ID.
For the SQLite Post Processor we have an attribute on the connection class we can use to fetch the ID.
"""
def process_insert_get_id(self, builder, results, id_key="id"):
"""Process the results from the query to the database.
Args:
builder (masoniteorm.builder.QueryBuilder): The query builder class
results (dict): The result from an insert query or the creates from the query builder.
This is usually a dictionary.
id_key (string): The key to set the primary key to. This is usually the primary key of the table.
Returns:
dictionary: Should return the modified dictionary.
"""
if id_key not in results:
results.update({id_key: builder.get_connection().get_cursor().lastrowid})
return results
def get_column_value(self, builder, column, results, id_key, id_value):
"""Gets the specific column value from a table. Typically done after an update to
refetch the new value of a field.
builder (masoniteorm.builder.QueryBuilder): The query builder class
column (string): The column to refetch the value for.
results (dict): The result from an update query from the query builder.
This is usually a dictionary.
id_key (string): The key to fetch the primary key for. This is usually the primary key of the table.
id_value (string): The value of the primary key to fetch
"""
new_builder = builder.select(column)
if id_key and id_value:
new_builder.where(id_key, id_value)
return new_builder.first()[column]
return {}
================================================
FILE: src/masoniteorm/query/processors/__init__.py
================================================
from .MSSQLPostProcessor import MSSQLPostProcessor
from .MySQLPostProcessor import MySQLPostProcessor
from .PostgresPostProcessor import PostgresPostProcessor
from .SQLitePostProcessor import SQLitePostProcessor
================================================
FILE: src/masoniteorm/relationships/BaseRelationship.py
================================================
class BaseRelationship:
def __init__(self, fn, local_key=None, foreign_key=None):
if isinstance(fn, str):
self.fn = None
self.local_key = fn
self.foreign_key = local_key
else:
self.fn = fn
self.local_key = local_key
self.foreign_key = foreign_key
def __set_name__(self, cls, name):
"""This method is called right after the decorator is registered.
At this point we finally have access to the model cls
Arguments:
name {object} -- The model class.
"""
pass
def __call__(self, fn=None, *args, **kwargs):
"""This method is called when the decorator contains arguments.
When you do something like this:
@belongs_to('id', 'user_id').
In this case, the {fn} argument will be the callable.
"""
if callable(fn):
self.fn = fn
return self
def get_builder(self):
return self._related_builder
def __get__(self, instance, owner):
"""
This method is called when the decorated method is accessed.
Arguments:
instance {object|None} -- The instance we called.
If we didn't call the attribute and only accessed it then this will be None.
owner {object} -- The current model that the property was accessed on.
Returns:
object -- Either returns a builder or a hydrated model.
"""
attribute = self.fn.__name__
relationship = self.fn(instance)()
self.set_keys(instance, attribute)
self._related_builder = relationship.builder
if not instance.is_loaded():
return self
if attribute in instance._relationships:
return instance._relationships[attribute]
return self.apply_query(self._related_builder, instance)
def __getattr__(self, attribute):
relationship = self.fn(self)()
return getattr(relationship.builder, attribute)
def apply_query(self, foreign, owner):
"""Return a dictionary to hydrate the model with
Arguments:
foreign {oject} -- The relationship object
owner {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'apply_query' method"
)
def query_where_exists(self, builder, callback, method="where_exists"):
"""Adds a criteria clause to the query filter for existing related records"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'query_where_exists' method"
)
def joins(self, builder, clause=None):
"""Helper method for adding join clauses to a relationship"""
other_table = self.get_builder().get_table_name()
local_table = builder.get_table_name()
return builder.join(
other_table,
f"{local_table}.{self.local_key}",
"=",
f"{other_table}.{self.foreign_key}",
clause=clause,
)
def get_with_count_query(self, builder, callback):
"""Adds a clause to the query to get the record count of the relationship"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'get_with_count_query' method"
)
def attach(self, current_model, related_record):
"""Link a related model to the current model"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'attach' method"
)
def get_related(self, query, relation, eagers=None, callback=None):
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'get_related' method"
)
def relate(self, related_record):
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'relate' method"
)
def detach(self, current_model, related_record):
"""Unlink a related model from the current model"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'detach' method"
)
def attach_related(self, current_model, related_record):
"""Unlink a related model from the current model"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'attach_related' method"
)
def detach_related(self, current_model, related_record):
"""Unlink a related model from the current model"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'detach_related' method"
)
def query_has(self, current_query_builder, method="where_exists"):
"""Adds a clause to the query to chek if a rwlarion exists"""
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'query_has' method"
)
def map_related(self, related_result):
klass = self.__class__.__name__
raise NotImplementedError(
f"{klass} relationship does not implement the 'related_result' method"
)
================================================
FILE: src/masoniteorm/relationships/BelongsTo.py
================================================
from ..collection import Collection
from .BaseRelationship import BaseRelationship
class BelongsTo(BaseRelationship):
"""Belongs To Relationship Class."""
def __init__(self, fn, local_key=None, foreign_key=None):
if isinstance(fn, str):
self.fn = None
self.local_key = fn or "id"
self.foreign_key = local_key
else:
self.fn = fn
self.local_key = local_key or "id"
self.foreign_key = foreign_key
def set_keys(self, owner, attribute):
self.local_key = self.local_key or f"{attribute}_id"
self.foreign_key = self.foreign_key or "id"
return self
def apply_query(self, foreign, owner):
"""Apply the query and return a dictionary to be hydrated
Arguments:
foreign {oject} -- The relationship object
owner {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
return foreign.where(
self.foreign_key, owner.__attributes__[self.local_key]
).first()
def query_has(self, current_query_builder, method="where_exists"):
related_builder = self.get_builder()
getattr(current_query_builder, method)(
related_builder.where_column(
f"{related_builder.get_table_name()}.{self.foreign_key}",
f"{current_query_builder.get_table_name()}.{self.local_key}",
)
)
return related_builder
def query_where_exists(self, builder, callback, method="where_exists"):
query = self.get_builder()
getattr(builder, method)(
callback(
query.where_column(
f"{query.get_table_name()}.{self.foreign_key}",
f"{builder.get_table_name()}.{self.local_key}",
)
)
)
return query
def get_related(self, query, relation, eagers=(), callback=None):
"""Gets the relation needed between the relation and the related builder. If the relation is a collection
then will need to pluck out all the keys from the collection and fetch from the related builder. If
relation is just a Model then we can just call the model based on the value of the related
builders primary key.
Args:
relation (Model|Collection):
Returns:
Model|Collection
"""
builder = self.get_builder().with_(eagers)
if callback:
callback(builder)
if isinstance(relation, Collection):
return builder.where_in(
f"{builder.get_table_name()}.{self.foreign_key}",
Collection(relation._get_value(self.local_key)).unique(),
).get()
else:
return builder.where(
f"{builder.get_table_name()}.{self.foreign_key}",
getattr(relation, self.local_key),
).first()
def register_related(self, key, model, collection):
related = collection.get(getattr(model, self.local_key), None)
model.add_relation({key: related[0] if related else None})
def map_related(self, related_result):
return related_result.group_by(self.foreign_key)
def attach(self, current_model, related_record):
foreign_key_value = getattr(related_record, self.foreign_key)
if not current_model.is_created():
current_model.fill({self.local_key: foreign_key_value})
return current_model.create(current_model.all_attributes(), cast=True)
return current_model.update({self.local_key: foreign_key_value})
def detach(self, current_model, related_record):
return current_model.update({self.local_key: None})
def relate(self, related_record):
return (
self.get_builder()
.where(self.foreign_key, related_record.__attributes__[self.local_key])
._set_creates_related(
{self.foreign_key: related_record.__attributes__[self.local_key]}
)
)
================================================
FILE: src/masoniteorm/relationships/BelongsToMany.py
================================================
import pendulum
from inflection import singularize
from ..collection import Collection
from ..models.Pivot import Pivot
from .BaseRelationship import BaseRelationship
class BelongsToMany(BaseRelationship):
"""Has Many Relationship Class."""
def __init__(
self,
fn=None,
local_foreign_key=None,
other_foreign_key=None,
local_owner_key=None,
other_owner_key=None,
table=None,
with_timestamps=False,
pivot_id="id",
attribute="pivot",
with_fields=[],
):
if isinstance(fn, str):
self.fn = None
self.local_key = fn
self.foreign_key = local_foreign_key
self.local_owner_key = other_foreign_key or "id"
self.other_owner_key = local_owner_key or "id"
else:
self.fn = fn
self.local_key = local_foreign_key
self.foreign_key = other_foreign_key
self.local_owner_key = local_owner_key or "id"
self.other_owner_key = other_owner_key or "id"
self._table = table
self.with_timestamps = with_timestamps
self._as = attribute
self.pivot_id = pivot_id
self.with_fields = with_fields
def set_keys(self, owner, attribute):
self.local_key = self.local_key or "id"
self.foreign_key = self.foreign_key or f"{attribute}_id"
return self
def apply_query(self, query, owner):
"""Apply the query and return a dictionary to be hydrated.
Used during accessing a relationship on a model
Arguments:
query {oject} -- The relationship object
owner {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
if not self._table:
pivot_tables = [
singularize(owner.builder.get_table_name()),
singularize(query.get_table_name()),
]
pivot_tables.sort()
pivot_table_1, pivot_table_2 = pivot_tables
self._table = "_".join(pivot_tables)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
elif self.local_key is None or self.foreign_key is None:
pivot_table_1, pivot_table_2 = self._table.split("_", 1)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
table1 = owner.get_table_name()
table2 = query.get_table_name()
result = query.select(
f"{query.get_table_name()}.*",
f"{self._table}.{self.local_key} as {self._table}_id",
f"{self._table}.{self.foreign_key} as m_reserved2",
).table(f"{table1}")
if self.pivot_id:
result.select(f"{self._table}.{self.pivot_id} as m_reserved3")
if self.with_timestamps:
result.select(
f"{self._table}.updated_at as m_reserved4",
f"{self._table}.created_at as m_reserved5",
)
result.join(
f"{self._table}",
f"{self._table}.{self.local_key}",
"=",
f"{table1}.{self.local_owner_key}",
)
result.join(
f"{table2}",
f"{self._table}.{self.foreign_key}",
"=",
f"{table2}.{self.other_owner_key}",
)
if hasattr(owner, self.local_owner_key):
result.where(
f"{table1}.{self.local_owner_key}", getattr(owner, self.local_owner_key)
)
if self.with_fields:
for field in self.with_fields:
result.select(f"{self._table}.{field}")
result = result.get()
for model in result:
pivot_data = {
self.local_key: getattr(model, f"{self._table}_id"),
self.foreign_key: getattr(model, "m_reserved2"),
}
if self.with_timestamps:
pivot_data = {
"created_at": getattr(model, "m_reserved5"),
"updated_at": getattr(model, "m_reserved4"),
}
model.delete_attribute("m_reserved4")
model.delete_attribute("m_reserved5")
model.delete_attribute("m_reserved2")
if self.pivot_id:
pivot_data.update({self.pivot_id: getattr(model, "m_reserved3")})
model.delete_attribute("m_reserved3")
if self.with_fields:
for field in self.with_fields:
pivot_data.update({field: getattr(model, field)})
model.delete_attribute(field)
model.__original_attributes__.update(
{
self._as: (
Pivot.on(query.connection)
.table(self._table)
.hydrate(pivot_data)
.activate_timestamps(self.with_timestamps)
)
}
)
return result
def table(self, table):
self._table = table
return self
def make_builder(self, eagers=None):
builder = self.get_builder().with_(eagers)
return builder
def make_query(self, query, relation, eagers=None, callback=None):
"""Used during eager loading a relationship
Args:
query ([type]): [description]
relation ([type]): [description]
eagers (list, optional): List of eager loaded relationships. Defaults to None.
Returns:
[type]: [description]
"""
eagers = eagers or []
builder = self.get_builder().with_(eagers)
if not self._table:
pivot_tables = [
singularize(builder.get_table_name()),
singularize(query.get_table_name()),
]
pivot_tables.sort()
pivot_table_1, pivot_table_2 = pivot_tables
self._table = "_".join(pivot_tables)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
elif self.local_key is None or self.foreign_key is None:
pivot_table_1, pivot_table_2 = self._table.split("_", 1)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
table2 = builder.get_table_name()
table1 = query.get_table_name()
result = (
builder.select(
f"{table2}.*",
f"{self._table}.{self.local_key} as {self._table}_id",
f"{self._table}.{self.foreign_key} as m_reserved2",
)
.run_scopes()
.table(f"{table1}")
)
if self.with_fields:
for field in self.with_fields:
result.select(f"{self._table}.{field}")
result.join(
f"{self._table}",
f"{self._table}.{self.local_key}",
"=",
f"{table1}.{self.local_owner_key}",
)
result.join(
f"{table2}",
f"{self._table}.{self.foreign_key}",
"=",
f"{table2}.{self.other_owner_key}",
)
if self.with_timestamps:
result.select(
f"{self._table}.updated_at as m_reserved4",
f"{self._table}.created_at as m_reserved5",
)
if self.pivot_id:
result.select(f"{self._table}.{self.pivot_id} as m_reserved3")
result.without_global_scopes()
if callback:
callback(result)
if isinstance(relation, Collection):
return result.where_in(
self.local_owner_key,
Collection(relation._get_value(self.local_owner_key)).unique(),
).get()
else:
return result.where(
self.local_owner_key, getattr(relation, self.local_owner_key)
).get()
def get_related(self, query, relation, eagers=None, callback=None):
final_result = self.make_query(
query, relation, eagers=eagers, callback=callback
)
builder = self.make_builder(eagers)
for model in final_result:
pivot_data = {
self.local_key: getattr(model, f"{self._table}_id"),
self.foreign_key: getattr(model, "m_reserved2"),
}
model.delete_attribute("m_reserved2")
if self.with_timestamps:
pivot_data.update(
{
"updated_at": getattr(model, "m_reserved4"),
"created_at": getattr(model, "m_reserved5"),
}
)
if self.pivot_id:
pivot_data.update({self.pivot_id: getattr(model, "m_reserved3")})
model.delete_attribute("m_reserved3")
if self.with_fields:
for field in self.with_fields:
pivot_data.update({field: getattr(model, field)})
model.delete_attribute(field)
model.__original_attributes__.update(
{
self._as: (
Pivot.on(builder.connection)
.table(self._table)
.hydrate(pivot_data)
.activate_timestamps(self.with_timestamps)
)
}
)
return final_result
def relate(self, related_record):
owner = related_record.get_builder()
query = self.get_builder()
if not self._table:
pivot_tables = [
singularize(owner.builder.get_table_name()),
singularize(query.get_table_name()),
]
pivot_tables.sort()
pivot_table_1, pivot_table_2 = pivot_tables
self._table = "_".join(pivot_tables)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
elif self.local_key is None or self.foreign_key is None:
pivot_table_1, pivot_table_2 = self._table.split("_", 1)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
table1 = owner.get_table_name()
table2 = query.get_table_name()
result = query.select(
f"{query.get_table_name()}.*",
f"{self._table}.{self.local_key} as {self._table}_id",
f"{self._table}.{self.foreign_key} as m_reserved2",
).table(f"{table1}")
if self.pivot_id:
result.select(f"{self._table}.{self.pivot_id} as m_reserved3")
if self.with_timestamps:
result.select(
f"{self._table}.updated_at as m_reserved4",
f"{self._table}.created_at as m_reserved5",
)
result.join(
f"{self._table}",
f"{self._table}.{self.local_key}",
"=",
f"{table1}.{self.local_owner_key}",
)
result.join(
f"{table2}",
f"{self._table}.{self.foreign_key}",
"=",
f"{table2}.{self.other_owner_key}",
)
if hasattr(owner, self.local_owner_key):
result.where(
f"{table1}.{self.local_owner_key}", getattr(owner, self.local_owner_key)
)
if self.with_fields:
for field in self.with_fields:
result.select(f"{self._table}.{field}")
return result
def register_related(self, key, model, collection):
model.add_relation(
{
key: collection.where(
f"{self._table}_id", getattr(model, self.local_owner_key)
)
}
)
def joins(self, builder, clause=None):
if not self._table:
pivot_tables = [
singularize(self.get_builder().get_table_name()),
singularize(builder.get_table_name()),
]
pivot_tables.sort()
pivot_table_1, pivot_table_2 = pivot_tables
self._table = "_".join(pivot_tables)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
elif self.local_key is None or self.foreign_key is None:
pivot_table_1, pivot_table_2 = self._table.split("_", 1)
self.foreign_key = self.foreign_key or f"{pivot_table_1}_id"
self.local_key = self.local_key or f"{pivot_table_2}_id"
query = self.get_builder()
table1 = query.get_table_name()
table2 = builder.get_table_name()
result = builder
if not builder._columns:
result = result.select(
f"{table2}.*",
f"{self._table}.{self.local_key} as {self._table}_id",
f"{self._table}.{self.foreign_key} as m_reserved2",
)
if self.pivot_id:
result.select(f"{self._table}.{self.pivot_id} as m_reserved3")
if self.with_timestamps:
result.select(
f"{self._table}.updated_at as m_reserved4",
f"{self._table}.created_at as m_reserved5",
)
if self.with_fields:
for field in self.with_fields:
result.select(f"{self._table}.{field}")
# Join pivot table with an inner join
result.join(
f"{self._table}",
f"{self._table}.{self.local_key}",
"=",
f"{table2}.{self.local_owner_key}",
clause="inner",
)
result.join(
f"{table1}",
f"{self._table}.{self.local_owner_key}",
"=",
f"{table1}.{self.other_owner_key}",
clause=clause,
)
if self.with_fields:
for field in self.with_fields:
result.select(f"{self._table}.{field}")
return result
def query_where_exists(self, builder, callback, method="where_exists"):
query = self.get_builder()
pivot_table = self._table or self.get_pivot_table_name(query, builder)
table = self.get_builder().get_table_name()
getattr(builder, method)(
query.new()
.table(table)
.join(
f"{pivot_table}",
f"{table}.{self.other_owner_key}",
"=",
f"{pivot_table}.{self.foreign_key}",
)
.where_column(
f"{pivot_table}.{self.local_key}",
f"{builder.get_table_name()}.{self.local_owner_key}",
)
.where_in(
self.other_owner_key, callback(query.select(self.other_owner_key))
)
)
def query_has(self, builder, method="where_exists"):
query = self.get_builder()
pivot_table = self._table or self.get_pivot_table_name(query, builder)
table = self.get_builder().get_table_name()
return getattr(builder, method)(
query.new()
.table(table)
.join(
f"{pivot_table}",
f"{table}.{self.other_owner_key}",
"=",
f"{pivot_table}.{self.foreign_key}",
)
.where_column(
f"{pivot_table}.{self.local_key}",
f"{builder.get_table_name()}.{self.local_owner_key}",
)
)
def get_pivot_table_name(self, query, builder):
pivot_tables = [
singularize(query.get_table_name()),
singularize(builder.get_table_name()),
]
pivot_tables.sort()
return "_".join(pivot_tables)
def get_with_count_query(self, builder, callback):
query = self.get_builder()
self._table = self._table or self.get_pivot_table_name(query, builder)
if not builder._columns:
builder = builder.select("*")
return_query = builder.add_select(
f"{query.get_table_name()}_count",
lambda q: (
(
q.count("*")
.where_column(
f"{builder.get_table_name()}.{self.local_owner_key}",
f"{self._table}.{self.local_key}",
)
.table(self._table)
.when(
callback,
lambda q: (
q.where_in(
self.foreign_key,
callback(query.select(self.other_owner_key)),
)
),
)
)
),
)
return return_query
def attach(self, current_model, related_record):
data = {
self.local_key: getattr(current_model, self.local_owner_key),
self.foreign_key: getattr(related_record, self.other_owner_key),
}
self._table = self._table or self.get_pivot_table_name(
current_model, related_record
)
if self.with_timestamps:
data.update(
{
"created_at": pendulum.now().to_datetime_string(),
"updated_at": pendulum.now().to_datetime_string(),
}
)
return (
Pivot.on(current_model.get_builder().connection)
.table(self._table)
.without_global_scopes()
.create(data)
)
def detach(self, current_model, related_record):
data = {
self.local_key: getattr(current_model, self.local_owner_key),
self.foreign_key: getattr(related_record, self.other_owner_key),
}
self._table = self._table or self.get_pivot_table_name(
current_model, related_record
)
return (
Pivot.on(current_model.get_builder().connection)
.table(self._table)
.without_global_scopes()
.where(data)
.delete()
)
def attach_related(self, current_model, related_record):
data = {
self.local_key: getattr(current_model, self.local_owner_key),
self.foreign_key: getattr(related_record, self.other_owner_key),
}
self._table = self._table or self.get_pivot_table_name(
current_model, related_record
)
if self.with_timestamps:
data.update(
{
"created_at": pendulum.now().to_datetime_string(),
"updated_at": pendulum.now().to_datetime_string(),
}
)
return (
Pivot.table(self._table)
.on(current_model.get_builder().connection)
.without_global_scopes()
.create(data)
)
def detach_related(self, current_model, related_record):
data = {
self.local_key: getattr(current_model, self.local_owner_key),
self.foreign_key: getattr(related_record, self.other_owner_key),
}
self._table = self._table or self.get_pivot_table_name(
current_model, related_record
)
if self.with_timestamps:
data.update(
{
"created_at": pendulum.now().to_datetime_string(),
"updated_at": pendulum.now().to_datetime_string(),
}
)
return (
Pivot.on(current_model.get_builder().connection)
.table(self._table)
.without_global_scopes()
.where(data)
.delete()
)
================================================
FILE: src/masoniteorm/relationships/HasMany.py
================================================
from ..collection import Collection
from .BaseRelationship import BaseRelationship
class HasMany(BaseRelationship):
"""Has Many Relationship Class."""
def apply_query(self, foreign, owner):
"""Apply the query and return a dictionary to be hydrated
Arguments:
foreign {oject} -- The relationship object
owner {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
result = foreign.where(
self.foreign_key, owner.__attributes__[self.local_key]
).get()
return result
def set_keys(self, owner, attribute):
self.local_key = self.local_key or "id"
self.foreign_key = self.foreign_key or f"{attribute}_id"
return self
def register_related(self, key, model, collection):
model.add_relation(
{key: collection.get(getattr(model, self.local_key)) or Collection()}
)
def map_related(self, related_result):
return related_result.group_by(self.foreign_key)
def attach(self, current_model, related_record):
local_key_value = getattr(current_model, self.local_key)
if not related_record.is_created():
related_record.fill({self.foreign_key: local_key_value})
return related_record.create(related_record.all_attributes(), cast=True)
return related_record.update({self.foreign_key: local_key_value})
def get_related(self, query, relation, eagers=None, callback=None):
eagers = eagers or []
builder = self.get_builder().with_(eagers)
if callback:
callback(builder)
if isinstance(relation, Collection):
return builder.where_in(
f"{builder.get_table_name()}.{self.foreign_key}",
Collection(relation._get_value(self.local_key)).unique(),
).get()
return builder.where(
f"{builder.get_table_name()}.{self.foreign_key}",
getattr(relation, self.local_key),
).get()
================================================
FILE: src/masoniteorm/relationships/HasManyThrough.py
================================================
from ..collection import Collection
from .BaseRelationship import BaseRelationship
class HasManyThrough(BaseRelationship):
"""HasManyThrough Relationship Class."""
def __init__(
self,
fn=None,
local_foreign_key=None,
other_foreign_key=None,
local_owner_key=None,
other_owner_key=None,
):
if isinstance(fn, str):
self.fn = None
self.local_key = fn
self.foreign_key = local_foreign_key
self.local_owner_key = other_foreign_key or "id"
self.other_owner_key = local_owner_key or "id"
else:
self.fn = fn
self.local_key = local_foreign_key
self.foreign_key = other_foreign_key
self.local_owner_key = local_owner_key or "id"
self.other_owner_key = other_owner_key or "id"
def set_keys(self, distant_builder, intermediary_builder, attribute):
self.local_key = self.local_key or "id"
self.foreign_key = self.foreign_key or f"{attribute}_id"
self.local_owner_key = self.local_owner_key or "id"
self.other_owner_key = self.other_owner_key or "id"
return self
def __get__(self, instance, owner):
"""This method is called when the decorated method is accessed.
Arguments:
instance {object|None} -- The instance we called.
If we didn't call the attribute and only accessed it then this will be None.
owner {object} -- The current model that the property was accessed on.
Returns:
object -- Either returns a builder or a hydrated model.
"""
attribute = self.fn.__name__
self.attribute = attribute
relationship1 = self.fn(self)[0]()
relationship2 = self.fn(self)[1]()
self.distant_builder = relationship1.builder
self.intermediary_builder = relationship2.builder
self.set_keys(self.distant_builder, self.intermediary_builder, attribute)
if not instance.is_loaded():
return self
if attribute in instance._relationships:
return instance._relationships[attribute]
return self.apply_related_query(
self.distant_builder, self.intermediary_builder, instance
)
def apply_related_query(self, distant_builder, intermediary_builder, owner):
"""
Apply the query to return a Collection of data for the distant models to be hydrated with.
Method is used when accessing a relationship on a model if its not
already eager loaded
Arguments
distant_builder (QueryBuilder): QueryBuilder attached to the distant table
intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table
owner (Any): the model this relationship is starting from
Returns
Collection: Collection of dicts which will be used for hydrating models.
"""
distant_table = distant_builder.get_table_name()
intermediate_table = intermediary_builder.get_table_name()
return (
self.distant_builder.select(
f"{distant_table}.*, {intermediate_table}.{self.local_key}"
)
.join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
)
.where(
f"{intermediate_table}.{self.local_key}",
getattr(owner, self.local_owner_key),
)
.get()
)
def relate(self, related_model):
return self.distant_builder.join(
f"{self.intermediary_builder.get_table_name()}",
f"{self.intermediary_builder.get_table_name()}.{self.foreign_key}",
"=",
f"{self.distant_builder.get_table_name()}.{self.other_owner_key}",
).where(
f"{self.intermediary_builder.get_table_name()}.{self.local_key}",
getattr(related_model, self.local_owner_key),
)
def get_builder(self):
return self.distant_builder
def make_builder(self, eagers=None):
builder = self.get_builder().with_(eagers)
return builder
def register_related(self, key, model, collection):
"""
Attach the related model to source models attribute
Arguments
key (str): The attribute name
model (Any): The model instance
collection (Collection): The data for the related models
Returns
None
"""
related = collection.get(getattr(model, self.local_owner_key), None)
if related and not isinstance(related, Collection):
related = Collection(related)
model.add_relation({key: related if related else None})
def get_related(self, current_builder, relation, eagers=None, callback=None):
"""
Get a Collection to hydrate the models for the distant table with
Used when eager loading the model attribute
Arguments
current_builder (QueryBuilder): The source models QueryBuilder object
relation (HasManyThrough): this relationship object
eagers (Any):
callback (Any):
Returns
Collection the collection of dicts to hydrate the distant models with
"""
distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()
if callback:
callback(current_builder)
(
self.distant_builder.select(
f"{distant_table}.*, {intermediate_table}.{self.local_key}"
).join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
)
)
if isinstance(relation, Collection):
return self.distant_builder.where_in(
f"{intermediate_table}.{self.local_key}",
Collection(relation._get_value(self.local_owner_key)).unique(),
).get()
else:
return self.distant_builder.where(
f"{intermediate_table}.{self.local_key}",
getattr(relation, self.local_owner_key),
).get()
def query_has(self, current_builder, method="where_exists"):
distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()
getattr(current_builder, method)(
self.distant_builder.join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
).where_column(
f"{intermediate_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
)
return self.distant_builder
def query_where_exists(self, current_builder, callback, method="where_exists"):
distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()
getattr(current_builder, method)(
self.distant_builder.join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
)
.where_column(
f"{intermediate_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
.when(callback, lambda q: (callback(q)))
)
def get_with_count_query(self, current_builder, callback):
distant_table = self.distant_builder.get_table_name()
intermediate_table = self.intermediary_builder.get_table_name()
if not current_builder._columns:
current_builder.select("*")
return_query = current_builder.add_select(
f"{self.attribute}_count",
lambda q: (
(
q.count("*")
.join(
f"{intermediate_table}",
f"{intermediate_table}.{self.foreign_key}",
"=",
f"{distant_table}.{self.other_owner_key}",
)
.where_column(
f"{intermediate_table}.{self.local_key}",
f"{current_builder.get_table_name()}.{self.local_owner_key}",
)
.table(distant_table)
.when(
callback,
lambda q: (
q.where_in(
self.foreign_key,
callback(
self.distant_builder.select(self.other_owner_key)
),
)
),
)
)
),
)
return return_query
def map_related(self, related_result):
return related_result.group_by(self.local_key)
================================================
FILE: src/masoniteorm/relationships/HasOne.py
================================================
from ..collection import Collection
from .BaseRelationship import BaseRelationship
class HasOne(BaseRelationship):
"""Belongs To Relationship Class."""
def __init__(self, fn, foreign_key=None, local_key=None):
if isinstance(fn, str):
self.foreign_key = fn
self.local_key = foreign_key or "id"
else:
self.fn = fn
self.local_key = local_key or "id"
self.foreign_key = foreign_key
def set_keys(self, owner, attribute):
self.local_key = self.local_key or "id"
self.foreign_key = self.foreign_key or f"{attribute}_id"
return self
def apply_query(self, foreign, owner):
"""Apply the query and return a dictionary to be hydrated
Arguments:
foreign {oject} -- The relationship object
owner {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
return foreign.where(
self.foreign_key, owner.__attributes__[self.local_key]
).first()
def get_related(self, query, relation, eagers=(), callback=None):
"""Gets the relation needed between the relation and the related builder. If the relation is a collection
then will need to pluck out all the keys from the collection and fetch from the related builder. If
relation is just a Model then we can just call the model based on the value of the related
builders primary key.
Args:
relation (Model|Collection):
Returns:
Model|Collection
"""
builder = self.get_builder().with_(eagers)
if callback:
callback(builder)
if isinstance(relation, Collection):
return builder.where_in(
f"{builder.get_table_name()}.{self.foreign_key}",
Collection(relation._get_value(self.local_key)).unique(),
).get()
else:
return builder.where(
f"{builder.get_table_name()}.{self.foreign_key}",
getattr(relation, self.local_key),
).first()
def query_has(self, current_query_builder, method="where_exists"):
related_builder = self.get_builder()
getattr(current_query_builder, method)(
related_builder.where_column(
f"{related_builder.get_table_name()}.{self.foreign_key}",
f"{current_query_builder.get_table_name()}.{self.local_key}",
)
)
return related_builder
def query_where_exists(self, builder, callback, method="where_exists"):
query = self.get_builder()
getattr(builder, method)(
callback(
query.where_column(
f"{query.get_table_name()}.{self.foreign_key}",
f"{builder.get_table_name()}.{self.local_key}",
)
)
)
return query
def register_related(self, key, model, collection):
related = collection.where(
self.foreign_key, getattr(model, self.local_key)
).first()
model.add_relation({key: related or None})
def map_related(self, related_result):
return related_result
def attach(self, current_model, related_record):
local_key_value = getattr(current_model, self.local_key)
if not related_record.is_created():
related_record.fill({self.foreign_key: local_key_value})
return related_record.create(related_record.all_attributes(), cast=True)
return related_record.update({self.foreign_key: local_key_value})
def detach(self, current_model, related_record):
return related_record.update({self.foreign_key: None})
================================================
FILE: src/masoniteorm/relationships/HasOneThrough.py
================================================
from ..collection import Collection
from .BaseRelationship import BaseRelationship
class HasOneThrough(BaseRelationship):
"""HasOneThrough Relationship Class."""
def __init__(
self,
fn=None,
local_foreign_key=None,
other_foreign_key=None,
local_owner_key=None,
other_owner_key=None,
):
if isinstance(fn, str):
self.fn = None
self.local_key = fn
self.foreign_key = local_foreign_key
self.local_owner_key = other_foreign_key or "id"
self.other_owner_key = local_owner_key or "id"
else:
self.fn = fn
self.local_key = local_foreign_key
self.foreign_key = other_foreign_key
self.local_owner_key = local_owner_key or "id"
self.other_owner_key = other_owner_key or "id"
def __getattr__(self, attribute):
relationship = self.fn(self)[1]()
return getattr(relationship.builder, attribute)
def set_keys(self, distant_builder, intermediary_builder, attribute):
self.local_key = self.local_key or "id"
self.foreign_key = self.foreign_key or f"{attribute}_id"
self.local_owner_key = self.local_owner_key or "id"
self.other_owner_key = self.other_owner_key or "id"
return self
def __get__(self, instance, owner):
"""
This method is called when the decorated method is accessed.
Arguments
instance (object|None): The instance we called.
If we didn't call the attribute and only accessed it then this will be None.
owner (object): The current model that the property was accessed on.
Returns
QueryBuilder|Model: Either returns a builder or a hydrated model.
"""
attribute = self.fn.__name__
self.attribute = attribute
relationship1 = self.fn(self)[0]()
relationship2 = self.fn(self)[1]()
self.distant_builder = relationship1.builder
self.intermediary_builder = relationship2.builder
self.set_keys(self.distant_builder, self.intermediary_builder, attribute)
if instance.is_loaded():
if attribute in instance._relationships:
return instance._relationships[attribute]
return self.apply_relation_query(
self.distant_builder, self.intermediary_builder, instance
)
else:
return self
def apply_relation_query(self, distant_builder, intermediary_builder, owner):
"""
Apply the query and return a dict of data for the distant model to be hydrated with.
Method is used when accessing a relationship on a model if its not
already eager loaded
Arguments
distant_builder (QueryBuilder): QueryBuilder attached to the distant table
intermediate_builder (QueryBuilder): QueryBuilder attached to the intermediate (linking) table
owner (Any): the model this relationship is starting from
Returns
dict: A dictionary of data which will be hydrated.
"""
dist_table = distant_builder.get_table_name()
int_table = intermediary_builder.get_table_name()
return (
distant_builder.select(
f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}"
)
.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
)
.where(
f"{int_table}.{self.local_owner_key}",
getattr(owner, self.local_key),
)
.first()
)
def relate(self, related_model):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()
return self.distant_builder.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
).where_column(
f"{int_table}.{self.local_owner_key}",
getattr(related_model, self.local_key),
)
def get_builder(self):
return self.distant_builder
def make_builder(self, eagers=None):
builder = self.get_builder().with_(eagers)
return builder
def register_related(self, key, model, collection):
"""
Attach the related model to source models attribute
Arguments
key (str): The attribute name
model (Any): The model instance
collection (Collection): The data for the related models
Returns
None
"""
related = collection.get(getattr(model, self.local_key), None)
model.add_relation({key: related[0] if related else None})
def get_related(self, current_builder, relation, eagers=None, callback=None):
"""
Get the data to hydrate the model for the distant table with
Used when eager loading the model attribute
Arguments
query (QueryBuilder): The source models QueryBuilder object
relation (HasOneThrough): this relationship object
eagers (Any):
callback (Any):
Returns
dict: the dict to hydrate the distant model with
"""
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()
if callback:
callback(current_builder)
(
self.distant_builder.select(
f"{dist_table}.*, {int_table}.{self.local_owner_key} as {self.local_key}"
).join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
)
)
if isinstance(relation, Collection):
return self.distant_builder.where_in(
f"{int_table}.{self.local_owner_key}",
Collection(relation._get_value(self.local_key)).unique(),
).get()
else:
return self.distant_builder.where(
f"{int_table}.{self.local_owner_key}",
getattr(relation, self.local_key),
).first()
def query_has(self, current_builder, method="where_exists"):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()
getattr(current_builder, method)(
self.distant_builder.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
).where_column(
f"{int_table}.{self.local_owner_key}",
f"{current_builder.get_table_name()}.{self.local_key}",
)
)
return self.distant_builder
def query_where_exists(self, current_builder, callback, method="where_exists"):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()
getattr(current_builder, method)(
self.distant_builder.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
)
.where_column(
f"{int_table}.{self.local_owner_key}",
f"{current_builder.get_table_name()}.{self.local_key}",
)
.when(callback, lambda q: (callback(q)))
)
def get_with_count_query(self, current_builder, callback):
dist_table = self.distant_builder.get_table_name()
int_table = self.intermediary_builder.get_table_name()
if not current_builder._columns:
current_builder.select("*")
return_query = current_builder.add_select(
f"{self.attribute}_count",
lambda q: (
(
q.count("*")
.join(
f"{int_table}",
f"{int_table}.{self.foreign_key}",
"=",
f"{dist_table}.{self.other_owner_key}",
)
.where_column(
f"{int_table}.{self.local_owner_key}",
f"{current_builder.get_table_name()}.{self.local_key}",
)
.table(dist_table)
.when(
callback,
lambda q: (
q.where_in(
self.foreign_key,
callback(
self.distant_builder.select(self.other_owner_key)
),
)
),
)
)
),
)
return return_query
def map_related(self, related_result):
return related_result.group_by(self.local_key)
================================================
FILE: src/masoniteorm/relationships/MorphMany.py
================================================
from ..collection import Collection
from ..config import load_config
from .BaseRelationship import BaseRelationship
class MorphMany(BaseRelationship):
def __init__(self, fn, morph_key="record_type", morph_id="record_id"):
if isinstance(fn, str):
self.fn = None
self.morph_key = fn
self.morph_id = morph_key
else:
self.fn = fn
self.morph_id = morph_id
self.morph_key = morph_key
def get_builder(self):
return self._related_builder
def set_keys(self, owner, attribute):
self.morph_id = self.morph_id or "record_id"
self.morph_key = self.morph_key or "record_type"
return self
def __get__(self, instance, owner):
"""This method is called when the decorated method is accessed.
Arguments:
instance {object|None} -- The instance we called.
If we didn't call the attribute and only accessed it then this will be None.
owner {object} -- The current model that the property was accessed on.
Returns:
object -- Either returns a builder or a hydrated model.
"""
attribute = self.fn.__name__
self._related_builder = instance.builder
self.polymorphic_builder = self.fn(self)()
self.set_keys(owner, self.fn)
if not instance.is_loaded():
return self
if attribute in instance._relationships:
return instance._relationships[attribute]
return self.apply_query(self._related_builder, instance)
def __getattr__(self, attribute):
relationship = self.fn(self)()
return getattr(relationship.builder, attribute)
def apply_query(self, builder, instance):
"""Apply the query and return a dictionary to be hydrated
Arguments:
builder {oject} -- The relationship object
instance {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
polymorphic_key = self.get_record_key_lookup(builder._model)
polymorphic_builder = self.polymorphic_builder
return (
polymorphic_builder.where(self.morph_key, polymorphic_key)
.where(self.morph_id, instance.get_primary_key_value())
.get()
)
def get_related(self, query, relation, eagers=None, callback=None):
"""Gets the relation needed between the relation and the related builder. If the relation is a collection
then will need to pluck out all the keys from the collection and fetch from the related builder. If
relation is just a Model then we can just call the model based on the value of the related
builders primary key.
Args:
relation (Model|Collection):
Returns:
Model|Collection
"""
if isinstance(relation, Collection):
record_type = self.get_record_key_lookup(relation.first())
if callback:
return callback(
self.polymorphic_builder.where(
f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}",
record_type,
).where_in(
self.morph_id,
relation.pluck(
relation.first().get_primary_key(), keep_nulls=False
).unique(),
)
).get()
return (
self.polymorphic_builder.where(
f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}",
record_type,
)
.where_in(
self.morph_id,
relation.pluck(
relation.first().get_primary_key(), keep_nulls=False
).unique(),
)
.get()
)
else:
record_type = self.get_record_key_lookup(relation)
if callback:
return callback(
self.polymorphic_builder.where(self.morph_key, record_type).where(
self.morph_id, relation.get_primary_key_value()
)
).get()
return (
self.polymorphic_builder.where(self.morph_key, record_type)
.where(self.morph_id, relation.get_primary_key_value())
.get()
)
def register_related(self, key, model, collection):
record_type = self.get_record_key_lookup(model)
related = collection.where(self.morph_key, record_type).where(
self.morph_id, model.get_primary_key_value()
)
model.add_relation({key: related})
def morph_map(self):
return load_config().DB._morph_map
def get_record_key_lookup(self, relation):
record_type = None
for record_type_loop, model in self.morph_map().items():
if model == relation.__class__:
record_type = record_type_loop
break
if not record_type:
raise ValueError(
f"Could not find the record type key for the {relation} class"
)
return record_type
================================================
FILE: src/masoniteorm/relationships/MorphOne.py
================================================
from ..collection import Collection
from ..config import load_config
from .BaseRelationship import BaseRelationship
class MorphOne(BaseRelationship):
def __init__(self, fn, morph_key="record_type", morph_id="record_id"):
if isinstance(fn, str):
self.fn = None
self.morph_key = fn
self.morph_id = morph_key
else:
self.fn = fn
self.morph_id = morph_id
self.morph_key = morph_key
def get_builder(self):
return self._related_builder
def set_keys(self, owner, attribute):
self.morph_id = self.morph_id or "record_id"
self.morph_key = self.morph_key or "record_type"
return self
def __get__(self, instance, owner):
"""This method is called when the decorated method is accessed.
Arguments:
instance {object|None} -- The instance we called.
If we didn't call the attribute and only accessed it then this will be None.
owner {object} -- The current model that the property was accessed on.
Returns:
object -- Either returns a builder or a hydrated model.
"""
attribute = self.fn.__name__
self._related_builder = instance.builder
self.polymorphic_builder = self.fn(self)()
self.set_keys(owner, self.fn)
if not instance.is_loaded():
return self
if attribute in instance._relationships:
return instance._relationships[attribute]
return self.apply_query(self._related_builder, instance)
def __getattr__(self, attribute):
relationship = self.fn(self)()
return getattr(relationship.builder, attribute)
def apply_query(self, builder, instance):
"""Apply the query and return a dictionary to be hydrated
Arguments:
builder {oject} -- The relationship object
instance {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
polymorphic_key = self.get_record_key_lookup(builder._model)
polymorphic_builder = self.polymorphic_builder
return (
polymorphic_builder.where(self.morph_key, polymorphic_key)
.where(self.morph_id, instance.get_primary_key_value())
.first()
)
def get_related(self, query, relation, eagers=None, callback=None):
"""Gets the relation needed between the relation and the related builder. If the relation is a collection
then will need to pluck out all the keys from the collection and fetch from the related builder. If
relation is just a Model then we can just call the model based on the value of the related
builders primary key.
Args:
relation (Model|Collection):
Returns:
Model|Collection
"""
if isinstance(relation, Collection):
record_type = self.get_record_key_lookup(relation.first())
if callback:
return callback(
self.polymorphic_builder.where(
f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}",
record_type,
).where_in(
self.morph_id,
relation.pluck(
relation.first().get_primary_key(), keep_nulls=False
).unique(),
)
).get()
return (
self.polymorphic_builder.where(
f"{self.polymorphic_builder.get_table_name()}.{self.morph_key}",
record_type,
)
.where_in(
self.morph_id,
relation.pluck(
relation.first().get_primary_key(), keep_nulls=False
).unique(),
)
.get()
)
else:
record_type = self.get_record_key_lookup(relation)
if callback:
return callback(
self.polymorphic_builder.where(self.morph_key, record_type).where(
self.morph_id, relation.get_primary_key_value()
)
).first()
return (
self.polymorphic_builder.where(self.morph_key, record_type)
.where(self.morph_id, relation.get_primary_key_value())
.first()
)
def register_related(self, key, model, collection):
record_type = self.get_record_key_lookup(model)
related = (
collection.where(self.morph_key, record_type)
.where(self.morph_id, model.get_primary_key_value())
.first()
)
model.add_relation({key: related})
def morph_map(self):
return load_config().DB._morph_map
def get_record_key_lookup(self, relation):
record_type = None
for record_type_loop, model in self.morph_map().items():
if model == relation.__class__:
record_type = record_type_loop
break
if not record_type:
raise ValueError(
f"Could not find the record type key for the {relation} class"
)
return record_type
================================================
FILE: src/masoniteorm/relationships/MorphTo.py
================================================
from ..collection import Collection
from ..config import load_config
from .BaseRelationship import BaseRelationship
class MorphTo(BaseRelationship):
def __init__(self, fn, morph_key="record_type", morph_id="record_id"):
if isinstance(fn, str):
self.fn = None
self.morph_key = fn
self.morph_id = morph_key
else:
self.fn = fn
self.morph_id = morph_id
self.morph_key = morph_key
def get_builder(self):
return self._related_builder
def set_keys(self, owner, attribute):
self.morph_id = self.morph_id or "record_id"
self.morph_key = self.morph_key or "record_type"
return self
def __get__(self, instance, owner):
"""This method is called when the decorated method is accessed.
Arguments:
instance {object|None} -- The instance we called.
If we didn't call the attribute and only accessed it then this will be None.
owner {object} -- The current model that the property was accessed on.
Returns:
object -- Either returns a builder or a hydrated model.
"""
attribute = self.fn.__name__
self._related_builder = instance.builder
self.set_keys(owner, self.fn)
if not instance.is_loaded():
return self
if attribute in instance._relationships:
return instance._relationships[attribute]
return self.apply_query(self._related_builder, instance)
def __getattr__(self, attribute):
relationship = self.fn(self)()
return getattr(relationship._related_builder, attribute)
def apply_query(self, builder, instance):
"""Apply the query and return a dictionary to be hydrated
Arguments:
builder {oject} -- The relationship object
instance {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
model = self.morph_map().get(instance.__attributes__[self.morph_key])
record = instance.__attributes__[self.morph_id]
return model.where(model.get_primary_key(), record).first()
def get_related(self, query, relation, eagers=None, callback=None):
"""Gets the relation needed between the relation and the related builder. If the relation is a collection
then will need to pluck out all the keys from the collection and fetch from the related builder. If
relation is just a Model then we can just call the model based on the value of the related
builders primary key.
Args:
relation (Model|Collection):
Returns:
Model|Collection
"""
if isinstance(relation, Collection):
relations = Collection()
for group, items in relation.group_by(self.morph_key).items():
morphed_model = self.morph_map().get(group)
relations.merge(
morphed_model.where_in(
f"{morphed_model.get_table_name()}.{morphed_model.get_primary_key()}",
Collection(items)
.pluck(self.morph_id, keep_nulls=False)
.unique(),
).get()
)
return relations
else:
model = self.morph_map().get(getattr(relation, self.morph_key))
if model:
return model.find(getattr(relation, self.morph_id))
def register_related(self, key, model, collection):
morphed_model = self.morph_map().get(getattr(model, self.morph_key))
related = collection.where(
morphed_model.get_primary_key(), getattr(model, self.morph_id)
).first()
model.add_relation({key: related})
def morph_map(self):
return load_config().DB._morph_map
def map_related(self, related_result):
return related_result
================================================
FILE: src/masoniteorm/relationships/MorphToMany.py
================================================
from ..collection import Collection
from ..config import load_config
from .BaseRelationship import BaseRelationship
class MorphToMany(BaseRelationship):
def __init__(self, fn, morph_key="record_type", morph_id="record_id"):
if isinstance(fn, str):
self.fn = None
self.morph_key = fn
self.morph_id = morph_key
else:
self.fn = fn
self.morph_id = morph_id
self.morph_key = morph_key
def get_builder(self):
return self._related_builder
def set_keys(self, owner, attribute):
self.morph_id = self.morph_id or "record_id"
self.morph_key = self.morph_key or "record_type"
return self
def __get__(self, instance, owner):
"""This method is called when the decorated method is accessed.
Arguments:
instance {object|None} -- The instance we called.
If we didn't call the attribute and only accessed it then this will be None.
owner {object} -- The current model that the property was accessed on.
Returns:
object -- Either returns a builder or a hydrated model.
"""
attribute = self.fn.__name__
self._related_builder = instance.builder
self.set_keys(owner, self.fn)
if not instance.is_loaded():
return self
if attribute in instance._relationships:
return instance._relationships[attribute]
return self.apply_query(self._related_builder, instance)
def __getattr__(self, attribute):
relationship = self.fn(self)()
return getattr(relationship.builder, attribute)
def apply_query(self, builder, instance):
"""Apply the query and return a dictionary to be hydrated
Arguments:
builder {oject} -- The relationship object
instance {object} -- The current model oject.
Returns:
dict -- A dictionary of data which will be hydrated.
"""
model = self.morph_map().get(instance.__attributes__[self.morph_key])
record = instance.__attributes__[self.morph_id]
return model.where(model.get_primary_key(), record).first()
def get_related(self, query, relation, eagers=None, callback=None):
"""Gets the relation needed between the relation and the related builder. If the relation is a collection
then will need to pluck out all the keys from the collection and fetch from the related builder. If
relation is just a Model then we can just call the model based on the value of the related
builders primary key.
Args:
relation (Model|Collection):
Returns:
Model|Collection
"""
if isinstance(relation, Collection):
relations = Collection()
for group, items in relation.group_by(self.morph_key).items():
morphed_model = self.morph_map().get(group)
relations.merge(
morphed_model.where_in(
f"{morphed_model.get_table_name()}.{morphed_model.get_primary_key()}",
Collection(items)
.pluck(self.morph_id, keep_nulls=False)
.unique(),
).get()
)
return relations
else:
model = self.morph_map().get(getattr(relation, self.morph_key))
if model:
return model.find([getattr(relation, self.morph_id)])
def register_related(self, key, model, collection):
morphed_model = self.morph_map().get(getattr(model, self.morph_key))
related = collection.where(
morphed_model.get_primary_key(), getattr(model, self.morph_id)
)
model.add_relation({key: related})
def morph_map(self):
return load_config().DB._morph_map
================================================
FILE: src/masoniteorm/relationships/__init__.py
================================================
from .BelongsTo import BelongsTo as belongs_to
from .BelongsToMany import BelongsToMany as belongs_to_many
from .HasMany import HasMany as has_many
from .HasManyThrough import HasManyThrough as has_many_through
from .HasOne import HasOne as has_one
from .HasOneThrough import HasOneThrough as has_one_through
from .MorphMany import MorphMany as morph_many
from .MorphOne import MorphOne as morph_one
from .MorphTo import MorphTo as morph_to
from .MorphToMany import MorphToMany as morph_to_many
================================================
FILE: src/masoniteorm/schema/Blueprint.py
================================================
class Blueprint:
"""Used for building schemas for creating, modifying or altering schema."""
def __init__(
self,
grammar,
table="",
connection=None,
platform=None,
schema=None,
action=None,
default_string_length=None,
dry=False,
):
self.grammar = grammar
self.table = table
self._last_column = None
self._default_string_length = default_string_length
self.platform = platform
self.schema = schema
self._dry = dry
self._action = action
self.connection = connection
if not platform:
self.platform = self.connection.get_default_platform()
def string(self, column, length=255, nullable=False):
"""Sets a column to be the string representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {255})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "string", length=length, nullable=nullable
)
return self
def tiny_integer(self, column, length=1, nullable=False):
"""Sets a column to be the tiny_integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {1})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "tiny_integer", length=length, nullable=nullable
)
return self
def small_integer(self, column, length=5, nullable=False):
"""Sets a column to be the small_integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {5})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "small_integer", length=length, nullable=nullable
)
return self
def medium_integer(self, column, length=7, nullable=False):
"""Sets a column to be the medium_integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {7})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "medium_integer", length=length, nullable=nullable
)
return self
def integer(self, column, length=11, nullable=False):
"""Sets a column to be the integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {11})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "integer", length=length, nullable=nullable
)
return self
def big_integer(self, column, length=32, nullable=False):
"""Sets a column to be the big_integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {32})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "big_integer", length=length, nullable=nullable
)
return self
def unsigned_big_integer(self, column, length=32, nullable=False):
"""Sets a column to be the unsigned big_integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {32})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
return self.big_integer(column, length=length, nullable=nullable).unsigned()
def _compile_create(self):
return self.grammar(creates=self._columns, table=self.table)._compile_create()
def _compile_alter(self):
return self.grammar(creates=self._columns, table=self.table)._compile_create()
def increments(self, column, nullable=False):
"""Sets a column to be the auto incrementing primary key representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "increments", nullable=nullable
)
self.primary(column)
return self
def tiny_increments(self, column, nullable=False):
"""Sets a column to be the auto tiny incrementing primary key representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "tiny_increments", nullable=nullable
)
self.primary(column)
return self
def id(self, column="id"):
"""Sets a column to be the auto-incrementing big integer (8-byte) primary key representation for the table.
Arguments:
column {string} -- The column name. Defaults to "id".
Returns:
self
"""
return self.big_increments(column)
def uuid(self, column, nullable=False, length=36):
"""Sets a column to be the UUID4 representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "uuid", nullable=nullable, length=length
)
return self
def big_increments(self, column, nullable=False):
"""Sets a column to be the the big integer increments representation for the table
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "big_increments", nullable=nullable
)
self.primary(column)
return self
def binary(self, column, nullable=False):
"""Sets a column to be the binary representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "binary", nullable=nullable)
return self
def boolean(self, column, nullable=False):
"""Sets a column to be the boolean representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "boolean", nullable=nullable)
return self
def default(self, value, raw=False):
self._last_column.default = value
self._last_column.default_is_raw = raw
return self
def default_raw(self, value):
self.default(value, True)
return self
def char(self, column, length=1, nullable=False):
"""Sets a column to be the char representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length for the column (default: {1})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "char", length=length, nullable=nullable
)
return self
def date(self, column, nullable=False):
"""Sets a column to be the date representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "date", nullable=nullable)
return self
def time(self, column, nullable=False):
"""Sets a column to be the time representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "time", nullable=nullable)
return self
def datetime(self, column, nullable=False, now=False):
"""Sets a column to be the datetime representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
now {bool} -- Whether the default for the column should be the current time. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "datetime", nullable=nullable)
if now:
self._last_column.use_current()
return self
def timestamp(self, column, nullable=False, now=False):
"""Sets a column to be the timestamp representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
now {bool} -- Whether the default for the column should be the current time. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "timestamp", nullable=nullable
)
if now:
self._last_column.use_current()
return self
def timestamps(self):
"""Creates `created_at` and `updated_at` timestamp columns.
Returns:
self
"""
self.datetime("created_at", nullable=True, now=True)
self.datetime("updated_at", nullable=True, now=True)
return self
def decimal(self, column, length=17, precision=6, nullable=False):
"""Sets a column to be the decimal representation for the table.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
length {int} -- The total length of the decimal number (default: {17})
precision {int} -- The number of places that should be used for floating numbers. (default: {6})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column,
"decimal",
length="{length}, {precision}".format(length=length, precision=precision),
nullable=nullable,
)
return self
def float(self, column, length=19, precision=4, nullable=False):
"""Sets a column to be the float representation for the table.
Arguments:
column {string} -- The name of the column.
Keyword Arguments:
length {int} -- The total length of the float number (default: {17})
precision {int} -- The number of places that should be used for floating numbers. (default: {6})
nullable {bool} -- Whether the column is nullable (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column,
"float",
length="{length}, {precision}".format(length=length, precision=precision),
nullable=nullable,
)
return self
def double(self, column, nullable=False):
"""Sets a column to be the the double representation for the table
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "double", nullable=nullable)
return self
def enum(self, column, options=None, nullable=False):
"""Sets a column to be the enum representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
options {list} -- A list of available options for the enum. (default: {False})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
options = options or []
new_options = ""
for option in options:
new_options += "'{}',".format(option)
new_options = new_options.rstrip(",")
self._last_column = self.table.add_column(
column, "enum", length="255", values=options, nullable=nullable
)
return self
def text(self, column, length=None, nullable=False):
"""Sets a column to be the text representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column if any. (default: {False})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "text", length=length, nullable=nullable
)
return self
def tiny_text(self, column, length=None, nullable=False):
"""Sets a column to be the text representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column if any. (default: {False})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "tiny_text", length=length, nullable=nullable
)
return self
def unsigned_decimal(self, column, length=17, precision=6, nullable=False):
"""Sets a column to be the text representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column if any. (default: {False})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column,
"decimal",
length="{length}, {precision}".format(length=length, precision=precision),
nullable=nullable,
).unsigned()
return self
return self
def long_text(self, column, length=None, nullable=False):
"""Sets a column to be the long_text representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column if any. (default: {False})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "long_text", length=length, nullable=nullable
)
return self
def json(self, column, nullable=False):
"""Sets a column to be the json representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "json", nullable=nullable)
return self
def jsonb(self, column, nullable=False):
"""Sets a column to be the jsonb representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "jsonb", nullable=nullable)
return self
def inet(self, column, length=255, nullable=False):
"""Sets a column to be the inet representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "inet", length=255, nullable=nullable
)
return self
def cidr(self, column, length=255, nullable=False):
"""Sets a column to be the cidr representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "cidr", length=255, nullable=nullable
)
return self
def macaddr(self, column, length=255, nullable=False):
"""Sets a column to be the macaddr representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "macaddr", length=255, nullable=nullable
)
return self
def point(self, column, nullable=False):
"""Sets a column to be the point representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "point", nullable=nullable)
return self
def geometry(self, column, nullable=False):
"""Sets a column to be the geometry representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(column, "geometry", nullable=nullable)
return self
def year(self, column, length=4, default=None, nullable=False):
"""Sets a column to be the year representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "year", length=length, nullable=nullable, default=default
)
return self
def unsigned(self, column=None, length=None, nullable=False):
"""Sets a column to be the unsigned integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
length {int} -- The length of the column. (default: {False})
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
if not column:
self._last_column.unsigned()
return self
self._last_column = self.table.add_column(
column, "unsigned", length=length, nullable=nullable
).unsigned()
return self
def unsigned_integer(self, column, nullable=False):
"""Sets a column to be the unsigned integer representation for the table.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
self._last_column = self.table.add_column(
column, "integer", nullable=nullable
).unsigned()
return self
def morphs(self, column, nullable=False, indexes=True):
"""Sets a column to be used in a polymorphic relationship.
Arguments:
column {string} -- The column name.
Keyword Arguments:
nullable {bool} -- Whether the column is nullable. (default: {False})
Returns:
self
"""
_columns = []
_columns.append(
self.table.add_column(
"{}_id".format(column), "integer", nullable=nullable
).unsigned()
)
_columns.append(
self.table.add_column(
"{}_type".format(column),
"string",
nullable=nullable,
length=self._default_string_length,
)
)
if indexes:
for column in _columns:
self.index(column.name)
self._last_column = _columns
return self
def to_sql(self):
"""Compiles the blueprint class into a sql statement.
Returns:
string -- The SQL statement generated.
"""
if self._action == "create":
return self.platform().compile_create_sql(self.table)
elif self._action == "create_table_if_not_exists":
return self.platform().compile_create_sql(self.table, if_not_exists=True)
else:
if not self._dry:
# get current table schema
table = self.platform().get_current_schema(
self.connection, self.table.name, schema=self.schema
)
self.table.from_table = table
return self.platform().compile_alter_sql(self.table)
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, exc_traceback):
if self._dry:
return
return self.connection.query(self.to_sql(), ())
def nullable(self):
"""Sets the last columns created as nullable
Returns:
self
"""
_columns = []
if not isinstance(self._last_column, list):
_columns = [self._last_column]
for column in _columns:
column.nullable()
return self
def soft_deletes(self, name="deleted_at"):
return self.datetime(name, nullable=True).nullable()
def unique(self, column=None, name=None):
"""Sets the last column to be unique if no column name is passed.
If a column name is passed this method will create a new unique column.
Keyword Arguments:
column {string} -- The name of the column. (default: {None})
Returns:
self
"""
if not column:
column = self._last_column.name
if not isinstance(column, list):
column = [column]
self.table.add_constraint(
name or f"{self.table.name}_{'_'.join(column)}_unique",
"unique",
columns=column,
)
return self
def index(self, column=None, name=None):
"""Creates a constraint based on the index constraint representation of the table.
Arguments:
column {string} -- The name of the column to create the index on.
Returns:
self
"""
if not column:
column = self._last_column.name
if not isinstance(column, list):
column = [column]
self.table.add_index(
column, name or f"{self.table.name}_{'_'.join(column)}_index", "index"
)
return self
def fulltext(self, column=None, name=None):
"""Creates a constraint based on the full text constraint representation of the table.
Arguments:
column {string} -- The name of the column to create the index on.
Returns:
self
"""
if not column:
column = self._last_column.name
if not isinstance(column, list):
column = [column]
self.table.add_constraint(
name or f"{'_'.join(column)}_fulltext", "fulltext", column
)
return self
def primary(self, column=None, name=None):
"""Creates a constraint based on the primary key constraint representation of the table.
Sets the constraint on the last column if no column name is passed.
Arguments:
column {string} -- The name of the column to create the index on. (default: {None})
Returns:
self
"""
if column is None:
column = self._last_column.name
if not isinstance(column, list):
column = [column]
self.table.add_constraint(
name or f"{self.table.name}_{'_'.join(column)}_primary",
"primary_key",
columns=column,
)
return self
def add_foreign(self, columns, name=None):
"""Creates the foreign spliting the foreign name, reference column, and
reference table.
Arguments:
columns {string} -- The name of the from_column . to_column . table
"""
if len(columns.split(".")) != 3:
raise Exception(
"Wrong add_foreign argument, the struncture is from_column.to_column.table"
)
from_column, to_column, table = columns.split(".")
return self.foreign(from_column, name=name).references(to_column).on(table)
def foreign(self, column, name=None):
"""Starts the creation of a foreign key constraint
Arguments:
column {string} -- The name of the column to create the index on.
Returns:
self
"""
self._last_foreign = self.table.add_foreign_key(
column, name=name or f"{self.table.name}_{column}_foreign"
)
return self
def foreign_id(self, column):
"""Sets a column to be a unsigned big integer (8-byte) representation for a foreign ID.
Arguments:
column {string} -- The name of the column to reference.
Returns:
self
"""
return self.unsigned_big_integer(column).foreign(column)
def foreign_uuid(self, column):
"""Sets a column to be a UUID representation for a foreign UUID.
Arguments:
column {string} -- The name of the column to reference.
Returns:
self
"""
return self.uuid(column).foreign(column)
def foreign_id_for(self, model, column=None):
"""Sets a column to be a unsigned big integer (8-byte) representation for a foreign ID.
Arguments:
model {Model} -- The model to reference.
Returns:
self
"""
clm = column if column else model.get_foreign_key()
return (
self.foreign_id(clm)
if model.get_primary_key_type() == "int"
else self.foreign_uuid(column)
)
def references(self, column):
"""Sets the other column on the foreign table that the local column will use to reference.
Arguments:
column {string} -- The name of the column to create the index on.
Returns:
self
"""
self._last_foreign.references(column)
return self
def on(self, table):
"""Sets the foreign table that the local column will use to reference on.
Arguments:
table {string} -- The foreign table name.
Returns:
self
"""
self._last_foreign.on(table)
return self
def on_delete(self, action):
"""Sets the last foreign key to a specific on delete action.
Arguments:
action {string} -- The specific action to do on delete.
Returns:
self
"""
self._last_foreign.on_delete(action)
return self
def on_update(self, action):
"""Sets the last foreign key to a specific on update action.
Arguments:
action {string} -- The specific action to do on update.
Returns:
self
"""
self._last_foreign.on_update(action)
return self
def comment(self, comment):
self._last_column.add_comment(comment)
return self
def table_comment(self, comment):
self.table.add_comment(comment)
return self
def rename(self, old_column, new_column, data_type, length=None):
"""Rename a column from the old value to a new value.
Arguments:
old_column {string} -- The name of the original old column name.
new_column {string} -- The name of the new column name.
Returns:
self
"""
self.table.rename_column(old_column, new_column, data_type, length=length)
return self
def after(self, old_column):
"""Sets the column that this new column should be created after.
This is useful for setting the location of the new column in the table schema.
Arguments:
old_column {string} -- The column that this new column should be created after
Returns:
self
"""
self._last_column.after(old_column)
return self
def drop_column(self, *columns):
"""Sets columns that should be dropped
Returns:
self
"""
for column in columns:
self.table.drop_column(column)
return self
def drop_index(self, index):
"""Specifies indexes that should be dropped.
Arguments:
indexes {list|string} -- Either a list of indexes or a specific index.
Returns:
self
"""
if isinstance(index, list):
for column in index:
self.table.remove_index(f"{self.table.name}_{column}_index")
return self
self.table.remove_index(index)
return self
def change(self):
self.table.change_column(self._last_column)
return self
def drop_unique(self, index):
"""Drops a unique index.
Arguments:
indexes {list|string} -- Either a list of indexes or a specific index.
Returns:
self
"""
if isinstance(index, list):
for column in index:
self.table.remove_unique_index(f"{self.table.name}_{column}_unique")
return self
self.table.remove_unique_index(index)
def drop_primary(self, index):
"""Drops a unique index.
Arguments:
indexes {list|string} -- Either a list of indexes or a specific index.
Returns:
self
"""
if isinstance(index, list):
for column in index:
self.table.drop_primary(f"{self.table.name}_{column}_primary")
return self
self.table.drop_primary(index)
def drop_foreign(self, index):
"""Drops foreign key indexes.
Arguments:
indexes {list|string} -- Either a list of indexes or a specific index.
Returns:
self
"""
if isinstance(index, list):
for column in index:
self.table.drop_foreign(f"{self.table.name}_{column}_foreign")
return self
self.table.drop_foreign(index)
return self
================================================
FILE: src/masoniteorm/schema/Column.py
================================================
class Column:
"""Used for creating or modifying columns."""
def __init__(
self,
name,
column_type,
length=None,
values=None,
nullable=False,
default=None,
signed=None,
default_is_raw=False,
column_python_type=str,
):
self.column_type = column_type
self.column_python_type = column_python_type
self.name = name
self.length = length
self.values = values or []
self.is_null = nullable
self._after = None
self.old_column = ""
self.default = default
self._signed = signed
self.default_is_raw = default_is_raw
self.primary = False
self.comment = None
def nullable(self):
"""Sets this column to be nullable
Returns:
self
"""
self.is_null = True
return self
def signed(self):
"""Sets this column to be nullable
Returns:
self
"""
self._signed = "signed"
return self
def unsigned(self):
"""Sets this column to be nullable
Returns:
self
"""
self._signed = "unsigned"
return self
def not_nullable(self):
"""Sets this column to be not nullable
Returns:
self
"""
self.is_null = False
return self
def set_as_primary(self):
self.primary = True
def rename(self, column):
"""Renames this column to a new name
Arguments:
column {string} -- The old column name
Returns:
self
"""
self.old_column = column
return self
def after(self, after):
"""Sets the column that this new column should be created after.
This is useful for setting the location of the new column in the table schema.
Arguments:
after {string} -- The column that this new column should be created after
Returns:
self
"""
self._after = after
return self
def get_after_column(self):
"""Sets the column that this new column should be created after.
This is useful for setting the location of the new column in the table schema.
Arguments:
after {string} -- The column that this new column should be created after
Returns:
self
"""
return self._after
def default(self, value, raw=False):
"""Sets a default value for this column
Arguments:
value {string} -- A default value.
raw {bool} -- should the value be quoted
Returns:
self
"""
self.default = value
self.default_is_raw = raw
return self
def change(self):
"""Sets the schema to create a modify sql statement.
Returns:
self
"""
self._action = "modify"
return self
def use_current(self):
"""Sets the column to use a current timestamp.
Used for timestamp columns.
Returns:
self
"""
self.default = "current"
return self
def add_comment(self, comment):
self.comment = comment
return self
================================================
FILE: src/masoniteorm/schema/ColumnDiff.py
================================================
================================================
FILE: src/masoniteorm/schema/Constraint.py
================================================
class Constraint:
def __init__(self, name, constraint_type, columns=None):
self.name = name
self.constraint_type = constraint_type
self.columns = columns or []
================================================
FILE: src/masoniteorm/schema/ForeignKeyConstraint.py
================================================
class ForeignKeyConstraint:
def __init__(self, column, foreign_table, foreign_column, name=None):
self.column = column
self.foreign_table = foreign_table
self.foreign_column = foreign_column
self.delete_action = None
self.update_action = None
self.constraint_name = name
def references(self, foreign_column):
self.foreign_column = foreign_column
return self
def on(self, foreign_table):
self.foreign_table = foreign_table
return self
def on_delete(self, action):
self.delete_action = action
return self
def on_update(self, action):
self.update_action = action
return self
def name(self, name):
self.constraint_name = name
return self
================================================
FILE: src/masoniteorm/schema/Index.py
================================================
class Index:
def __init__(self, column, name, index_type):
self.column = column
self.name = name
self.index_type = index_type
================================================
FILE: src/masoniteorm/schema/Schema.py
================================================
from .Blueprint import Blueprint
from .Table import Table
from .TableDiff import TableDiff
from ..exceptions import ConnectionNotRegistered
from ..config import load_config
class Schema:
_default_string_length = "255"
_type_hints_map = {
"string": str,
"char": str,
"big_increments": int,
"integer": int,
"big_integer": int,
"tiny_integer": int,
"small_integer": int,
"medium_integer": int,
"integer_unsigned": int,
"big_integer_unsigned": int,
"tiny_integer_unsigned": int,
"small_integer_unsigned": int,
"medium_integer_unsigned": int,
"increments": int,
"uuid": str,
"binary": bytes,
"boolean": bool,
"decimal": float,
"double": float,
"enum": str,
"text": str,
"float": float,
"geometry": str, # ?
"json": dict,
"jsonb": bytes,
"inet": str,
"cidr": str,
"macaddr": str,
"long_text": str,
"point": str, # ?
"time": str, # or pendulum.DateTime
"timestamp": str, # or pendulum.DateTime
"date": str, # or pendulum.DateTime
"year": str,
"datetime": str, # or pendulum.DateTime
"tiny_increments": int,
"unsigned": int,
"unsigned_integer": int,
}
def __init__(
self,
dry=False,
connection="default",
connection_class=None,
platform=None,
grammar=None,
connection_details=None,
schema=None,
config_path=None,
):
self._dry = dry
self.connection = connection
self.connection_class = connection_class
self._connection = None
self.grammar = grammar
self.platform = platform
self.connection_details = connection_details or {}
self._blueprint = None
self._sql = None
self.schema = schema
self.config_path = config_path
if not self.connection_class:
self.on(self.connection)
if not self.platform:
self.platform = self.connection_class.get_default_platform()
def on(self, connection_key):
"""Change the connection from the default connection
Arguments:
connection {string} -- A connection string like 'mysql' or 'mssql'.
It will be made with the connection factory.
Returns:
cls
"""
DB = load_config(config_path=self.config_path).DB
if connection_key == "default":
self.connection = self.connection_details.get("default")
connection_detail = self._connection_driver = self.connection_details.get(
self.connection
)
if connection_detail:
self._connection_driver = connection_detail.get("driver")
else:
raise ConnectionNotRegistered(
f"Could not find the '{connection_key}' connection details"
)
self.connection_class = DB.connection_factory.make(self._connection_driver)
return self
def dry(self):
"""Whether the query should be executed. (default: {False})
Returns:
self
"""
self._dry = True
return self
def create(self, table):
"""Sets the table and returns the blueprint.
This should be used as a context manager.
Arguments:
table {string} -- The name of a table like 'users'
Returns:
masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object.
"""
self._table = table
self._blueprint = Blueprint(
self.grammar,
connection=self.new_connection(),
table=Table(table),
action="create",
platform=self.platform,
schema=self.schema,
default_string_length=self._default_string_length,
dry=self._dry,
)
return self._blueprint
def create_table_if_not_exists(self, table):
self._table = table
self._blueprint = Blueprint(
self.grammar,
connection=self.new_connection(),
table=Table(table),
action="create_table_if_not_exists",
platform=self.platform,
schema=self.schema,
default_string_length=self._default_string_length,
dry=self._dry,
)
return self._blueprint
def table(self, table):
"""Sets the table and returns the blueprint.
This should be used as a context manager.
Arguments:
table {string} -- The name of a table like 'users'
Returns:
masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object.
"""
self._table = table
self._blueprint = Blueprint(
self.grammar,
connection=self.new_connection(),
table=TableDiff(table),
action="alter",
platform=self.platform,
schema=self.schema,
default_string_length=self._default_string_length,
dry=self._dry,
)
return self._blueprint
def get_connection_information(self):
return {
"host": self.connection_details.get(self.connection, {}).get("host"),
"database": self.connection_details.get(self.connection, {}).get(
"database"
),
"user": self.connection_details.get(self.connection, {}).get("user"),
"port": self.connection_details.get(self.connection, {}).get("port"),
"password": self.connection_details.get(self.connection, {}).get(
"password"
),
"prefix": self.connection_details.get(self.connection, {}).get("prefix"),
"options": self.connection_details.get(self.connection, {}).get(
"options", {}
),
"full_details": self.connection_details.get(self.connection),
}
def new_connection(self):
if self._dry:
return
self._connection = (
self.connection_class(**self.get_connection_information())
.set_schema(self.schema)
.make_connection()
)
return self._connection
def has_column(self, table, column, query_only=False):
"""Checks if the a table has a specific column
Arguments:
table {string} -- The name of a table like 'users'
Returns:
masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object.
"""
sql = self.platform().compile_column_exists(table, column)
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
def get_columns(self, table, dict=True):
table = self.platform().get_current_schema(
self.new_connection(), table, schema=self.get_schema()
)
result = {}
if dict:
for column in table.get_added_columns().items():
result.update({column[0]: column[1]})
return result
else:
return table.get_added_columns().items()
@classmethod
def set_default_string_length(cls, length):
cls._default_string_length = length
return cls
def drop_table(self, table, query_only=False):
sql = self.platform().compile_drop_table(table)
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
def drop(self, *args, **kwargs):
return self.drop_table(*args, **kwargs)
def drop_table_if_exists(self, table, exists=False, query_only=False):
sql = self.platform().compile_drop_table_if_exists(table)
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
def rename(self, table, new_name):
sql = self.platform().compile_rename_table(table, new_name)
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
def truncate(self, table, foreign_keys=False):
sql = self.platform().compile_truncate(table, foreign_keys=foreign_keys)
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
def get_schema(self):
"""Gets the schema set on the migration class"""
return self.schema or self.get_connection_information().get("full_details").get(
"schema"
)
def get_all_tables(self):
"""Gets all tables in the database"""
sql = self.platform().compile_get_all_tables(
database=self.get_connection_information().get("database"),
schema=self.get_schema(),
)
if self._dry:
self._sql = sql
return sql
result = self.new_connection().query(sql, ())
return list(map(lambda t: list(t.values())[0], result)) if result else []
def has_table(self, table, query_only=False):
"""Checks if the a database has a specific table
Arguments:
table {string} -- The name of a table like 'users'
Returns:
masoniteorm.blueprint.Blueprint -- The Masonite ORM blueprint object.
"""
sql = self.platform().compile_table_exists(
table,
database=self.get_connection_information().get("database"),
schema=self.get_schema(),
)
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
def enable_foreign_key_constraints(self):
sql = self.platform().enable_foreign_key_constraints()
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
def disable_foreign_key_constraints(self):
sql = self.platform().disable_foreign_key_constraints()
if self._dry:
self._sql = sql
return sql
return bool(self.new_connection().query(sql, ()))
================================================
FILE: src/masoniteorm/schema/Table.py
================================================
from .Column import Column
from .Constraint import Constraint
from .Index import Index
from .ForeignKeyConstraint import ForeignKeyConstraint
class Table:
def __init__(self, table):
self.name = table
self.added_columns = {}
self.added_constraints = {}
self.added_indexes = {}
self.added_foreign_keys = {}
self.renamed_columns = {}
self.drop_indexes = {}
self.foreign_keys = {}
self.primary_key = None
self.comment = None
def add_column(
self,
name=None,
column_type=None,
length=None,
values=None,
nullable=False,
default=None,
signed=None,
default_is_raw=False,
primary=False,
column_python_type=str,
):
column = Column(
name,
column_type,
length=length,
nullable=nullable,
values=values or [],
default=default,
signed=signed,
default_is_raw=default_is_raw,
column_python_type=column_python_type,
)
if primary:
column.set_as_primary()
self.added_columns.update({name: column})
return column
def add_constraint(self, name, constraint_type, columns=None):
self.added_constraints.update(
{name: Constraint(name, constraint_type, columns=columns or [])}
)
def add_foreign_key(self, column, table=None, foreign_column=None, name=None):
foreign_key = ForeignKeyConstraint(
column, table, foreign_column, name=name or f"{self.name}_{column}_foreign"
)
self.added_foreign_keys.update({column: foreign_key})
return foreign_key
def get_added_foreign_keys(self):
return self.added_foreign_keys
def get_constraint(self, name):
return self.added_constraints[name]
def get_added_constraints(self):
return self.added_constraints
def get_added_columns(self):
return self.added_columns
def get_renamed_columns(self):
return self.added_columns
def set_primary_key(self, columns):
self.primary_key = columns
return self
def add_index(self, column, name, index_type):
self.added_indexes.update({name: Index(column, name, index_type)})
def get_index(self, name):
return self.added_indexes[name]
def add_comment(self, comment):
self.comment = comment
return self
================================================
FILE: src/masoniteorm/schema/TableDiff.py
================================================
from .Column import Column
from .Table import Table
class TableDiff(Table):
def __init__(self, name):
self.name = name
self.from_table = None
self.new_name = None
self.removed_indexes = []
self.removed_unique_indexes = []
self.added_indexes = {}
self.added_columns = {}
self.changed_columns = {}
self.dropped_columns = []
self.dropped_foreign_keys = []
self.dropped_primary_keys = []
self.renamed_columns = {}
self.removed_constraints = {}
self.added_constraints = {}
self.added_foreign_keys = {}
self.comment = None
def remove_constraint(self, name):
self.removed_constraints.update({name: self.from_table.get_constraint(name)})
def get_removed_constraints(self):
return self.removed_constraints
def get_renamed_columns(self):
return self.renamed_columns
def rename_column(
self,
original_name,
new_name,
column_type=None,
length=None,
nullable=False,
default=None,
):
self.renamed_columns.update(
{
original_name: Column(
new_name,
column_type,
length=length,
nullable=nullable,
default=default,
)
}
)
def remove_index(self, name):
self.removed_indexes.append(name)
def remove_unique_index(self, name):
self.removed_unique_indexes.append(name)
def drop_column(self, name):
self.dropped_columns.append(name)
def get_dropped_columns(self):
return self.dropped_columns
def get_dropped_foreign_keys(self):
return self.dropped_foreign_keys
def drop_foreign(self, name):
self.dropped_foreign_keys.append(name)
return self
def drop_primary(self, name):
self.dropped_primary_keys.append(name)
return self
def change_column(self, added_column):
self.added_columns.pop(added_column.name)
self.changed_columns.update({added_column.name: added_column})
def add_comment(self, comment):
self.comment = comment
return self
================================================
FILE: src/masoniteorm/schema/__init__.py
================================================
from .Schema import Schema
from .Table import Table
from .Column import Column
================================================
FILE: src/masoniteorm/schema/platforms/MSSQLPlatform.py
================================================
from .Platform import Platform
from ..Table import Table
class MSSQLPlatform(Platform):
types_without_lengths = [
"integer",
"big_integer",
"tiny_integer",
"small_integer",
"medium_integer",
]
type_map = {
"string": "VARCHAR",
"char": "CHAR",
"big_increments": "BIGINT IDENTITY",
"integer": "INT",
"big_integer": "BIGINT",
"tiny_integer": "TINYINT",
"small_integer": "SMALLINT",
"medium_integer": "MEDIUMINT",
"integer_unsigned": "INT",
"big_integer_unsigned": "BIGINT",
"tiny_integer_unsigned": "TINYINT",
"small_integer_unsigned": "SMALLINT",
"medium_integer_unsigned": "MEDIUMINT",
"increments": "INT IDENTITY",
"uuid": "CHAR",
"binary": "LONGBLOB",
"boolean": "BOOLEAN",
"decimal": "DECIMAL",
"double": "DOUBLE",
"enum": "VARCHAR",
"text": "TEXT",
"tiny_text": "TINYTEXT",
"float": "FLOAT",
"geometry": "GEOMETRY",
"json": "JSON",
"jsonb": "LONGBLOB",
"inet": "VARCHAR",
"cidr": "VARCHAR",
"macaddr": "VARCHAR",
"long_text": "LONGTEXT",
"point": "POINT",
"time": "TIME",
"timestamp": "DATETIME",
"date": "DATE",
"year": "YEAR",
"datetime": "DATETIME",
"tiny_increments": "TINYINT IDENTITY",
"unsigned": "INT",
"unsigned_integer": "INT",
}
premapped_nulls = {True: "NULL", False: "NOT NULL"}
premapped_defaults = {
"current": " DEFAULT CURRENT_TIMESTAMP",
"now": " DEFAULT NOW()",
"null": " DEFAULT NULL",
}
def compile_create_sql(self, table, if_not_exists=False):
sql = []
table_create_format = (
self.create_if_not_exists_format()
if if_not_exists
else self.create_format()
)
sql.append(
table_create_format.format(
table=self.wrap_table(table.name),
columns=", ".join(self.columnize(table.get_added_columns())).strip(),
constraints=(
", "
+ ", ".join(
self.constraintize(table.get_added_constraints(), table)
)
if table.get_added_constraints()
else ""
),
foreign_keys=(
", "
+ ", ".join(
self.foreign_key_constraintize(
table.name, table.added_foreign_keys
)
)
if table.added_foreign_keys
else ""
),
)
)
if table.added_indexes:
for name, index in table.added_indexes.items():
sql.append(
"CREATE INDEX {name} ON {table}({column})".format(
name=index.name,
table=self.wrap_table(table.name),
column=",".join(index.column),
)
)
return sql
def compile_alter_sql(self, table):
sql = []
if table.added_columns:
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns="ADD "
+ ", ".join(self.columnize(table.added_columns)).strip(),
)
)
if table.changed_columns:
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns="ALTER COLUMN "
+ ", ".join(self.columnize(table.changed_columns)).strip(),
)
)
if table.renamed_columns:
for name, column in table.get_renamed_columns().items():
sql.append(
self.rename_column_string(table.name, name, column.name).strip()
)
if table.dropped_columns:
dropped_sql = []
for name in table.get_dropped_columns():
dropped_sql.append(self.drop_column_string().format(name=name).strip())
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns="DROP COLUMN " + ", ".join(dropped_sql),
)
)
if table.added_foreign_keys:
for (
column,
foreign_key_constraint,
) in table.get_added_foreign_keys().items():
cascade = ""
if foreign_key_constraint.delete_action:
cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key_constraint.delete_action.lower())}"
if foreign_key_constraint.update_action:
cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key_constraint.update_action.lower())}"
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD "
+ self.get_foreign_key_constraint_string().format(
constraint_name=foreign_key_constraint.constraint_name,
column=self.wrap_column(column),
table=self.wrap_table(table.name),
foreign_table=self.wrap_table(
foreign_key_constraint.foreign_table
),
foreign_column=self.wrap_column(
foreign_key_constraint.foreign_column
),
cascade=cascade,
)
)
if table.dropped_foreign_keys:
constraints = table.dropped_foreign_keys
for constraint in constraints:
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}"
)
if table.added_indexes:
for name, index in table.added_indexes.items():
sql.append(
"CREATE INDEX {name} ON {table}({column})".format(
name=index.name,
table=self.wrap_table(table.name),
column=",".join(index.column),
)
)
if (
table.removed_indexes
or table.removed_unique_indexes
or table.dropped_primary_keys
):
constraints = table.removed_indexes
constraints += table.removed_unique_indexes
constraints += table.dropped_primary_keys
for constraint in constraints:
sql.append(
f"DROP INDEX {self.wrap_table(table.name)}.{self.wrap_table(constraint)}"
)
if table.added_constraints:
for name, constraint in table.added_constraints.items():
if constraint.constraint_type == "unique":
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} UNIQUE({','.join(constraint.columns)})"
)
elif constraint.constraint_type == "fulltext":
pass
elif constraint.constraint_type == "primary_key":
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})"
)
return sql
def add_column_string(self):
return "{name} {data_type}{length}"
def drop_column_string(self):
return "{name}"
def rename_column_string(self, table, old, new):
return f"EXEC sp_rename '{table}.{old}', '{new}', 'COLUMN'"
def columnize(self, columns):
sql = []
for name, column in columns.items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
if column.default == "":
default = " DEFAULT ''"
elif column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)) and not column.default_is_raw:
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
constraint = ""
column_constraint = ""
if column.primary:
constraint = " PRIMARY KEY"
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f" CHECK([{column.name}] IN ({values}))"
sql.append(
self.columnize_string()
.format(
name=column.name,
data_type=self.type_map.get(column.column_type, ""),
column_constraint=column_constraint,
length=length,
constraint=constraint,
nullable=self.premapped_nulls.get(column.is_null) or "",
default=default,
)
.strip()
)
return sql
def columnize_string(self):
return "[{name}] {data_type}{length} {nullable}{default}{column_constraint}{constraint}"
def constraintize(self, constraints, table):
sql = []
for name, constraint in constraints.items():
sql.append(
getattr(
self, f"get_{constraint.constraint_type}_constraint_string"
)().format(
columns=", ".join(constraint.columns),
name_columns="_".join(constraint.columns),
constraint_name=constraint.name,
table=table.name,
)
)
return sql
def get_table_string(self):
return "[{table}]"
def get_column_string(self):
return "[{column}]"
def create_format(self):
return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys})"
def create_if_not_exists_format(self):
return (
"CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys})"
)
def alter_format(self):
return "ALTER TABLE {table} {columns}"
def get_foreign_key_constraint_string(self):
return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}"
def get_primary_key_constraint_string(self):
return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})"
def get_unique_constraint_string(self):
return "CONSTRAINT {constraint_name} UNIQUE ({columns})"
def compile_table_exists(self, table, database=None, schema=None):
return f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = '{table}'"
def compile_truncate(self, table, foreign_keys=False):
if not foreign_keys:
return f"TRUNCATE TABLE {self.wrap_table(table)}"
return [
f"ALTER TABLE {self.wrap_table(table)} NOCHECK CONSTRAINT ALL",
f"TRUNCATE TABLE {self.wrap_table(table)}",
f"ALTER TABLE {self.wrap_table(table)} WITH CHECK CHECK CONSTRAINT ALL",
]
def compile_rename_table(self, current_name, new_name):
return f"EXEC sp_rename {self.wrap_table(current_name)}, {self.wrap_table(new_name)}"
def compile_drop_table_if_exists(self, table):
return f"DROP TABLE IF EXISTS {self.wrap_table(table)}"
def compile_drop_table(self, table):
return f"DROP TABLE {self.wrap_table(table)}"
def compile_column_exists(self, table, column):
return f"SELECT 1 FROM sys.columns WHERE Name = N'{column}' AND Object_ID = Object_ID(N'{table}')"
def compile_get_all_tables(self, database, schema=None):
return f"SELECT name FROM {database}.sys.tables"
def get_current_schema(self, connection, table_name, schema=None):
return Table(table_name)
def enable_foreign_key_constraints(self):
"""MSSQL does not allow a global way to enable foreign key constraints"""
return ""
def disable_foreign_key_constraints(self):
"""MSSQL does not allow a global way to disable foreign key constraints"""
return ""
================================================
FILE: src/masoniteorm/schema/platforms/MySQLPlatform.py
================================================
from ...schema import Schema
from .Platform import Platform
from ..Table import Table
import re
class MySQLPlatform(Platform):
types_without_lengths = ["enum"]
type_map = {
"string": "VARCHAR",
"char": "CHAR",
"integer": "INT",
"big_integer": "BIGINT",
"tiny_integer": "TINYINT",
"small_integer": "SMALLINT",
"medium_integer": "MEDIUMINT",
"integer_unsigned": "INT UNSIGNED",
"big_integer_unsigned": "BIGINT UNSIGNED",
"tiny_integer_unsigned": "TINYINT UNSIGNED",
"small_integer_unsigned": "SMALLINT UNSIGNED",
"medium_integer_unsigned": "MEDIUMINT UNSIGNED",
"big_increments": "BIGINT UNSIGNED AUTO_INCREMENT",
"increments": "INT UNSIGNED AUTO_INCREMENT",
"uuid": "CHAR",
"binary": "LONGBLOB",
"boolean": "BOOLEAN",
"decimal": "DECIMAL",
"double": "DOUBLE",
"enum": "ENUM",
"text": "TEXT",
"tiny_text": "TINYTEXT",
"float": "FLOAT",
"geometry": "GEOMETRY",
"json": "JSON",
"jsonb": "LONGBLOB",
"inet": "VARCHAR",
"cidr": "VARCHAR",
"macaddr": "VARCHAR",
"long_text": "LONGTEXT",
"point": "POINT",
"time": "TIME",
"timestamp": "TIMESTAMP",
"date": "DATE",
"year": "YEAR",
"datetime": "DATETIME",
"tiny_increments": "TINYINT AUTO_INCREMENT",
"unsigned": "INT UNSIGNED",
}
premapped_nulls = {True: "NULL", False: "NOT NULL"}
premapped_defaults = {
"current": " DEFAULT CURRENT_TIMESTAMP",
"now": " DEFAULT NOW()",
"null": " DEFAULT NULL",
}
signed = {"unsigned": "UNSIGNED", "signed": "SIGNED"}
def columnize(self, columns):
sql = []
for name, column in columns.items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
if column.default == "":
default = " DEFAULT ''"
elif column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)) and not column.default_is_raw:
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
constraint = ""
column_constraint = ""
if column.primary:
constraint = "PRIMARY KEY"
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f"({values})"
sql.append(
self.columnize_string()
.format(
name=self.get_column_string().format(column=column.name),
data_type=self.type_map.get(column.column_type, ""),
column_constraint=column_constraint,
length=length,
constraint=constraint,
nullable=self.premapped_nulls.get(column.is_null) or "",
default=default,
signed=(
" " + self.signed.get(column._signed) if column._signed else ""
),
comment=(
"COMMENT '" + column.comment + "'" if column.comment else ""
),
)
.strip()
)
return sql
def compile_create_sql(self, table, if_not_exists=False):
sql = []
table_create_format = (
self.create_if_not_exists_format()
if if_not_exists
else self.create_format()
)
sql.append(
table_create_format.format(
table=self.get_table_string().format(table=table.name),
columns=", ".join(self.columnize(table.get_added_columns())).strip(),
constraints=(
", "
+ ", ".join(
self.constraintize(table.get_added_constraints(), table)
)
if table.get_added_constraints()
else ""
),
foreign_keys=(
", "
+ ", ".join(
self.foreign_key_constraintize(
table.name, table.added_foreign_keys
)
)
if table.added_foreign_keys
else ""
),
comment=f" COMMENT '{table.comment}'" if table.comment else "",
)
)
if table.added_indexes:
for name, index in table.added_indexes.items():
sql.append(
"CREATE INDEX {name} ON {table}({column})".format(
name=index.name,
table=self.wrap_table(table.name),
column=",".join(index.column),
)
)
return sql
def compile_alter_sql(self, table):
sql = []
if table.added_columns:
add_columns = []
for name, column in table.get_added_columns().items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
default = ""
if column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)):
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
column_constraint = ""
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f"({values})"
add_columns.append(
self.add_column_string()
.format(
name=self.get_column_string().format(column=column.name),
data_type=self.type_map.get(column.column_type, ""),
column_constraint=column_constraint,
length=length,
constraint="PRIMARY KEY" if column.primary else "",
nullable="NULL" if column.is_null else "NOT NULL",
default=default,
signed=(
" " + self.signed.get(column._signed)
if column._signed
else ""
),
after=(
(" AFTER " + self.wrap_column(column._after))
if column._after
else ""
),
comment=(
" COMMENT '" + column.comment + "'"
if column.comment
else ""
),
)
.strip()
)
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns=", ".join(add_columns).strip(),
comment=f" COMMENT '{table.comment}'" if table.comment else "",
)
)
if table.renamed_columns:
renamed_sql = []
for name, column in table.get_renamed_columns().items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
renamed_sql.append(
self.rename_column_string()
.format(
to=self.columnize({column.name: column})[0],
old=self.get_column_string().format(column=name),
)
.strip()
)
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns=", ".join(renamed_sql).strip(),
)
)
if table.changed_columns:
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns=", ".join(
f"MODIFY {x}" for x in self.columnize(table.changed_columns)
),
)
)
if table.dropped_columns:
dropped_sql = []
for name in table.get_dropped_columns():
dropped_sql.append(
self.drop_column_string()
.format(name=self.get_column_string().format(column=name))
.strip()
)
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name), columns=", ".join(dropped_sql)
)
)
if table.added_foreign_keys:
for (
column,
foreign_key_constraint,
) in table.get_added_foreign_keys().items():
cascade = ""
if foreign_key_constraint.delete_action:
cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key_constraint.delete_action.lower())}"
if foreign_key_constraint.update_action:
cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key_constraint.update_action.lower())}"
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD "
+ self.get_foreign_key_constraint_string().format(
column=column,
constraint_name=foreign_key_constraint.constraint_name,
table=table.name,
foreign_table=foreign_key_constraint.foreign_table,
foreign_column=foreign_key_constraint.foreign_column,
cascade=cascade,
)
)
if table.dropped_foreign_keys:
constraints = table.dropped_foreign_keys
for constraint in constraints:
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} DROP FOREIGN KEY {constraint}"
)
if table.added_indexes:
for name, index in table.added_indexes.items():
sql.append(
"CREATE INDEX {name} ON {table}({column})".format(
name=index.name,
table=self.wrap_table(table.name),
column=",".join(index.column),
)
)
if table.added_constraints:
for name, constraint in table.added_constraints.items():
if constraint.constraint_type == "unique":
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT UNIQUE INDEX {constraint.name}({','.join(constraint.columns)})"
)
elif constraint.constraint_type == "fulltext":
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD FULLTEXT {constraint.name}({','.join(constraint.columns)})"
)
elif constraint.constraint_type == "primary_key":
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})"
)
if (
table.removed_indexes
or table.removed_unique_indexes
or table.dropped_primary_keys
):
constraints = table.removed_indexes
constraints += table.removed_unique_indexes
constraints += table.dropped_primary_keys
for constraint in constraints:
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} DROP INDEX {constraint}"
)
if table.comment:
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} COMMENT '{table.comment}'"
)
return sql
def add_column_string(self):
return "ADD {name} {data_type}{length}{column_constraint}{signed} {nullable}{default}{after}{comment}"
def drop_column_string(self):
return "DROP COLUMN {name}"
def change_column_string(self):
return "MODIFY {name}{data_type}{length}{column_constraint} {nullable}{default} {constraint}"
def rename_column_string(self):
return "CHANGE {old} {to}"
def columnize_string(self):
return "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}{comment}"
def constraintize(self, constraints, table):
sql = []
for name, constraint in constraints.items():
sql.append(
getattr(
self, f"get_{constraint.constraint_type}_constraint_string"
)().format(
columns=", ".join(constraint.columns),
name_columns="_".join(constraint.columns),
table=table.name,
constraint_name=constraint.name,
)
)
return sql
def get_table_string(self):
return "`{table}`"
def get_column_string(self):
return "`{column}`"
def create_format(self):
return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys}){comment}"
def create_if_not_exists_format(self):
return "CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys}){comment}"
def alter_format(self):
return "ALTER TABLE {table} {columns}"
def get_foreign_key_constraint_string(self):
return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}"
def get_primary_key_constraint_string(self):
return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})"
def get_unique_constraint_string(self):
return "CONSTRAINT {constraint_name} UNIQUE ({columns})"
def compile_table_exists(self, table, database=None, schema=None):
return f"SELECT * from information_schema.tables where table_name='{table}' AND table_schema = '{database}'"
def compile_truncate(self, table, foreign_keys=False):
if not foreign_keys:
return f"TRUNCATE {self.wrap_table(table)}"
return [
self.disable_foreign_key_constraints(),
f"TRUNCATE {self.wrap_table(table)}",
self.enable_foreign_key_constraints(),
]
def compile_rename_table(self, current_name, new_name):
return f"ALTER TABLE {self.wrap_table(current_name)} RENAME TO {self.wrap_table(new_name)}"
def compile_drop_table_if_exists(self, table):
return f"DROP TABLE IF EXISTS {self.wrap_table(table)}"
def compile_drop_table(self, table):
return f"DROP TABLE {self.wrap_table(table)}"
def compile_column_exists(self, table, column):
return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'"
def compile_get_all_tables(self, database, schema=None):
return f"SELECT table_name FROM information_schema.tables WHERE table_schema = '{database}'"
def get_current_schema(self, connection, table_name, schema=None):
table = Table(table_name)
sql = f"DESCRIBE {table_name}"
result = connection.query(sql, ())
reversed_type_map = {v: k for k, v in self.type_map.items()}
for column in result:
column_type = self.get_column_type(
reversed_type_map, column["Type"].upper()
)
length = self.get_column_length(column["Type"])
default = column.get("Default")
table.add_column(
column["Field"],
column_type,
column_python_type=Schema._type_hints_map.get(column_type, str),
default=default,
length=length,
)
return table
def get_column_type(self, reversed_type_map, column_type):
if "(" in column_type:
parenthesis_index = column_type.find("(")
column_type = column_type[:parenthesis_index]
length = self.get_column_length(column_type)
if column_type == "CHAR":
if length == "1":
return "char"
elif length == "36":
return "uuid"
else:
return "char"
elif column_type == "VARCHAR":
if length == "4":
return "year"
else:
return "string"
return reversed_type_map.get(column_type)
def get_column_length(self, raw_column_type):
regex = re.compile(r"^\w+\((\d+)\)")
match = regex.match(raw_column_type)
if match:
return match.groups()[0]
if "(" in raw_column_type:
parenthesis_index = raw_column_type.find("(")
return raw_column_type[parenthesis_index + 1 : -1]
else:
return None
def enable_foreign_key_constraints(self):
return "SET FOREIGN_KEY_CHECKS=1"
def disable_foreign_key_constraints(self):
return "SET FOREIGN_KEY_CHECKS=0"
================================================
FILE: src/masoniteorm/schema/platforms/Platform.py
================================================
class Platform:
foreign_key_actions = {
"cascade": "CASCADE",
"set null": "SET NULL",
"cascade": "CASCADE",
"restrict": "RESTRICT",
"no action": "NO ACTION",
"default": "SET DEFAULT",
}
signed = {"signed": "SIGNED", "unsigned": "UNSIGNED"}
def columnize(self, columns):
sql = []
for name, column in columns.items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
if column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)) and not column.default_is_raw:
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
sql.append(
self.columnize_string()
.format(
name=column.name,
data_type=self.type_map.get(column.column_type, ""),
length=length,
constraint="PRIMARY KEY" if column.primary else "",
nullable=self.premapped_nulls.get(column.is_null) or "",
default=default,
)
.strip()
)
return sql
def columnize_string(self):
raise NotImplementedError
def create_column_length(self, column_type):
if column_type in self.types_without_lengths:
return ""
return "({length})"
def foreign_key_constraintize(self, table, foreign_keys):
sql = []
for name, foreign_key in foreign_keys.items():
cascade = ""
if foreign_key.delete_action:
cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key.delete_action.lower())}"
if foreign_key.update_action:
cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key.update_action.lower())}"
sql.append(
self.get_foreign_key_constraint_string().format(
column=self.wrap_column(foreign_key.column),
constraint_name=foreign_key.constraint_name,
table=self.wrap_table(table),
foreign_table=self.wrap_table(foreign_key.foreign_table),
foreign_column=self.wrap_column(foreign_key.foreign_column),
cascade=cascade,
)
)
return sql
def constraintize(self, constraints):
sql = []
for name, constraint in constraints.items():
sql.append(
getattr(
self, f"get_{constraint.constraint_type}_constraint_string"
)().format(columns=", ".join(constraint.columns))
)
return sql
def wrap_table(self, table_name):
return self.get_table_string().format(table=table_name)
def wrap_column(self, column_name):
return self.get_column_string().format(column=column_name)
================================================
FILE: src/masoniteorm/schema/platforms/PostgresPlatform.py
================================================
from ...schema import Schema
from .Platform import Platform
from ..Table import Table
class PostgresPlatform(Platform):
types_without_lengths = [
"integer",
"big_integer",
"tiny_integer",
"small_integer",
"medium_integer",
"inet",
"cidr",
"macaddr",
"uuid",
]
type_map = {
"string": "VARCHAR",
"char": "CHAR",
"integer": "INTEGER",
"big_integer": "BIGINT",
"tiny_integer": "TINYINT",
"big_increments": "BIGSERIAL UNIQUE",
"small_integer": "SMALLINT",
"medium_integer": "MEDIUMINT",
# Postgres database does not implement unsigned types
# So the below types are the same as the normal ones
"integer_unsigned": "INTEGER",
"big_integer_unsigned": "BIGINT",
"tiny_integer_unsigned": "TINYINT",
"small_integer_unsigned": "SMALLINT",
"medium_integer_unsigned": "MEDIUMINT",
"increments": "SERIAL UNIQUE",
"uuid": "UUID",
"binary": "BYTEA",
"boolean": "BOOLEAN",
"decimal": "DECIMAL",
"double": "DOUBLE PRECISION",
"enum": "VARCHAR",
"text": "TEXT",
"tiny_text": "TEXT",
"float": "FLOAT",
"geometry": "GEOMETRY",
"json": "JSON",
"jsonb": "JSONB",
"inet": "INET",
"cidr": "CIDR",
"macaddr": "MACADDR",
"long_text": "TEXT",
"point": "POINT",
"time": "TIME",
"timestamp": "TIMESTAMP",
"date": "DATE",
"year": "YEAR",
"datetime": "TIMESTAMPTZ",
"tiny_increments": "TINYINT AUTO_INCREMENT",
"unsigned": "INT",
}
table_info_map = {
"CHARACTER VARYING": "string",
"TIMESTAMP WITH TIME ZONE": "datetime",
"TIMESTAMP WITHOUT TIME ZONE": "datetime",
}
premapped_defaults = {
"current": " DEFAULT CURRENT_TIMESTAMP",
"now": " DEFAULT NOW()",
"null": " DEFAULT NULL",
}
premapped_nulls = {True: "NULL", False: "NOT NULL"}
def compile_create_sql(self, table, if_not_exists=False):
sql = []
table_create_format = (
self.create_if_not_exists_format()
if if_not_exists
else self.create_format()
)
sql.append(
table_create_format.format(
table=self.wrap_table(table.name),
columns=", ".join(self.columnize(table.get_added_columns())).strip(),
constraints=(
", "
+ ", ".join(
self.constraintize(table.get_added_constraints(), table)
)
if table.get_added_constraints()
else ""
),
foreign_keys=(
", "
+ ", ".join(
self.foreign_key_constraintize(
table.name, table.added_foreign_keys
)
)
if table.added_foreign_keys
else ""
),
)
)
if table.added_indexes:
for name, index in table.added_indexes.items():
sql.append(
"CREATE INDEX {name} ON {table}({column})".format(
name=index.name,
table=self.wrap_table(table.name),
column=",".join(index.column),
)
)
for name, column in table.get_added_columns().items():
if column.comment:
sql.append(
f"""COMMENT ON COLUMN "{table.name}"."{name}" is '{column.comment}'"""
)
if table.comment:
sql.append(f"""COMMENT ON TABLE "{table.name}" is '{table.comment}'""")
return sql
def columnize(self, columns):
sql = []
for name, column in columns.items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
if column.default == "":
default = " DEFAULT ''"
elif column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)) and not column.default_is_raw:
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
constraint = ""
column_constraint = ""
if column.primary:
constraint = "PRIMARY KEY"
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f" CHECK({column.name} IN ({values}))"
sql.append(
self.columnize_string()
.format(
name=self.wrap_column(column.name),
data_type=self.type_map.get(column.column_type, ""),
column_constraint=column_constraint,
length=length,
constraint=constraint,
nullable=self.premapped_nulls.get(column.is_null) or "",
default=default,
)
.strip()
)
return sql
def compile_alter_sql(self, table):
sql = []
if table.added_columns:
add_columns = []
for name, column in table.get_added_columns().items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
default = ""
if column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)):
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
column_constraint = ""
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f" CHECK({column.name} IN ({values}))"
add_columns.append(
self.add_column_string()
.format(
name=self.wrap_column(column.name),
data_type=self.type_map.get(column.column_type, ""),
length=length,
constraint="PRIMARY KEY" if column.primary else "",
column_constraint=column_constraint,
nullable="NULL" if column.is_null else "NOT NULL",
default=default,
after=(
(" AFTER " + self.wrap_column(column._after))
if column._after
else ""
),
)
.strip()
)
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns=", ".join(add_columns).strip(),
)
)
if table.renamed_columns:
renamed_sql = []
for name, column in table.get_renamed_columns().items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
renamed_sql.append(
self.rename_column_string()
.format(
to=self.wrap_column(column.name), old=self.wrap_column(name)
)
.strip()
)
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name),
columns=", ".join(renamed_sql).strip(),
)
)
if table.dropped_columns:
dropped_sql = []
for name in table.get_dropped_columns():
dropped_sql.append(
self.drop_column_string()
.format(name=self.wrap_column(name))
.strip()
)
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name), columns=", ".join(dropped_sql)
)
)
if table.changed_columns:
changed_sql = []
for name, column in table.changed_columns.items():
column_constraint = ""
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f" CHECK({column.name} IN ({values}))"
changed_sql.append(
self.modify_column_string()
.format(
name=self.wrap_column(name),
data_type=self.type_map.get(column.column_type),
column_constraint=column_constraint,
constraint="PRIMARY KEY" if column.primary else "",
length=(
"(" + str(column.length) + ")"
if column.column_type not in self.types_without_lengths
else ""
),
)
.strip()
)
if column.is_null:
changed_sql.append(
f"ALTER COLUMN {self.wrap_column(name)} DROP NOT NULL"
)
else:
changed_sql.append(
f"ALTER COLUMN {self.wrap_column(name)} SET NOT NULL"
)
if column.default is not None:
changed_sql.append(
f"ALTER COLUMN {self.wrap_column(name)} SET DEFAULT {column.default}"
)
sql.append(
self.alter_format().format(
table=self.wrap_table(table.name), columns=", ".join(changed_sql)
)
)
if table.added_foreign_keys:
for (
column,
foreign_key_constraint,
) in table.get_added_foreign_keys().items():
cascade = ""
if foreign_key_constraint.delete_action:
cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key_constraint.delete_action.lower())}"
if foreign_key_constraint.update_action:
cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key_constraint.update_action.lower())}"
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD "
+ self.get_foreign_key_constraint_string().format(
column=self.wrap_column(column),
constraint_name=foreign_key_constraint.constraint_name,
table=self.wrap_table(table.name),
foreign_table=self.wrap_table(
foreign_key_constraint.foreign_table
),
foreign_column=self.wrap_column(
foreign_key_constraint.foreign_column
),
cascade=cascade,
)
)
if table.removed_indexes:
constraints = table.removed_indexes
for constraint in constraints:
sql.append(f"DROP INDEX {constraint}")
if (
table.dropped_foreign_keys
or table.removed_unique_indexes
or table.dropped_primary_keys
):
constraints = table.dropped_foreign_keys
constraints += table.removed_unique_indexes
constraints += table.dropped_primary_keys
for constraint in constraints:
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} DROP CONSTRAINT {constraint}"
)
if table.added_indexes:
for name, index in table.added_indexes.items():
sql.append(
"CREATE INDEX {name} ON {table}({column})".format(
name=index.name,
table=self.wrap_table(table.name),
column=",".join(index.column),
)
)
if table.added_constraints:
for name, constraint in table.added_constraints.items():
if constraint.constraint_type == "unique":
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} UNIQUE({','.join(constraint.columns)})"
)
elif constraint.constraint_type == "primary_key":
sql.append(
f"ALTER TABLE {self.wrap_table(table.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})"
)
for name, column in table.get_added_columns().items():
if column.comment:
sql.append(
f"""COMMENT ON COLUMN {self.wrap_table(table.name)}.{self.wrap_column(name)} is '{column.comment}'"""
)
if table.comment:
sql.append(
f"""COMMENT ON TABLE {self.wrap_table(table.name)} is '{table.comment}'"""
)
return sql
def alter_format(self):
return "ALTER TABLE {table} {columns}"
def alter_format_add_foreign_key(self):
return "ALTER TABLE {table} {columns}"
def add_column_string(self):
return "ADD COLUMN {name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}"
def drop_column_string(self):
return "DROP COLUMN {name}"
def modify_column_string(self):
return "ALTER COLUMN {name} TYPE {data_type}{length}{column_constraint} {constraint}"
def rename_column_string(self):
return "RENAME COLUMN {old} TO {to}"
def columnize_string(self):
return "{name} {data_type}{length}{column_constraint} {nullable}{default} {constraint}"
def constraintize(self, constraints, table):
sql = []
for name, constraint in constraints.items():
sql.append(
getattr(
self, f"get_{constraint.constraint_type}_constraint_string"
)().format(
columns=", ".join(constraint.columns),
name_columns="_".join(constraint.columns),
constraint_name=constraint.name,
table=table.name,
)
)
return sql
def create_format(self):
return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys})"
def create_if_not_exists_format(self):
return (
"CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys})"
)
def get_foreign_key_constraint_string(self):
return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}"
def get_primary_key_constraint_string(self):
return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})"
def get_unique_constraint_string(self):
return "CONSTRAINT {constraint_name} UNIQUE ({columns})"
def get_table_string(self):
return '"{table}"'
def get_column_string(self):
return '"{column}"'
def table_information_string(self):
return "SELECT * FROM information_schema.columns WHERE table_schema = '{schema}' AND table_name = '{table}'"
def compile_table_exists(self, table, database=None, schema=None):
return f"SELECT * from information_schema.tables where table_name='{table}' AND table_schema = '{schema or 'public'}'"
def compile_truncate(self, table, foreign_keys=False):
if not foreign_keys:
return f"TRUNCATE {self.wrap_table(table)}"
return [
f"ALTER TABLE {self.wrap_table(table)} DISABLE TRIGGER ALL",
f"TRUNCATE {self.wrap_table(table)}",
f"ALTER TABLE {self.wrap_table(table)} ENABLE TRIGGER ALL",
]
def compile_rename_table(self, current_name, new_name):
return f"ALTER TABLE {self.wrap_table(current_name)} RENAME TO {self.wrap_table(new_name)}"
def compile_drop_table_if_exists(self, table):
return f"DROP TABLE IF EXISTS {self.wrap_table(table)}"
def compile_drop_table(self, table):
return f"DROP TABLE {self.wrap_table(table)}"
def compile_column_exists(self, table, column):
return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'"
def compile_get_all_tables(self, database=None, schema=None):
return f"SELECT table_name FROM information_schema.tables WHERE table_schema = 'public' AND table_catalog = '{database}'"
def get_current_schema(self, connection, table_name, schema=None):
sql = self.table_information_string().format(
table=table_name, schema=schema or "public"
)
reversed_type_map = {v: k for k, v in self.type_map.items()}
reversed_type_map.update(self.table_info_map)
table = Table(table_name)
result = connection.query(sql, ())
for column in result:
column_type = reversed_type_map.get(column["data_type"].upper())
# find length
if column.get("character_maximum_length", None):
length = column.get("character_maximum_length")
elif column.get("numeric_precision", None):
length = column.get("numeric_precision")
elif column.get("datetime_precision", None):
length = column.get("datetime_precision")
else:
length = None
# find default
default = column.get("dflt_value", "") or column.get("column_default", "")
if default and default.startswith("nextval"):
table.set_primary_key(column["column_name"])
default = None
table.add_column(
column["column_name"],
column_type,
default=default,
column_python_type=Schema._type_hints_map.get(column_type, str),
length=length,
)
return table
def enable_foreign_key_constraints(self):
"""Postgres does not allow a global way to enable foreign key constraints"""
return ""
def disable_foreign_key_constraints(self):
"""Postgres does not allow a global way to disable foreign key constraints"""
return ""
================================================
FILE: src/masoniteorm/schema/platforms/SQLitePlatform.py
================================================
from ...schema import Schema
from ..Table import Table
from .Platform import Platform
class SQLitePlatform(Platform):
types_without_lengths = [
"integer",
"big_integer",
"tiny_integer",
"small_integer",
"medium_integer",
]
types_without_signs = ["decimal"]
type_map = {
"string": "VARCHAR",
"char": "CHAR",
"integer": "INTEGER",
"big_integer": "BIGINT",
"tiny_integer": "TINYINT",
"big_increments": "BIGINT",
"small_integer": "SMALLINT",
"medium_integer": "MEDIUMINT",
"integer_unsigned": "INT UNSIGNED",
"big_integer_unsigned": "BIGINT UNSIGNED",
"tiny_integer_unsigned": "TINYINT UNSIGNED",
"small_integer_unsigned": "SMALLINT UNSIGNED",
"medium_integer_unsigned": "MEDIUMINT UNSIGNED",
"increments": "INTEGER",
"uuid": "CHAR",
"binary": "LONGBLOB",
"boolean": "BOOLEAN",
"decimal": "DECIMAL",
"double": "DOUBLE",
"enum": "VARCHAR",
"text": "TEXT",
"tiny_text": "TEXT",
"float": "FLOAT",
"geometry": "GEOMETRY",
"json": "JSON",
"jsonb": "LONGBLOB",
"inet": "VARCHAR",
"cidr": "VARCHAR",
"macaddr": "VARCHAR",
"long_text": "LONGTEXT",
"point": "POINT",
"time": "TIME",
"timestamp": "TIMESTAMP",
"date": "DATE",
"year": "VARCHAR",
"datetime": "DATETIME",
"tiny_increments": "TINYINT AUTO_INCREMENT",
"unsigned": "INT UNSIGNED",
}
premapped_defaults = {
"current": " DEFAULT CURRENT_TIMESTAMP",
"now": " DEFAULT NOW()",
"null": " DEFAULT NULL",
}
premapped_nulls = {True: "NULL", False: "NOT NULL"}
def compile_create_sql(self, table, if_not_exists=False):
sql = []
table_create_format = (
self.create_if_not_exists_format()
if if_not_exists
else self.create_format()
)
sql.append(
table_create_format.format(
table=self.get_table_string().format(table=table.name).strip(),
columns=", ".join(self.columnize(table.get_added_columns())).strip(),
constraints=(
", " + ", ".join(self.constraintize(table.get_added_constraints()))
if table.get_added_constraints()
else ""
),
foreign_keys=(
", "
+ ", ".join(
self.foreign_key_constraintize(
table.name, table.added_foreign_keys
)
)
if table.added_foreign_keys
else ""
),
)
)
if table.added_indexes:
for name, index in table.added_indexes.items():
sql.append(
f"CREATE INDEX {index.name} ON {self.wrap_table(table.name)}({','.join(index.column)})"
)
return sql
def columnize(self, columns):
sql = []
for name, column in columns.items():
if column.length:
length = self.create_column_length(column.column_type).format(
length=column.length
)
else:
length = ""
if column.default == "":
default = " DEFAULT ''"
elif column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)) and not column.default_is_raw:
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
constraint = ""
column_constraint = ""
if column.primary:
constraint = "PRIMARY KEY"
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f" CHECK({column.name} IN ({values}))"
sql.append(
self.columnize_string()
.format(
name=self.wrap_column(column.name),
data_type=self.type_map.get(column.column_type, ""),
column_constraint=column_constraint,
length=length,
signed=(
" " + self.signed.get(column._signed)
if column.column_type not in self.types_without_signs
and column._signed
else ""
),
constraint=constraint,
nullable=self.premapped_nulls.get(column.is_null) or "",
default=default,
)
.strip()
)
return sql
def compile_alter_sql(self, diff):
sql = []
if diff.removed_indexes or diff.removed_unique_indexes:
indexes = diff.removed_indexes
indexes += diff.removed_unique_indexes
for name in indexes:
sql.append("DROP INDEX {name}".format(name=name))
if diff.added_columns:
for name, column in diff.added_columns.items():
default = ""
if column.default in (0,):
default = f" DEFAULT {column.default}"
elif column.default in self.premapped_defaults.keys():
default = self.premapped_defaults.get(column.default)
elif column.default:
if isinstance(column.default, (str,)):
default = f" DEFAULT '{column.default}'"
else:
default = f" DEFAULT {column.default}"
else:
default = ""
constraint = ""
column_constraint = ""
if column.name in diff.added_foreign_keys:
foreign_key = diff.added_foreign_keys[column.name]
constraint = f" REFERENCES {self.wrap_table(foreign_key.foreign_table)}({self.wrap_column(foreign_key.foreign_column)})"
if column.column_type == "enum":
values = ", ".join(f"'{x}'" for x in column.values)
column_constraint = f" CHECK('{column.name}' IN({values}))"
sql.append(
self.add_column_string()
.format(
table=self.wrap_table(diff.name),
name=self.wrap_column(column.name),
data_type=self.type_map.get(column.column_type, ""),
column_constraint=column_constraint,
nullable="NULL" if column.is_null else "NOT NULL",
default=default,
signed=(
" " + self.signed.get(column._signed)
if column.column_type not in self.types_without_signs
and column._signed
else ""
),
constraint=constraint,
)
.strip()
)
if (
diff.renamed_columns
or diff.dropped_columns
or diff.changed_columns
or diff.added_foreign_keys
):
original_columns = diff.from_table.added_columns
# pop off the dropped columns. No need for them here
for column in diff.dropped_columns:
original_columns.pop(column)
sql.append(
"CREATE TEMPORARY TABLE __temp__{table} AS SELECT {original_column_names} FROM {table}".format(
table=diff.name,
original_column_names=", ".join(
diff.from_table.added_columns.keys()
),
)
)
sql.append("DROP TABLE {table}".format(table=self.wrap_table(diff.name)))
columns = diff.from_table.added_columns
columns.update(diff.renamed_columns)
columns.update(diff.changed_columns)
columns.update(diff.added_columns)
sql.append(
self.create_format().format(
table=self.get_table_string().format(table=diff.name).strip(),
columns=", ".join(self.columnize(columns)).strip(),
constraints=(
", "
+ ", ".join(self.constraintize(diff.get_added_constraints()))
if diff.get_added_constraints()
else ""
),
foreign_keys=(
", "
+ ", ".join(
self.foreign_key_constraintize(
diff.name, diff.added_foreign_keys
)
)
if diff.added_foreign_keys
else ""
),
)
)
for column in diff.added_columns:
columns.pop(column)
sql.append(
"INSERT INTO {quoted_table} ({new_columns}) SELECT {original_column_names} FROM __temp__{table}".format(
quoted_table=self.wrap_table(diff.name),
table=diff.name,
new_columns=", ".join(self.columnize_names(columns)),
original_column_names=", ".join(
diff.from_table.added_columns.keys()
),
)
)
sql.append("DROP TABLE __temp__{table}".format(table=diff.name))
if diff.new_name:
sql.append(
"ALTER TABLE {old_name} RENAME TO {new_name}".format(
old_name=self.wrap_table(diff.name),
new_name=self.wrap_table(diff.new_name),
)
)
if diff.added_indexes:
for name, index in diff.added_indexes.items():
sql.append(
f"CREATE INDEX {index.name} ON {self.wrap_table(diff.name)}({','.join(index.column)})"
)
if diff.added_constraints:
for name, constraint in diff.added_constraints.items():
if constraint.constraint_type == "unique":
sql.append(
f"CREATE UNIQUE INDEX {constraint.name} ON {self.wrap_table(diff.name)}({','.join(constraint.columns if isinstance(constraint.columns, list) else [constraint.columns])})"
)
elif constraint.constraint_type == "primary_key":
sql.append(
f"ALTER TABLE {self.wrap_table(diff.name)} ADD CONSTRAINT {constraint.name} PRIMARY KEY ({','.join(constraint.columns)})"
)
return sql
def create_format(self):
return "CREATE TABLE {table} ({columns}{constraints}{foreign_keys})"
def create_if_not_exists_format(self):
return (
"CREATE TABLE IF NOT EXISTS {table} ({columns}{constraints}{foreign_keys})"
)
def get_table_string(self):
return '"{table}"'
def get_column_string(self):
return '"{column}"'
def add_column_string(self):
return "ALTER TABLE {table} ADD COLUMN {name} {data_type}{column_constraint}{signed} {nullable}{default}{constraint}"
def create_column_length(self, column_type):
if column_type in self.types_without_lengths:
return ""
return "({length})"
def columnize_string(self):
return "{name} {data_type}{length}{column_constraint}{signed} {nullable}{default} {constraint}"
def get_unique_constraint_string(self):
return "UNIQUE({columns})"
def get_foreign_key_constraint_string(self):
return "CONSTRAINT {constraint_name} FOREIGN KEY ({column}) REFERENCES {foreign_table}({foreign_column}){cascade}"
def get_primary_key_constraint_string(self):
return "CONSTRAINT {constraint_name} PRIMARY KEY ({columns})"
def constraintize(self, constraints):
sql = []
for name, constraint in constraints.items():
sql.append(
getattr(
self, f"get_{constraint.constraint_type}_constraint_string"
)().format(
columns=", ".join(constraint.columns),
constraint_name=constraint.name,
)
)
return sql
def foreign_key_constraintize(self, table, foreign_keys):
sql = []
for name, foreign_key in foreign_keys.items():
cascade = ""
if foreign_key.delete_action:
cascade += f" ON DELETE {self.foreign_key_actions.get(foreign_key.delete_action.lower())}"
if foreign_key.update_action:
cascade += f" ON UPDATE {self.foreign_key_actions.get(foreign_key.update_action.lower())}"
sql.append(
self.get_foreign_key_constraint_string().format(
column=self.wrap_column(foreign_key.column),
constraint_name=foreign_key.constraint_name,
table=self.wrap_table(table),
foreign_table=self.wrap_table(foreign_key.foreign_table),
foreign_column=self.wrap_column(foreign_key.foreign_column),
cascade=cascade,
)
)
return sql
def columnize_names(self, columns):
names = []
for name, column in columns.items():
names.append(self.wrap_column(column.name))
return names
def get_current_schema(self, connection, table_name, schema=None):
sql = f"PRAGMA table_info({table_name})"
reversed_type_map = {v: k for k, v in self.type_map.items()}
table = Table(table_name)
result = connection.query(sql, ())
for column in result:
column_type = self.get_column_type(
reversed_type_map, column["type"].upper()
)
length = self.get_column_length(column["type"])
# find default
default = column.get("dflt_value")
if default:
default = default.replace("'", "")
table.add_column(
column["name"],
column_type,
column_python_type=Schema._type_hints_map.get(column_type, str),
default=default,
length=length,
nullable=int(column.get("notnull")) == 0,
)
if column.get("pk") == 1:
table.set_primary_key(column["name"])
return table
def get_column_length(self, column_type):
if "(" in column_type:
parenthesis_index = column_type.find("(")
return column_type[parenthesis_index + 1 : -1]
else:
return None
def get_column_type(self, reversed_type_map, column_type):
if "(" in column_type:
parenthesis_index = column_type.find("(")
db_type = column_type[:parenthesis_index]
length = self.get_column_length(column_type)
if db_type == "CHAR":
if length == "1":
return "char"
elif length == "36":
return "uuid"
else:
return "char"
elif db_type == "VARCHAR":
if length == "4":
return "year"
else:
return "string"
else:
return reversed_type_map.get(column_type)
def compile_table_exists(self, table, database=None, schema=None):
return f"SELECT name FROM sqlite_master WHERE type='table' AND name='{table}'"
def compile_column_exists(self, table, column):
return f"SELECT column_name FROM information_schema.columns WHERE table_name='{table}' and column_name='{column}'"
def compile_get_all_tables(self, database, schema=None):
return "SELECT name FROM sqlite_master WHERE type='table'"
def compile_truncate(self, table, foreign_keys=False):
if not foreign_keys:
return f"DELETE FROM {self.wrap_table(table)}"
return [
self.disable_foreign_key_constraints(),
f"DELETE FROM {self.wrap_table(table)}",
self.enable_foreign_key_constraints(),
]
def compile_rename_table(self, current_table, new_name):
return f"ALTER TABLE {self.wrap_table(current_table)} RENAME TO {self.wrap_table(new_name)}"
def compile_drop_table_if_exists(self, current_table):
return f"DROP TABLE IF EXISTS {self.wrap_table(current_table)}"
def compile_drop_table(self, current_table):
return f"DROP TABLE {self.wrap_table(current_table)}"
def enable_foreign_key_constraints(self):
return "PRAGMA foreign_keys = ON"
def disable_foreign_key_constraints(self):
return "PRAGMA foreign_keys = OFF"
================================================
FILE: src/masoniteorm/schema/platforms/__init__.py
================================================
from .SQLitePlatform import SQLitePlatform
from .MySQLPlatform import MySQLPlatform
from .MSSQLPlatform import MSSQLPlatform
from .PostgresPlatform import PostgresPlatform
================================================
FILE: src/masoniteorm/scopes/BaseScope.py
================================================
class BaseScope:
def on_boot(self, builder):
raise NotImplementedError()
def on_remove(self, builder):
raise NotImplementedError()
================================================
FILE: src/masoniteorm/scopes/SoftDeleteScope.py
================================================
from .BaseScope import BaseScope
class SoftDeleteScope(BaseScope):
"""Global scope class to add soft deleting to models."""
def __init__(self, deleted_at_column="deleted_at"):
self.deleted_at_column = deleted_at_column
def on_boot(self, builder):
builder.set_global_scope("_where_null", self._where_null, action="select")
builder.set_global_scope(
"_query_set_null_on_delete", self._query_set_null_on_delete, action="delete"
)
builder.macro("with_trashed", self._with_trashed)
builder.macro("only_trashed", self._only_trashed)
builder.macro("force_delete", self._force_delete)
builder.macro("restore", self._restore)
def on_remove(self, builder):
builder.remove_global_scope("_where_null", action="select")
builder.remove_global_scope("_query_set_null_on_delete", action="delete")
def _where_null(self, builder):
return builder.where_null(
f"{builder.get_table_name()}.{self.deleted_at_column}"
)
def _with_trashed(self, model, builder):
builder.remove_global_scope("_where_null", action="select")
return builder
def _only_trashed(self, model, builder):
builder.remove_global_scope("_where_null", action="select")
return builder.where_not_null(self.deleted_at_column)
def _force_delete(self, model, builder, query=False):
if query:
return builder.remove_global_scope(self).set_action("delete")
return builder.remove_global_scope(self).delete()
def _restore(self, model, builder):
return builder.remove_global_scope(self).update({self.deleted_at_column: None})
def _query_set_null_on_delete(self, builder):
return builder.set_action("update").set_updates(
{self.deleted_at_column: builder._model.get_new_datetime_string()}
)
================================================
FILE: src/masoniteorm/scopes/SoftDeletesMixin.py
================================================
from .SoftDeleteScope import SoftDeleteScope
class SoftDeletesMixin:
"""Global scope class to add soft deleting to models."""
__deleted_at__ = "deleted_at"
def boot_SoftDeletesMixin(self, builder):
builder.set_global_scope(SoftDeleteScope(self.__deleted_at__))
def get_deleted_at_column(self):
return self.__deleted_at__
================================================
FILE: src/masoniteorm/scopes/TimeStampsMixin.py
================================================
from .TimeStampsScope import TimeStampsScope
class TimeStampsMixin:
"""Global scope class to add soft deleting to models."""
def boot_TimeStampsMixin(self, builder):
builder.set_global_scope(TimeStampsScope())
def activate_timestamps(self, boolean=True):
self.__timestamps__ = boolean
return self
================================================
FILE: src/masoniteorm/scopes/TimeStampsScope.py
================================================
from ..expressions.expressions import UpdateQueryExpression
from .BaseScope import BaseScope
class TimeStampsScope(BaseScope):
"""Global scope class to add soft deleting to models."""
def on_boot(self, builder):
builder.set_global_scope(
"_timestamps", self.set_timestamp_create, action="insert"
)
builder.set_global_scope(
"_timestamp_update", self.set_timestamp_update, action="update"
)
def on_remove(self, builder):
pass
def set_timestamp(owner_cls, query):
owner_cls.updated_at = "now"
def set_timestamp_create(self, builder):
if not builder._model.__timestamps__:
return builder
builder._creates.update(
{
builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string(),
builder._model.date_created_at: builder._model.get_new_date().to_datetime_string(),
}
)
def set_timestamp_update(self, builder):
if not builder._model.__timestamps__:
return builder
for update in builder._updates:
if builder._model.date_updated_at in update.column:
return
builder._updates += (
UpdateQueryExpression(
{
builder._model.date_updated_at: builder._model.get_new_date().to_datetime_string()
}
),
)
================================================
FILE: src/masoniteorm/scopes/UUIDPrimaryKeyMixin.py
================================================
from .UUIDPrimaryKeyScope import UUIDPrimaryKeyScope
class UUIDPrimaryKeyMixin:
"""Global scope class to add UUID as primary key to models."""
def boot_UUIDPrimaryKeyMixin(self, builder):
builder.set_global_scope(UUIDPrimaryKeyScope())
================================================
FILE: src/masoniteorm/scopes/UUIDPrimaryKeyScope.py
================================================
import uuid
from .BaseScope import BaseScope
class UUIDPrimaryKeyScope(BaseScope):
"""Global scope class to use UUID4 as primary key."""
def on_boot(self, builder):
builder.set_global_scope(
"_UUID_primary_key", self.set_uuid_create, action="insert"
)
builder.set_global_scope(
"_UUID_primary_key", self.set_bulk_uuid_create, action="bulk_create"
)
def on_remove(self, builder):
pass
def generate_uuid(self, builder, uuid_version, bytes=False):
# UUID 3 and 5 requires parameters
uuid_func = getattr(uuid, f"uuid{uuid_version}")
args = []
if uuid_version in [3, 5]:
args = [builder._model.__uuid_namespace__, builder._model.__uuid_name__]
return uuid_func(*args).bytes if bytes else str(uuid_func(*args))
def build_uuid_pk(self, builder):
uuid_version = getattr(builder._model, "__uuid_version__", 4)
uuid_bytes = getattr(builder._model, "__uuid_bytes__", False)
return {
builder._model.__primary_key__: self.generate_uuid(
builder, uuid_version, uuid_bytes
)
}
def set_uuid_create(self, builder):
# if there is already a primary key, no need to set a new one
if builder._model.__primary_key__ not in builder._creates:
builder._creates.update(self.build_uuid_pk(builder))
def set_bulk_uuid_create(self, builder):
for idx, create_atts in enumerate(builder._creates):
if builder._model.__primary_key__ not in create_atts:
builder._creates[idx].update(self.build_uuid_pk(builder))
================================================
FILE: src/masoniteorm/scopes/__init__.py
================================================
from .scope import scope
from .BaseScope import BaseScope
from .SoftDeletesMixin import SoftDeletesMixin
from .SoftDeleteScope import SoftDeleteScope
from .TimeStampsMixin import TimeStampsMixin
from .TimeStampsScope import TimeStampsScope
from .UUIDPrimaryKeyScope import UUIDPrimaryKeyScope
from .UUIDPrimaryKeyMixin import UUIDPrimaryKeyMixin
================================================
FILE: src/masoniteorm/scopes/scope.py
================================================
class scope:
def __init__(self, callback, *params, **kwargs):
self.fn = callback
def __set_name__(self, cls, name):
if cls not in cls._scopes:
cls._scopes[cls] = {name: self.fn}
else:
cls._scopes[cls].update({name: self.fn})
self.cls = cls
def __call__(self, *args, **kwargs):
instantiated = self.cls()
builder = instantiated.get_builder()
return self.fn(instantiated, builder, *args, **kwargs)
================================================
FILE: src/masoniteorm/seeds/Seeder.py
================================================
import pydoc
class Seeder:
def __init__(self, dry=False, seed_path="databases/seeds", connection=None):
self.ran_seeds = []
self.dry = dry
self.seed_path = seed_path
self.connection = connection
self.seed_module = seed_path.replace("/", ".").replace("\\", ".")
def call(self, *seeder_classes):
for seeder_class in seeder_classes:
self.ran_seeds.append(seeder_class)
if not self.dry:
seeder_class(connection=self.connection).run()
def run_database_seed(self):
database_seeder = pydoc.locate(
f"{self.seed_module}.database_seeder.DatabaseSeeder"
)
self.ran_seeds.append(database_seeder)
if not self.dry:
database_seeder(connection=self.connection).run()
def run_specific_seed(self, seed):
file_name = f"{self.seed_module}.{seed}"
database_seeder = pydoc.locate(file_name)
if not database_seeder:
raise ValueError(f"Could not find the {file_name} seeder file")
self.ran_seeds.append(database_seeder)
if not self.dry:
database_seeder(connection=self.connection).run()
else:
print(f"Running {database_seeder}")
================================================
FILE: src/masoniteorm/seeds/__init__.py
================================================
from .Seeder import Seeder
================================================
FILE: src/masoniteorm/stubs/create-migration.html
================================================
from masoniteorm.migrations import Migration
class CreateUsersTable(Migration):
def up(self):
"""Run the migrations."""
with self.schema.create('users') as table:
pass
def down(self):
"""Revert the migrations."""
self.schema.drop('users')
================================================
FILE: src/masoniteorm/stubs/table-migration.html
================================================
================================================
FILE: src/masoniteorm/testing/BaseTestCaseSelectGrammar.py
================================================
import inspect
from ..query import QueryBuilder
from ..expressions import JoinClause
from ..models import Model
class MockConnection:
connection_details = {}
def make_connection(self):
return self
class BaseTestCaseSelectGrammar:
def setUp(self):
self.builder = QueryBuilder(
self.grammar,
table="users",
connection_class=MockConnection,
model=Model(),
dry=True,
)
def test_can_compile_select(self):
to_sql = self.builder.to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_order_by_and_first(self):
to_sql = self.builder.order_by("id", "asc").first(query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_columns(self):
to_sql = self.builder.select("username", "password").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_where(self):
to_sql = self.builder.select("username", "password").where("id", 1).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_or_where(self):
to_sql = self.builder.where("name", 2).or_where("name", 3).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_grouped_where(self):
to_sql = self.builder.where(
lambda query: query.where("age", 2).where("name", "Joe")
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_several_where(self):
to_sql = (
self.builder.select("username", "password")
.where("id", 1)
.where("username", "joe")
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_several_where_and_limit(self):
to_sql = (
self.builder.select("username", "password")
.where("id", 1)
.where("username", "joe")
.limit(10)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_sum(self):
to_sql = self.builder.sum("age").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_max(self):
to_sql = self.builder.max("age").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_max_and_columns(self):
to_sql = self.builder.select("username").max("age").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_max_and_columns_different_order(self):
to_sql = self.builder.max("age").select("username").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_order_by(self):
to_sql = self.builder.select("username").order_by("age", "desc").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_multiple_order_by(self):
to_sql = (
self.builder.select("username")
.order_by("age", "desc")
.order_by("name")
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_with_group_by(self):
to_sql = self.builder.select("username").group_by("age").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_where_in(self):
to_sql = self.builder.select("username").where_in("age", [1, 2, 3]).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_where_in_empty(self):
to_sql = self.builder.where_in("age", []).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_where_not_in(self):
to_sql = self.builder.select("username").where_not_in("age", [1, 2, 3]).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_where_null(self):
to_sql = self.builder.select("username").where_null("age").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_where_not_null(self):
to_sql = self.builder.select("username").where_not_null("age").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_count(self):
to_sql = self.builder.count("*").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_count_column(self):
to_sql = self.builder.count("money").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_where_column(self):
to_sql = self.builder.where_column("name", "email").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_sub_select(self):
to_sql = self.builder.where_in(
"name", self.builder.new().select("age")
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_complex_sub_select(self):
to_sql = self.builder.where_in(
"name",
(
self.builder.new()
.select("age")
.where_in("email", self.builder.new().select("email"))
),
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_sub_select_where(self):
to_sql = self.builder.where_in(
"age", self.builder.new().select("age").where("age", 2).where("name", "Joe")
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_sub_select_from_lambda(self):
to_sql = (
self.builder.new()
.where_in(
"age", lambda q: (q.select("age").where("age", 2).where("name", "Joe"))
)
.to_sql()
)
sql = getattr(self, "can_compile_sub_select_where")()
self.assertEqual(to_sql, sql)
def test_can_compile_sub_select_value(self):
to_sql = self.builder.where("name", self.builder.new().sum("age")).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_exists(self):
to_sql = (
self.builder.select("age")
.where_exists(self.builder.new().select("username").where("age", 12))
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_not_exists(self):
to_sql = (
self.builder.select("age")
.where_not_exists(self.builder.new().select("username").where("age", 12))
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_having(self):
to_sql = self.builder.sum("age").group_by("age").having("age").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_having_with_expression(self):
to_sql = self.builder.sum("age").group_by("age").having("age", 10).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_having_with_greater_than_expression(self):
to_sql = self.builder.sum("age").group_by("age").having("age", ">", 10).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_join(self):
to_sql = self.builder.join(
"contacts", "users.id", "=", "contacts.user_id"
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_join_clause(self):
clause = (
JoinClause("report_groups as rg")
.on("bgt.fund", "=", "rg.fund")
.on("bgt.dept", "=", "rg.dept")
.on("bgt.acct", "=", "rg.acct")
.on("bgt.sub", "=", "rg.sub")
)
to_sql = self.builder.join(clause).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_join_clause_with_value(self):
clause = (
JoinClause("report_groups as rg")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
to_sql = self.builder.join(clause).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_join_clause_with_null(self):
clause = (
JoinClause("report_groups as rg")
.on_null("bgt.acct")
.or_on_null("bgt.dept")
.on_value("rg.abc", 10)
)
to_sql = self.builder.join(clause).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_join_clause_with_not_null(self):
clause = (
JoinClause("report_groups as rg")
.on_not_null("bgt.acct")
.or_on_not_null("bgt.dept")
.on_value("rg.abc", 10)
)
to_sql = self.builder.join(clause).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_join_clause_with_lambda(self):
to_sql = self.builder.join(
"report_groups as rg",
lambda clause: (clause.on("bgt.fund", "=", "rg.fund").on_null("bgt")),
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_left_join_clause_with_lambda(self):
to_sql = self.builder.left_join(
"report_groups as rg",
lambda clause: (clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt")),
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_right_join_clause_with_lambda(self):
to_sql = self.builder.right_join(
"report_groups as rg",
lambda clause: (clause.on("bgt.fund", "=", "rg.fund").or_on_null("bgt")),
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_left_join(self):
to_sql = self.builder.left_join(
"contacts", "users.id", "=", "contacts.user_id"
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_multiple_join(self):
to_sql = (
self.builder.join("contacts", "users.id", "=", "contacts.user_id")
.join("posts", "comments.post_id", "=", "posts.id")
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_limit_and_offset(self):
to_sql = self.builder.limit(10).offset(10).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_between(self):
to_sql = self.builder.between("age", 18, 21).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_not_between(self):
to_sql = self.builder.not_between("age", 18, 21).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_user_where_raw_and_where(self):
to_sql = (
self.builder.where_raw("age = '18'").where("name", "=", "James").to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_where_raw_and_where_with_multiple_bindings(self):
query = self.builder.where_raw(
"`age` = '?' AND `is_admin` = '?'", [18, True]
).where("email", "test@example.com")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(query.to_qmark(), sql)
self.assertEqual(query._bindings, [18, True, "test@example.com"])
def test_can_compile_first_or_fail(self):
to_sql = (
self.builder.where("is_admin", "=", True).first_or_fail(query=True).to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_where_like(self):
to_sql = self.builder.where("age", "like", "%name%").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_where_regexp(self):
to_sql = self.builder.where("age", "regexp", "Joe").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_where_exists_with_lambda(self):
to_sql = self.builder.where_exists(lambda q: q.where("age", 1)).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
print(to_sql)
self.assertEqual(to_sql, sql)
def test_where_not_exists_with_lambda(self):
to_sql = self.builder.where_not_exists(lambda q: q.where("age", 1)).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
print(to_sql)
self.assertEqual(to_sql, sql)
def test_where_not_regexp(self):
to_sql = self.builder.where("age", "not regexp", "Joe").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_where_not_like(self):
to_sql = self.builder.where("age", "not like", "%name%").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_shared_lock(self):
to_sql = self.builder.where("votes", ">=", 100).shared_lock().to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_update_lock(self):
to_sql = self.builder.where("votes", ">=", 100).lock_for_update().to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_where_date(self):
to_sql = self.builder.where_date("created_at", "2022-06-01").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_or_where_null(self):
to_sql = self.builder.where_null("column1").or_where_null("column2").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_select_distinct(self):
to_sql = self.builder.select("group").distinct().to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
================================================
FILE: src/masoniteorm/testing/__init__.py
================================================
from .BaseTestCaseSelectGrammar import BaseTestCaseSelectGrammar
================================================
FILE: tests/User.py
================================================
""" User Model """
from src.masoniteorm import Model
class User(Model):
"""User Model"""
__fillable__ = ["name", "email", "password"]
__connection__ = "t"
__auth__ = "email"
@property
def meta(self):
return 1
================================================
FILE: tests/collection/test_collection.py
================================================
import unittest
from src.masoniteorm.collection import Collection
from src.masoniteorm.factories import Factory as factory
from src.masoniteorm.models import Model
from tests.User import User
class TestCollection(unittest.TestCase):
def test_take(self):
collection = Collection([1, 2, 3, 4])
self.assertEqual(collection.take(2), [1, 2])
def test_first(self):
collection = Collection([1, 2, 3, 4])
self.assertEqual(collection.first(), 1)
self.assertEqual(collection.last(), 4)
self.assertEqual(collection.first(lambda x: x < 3), 1)
def test_last(self):
collection = Collection([1, 2, 3, 4])
self.assertEqual(collection.last(), 4)
self.assertEqual(collection.last(lambda x: x < 3), 2)
def test_pluck(self):
collection = Collection([{"id": 1, "name": "Joe"}, {"id": 2, "name": "Bob"}])
self.assertEqual(collection.pluck("id"), [1, 2])
self.assertEqual(collection.pluck("id").serialize(), [1, 2])
self.assertEqual(collection.pluck("name", "id"), {1: "Joe", 2: "Bob"})
def test_pluck_with_models(self):
factory.register(Model, lambda faker: {"id": 1, "batch": 1})
collection = factory(Model, 5).make()
self.assertEqual(collection.pluck("batch"), [1, 1, 1, 1, 1])
def test_where(self):
collection = Collection(
[
{"id": 1, "name": "Joe"},
{"id": 2, "name": "Joe"},
{"id": 3, "name": "Bob"},
]
)
self.assertEqual(len(collection.where("name", "Joe")), 2)
self.assertEqual(len(collection.where("id", "!=", 1)), 2)
self.assertEqual(len(collection.where("id", ">", 1)), 2)
self.assertEqual(len(collection.where("id", ">=", 1)), 3)
self.assertEqual(len(collection.where("id", "<=", 1)), 1)
self.assertEqual(len(collection.where("id", "<", 3)), 2)
def test_where_in(self):
collection = Collection(
[
{"id": 1, "name": "Joe"},
{"id": 2, "name": "Joe"},
{"id": 3, "name": "Bob"},
]
)
self.assertEqual(len(collection.where_in("id", [1, 2])), 2)
self.assertEqual(len(collection.where_in("id", [3])), 1)
self.assertEqual(len(collection.where_in("id", [4])), 0)
self.assertEqual(len(collection.where_in("id", ["1", "2"])), 2)
self.assertEqual(len(collection.where_in("id", ["3"])), 1)
self.assertEqual(len(collection.where_in("id", ["4"])), 0)
self.assertEqual(len(collection.where_in("name", ["Joe"])), 2)
def test_where_not_in(self):
collection = Collection(
[
{"id": 1, "name": "Joe"},
{"id": 2, "name": "Joe"},
{"id": 3, "name": "Bob"},
]
)
self.assertEqual(len(collection.where_not_in("id", [1, 2])), 1)
self.assertEqual(len(collection.where_not_in("id", [3])), 2)
self.assertEqual(len(collection.where_not_in("id", [4])), 3)
self.assertEqual(len(collection.where_not_in("id", ["1", "2"])), 1)
self.assertEqual(len(collection.where_not_in("id", ["3"])), 2)
self.assertEqual(len(collection.where_not_in("id", ["4"])), 3)
self.assertEqual(len(collection.where_not_in("name", ["Joe"])), 1)
def test_where_in_bool(self):
nested_collection = Collection(
[
{"id": 1, "is_active": True},
{"id": 2, "is_active": True},
{"id": 3, "is_active": True},
{"id": 4},
]
)
self.assertEqual(len(nested_collection.where_in("is_active", [False])), 0)
self.assertEqual(len(nested_collection.where_in("is_active", [True])), 3)
self.assertEqual(len(nested_collection.where_in("is_active", [True, False])), 3)
obj_collection = Collection(
[
type("", (), {"is_active": True, "is_disabled": False}),
type("", (), {"is_active": False, "is_disabled": True}),
type("", (), {"is_active": True, "is_disabled": True}),
]
)
self.assertEqual(len(obj_collection.where_in("is_active", [False])), 1)
self.assertEqual(len(obj_collection.where_in("is_active", [True])), 2)
self.assertEqual(len(obj_collection.where_in("is_active", [True, False])), 3)
self.assertEqual(len(obj_collection.where_in("nonexistent_key", [False])), 0)
self.assertEqual(len(obj_collection.where_in("nonexistent_key", [True])), 0)
def test_where_in_bytes(self):
byte_strs = [bytes("should find this", "utf-8"), bytes("and this", "utf-8")]
collection = Collection(
[
{"id": 1, "name": "Joe", "bytes_val": byte_strs[0]},
{"id": 2, "name": "Joe", "bytes_val": byte_strs[1]},
{
"id": 3,
"name": "Bob",
"bytes_val": bytes("should not find", "utf-8"),
},
{"id": 4, "name": "Bob"},
]
)
self.assertEqual(len(collection.where_in("bytes_val", byte_strs)), 2)
self.assertEqual(len(collection.where_in("bytes_val", [byte_strs[0]])), 1)
def test_pop(self):
collection = Collection([1, 2, 3])
self.assertEqual(collection.pop(), 3)
self.assertEqual(collection.all(), [1, 2])
def test_is_empty(self):
collection = Collection([])
self.assertEqual(collection.is_empty(), True)
collection = Collection([1, 2, 3])
self.assertEqual(collection.is_empty(), False)
def test_sum(self):
collection = Collection([1, 1, 2, 4])
self.assertEqual(collection.sum(), 8)
collection = Collection(
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
self.assertEqual(collection.sum("age"), 10)
self.assertEqual(collection.sum(), 0)
collection = Collection(
[
{"name": "chair", "colours": ["green", "black"]},
{"name": "desk", "colours": ["red", "yellow"]},
{"name": "bookcase", "colours": ["white"]},
]
)
self.assertEqual(collection.sum(lambda x: len(x["colours"])), 5)
self.assertEqual(collection.sum(lambda x: len(x)), 6)
def test_avg(self):
collection = Collection([1, 1, 2, 4])
self.assertEqual(collection.avg(), 2)
collection = Collection(
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
self.assertEqual(collection.avg("age"), 2.5)
self.assertEqual(collection.avg(), 0)
collection = Collection(
[
{"name": "chair", "colours": ["green", "black"]},
{"name": "desk", "colours": ["red", "yellow"]},
{"name": "bookcase", "colours": ["white"]},
]
)
self.assertEqual(collection.avg(lambda x: len(x["colours"])), 5 / 3)
self.assertEqual(collection.avg(lambda x: len(x)), 2)
def test_max(self):
collection = Collection([1, 1, 2, 4])
self.assertEqual(collection.max(), 4)
collection = Collection(
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
self.assertEqual(collection.max("age"), 4)
self.assertEqual(collection.max(), 0)
collection = Collection([{"batch": 1}, {"batch": 1}])
self.assertEqual(collection.max("batch"), 1)
def test_min(self):
collection = Collection([1, 1, 2, 4])
self.assertEqual(collection.min(), 1)
collection = Collection(
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
self.assertEqual(collection.min("age"), 1)
self.assertEqual(collection.min(), 0)
collection = Collection([{"batch": 1}, {"batch": 1}])
self.assertEqual(collection.min("batch"), 1)
def test_count(self):
collection = Collection([1, 1, 2, 4])
self.assertEqual(collection.count(), 4)
collection = Collection(
[{"name": "Corentin All", "age": 1}, {"name": "Corentin All", "age": 2}]
)
self.assertEqual(collection.count(), 2)
def test_chunk(self):
collection = Collection([1, 1, 2, 4])
chunked = collection.chunk(2)
self.assertEqual(chunked, Collection([Collection([1, 1]), Collection([2, 4])]))
collection = Collection(
[
{"name": "chair", "colours": ["green", "black"]},
{"name": "desk", "colours": ["red", "yellow"]},
{"name": "bookcase", "colours": ["white"]},
]
)
chunked = collection.chunk(2)
self.assertEqual(
chunked,
Collection(
[
Collection(
[
{"name": "chair", "colours": ["green", "black"]},
{"name": "desk", "colours": ["red", "yellow"]},
]
),
Collection([{"name": "bookcase", "colours": ["white"]}]),
]
),
)
def test_collapse(self):
collection = Collection([[1, 1], [2, 4]])
collapsed = collection.collapse()
self.assertEqual(collapsed, Collection([1, 1, 2, 4]))
def test_get(self):
collection = Collection([[1, 1], [2, 4]])
self.assertEqual(collection.get(0), [1, 1])
self.assertIsNone(collection.get(2))
self.assertEqual(collection.get(2, 0), 0)
def test_merge(self):
collection = Collection([[1, 1], [2, 4]])
collection.merge([[2, 1]])
self.assertEqual(collection.all(), [[1, 1], [2, 4], [2, 1]])
def test_reduce(self):
callback = lambda x, y: x + y
collection = Collection([1, 1, 2, 4])
sum = collection.sum()
reduce = collection.reduce(callback)
self.assertEqual(sum, reduce)
reduce = collection.reduce(callback, 10)
self.assertEqual(10 + sum, reduce)
def test_forget(self):
collection = Collection([1, 2, 3, 4])
collection.forget(0)
self.assertEqual(collection.all(), [2, 3, 4])
collection.forget(1, 2)
self.assertEqual(collection.all(), [2])
collection.forget(0)
self.assertTrue(collection.is_empty())
def test_prepend(self):
collection = Collection([1, 2, 3, 4])
collection.prepend(0)
self.assertEqual(collection.get(0), 0)
self.assertEqual(collection.all(), [0, 1, 2, 3, 4])
def test_pull(self):
collection = Collection([1, 2, 3, 4])
value = collection.pull(0)
self.assertEqual(value, 1)
self.assertEqual(collection.all(), [2, 3, 4])
def test_push(self):
collection = Collection([1, 2, 3, 4])
collection.push(5)
self.assertEqual(collection.get(4), 5)
self.assertEqual(collection.all(), [1, 2, 3, 4, 5])
def test_put(self):
collection = Collection([1, 2, 3, 4])
collection.put(2, 5)
self.assertEqual(collection.get(2), 5)
self.assertEqual(collection.all(), [1, 2, 5, 4])
def test_reject(self):
collection = Collection([1, 2, 3, 4])
collection.reject(lambda x: x if x > 2 else None)
self.assertEqual(collection.all(), [3, 4])
collection = Collection(
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
collection.reject(lambda x: x if x["age"] > 2 else None)
self.assertEqual(
Collection(
[{"name": "Corentin All", "age": 3}, {"name": "Corentin All", "age": 4}]
),
collection.all(),
)
collection.reject(lambda x: x["age"] if x["age"] > 2 else None)
self.assertEqual(collection.all(), [3, 4])
def test_for_page(self):
collection = Collection([1, 2, 3, 4])
chunked = collection.for_page(0, 3)
self.assertEqual(chunked.all(), [1, 2, 3])
def test_unique(self):
collection = Collection([1, 1, 2, 3, 4])
unique = collection.unique()
self.assertEqual(unique.all(), [1, 2, 3, 4])
collection = Collection(
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
unique = collection.unique("age")
self.assertEqual(
unique.all(),
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
],
)
self.assertEqual(collection.pluck("name").unique().all(), ["Corentin All"])
def test_transform(self):
collection = Collection([1, 1, 2, 3, 4])
collection.transform(lambda x: x * 2)
self.assertEqual(collection.all(), [2, 2, 4, 6, 8])
def test_shift(self):
collection = Collection([1, 2, 3, 4])
value = collection.shift()
self.assertEqual(value, 1)
self.assertEqual(collection.all(), [2, 3, 4])
collection = Collection(
[
{"name": "Corentin All", "age": 1},
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
value = collection.shift()
self.assertEqual(value, {"name": "Corentin All", "age": 1})
self.assertEqual(
collection.all(),
[
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
],
)
def test_sort(self):
collection = Collection([4, 1, 2, 3])
collection.sort()
self.assertEqual(collection.all(), [1, 2, 3, 4])
def test_reverse(self):
collection = Collection([4, 1, 2, 3])
collection.reverse()
self.assertEqual(collection.all(), [3, 2, 1, 4])
collection = Collection(
[
{"name": "Corentin All", "age": 2},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 4},
]
)
collection.reverse()
self.assertEqual(
collection.all(),
[
{"name": "Corentin All", "age": 4},
{"name": "Corentin All", "age": 3},
{"name": "Corentin All", "age": 2},
],
)
def test_zip(self):
collection = Collection(["Chair", "Desk"])
zipped = collection.zip([100, 200])
self.assertEqual(zipped.all(), [["Chair", 100], ["Desk", 200]])
def test_diff(self):
collection = Collection(["Chair", "Desk"])
diff = collection.diff([100, 200])
self.assertEqual(diff.all(), ["Chair", "Desk"])
def test_each(self):
collection = Collection([1, 2, 3, 4])
collection.each(lambda x: x + 2)
self.assertEqual(collection.all(), [x + 2 for x in range(1, 5)])
def test_every(self):
collection = Collection([1, 2, 3, 4])
self.assertFalse(collection.every(lambda x: x > 2))
collection = Collection([1, 2, 3, 4])
self.assertTrue(collection.every(lambda x: x >= 1))
def test_filter(self):
collection = Collection([1, 2, 3, 4])
filtered = collection.filter(lambda x: x > 2)
self.assertEqual(filtered.all(), Collection([3, 4]))
def test_implode(self):
collection = Collection([1, 2, 3, 4])
result = collection.implode("-")
self.assertEqual(result, "1-2-3-4")
collection = Collection(
[{"name": "Corentin"}, {"name": "Joe"}, {"name": "Marlysson"}]
)
result = collection.implode(key="name")
self.assertEqual(result, "Corentin,Joe,Marlysson")
def test_map_into(self):
collection = Collection(["USD", "EUR", "GBP"])
class Currency:
def __init__(self, code):
self.code = code
def __eq__(self, other):
return self.code == other.code
currencies = collection.map_into(Currency)
self.assertEqual(
currencies.all(), [Currency("USD"), Currency("EUR"), Currency("GBP")]
)
def test_map(self):
collection = Collection([1, 2, 3, 4])
multiplied = collection.map(lambda x: x * 2)
self.assertEqual(multiplied.all(), [2, 4, 6, 8])
def callback(x):
x["age"] = x["age"] + 2
return x
collection = Collection(
[
{"name": "Corentin", "age": 10},
{"name": "Joe", "age": 20},
{"name": "Marlysson", "age": 15},
]
)
result = collection.map(callback)
self.assertEqual(
result.all(),
[
{"name": "Corentin", "age": 12},
{"name": "Joe", "age": 22},
{"name": "Marlysson", "age": 17},
],
)
def test_serialize(self):
class Currency:
def __init__(self, code):
self.code = code
def __eq__(self, other):
return self.code == other.code
def to_dict(self):
return {"code": self.code}
collection = Collection(
[
Collection([{"name": "Corentin", "age": 12}]),
{"name": "Joe", "age": 22},
{"name": "Marlysson", "age": 17},
Currency("USD"),
]
)
serialized_data = collection.serialize()
self.assertEqual(
serialized_data,
[
[{"name": "Corentin", "age": 12}],
{"name": "Joe", "age": 22},
{"name": "Marlysson", "age": 17},
{"code": "USD"},
],
)
def test_json(self):
collection = Collection(
[
{"name": "Corentin", "age": 10},
{"name": "Joe", "age": 20},
{"name": "Marlysson", "age": 15},
]
)
json_data = collection.to_json()
self.assertEqual(
json_data,
'[{"name": "Corentin", "age": 10}, '
'{"name": "Joe", "age": 20}, {"name": "Marlysson", "age": 15}]',
)
def test_contains(self):
collection = Collection([1, 2, 3, 4])
self.assertTrue(collection.contains(3))
self.assertFalse(collection.contains(5))
collection = Collection(
[
{"name": "Corentin", "age": 10},
{"name": "Joe", "age": 20},
{"name": "Marlysson", "age": 15},
]
)
self.assertTrue(collection.contains(lambda x: x["age"] == 10))
self.assertFalse(collection.contains("age"))
self.assertTrue(collection.contains("age", 10))
self.assertFalse(collection.contains("age", 11))
def test_all(self):
collection = Collection([1, 2, 3, 4])
self.assertEqual(collection.all(), [1, 2, 3, 4])
collection = Collection(
[
{"name": "Corentin", "age": 10},
{"name": "Joe", "age": 20},
{"name": "Marlysson", "age": 15},
]
)
self.assertEqual(
collection.all(),
[
{"name": "Corentin", "age": 10},
{"name": "Joe", "age": 20},
{"name": "Marlysson", "age": 15},
],
)
def test_flatten(self):
collection = Collection([1, 2, [3, 4, 5, {"foo": "bar"}]])
flattened = collection.flatten()
self.assertEqual(flattened.all(), [1, 2, 3, 4, 5, "bar"])
def test_group_by(self):
collection = Collection(
[
{"name": "Corentin", "age": 10},
{"name": "Joe", "age": 10},
{"name": "Marlysson", "age": 20},
]
)
grouped = collection.group_by("age")
self.assertIsInstance(grouped, Collection)
self.assertEqual(
grouped,
{
10: [{"name": "Corentin", "age": 10}, {"name": "Joe", "age": 10}],
20: [{"name": "Marlysson", "age": 20}],
},
)
def test_serialize_with_model_appends(self):
User.__appends__ = ["meta"]
users = User.all().serialize()
self.assertTrue(users[0].get("meta"))
def test_serialize_with_on_the_fly_appends(self):
users = User.all().set_appends(["meta"]).serialize()
self.assertTrue(users[0].get("meta"))
def test_random(self):
collection = Collection([1, 2, 3, 4])
item = collection.random()
self.assertIn(item, collection)
collection = Collection([])
item = collection.random()
self.assertIsNone(item)
collection = Collection([3])
item = collection.random()
self.assertEqual(item, 3)
def test_random_with_count(self):
collection = Collection([1, 2, 3, 4])
items = collection.random(2)
self.assertEqual(items.count(), 2)
self.assertIsInstance(items, Collection)
with self.assertRaises(ValueError):
items = collection.random(6)
items = collection.random(1)
self.assertEqual(items.count(), 1)
self.assertIsInstance(items, Collection)
def test_make_comparison(self):
collection = Collection([])
self.assertTrue(collection._make_comparison(1, 1, "=="))
self.assertTrue(collection._make_comparison(1, "1", "=="))
def test_eq(self):
collection = Collection([1, 2, 3, 4])
other = Collection([1, 2, 3, 4])
self.assertTrue(collection == other)
different = Collection([1, 2, 3])
self.assertFalse(collection == different)
================================================
FILE: tests/commands/test_shell.py
================================================
import unittest
from cleo import CommandTester
from src.masoniteorm.commands import ShellCommand
class TestShellCommand(unittest.TestCase):
def setUp(self):
self.command = ShellCommand()
self.command_tester = CommandTester(self.command)
def test_for_mysql(self):
config = {
"host": "localhost",
"database": "orm",
"user": "root",
"port": "1234",
"password": "secret",
"prefix": "",
"options": {"charset": "utf8mb4"},
"full_details": {"driver": "mysql"},
}
command, _ = self.command.get_command(config)
assert (
command
== "mysql orm --host localhost --port 1234 --user root --password secret --default-character-set utf8mb4"
)
def test_for_postgres(self):
config = {
"host": "localhost",
"database": "orm",
"user": "root",
"port": "1234",
"password": "secretpostgres",
"prefix": "",
"options": {"charset": "utf8mb4"},
"full_details": {"driver": "postgres"},
}
command, env = self.command.get_command(config)
assert command == "psql orm --host localhost --port 1234 --username root"
assert env.get("PGPASSWORD", "secretpostgres")
def test_for_sqlite(self):
config = {
"database": "orm.sqlite3",
"prefix": "",
"full_details": {"driver": "sqlite"},
}
command, _ = self.command.get_command(config)
assert command == "sqlite3 orm.sqlite3"
def test_for_mssql(self):
config = {
"host": "db.masonite.com",
"database": "orm",
"user": "root",
"port": "1234",
"password": "secretpostgres",
"prefix": "",
"options": {"charset": "utf8mb4"},
"full_details": {"driver": "mssql"},
}
command, _ = self.command.get_command(config)
assert (
command
== "sqlcmd -d orm -U root -P secretpostgres -S tcp:db.masonite.com,1234"
)
def test_running_command_with_sqlite(self):
self.command_tester.execute("-c dev")
assert "sqlite3" not in self.command_tester.io.fetch_output()
self.command_tester.execute("-c dev -s")
assert "sqlite3 orm.sqlite3" in self.command_tester.io.fetch_output()
def test_hiding_sensitive_options(self):
config = {
"host": "localhost",
"database": "orm",
"user": "root",
"port": "",
"password": "secret",
"full_details": {"driver": "mysql"},
}
command, _ = self.command.get_command(config)
cleaned_command = self.command.hide_sensitive_options(config, command)
assert (
cleaned_command == "mysql orm --host localhost --user root --password ***"
)
================================================
FILE: tests/config/test_db_url.py
================================================
import os
import pytest
import unittest
from src.masoniteorm.config import db_url, load_config
from src.masoniteorm.exceptions import InvalidUrlConfiguration
from src.masoniteorm.connections import ConnectionResolver
class TestDbUrlHelper(unittest.TestCase):
def setUp(self):
self.original_db_url = os.getenv("DATABASE_URL")
# def tearDown(self):
# os.environ["DATABASE_URL"] = self.original_db_url
def test_parse_env_by_default(self):
os.environ["DATABASE_URL"] = "mysql://root:@localhost:3306/orm"
config = db_url()
assert config.get("driver") == "mysql"
# def test_raise_error_if_no_url(self):
# # no DATABASE_URL is defined yet
# with pytest.raises(InvalidUrlConfiguration):
# db_url()
def test_parse_sqlite(self):
# check in memory use
config = db_url("sqlite://")
assert config.get("driver", "sqlite")
assert config.get("database", ":memory:")
assert not config.get("user")
config = db_url("sqlite://:memory:")
assert config.get("driver", "sqlite")
assert config.get("database", ":memory:")
assert not config.get("user")
config = db_url("sqlite://db.sqlite3")
assert config.get("driver", "sqlite")
assert config.get("database", "db.sqlite3")
assert not config.get("user")
def test_parse_mysql(self):
config = db_url("mysql://root:@localhost:3306/orm")
assert config == {
"driver": "mysql",
"database": "orm",
"prefix": "",
"options": {},
"log_queries": False,
"user": "root",
"password": "",
"host": "localhost",
"port": 3306,
}
def test_parse_postgres(self):
config = db_url(
"postgres://utpcrbiihfqqys:de0a0d847094a66e32274262aa5b5f0ad78e5e34197875fc6089a2d9185d0032@ec2-54-225-242-183.compute-1.amazonaws.com:5432/da455n1ef8kout"
)
assert config == {
"driver": "postgres",
"database": "da455n1ef8kout",
"prefix": "",
"options": {},
"log_queries": False,
"user": "utpcrbiihfqqys",
"password": "de0a0d847094a66e32274262aa5b5f0ad78e5e34197875fc6089a2d9185d0032",
"host": "ec2-54-225-242-183.compute-1.amazonaws.com",
"port": 5432,
}
def test_parse_mssql(self):
config = db_url("mssql://john:secret@127.0.0.1:1433/mssql_db")
assert config == {
"driver": "mssql",
"database": "mssql_db",
"prefix": "",
"options": {},
"log_queries": False,
"user": "john",
"password": "secret",
"host": "127.0.0.1",
"port": "1433",
}
def test_parse_with_params(self):
config = db_url(
"mysql://root:@localhost:3306/orm",
log_queries=True,
prefix="a",
options={"key": "value"},
)
assert config == {
"driver": "mysql",
"database": "orm",
"prefix": "a",
"options": {"key": "value"},
"log_queries": True,
"user": "root",
"password": "",
"host": "localhost",
"port": 3306,
}
def test_using_it_with_connection_resolver(self):
TEST_DATABASES = {
"default": "test",
"test": {
**db_url("mysql://root:@localhost:3306/orm"),
"prefix": "",
"log_queries": True,
},
}
resolver = ConnectionResolver().set_connection_details(TEST_DATABASES)
config = resolver.get_connection_details().get("test")
assert config.get("database") == "orm"
assert config.get("user") == "root"
assert config.get("password") == ""
assert config.get("port") == 3306
assert config.get("host") == "localhost"
assert config.get("log_queries")
# reset connection resolver to default for other tests to continue working
from tests.integrations.config.database import DATABASES
ConnectionResolver().set_connection_details(DATABASES)
================================================
FILE: tests/connections/test_base_connections.py
================================================
import unittest
from src.masoniteorm.connections import ConnectionResolver
from tests.integrations.config.database import DB
class TestDefaultBehaviorConnections(unittest.TestCase):
def test_should_return_connection_with_enabled_logs(self):
connection = DB.begin_transaction("dev")
should_log_queries = connection.full_details.get("log_queries")
DB.commit("dev")
self.assertTrue(should_log_queries)
def test_should_disable_log_queries_in_connection(self):
connection = DB.begin_transaction("dev")
connection.disable_query_log()
should_log_queries = connection.full_details.get("log_queries")
self.assertFalse(should_log_queries)
connection.enable_query_log()
should_log_queries = connection.full_details.get("log_queries")
DB.commit("dev")
self.assertTrue(should_log_queries)
================================================
FILE: tests/eagers/test_eager.py
================================================
import os
import unittest
from src.masoniteorm.query.EagerRelation import EagerRelations
class TestEagerRelation(unittest.TestCase):
def test_can_register_string_eager_load(self):
self.assertEqual(
EagerRelations().register("profile").get_eagers(), [["profile"]]
)
self.assertEqual(EagerRelations().register("profile").is_nested, False)
self.assertEqual(
EagerRelations().register("profile.user").get_eagers(),
[{"profile": ["user"]}],
)
self.assertEqual(
EagerRelations().register("profile.user", "profile.logo").get_eagers(),
[{"profile": ["user", "logo"]}],
)
self.assertEqual(
EagerRelations()
.register("profile.user", "profile.logo", "profile.bio")
.get_eagers(),
[{"profile": ["user", "logo", "bio"]}],
)
self.assertEqual(
EagerRelations().register("user", "logo", "bio").get_eagers(),
[["user", "logo", "bio"]],
)
def test_can_register_tuple_eager_load(self):
self.assertEqual(
EagerRelations().register(("profile",)).get_eagers(), [["profile"]]
)
self.assertEqual(
EagerRelations().register(("profile", "user")).get_eagers(),
[["profile", "user"]],
)
self.assertEqual(
EagerRelations().register(("profile.name", "profile.user")).get_eagers(),
[{"profile": ["name", "user"]}],
)
def test_can_register_list_eager_load(self):
self.assertEqual(
EagerRelations().register(["profile"]).get_eagers(), [["profile"]]
)
self.assertEqual(
EagerRelations().register(["profile", "user"]).get_eagers(),
[["profile", "user"]],
)
self.assertEqual(
EagerRelations().register(["profile.name", "profile.user"]).get_eagers(),
[{"profile": ["name", "user"]}],
)
self.assertEqual(
EagerRelations().register(["profile.name"]).get_eagers(),
[{"profile": ["name"]}],
)
self.assertEqual(
EagerRelations().register(["profile.name", "logo"]).get_eagers(),
[["logo"], {"profile": ["name"]}],
)
self.assertEqual(
EagerRelations()
.register(["profile.name", "logo", "profile.user"])
.get_eagers(),
[["logo"], {"profile": ["name", "user"]}],
)
================================================
FILE: tests/factories/test_factories.py
================================================
import os
import unittest
from src.masoniteorm import Factory as factory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
class User(Model):
pass
class AfterCreatedModel(Model):
__dry__ = True
pass
class TestFactories(unittest.TestCase):
def setUp(self):
factory.register(User, self.user_factory)
factory.register(User, self.named_user_factory, name="admin")
factory.register(AfterCreatedModel, self.named_user_factory)
factory.after_creating(AfterCreatedModel, self.after_creating)
def user_factory(self, faker):
return {"id": 1, "name": faker.name()}
def named_user_factory(self, faker):
return {"id": 1, "name": faker.name(), "admin": 1}
def after_creating(self, model, faker):
model.after_created = True
def test_can_make_single(self):
user = factory(User).make({"id": 1, "name": "Joe"})
self.assertEqual(user.name, "Joe")
self.assertIsInstance(user, User)
def test_can_make_several(self):
users = factory(User).make([{"id": 1, "name": "Joe"}, {"id": 2, "name": "Bob"}])
self.assertEqual(users.count(), 2)
def test_can_make_any_number(self):
users = factory(User, 50).make()
self.assertEqual(users.count(), 50)
def test_can_make_named_factory(self):
user = factory(User).make(name="admin")
self.assertEqual(user.admin, 1)
def test_after_creates(self):
user = factory(AfterCreatedModel).create()
self.assertTrue(user.name)
self.assertEqual(user.after_created, True)
users = factory(AfterCreatedModel).create({"name": "billy"})
self.assertEqual(users.name, "billy")
self.assertEqual(users.after_created, True)
users = factory(AfterCreatedModel).make()
self.assertEqual(users.after_created, True)
users = factory(AfterCreatedModel, 2).make()
for user in users:
self.assertEqual(user.after_created, True)
================================================
FILE: tests/integrations/config/__init__.py
================================================
================================================
FILE: tests/integrations/config/database.py
================================================
""" Database Settings """
import os
import logging
from dotenv import load_dotenv
from src.masoniteorm.connections import ConnectionResolver
from src.masoniteorm.config import db_url
"""
|--------------------------------------------------------------------------
| Load Environment Variables
|--------------------------------------------------------------------------
|
| Loads in the environment variables when this page is imported.
|
"""
load_dotenv(".env")
"""
The connections here don't determine the database but determine the "connection".
They can be named whatever you want.
"""
DATABASES = {
"default": "mysql",
"mysql": {
"driver": "mysql",
"host": os.getenv("MYSQL_DATABASE_HOST"),
"user": os.getenv("MYSQL_DATABASE_USER"),
"password": os.getenv("MYSQL_DATABASE_PASSWORD"),
"database": os.getenv("MYSQL_DATABASE_DATABASE"),
"port": os.getenv("MYSQL_DATABASE_PORT"),
"prefix": "",
"options": {"charset": "utf8mb4"},
"log_queries": True,
"propagate": False,
"connection_pooling_enabled": True,
"connection_pooling_max_size": 10,
"connection_pooling_min_size": None,
},
"t": {"driver": "sqlite", "database": "orm.sqlite3", "log_queries": True, "foreign_keys": True},
"devprod": {
"driver": "mysql",
"host": os.getenv("MYSQL_DATABASE_HOST"),
"user": os.getenv("MYSQL_DATABASE_USER"),
"password": os.getenv("MYSQL_DATABASE_PASSWORD"),
"database": "DEVPROD",
"port": os.getenv("MYSQL_DATABASE_PORT"),
"prefix": "",
"options": {"charset": "utf8mb4"},
"log_queries": True,
"propagate": False,
},
"many": {
"driver": "mysql",
"host": "localhost",
"user": "root",
"password": "",
"database": "replicate",
"port": os.getenv("MYSQL_DATABASE_PORT"),
"options": {"charset": "utf8mb4"},
"log_queries": True,
"propagate": False,
},
"postgres": {
"driver": "postgres",
"host": os.getenv("POSTGRES_DATABASE_HOST"),
"user": os.getenv("POSTGRES_DATABASE_USER"),
"password": os.getenv("POSTGRES_DATABASE_PASSWORD"),
"database": os.getenv("POSTGRES_DATABASE_DATABASE"),
"port": os.getenv("POSTGRES_DATABASE_PORT"),
"connection_pooling_enabled": True,
"connection_pooling_max_size": 10,
"connection_pooling_min_size": 2,
"prefix": "",
"log_queries": True,
"propagate": False,
},
# Example with db_url()
# "postgres": db_url(
# "postgres://user:@localhost:5432/postgres", log_queries=True
# ),
"dev": {
"driver": "sqlite",
"database": "orm.sqlite3",
"prefix": "",
"log_queries": True,
},
# Example with db_url()
# "dev": {**db_url("sqlite://orm.sqlite3"), "prefix": "", "log_queries": True},
"mssql": {
"driver": "mssql",
"host": os.getenv("MSSQL_DATABASE_HOST"),
"user": os.getenv("MSSQL_DATABASE_USER"),
"password": os.getenv("MSSQL_DATABASE_PASSWORD"),
"database": os.getenv("MSSQL_DATABASE_DATABASE"),
"port": os.getenv("MSSQL_DATABASE_PORT"),
"prefix": "",
"log_queries": True,
"options": {
"trusted_connection": "Yes",
"integrated_security": "sspi",
"instance": "SQLExpress",
"authentication": "ActiveDirectoryPassword",
"driver": "ODBC Driver 17 for SQL Server",
"connection_timeout": 15,
"connection_pooling": False,
"connection_pooling_size": 100,
},
},
}
DB = ConnectionResolver().set_connection_details(DATABASES)
logger = logging.getLogger("masoniteorm.connection.queries")
logger.setLevel(logging.DEBUG)
formatter = logging.Formatter("It executed the query %(query)s")
stream_handler = logging.StreamHandler()
file_handler = logging.FileHandler("queries.log")
logger.addHandler(stream_handler)
logger.addHandler(file_handler)
# DB = QueryBuilder(connection_details=DATABASES)
# DATABASES = {
# 'default': os.environ.get('DB_DRIVER'),
# 'sqlite': {
# 'driver': 'sqlite',
# 'database': os.environ.get('DB_DATABASE')
# },
# 'postgres': {
# 'driver': 'postgres',
# 'host': env('DB_HOST'),
# 'database': env('DB_DATABASE'),
# 'port': env('DB_PORT'),
# 'user': env('DB_USERNAME'),
# 'password': env('DB_PASSWORD'),
# 'log_queries': env('DB_LOG'),
# },
# }
# DB = DatabaseManager(DATABASES)
# Model.set_connection_resolver(DB)
================================================
FILE: tests/models/test_models.py
================================================
import datetime
import json
import unittest
import pendulum
from src.masoniteorm.models import Model
class ModelTest(Model):
__dates__ = ["due_date"]
__casts__ = {
"is_vip": "bool",
"payload": "json",
"x": "int",
"f": "float",
"d": "decimal",
}
class FillableModelTest(Model):
__fillable__ = ["due_date", "is_vip"]
class InvalidFillableGuardedModelTest(Model):
__fillable__ = ["due_date"]
__guarded__ = ["is_vip", "payload"]
class InvalidFillableGuardedChildModelTest(ModelTest):
__fillable__ = ["due_date"]
__guarded__ = ["is_vip", "payload"]
class ModelTestForced(Model):
__table__ = "users"
__force_update__ = True
class BaseModel(Model):
__dry__ = True
def get_selects(self):
return [f"{self.get_table_name()}.*"]
class ModelWithBaseModel(BaseModel):
__table__ = "users"
class TestModels(unittest.TestCase):
def test_model_can_access_str_dates_as_pendulum(self):
model = ModelTest.hydrate({"user": "joe", "due_date": "2020-11-28 11:42:07"})
self.assertTrue(model.user)
self.assertTrue(model.due_date)
self.assertIsInstance(model.due_date, pendulum.now().__class__)
def test_model_can_access_str_dates_as_pendulum_from_correct_datetimes(self):
model = ModelTest()
self.assertEqual(
model.get_new_date(datetime.datetime(2021, 1, 1, 7, 10)).hour, 7
)
self.assertEqual(model.get_new_date(datetime.date(2021, 1, 1)).hour, 0)
self.assertEqual(model.get_new_date(datetime.time(1, 1, 1)).hour, 1)
self.assertEqual(model.get_new_date("2020-11-28 11:42:07").hour, 11)
def test_model_can_access_str_dates_on_relationships(self):
model = ModelTest.hydrate({"user": "joe", "due_date": "2020-11-28 11:42:07"})
model.add_relation(
{
"profile": ModelTest.hydrate(
{"name": "bob", "due_date": "2020-11-28 11:42:07"}
)
}
)
self.assertEqual(model.profile.name, "bob")
self.assertTrue(model.profile.due_date.is_past())
def test_model_original_and_dirty_attributes(self):
model = ModelTest.hydrate({"username": "joe", "admin": True})
self.assertEqual(model.username, "joe")
self.assertEqual(
model.__original_attributes__, {"username": "joe", "admin": True}
)
model.username = "bob"
self.assertEqual(model.username, "bob")
self.assertEqual(model.get_original("username"), "joe")
self.assertEqual(model.get_dirty("username"), "bob")
self.assertEqual(model.__dirty_attributes__["username"], "bob")
self.assertEqual(model.get_dirty_keys(), ["username"])
self.assertTrue(model.is_dirty() is True)
self.assertEqual(
model.__original_attributes__, {"username": "joe", "admin": True}
)
def test_model_creates_when_new(self):
model = ModelTest.hydrate({"id": 1, "username": "joe", "admin": True})
model.name = "Bill"
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("UPDATE"))
model = ModelTest()
model.name = "Bill"
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("INSERT"))
def test_model_can_cast_attributes(self):
model = ModelTest.hydrate(
{
"is_vip": 1,
"payload": '["item1", "item2"]',
"x": True,
"f": "10.5",
"d": 3.14,
}
)
self.assertEqual(type(model.payload), list)
self.assertEqual(type(model.x), int)
self.assertEqual(type(model.f), float)
self.assertEqual(type(model.is_vip), bool)
self.assertEqual(type(model.serialize()["is_vip"]), bool)
def test_model_can_cast_dict_attributes(self):
"""test cast with dict object to json field"""
dictcasttest = {}
dictcasttest["key"] = "value"
model = ModelTest.hydrate(
{"is_vip": 1, "payload": dictcasttest, "x": True, "f": "10.5"}
)
self.assertEqual(type(model.payload), dict)
self.assertEqual(type(model.x), int)
self.assertEqual(type(model.f), float)
self.assertEqual(type(model.is_vip), bool)
self.assertEqual(type(model.serialize()["is_vip"]), bool)
def test_valid_json_cast(self):
model = ModelTest.hydrate(
{"payload": {"this": "dict", "is": "usable", "as": "json"}}
)
self.assertEqual(type(model.payload), dict)
model = ModelTest.hydrate(
{"payload": {"this": "dict", "is": "invalid", "as": "json"}}
)
self.assertEqual(type(model.payload), dict)
model = ModelTest.hydrate(
{"payload": '{"this": "dict", "is": "usable", "as": "json"}'}
)
self.assertEqual(type(model.payload), dict)
model = ModelTest.hydrate({"payload": '{"valid": "json", "int": 1}'})
self.assertEqual(type(model.payload), dict)
model = ModelTest.hydrate({"payload": "{'this': 'should', 'throw': 'error'}"})
self.assertEqual(model.payload, None)
with self.assertRaises(ValueError):
model.payload = "{'this': 'should', 'throw': 'error'}"
model.save()
def test_model_update_without_changes(self):
model = ModelTest.hydrate(
{"id": 1, "username": "joe", "name": "Joe", "admin": True}
)
model.username = "joe"
model.name = "Bill"
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("UPDATE"))
self.assertNotIn("username", sql)
def test_force_update_on_model_class(self):
model = ModelTestForced.hydrate(
{"id": 1, "username": "joe", "name": "Joe", "admin": True}
)
model.username = "joe"
model.name = "Bill"
sql = model.save(query=True).to_sql()
self.assertTrue(sql.startswith("UPDATE"))
self.assertIn("username", sql)
self.assertIn("name", sql)
def test_only_method(self):
model = ModelTestForced.hydrate(
{"id": 1, "username": "joe", "name": "Joe", "admin": True}
)
self.assertEqual({"username": "joe"}, model.only("username"))
self.assertEqual({"username": "joe"}, model.only(["username"]))
def test_model_update_without_changes_at_all(self):
model = ModelTest.hydrate(
{"id": 1, "username": "joe", "name": "Joe", "admin": True}
)
model.username = "joe"
model.name = "Joe"
sql = model.save(query=True).to_sql()
self.assertFalse(sql.startswith("UPDATE"))
def test_model_using_or_where(self):
model = ModelTest()
sql = model.where("name", "=", "joe").or_where("is_vip", True).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `model_tests` WHERE `model_tests`.`name` = 'joe' OR `model_tests`.`is_vip` = '1'""",
)
def test_model_using_or_where_and_chaining_wheres(self):
model = ModelTest()
sql = (
model.where("name", "=", "joe")
.or_where(
lambda query: query.where("username", "Joseph").or_where(
"age", ">=", 18
)
)
.to_sql()
)
self.assertTrue(
sql,
"""SELECT * FROM `model_tests` WHERE `model_tests`.`name` = 'joe' OR (`model_tests`.`username` = 'Joseph' OR `model_tests`.`age` >= '18'))""",
)
def test_both_fillable_and_guarded_attributes_raise(self):
# Both fillable and guarded props are populated on this class
with self.assertRaises(AttributeError):
InvalidFillableGuardedModelTest()
# Child that inherits from an intermediary class also fails
with self.assertRaises(AttributeError):
InvalidFillableGuardedChildModelTest()
# Still shouldn't be allowed to define even if empty
InvalidFillableGuardedModelTest.__fillable__ = []
with self.assertRaises(AttributeError):
InvalidFillableGuardedModelTest()
# Or wildcard
InvalidFillableGuardedModelTest.__fillable__ = ["*"]
with self.assertRaises(AttributeError):
InvalidFillableGuardedModelTest()
# Empty guarded attr still raises
InvalidFillableGuardedModelTest.__guarded__ = []
with self.assertRaises(AttributeError):
InvalidFillableGuardedModelTest()
# Removing one of the props allows us to instantiate
delattr(InvalidFillableGuardedModelTest, "__guarded__")
InvalidFillableGuardedModelTest()
def test_model_can_provide_default_select(self):
sql = ModelWithBaseModel.to_sql()
self.assertEqual(
sql,
"""SELECT `users`.* FROM `users`""",
)
def test_model_can_override_to_default_select(self):
sql = ModelWithBaseModel.select(["products.name", "products.id", "store.name"]).to_sql()
self.assertEqual(
sql,
"""SELECT `products`.`name`, `products`.`id`, `store`.`name` FROM `users`""",
)
def test_model_can_use_aggregate_funcs_with_default_selects(self):
sql = ModelWithBaseModel.count().to_sql()
self.assertEqual(
sql,
"""SELECT COUNT(*) AS m_count_reserved FROM `users`""",
)
sql = ModelWithBaseModel.max("id").to_sql()
self.assertEqual(
sql,
"""SELECT MAX(`users`.`id`) AS id FROM `users`""",
)
sql = ModelWithBaseModel.min("id").to_sql()
self.assertEqual(
sql,
"""SELECT MIN(`users`.`id`) AS id FROM `users`""",
)
================================================
FILE: tests/mssql/builder/test_mssql_query_builder.py
================================================
import inspect
import unittest
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import PostgresGrammar
from tests.utils import MockConnectionFactory
class MockConnection:
connection_details = {}
def make_connection(self):
return self
@classmethod
def get_default_query_grammar(cls):
return
class ModelTest(Model):
__timestamps__ = False
class TestMSSQLQueryBuilder(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users", dry=True):
connection = MockConnectionFactory().make("mssql")
return QueryBuilder(
# self.grammar,
connection_class=connection,
connection="mssql",
table=table,
model=ModelTest(),
dry=dry,
)
def test_sum(self):
builder = self.get_builder()
builder.sum("age")
self.assertEqual(
builder.to_sql(), "SELECT SUM([users].[age]) AS age FROM [users]"
)
def test_where_like(self):
builder = self.get_builder()
builder.where("age", "like", "%name%")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] LIKE '%name%'"
)
def test_where_not_like(self):
builder = self.get_builder()
builder.where("age", "not like", "%name%")
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] WHERE [users].[age] NOT LIKE '%name%'",
)
def test_max(self):
builder = self.get_builder()
builder.max("age")
self.assertEqual(
builder.to_sql(), "SELECT MAX([users].[age]) AS age FROM [users]"
)
def test_min(self):
builder = self.get_builder()
builder.min("age")
self.assertEqual(
builder.to_sql(), "SELECT MIN([users].[age]) AS age FROM [users]"
)
def test_avg(self):
builder = self.get_builder()
builder.avg("age")
self.assertEqual(
builder.to_sql(), "SELECT AVG([users].[age]) AS age FROM [users]"
)
def test_all(self):
builder = self.get_builder()
builder.all()
self.assertEqual(builder.to_sql(), "SELECT * FROM [users]")
def test_get(self):
builder = self.get_builder()
builder.get()
self.assertEqual(builder.to_sql(), "SELECT * FROM [users]")
def test_first(self):
builder = self.get_builder().first(query=True)
self.assertEqual(builder.to_sql(), "SELECT TOP 1 * FROM [users]")
def test_select(self):
builder = self.get_builder()
builder.select("name", "email")
self.assertEqual(
builder.to_sql(), "SELECT [users].[name], [users].[email] FROM [users]"
)
def test_add_select_no_table(self):
builder = self.get_builder(table=None)
builder.add_select(
"other_test", lambda q: q.max("updated_at").table("different_table")
).add_select("some_alias", lambda q: q.max("updated_at").table("another_table"))
self.assertEqual(
builder.to_sql(),
(
"SELECT "
"(SELECT MAX([different_table].[updated_at]) AS updated_at FROM [different_table]) AS other_test, "
"(SELECT MAX([another_table].[updated_at]) AS updated_at FROM [another_table]) AS some_alias"
),
)
def test_select_raw(self):
builder = self.get_builder()
builder.select_raw("count(email) as email_count")
self.assertEqual(
builder.to_sql(), "SELECT count(email) as email_count FROM [users]"
)
def test_create(self):
builder = self.get_builder().without_global_scopes()
builder.create(
{"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True
)
self.assertEqual(
builder.to_sql(),
"INSERT INTO [users] ([users].[name], [users].[email]) VALUES ('Corentin All', 'corentin@yopmail.com')",
)
def test_delete(self):
builder = self.get_builder()
builder.delete("name", "Joe", query=True)
self.assertEqual(
builder.to_sql(), "DELETE FROM [users] WHERE [users].[name] = 'Joe'"
)
def test_where(self):
builder = self.get_builder()
builder.where("name", "Joe")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[name] = 'Joe'"
)
def test_where_exists(self):
builder = self.get_builder()
builder.where_exists("name")
self.assertEqual(builder.to_sql(), "SELECT * FROM [users] WHERE EXISTS 'name'")
def test_limit(self):
builder = self.get_builder()
builder.limit(5)
self.assertEqual(builder.to_sql(), "SELECT TOP 5 * FROM [users]")
def test_offset(self):
builder = self.get_builder()
builder.offset(5)
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] OFFSET 5 ROWS FETCH NEXT 1 ROWS ONLY",
)
def test_join(self):
builder = self.get_builder()
builder.join("profiles", "users.id", "=", "profiles.user_id")
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] INNER JOIN [profiles] ON [users].[id] = [profiles].[user_id]",
)
def test_left_join(self):
builder = self.get_builder()
builder.left_join("profiles", "users.id", "=", "profiles.user_id")
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] LEFT JOIN [profiles] ON [users].[id] = [profiles].[user_id]",
)
def test_right_join(self):
builder = self.get_builder()
builder.right_join("profiles", "users.id", "=", "profiles.user_id")
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] RIGHT JOIN [profiles] ON [users].[id] = [profiles].[user_id]",
)
def test_update(self):
builder = self.get_builder().update(
{"name": "Joe", "email": "joe@yopmail.com"}, dry=True
)
self.assertEqual(
builder.to_sql(),
"UPDATE [users] SET [users].[name] = 'Joe', [users].[email] = 'joe@yopmail.com'",
)
# def test_increment(self):
# builder = self.get_builder()
# builder.increment("age", 1)
# self.assertEqual(
# builder.to_sql(), "UPDATE [users] SET [users].[age] = [users].[age] + '1'"
# )
# def test_decrement(self):
# builder = self.get_builder()
# builder.decrement("age", 1)
# self.assertEqual(
# builder.to_sql(), "UPDATE [users] SET [users].[age] = [users].[age] - '1'"
# )
def test_count(self):
builder = self.get_builder()
builder.count("id")
self.assertEqual(
builder.to_sql(), "SELECT COUNT([users].[id]) AS id FROM [users]"
)
def test_order_by_asc(self):
builder = self.get_builder()
builder.order_by("email", "asc")
self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC")
def test_order_by_desc(self):
builder = self.get_builder()
builder.order_by("email", "desc")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC"
)
def test_where_column(self):
builder = self.get_builder()
builder.where_column("name", "username")
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] WHERE [users].[name] = [users].[username]",
)
def test_where_not_in(self):
builder = self.get_builder()
builder.where_not_in("id", [1, 2, 3])
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] WHERE [users].[id] NOT IN ('1','2','3')",
)
def test_between(self):
builder = self.get_builder()
builder.between("id", 2, 5)
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] WHERE [users].[id] BETWEEN '2' AND '5'",
)
def test_not_between(self):
builder = self.get_builder()
builder.not_between("id", 2, 5)
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] WHERE [users].[id] NOT BETWEEN '2' AND '5'",
)
def test_where_in(self):
builder = self.get_builder()
builder.where_in("id", [1, 2, 3])
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] WHERE [users].[id] IN ('1','2','3')",
)
def test_where_null(self):
builder = self.get_builder()
builder.where_null("name")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[name] IS NULL"
)
def test_where_not_null(self):
builder = self.get_builder()
builder.where_not_null("name")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[name] IS NOT NULL"
)
def test_having(self):
builder = self.get_builder(table="payments")
builder.select("user_id").avg("salary").group_by("user_id").having(
"salary", ">=", "1000"
)
self.assertEqual(
builder.to_sql(),
"SELECT [payments].[user_id], AVG([payments].[salary]) AS salary FROM [payments] GROUP BY [payments].[user_id] HAVING [payments].[salary] >= '1000'",
)
def test_group_by(self):
builder = self.get_builder(table="payments")
builder.select("user_id").min("salary").group_by("user_id")
self.assertEqual(
builder.to_sql(),
"SELECT [payments].[user_id], MIN([payments].[salary]) AS salary FROM [payments] GROUP BY [payments].[user_id]",
)
def test_builder_alone(self):
self.assertTrue(
QueryBuilder(
connection_class=MockConnection,
connection="mssql",
connection_details={
"default": "mssql",
"mssql": {
"driver": "mssql",
"host": "localhost",
"user": "root",
"password": "root",
"database": "orm",
"port": "5432",
"prefix": "",
"grammar": "mssql",
},
},
).table("users")
)
def test_where_lt(self):
builder = self.get_builder()
builder.where("age", "<", "20")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] < '20'"
)
def test_where_lte(self):
builder = self.get_builder()
builder.where("age", "<=", "20")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] <= '20'"
)
def test_where_gt(self):
builder = self.get_builder()
builder.where("age", ">", "20")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] > '20'"
)
def test_where_gte(self):
builder = self.get_builder()
builder.where("age", ">=", "20")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] >= '20'"
)
def test_where_ne(self):
builder = self.get_builder()
builder.where("age", "!=", "20")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] WHERE [users].[age] != '20'"
)
def test_or_where(self):
builder = self.get_builder()
builder.where("age", "20").or_where("age", "<", 20)
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] WHERE [users].[age] = '20' OR [users].[age] < '20'",
)
def test_can_call_with_schema(self):
builder = self.get_builder()
sql = (
builder.table("information_schema.columns")
.select("table_name")
.where("table_name", "users")
.to_sql()
)
self.assertEqual(
sql,
"""SELECT [information_schema].[columns].[table_name] FROM [information_schema].[columns] WHERE [information_schema].[columns].[table_name] = 'users'""",
)
def test_truncate(self):
builder = self.get_builder(dry=True)
sql = builder.truncate()
self.assertEqual(sql, "TRUNCATE TABLE [users]")
def test_truncate_without_foreign_keys(self):
builder = self.get_builder(dry=True)
sql = builder.truncate(foreign_keys=True)
self.assertEqual(sql, "TRUNCATE TABLE [users]")
def test_latest(self):
builder = self.get_builder()
builder.latest("email")
self.assertEqual(
builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] DESC"
)
def test_latest_multiple(self):
builder = self.get_builder()
builder.latest("email", "created_at")
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] ORDER BY [email] DESC, [created_at] DESC",
)
def test_oldest(self):
builder = self.get_builder()
builder.oldest("email")
self.assertEqual(builder.to_sql(), "SELECT * FROM [users] ORDER BY [email] ASC")
def test_oldest_multiple(self):
builder = self.get_builder()
builder.oldest("email", "created_at")
self.assertEqual(
builder.to_sql(),
"SELECT * FROM [users] ORDER BY [email] ASC, [created_at] ASC",
)
================================================
FILE: tests/mssql/builder/test_mssql_query_builder_relationships.py
================================================
import inspect
import unittest
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MSSQLGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory
from dotenv import load_dotenv
load_dotenv(".env")
class Logo(Model):
__connection__ = "mssql"
class Article(Model):
__connection__ = "mssql"
@belongs_to("id", "article_id")
def logo(self):
return Logo
class Profile(Model):
__connection__ = "mssql"
class User(Model):
__connection__ = "mssql"
@belongs_to("id", "user_id")
def articles(self):
return Article
@belongs_to("id", "user_id")
def profile(self):
return Profile
@belongs_to("id", "parent_dynamic_id")
def parent_dynamic(self):
return self.__class__
@belongs_to("id", "parent_specified_id")
def parent_specified(self):
return User
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users"):
connection = MockConnectionFactory().make("mssql")
return QueryBuilder(
grammar=MSSQLGrammar,
connection_class=connection,
connection="mssql",
table=table,
model=User(),
)
def test_has(self):
builder = self.get_builder()
sql = builder.has("articles").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM [users] WHERE EXISTS ("""
"""SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id]"""
""")""",
)
def test_has_reference_to_self(self):
builder = self.get_builder()
sql = builder.has("parent_dynamic").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM [users] WHERE EXISTS ("""
"""SELECT * FROM [users] WHERE [users].[parent_dynamic_id] = [users].[id]"""
""")""",
)
def test_has_reference_to_self_using_class(self):
builder = self.get_builder()
sql = builder.has("parent_specified").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM [users] WHERE EXISTS ("""
"""SELECT * FROM [users] WHERE [users].[parent_specified_id] = [users].[id]"""
""")""",
)
def test_where_has_query(self):
builder = self.get_builder()
sql = builder.where_has("articles", lambda q: q.where("active", 1)).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM [users] WHERE EXISTS ("""
"""SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id] AND [articles].[active] = '1'"""
""")""",
)
def test_relationship_multiple_has(self):
to_sql = User.has("articles", "profile").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM [users] WHERE EXISTS ("""
"""SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id]"""
""") AND EXISTS ("""
"""SELECT * FROM [profiles] WHERE [profiles].[user_id] = [users].[id]"""
""")""",
)
def test_relationship_multiple_has_calls(self):
to_sql = User.has("articles").has("profile").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM [users] WHERE EXISTS ("""
"""SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id]"""
""") AND EXISTS ("""
"""SELECT * FROM [profiles] WHERE [profiles].[user_id] = [users].[id]"""
""")""",
)
def test_nested_has(self):
to_sql = User.has("articles.logo").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM [users] WHERE EXISTS (SELECT * FROM [articles] WHERE [articles].[user_id] = [users].[id] AND EXISTS (SELECT * FROM [logos] WHERE [logos].[article_id] = [articles].[id]))""",
)
================================================
FILE: tests/mssql/grammar/test_mssql_delete_grammar.py
================================================
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MSSQLGrammar
class TestMySQLDeleteGrammar(unittest.TestCase):
def setUp(self):
self.builder = QueryBuilder(MSSQLGrammar, table="users")
def test_can_compile_delete(self):
to_sql = self.builder.delete("id", 1, query=True).to_sql()
sql = "DELETE FROM [users] WHERE [users].[id] = '1'"
self.assertEqual(to_sql, sql)
def test_can_compile_delete_with_where(self):
to_sql = (
self.builder.where("age", 20)
.where("profile", 1)
.set_action("delete")
.delete(query=True)
.to_sql()
)
sql = (
"DELETE FROM [users] WHERE [users].[age] = '20' AND [users].[profile] = '1'"
)
self.assertEqual(to_sql, sql)
================================================
FILE: tests/mssql/grammar/test_mssql_insert_grammar.py
================================================
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MSSQLGrammar
class TestMySQLInsertGrammar(unittest.TestCase):
def setUp(self):
self.builder = QueryBuilder(MSSQLGrammar, table="users")
def test_can_compile_insert(self):
to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql()
sql = "INSERT INTO [users] ([users].[name]) VALUES ('Joe')"
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create(self):
to_sql = self.builder.bulk_create(
# These keys are intentionally out of order to show column to value alignment works
[
{"name": "Joe", "age": 5},
{"age": 35, "name": "Bill"},
{"name": "John", "age": 10},
],
query=True,
).to_sql()
sql = "INSERT INTO [users] ([age], [name]) VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')"
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create_qmark(self):
to_sql = self.builder.bulk_create(
[{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True
).to_qmark()
sql = "INSERT INTO [users] ([name]) VALUES ('?'), ('?'), ('?')"
self.assertEqual(to_sql, sql)
================================================
FILE: tests/mssql/grammar/test_mssql_qmark.py
================================================
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MSSQLGrammar
class TestMSSQLQmark(unittest.TestCase):
def setUp(self):
self.builder = QueryBuilder(MSSQLGrammar, table="users")
def test_can_compile_select(self):
mark = self.builder.select("username").where("name", "Joe")
sql = "SELECT [users].[username] FROM [users] WHERE [users].[name] = '?'"
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, ["Joe"])
def test_can_compile_update(self):
mark = self.builder.update({"name": "Bob"}, dry=True).where("name", "Joe")
sql = "UPDATE [users] SET [users].[name] = '?' WHERE [users].[name] = '?'"
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, ["Bob", "Joe"])
def test_can_compile_insert(self):
mark = self.builder.create({"name": "Bob"}, query=True)
sql = "INSERT INTO [users] ([users].[name]) VALUES ('?')"
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, ["Bob"])
def test_can_compile_where_in(self):
mark = self.builder.where_in("id", [1, 2, 3])
qmark_sql = "SELECT * FROM [users] WHERE [users].[id] IN ('?', '?', '?')"
sql = "SELECT * FROM [users] WHERE [users].[id] IN ('1','2','3')"
self.assertEqual(mark.to_qmark(), qmark_sql)
self.assertEqual(mark._bindings, [1, 2, 3])
self.assertEqual(self.builder.where_in("id", [1, 2, 3]).to_sql(), sql)
self.builder.reset()
# Assert that when passed string values it generates synonymous sql
self.assertEqual(self.builder.where_in("id", ["1", "2", "3"]).to_sql(), sql)
================================================
FILE: tests/mssql/grammar/test_mssql_select_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query.grammars import MSSQLGrammar
from src.masoniteorm.testing import BaseTestCaseSelectGrammar
class TestMSSQLGrammar(BaseTestCaseSelectGrammar, unittest.TestCase):
grammar = MSSQLGrammar
def can_compile_select(self):
"""
self.builder.to_sql()
"""
return "SELECT * FROM [users]"
def can_compile_with_columns(self):
"""
self.builder.select('username', 'password').to_sql()
"""
return "SELECT [users].[username], [users].[password] FROM [users]"
def can_compile_with_where(self):
"""
self.builder.select('username', 'password').where('id', 1).to_sql()
"""
return "SELECT [users].[username], [users].[password] FROM [users] WHERE [users].[id] = '1'"
def can_compile_with_several_where(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql()
"""
return "SELECT [users].[username], [users].[password] FROM [users] WHERE [users].[id] = '1' AND [users].[username] = 'joe'"
def can_compile_with_several_where_and_limit(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql()
"""
return "SELECT TOP 10 [users].[username], [users].[password] FROM [users] WHERE [users].[id] = '1' AND [users].[username] = 'joe'"
def can_compile_with_sum(self):
"""
self.builder.sum('age').to_sql()
"""
return "SELECT SUM([users].[age]) AS age FROM [users]"
def can_compile_order_by_and_first(self):
"""
self.builder.order_by('id', 'asc').first()
"""
return """SELECT TOP 1 * FROM [users] ORDER BY [id] ASC"""
def can_compile_with_max(self):
"""
self.builder.max('age').to_sql()
"""
return "SELECT MAX([users].[age]) AS age FROM [users]"
def can_compile_with_max_and_columns(self):
"""
self.builder.select('username').max('age').to_sql()
"""
return "SELECT [users].[username], MAX([users].[age]) AS age FROM [users]"
def can_compile_with_max_and_columns_different_order(self):
"""
self.builder.max('age').select('username').to_sql()
"""
return "SELECT [users].[username], MAX([users].[age]) AS age FROM [users]"
def can_compile_with_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').to_sql()
"""
return "SELECT [users].[username] FROM [users] ORDER BY [age] DESC"
def can_compile_with_multiple_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql()
"""
return "SELECT [users].[username] FROM [users] ORDER BY [age] DESC, [name] ASC"
def can_compile_with_group_by(self):
"""
self.builder.select('username').group_by('age').to_sql()
"""
return "SELECT [users].[username] FROM [users] GROUP BY [users].[age]"
def can_compile_where_in(self):
"""
self.builder.select('username').where_in('age', [1,2,3]).to_sql()
"""
return "SELECT [users].[username] FROM [users] WHERE [users].[age] IN ('1','2','3')"
def can_compile_where_in_empty(self):
"""
self.builder.where_in('age', []).to_sql()
"""
return """SELECT * FROM [users] WHERE 0 = 1"""
def can_compile_where_null(self):
"""
self.builder.select('username').where_null('age').to_sql()
"""
return "SELECT [users].[username] FROM [users] WHERE [users].[age] IS NULL"
def can_compile_where_not_null(self):
"""
self.builder.select('username').where_not_null('age').to_sql()
"""
return "SELECT [users].[username] FROM [users] WHERE [users].[age] IS NOT NULL"
def can_compile_where_raw(self):
"""
self.builder.where_raw("`age` = '18'").to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[age] = '18'"
def test_can_compile_where_raw_and_where_with_multiple_bindings(self):
query = self.builder.where_raw(
"[age] = '?' AND [is_admin] = '?'", [18, True]
).where("email", "test@example.com")
self.assertEqual(
query.to_qmark(),
"SELECT * FROM [users] WHERE [age] = '?' AND [is_admin] = '?' AND [users].[email] = '?'",
)
self.assertEqual(query._bindings, [18, True, "test@example.com"])
def can_compile_select_raw(self):
"""
self.builder.select_raw("COUNT(*)").to_sql()
"""
return "SELECT COUNT(*) FROM [users]"
def can_compile_limit_and_offset(self):
"""
self.builder.limit(10).offset(10).to_sql()
"""
return "SELECT * FROM [users] OFFSET 10 ROWS FETCH NEXT 10 ROWS ONLY"
def can_compile_select_raw_with_select(self):
"""
self.builder.select('id').select_raw("COUNT(*)").to_sql()
"""
return "SELECT [users].[id], COUNT(*) FROM [users]"
def can_compile_having_raw(self):
"""
self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql()
"""
return "SELECT COUNT(*) as counts FROM [users] HAVING counts > 18"
def can_compile_count(self):
"""
self.builder.count().to_sql()
"""
return "SELECT COUNT(*) AS m_count_reserved FROM [users]"
def can_compile_count_column(self):
"""
self.builder.count().to_sql()
"""
return "SELECT COUNT([users].[money]) AS money FROM [users]"
def can_compile_where_column(self):
"""
self.builder.where_column('name', 'email').to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[name] = [users].[email]"
def can_compile_or_where(self):
"""
self.builder.where('name', 2).or_where('name', 3).to_sql()
"""
return (
"SELECT * FROM [users] WHERE [users].[name] = '2' OR [users].[name] = '3'"
)
def can_grouped_where(self):
"""
self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql()
"""
return "SELECT * FROM [users] WHERE ([users].[age] = '2' AND [users].[name] = 'Joe')"
def can_compile_sub_select(self):
"""
self.builder.where_in('name',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age')
).to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[name] IN (SELECT [users].[age] FROM [users])"
def can_compile_sub_select_where(self):
"""
self.builder.where_in('age',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe')
).to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[age] IN (SELECT [users].[age] FROM [users] WHERE [users].[age] = '2' AND [users].[name] = 'Joe')"
def can_compile_sub_select_value(self):
"""
self.builder.where('name',
self.builder.new().sum('age')
).to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[name] = (SELECT SUM([users].[age]) AS age FROM [users])"
def can_compile_complex_sub_select(self):
"""
self.builder.where_in('name',
(QueryBuilder(GrammarFactory.make(self.grammar), table='users')
.select('age').where_in('email',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email')
))
).to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[name] IN (SELECT [users].[age] FROM [users] WHERE [users].[email] IN (SELECT [users].[email] FROM [users]))"
def can_compile_exists(self):
"""
self.builder.select('age').where_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return "SELECT [users].[age] FROM [users] WHERE EXISTS (SELECT [users].[username] FROM [users] WHERE [users].[age] = '12')"
def can_compile_not_exists(self):
"""
self.builder.select('age').where_not_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return "SELECT [users].[age] FROM [users] WHERE NOT EXISTS (SELECT [users].[username] FROM [users] WHERE [users].[age] = '12')"
def can_compile_having(self):
"""
builder.sum('age').group_by('age').having('age').to_sql()
"""
return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age]"
def can_compile_having_order(self):
"""
builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql()
"""
return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] ORDER [users].[age] DESC"
def can_compile_between(self):
"""
builder.between('age', 18, 21).to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[age] BETWEEN '18' AND '21'"
def can_compile_not_between(self):
"""
builder.not_between('age', 18, 21).to_sql()
"""
return "SELECT * FROM [users] WHERE [users].[age] NOT BETWEEN '18' AND '21'"
def can_compile_where_not_in(self):
"""
self.builder.select('username').where_not_in('age', [1,2,3]).to_sql()
"""
return "SELECT [users].[username] FROM [users] WHERE [users].[age] NOT IN ('1','2','3')"
def can_compile_having_with_expression(self):
"""
builder.sum('age').group_by('age').having('age', 10).to_sql()
"""
return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] = '10'"
def can_compile_having_with_greater_than_expression(self):
"""
builder.sum('age').group_by('age').having('age', '>', 10).to_sql()
"""
return "SELECT SUM([users].[age]) AS age FROM [users] GROUP BY [users].[age] HAVING [users].[age] > '10'"
def can_compile_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return "SELECT * FROM [users] INNER JOIN [contacts] ON [users].[id] = [contacts].[user_id]"
def can_compile_left_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return "SELECT * FROM [users] LEFT JOIN [contacts] ON [users].[id] = [contacts].[user_id]"
def can_compile_multiple_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return "SELECT * FROM [users] INNER JOIN [contacts] ON [users].[id] = [contacts].[user_id] INNER JOIN [posts] ON [comments].[post_id] = [posts].[id]"
def test_can_compile_where_raw(self):
to_sql = self.builder.where_raw("[age] = '18'").to_sql()
self.assertEqual(to_sql, "SELECT * FROM [users] WHERE [age] = '18'")
def test_can_compile_having_raw(self):
to_sql = (
self.builder.select_raw("COUNT(*) as counts")
.having_raw("counts > 10")
.to_sql()
)
self.assertEqual(
to_sql, "SELECT COUNT(*) as counts FROM [users] HAVING counts > 10"
)
def test_can_compile_having_raw_order(self):
to_sql = (
self.builder.select_raw("COUNT(*) as counts")
.having_raw("counts > 10")
.order_by_raw("counts DESC")
.to_sql()
)
self.assertEqual(
to_sql,
"SELECT COUNT(*) as counts FROM [users] HAVING counts > 10 ORDER BY counts DESC",
)
def test_can_compile_select_raw(self):
to_sql = self.builder.select_raw("COUNT(*)").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_select_raw_with_select(self):
to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def can_compile_first_or_fail(self):
"""
builder = self.get_builder()
builder.where("is_admin", "=", True).first_or_fail()
"""
return """SELECT TOP 1 * FROM [users] WHERE [users].[is_admin] = '1'"""
def where_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%")
"""
return """SELECT * FROM [users] WHERE [users].[age] LIKE '%name%'"""
def where_not_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%")
"""
return """SELECT * FROM [users] WHERE [users].[age] NOT LIKE '%name%'"""
def where_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "regexp", "Joe")
"""
return """SELECT * FROM [users] WHERE [users].[age] LIKE 'Joe'"""
def where_not_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "not regexp", "Joe")
"""
return """SELECT * FROM [users] WHERE [users].[age] NOT LIKE 'Joe'"""
def can_compile_join_clause(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on("bgt.fund", "=", "rg.fund")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] AND [bgt].[dept] = [rg].[dept] AND [bgt].[acct] = [rg].[acct] AND [bgt].[sub] = [rg].[sub]"
def can_compile_join_clause_with_value(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [bgt].[active] = '1' OR [bgt].[acct] = '1234'"
def can_compile_join_clause_with_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_null("bgt.acct")
.or_on_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [acct] IS NULL OR [dept] IS NULL AND [rg].[abc] = '10'"
def can_compile_join_clause_with_not_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_not_null("bgt.acct")
.or_on_not_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [acct] IS NOT NULL OR [dept] IS NOT NULL AND [rg].[abc] = '10'"
def can_compile_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.on_null("bgt")
),
).to_sql()
"""
return "SELECT * FROM [users] INNER JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] AND [bgt] IS NULL"
def can_compile_left_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.left_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return "SELECT * FROM [users] LEFT JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] OR [bgt] IS NULL"
def can_compile_right_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.right_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return "SELECT * FROM [users] RIGHT JOIN [report_groups] AS [rg] ON [bgt].[fund] = [rg].[fund] OR [bgt] IS NULL"
def shared_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return "SELECT * FROM [users] WITH(ROWLOCK) WHERE [users].[votes] >= '100'"
def update_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return "SELECT * FROM [users] WITH(ROWLOCK) WHERE [users].[votes] >= '100'"
def can_user_where_raw_and_where(self):
"""
builder.where_raw("`age` = '18'").where("name", "=", "James").to_sql()
"""
return "SELECT * FROM [users] WHERE age = '18' AND [users].[name] = 'James'"
def where_exists_with_lambda(self):
return """SELECT * FROM [users] WHERE EXISTS (SELECT * FROM [users] WHERE [users].[age] = '1')"""
def where_not_exists_with_lambda(self):
return """SELECT * FROM [users] WHERE NOT EXISTS (SELECT * FROM [users] WHERE [users].[age] = '1')"""
def where_date(self):
return (
"""SELECT * FROM [users] WHERE DATE([users].[created_at]) = '2022-06-01'"""
)
def or_where_null(self):
return """SELECT * FROM [users] WHERE [users].[column1] IS NULL OR [users].[column2] IS NULL"""
def select_distinct(self):
return """SELECT DISTINCT [users].[group] FROM [users]"""
================================================
FILE: tests/mssql/grammar/test_mssql_update_grammar.py
================================================
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MSSQLGrammar
from src.masoniteorm.expressions import Raw
class TestMSSQLUpdateGrammar(unittest.TestCase):
def setUp(self):
self.builder = QueryBuilder(MSSQLGrammar, table="users")
def test_can_compile_update(self):
to_sql = (
self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql()
)
sql = "UPDATE [users] SET [users].[name] = 'Joe' WHERE [users].[name] = 'bob'"
self.assertEqual(to_sql, sql)
def test_can_compile_update_with_multiple_where(self):
to_sql = (
self.builder.where("name", "bob")
.where("age", 20)
.update({"name": "Joe"}, dry=True)
.to_sql()
)
sql = "UPDATE [users] SET [users].[name] = 'Joe' WHERE [users].[name] = 'bob' AND [users].[age] = '20'"
self.assertEqual(to_sql, sql)
def test_raw_expression(self):
to_sql = self.builder.update({"name": Raw("[username]")}, dry=True).to_sql()
sql = "UPDATE [users] SET [users].[name] = [username]"
self.assertEqual(to_sql, sql)
================================================
FILE: tests/mssql/schema/test_mssql_schema_builder.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import MSSQLConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import MSSQLPlatform
class TestMSSQLSchemaBuilder(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection_class=MSSQLConnection,
connection="mssql",
connection_details=DATABASES,
platform=MSSQLPlatform,
dry=True,
).on("mssql")
def test_can_add_columns(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
["CREATE TABLE [users] ([name] VARCHAR(255) NOT NULL, [age] INT NOT NULL)"],
)
def test_can_add_tiny_text(self):
with self.schema.create("users") as blueprint:
blueprint.tiny_text("description")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
["CREATE TABLE [users] ([description] TINYTEXT NOT NULL)"],
)
def test_can_add_unsigned_decimal(self):
with self.schema.create("users") as blueprint:
blueprint.unsigned_decimal("amount", 19, 4)
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
["CREATE TABLE [users] ([amount] DECIMAL(19, 4) NOT NULL)"],
)
def test_can_add_columns_with_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
blueprint.unique("name")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] ([name] VARCHAR(255) NOT NULL, [age] INT NOT NULL, CONSTRAINT users_name_unique UNIQUE (name))"
],
)
def test_can_have_float_type(self):
with self.schema.create("users") as blueprint:
blueprint.float("amount")
self.assertEqual(
blueprint.to_sql(),
["""CREATE TABLE [users] (""" """[amount] FLOAT(19, 4) NOT NULL)"""],
)
def test_can_have_unsigned_columns(self):
with self.schema.create("users") as blueprint:
blueprint.integer("profile_id").unsigned()
blueprint.big_integer("big_profile_id").unsigned()
blueprint.tiny_integer("tiny_profile_id").unsigned()
blueprint.small_integer("small_profile_id").unsigned()
blueprint.medium_integer("medium_profile_id").unsigned()
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] ("
"[profile_id] INT NOT NULL, "
"[big_profile_id] BIGINT NOT NULL, "
"[tiny_profile_id] TINYINT NOT NULL, "
"[small_profile_id] SMALLINT NOT NULL, "
"[medium_profile_id] MEDIUMINT NOT NULL)"
],
)
def test_can_add_columns_with_foreign_key_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.foreign("profile_id").references("id").on("profiles")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] "
"([name] VARCHAR(255) NOT NULL, "
"[age] INT NOT NULL, "
"[profile_id] INT NOT NULL, "
"CONSTRAINT users_name_unique UNIQUE (name), "
"CONSTRAINT users_profile_id_foreign FOREIGN KEY ([profile_id]) REFERENCES [profiles]([id]))"
],
)
def test_can_add_columns_with_add_foreign_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.add_foreign("profile_id.id.profiles")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] "
"([name] VARCHAR(255) NOT NULL, "
"[age] INT NOT NULL, "
"[profile_id] INT NOT NULL, "
"CONSTRAINT users_name_unique UNIQUE (name), "
"CONSTRAINT users_profile_id_foreign FOREIGN KEY ([profile_id]) REFERENCES [profiles]([id]))"
],
)
def test_can_advanced_table_creation(self):
with self.schema.create("users") as blueprint:
blueprint.increments("id")
blueprint.string("name")
blueprint.string("email").unique()
blueprint.string("password")
blueprint.integer("admin").default(0)
blueprint.string("remember_token").nullable()
blueprint.timestamp("verified_at").nullable()
blueprint.timestamp("registered_at").default_raw("CURRENT_TIMESTAMP")
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 10)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] ([id] INT IDENTITY NOT NULL, [name] VARCHAR(255) NOT NULL, [email] VARCHAR(255) NOT NULL, "
"[password] VARCHAR(255) NOT NULL, [admin] INT NOT NULL DEFAULT 0, [remember_token] VARCHAR(255) NULL, "
"[verified_at] DATETIME NULL, [registered_at] DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, [created_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, "
"[updated_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_email_unique UNIQUE (email))"
],
)
def test_can_advanced_table_creation2(self):
with self.schema.create("users") as blueprint:
blueprint.increments("id")
blueprint.enum("gender", ["male", "female"])
blueprint.string("name")
blueprint.string("duration")
blueprint.string("url")
blueprint.inet("last_address").nullable()
blueprint.cidr("route_origin").nullable()
blueprint.macaddr("mac_address").nullable()
blueprint.datetime("published_at")
blueprint.string("thumbnail").nullable()
blueprint.integer("premium")
blueprint.integer("author_id").unsigned().nullable()
blueprint.foreign("author_id").references("id").on("users").on_delete(
"CASCADE"
)
blueprint.text("description")
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 15)
self.assertEqual(
blueprint.to_sql(),
(
[
"CREATE TABLE [users] ([id] INT IDENTITY NOT NULL, [gender] VARCHAR(255) NOT NULL CHECK([gender] IN ('male', 'female')), [name] VARCHAR(255) NOT NULL, [duration] VARCHAR(255) NOT NULL, "
"[url] VARCHAR(255) NOT NULL, [last_address] VARCHAR(255) NULL, [route_origin] VARCHAR(255) NULL, [mac_address] VARCHAR(255) NULL, [published_at] DATETIME NOT NULL, [thumbnail] VARCHAR(255) NULL, [premium] INT NOT NULL, "
"[author_id] INT NULL, [description] TEXT NOT NULL, [created_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, "
"[updated_at] DATETIME NULL DEFAULT CURRENT_TIMESTAMP, "
"CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY ([author_id]) REFERENCES [users]([id]) ON DELETE CASCADE)"
]
),
)
def test_can_add_columns_with_foreign_key_constraint_name(self):
with self.schema.create("users") as blueprint:
blueprint.integer("profile_id")
blueprint.foreign("profile_id", name="profile_foreign").references("id").on(
"profiles"
)
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] ("
"[profile_id] INT NOT NULL, "
"CONSTRAINT profile_foreign FOREIGN KEY ([profile_id]) REFERENCES [profiles]([id]))"
],
)
def test_can_have_composite_keys(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.primary(["name", "age"])
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] "
"([name] VARCHAR(255) NOT NULL, "
"[age] INT NOT NULL, "
"[profile_id] INT NOT NULL, "
"CONSTRAINT users_name_unique UNIQUE (name), "
"CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))"
],
)
def test_can_have_column_primary_key(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").primary()
blueprint.integer("age")
blueprint.integer("profile_id")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] "
"([name] VARCHAR(255) NOT NULL, "
"[age] INT NOT NULL, "
"[profile_id] INT NOT NULL, "
"CONSTRAINT users_name_primary PRIMARY KEY (name))"
],
)
def test_has_table(self):
schema_sql = self.schema.has_table("users")
sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'users'"
self.assertEqual(schema_sql, sql)
def test_can_truncate(self):
sql = self.schema.truncate("users")
self.assertEqual(sql, "TRUNCATE TABLE [users]")
def test_can_rename_table(self):
sql = self.schema.rename("users", "clients")
self.assertEqual(sql, "EXEC sp_rename [users], [clients]")
def test_can_drop_table_if_exists(self):
sql = self.schema.drop_table_if_exists("users", "clients")
self.assertEqual(sql, "DROP TABLE IF EXISTS [users]")
def test_can_drop_table(self):
sql = self.schema.drop_table("users", "clients")
self.assertEqual(sql, "DROP TABLE [users]")
def test_has_column(self):
sql = self.schema.has_column("users", "name")
self.assertEqual(
sql,
"SELECT 1 FROM sys.columns WHERE Name = N'name' AND Object_ID = Object_ID(N'users')",
)
def test_can_enable_foreign_keys(self):
sql = self.schema.enable_foreign_key_constraints()
self.assertEqual(sql, "")
def test_can_disable_foreign_keys(self):
sql = self.schema.disable_foreign_key_constraints()
self.assertEqual(sql, "")
def test_can_truncate_without_foreign_keys(self):
sql = self.schema.truncate("users", foreign_keys=True)
self.assertEqual(
sql,
[
"ALTER TABLE [users] NOCHECK CONSTRAINT ALL",
"TRUNCATE TABLE [users]",
"ALTER TABLE [users] WITH CHECK CHECK CONSTRAINT ALL",
],
)
def test_can_add_enum(self):
with self.schema.create("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE [users] ([status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive')))"
],
)
def test_can_change_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active").change()
self.assertEqual(len(blueprint.table.changed_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
"ALTER TABLE [users] ALTER COLUMN [status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive'))"
],
)
================================================
FILE: tests/mssql/schema/test_mssql_schema_builder_alter.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import MSSQLConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import MSSQLPlatform
from src.masoniteorm.schema.Table import Table
class TestMySQLSchemaBuilderAlter(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection_class=MSSQLConnection,
connection="mssql",
connection_details=DATABASES,
platform=MSSQLPlatform,
dry=True,
)
def test_can_add_columns(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
sql = [
"ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL, [age] INT NOT NULL"
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_adds_column_with_default(self):
with self.schema.table("users") as blueprint:
blueprint.string("name").default(0)
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = ["ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL DEFAULT 0"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_rename(self):
with self.schema.table("users") as blueprint:
blueprint.rename("post", "comment", "integer")
table = Table("users")
table.add_column("post", "integer")
blueprint.table.from_table = table
sql = ["EXEC sp_rename 'users.post', 'comment', 'COLUMN'"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_and_rename(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.rename("post", "comment", "integer")
table = Table("users")
table.add_column("post", "integer")
blueprint.table.from_table = table
sql = [
"ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL",
"EXEC sp_rename 'users.post', 'comment', 'COLUMN'",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop1(self):
with self.schema.table("users") as blueprint:
blueprint.drop_column("post")
sql = ["ALTER TABLE [users] DROP COLUMN post"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_column_and_foreign_key(self):
with self.schema.table("users") as blueprint:
blueprint.unsigned_integer("playlist_id").nullable()
blueprint.foreign("playlist_id").references("id").on("playlists").on_delete(
"cascade"
)
sql = [
"ALTER TABLE [users] ADD [playlist_id] INT NULL",
"ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY ([playlist_id]) REFERENCES [playlists]([id]) ON DELETE CASCADE",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_column_and_add_foreign(self):
with self.schema.table("users") as blueprint:
blueprint.unsigned_integer("playlist_id").nullable()
blueprint.add_foreign("playlist_id.id.playlists").on_delete("cascade")
sql = [
"ALTER TABLE [users] ADD [playlist_id] INT NULL",
"ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY ([playlist_id]) REFERENCES [playlists]([id]) ON DELETE CASCADE",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_foreign_key(self):
with self.schema.table("users") as blueprint:
blueprint.drop_foreign("users_playlist_id_foreign")
sql = ["ALTER TABLE [users] DROP CONSTRAINT users_playlist_id_foreign"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_foreign_key_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_foreign(["playlist_id"])
sql = ["ALTER TABLE [users] DROP CONSTRAINT users_playlist_id_foreign"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_unique_constraint(self):
with self.schema.table("users") as blueprint:
blueprint.drop_unique("users_playlist_id_unique")
sql = ["DROP INDEX [users].[users_playlist_id_unique]"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_primary(self):
with self.schema.table("users") as blueprint:
blueprint.primary("playlist_id")
sql = [
"ALTER TABLE [users] ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)"
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_index(self):
with self.schema.table("users") as blueprint:
blueprint.index("playlist_id")
sql = ["CREATE INDEX users_playlist_id_index ON [users](playlist_id)"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_index(self):
with self.schema.table("users") as blueprint:
blueprint.drop_index("users_playlist_id_index")
sql = ["DROP INDEX [users].[users_playlist_id_index]"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_index_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_index(["playlist_id"])
sql = ["DROP INDEX [users].[users_playlist_id_index]"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_unique_constraint_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_unique(["playlist_id"])
sql = ["DROP INDEX [users].[users_playlist_id_unique]"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_primary(self):
with self.schema.table("users") as blueprint:
blueprint.drop_primary(["id"])
sql = ["DROP INDEX [users].[users_id_primary]"]
self.assertEqual(blueprint.to_sql(), sql)
def test_has_table(self):
schema_sql = self.schema.has_table("users")
sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'users'"
self.assertEqual(schema_sql, sql)
def test_drop_table(self):
schema_sql = self.schema.has_table("users")
sql = "SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE TABLE_NAME = 'users'"
self.assertEqual(schema_sql, sql)
def test_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").change()
blueprint.string("name")
blueprint.string("external_type").default("external")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = [
"ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL, [external_type] VARCHAR(255) NOT NULL DEFAULT 'external'",
"ALTER TABLE [users] ALTER COLUMN [age] INT NOT NULL",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_drop_add_and_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").default(0).change()
blueprint.string("name")
blueprint.drop_column("email")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
table.add_column("email", "string")
blueprint.table.from_table = table
sql = [
"ALTER TABLE [users] ADD [name] VARCHAR(255) NOT NULL",
"ALTER TABLE [users] ALTER COLUMN [age] INT NOT NULL DEFAULT 0",
"ALTER TABLE [users] DROP COLUMN email",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_create_indexes(self):
with self.schema.table("users") as blueprint:
blueprint.index("name")
blueprint.index(["name", "email"])
blueprint.unique("name")
blueprint.unique(["name", "email"])
blueprint.fulltext("description")
self.assertEqual(len(blueprint.table.added_columns), 0)
print(blueprint.to_sql())
self.assertEqual(
blueprint.to_sql(),
[
"CREATE INDEX users_name_index ON [users](name)",
"CREATE INDEX users_name_email_index ON [users](name,email)",
"ALTER TABLE [users] ADD CONSTRAINT users_name_unique UNIQUE(name)",
"ALTER TABLE [users] ADD CONSTRAINT users_name_email_unique UNIQUE(name,email)",
],
)
def test_timestamp_alter_add_nullable_column(self):
with self.schema.table("users") as blueprint:
blueprint.timestamp("due_date").nullable()
self.assertEqual(len(blueprint.table.added_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = ["ALTER TABLE [users] ADD [due_date] DATETIME NULL"]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_add_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
"ALTER TABLE [users] ADD [status] VARCHAR(255) NOT NULL DEFAULT 'active' CHECK([status] IN ('active', 'inactive'))"
]
self.assertEqual(blueprint.to_sql(), sql)
================================================
FILE: tests/mysql/builder/test_mysql_builder_transaction.py
================================================
import inspect
import os
import unittest
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory
from tests.integrations.config.database import DB
if os.getenv("RUN_MYSQL_DATABASE") == "True":
class User(Model):
__connection__ = "mysql"
__timestamps__ = False
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users"):
connection = ConnectionFactory().make("mysql")
return QueryBuilder(
grammar=MySQLGrammar, connection=connection, table=table
).on("mysql")
def test_transaction(self):
builder = self.get_builder()
builder.begin()
builder.create({"name": "phillip2", "email": "phillip2"})
# builder.commit()
user = builder.where("name", "phillip2").first()
self.assertEqual(user["name"], "phillip2")
builder.rollback()
user = builder.where("name", "phillip2").first()
self.assertEqual(user, None)
def test_transaction_default_globally(self):
connection = DB.begin_transaction()
self.assertEqual(connection, self.get_builder().new_connection())
DB.commit()
DB.begin_transaction()
DB.rollback()
================================================
FILE: tests/mysql/builder/test_query_builder.py
================================================
import datetime
import inspect
import unittest
from src.masoniteorm.exceptions import InvalidArgument
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.relationships import has_many
from src.masoniteorm.scopes import SoftDeleteScope
from tests.integrations.config.database import DATABASES
from tests.utils import MockConnectionFactory
class Articles(Model):
pass
class User(Model):
__timestamps__ = False
@has_many("id", "user_id")
def articles(self):
return Articles
class BaseTestQueryBuilder:
maxDiff = None
def get_builder(self, table="users", dry=True):
connection = MockConnectionFactory().make("default")
return QueryBuilder(
grammar=self.grammar,
connection_class=connection,
connection="mysql",
table=table,
model=User(),
dry=dry,
connection_details=DATABASES,
)
def test_sum(self):
builder = self.get_builder()
builder.sum("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_sum_chained(self):
builder = self.get_builder()
builder.sum("age").max("salary")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_with_(self):
builder = self.get_builder()
builder.with_("articles").sum("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_like(self):
builder = self.get_builder()
builder.where("age", "like", "%name%")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_like(self):
builder = self.get_builder()
builder.where("age", "not like", "%name%")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_max(self):
builder = self.get_builder()
builder.max("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_min(self):
builder = self.get_builder()
builder.min("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_avg(self):
builder = self.get_builder()
builder.avg("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_all(self):
builder = self.get_builder()
builder.all()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_get(self):
builder = self.get_builder()
builder.get()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_first(self):
builder = self.get_builder().first(query=True)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_find_with_model(self):
builder = self.get_builder()
builder.find(1000, query=True)
sql = '''SELECT * FROM `users` WHERE `users`.`id` = '1000\''''
self.assertEqual(builder.to_sql(), sql)
def test_find_with_model_and_list(self):
builder = self.get_builder()
builder.find([1000, 2000, 3000], query=True)
sql = '''SELECT * FROM `users` WHERE `users`.`id` IN ('1000','2000','3000')'''
self.assertEqual(builder.to_sql(), sql)
def test_find_with_model_custom_column(self):
builder = self.get_builder()
builder.find(10, column="age", query=True)
sql = '''SELECT * FROM `users` WHERE `users`.`age` = '10\''''
self.assertEqual(builder.to_sql(), sql)
def test_find_with_builder(self):
builder = self.get_builder()
builder._model = None
builder.find(10, column="age", query=True)
sql = '''SELECT * FROM `users` WHERE `users`.`age` = '10\''''
self.assertEqual(builder.to_sql(), sql)
def test_find_with_builder_and_list(self):
builder = self.get_builder()
builder._model = None
builder.find([10, 20, 30], column="age", query=True)
sql = '''SELECT * FROM `users` WHERE `users`.`age` IN ('10','20','30')'''
self.assertEqual(builder.to_sql(), sql)
def test_find_with_builder_without_column(self):
builder = self.get_builder()
builder._model = None
with self.assertRaises(InvalidArgument):
builder.find(10, query=True)
def test_select(self):
builder = self.get_builder()
builder.select("name", "email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select_with_table(self):
builder = self.get_builder()
builder.select("users.*")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select_with_table_raw(self):
builder = self.get_builder()
builder.select("users.*").from_raw("orders, customers")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select_with_alias(self):
builder = self.get_builder()
builder.select("users.username as name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select_raw(self):
builder = self.get_builder()
builder.select_raw("count(email) as email_count")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_add_select(self):
builder = self.get_builder()
sql = (
builder.select("name")
.add_select("phone_count", lambda q: q.count("*").table("phones"))
.add_select("salary", lambda q: q.count("*").table("salary"))
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_add_select_no_table(self):
builder = self.get_builder(table=None)
sql = (
builder.add_select(
"other_test", lambda q: q.max("updated_at").table("different_table")
)
.add_select(
"some_alias", lambda q: q.max("updated_at").table("another_table")
)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_create(self):
builder = self.get_builder().without_global_scopes()
builder.create(
{"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_delete(self):
builder = self.get_builder()
builder.delete("name", "Joe", query=True)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where(self):
builder = self.get_builder()
builder.where("name", "Joe")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_exists(self):
builder = self.get_builder()
builder.where_exists("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_limit(self):
builder = self.get_builder()
builder.limit(5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_offset(self):
builder = self.get_builder()
builder.offset(5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_join(self):
builder = self.get_builder()
builder.join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_left_join(self):
builder = self.get_builder()
builder.left_join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_right_join(self):
builder = self.get_builder()
builder.right_join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_update(self):
builder = self.get_builder().update(
{"name": "Joe", "email": "joe@yopmail.com"}, dry=True
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
# def test_increment(self):
# builder = self.get_builder()
# builder.increment("age", 1)
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(builder.to_sql(), sql)
# def test_decrement(self):
# builder = self.get_builder()
# builder.decrement("age", 1)
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(builder.to_sql(), sql)
def test_count(self):
builder = self.get_builder()
builder.count("id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_asc(self):
builder = self.get_builder()
builder.order_by("email", "asc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_desc(self):
builder = self.get_builder()
builder.order_by("email", "desc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_column(self):
builder = self.get_builder()
builder.where_column("name", "username")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_in(self):
builder = self.get_builder()
builder.where_not_in("id", [1, 2, 3])
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_between(self):
builder = self.get_builder()
builder.between("id", 2, 5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_not_between(self):
builder = self.get_builder()
builder.not_between("id", 2, 5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_in(self):
builder = self.get_builder()
builder.where_in("id", [1, 2, 3])
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_null(self):
builder = self.get_builder()
builder.where_null("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_null(self):
builder = self.get_builder()
builder.where_not_null("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_having(self):
builder = self.get_builder(table="payments")
builder.select("user_id").avg("salary").group_by("user_id").having(
"salary", ">=", "1000"
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_group_by(self):
builder = self.get_builder(table="payments")
builder.select("user_id").min("salary").group_by("user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_builder_alone(self):
self.assertTrue(
QueryBuilder(
dry=True,
connection_details={
"default": "mysql",
"mysql": {
"driver": "mysql",
"host": "localhost",
"username": "root",
"password": "",
"database": "orm",
"port": "3306",
"prefix": "",
"grammar": "mysql",
"options": {"charset": "utf8mb4"},
},
},
).table("users")
)
def test_where_lt(self):
builder = self.get_builder()
builder.where("age", "<", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_lte(self):
builder = self.get_builder()
builder.where("age", "<=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_gt(self):
builder = self.get_builder()
builder.where("age", ">", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_gte(self):
builder = self.get_builder()
builder.where("age", ">=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_ne(self):
builder = self.get_builder()
builder.where("age", "!=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_or_where(self):
builder = self.get_builder()
builder.where("age", "20").or_where("age", "<", 20)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_or_where(self):
builder = self.get_builder()
builder.where("age", "20").or_where("age", "<", 20)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_like_as_operator(self):
builder = self.get_builder()
builder.where("age", "like", "%name%")
sql = getattr(self, "where_like")()
self.assertEqual(builder.to_sql(), sql)
def test_where_like(self):
builder = self.get_builder()
builder.where_like("age", "%name%")
sql = getattr(self, "where_like")()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_like_as_operator(self):
builder = self.get_builder()
builder.where("age", "not like", "%name%")
sql = getattr(self, "where_not_like")()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_like(self):
builder = self.get_builder()
builder.where_not_like("age", "%name%")
sql = getattr(self, "where_not_like")()
self.assertEqual(builder.to_sql(), sql)
def test_can_call_with_multi_tables(self):
builder = self.get_builder()
sql = (
builder.table("information_schema.columns")
.select("table_name")
.where("table_name", "users")
.to_sql()
)
self.assertEqual(
sql,
"""SELECT `information_schema`.`columns`.`table_name` FROM `information_schema`.`columns` WHERE `information_schema`.`columns`.`table_name` = 'users'""",
)
def test_truncate(self):
builder = self.get_builder(dry=True)
sql = builder.truncate()
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
def test_truncate_without_foreign_keys(self):
builder = self.get_builder(dry=True)
sql = builder.truncate(foreign_keys=True)
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
def test_shared_lock(self):
builder = self.get_builder(dry=True)
sql = builder.where("votes", ">=", 100).shared_lock().to_sql()
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
def test_update_lock(self):
builder = self.get_builder(dry=True)
sql = builder.where("votes", ">=", 100).lock_for_update().to_sql()
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
class MySQLQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase):
grammar = MySQLGrammar
def sum(self):
"""
builder = self.get_builder()
builder.sum('age')
"""
return "SELECT SUM(`users`.`age`) AS age FROM `users`"
def sum_chained(self):
"""
builder = self.get_builder()
builder.sum('age')
"""
return "SELECT SUM(`users`.`age`) AS age, MAX(`users`.`salary`) AS salary FROM `users`"
def with_(self):
"""
builder = self.get_builder()
builder.with_('articles').sum('age')
"""
return "SELECT SUM(`users`.`age`) AS age FROM `users`"
def max(self):
"""
builder = self.get_builder()
builder.max('age')
"""
return "SELECT MAX(`users`.`age`) AS age FROM `users`"
def min(self):
"""
builder = self.get_builder()
builder.min('age')
"""
return "SELECT MIN(`users`.`age`) AS age FROM `users`"
def avg(self):
"""
builder = self.get_builder()
builder.avg('age')
"""
return "SELECT AVG(`users`.`age`) AS age FROM `users`"
def first(self):
"""
builder = self.get_builder()
builder.first()
"""
return "SELECT * FROM `users` LIMIT 1"
def all(self):
"""
builder = self.get_builder()
builder.all()
"""
return "SELECT * FROM `users`"
def get(self):
"""
builder = self.get_builder()
builder.get()
"""
return "SELECT * FROM `users`"
def select(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return "SELECT `users`.`name`, `users`.`email` FROM `users`"
def select_with_table(self):
"""
builder = self.get_builder()
builder.select('users.*')
"""
return "SELECT `users`.* FROM `users`"
def select_with_table_raw(self):
"""
builder = self.get_builder()
builder.select('users.*')
"""
return "SELECT `users`.* FROM orders, customers"
def select_with_alias(self):
"""
builder = self.get_builder()
builder.select('users.name as name')
"""
return "SELECT `users`.`username` AS name FROM `users`"
def select_raw(self):
"""
builder = self.get_builder()
builder.select_raw('count(email) as email_count')
"""
return "SELECT count(email) as email_count FROM `users`"
def add_select(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return "SELECT `users`.`name`, (SELECT COUNT(*) AS m_count_reserved FROM `phones`) AS phone_count, (SELECT COUNT(*) AS m_count_reserved FROM `salary`) AS salary FROM `users`"
def add_select_no_table(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return (
"SELECT "
"(SELECT MAX(`different_table`.`updated_at`) AS updated_at FROM `different_table`) AS other_test, "
"(SELECT MAX(`another_table`.`updated_at`) AS updated_at FROM `another_table`) AS some_alias"
)
def create(self):
"""
builder = get_builder()
builder.create({"name": "Corentin All", 'email': 'corentin@yopmail.com'})
"""
return "INSERT INTO `users` (`users`.`name`, `users`.`email`) VALUES ('Corentin All', 'corentin@yopmail.com')"
def delete(self):
"""
builder = get_builder()
builder.delete("name', 'Joe')
"""
return "DELETE FROM `users` WHERE `users`.`name` = 'Joe'"
def where(self):
"""
builder = get_builder()
builder.where('name', 'Joe')
"""
return "SELECT * FROM `users` WHERE `users`.`name` = 'Joe'"
def where_exists(self):
"""
builder = get_builder()
builder.where_exists('name')
"""
return "SELECT * FROM `users` WHERE EXISTS 'name'"
def limit(self):
"""
builder = get_builder()
builder.limit(5)
"""
return "SELECT * FROM `users` LIMIT 5"
def offset(self):
"""
builder = get_builder()
builder.offset(5)
"""
return "SELECT * FROM `users` OFFSET 5"
def join(self):
"""
builder.join("profiles", "users.id", "=", "profiles.user_id")
"""
return "SELECT * FROM `users` INNER JOIN `profiles` ON `users`.`id` = `profiles`.`user_id`"
def left_join(self):
"""
builder.left_join("profiles", "users.id", "=", "profiles.user_id")
"""
return "SELECT * FROM `users` LEFT JOIN `profiles` ON `users`.`id` = `profiles`.`user_id`"
def right_join(self):
"""
builder.right_join("profiles", "users.id", "=", "profiles.user_id")
"""
return "SELECT * FROM `users` RIGHT JOIN `profiles` ON `users`.`id` = `profiles`.`user_id`"
def update(self):
"""
builder.update({"name": "Joe", "email": "joe@yopmail.com"})
"""
return "UPDATE `users` SET `users`.`name` = 'Joe', `users`.`email` = 'joe@yopmail.com'"
def increment(self):
"""
builder.increment('age', 1)
"""
return "UPDATE `users` SET `users`.`age` = `users`.`age` + '1'"
def decrement(self):
"""
builder.decrement('age', 1)
"""
return "UPDATE `users` SET `users`.`age` = `users`.`age` - '1'"
def count(self):
"""
builder.count(id)
"""
return "SELECT COUNT(`users`.`id`) AS id FROM `users`"
def order_by_asc(self):
"""
builder.order_by('email', 'asc')
"""
return "SELECT * FROM `users` ORDER BY `email` ASC"
def order_by_desc(self):
"""
builder.order_by('email', 'des')
"""
return "SELECT * FROM `users` ORDER BY `email` DESC"
def where_column(self):
"""
builder.where_column('name', 'username')
"""
return "SELECT * FROM `users` WHERE `users`.`name` = `users`.`username`"
def where_null(self):
"""
builder.where_null('name')
"""
return "SELECT * FROM `users` WHERE `users`.`name` IS NULL"
def where_not_null(self):
"""
builder.where_null('name')
"""
return "SELECT * FROM `users` WHERE `users`.`name` IS NOT NULL"
def where_not_in(self):
"""
builder.where_not_in('id', [1, 2, 3])
"""
return "SELECT * FROM `users` WHERE `users`.`id` NOT IN ('1','2','3')"
def where_in(self):
"""
builder.where_in('id', [1, 2, 3])
"""
return "SELECT * FROM `users` WHERE `users`.`id` IN ('1','2','3')"
def between(self):
"""
builder.between('id', 2, 5)
"""
return "SELECT * FROM `users` WHERE `users`.`id` BETWEEN '2' AND '5'"
def not_between(self):
"""
builder.not_between('id', 2, 5)
"""
return "SELECT * FROM `users` WHERE `users`.`id` NOT BETWEEN '2' AND '5'"
def having(self):
"""
builder.select('user_id').avg('salary').group_by('user_id').having('salary', '>=', '1000')
"""
return "SELECT `payments`.`user_id`, AVG(`payments`.`salary`) AS salary FROM `payments` GROUP BY `payments`.`user_id` HAVING `payments`.`salary` >= '1000'"
def group_by(self):
"""
builder.select('user_id').min('salary').group_by('user_id')
"""
return "SELECT `payments`.`user_id`, MIN(`payments`.`salary`) AS salary FROM `payments` GROUP BY `payments`.`user_id`"
def where_lt(self):
"""
builder = self.get_builder()
builder.where('age', '<', '20')
"""
return "SELECT * FROM `users` WHERE `users`.`age` < '20'"
def where_lte(self):
"""
builder = self.get_builder()
builder.where('age', '<=', '20')
"""
return "SELECT * FROM `users` WHERE `users`.`age` <= '20'"
def where_gt(self):
"""
builder = self.get_builder()
builder.where('age', '>', '20')
"""
return "SELECT * FROM `users` WHERE `users`.`age` > '20'"
def where_gte(self):
"""
builder = self.get_builder()
builder.where('age', '>=', '20')
"""
return "SELECT * FROM `users` WHERE `users`.`age` >= '20'"
def where_ne(self):
"""
builder = self.get_builder()
builder.where('age', '!=', '20')
"""
return "SELECT * FROM `users` WHERE `users`.`age` != '20'"
def or_where(self):
"""
builder = self.get_builder()
builder.where('age', '20').or_where('age','<', 20)
"""
return (
"SELECT * FROM `users` WHERE `users`.`age` = '20' OR `users`.`age` < '20'"
)
def where_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%")
"""
return "SELECT * FROM `users` WHERE `users`.`age` LIKE '%name%'"
def where_not_like(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%")
"""
return "SELECT * FROM `users` WHERE `users`.`age` NOT LIKE '%name%'"
def truncate(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return """TRUNCATE TABLE `users`"""
def truncate_without_foreign_keys(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return [
"SET FOREIGN_KEY_CHECKS=0",
"TRUNCATE TABLE `users`",
"SET FOREIGN_KEY_CHECKS=1",
]
def shared_lock(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' LOCK IN SHARE MODE"
def update_lock(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE"
def test_latest(self):
builder = self.get_builder()
builder.latest("email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_oldest(self):
builder = self.get_builder()
builder.oldest("email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def latest(self):
"""
builder.order_by('email', 'des')
"""
return "SELECT * FROM `users` ORDER BY `email` DESC"
def oldest(self):
"""
builder.order_by('email', 'asc')
"""
return "SELECT * FROM `users` ORDER BY `email` ASC"
================================================
FILE: tests/mysql/builder/test_query_builder_scopes.py
================================================
import inspect
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.relationships import has_many
from src.masoniteorm.scopes import SoftDeleteScope
from tests.utils import MockConnectionFactory
class BaseTestQueryBuilderScopes(unittest.TestCase):
grammar = "mysql"
def get_builder(self, table="users"):
connection = MockConnectionFactory().make("default")
return QueryBuilder(
grammar=MySQLGrammar,
connection_class=connection,
connection="mysql",
table=table,
connection_details=DATABASES,
)
def test_scopes(self):
builder = self.get_builder().set_scope(
"gender", lambda model, q: q.where("gender", "w")
)
self.assertEqual(
builder.gender().where("id", 1).to_sql(),
"SELECT * FROM `users` WHERE `users`.`gender` = 'w' AND `users`.`id` = '1'",
)
def test_global_scopes(self):
builder = self.get_builder().set_global_scope(
"where_not_null", lambda q: q.where_not_null("deleted_at"), action="select"
)
self.assertEqual(
builder.where("id", 1).to_sql(),
"SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NOT NULL",
)
def test_global_scope_from_class(self):
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(
builder.where("id", 1).to_sql(),
"SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NULL",
)
def test_global_scope_remove_from_class(self):
builder = (
self.get_builder()
.set_global_scope(SoftDeleteScope())
.remove_global_scope(SoftDeleteScope())
)
self.assertEqual(
builder.where("id", 1).to_sql(),
"SELECT * FROM `users` WHERE `users`.`id` = '1'",
)
def test_global_scope_adds_method(self):
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(builder.with_trashed().to_sql(), "SELECT * FROM `users`")
================================================
FILE: tests/mysql/builder/test_transactions.py
================================================
import inspect
import os
import unittest
from src.masoniteorm.connections.ConnectionFactory import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.relationships import has_many
from tests.utils import MockConnectionFactory
class Articles(Model):
pass
class User(Model):
@has_many("id", "user_id")
def articles(self):
return Articles
if os.getenv("RUN_MYSQL_DATABASE", False) == "True":
class TestTransactions(unittest.TestCase):
pass
# def get_builder(self, table="users"):
# connection = ConnectionFactory().make("default")
# return QueryBuilder(MySQLGrammar, connection, table=table, model=User())
# def test_can_start_transaction(self, table="users"):
# builder = self.get_builder()
# builder.begin()
# builder.create({"name": "mike", "email": "mike@email.com"})
# builder.rollback()
# self.assertFalse(builder.where("email", "mike@email.com").first())
================================================
FILE: tests/mysql/connections/test_mysql_connection_selects.py
================================================
import os
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
class MockUser(Model):
__table__ = "users"
if os.getenv("RUN_MYSQL_DATABASE", False) == "True":
class TestMySQLSelectConnection(unittest.TestCase):
def setUp(self):
self.builder = QueryBuilder(MySQLGrammar, table="users")
def test_can_compile_select(self):
to_sql = MockUser.where("id", 1).to_sql()
sql = "SELECT * FROM `users` WHERE `users`.`id` = '1'"
self.assertEqual(to_sql, sql)
def test_can_get_first_record(self):
user = MockUser.where("id", 1).first()
self.assertEqual(user.id, 1)
def test_can_find_first_record(self):
user = MockUser.find(1)
self.assertEqual(user.id, 1)
def test_can_get_all_records(self):
users = MockUser.all()
self.assertGreater(len(users), 1)
def test_can_get_5_records(self):
users = MockUser.limit(5).get()
self.assertEqual(len(users), 5)
def test_can_get_1_record_with_get(self):
users = MockUser.where("id", 1).limit(5).get()
self.assertEqual(len(users), 1)
users = MockUser.limit(5).where("id", 1).get()
self.assertEqual(len(users), 1)
================================================
FILE: tests/mysql/grammar/test_mysql_delete_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
class BaseDeleteGrammarTest:
def setUp(self):
self.builder = QueryBuilder(MySQLGrammar, table="users")
def test_can_compile_delete(self):
to_sql = self.builder.delete("id", 1, query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_delete_in(self):
to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_delete_with_where(self):
to_sql = (
self.builder.where("age", 20)
.where("profile", 1)
.set_action("delete")
.delete(query=True)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestMySQLDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase):
grammar = "mysql"
def can_compile_delete(self):
"""
(
self.builder
.delete('id', 1)
.to_sql()
)
"""
return "DELETE FROM `users` WHERE `users`.`id` = '1'"
def can_compile_delete_in(self):
"""
(
self.builder
.delete('id', 1)
.to_sql()
)
"""
return "DELETE FROM `users` WHERE `users`.`id` IN ('1','2','3')"
def can_compile_delete_with_where(self):
"""
(
self.builder
.where('age', 20)
.where('profile', 1)
.set_action('delete')
.delete()
.to_sql()
)
"""
return (
"DELETE FROM `users` WHERE `users`.`age` = '20' AND `users`.`profile` = '1'"
)
================================================
FILE: tests/mysql/grammar/test_mysql_insert_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
class BaseInsertGrammarTest:
def setUp(self):
self.builder = QueryBuilder(MySQLGrammar, table="users")
def test_can_compile_insert(self):
to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_insert_with_keywords(self):
to_sql = self.builder.create(name="Joe", query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create(self):
to_sql = self.builder.bulk_create(
# These keys are intentionally out of order to show column to value alignment works
[
{"name": "Joe", "age": 5},
{"age": 35, "name": "Bill"},
{"name": "John", "age": 10},
],
query=True,
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create_qmark(self):
to_sql = self.builder.bulk_create(
[{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True
).to_qmark()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create_multiple(self):
to_sql = self.builder.bulk_create(
[
{"name": "Joe", "active": "1"},
{"name": "Bill", "active": "1"},
{"name": "John", "active": "1"},
],
query=True,
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestMySQLUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase):
grammar = "mysql"
def can_compile_insert(self):
"""
self.builder.create({
'name': 'Joe'
}).to_sql()
"""
return "INSERT INTO `users` (`users`.`name`) VALUES ('Joe')"
def can_compile_insert_with_keywords(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return "INSERT INTO `users` (`users`.`name`) VALUES ('Joe')"
def can_compile_bulk_create(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO `users` (`age`, `name`) VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')"""
def can_compile_bulk_create_multiple(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO `users` (`active`, `name`) VALUES ('1', 'Joe'), ('1', 'Bill'), ('1', 'John')"""
def can_compile_bulk_create_qmark(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO `users` (`name`) VALUES ('?'), ('?'), ('?')"""
================================================
FILE: tests/mysql/grammar/test_mysql_qmark.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
class BaseQMarkTest:
def setUp(self):
self.builder = QueryBuilder(grammar=MySQLGrammar, table="users")
def test_can_compile_select(self):
mark = self.builder.select("username").where("name", "Joe")
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
def test_can_compile_delete(self):
mark = self.builder.where("name", "Joe").delete(query=True)
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
def test_can_compile_update(self):
mark = self.builder.update({"name": "Bob"}, dry=True).where("name", "Joe")
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
def test_can_compile_where_in(self):
mark = self.builder.where_in("id", [1, 2, 3])
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
def test_can_compile_where_not_null(self):
mark = self.builder.where_not_null("id")
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, [])
def test_can_compile_where_with_falsy_values(self):
mark = self.builder.where("name", 0)
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
def test_can_compile_where_with_true_value(self):
mark = self.builder.where("is_admin", True)
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
def test_can_compile_where_with_false_value(self):
mark = self.builder.where("is_admin", False)
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
def test_can_compile_sub_group_bindings(self):
mark = self.builder.where(
lambda query: (
query.where("challenger", 1)
.or_where("proposer", 1)
.or_where("referee", 1)
)
)
sql, bindings = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(mark.to_qmark(), sql)
self.assertEqual(mark._bindings, bindings)
class TestMySQLQmark(BaseQMarkTest, unittest.TestCase):
def can_compile_select(self):
"""
self.builder.select('username').where('name', 'Joe')
"""
return (
"SELECT `users`.`username` FROM `users` WHERE `users`.`name` = '?'",
["Joe"],
)
def can_compile_delete(self):
"""
self.builder.where('name', 'Joe').delete()
"""
return "DELETE FROM `users` WHERE `users`.`name` = '?'", ["Joe"]
def can_compile_update(self):
"""
self.builder.update({
'name': 'Bob'
}).where('name', 'Joe')
"""
return (
"UPDATE `users` SET `users`.`name` = '?' WHERE `users`.`name` = '?'",
["Bob", "Joe"],
)
def can_compile_where_in(self):
"""
self.builder.where_in('id', [1,2,3]).to_qmark()
"""
return (
"SELECT * FROM `users` WHERE `users`.`id` IN ('?', '?', '?')",
[1, 2, 3],
)
def can_compile_where_not_null(self):
"""
self.builder.where_not_null("id").to_qmark()
"""
return ("SELECT * FROM `users` WHERE `users`.`id` IS NOT NULL", ())
def can_compile_where_with_falsy_values(self):
"""
self.builder.where_not_null("id").to_qmark()
"""
return ("SELECT * FROM `users` WHERE `users`.`name` = '?'", [0])
def can_compile_where_with_true_value(self):
"""
self.builder.where("is_admin", True).to_qmark()
"""
return ("SELECT * FROM `users` WHERE `users`.`is_admin` = '1'", [])
def can_compile_where_with_false_value(self):
"""
self.builder.where("is_admin", True).to_qmark()
"""
return ("SELECT * FROM `users` WHERE `users`.`is_admin` = '0'", [])
def can_compile_sub_group_bindings(self):
"""
self.builder.where("is_admin", True).to_qmark()
"""
return (
"SELECT * FROM `users` WHERE (`users`.`challenger` = '?' OR `users`.`proposer` = '?' OR `users`.`referee` = '?')",
[1, 1, 1],
)
================================================
FILE: tests/mysql/grammar/test_mysql_select_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.testing import BaseTestCaseSelectGrammar
class TestMySQLGrammar(BaseTestCaseSelectGrammar, unittest.TestCase):
grammar = MySQLGrammar
def can_compile_select(self):
"""
self.builder.to_sql()
"""
return "SELECT * FROM `users`"
def can_compile_with_columns(self):
"""
self.builder.select('username', 'password').to_sql()
"""
return "SELECT `users`.`username`, `users`.`password` FROM `users`"
def can_compile_order_by_and_first(self):
"""
self.builder.order_by('id', 'asc').first()
"""
return """SELECT * FROM `users` ORDER BY `id` ASC LIMIT 1"""
def can_compile_with_where(self):
"""
self.builder.select('username', 'password').where('id', 1).to_sql()
"""
return "SELECT `users`.`username`, `users`.`password` FROM `users` WHERE `users`.`id` = '1'"
def can_compile_with_several_where(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql()
"""
return "SELECT `users`.`username`, `users`.`password` FROM `users` WHERE `users`.`id` = '1' AND `users`.`username` = 'joe'"
def can_compile_with_several_where_and_limit(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql()
"""
return "SELECT `users`.`username`, `users`.`password` FROM `users` WHERE `users`.`id` = '1' AND `users`.`username` = 'joe' LIMIT 10"
def can_compile_with_sum(self):
"""
self.builder.sum('age').to_sql()
"""
return "SELECT SUM(`users`.`age`) AS age FROM `users`"
def can_compile_with_max(self):
"""
self.builder.max('age').to_sql()
"""
return "SELECT MAX(`users`.`age`) AS age FROM `users`"
def can_compile_with_max_and_columns(self):
"""
self.builder.select('username').max('age').to_sql()
"""
return "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`"
def can_compile_with_max_and_columns_different_order(self):
"""
self.builder.max('age').select('username').to_sql()
"""
return "SELECT `users`.`username`, MAX(`users`.`age`) AS age FROM `users`"
def can_compile_with_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').to_sql()
"""
return "SELECT `users`.`username` FROM `users` ORDER BY `age` DESC"
def can_compile_with_multiple_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql()
"""
return "SELECT `users`.`username` FROM `users` ORDER BY `age` DESC, `name` ASC"
def can_compile_with_group_by(self):
"""
self.builder.select('username').group_by('age').to_sql()
"""
return "SELECT `users`.`username` FROM `users` GROUP BY `users`.`age`"
def can_compile_where_in(self):
"""
self.builder.select('username').where_in('age', [1,2,3]).to_sql()
"""
return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` IN ('1','2','3')"
def can_compile_where_in_empty(self):
"""
self.builder.where_in('age', []).to_sql()
"""
return """SELECT * FROM `users` WHERE 0 = 1"""
def can_compile_where_not_in(self):
"""
self.builder.select('username').where_not_in('age', [1,2,3]).to_sql()
"""
return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` NOT IN ('1','2','3')"
def can_compile_where_null(self):
"""
self.builder.select('username').where_null('age').to_sql()
"""
return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` IS NULL"
def can_compile_where_not_null(self):
"""
self.builder.select('username').where_not_null('age').to_sql()
"""
return "SELECT `users`.`username` FROM `users` WHERE `users`.`age` IS NOT NULL"
def can_compile_where_raw(self):
"""
self.builder.where_raw("`age` = '18'").to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` = '18'"
def can_compile_where_raw_and_where_with_multiple_bindings(self):
"""
self.builder.where_raw("`age` = '?' AND `is_admin` = '?'", [18, True]).where("email", "test@example.com")
"""
return "SELECT * FROM `users` WHERE `age` = '?' AND `is_admin` = '?' AND `users`.`email` = '?'"
def can_compile_having_raw(self):
"""
self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql()
"""
return "SELECT COUNT(*) as counts FROM `users` HAVING counts > 18"
def can_compile_select_raw(self):
"""
self.builder.select_raw("COUNT(*)").to_sql()
"""
return "SELECT COUNT(*) FROM `users`"
def can_compile_limit_and_offset(self):
"""
self.builder.limit(10).offset(10).to_sql()
"""
return "SELECT * FROM `users` LIMIT 10 OFFSET 10"
def can_compile_select_raw_with_select(self):
"""
self.builder.select('id').select_raw("COUNT(*)").to_sql()
"""
return "SELECT `users`.`id`, COUNT(*) FROM `users`"
def can_compile_count(self):
"""
self.builder.count().to_sql()
"""
return "SELECT COUNT(*) AS m_count_reserved FROM `users`"
def can_compile_count_column(self):
"""
self.builder.count().to_sql()
"""
return "SELECT COUNT(`users`.`money`) AS money FROM `users`"
def can_compile_where_column(self):
"""
self.builder.where_column('name', 'email').to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`name` = `users`.`email`"
def can_compile_or_where(self):
"""
self.builder.where('name', 2).or_where('name', 3).to_sql()
"""
return (
"SELECT * FROM `users` WHERE `users`.`name` = '2' OR `users`.`name` = '3'"
)
def can_grouped_where(self):
"""
self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql()
"""
return "SELECT * FROM `users` WHERE (`users`.`age` = '2' AND `users`.`name` = 'Joe')"
def can_compile_sub_select(self):
"""
self.builder.where_in('name',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age')
).to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`name` IN (SELECT `users`.`age` FROM `users`)"
def can_compile_sub_select_where(self):
"""
self.builder.where_in('age',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe')
).to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` IN (SELECT `users`.`age` FROM `users` WHERE `users`.`age` = '2' AND `users`.`name` = 'Joe')"
def can_compile_sub_select_value(self):
"""
self.builder.where('name',
self.builder.new().sum('age')
).to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`name` = (SELECT SUM(`users`.`age`) AS age FROM `users`)"
def can_compile_complex_sub_select(self):
"""
self.builder.where_in('name',
(QueryBuilder(GrammarFactory.make(self.grammar), table='users')
.select('age').where_in('email',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email')
))
).to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`name` IN (SELECT `users`.`age` FROM `users` WHERE `users`.`email` IN (SELECT `users`.`email` FROM `users`))"
def can_compile_exists(self):
"""
self.builder.select('age').where_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return "SELECT `users`.`age` FROM `users` WHERE EXISTS (SELECT `users`.`username` FROM `users` WHERE `users`.`age` = '12')"
def can_compile_not_exists(self):
"""
self.builder.select('age').where_not_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return "SELECT `users`.`age` FROM `users` WHERE NOT EXISTS (SELECT `users`.`username` FROM `users` WHERE `users`.`age` = '12')"
def can_compile_having(self):
"""
builder.sum('age').group_by('age').having('age').to_sql()
"""
return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age`"
def can_compile_having_order(self):
"""
builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql()
"""
return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` ORDER `users`.`age` DESC"
def can_compile_having_with_expression(self):
"""
builder.sum('age').group_by('age').having('age', 10).to_sql()
"""
return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` = '10'"
def can_compile_having_with_greater_than_expression(self):
"""
builder.sum('age').group_by('age').having('age', '>', 10).to_sql()
"""
return "SELECT SUM(`users`.`age`) AS age FROM `users` GROUP BY `users`.`age` HAVING `users`.`age` > '10'"
def can_compile_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return "SELECT * FROM `users` INNER JOIN `contacts` ON `users`.`id` = `contacts`.`user_id`"
def can_compile_left_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return "SELECT * FROM `users` LEFT JOIN `contacts` ON `users`.`id` = `contacts`.`user_id`"
def can_compile_multiple_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return "SELECT * FROM `users` INNER JOIN `contacts` ON `users`.`id` = `contacts`.`user_id` INNER JOIN `posts` ON `comments`.`post_id` = `posts`.`id`"
def can_compile_between(self):
"""
builder.between('age', 18, 21).to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` BETWEEN '18' AND '21'"
def can_compile_not_between(self):
"""
builder.not_between('age', 18, 21).to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` NOT BETWEEN '18' AND '21'"
def test_can_compile_where_raw(self):
to_sql = self.builder.where_raw("`age` = '18'").to_sql()
self.assertEqual(to_sql, "SELECT * FROM `users` WHERE `age` = '18'")
def test_can_compile_having_raw(self):
to_sql = (
self.builder.select_raw("COUNT(*) as counts")
.having_raw("counts > 10")
.to_sql()
)
self.assertEqual(
to_sql, "SELECT COUNT(*) as counts FROM `users` HAVING counts > 10"
)
def test_can_compile_having_raw_order(self):
to_sql = (
self.builder.select_raw("COUNT(*) as counts")
.having_raw("counts > 10")
.order_by_raw("counts DESC")
.to_sql()
)
self.assertEqual(
to_sql,
"SELECT COUNT(*) as counts FROM `users` HAVING counts > 10 ORDER BY counts DESC",
)
def test_can_compile_select_raw(self):
to_sql = self.builder.select_raw("COUNT(*)").to_sql()
self.assertEqual(to_sql, "SELECT COUNT(*) FROM `users`")
def test_can_compile_select_raw_with_select(self):
to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql()
self.assertEqual(to_sql, "SELECT `users`.`id`, COUNT(*) FROM `users`")
def can_compile_first_or_fail(self):
"""
builder = self.get_builder()
builder.where("is_admin", "=", True).first_or_fail()
"""
return """SELECT * FROM `users` WHERE `users`.`is_admin` = '1' LIMIT 1"""
def where_not_like(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` NOT LIKE '%name%'"
def where_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "regexp", "Joe").to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` REGEXP 'Joe'"
def where_not_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "not regexp", "Joe").to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` NOT REGEXP 'Joe'"
def where_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%").to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`age` LIKE '%name%'"
def can_compile_join_clause(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on("bgt.fund", "=", "rg.fund")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` AND `bgt`.`dept` = `rg`.`dept` AND `bgt`.`acct` = `rg`.`acct` AND `bgt`.`sub` = `rg`.`sub`"
def can_compile_join_clause_with_value(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `bgt`.`active` = '1' OR `bgt`.`acct` = '1234'"
def can_compile_join_clause_with_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_null("bgt.acct")
.or_on_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `acct` IS NULL OR `dept` IS NULL AND `rg`.`abc` = '10'"
def can_compile_join_clause_with_not_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_not_null("bgt.acct")
.or_on_not_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `acct` IS NOT NULL OR `dept` IS NOT NULL AND `rg`.`abc` = '10'"
def can_compile_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.on_null("bgt")
),
).to_sql()
"""
return "SELECT * FROM `users` INNER JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` AND `bgt` IS NULL"
def can_compile_left_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.left_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return "SELECT * FROM `users` LEFT JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` OR `bgt` IS NULL"
def can_compile_right_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.right_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return "SELECT * FROM `users` RIGHT JOIN `report_groups` AS `rg` ON `bgt`.`fund` = `rg`.`fund` OR `bgt` IS NULL"
def shared_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' LOCK IN SHARE MODE"
def update_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return "SELECT * FROM `users` WHERE `users`.`votes` >= '100' FOR UPDATE"
def can_user_where_raw_and_where(self):
"""
builder.where_raw("`age` = '18'").where("name", "=", "James").to_sql()
"""
return "SELECT * FROM `users` WHERE age = '18' AND `users`.`name` = 'James'"
def where_exists_with_lambda(self):
return """SELECT * FROM `users` WHERE EXISTS (SELECT * FROM `users` WHERE `users`.`age` = '1')"""
def where_not_exists_with_lambda(self):
return """SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `users` WHERE `users`.`age` = '1')"""
def where_date(self):
return (
"""SELECT * FROM `users` WHERE DATE(`users`.`created_at`) = '2022-06-01'"""
)
def or_where_null(self):
return """SELECT * FROM `users` WHERE `users`.`column1` IS NULL OR `users`.`column2` IS NULL"""
def select_distinct(self):
return """SELECT DISTINCT `users`.`group` FROM `users`"""
================================================
FILE: tests/mysql/grammar/test_mysql_update_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.expressions import Raw
class BaseTestCaseUpdateGrammar:
def setUp(self):
self.builder = QueryBuilder(self.grammar, table="users")
def test_can_compile_update(self):
to_sql = (
self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_multiple_update(self):
to_sql = self.builder.update(
{"name": "Joe", "email": "user@email.com"}, dry=True
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_update_with_multiple_where(self):
to_sql = (
self.builder.where("name", "bob")
.where("age", 20)
.update({"name": "Joe"}, dry=True)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
# def test_can_compile_increment(self):
# to_sql = self.builder.increment("age")
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(to_sql, sql)
# def test_can_compile_decrement(self):
# to_sql = self.builder.decrement("age", 20).to_sql()
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(to_sql, sql)
def test_raw_expression(self):
to_sql = self.builder.update({"name": Raw("`username`")}, dry=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestMySQLUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase):
grammar = MySQLGrammar
def can_compile_update(self):
"""
builder.where('name', 'bob').update({
'name': 'Joe'
}).to_sql()
"""
return "UPDATE `users` SET `users`.`name` = 'Joe' WHERE `users`.`name` = 'bob'"
def raw_expression(self):
"""
builder.where('name', 'bob').update({
'name': 'Joe'
}).to_sql()
"""
return "UPDATE `users` SET `users`.`name` = `username`"
def can_compile_multiple_update(self):
"""
self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql()
"""
return "UPDATE `users` SET `users`.`name` = 'Joe', `users`.`email` = 'user@email.com'"
def can_compile_update_with_multiple_where(self):
"""
builder.where('name', 'bob').where('age', 20).update({
'name': 'Joe'
}).to_sql()
"""
return "UPDATE `users` SET `users`.`name` = 'Joe' WHERE `users`.`name` = 'bob' AND `users`.`age` = '20'"
def can_compile_increment(self):
"""
builder.increment('age').to_sql()
"""
return "UPDATE `users` SET `users`.`age` = `users`.`age` + '1'"
def can_compile_decrement(self):
"""
builder.decrement('age', 20).to_sql()
"""
return "UPDATE `users` SET `users`.`age` = `users`.`age` - '20'"
================================================
FILE: tests/mysql/model/test_accessors_and_mutators.py
================================================
import datetime
import json
import os
import unittest
import pendulum
from src.masoniteorm.collection import Collection
from src.masoniteorm.models import Model
from src.masoniteorm.query.grammars import MSSQLGrammar
from tests.User import User
class User(Model):
__casts__ = {"is_admin": "bool"}
def get_name_attribute(self):
return f"Hello, {self.get_raw_attribute('name')}"
def set_name_attribute(self, attribute):
return str(attribute).upper()
class SetUser(Model):
__casts__ = {"is_admin": "bool"}
def set_name_attribute(self, attribute):
return str(attribute).upper()
class TestAccessor(unittest.TestCase):
def test_can_get_accessor(self):
user = User.hydrate(
{"name": "joe", "email": "joe@masoniteproject.com", "is_admin": 1}
)
self.assertEqual(user.email, "joe@masoniteproject.com")
self.assertEqual(user.name, "Hello, joe")
self.assertTrue(user.is_admin is True, f"{user.is_admin} is not True")
def test_mutator(self):
user = SetUser.hydrate({"email": "joe@masoniteproject.com", "is_admin": 1})
user.name = "joe"
self.assertEqual(user.name, "JOE")
================================================
FILE: tests/mysql/model/test_model.py
================================================
import datetime
import json
import os
import unittest
import pendulum
from src.masoniteorm.collection import Collection
from src.masoniteorm.exceptions import ModelNotFound
from src.masoniteorm.models import Model
from tests.User import User
class ProfileFillable(Model):
__fillable__ = ["name"]
__table__ = "profiles"
__timestamps__ = None
class ProfileFillTimeStamped(Model):
__fillable__ = ["*"]
__table__ = "profiles"
class ProfileFillAsterisk(Model):
__fillable__ = ["*"]
__table__ = "profiles"
__timestamps__ = None
class ProfileGuarded(Model):
__guarded__ = ["email"]
__table__ = "profiles"
__timestamps__ = None
class ProfileGuardedAsterisk(Model):
__guarded__ = ["*"]
__table__ = "profiles"
__timestamps__ = None
class ProfileSerialize(Model):
__fillable__ = ["*"]
__table__ = "profiles"
__hidden__ = ["password"]
class ProfileSerializeWithVisible(Model):
__fillable__ = ["*"]
__table__ = "profiles"
__visible__ = ["name", "email"]
class ProfileSerializeWithVisibleAndHidden(Model):
__fillable__ = ["*"]
__table__ = "profiles"
__visible__ = ["name", "email"]
__hidden__ = ["password"]
class Profile(Model):
pass
class Company(Model):
pass
class User(Model):
@property
def meta(self):
return {"is_subscribed": True}
class ProductNames(Model):
pass
class TestModel(unittest.TestCase):
def test_create_can_use_fillable(self):
sql = ProfileFillable.create(
{"name": "Joe", "email": "user@example.com"}, query=True
).to_sql()
self.assertEqual(
sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')"
)
def test_create_can_use_fillable_asterisk(self):
sql = ProfileFillAsterisk.create(
{"name": "Joe", "email": "user@example.com"}, query=True
).to_sql()
self.assertEqual(
sql,
"INSERT INTO `profiles` (`profiles`.`name`, `profiles`.`email`) VALUES ('Joe', 'user@example.com')",
)
def test_create_can_use_guarded(self):
sql = ProfileGuarded.create(
{"name": "Joe", "email": "user@example.com"}, query=True
).to_sql()
self.assertEqual(
sql, "INSERT INTO `profiles` (`profiles`.`name`) VALUES ('Joe')"
)
def test_create_can_use_guarded_asterisk(self):
sql = ProfileGuardedAsterisk.create(
{"name": "Joe", "email": "user@example.com"}, query=True
).to_sql()
# An asterisk guarded attribute excludes all fields from mass-assignment.
# This would raise a DB error if there are any required fields.
self.assertEqual(sql, "INSERT INTO `profiles` (*) VALUES ()")
def test_bulk_create_can_use_fillable(self):
query_builder = ProfileFillable.bulk_create(
[
{"name": "Joe", "email": "user@example.com"},
{"name": "Joe II", "email": "userII@example.com"},
],
query=True,
)
self.assertEqual(
query_builder.to_sql(),
"INSERT INTO `profiles` (`name`) VALUES ('Joe'), ('Joe II')",
)
def test_bulk_create_can_use_fillable_asterisk(self):
query_builder = ProfileFillAsterisk.bulk_create(
[
{"name": "Joe", "email": "user@example.com"},
{"name": "Joe II", "email": "userII@example.com"},
],
query=True,
)
self.assertEqual(
query_builder.to_sql(),
"INSERT INTO `profiles` (`email`, `name`) VALUES ('user@example.com', 'Joe'), ('userII@example.com', 'Joe II')",
)
def test_bulk_create_can_use_guarded(self):
query_builder = ProfileGuarded.bulk_create(
[
{"name": "Joe", "email": "user@example.com"},
{"name": "Joe II", "email": "userII@example.com"},
],
query=True,
)
self.assertEqual(
query_builder.to_sql(),
"INSERT INTO `profiles` (`name`) VALUES ('Joe'), ('Joe II')",
)
def test_bulk_create_can_use_guarded_asterisk(self):
query_builder = ProfileGuardedAsterisk.bulk_create(
[
{"name": "Joe", "email": "user@example.com"},
{"name": "Joe II", "email": "userII@example.com"},
],
query=True,
)
# An asterisk guarded attribute excludes all fields from mass-assignment.
# This would obviously raise an invalid SQL syntax error.
# TODO: Raise a clearer error?
self.assertEqual(
query_builder.to_sql(), "INSERT INTO `profiles` () VALUES (), ()"
)
def test_update_can_use_fillable(self):
query_builder = ProfileFillable().update(
{"name": "Joe", "email": "user@example.com"}, dry=True
)
self.assertEqual(
query_builder.to_sql(), "UPDATE `profiles` SET `profiles`.`name` = 'Joe'"
)
def test_update_can_use_fillable_asterisk(self):
query_builder = ProfileFillAsterisk().update(
{"name": "Joe", "email": "user@example.com"}, dry=True
)
self.assertEqual(
query_builder.to_sql(),
"UPDATE `profiles` SET `profiles`.`name` = 'Joe', `profiles`.`email` = 'user@example.com'",
)
def test_update_can_use_guarded(self):
query_builder = ProfileGuarded().update(
{"name": "Joe", "email": "user@example.com"}, dry=True
)
self.assertEqual(
query_builder.to_sql(), "UPDATE `profiles` SET `profiles`.`name` = 'Joe'"
)
def test_update_can_use_guarded_asterisk(self):
profile = ProfileGuardedAsterisk()
initial_sql = profile.get_builder().to_sql()
query_builder = profile.update(
{"name": "Joe", "email": "user@example.com"}, dry=True
)
# An asterisk guarded attribute excludes all fields from mass-assignment.
# The query builder's sql should not have been altered in any way.
self.assertEqual(query_builder.to_sql(), initial_sql)
def test_table_name(self):
table_name = Profile.get_table_name()
self.assertEqual(table_name, "profiles")
table_name = Company.get_table_name()
self.assertEqual(table_name, "companies")
table_name = ProductNames.get_table_name()
self.assertEqual(table_name, "product_names")
def test_serialize(self):
profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1})
self.assertEqual(profile.serialize(), {"name": "Joe", "id": 1})
def test_json(self):
profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1})
self.assertEqual(profile.to_json(), '{"name": "Joe", "id": 1}')
def test_serialize_with_hidden(self):
profile = ProfileSerialize.hydrate(
{"name": "Joe", "id": 1, "password": "secret"}
)
self.assertTrue(profile.serialize().get("name"))
self.assertTrue(profile.serialize().get("id"))
self.assertFalse(profile.serialize().get("password"))
def test_serialize_with_visible(self):
profile = ProfileSerializeWithVisible.hydrate(
{"name": "Joe", "id": 1, "password": "secret", "email": "joe@masonite.com"}
)
self.assertTrue(
{"name": "Joe", "email": "joe@masonite.com"}, profile.serialize()
)
def test_serialize_with_visible_and_hidden_raise_error(self):
profile = ProfileSerializeWithVisibleAndHidden.hydrate(
{"name": "Joe", "id": 1, "password": "secret", "email": "joe@masonite.com"}
)
with self.assertRaises(AttributeError):
profile.serialize()
def test_serialize_with_on_the_fly_appends(self):
user = User.hydrate({"name": "Joe", "id": 1})
user.set_appends(["meta"])
serialized = user.serialize()
self.assertEqual(serialized["id"], 1)
self.assertEqual(serialized["name"], "Joe")
self.assertEqual(serialized["meta"]["is_subscribed"], True)
def test_serialize_with_model_appends(self):
User.__appends__ = ["meta"]
user = User.hydrate({"name": "Joe", "id": 1})
serialized = user.serialize()
self.assertEqual(serialized["id"], 1)
self.assertEqual(serialized["name"], "Joe")
self.assertEqual(serialized["meta"]["is_subscribed"], True)
def test_serialize_with_date(self):
user = User.hydrate({"name": "Joe", "created_at": pendulum.now()})
self.assertTrue(json.dumps(user.serialize()))
def test_set_as_date(self):
user = User.hydrate(
{
"name": "Joe",
"created_at": pendulum.now().add(days=10).to_datetime_string(),
}
)
self.assertTrue(user.created_at)
self.assertTrue(user.created_at.is_future())
def test_access_as_date(self):
user = User.hydrate(
{
"name": "Joe",
"created_at": datetime.datetime.now() + datetime.timedelta(days=1),
}
)
self.assertTrue(user.created_at)
self.assertTrue(user.created_at.is_future())
def test_hydrate_with_none(self):
profile = ProfileFillAsterisk.hydrate(None)
self.assertEqual(profile, None)
def test_serialize_with_dirty_attribute(self):
profile = ProfileFillAsterisk.hydrate({"name": "Joe", "id": 1})
profile.age = 18
self.assertEqual(profile.serialize(), {"age": 18, "name": "Joe", "id": 1})
def test_attribute_check_with_hasattr(self):
self.assertFalse(hasattr(Profile(), "__password__"))
if os.getenv("RUN_MYSQL_DATABASE", "false").lower() == "true":
class MysqlTestModel(unittest.TestCase):
# TODO: these tests aren't getting run in CI... is that intentional?
def test_can_find_first(self):
profile = User.find(1)
def test_can_touch(self):
profile = ProfileFillTimeStamped.hydrate({"name": "Joe", "id": 1})
sql = profile.touch("now", query=True).to_sql()
self.assertEqual(
sql,
"UPDATE `profiles` SET `profiles`.`updated_at` = 'now' WHERE `profiles`.`id` = '1'",
)
def test_find_or_fail_raise_an_exception_if_not_exists(self):
with self.assertRaises(ModelNotFound):
User.find(100)
def test_returns_correct_data_type(self):
self.assertIsInstance(User.all(), Collection)
# self.assertIsInstance(User.first(), User)
# self.assertIsInstance(User.first(), User)
================================================
FILE: tests/mysql/relationships/test_belongs_to_many.py
================================================
import unittest
# from src.masoniteorm import query
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import (
has_one,
belongs_to_many,
has_one_through,
has_many,
)
from dotenv import load_dotenv
load_dotenv(".env")
class User(Model):
@has_one
def profile(self):
return Profile
class Profile(Model):
pass
class Permission(Model):
@belongs_to_many("permission_id", "role_id", "id", "id")
def role(self):
return Role
class PermissionSelect(Model):
__table__ = "permissions"
__selects__ = ["permission_id"]
@belongs_to_many("permission_id", "role_id", "id", "id")
def role(self):
return Role
class Role(Model):
@belongs_to_many("role_id", "permission_id", "id", "id")
def permissions(self):
return Permission
class MySQLRelationships(unittest.TestCase):
maxDiff = None
def test_belongs_to_many(self):
sql = Permission.where_has(
"role", lambda query: (query.where("slug", "users"))
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `permissions` WHERE EXISTS (SELECT * FROM `roles` INNER JOIN `permission_role` ON `roles`.`id` = `permission_role`.`role_id` WHERE `permission_role`.`permission_id` = `permissions`.`id` AND `roles`.`id` IN (SELECT `roles`.`id` FROM `roles` WHERE `roles`.`slug` = 'users'))""",
)
def test_belongs_to_many_has(self):
sql = Role.has("permissions").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `roles` WHERE EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id`)""",
)
def test_belongs_to_many_or_has(self):
sql = Role.where("name", "role_name").or_has("permissions").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id`)""",
)
def test_belongs_to_many_or_where_has(self):
sql = (
Role.where("name", "role_name")
.or_where_has("permissions", lambda q: q.where("permission_id", 1))
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`permission_id` = '1'))""",
)
def test_belongs_to_many_or_doesnt_have(self):
sql = Role.where("name", "role_name").or_doesnt_have("permissions").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR NOT EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id`)""",
)
def test_where_doesnt_have(self):
sql = (
Role.where("name", "role_name")
.where_doesnt_have(
"permissions", lambda q: q.where("name", "Creates Users")
)
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' AND NOT EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`name` = 'Creates Users'))""",
)
def test_or_where_doesnt_have(self):
sql = (
Role.where("name", "role_name")
.or_where_doesnt_have(
"permissions", lambda q: q.where("name", "Creates Users")
)
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `roles` WHERE `roles`.`name` = 'role_name' OR NOT EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`name` = 'Creates Users'))""",
)
def test_belongs_to_many_where_has(self):
sql = Role.where_has(
"permissions", lambda q: q.where("name", "Creates Users")
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `roles` WHERE EXISTS (SELECT * FROM `permissions` INNER JOIN `permission_role` ON `permissions`.`id` = `permission_role`.`permission_id` WHERE `permission_role`.`role_id` = `roles`.`id` AND `permissions`.`id` IN (SELECT `permissions`.`id` FROM `permissions` WHERE `permissions`.`name` = 'Creates Users'))""",
)
def test_belongs_to_many_relate_method(self):
permission = Permission.hydrate({"id": 1, "name": "Create Users"})
sql = permission.related("role").to_sql()
self.assertEqual(
sql,
"""SELECT `roles`.*, `permission_role`.`permission_id` AS permission_role_id, `permission_role`.`role_id` AS m_reserved2, `permission_role`.`id` AS m_reserved3 FROM `permissions` INNER JOIN `permission_role` ON `permission_role`.`permission_id` = `permissions`.`id` INNER JOIN `roles` ON `permission_role`.`role_id` = `roles`.`id`""",
)
def test_belongs_to_many_relate_method_reversed(self):
role = Role.hydrate({"id": 1, "name": "Create Users"})
sql = role.related("permissions").to_sql()
self.assertEqual(
sql,
"""SELECT `permissions`.*, `permission_role`.`role_id` AS permission_role_id, `permission_role`.`permission_id` AS m_reserved2, `permission_role`.`id` AS m_reserved3 FROM `roles` INNER JOIN `permission_role` ON `permission_role`.`role_id` = `roles`.`id` INNER JOIN `permissions` ON `permission_role`.`permission_id` = `permissions`.`id`""",
)
def test_belongs_to_many_joins(self):
sql = Role.joins("permissions").to_sql()
self.assertEqual(
sql,
"""SELECT `roles`.*, `permission_role`.`role_id` AS permission_role_id, `permission_role`.`permission_id` AS m_reserved2, `permission_role`.`id` AS m_reserved3 FROM `roles` INNER JOIN `permission_role` ON `permission_role`.`role_id` = `roles`.`id` INNER JOIN `permissions` ON `permission_role`.`id` = `permissions`.`id`""",
)
def test_with_count(self):
sql = Permission.with_count("role").to_sql()
self.assertEqual(
sql,
"""SELECT `permissions`.*, (SELECT COUNT(*) AS m_count_reserved FROM `permission_role` WHERE `permissions`.`id` = `permission_role`.`permission_id`) AS roles_count FROM `permissions`""",
)
def test_with_count_with_selects(self):
sql = PermissionSelect.with_count("role").to_sql()
self.assertEqual(
sql,
"""SELECT `permissions`.`permission_id`, (SELECT COUNT(*) AS m_count_reserved FROM `permission_role` WHERE `permissions`.`id` = `permission_role`.`permission_id`) AS roles_count FROM `permissions`""",
)
================================================
FILE: tests/mysql/relationships/test_has_many_through.py
================================================
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import (
has_many_through,
)
from dotenv import load_dotenv
load_dotenv(".env")
class InboundShipment(Model):
@has_many_through("port_id", "country_id", "from_port_id", "country_id")
def from_country(self):
return Country, Port
class Country(Model):
pass
class Port(Model):
pass
class MySQLRelationships(unittest.TestCase):
maxDiff = None
def test_has_query(self):
sql = InboundShipment.has("from_country").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)
def test_or_has(self):
sql = InboundShipment.where("name", "Joe").or_has("from_country").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)
def test_where_has_query(self):
sql = InboundShipment.where_has(
"from_country", lambda query: query.where("name", "USA")
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
def test_or_where_has(self):
sql = (
InboundShipment.where("name", "Joe")
.or_where_has("from_country", lambda query: query.where("name", "USA"))
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
def test_doesnt_have(self):
sql = InboundShipment.doesnt_have("from_country").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)
def test_or_where_doesnt_have(self):
sql = (
InboundShipment.where("name", "Joe")
.or_where_doesnt_have(
"from_country", lambda query: query.where("name", "USA")
)
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
================================================
FILE: tests/mysql/relationships/test_has_one_through.py
================================================
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import (
has_one_through,
)
from dotenv import load_dotenv
load_dotenv(".env")
class InboundShipment(Model):
@has_one_through(None, "from_port_id", "country_id", "port_id", "country_id")
def from_country(self):
return Country, Port
class Country(Model):
pass
class Port(Model):
pass
class MySQLHasOneThroughRelationship(unittest.TestCase):
maxDiff = None
def test_has_query(self):
sql = InboundShipment.has("from_country").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)
def test_or_has(self):
sql = InboundShipment.where("name", "Joe").or_has("from_country").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)
def test_where_has_query(self):
sql = InboundShipment.where_has(
"from_country", lambda query: query.where("name", "USA")
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
def test_or_where_has(self):
sql = (
InboundShipment.where("name", "Joe")
.or_where_has("from_country", lambda query: query.where("name", "USA"))
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
def test_doesnt_have(self):
sql = InboundShipment.doesnt_have("from_country").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`)""",
)
def test_or_where_doesnt_have(self):
sql = (
InboundShipment.where("name", "Joe")
.or_where_doesnt_have(
"from_country", lambda query: query.where("name", "USA")
)
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `inbound_shipments` WHERE `inbound_shipments`.`name` = 'Joe' OR NOT EXISTS (SELECT * FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id` AND `countries`.`name` = 'USA')""",
)
def test_has_one_through_with_count(self):
sql = InboundShipment.with_count("from_country").to_sql()
self.assertEqual(
sql,
"""SELECT `inbound_shipments`.*, (SELECT COUNT(*) AS m_count_reserved FROM `countries` INNER JOIN `ports` ON `ports`.`country_id` = `countries`.`country_id` WHERE `ports`.`port_id` = `inbound_shipments`.`from_port_id`) AS from_country_count FROM `inbound_shipments`""",
)
================================================
FILE: tests/mysql/relationships/test_relationships.py
================================================
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import (
has_one,
belongs_to_many,
has_one_through,
has_many,
)
from dotenv import load_dotenv
load_dotenv(".env")
class User(Model):
@has_one
def profile(self):
return Profile
class Profile(Model):
@has_one
def identification(self):
return Identification
class Identification(Model):
pass
class MySQLRelationships(unittest.TestCase):
maxDiff = None
def test_has(self):
sql = User.has("profile").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`)""",
)
def test_has_nested(self):
sql = User.has("profile.identification").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id`))""",
)
def test_or_has(self):
sql = User.where("name", "Joe").or_has("profile").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`)""",
)
def test_or_has_nested(self):
sql = User.where("name", "Joe").or_has("profile.identification").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id`))""",
)
def test_relationship_where_has(self):
sql = (
User.where("name", "Joe")
.where_has("profile", lambda q: q.where("profile_id", 1))
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE `users`.`name` = 'Joe' AND EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""",
)
def test_relationship_where_has_nested(self):
sql = (
User.where("name", "Joe")
.where_has(
"profile.identification", lambda q: q.where("identification_id", 1)
)
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE `users`.`name` = 'Joe' AND EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id` AND `identifications`.`identification_id` = '1'))""",
)
def test_relationship_or_where_has(self):
sql = (
User.where("name", "Joe")
.or_where_has("profile", lambda q: q.where("profile_id", 1))
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""",
)
def test_relationship_or_where_has_nested(self):
sql = (
User.where("name", "Joe")
.or_where_has(
"profile.identification", lambda q: q.where("identification_id", 1)
)
.to_sql()
)
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE `users`.`name` = 'Joe' OR EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id` AND `identifications`.`identification_id` = '1'))""",
)
def test_relationship_doesnt_have(self):
sql = User.doesnt_have("profile").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`)""",
)
def test_relationship_doesnt_have_nested(self):
sql = User.doesnt_have("profile.identification").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `profiles`.`id`))""",
)
def test_relationship_where_doesnt_have(self):
sql = User.where_doesnt_have(
"profile", lambda q: q.where("profile_id", 1)
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""",
)
def test_relationship_where_doesnt_have_nested(self):
sql = User.where_doesnt_have(
"profile.identification", lambda q: q.where("identification_id", 1)
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`) AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `users`.`id` AND `identifications`.`identification_id` = '1')""",
)
def test_relationship_or_where_doesnt_have(self):
sql = User.or_where_doesnt_have(
"profile", lambda q: q.where("profile_id", 1)
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id` AND `profiles`.`profile_id` = '1')""",
)
def test_relationship_or_where_doesnt_have_nested(self):
sql = User.or_where_doesnt_have(
"profile.identification", lambda q: q.where("identification_id", 1)
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` WHERE NOT EXISTS (SELECT * FROM `profiles` WHERE `profiles`.`profile_id` = `users`.`id`) AND EXISTS (SELECT * FROM `identifications` WHERE `identifications`.`identification_id` = `users`.`id` AND `identifications`.`identification_id` = '1')""",
)
def test_joins(self):
sql = User.joins("profile").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` INNER JOIN `profiles` ON `users`.`id` = `profiles`.`profile_id`""",
)
def test_join_on(self):
sql = User.join_on("profile", lambda q: (q.where("active", 1))).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM `users` INNER JOIN `profiles` ON `users`.`id` = `profiles`.`profile_id` WHERE (`profiles`.`active` = '1')""",
)
================================================
FILE: tests/mysql/schema/test_mysql_schema_builder.py
================================================
import os
import unittest
from src.masoniteorm import Model
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import MySQLConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import MySQLPlatform
from tests.integrations.config.database import DATABASES
class Discussion(Model):
pass
class TestMySQLSchemaBuilder(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection_class=MySQLConnection,
connection="mysql",
connection_details=DATABASES,
platform=MySQLPlatform,
dry=True,
).on("mysql")
def test_can_add_columns1(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)"
],
)
def test_can_add_tiny_text(self):
with self.schema.create("users") as blueprint:
blueprint.tiny_text("description")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
["CREATE TABLE `users` (`description` TINYTEXT NOT NULL)"],
)
def test_can_add_unsigned_decimal(self):
with self.schema.create("users") as blueprint:
blueprint.unsigned_decimal("amount", 19, 4)
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
["CREATE TABLE `users` (`amount` DECIMAL(19, 4) UNSIGNED NOT NULL)"],
)
def test_can_create_table_if_not_exists(self):
with self.schema.create_table_if_not_exists("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE IF NOT EXISTS `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL)"
],
)
def test_can_add_columns_with_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
blueprint.unique("name"),
blueprint.unique("name", name="table_unique"),
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, `age` INT(11) NOT NULL, CONSTRAINT users_name_unique UNIQUE (name), CONSTRAINT table_unique UNIQUE (name))"
],
)
def test_add_column_comment(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").comment("A users username")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL COMMENT 'A users username')"
],
)
def test_can_add_table_comment(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.table_comment("A users table")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL) COMMENT 'A users table'"
],
)
def test_can_add_columns_with_foreign_key_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.foreign("profile_id").references("id").on("profiles")
blueprint.foreign_id("post_id").references("id").on("posts")
blueprint.foreign_id_for(Discussion).references("id").on("discussions")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, "
"`age` INT(11) NOT NULL, "
"`profile_id` INT(11) NOT NULL, "
"`post_id` BIGINT UNSIGNED NOT NULL, "
"CONSTRAINT users_name_unique UNIQUE (name), "
"CONSTRAINT users_profile_id_foreign FOREIGN KEY (`profile_id`) REFERENCES `profiles`(`id`), "
"CONSTRAINT users_profile_id_foreign FOREIGN KEY (`post_id`) REFERENCES `posts`(`id`)), "
"CONSTRAINT users_discussions_id_foreign FOREIGN KEY (`discussion_id`) REFERENCES `posts`(`id`))"
],
)
def test_can_add_columns_with_foreign_key_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.add_foreign("profile_id.id.profiles")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`name` VARCHAR(255) NOT NULL, "
"`age` INT(11) NOT NULL, "
"`profile_id` INT(11) NOT NULL, "
"CONSTRAINT users_name_unique UNIQUE (name), "
"CONSTRAINT users_profile_id_foreign FOREIGN KEY (`profile_id`) REFERENCES `profiles`(`id`))"
],
)
def test_can_advanced_table_creation(self):
with self.schema.create("users") as blueprint:
blueprint.increments("id")
blueprint.id("id2")
blueprint.string("name")
blueprint.tiny_integer("active")
blueprint.string("email").unique()
blueprint.enum("gender", ["male", "female"])
blueprint.string("password")
blueprint.decimal("money")
blueprint.integer("admin").default(0)
blueprint.string("option").default("ADMIN")
blueprint.string("remember_token").nullable()
blueprint.timestamp("verified_at").nullable()
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 14)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`id` INT UNSIGNED AUTO_INCREMENT NOT NULL, "
"`id2` BIGINT UNSIGNED AUTO_INCREMENT NOT NULL, "
"`name` VARCHAR(255) NOT NULL, `active` TINYINT(1) NOT NULL, `email` VARCHAR(255) NOT NULL, `gender` ENUM('male', 'female') NOT NULL, "
"`password` VARCHAR(255) NOT NULL, `money` DECIMAL(17, 6) NOT NULL, "
"`admin` INT(11) NOT NULL DEFAULT 0, `option` VARCHAR(255) NOT NULL DEFAULT 'ADMIN', `remember_token` VARCHAR(255) NULL, `verified_at` TIMESTAMP NULL, "
"`created_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, `updated_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_id2_primary PRIMARY KEY (id2), CONSTRAINT users_email_unique UNIQUE (email))"
],
)
def test_can_add_primary_constraint_without_column_name(self):
with self.schema.create("users") as blueprint:
blueprint.integer("user_id").primary()
blueprint.string("name")
blueprint.string("email")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(len(blueprint.table.added_constraints), 1)
self.assertTrue(
blueprint.to_sql()[0].startswith(
"CREATE TABLE `users` (`user_id` INT(11) NOT NULL"
)
)
def test_can_advanced_table_creation2(self):
with self.schema.create("users") as blueprint:
blueprint.big_increments("id")
blueprint.string("name")
blueprint.string("duration")
blueprint.string("url")
blueprint.inet("last_address").nullable()
blueprint.cidr("route_origin").nullable()
blueprint.macaddr("mac_address").nullable()
blueprint.datetime("published_at")
blueprint.string("thumbnail").nullable()
blueprint.integer("premium")
blueprint.integer("author_id").unsigned().nullable()
blueprint.foreign("author_id").references("id").on("users").on_delete(
"CASCADE"
)
blueprint.text("description")
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 14)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`id` BIGINT UNSIGNED AUTO_INCREMENT NOT NULL, `name` VARCHAR(255) NOT NULL, "
"`duration` VARCHAR(255) NOT NULL, `url` VARCHAR(255) NOT NULL, `last_address` VARCHAR(255) NULL, `route_origin` VARCHAR(255) NULL, `mac_address` VARCHAR(255) NULL, "
"`published_at` DATETIME NOT NULL, `thumbnail` VARCHAR(255) NULL, "
"`premium` INT(11) NOT NULL, `author_id` INT(11) UNSIGNED NULL, `description` TEXT NOT NULL, `created_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, "
"`updated_at` DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY (`author_id`) REFERENCES `users`(`id`) ON DELETE CASCADE)"
],
)
def test_can_add_columns_with_foreign_key_constraint_name(self):
with self.schema.create("users") as blueprint:
blueprint.integer("profile_id")
blueprint.foreign("profile_id", name="profile_foreign").references("id").on(
"profiles"
)
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` ("
"`profile_id` INT(11) NOT NULL, "
"CONSTRAINT profile_foreign FOREIGN KEY (`profile_id`) REFERENCES `profiles`(`id`))"
],
)
def test_can_have_composite_keys(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.primary(["name", "age"])
self.assertEqual(len(blueprint.table.added_columns), 3)
print(blueprint.to_sql())
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` "
"(`name` VARCHAR(255) NOT NULL, "
"`age` INT(11) NOT NULL, "
"`profile_id` INT(11) NOT NULL, "
"CONSTRAINT users_name_unique UNIQUE (name), "
"CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))"
],
)
def test_can_have_column_primary_key(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").primary()
blueprint.integer("age")
blueprint.integer("profile_id")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` "
"(`name` VARCHAR(255) NOT NULL, "
"`age` INT(11) NOT NULL, "
"`profile_id` INT(11) NOT NULL, "
"CONSTRAINT users_name_primary PRIMARY KEY (name))"
],
)
def test_can_have_unsigned_columns(self):
with self.schema.create("users") as blueprint:
blueprint.integer("profile_id").unsigned()
blueprint.big_integer("big_profile_id").unsigned()
blueprint.tiny_integer("tiny_profile_id").unsigned()
blueprint.small_integer("small_profile_id").unsigned()
blueprint.medium_integer("medium_profile_id").unsigned()
blueprint.unsigned_integer("unsigned_profile_id")
blueprint.unsigned_big_integer("unsigned_big_profile_id")
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` ("
"`profile_id` INT(11) UNSIGNED NOT NULL, "
"`big_profile_id` BIGINT(32) UNSIGNED NOT NULL, "
"`tiny_profile_id` TINYINT(1) UNSIGNED NOT NULL, "
"`small_profile_id` SMALLINT(5) UNSIGNED NOT NULL, "
"`medium_profile_id` MEDIUMINT(7) UNSIGNED NOT NULL, "
"`unsigned_profile_id` INT UNSIGNED NOT NULL, "
"`unsigned_big_profile_id` BIGINT(32) UNSIGNED NOT NULL)"
],
)
def test_can_have_default_blank_string(self):
with self.schema.create("users") as blueprint:
blueprint.string("profile_id").default("")
self.assertEqual(
blueprint.to_sql(),
["CREATE TABLE `users` (" "`profile_id` VARCHAR(255) NOT NULL DEFAULT '')"],
)
def test_can_have_float_type(self):
with self.schema.create("users") as blueprint:
blueprint.float("amount")
self.assertEqual(
blueprint.to_sql(),
["CREATE TABLE `users` (" "`amount` FLOAT(19, 4) NOT NULL)"],
)
def test_has_table(self):
schema_sql = self.schema.has_table("users")
sql = f"SELECT * from information_schema.tables where table_name='users' AND table_schema = '{os.getenv('MYSQL_DATABASE_DATABASE')}'"
self.assertEqual(schema_sql, sql)
def test_can_truncate(self):
sql = self.schema.truncate("users")
self.assertEqual(sql, "TRUNCATE `users`")
def test_can_rename_table(self):
sql = self.schema.rename("users", "clients")
self.assertEqual(sql, "ALTER TABLE `users` RENAME TO `clients`")
def test_can_drop_table_if_exists(self):
sql = self.schema.drop_table_if_exists("users", "clients")
self.assertEqual(sql, "DROP TABLE IF EXISTS `users`")
def test_can_drop_table(self):
sql = self.schema.drop_table("users", "clients")
self.assertEqual(sql, "DROP TABLE `users`")
def test_has_column(self):
sql = self.schema.has_column("users", "name")
self.assertEqual(
sql,
"SELECT column_name FROM information_schema.columns WHERE table_name='users' and column_name='name'",
)
def test_can_enable_foreign_keys(self):
sql = self.schema.enable_foreign_key_constraints()
self.assertEqual(sql, "SET FOREIGN_KEY_CHECKS=1")
def test_can_disable_foreign_keys(self):
sql = self.schema.disable_foreign_key_constraints()
self.assertEqual(sql, "SET FOREIGN_KEY_CHECKS=0")
def test_can_truncate_without_foreign_keys(self):
sql = self.schema.truncate("users", foreign_keys=True)
self.assertEqual(
sql,
[
"SET FOREIGN_KEY_CHECKS=0",
"TRUNCATE `users`",
"SET FOREIGN_KEY_CHECKS=1",
],
)
def test_can_add_enum(self):
with self.schema.create("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
"CREATE TABLE `users` (`status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active')"
],
)
================================================
FILE: tests/mysql/schema/test_mysql_schema_builder_alter.py
================================================
import unittest
import os
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import MySQLConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import MySQLPlatform
from src.masoniteorm.schema.Table import Table
class TestMySQLSchemaBuilderAlter(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection_class=MySQLConnection,
connection="mysql",
connection_details=DATABASES,
platform=MySQLPlatform,
dry=True,
).on("mysql")
def test_can_add_columns(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
sql = [
"ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL, ADD `age` INT(11) NOT NULL"
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_add_column_comments(self):
with self.schema.table("users") as blueprint:
blueprint.string("name").comment("A users username")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
"ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL COMMENT 'A users username'"
]
def test_can_add_table_comment(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.table_comment("A users username")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
"ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL COMMENT 'A users username'"
]
def test_can_add_table_comment_with_no_columns(self):
with self.schema.table("users") as blueprint:
blueprint.table_comment("A users username")
self.assertEqual(len(blueprint.table.added_columns), 0)
sql = ["ALTER TABLE `users` COMMENT 'A users username'"]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_add_column_after(self):
with self.schema.table("users") as blueprint:
blueprint.string("name").after("age")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = ["ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `age`"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_rename(self):
with self.schema.table("users") as blueprint:
blueprint.rename("post", "comment", "integer")
table = Table("users")
table.add_column("post", "integer")
blueprint.table.from_table = table
sql = ["ALTER TABLE `users` CHANGE `post` `comment` INT NOT NULL"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_and_rename(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.rename("post", "comment", "string")
table = Table("users")
table.add_column("post", "string")
blueprint.table.from_table = table
sql = [
"ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL",
"ALTER TABLE `users` CHANGE `post` `comment` VARCHAR NOT NULL",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_and_rename_to_string(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.rename("post", "comment", "string", length=200)
table = Table("users")
table.add_column("post", "integer")
blueprint.table.from_table = table
sql = [
"ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL",
"ALTER TABLE `users` CHANGE `post` `comment` VARCHAR(200) NOT NULL",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop1(self):
with self.schema.table("users") as blueprint:
blueprint.drop_column("post")
sql = ["ALTER TABLE `users` DROP COLUMN `post`"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_column_and_foreign_key(self):
with self.schema.table("users") as blueprint:
blueprint.unsigned_integer("playlist_id").nullable()
blueprint.foreign("playlist_id").references("id").on("playlists").on_delete(
"cascade"
)
sql = [
"ALTER TABLE `users` ADD `playlist_id` INT UNSIGNED NULL",
"ALTER TABLE `users` ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY (playlist_id) REFERENCES playlists(id) ON DELETE CASCADE",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_foreign_key(self):
with self.schema.table("users") as blueprint:
blueprint.drop_foreign("users_playlist_id_foreign")
sql = ["ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_foreign_key_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_foreign(["playlist_id"])
sql = ["ALTER TABLE `users` DROP FOREIGN KEY users_playlist_id_foreign"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_unique_constraint(self):
with self.schema.table("users") as blueprint:
blueprint.drop_unique("users_playlist_id_unique")
sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_unique"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_index(self):
with self.schema.table("users") as blueprint:
blueprint.index("playlist_id")
sql = ["CREATE INDEX users_playlist_id_index ON `users`(playlist_id)"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_index(self):
with self.schema.table("users") as blueprint:
blueprint.drop_index("users_playlist_id_index")
sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_index"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_primary(self):
with self.schema.table("users") as blueprint:
blueprint.primary("playlist_id")
sql = [
"ALTER TABLE `users` ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)"
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_index_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_index(["playlist_id"])
sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_index"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_unique_constraint_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_unique(["playlist_id"])
sql = ["ALTER TABLE `users` DROP INDEX users_playlist_id_unique"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_primary(self):
with self.schema.table("users") as blueprint:
blueprint.drop_primary("users_id_primary")
sql = ["ALTER TABLE `users` DROP INDEX users_id_primary"]
self.assertEqual(blueprint.to_sql(), sql)
def test_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").change()
blueprint.string("external_type").default("external")
blueprint.integer("gender").nullable().change()
blueprint.string("name")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(len(blueprint.table.changed_columns), 2)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = [
"ALTER TABLE `users` ADD `external_type` VARCHAR(255) NOT NULL DEFAULT 'external', ADD `name` VARCHAR(255) NOT NULL",
"ALTER TABLE `users` MODIFY `age` INT(11) NOT NULL, MODIFY `gender` INT(11) NULL",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_timestamp_alter_add_nullable_column(self):
with self.schema.table("users") as blueprint:
blueprint.timestamp("due_date").nullable()
self.assertEqual(len(blueprint.table.added_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = ["ALTER TABLE `users` ADD `due_date` TIMESTAMP NULL"]
self.assertEqual(blueprint.to_sql(), sql)
def test_drop_add_and_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").default(0).change()
blueprint.string("name")
blueprint.drop_column("email")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
table.add_column("email", "string")
blueprint.table.from_table = table
sql = [
"ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL",
"ALTER TABLE `users` MODIFY `age` INT(11) NOT NULL DEFAULT 0",
"ALTER TABLE `users` DROP COLUMN `email`",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_create_indexes(self):
with self.schema.table("users") as blueprint:
blueprint.index("name")
blueprint.index(["name", "email"])
blueprint.unique("name")
blueprint.unique("name", name="table_unique")
blueprint.unique(["name", "email"])
blueprint.fulltext("description")
self.assertEqual(len(blueprint.table.added_columns), 0)
print(blueprint.to_sql())
self.assertEqual(
blueprint.to_sql(),
[
"CREATE INDEX users_name_index ON `users`(name)",
"CREATE INDEX users_name_email_index ON `users`(name,email)",
"ALTER TABLE `users` ADD CONSTRAINT UNIQUE INDEX users_name_unique(name)",
"ALTER TABLE `users` ADD CONSTRAINT UNIQUE INDEX table_unique(name)",
"ALTER TABLE `users` ADD CONSTRAINT UNIQUE INDEX users_name_email_unique(name,email)",
"ALTER TABLE `users` ADD FULLTEXT description_fulltext(description)",
],
)
def test_can_add_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
"ALTER TABLE `users` ADD `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'"
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_change_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active").change()
self.assertEqual(len(blueprint.table.changed_columns), 1)
sql = [
"ALTER TABLE `users` MODIFY `status` ENUM('active', 'inactive') NOT NULL DEFAULT 'active'"
]
self.assertEqual(blueprint.to_sql(), sql)
================================================
FILE: tests/mysql/scopes/test_can_use_global_scopes.py
================================================
import inspect
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.scopes import (
SoftDeleteScope,
SoftDeletesMixin,
TimeStampsMixin,
scope,
)
class UserSoft(Model, SoftDeletesMixin):
__dry__ = True
class User(Model):
__dry__ = True
class TestMySQLGlobalScopes(unittest.TestCase):
def test_can_use_global_scopes_on_select(self):
sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`name` = 'joe' AND `user_softs`.`deleted_at` IS NULL"
self.assertEqual(sql, UserSoft.where("name", "joe").to_sql())
# def test_can_use_global_scopes_on_delete(self):
# sql = "UPDATE `users` SET `users`.`deleted_at` = 'now' WHERE `users`.`name` = 'joe'"
# self.assertEqual(
# sql,
# User.apply_scope(SoftDeletes)
# .where("name", "joe")
# .delete(query=True)
# .to_sql(),
# )
def test_can_use_global_scopes_on_time(self):
sql = "INSERT INTO `users` (`users`.`name`, `users`.`updated_at`, `users`.`created_at`) VALUES ('Joe'"
self.assertTrue(User.create({"name": "Joe"}, query=True).to_sql().startswith(sql))
# def test_can_use_global_scopes_on_inherit(self):
# sql = "SELECT * FROM `user_softs` WHERE `user_softs`.`deleted_at` IS NULL"
# self.assertEqual(sql, UserSoft.all(query=True))
================================================
FILE: tests/mysql/scopes/test_can_use_scopes.py
================================================
import inspect
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.scopes import SoftDeletesMixin, scope
from tests.User import User
class User(Model):
__dry__ = True
@scope
def active(self, query, status):
return query.where("active", status)
@scope
def gender(self, query, status):
return query.where("gender", status)
class UserSoft(Model, SoftDeletesMixin):
__dry__ = True
class TestMySQLScopes(unittest.TestCase):
def test_can_get_sql(self):
sql = "SELECT * FROM `users` WHERE `users`.`name` = 'joe'"
self.assertEqual(sql, User.where("name", "joe").to_sql())
def test_active_scope(self):
sql = "SELECT * FROM `users` WHERE `users`.`name` = 'joe' AND `users`.`active` = '1'"
self.assertEqual(sql, User.where("name", "joe").active(1).to_sql())
def test_active_scope_with_params(self):
sql = "SELECT * FROM `users` WHERE `users`.`active` = '2' AND `users`.`name` = 'joe'"
self.assertEqual(sql, User.active(2).where("name", "joe").to_sql())
def test_can_chain_scopes(self):
sql = "SELECT * FROM `users` WHERE `users`.`active` = '2' AND `users`.`gender` = 'W' AND `users`.`name` = 'joe'"
self.assertEqual(sql, User.active(2).gender("W").where("name", "joe").to_sql())
================================================
FILE: tests/mysql/scopes/test_soft_delete.py
================================================
import unittest
import pendulum
from tests.integrations.config.database import DATABASES
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import MySQLGrammar
from src.masoniteorm.scopes import SoftDeleteScope
from tests.utils import MockConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.scopes import SoftDeletesMixin
class UserSoft(Model, SoftDeletesMixin):
__dry__ = True
__table__ = "users"
class UserSoftArchived(Model, SoftDeletesMixin):
__dry__ = True
__deleted_at__ = "archived_at"
__table__ = "users"
class TestSoftDeleteScope(unittest.TestCase):
def get_builder(self, table="users"):
connection = MockConnectionFactory().make("default")
return QueryBuilder(
grammar=MySQLGrammar,
connection_class=connection,
connection="mysql",
table=table,
connection_details=DATABASES,
dry=True,
)
def test_with_trashed(self):
sql = "SELECT * FROM `users`"
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(sql, builder.with_trashed().to_sql())
def test_force_delete(self):
sql = "DELETE FROM `users`"
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(sql, builder.force_delete(query=True).to_sql())
def test_restore(self):
sql = "UPDATE `users` SET `users`.`deleted_at` = 'None'"
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(sql, builder.restore().to_sql())
def test_force_delete_with_wheres(self):
sql = "DELETE FROM `users` WHERE `users`.`active` = '1'"
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(
sql, UserSoft.where("active", 1).force_delete(query=True).to_sql()
)
def test_that_trashed_users_are_not_returned_by_default(self):
sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NULL"
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(sql, builder.to_sql())
def test_only_trashed(self):
sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL"
builder = self.get_builder().set_global_scope(SoftDeleteScope())
self.assertEqual(sql, builder.only_trashed().to_sql())
def test_only_trashed_on_model(self):
sql = "SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL"
self.assertEqual(sql, UserSoft.only_trashed().to_sql())
def test_can_change_column(self):
sql = "SELECT * FROM `users` WHERE `users`.`archived_at` IS NOT NULL"
self.assertEqual(sql, UserSoftArchived.only_trashed().to_sql())
def test_find_with_global_scope(self):
find_sql = UserSoft.find("1", query=True).to_sql()
raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1' AND `users`.`deleted_at` IS NULL"""
self.assertEqual(find_sql, raw_sql)
def test_find_with_trashed_scope(self):
find_sql = UserSoft.with_trashed().find("1", query=True).to_sql()
raw_sql = """SELECT * FROM `users` WHERE `users`.`id` = '1'"""
self.assertEqual(find_sql, raw_sql)
def test_find_with_only_trashed_scope(self):
find_sql = UserSoft.only_trashed().find("1", query=True).to_sql()
raw_sql = """SELECT * FROM `users` WHERE `users`.`deleted_at` IS NOT NULL AND `users`.`id` = '1'"""
self.assertEqual(find_sql, raw_sql)
================================================
FILE: tests/postgres/builder/test_postgres_query_builder.py
================================================
import inspect
import unittest
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import PostgresGrammar
from tests.utils import MockConnectionFactory
class MockConnection:
connection_details = {}
def make_connection(self):
return self
@classmethod
def get_default_query_grammar(cls):
return
class ModelTest(Model):
__timestamps__ = False
class BaseTestQueryBuilder:
def get_builder(self, table="users", dry=True):
connection = MockConnectionFactory().make("postgres")
return QueryBuilder(
self.grammar,
connection_class=connection,
connection="postgres",
table=table,
model=ModelTest(),
dry=dry,
)
def test_sum(self):
builder = self.get_builder()
builder.sum("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_like(self):
builder = self.get_builder()
builder.where("age", "like", "%name%")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_like(self):
builder = self.get_builder()
builder.where("age", "not like", "%name%")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_max(self):
builder = self.get_builder()
builder.max("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_min(self):
builder = self.get_builder()
builder.min("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_avg(self):
builder = self.get_builder()
builder.avg("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_all(self):
builder = self.get_builder()
builder.all()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_get(self):
builder = self.get_builder()
builder.get()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_first(self):
builder = self.get_builder().first(query=True)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select(self):
builder = self.get_builder()
builder.select("name", "email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_add_select_no_table(self):
builder = self.get_builder(table=None)
sql = (
builder.add_select(
"other_test", lambda q: q.max("updated_at").table("different_table")
)
.add_select(
"some_alias", lambda q: q.max("updated_at").table("another_table")
)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select_raw(self):
builder = self.get_builder()
builder.select_raw("count(email) as email_count")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_create(self):
builder = self.get_builder().without_global_scopes()
builder.create(
{"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_delete(self):
builder = self.get_builder()
builder.delete("name", "Joe", query=True)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where(self):
builder = self.get_builder()
builder.where("name", "Joe")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_exists(self):
builder = self.get_builder()
builder.where_exists("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_limit(self):
builder = self.get_builder()
builder.limit(5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_offset(self):
builder = self.get_builder()
builder.offset(5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_join(self):
builder = self.get_builder()
builder.join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_left_join(self):
builder = self.get_builder()
builder.left_join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_right_join(self):
builder = self.get_builder()
builder.right_join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_update(self):
builder = self.get_builder().update(
{"name": "Joe", "email": "joe@yopmail.com"}, dry=True
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
# def test_increment(self):
# builder = self.get_builder()
# builder.increment("age", 1)
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(builder.to_sql(), sql)
# def test_decrement(self):
# builder = self.get_builder()
# builder.decrement("age", 1)
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(builder.to_sql(), sql)
def test_count(self):
builder = self.get_builder()
builder.count("id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_asc(self):
builder = self.get_builder()
builder.order_by("email", "asc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_desc(self):
builder = self.get_builder()
builder.order_by("email", "desc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_column(self):
builder = self.get_builder()
builder.where_column("name", "username")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_in(self):
builder = self.get_builder()
builder.where_not_in("id", [1, 2, 3])
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_between(self):
builder = self.get_builder()
builder.between("id", 2, 5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_not_between(self):
builder = self.get_builder()
builder.not_between("id", 2, 5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_in(self):
builder = self.get_builder()
builder.where_in("id", [1, 2, 3])
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_null(self):
builder = self.get_builder()
builder.where_null("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_null(self):
builder = self.get_builder()
builder.where_not_null("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_having(self):
builder = self.get_builder(table="payments")
builder.select("user_id").avg("salary").group_by("user_id").having(
"salary", ">=", "1000"
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_group_by(self):
builder = self.get_builder(table="payments")
builder.select("user_id").min("salary").group_by("user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_builder_alone(self):
self.assertTrue(
QueryBuilder(
connection_class=MockConnection,
connection="postgres",
connection_details={
"default": "postgres",
"postgres": {
"driver": "postgres",
"host": "localhost",
"user": "postgres",
"password": "postgres",
"database": "orm",
"port": "5432",
"prefix": "",
"grammar": "postgres",
},
},
).table("users")
)
def test_where_lt(self):
builder = self.get_builder()
builder.where("age", "<", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_lte(self):
builder = self.get_builder()
builder.where("age", "<=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_gt(self):
builder = self.get_builder()
builder.where("age", ">", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_gte(self):
builder = self.get_builder()
builder.where("age", ">=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_ne(self):
builder = self.get_builder()
builder.where("age", "!=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_or_where(self):
builder = self.get_builder()
builder.where("age", "20").or_where("age", "<", 20)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_can_call_with_schema(self):
builder = self.get_builder()
sql = (
builder.table("information_schema.columns")
.select("table_name")
.where("table_name", "users")
.to_sql()
)
self.assertEqual(
sql,
"""SELECT "information_schema"."columns"."table_name" FROM "information_schema"."columns" WHERE "information_schema"."columns"."table_name" = 'users'""",
)
def test_truncate(self):
builder = self.get_builder(dry=True)
sql = builder.truncate()
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
def test_truncate_without_foreign_keys(self):
builder = self.get_builder(dry=True)
sql = builder.truncate(foreign_keys=True)
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
def test_shared_lock(self):
builder = self.get_builder(dry=True)
sql = builder.where("votes", ">=", 100).shared_lock().to_sql()
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
def test_update_lock(self):
builder = self.get_builder(dry=True)
sql = builder.where("votes", ">=", 100).lock_for_update().to_sql()
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
class PostgresQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase):
grammar = PostgresGrammar
def sum(self):
"""
builder = self.get_builder()
builder.sum('age')
"""
return """SELECT SUM("users"."age") AS age FROM "users\""""
def max(self):
"""
builder = self.get_builder()
builder.max('age')
"""
return """SELECT MAX("users"."age") AS age FROM "users\""""
def min(self):
"""
builder = self.get_builder()
builder.min('age')
"""
return """SELECT MIN("users"."age") AS age FROM "users\""""
def avg(self):
"""
builder = self.get_builder()
builder.avg('age')
"""
return """SELECT AVG("users"."age") AS age FROM "users\""""
def first(self):
"""
builder = self.get_builder()
builder.first()
"""
return """SELECT * FROM "users" LIMIT 1"""
def all(self):
"""
builder = self.get_builder()
builder.all()
"""
return """SELECT * FROM "users\""""
def get(self):
"""
builder = self.get_builder()
builder.get()
"""
return """SELECT * FROM "users\""""
def select(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return """SELECT "users"."name", "users"."email" FROM "users\""""
def add_select_no_table(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return (
"SELECT "
'(SELECT MAX("different_table"."updated_at") AS updated_at FROM "different_table") AS other_test, '
'(SELECT MAX("another_table"."updated_at") AS updated_at FROM "another_table") AS some_alias'
)
def select_raw(self):
"""
builder = self.get_builder()
builder.select_raw('count(email) as email_count')
"""
return """SELECT count(email) as email_count FROM "users\""""
def create(self):
"""
builder = get_builder()
builder.create({"name": "Corentin All", 'email': 'corentin@yopmail.com'})
"""
return """INSERT INTO "users" ("name", "email") VALUES ('Corentin All', 'corentin@yopmail.com') RETURNING *"""
def delete(self):
"""
builder = get_builder()
builder.delete("name', 'Joe')
"""
return """DELETE FROM "users" WHERE "users"."name" = 'Joe'"""
def where(self):
"""
builder = get_builder()
builder.where('name', 'Joe')
"""
return """SELECT * FROM "users" WHERE "users"."name" = 'Joe'"""
def where_exists(self):
"""
builder = get_builder()
builder.where_exists('name')
"""
return """SELECT * FROM "users" WHERE EXISTS 'name'"""
def limit(self):
"""
builder = get_builder()
builder.limit(5)
"""
return """SELECT * FROM "users" LIMIT 5"""
def offset(self):
"""
builder = get_builder()
builder.offset(5)
"""
return """SELECT * FROM "users" OFFSET 5"""
def join(self):
"""
builder.join("profiles", "users.id", "=", "profiles.user_id")
"""
return """SELECT * FROM "users" INNER JOIN "profiles" ON "users"."id" = "profiles"."user_id\""""
def left_join(self):
"""
builder.left_join("profiles", "users.id", "=", "profiles.user_id")
"""
return """SELECT * FROM "users" LEFT JOIN "profiles" ON "users"."id" = "profiles"."user_id\""""
def right_join(self):
"""
builder.right_join("profiles", "users.id", "=", "profiles.user_id")
"""
return """SELECT * FROM "users" RIGHT JOIN "profiles" ON "users"."id" = "profiles"."user_id\""""
def update(self):
"""
builder.update({"name": "Joe", "email": "joe@yopmail.com"})
"""
return """UPDATE "users" SET "name" = 'Joe', "email" = 'joe@yopmail.com'"""
def increment(self):
"""
builder.increment('age', 1)
"""
return """UPDATE "users" SET "age" = "age" + '1'"""
def decrement(self):
"""
builder.decrement('age', 1)
"""
return """UPDATE "users" SET "age" = "age" - '1'"""
def count(self):
"""
builder.count(id)
"""
return """SELECT COUNT("users"."id") AS id FROM "users\""""
def order_by_asc(self):
"""
builder.order_by('email', 'asc')
"""
return """SELECT * FROM "users" ORDER BY "email" ASC"""
def order_by_desc(self):
"""
builder.order_by('email', 'des')
"""
return """SELECT * FROM "users" ORDER BY "email" DESC"""
def where_column(self):
"""
builder.where_column('name', 'username')
"""
return """SELECT * FROM "users" WHERE "users"."name" = "users"."username\""""
def where_null(self):
"""
builder.where_null('name')
"""
return """SELECT * FROM "users" WHERE "users"."name" IS NULL"""
def where_not_null(self):
"""
builder.where_null('name')
"""
return """SELECT * FROM "users" WHERE "users"."name" IS NOT NULL"""
def where_not_in(self):
"""
builder.where_not_in('id', [1, 2, 3])
"""
return """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')"""
def where_in(self):
"""
builder.where_in('id', [1, 2, 3])
"""
return """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')"""
def between(self):
"""
builder.between('id', 2, 5)
"""
return """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'"""
def not_between(self):
"""
builder.not_between('id', 2, 5)
"""
return """SELECT * FROM "users" WHERE "users"."id" NOT BETWEEN '2' AND '5'"""
def having(self):
"""
builder.select('user_id').avg('salary').group_by('user_id').having('salary', '>=', '1000')
"""
return """SELECT "payments"."user_id", AVG("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id" HAVING "payments"."salary" >= '1000'"""
def group_by(self):
"""
builder.select('user_id').min('salary').group_by('user_id')
"""
return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id\""""
def where_lt(self):
"""
builder = self.get_builder()
builder.where('age', '<', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" < '20'"""
def where_lte(self):
"""
builder = self.get_builder()
builder.where('age', '<=', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" <= '20'"""
def where_gt(self):
"""
builder = self.get_builder()
builder.where('age', '>', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" > '20'"""
def where_gte(self):
"""
builder = self.get_builder()
builder.where('age', '>=', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" >= '20'"""
def where_ne(self):
"""
builder = self.get_builder()
builder.where('age', '!=', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" != '20'"""
def or_where(self):
"""
builder = self.get_builder()
builder.where('age', '20').or_where('age','<', 20)
"""
return """SELECT * FROM "users" WHERE "users"."age" = '20' OR "users"."age" < '20'"""
def where_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%")
"""
return """SELECT * FROM "users" WHERE "users"."age" ILIKE '%name%'"""
def where_not_like(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%")
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'"""
def truncate(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return """TRUNCATE TABLE "users\""""
def truncate_without_foreign_keys(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return """TRUNCATE TABLE "users\""""
def update_lock(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR UPDATE"""
def shared_lock(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR SHARE"""
def test_latest(self):
builder = self.get_builder()
builder.latest("email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_oldest(self):
builder = self.get_builder()
builder.oldest("email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def oldest(self):
"""
builder.order_by('email', 'asc')
"""
return """SELECT * FROM "users" ORDER BY "email" ASC"""
def latest(self):
"""
builder.order_by('email', 'des')
"""
return """SELECT * FROM "users" ORDER BY "email" DESC"""
================================================
FILE: tests/postgres/builder/test_postgres_transaction.py
================================================
import inspect
import os
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import PostgresGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory
if os.getenv("RUN_POSTGRES_DATABASE") == "True":
class User(Model):
__connection__ = "postgres"
__timestamps__ = False
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users"):
connection = ConnectionFactory().make("postgres")
return QueryBuilder(
grammar=PostgresGrammar,
connection=connection,
table=table,
connection_details=DATABASES,
).on("postgres")
def test_transaction(self):
builder = self.get_builder()
builder.begin()
builder.create({"name": "phillip2", "email": "phillip2"})
# builder.commit()
user = builder.where("name", "phillip2").first()
self.assertEqual(user["name"], "phillip2")
builder.rollback()
user = builder.where("name", "phillip2").first()
self.assertEqual(user, None)
================================================
FILE: tests/postgres/grammar/test_delete_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import PostgresGrammar
class BaseDeleteGrammarTest:
def setUp(self):
self.builder = QueryBuilder(PostgresGrammar, table="users")
def test_can_compile_delete(self):
to_sql = self.builder.delete("id", 1, query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_delete_in(self):
to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_delete_with_where(self):
to_sql = (
self.builder.where("age", 20)
.where("profile", 1)
.set_action("delete")
.delete(query=True)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestPostgresDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase):
grammar = "postgres"
def can_compile_delete(self):
"""
(
self.builder
.delete('id', 1)
.to_sql()
)
"""
return """DELETE FROM "users" WHERE "users"."id" = '1'"""
def can_compile_delete_in(self):
"""
(
self.builder
.delete('id', 1)
.to_sql()
)
"""
return """DELETE FROM "users" WHERE "users"."id" IN ('1','2','3')"""
def can_compile_delete_with_where(self):
"""
(
self.builder
.where('age', 20)
.where('profile', 1)
.set_action('delete')
.delete()
.to_sql()
)
"""
return """DELETE FROM "users" WHERE "users"."age" = '20' AND "users"."profile" = '1'"""
================================================
FILE: tests/postgres/grammar/test_insert_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import PostgresGrammar
class BaseInsertGrammarTest:
def setUp(self):
self.builder = QueryBuilder(PostgresGrammar, table="users")
def test_can_compile_insert(self):
to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_insert_with_keywords(self):
to_sql = self.builder.create(name="Joe", query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create(self):
to_sql = self.builder.bulk_create(
# These keys are intentionally out of order to show column to value alignment works
[
{"name": "Joe", "age": 5},
{"age": 35, "name": "Bill"},
{"name": "John", "age": 10},
],
query=True,
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create_qmark(self):
to_sql = self.builder.bulk_create(
[{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True
).to_qmark()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestPostgresUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase):
grammar = "postgres"
def can_compile_insert(self):
"""
self.builder.create({
'name': 'Joe'
}).to_sql()
"""
return """INSERT INTO "users" ("name") VALUES ('Joe') RETURNING *"""
def can_compile_insert_with_keywords(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO "users" ("name") VALUES ('Joe') RETURNING *"""
def can_compile_bulk_create(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO "users" ("age", "name") VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John') RETURNING *"""
def can_compile_bulk_create_qmark(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO "users" ("name") VALUES ('?'), ('?'), ('?') RETURNING *"""
================================================
FILE: tests/postgres/grammar/test_select_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query.grammars import PostgresGrammar
from src.masoniteorm.testing import BaseTestCaseSelectGrammar
class TestPostgresGrammar(BaseTestCaseSelectGrammar, unittest.TestCase):
grammar = PostgresGrammar
def can_compile_select(self):
"""
self.builder.to_sql()
"""
return """SELECT * FROM "users\""""
def can_compile_with_columns(self):
"""
self.builder.select('username', 'password').to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users\""""
def can_compile_with_where(self):
"""
self.builder.select('username', 'password').where('id', 1).to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1'"""
def can_compile_with_several_where(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe'"""
def can_compile_with_several_where_and_limit(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe' LIMIT 10"""
def can_compile_with_sum(self):
"""
self.builder.sum('age').to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users\""""
def can_compile_with_max(self):
"""
self.builder.max('age').to_sql()
"""
return """SELECT MAX("users"."age") AS age FROM "users\""""
def can_compile_with_max_and_columns(self):
"""
self.builder.select('username').max('age').to_sql()
"""
return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\""""
def can_compile_with_max_and_columns_different_order(self):
"""
self.builder.max('age').select('username').to_sql()
"""
return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\""""
def can_compile_with_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').to_sql()
"""
return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC"""
def can_compile_with_multiple_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql()
"""
return (
"""SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC"""
)
def can_compile_with_group_by(self):
"""
self.builder.select('username').group_by('age').to_sql()
"""
return """SELECT "users"."username" FROM "users" GROUP BY "users"."age\""""
def can_compile_where_in(self):
"""
self.builder.select('username').where_in('age', [1,2,3]).to_sql()
"""
return """SELECT "users"."username" FROM "users" WHERE "users"."age" IN ('1','2','3')"""
def can_compile_where_in_empty(self):
"""
self.builder.where_in('age', []).to_sql()
"""
return """SELECT * FROM "users" WHERE 0 = 1"""
def can_compile_where_not_in(self):
"""
self.builder.select('username').where_not_in('age', [1,2,3]).to_sql()
"""
return """SELECT "users"."username" FROM "users" WHERE "users"."age" NOT IN ('1','2','3')"""
def can_compile_where_null(self):
"""
self.builder.select('username').where_null('age').to_sql()
"""
return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NULL"""
def can_compile_where_not_null(self):
"""
self.builder.select('username').where_not_null('age').to_sql()
"""
return (
"""SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL"""
)
def can_compile_where_raw(self):
"""
self.builder.where_raw(""age" = '18'").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" = '18'"""
def can_compile_having_raw(self):
"""
self.builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql()
"""
return """SELECT COUNT(*) as counts FROM "users" HAVING counts > 18"""
def can_compile_select_raw(self):
"""
self.builder.select_raw("COUNT(*)").to_sql()
"""
return """SELECT COUNT(*) FROM "users\""""
def can_compile_limit_and_offset(self):
"""
self.builder.limit(10).offset(10).to_sql()
"""
return """SELECT * FROM "users" LIMIT 10 OFFSET 10"""
def can_compile_select_raw_with_select(self):
"""
self.builder.select('id').select_raw("COUNT(*)").to_sql()
"""
return """SELECT "users"."id", COUNT(*) FROM "users\""""
def can_compile_count(self):
"""
self.builder.count().to_sql()
"""
return """SELECT COUNT(*) AS m_count_reserved FROM "users\""""
def can_compile_count_column(self):
"""
self.builder.count().to_sql()
"""
return """SELECT COUNT("users"."money") AS money FROM "users\""""
def can_compile_where_column(self):
"""
self.builder.where_column('name', 'email').to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" = "users"."email\""""
def can_compile_or_where(self):
"""
self.builder.where('name', 2).or_where('name', 3).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" = '2' OR "users"."name" = '3'"""
def can_grouped_where(self):
"""
self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql()
"""
return """SELECT * FROM "users" WHERE ("users"."age" = '2' AND "users"."name" = 'Joe')"""
def can_compile_sub_select(self):
"""
self.builder.where_in('name',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age')
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users")"""
def can_compile_sub_select_from_lambda(self):
"""
self.builder.where_in('name',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age')
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users")"""
def can_compile_sub_select_where(self):
"""
self.builder.where_in('age',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe')
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" IN (SELECT "users"."age" FROM "users" WHERE "users"."age" = '2' AND "users"."name" = 'Joe')"""
def can_compile_sub_select_value(self):
"""
self.builder.where('name',
self.builder.new().sum('age')
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")"""
def can_compile_complex_sub_select(self):
"""
self.builder.where_in('name',
(QueryBuilder(GrammarFactory.make(self.grammar), table='users')
.select('age').where_in('email',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email')
))
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users" WHERE "users"."email" IN (SELECT "users"."email" FROM "users"))"""
def can_compile_exists(self):
"""
self.builder.select('age').where_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return """SELECT "users"."age" FROM "users" WHERE EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')"""
def can_compile_not_exists(self):
"""
self.builder.select('age').where_not_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return """SELECT "users"."age" FROM "users" WHERE NOT EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')"""
def can_compile_having(self):
"""
builder.sum('age').group_by('age').having('age').to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\""""
def can_compile_having_order(self):
"""
builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" ORDER "users"."age" DESC"""
def can_compile_having_with_expression(self):
"""
builder.sum('age').group_by('age').having('age', 10).to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" = '10'"""
def can_compile_order_by_and_first(self):
"""
self.builder.order_by('id', 'asc').first()
"""
return """SELECT * FROM "users" ORDER BY "id" ASC LIMIT 1"""
def can_compile_having_with_greater_than_expression(self):
"""
builder.sum('age').group_by('age').having('age', '>', 10).to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" > '10'"""
def can_compile_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id\""""
def can_compile_left_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return """SELECT * FROM "users" LEFT JOIN "contacts" ON "users"."id" = "contacts"."user_id\""""
def can_compile_multiple_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id" INNER JOIN "posts" ON "comments"."post_id" = "posts"."id\""""
def can_compile_between(self):
"""
builder.between('age', 18, 21).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" BETWEEN '18' AND '21'"""
def can_compile_not_between(self):
"""
builder.not_between('age', 18, 21).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT BETWEEN '18' AND '21'"""
def test_can_compile_where_raw(self):
to_sql = self.builder.where_raw(""" "age" = '18'""").to_sql()
self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""")
def test_can_compile_having_raw(self):
to_sql = (
self.builder.select_raw("COUNT(*) as counts")
.having_raw("counts > 10")
.to_sql()
)
self.assertEqual(
to_sql, """SELECT COUNT(*) as counts FROM "users" HAVING counts > 10"""
)
def test_can_compile_having_raw_order(self):
to_sql = (
self.builder.select_raw("COUNT(*) as counts")
.having_raw("counts > 10")
.order_by_raw("counts DESC")
.to_sql()
)
self.assertEqual(
to_sql,
"""SELECT COUNT(*) as counts FROM "users" HAVING counts > 10 ORDER BY counts DESC""",
)
def test_can_compile_where_raw_and_where_with_multiple_bindings(self):
query = self.builder.where_raw(
""" "age" = '?' AND "is_admin" = '?'""", [18, True]
).where("email", "test@example.com")
self.assertEqual(
query.to_qmark(),
"""SELECT * FROM "users" WHERE "age" = '?' AND "is_admin" = '?' AND "users"."email" = '?'""",
)
self.assertEqual(query._bindings, [18, True, "test@example.com"])
def test_can_compile_select_raw(self):
to_sql = self.builder.select_raw("COUNT(*)").to_sql()
self.assertEqual(to_sql, """SELECT COUNT(*) FROM "users\"""")
def test_can_compile_select_raw_with_select(self):
to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql()
self.assertEqual(to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""")
def can_compile_first_or_fail(self):
"""
builder = self.get_builder()
builder.where("is_admin", "=", True).first_or_fail()
"""
return """SELECT * FROM "users" WHERE "users"."is_admin" IS True LIMIT 1"""
def where_not_like(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT ILIKE '%name%'"""
def where_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" ILIKE '%name%'"""
def where_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "regexp", "Joe").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" REGEXP 'Joe'"""
def where_not_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "regexp", "Joe").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT REGEXP 'Joe'"""
def can_compile_join_clause(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on("bgt.fund", "=", "rg.fund")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt"."dept" = "rg"."dept" AND "bgt"."acct" = "rg"."acct" AND "bgt"."sub" = "rg"."sub\""""
def can_compile_join_clause_with_value(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."active" = '1' OR "bgt"."acct" = '1234'"""
def can_compile_join_clause_with_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_null("bgt.acct")
.or_on_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NULL AND "rg"."abc" = '10'"""
def can_compile_join_clause_with_not_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_not_null("bgt.acct")
.or_on_not_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NOT NULL OR "dept" IS NOT NULL AND "rg"."abc" = '10'"""
def can_compile_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.on_null("bgt")
),
).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt" IS NULL"""
def can_compile_left_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.left_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return """SELECT * FROM "users" LEFT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL"""
def can_compile_right_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.right_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return """SELECT * FROM "users" RIGHT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL"""
def shared_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR SHARE"""
def update_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."votes" >= '100' FOR UPDATE"""
def can_user_where_raw_and_where(self):
"""
builder.where_raw("`age` = '18'").where("name", "=", "James").to_sql()
"""
return """SELECT * FROM "users" WHERE age = '18' AND "users"."name" = 'James'"""
def where_exists_with_lambda(self):
return """SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')"""
def where_not_exists_with_lambda(self):
return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')"""
def where_date(self):
return (
"""SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'"""
)
def or_where_null(self):
return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL"""
def select_distinct(self):
return """SELECT DISTINCT "users"."group" FROM "users\""""
================================================
FILE: tests/postgres/grammar/test_update_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.connections import PostgresConnection
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import PostgresGrammar
from src.masoniteorm.expressions import Raw
class BaseTestCaseUpdateGrammar:
def setUp(self):
self.builder = QueryBuilder(
PostgresGrammar, connection_class=PostgresConnection, table="users"
)
def test_can_compile_update(self):
to_sql = (
self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_multiple_update(self):
to_sql = self.builder.update(
{"name": "Joe", "email": "user@email.com"}, dry=True
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_update_with_multiple_where(self):
to_sql = (
self.builder.where("name", "bob")
.where("age", 20)
.update({"name": "Joe"}, dry=True)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
# def test_can_compile_increment(self):
# to_sql = self.builder.increment("age").to_sql()
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(to_sql, sql)
# def test_can_compile_decrement(self):
# to_sql = self.builder.decrement("age", 20).to_sql()
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(to_sql, sql)
def test_raw_expression(self):
to_sql = self.builder.update({"name": Raw('"username"')}, dry=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_update_null(self):
to_sql = self.builder.update({"name": None}, dry=True).to_sql()
print(to_sql)
sql = """UPDATE "users" SET "name" = \'None\'"""
self.assertEqual(to_sql, sql)
class TestPostgresUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase):
grammar = "postgres"
def can_compile_update(self):
"""
builder.where('name', 'bob').update({
'name': 'Joe'
}).to_sql()
"""
return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob'"""
def raw_expression(self):
"""
builder.where('name', 'bob').update({
'name': 'Joe'
}).to_sql()
"""
return """UPDATE "users" SET "name" = "username\""""
def can_compile_multiple_update(self):
"""
self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql()
"""
return """UPDATE "users" SET "name" = 'Joe', "email" = 'user@email.com'"""
def can_compile_update_with_multiple_where(self):
"""
builder.where('name', 'bob').where('age', 20).update({
'name': 'Joe'
}).to_sql()
"""
return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob' AND "age" = '20'"""
def can_compile_increment(self):
"""
builder.increment('age').to_sql()
"""
return """UPDATE "users" SET "age" = "age" + '1'"""
def can_compile_decrement(self):
"""
builder.decrement('age', 20).to_sql()
"""
return """UPDATE "users" SET "age" = "age" - '20'"""
================================================
FILE: tests/postgres/relationships/test_postgres_relationships.py
================================================
import os
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import belongs_to, has_many
if os.getenv("RUN_POSTGRES_DATABASE", False) == "True":
class Profile(Model):
__table__ = "profiles"
__connection__ = "postgres"
class Articles(Model):
__table__ = "articles"
__connection__ = "postgres"
@belongs_to("id", "article_id")
def logo(self):
return Logo
class Logo(Model):
__table__ = "logos"
__connection__ = "postgres"
class User(Model):
__connection__ = "postgres"
_eager_loads = ()
__casts__ = {"is_admin": "bool"}
@belongs_to("id", "user_id")
def profile(self):
return Profile
@has_many("id", "user_id")
def articles(self):
return Articles
def get_is_admin(self):
return "You are an admin"
class TestRelationships(unittest.TestCase):
maxDiff = None
def test_relationship_can_be_callable(self):
self.assertEqual(
User.profile().where("name", "Joe").to_sql(),
"""SELECT * FROM "profiles" WHERE "profiles"."name" = 'Joe'""",
)
def test_can_access_relationship(self):
for user in User.where("id", 1).get():
self.assertIsInstance(user.profile, Profile)
def test_can_access_has_many_relationship(self):
user = User.hydrate(User.where("id", 1).first())
self.assertEqual(len(user.articles), 4)
def test_can_access_relationship_multiple_times(self):
user = User.hydrate(User.where("id", 1).first())
self.assertEqual(len(user.articles), 4)
self.assertEqual(len(user.articles), 4)
def test_loading(self):
users = User.with_("articles").get()
for user in users:
user
def test_casting(self):
users = User.with_("articles").where("is_admin", True).get()
for user in users:
user
def test_setting(self):
users = User.with_("articles").where("is_admin", True).get()
for user in users:
user.name = "Joe"
user.is_admin = 1
user.save()
def test_relationship_has(self):
to_sql = User.has("articles").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\""""
""")""",
)
def test_relationship_has_off_builder(self):
to_sql = User.where("active", 1).has("articles").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE "users"."active" = '1' AND EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\""""
""")""",
)
def test_relationship_multiple_has(self):
to_sql = User.has("articles", "profile").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\""""
""") AND EXISTS ("""
"""SELECT * FROM "profiles" WHERE "profiles"."user_id" = "users"."id\""""
""")""",
)
count = User.has("articles", "profile").get().count()
self.assertEqual(count, 2)
def test_nested_has(self):
to_sql = User.has("articles.logo").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND EXISTS (SELECT * FROM "logos" WHERE "logos"."article_id" = "articles"."id"))""",
)
count = User.has("articles.logo").get().count()
self.assertEqual(count, 2)
def test_relationship_where_has(self):
to_sql = User.where_has("articles", lambda q: q.where("status", 1)).to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND "articles"."status" = '1'"""
""")""",
)
================================================
FILE: tests/postgres/schema/test_postgres_schema_builder.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import PostgresConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import PostgresPlatform
class TestPostgresSchemaBuilder(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection_class=PostgresConnection,
connection="postgres",
connection_details=DATABASES,
platform=PostgresPlatform,
dry=True,
)
def test_can_add_columns(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'
],
)
def test_can_add_tiny_text(self):
with self.schema.create("users") as blueprint:
blueprint.tiny_text("description")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)']
)
def test_can_add_unsigned_decimal(self):
with self.schema.create("users") as blueprint:
blueprint.unsigned_decimal("amount", 19, 4)
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
['CREATE TABLE "users" ("amount" DECIMAL(19, 4) NOT NULL)'],
)
def test_can_create_table_if_not_exists(self):
with self.schema.create_table_if_not_exists("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'
],
)
def test_can_add_column_comment(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").comment("A users username")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL)',
"""COMMENT ON COLUMN "users"."name" is 'A users username'""",
],
)
def test_can_add_table_comment(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.table_comment("A users table")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL)',
"""COMMENT ON TABLE "users" is 'A users table'""",
],
)
def test_can_truncate(self):
sql = self.schema.truncate("users")
self.assertEqual(sql, 'TRUNCATE "users"')
def test_can_rename_table(self):
sql = self.schema.rename("users", "clients")
self.assertEqual(sql, 'ALTER TABLE "users" RENAME TO "clients"')
def test_can_drop_table_if_exists(self):
sql = self.schema.drop_table_if_exists("users", "clients")
self.assertEqual(sql, 'DROP TABLE IF EXISTS "users"')
def test_can_drop_table(self):
sql = self.schema.drop_table("users", "clients")
self.assertEqual(sql, 'DROP TABLE "users"')
def test_has_column(self):
sql = self.schema.has_column("users", "name")
self.assertEqual(
sql,
"SELECT column_name FROM information_schema.columns WHERE table_name='users' and column_name='name'",
)
def test_can_add_columns_with_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
blueprint.unique("name")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, CONSTRAINT users_name_unique UNIQUE (name))'
],
)
def test_can_add_columns_with_long_text(self):
with self.schema.create("users") as blueprint:
blueprint.long_text("description")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)']
)
def test_can_have_unsigned_columns(self):
with self.schema.create("users") as blueprint:
blueprint.integer("profile_id").unsigned()
blueprint.big_integer("big_profile_id").unsigned()
blueprint.tiny_integer("tiny_profile_id").unsigned()
blueprint.small_integer("small_profile_id").unsigned()
blueprint.medium_integer("medium_profile_id").unsigned()
self.assertEqual(
blueprint.to_sql(),
[
"""CREATE TABLE "users" ("""
""""profile_id" INTEGER NOT NULL, """
""""big_profile_id" BIGINT NOT NULL, """
""""tiny_profile_id" TINYINT NOT NULL, """
""""small_profile_id" SMALLINT NOT NULL, """
""""medium_profile_id" MEDIUMINT NOT NULL)"""
],
)
def test_can_add_columns_with_foreign_key_constaint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.foreign("profile_id").references("id").on("profiles")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, '
'"profile_id" INTEGER NOT NULL, CONSTRAINT users_name_unique UNIQUE (name), '
'CONSTRAINT users_profile_id_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))'
],
)
def test_can_advanced_table_creation(self):
with self.schema.create("users") as blueprint:
blueprint.increments("id")
blueprint.string("name")
blueprint.string("email").unique()
blueprint.string("password")
blueprint.integer("admin").default(0)
blueprint.string("remember_token").nullable()
blueprint.timestamp("verified_at").nullable()
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 9)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("id" SERIAL UNIQUE NOT NULL, "name" VARCHAR(255) NOT NULL, '
'"email" VARCHAR(255) NOT NULL, "password" VARCHAR(255) NOT NULL, "admin" INTEGER NOT NULL DEFAULT 0, '
'"remember_token" VARCHAR(255) NULL, "verified_at" TIMESTAMP NULL, '
'"created_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, "updated_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, '
"CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_email_unique UNIQUE (email))"
],
)
def test_can_advanced_table_creation2(self):
with self.schema.create("users") as blueprint:
blueprint.big_increments("id")
blueprint.string("name")
blueprint.enum("gender", ["male", "female"])
blueprint.string("duration")
blueprint.decimal("money")
blueprint.string("url")
blueprint.string("option").default("ADMIN")
blueprint.jsonb("payload")
blueprint.inet("last_address").nullable()
blueprint.cidr("route_origin").nullable()
blueprint.macaddr("mac_address").nullable()
blueprint.datetime("published_at")
blueprint.string("thumbnail").nullable()
blueprint.integer("premium")
blueprint.double("amount").default(0.0)
blueprint.integer("author_id").unsigned().nullable()
blueprint.foreign("author_id").references("id").on("authors").on_delete(
"CASCADE"
)
blueprint.text("description")
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 19)
self.assertEqual(
blueprint.to_sql(),
(
[
"""CREATE TABLE "users" ("id" BIGSERIAL UNIQUE NOT NULL, "name" VARCHAR(255) NOT NULL, "gender" VARCHAR(255) CHECK(gender IN ('male', 'female')) NOT NULL, """
""""duration" VARCHAR(255) NOT NULL, "money" DECIMAL(17, 6) NOT NULL, "url" VARCHAR(255) NOT NULL, "option" VARCHAR(255) NOT NULL DEFAULT 'ADMIN', "payload" JSONB NOT NULL, "last_address" INET NULL, """
'"route_origin" CIDR NULL, "mac_address" MACADDR NULL, "published_at" TIMESTAMPTZ NOT NULL, "thumbnail" VARCHAR(255) NULL, "premium" INTEGER NOT NULL, "amount" DOUBLE PRECISION NOT NULL DEFAULT 0.0, '
'"author_id" INTEGER NULL, "description" TEXT NOT NULL, "created_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, '
'"updated_at" TIMESTAMPTZ NULL DEFAULT CURRENT_TIMESTAMP, '
'CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY ("author_id") REFERENCES "authors"("id") ON DELETE CASCADE)'
]
),
)
def test_can_add_uuid_column(self):
# might not be the right place for this test + other column types
# are not tested => just for testing the PR now
with self.schema.create("users") as table:
table.uuid("id").default_raw("uuid_generate_v4()")
table.primary("id")
table.string("name")
table.uuid("public_id").nullable()
table.uuid("other_id").default_raw("uuid_generate_v4()")
self.assertEqual(len(table.table.added_columns), 4)
self.assertEqual(
table.to_sql(),
[
'CREATE TABLE "users" ("id" UUID NOT NULL DEFAULT uuid_generate_v4(), "name" VARCHAR(255) NOT NULL, "public_id" UUID NULL, "other_id" UUID NOT NULL DEFAULT uuid_generate_v4(), CONSTRAINT users_id_primary PRIMARY KEY (id))'
],
)
def test_can_add_columns_with_foreign_key_constraint_name(self):
with self.schema.create("users") as blueprint:
blueprint.integer("profile_id")
blueprint.foreign("profile_id", name="profile_foreign").references("id").on(
"profiles"
)
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ('
'"profile_id" INTEGER NOT NULL, '
'CONSTRAINT profile_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))'
],
)
def test_can_have_composite_keys(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.primary(["name", "age"])
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" '
'("name" VARCHAR(255) NOT NULL, '
'"age" INTEGER NOT NULL, '
'"profile_id" INTEGER NOT NULL, '
"CONSTRAINT users_name_unique UNIQUE (name), "
"CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))"
],
)
def test_can_have_column_primary_key(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").primary()
blueprint.integer("age")
blueprint.integer("profile_id")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" '
'("name" VARCHAR(255) NOT NULL, '
'"age" INTEGER NOT NULL, '
'"profile_id" INTEGER NOT NULL, '
"CONSTRAINT users_name_primary PRIMARY KEY (name))"
],
)
def test_can_add_other_integer_types_column(self):
with self.schema.create("integer_types") as table:
table.tiny_integer("tiny")
table.small_integer("small")
table.medium_integer("medium")
table.big_integer("big")
self.assertEqual(len(table.table.added_columns), 4)
self.assertEqual(
table.to_sql(),
[
'CREATE TABLE "integer_types" ("tiny" TINYINT NOT NULL, "small" SMALLINT NOT NULL, "medium" MEDIUMINT NOT NULL, "big" BIGINT NOT NULL)'
],
)
def test_can_add_binary_column(self):
with self.schema.create("binary_storing") as table:
table.binary("filecontent")
self.assertEqual(len(table.table.added_columns), 1)
self.assertEqual(
table.to_sql(),
['CREATE TABLE "binary_storing" ("filecontent" BYTEA NOT NULL)'],
)
def test_can_have_float_type(self):
with self.schema.create("users") as blueprint:
blueprint.float("amount")
self.assertEqual(
blueprint.to_sql(),
["""CREATE TABLE "users" (""" """\"amount" FLOAT(19, 4) NOT NULL)"""],
)
def test_can_enable_foreign_keys(self):
sql = self.schema.enable_foreign_key_constraints()
self.assertEqual(sql, "")
def test_can_disable_foreign_keys(self):
sql = self.schema.disable_foreign_key_constraints()
self.assertEqual(sql, "")
def test_can_truncate_without_foreign_keys(self):
sql = self.schema.truncate("users", foreign_keys=True)
self.assertEqual(
sql,
[
'ALTER TABLE "users" DISABLE TRIGGER ALL',
'TRUNCATE "users"',
'ALTER TABLE "users" ENABLE TRIGGER ALL',
],
)
def test_can_add_enum(self):
with self.schema.create("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL ' 'DEFAULT \'active\')'
],
)
================================================
FILE: tests/postgres/schema/test_postgres_schema_builder_alter.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import PostgresConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import PostgresPlatform
from src.masoniteorm.schema.Table import Table
class TestPostgresSchemaBuilderAlter(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection_class=PostgresConnection,
connection="postgres",
connection_details=DATABASES,
platform=PostgresPlatform,
dry=True,
).on("postgres")
def test_can_add_columns(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL, ADD COLUMN "age" INTEGER NOT NULL'
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_add_column_comments(self):
with self.schema.table("users") as blueprint:
blueprint.string("name").comment("A users username")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL',
"""COMMENT ON COLUMN "users"."name" is 'A users username'""",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_add_table_comment(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.table_comment("A users table")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL',
"""COMMENT ON TABLE "users" is 'A users table'""",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_rename(self):
with self.schema.table("users") as blueprint:
blueprint.rename("post", "comment", "integer")
table = Table("users")
table.add_column("post", "integer")
blueprint.table.from_table = table
sql = ['ALTER TABLE "users" RENAME COLUMN "post" TO "comment"']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_and_rename(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.rename("post", "comment", "integer")
table = Table("users")
table.add_column("post", "integer")
blueprint.table.from_table = table
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL',
'ALTER TABLE "users" RENAME COLUMN "post" TO "comment"',
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop(self):
with self.schema.table("users") as blueprint:
blueprint.drop_column("post")
sql = ['ALTER TABLE "users" DROP COLUMN "post"']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_column_and_foreign_key(self):
with self.schema.table("users") as blueprint:
blueprint.unsigned_integer("playlist_id").nullable()
blueprint.foreign("playlist_id").references("id").on("playlists").on_delete(
"cascade"
)
sql = [
'ALTER TABLE "users" ADD COLUMN "playlist_id" INTEGER NULL',
'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_foreign FOREIGN KEY ("playlist_id") REFERENCES "playlists"("id") ON DELETE CASCADE',
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_create_indexes(self):
with self.schema.table("users") as blueprint:
blueprint.index("name")
blueprint.index(["name", "email"])
blueprint.unique("name")
blueprint.unique(["name", "email"])
blueprint.fulltext("description")
self.assertEqual(len(blueprint.table.added_columns), 0)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE INDEX users_name_index ON "users"(name)',
'CREATE INDEX users_name_email_index ON "users"(name,email)',
'ALTER TABLE "users" ADD CONSTRAINT users_name_unique UNIQUE(name)',
'ALTER TABLE "users" ADD CONSTRAINT users_name_email_unique UNIQUE(name,email)',
],
)
def test_alter_drop_foreign_key(self):
with self.schema.table("users") as blueprint:
blueprint.drop_foreign("users_playlist_id_foreign")
sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_foreign']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_foreign_key_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_foreign(["playlist_id"])
sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_foreign']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_unique_constraint(self):
with self.schema.table("users") as blueprint:
blueprint.drop_unique("users_playlist_id_unique")
sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_unique']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_index(self):
with self.schema.table("users") as blueprint:
blueprint.index("playlist_id")
sql = ['CREATE INDEX users_playlist_id_index ON "users"(playlist_id)']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_primary(self):
with self.schema.table("users") as blueprint:
blueprint.primary("playlist_id")
sql = [
'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)'
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_index(self):
with self.schema.table("users") as blueprint:
blueprint.drop_index("users_playlist_id_index")
sql = ["DROP INDEX users_playlist_id_index"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_index_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_index(["playlist_id"])
sql = ["DROP INDEX users_playlist_id_index"]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_primary(self):
with self.schema.table("users") as blueprint:
blueprint.drop_primary("users_id_primary")
sql = ['ALTER TABLE "users" DROP CONSTRAINT users_id_primary']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_unique_constraint_shortcut(self):
with self.schema.table("users") as blueprint:
blueprint.drop_unique(["playlist_id"])
sql = ['ALTER TABLE "users" DROP CONSTRAINT users_playlist_id_unique']
self.assertEqual(blueprint.to_sql(), sql)
def test_has_table(self):
schema_sql = self.schema.has_table("users")
sql = "SELECT * from information_schema.tables where table_name='users' AND table_schema = 'public'"
self.assertEqual(schema_sql, sql)
def test_drop_table(self):
schema_sql = self.schema.has_table("users")
sql = "SELECT * from information_schema.tables where table_name='users' AND table_schema = 'public'"
self.assertEqual(schema_sql, sql)
def test_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").change()
blueprint.string("name")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL',
'ALTER TABLE "users" ALTER COLUMN "age" TYPE INTEGER, ALTER COLUMN "age" SET NOT NULL',
]
self.assertEqual(blueprint.to_sql(), sql)
def test_change_string(self):
with self.schema.table("users") as blueprint:
blueprint.string("name", 93).change()
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = [
'ALTER TABLE "users" ALTER COLUMN "name" TYPE VARCHAR(93), ALTER COLUMN "name" SET NOT NULL'
]
self.assertEqual(blueprint.to_sql(), sql)
def test_drop_add_and_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").default(0).nullable().change()
blueprint.string("name")
blueprint.string("external_type").default("external")
blueprint.drop_column("email")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
table.add_column("email", "string")
blueprint.table.from_table = table
sql = [
"""ALTER TABLE "users" ADD COLUMN "name" VARCHAR(255) NOT NULL, ADD COLUMN "external_type" VARCHAR(255) NOT NULL DEFAULT 'external'""",
'ALTER TABLE "users" DROP COLUMN "email"',
'ALTER TABLE "users" ALTER COLUMN "age" TYPE INTEGER, ALTER COLUMN "age" DROP NOT NULL, ALTER COLUMN "age" SET DEFAULT 0',
]
self.assertEqual(blueprint.to_sql(), sql)
def test_timestamp_alter_add_nullable_column(self):
with self.schema.table("users") as blueprint:
blueprint.timestamp("due_date").nullable()
self.assertEqual(len(blueprint.table.added_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = ['ALTER TABLE "users" ADD COLUMN "due_date" TIMESTAMP NULL']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_on_table_schema_table(self):
schema = Schema(
connection_class=PostgresConnection,
connection="postgres",
connection_details=DATABASES,
).on("postgres")
with schema.table("table_schema") as blueprint:
blueprint.drop_column("name")
with schema.table("table_schema") as blueprint:
blueprint.string("name")
def test_can_add_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
'ALTER TABLE "users" ADD COLUMN "status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\'',
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_change_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active").change()
self.assertEqual(len(blueprint.table.changed_columns), 1)
sql = [
'ALTER TABLE "users" ALTER COLUMN "status" TYPE VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')), ALTER COLUMN "status" SET NOT NULL, ALTER COLUMN "status" SET DEFAULT active',
]
self.assertEqual(blueprint.to_sql(), sql)
================================================
FILE: tests/scopes/test_default_global_scopes.py
================================================
"""Test the default global scopes available in ORM."""
import unittest
import uuid
import pendulum
from src.masoniteorm.models import Model
from src.masoniteorm.scopes import (
SoftDeletesMixin,
TimeStampsScope,
TimeStampsMixin,
UUIDPrimaryKeyScope,
UUIDPrimaryKeyMixin,
)
class MockBuilder:
def __init__(self, model):
self._model = model()
self._creates = {}
self._updates = {}
class UserWithUUID(Model, UUIDPrimaryKeyMixin):
__dry__ = True
class UserWithTimeStamps(Model, TimeStampsMixin):
__dry__ = True
class UserWithCustomTimeStamps(Model, TimeStampsMixin):
__dry__ = True
date_updated_at = "updated_ts"
date_created_at = "created_ts"
class UserSoft(Model, SoftDeletesMixin):
__dry__ = True
class TestUUIDPrimaryKeyScope(unittest.TestCase):
def setUp(self):
self.builder = MockBuilder(UserWithUUID)
self.scope = UUIDPrimaryKeyScope()
# reset User attributes before each test
UserWithUUID.__primary_key__ = "id"
flags = {
"__uuid_version__",
"__uuid_namespace__",
"__uuid_name__",
"__uuid_bytes__",
}
for flag in flags:
if hasattr(UserWithUUID, flag):
delattr(UserWithUUID, flag)
def test_default_to_uuid4(self):
self.scope.set_uuid_create(self.builder)
uuid_value = uuid.UUID(self.builder._creates["id"])
self.assertEqual(4, uuid_value.version)
def test_can_set_uuid_version(self):
# required for uuid 3 and 5
UserWithUUID.__uuid_namespace__ = uuid.NAMESPACE_DNS
UserWithUUID.__uuid_name__ = "domain.com"
for version in [1, 3, 4, 5]:
UserWithUUID.__uuid_version__ = version
self.scope.set_uuid_create(self.builder)
uuid_value = uuid.UUID(self.builder._creates["id"])
self.assertEqual(version, uuid_value.version)
del self.builder._creates["id"]
def test_default_to_uuid_str(self):
# Generates UUIDs as strings by default
self.scope.set_uuid_create(self.builder)
self.assertIsInstance(self.builder._creates["id"], str)
def test_can_set_uuid_bytes(self):
# Generates UUIDs in bytes instead of strings
UserWithUUID.__uuid_bytes__ = True
self.scope.set_uuid_create(self.builder)
self.assertIsInstance(self.builder._creates["id"], bytes)
def test_works_with_custom_pk_column(self):
UserWithUUID.__primary_key__ = "ref"
self.scope.set_uuid_create(self.builder)
self.assertIn("ref", self.builder._creates)
class TestSoftDeletesScope(unittest.TestCase):
def test_soft_deletes_changes_delete_to_update(self):
UserSoft.__timestamps__ = False
user = UserSoft.hydrate({"id": 1})
sql = user.delete(query=True).to_sql()
self.assertTrue(sql.startswith("UPDATE"))
class TestTimeStampsScope(unittest.TestCase):
def setUp(self):
self.builder = MockBuilder(UserWithTimeStamps)
self.scope = TimeStampsScope()
try:
del UserWithTimeStamps.__timestamps__
except:
pass
def test_updated_and_created_dates_are_set_when_create(self):
self.scope.set_timestamp_create(self.builder)
self.assertIn("created_at", self.builder._creates)
self.assertIn("updated_at", self.builder._creates)
created_at = pendulum.parse(self.builder._creates["created_at"])
updated_at = pendulum.parse(self.builder._creates["updated_at"])
self.assertIsInstance(created_at, pendulum.DateTime)
self.assertIsInstance(updated_at, pendulum.DateTime)
def test_timestamps_can_be_disabled(self):
UserWithTimeStamps.__timestamps__ = False
self.scope.set_timestamp_create(self.builder)
self.assertNotIn("created_at", self.builder._creates)
self.assertNotIn("updated_at", self.builder._creates)
def test_uses_custom_timestamp_columns_on_create(self):
self.builder = MockBuilder(UserWithCustomTimeStamps)
self.scope.set_timestamp_create(self.builder)
created_column = UserWithCustomTimeStamps.date_created_at
updated_column = UserWithCustomTimeStamps.date_updated_at
self.assertNotIn("created_at", self.builder._creates)
self.assertNotIn("updated_at", self.builder._creates)
self.assertIn(created_column, self.builder._creates)
self.assertIn(updated_column, self.builder._creates)
self.assertIsInstance(
pendulum.parse(self.builder._creates[created_column]), pendulum.DateTime
)
self.assertIsInstance(
pendulum.parse(self.builder._creates[updated_column]), pendulum.DateTime
)
def test_uses_custom_updated_column_on_update(self):
user = UserWithCustomTimeStamps.hydrate({"id": 1})
sql = user.update({"id": 2}).to_sql()
self.assertTrue(UserWithCustomTimeStamps.date_updated_at in sql)
================================================
FILE: tests/seeds/test_seeds.py
================================================
import unittest
from databases.seeds.user_table_seeder import UserTableSeeder
from src.masoniteorm.seeds import Seeder
class TestSeeds(unittest.TestCase):
def test_can_run_seeds(self):
seeder = Seeder(dry=True)
seeder.call(UserTableSeeder)
self.assertEqual(seeder.ran_seeds, [UserTableSeeder])
================================================
FILE: tests/sqlite/builder/test_sqlite_builder_insert.py
================================================
import inspect
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
class User(Model):
__connection__ = "dev"
__timestamps__ = False
pass
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users"):
connection = ConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar,
connection_class=connection,
connection="dev",
table=table,
connection_details=DATABASES,
).on("dev")
def test_insert(self):
builder = self.get_builder()
result = builder.create(
{"name": "Joe", "email": "joe@masoniteproject.com", "password": "secret"}
)
self.assertIsInstance(result["id"], int)
================================================
FILE: tests/sqlite/builder/test_sqlite_builder_pagination.py
================================================
import inspect
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
class User(Model):
__connection__ = "dev"
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users", model=User()):
connection = ConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar,
connection_class=connection,
connection="dev",
table=table,
model=model,
connection_details=DATABASES,
).on("dev")
def test_pagination(self):
builder = self.get_builder()
paginator = builder.table("users").paginate(1)
self.assertTrue(paginator.count)
self.assertTrue(paginator.serialize()["data"])
self.assertTrue(paginator.serialize()["meta"])
self.assertTrue(paginator.result)
self.assertTrue(paginator.current_page)
self.assertTrue(paginator.per_page)
self.assertTrue(paginator.count)
self.assertTrue(paginator.last_page)
self.assertTrue(paginator.next_page)
self.assertEqual(paginator.previous_page, None)
self.assertTrue(paginator.total)
for user in paginator:
self.assertIsInstance(user, User)
paginator = builder.table("users").simple_paginate(10, 1)
self.assertIsInstance(paginator.to_json(), str)
self.assertTrue(paginator.count)
self.assertTrue(paginator.serialize()["data"])
self.assertTrue(paginator.serialize()["meta"])
self.assertTrue(paginator.result)
self.assertTrue(paginator.current_page)
self.assertTrue(paginator.per_page)
self.assertTrue(paginator.count)
self.assertEqual(paginator.next_page, None)
self.assertEqual(paginator.previous_page, None)
for user in paginator:
self.assertIsInstance(user, User)
self.assertIsInstance(paginator.to_json(), str)
================================================
FILE: tests/sqlite/builder/test_sqlite_query_builder.py
================================================
import inspect
import unittest
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from tests.utils import MockConnectionFactory
from src.masoniteorm.exceptions import ModelNotFound, HTTP404
class UserMock(Model):
__connection__ = "dev"
__table__ = "users"
class BaseTestQueryBuilder:
maxDiff = None
def get_builder(self, table="users"):
connection = MockConnectionFactory().make("sqlite")
return QueryBuilder(
self.grammar,
connection_class=connection,
connection="mysql",
table=table,
dry=True,
)
def test_sum(self):
builder = self.get_builder()
builder.sum("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_sum_aggregate(self):
builder = self.get_builder()
builder.aggregate("SUM", "age")
sql = getattr(self, "sum")()
self.assertEqual(builder.to_sql(), sql)
def test_sum_aggregate_with_alias(self):
builder = self.get_builder()
builder.aggregate("SUM", "age", alias="number")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_sum_aggregate_with_alias_in_column_name(self):
builder = self.get_builder()
builder.sum("age as number")
sql = getattr(self, "sum_aggregate_with_alias")()
self.assertEqual(builder.to_sql(), sql)
def test_where_like(self):
builder = self.get_builder()
builder.where("age", "like", "%name%")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_like(self):
builder = self.get_builder()
builder.where("age", "not like", "%name%")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_max(self):
builder = self.get_builder()
builder.max("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_min(self):
builder = self.get_builder()
builder.min("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_avg(self):
builder = self.get_builder()
builder.avg("age")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_all(self):
builder = self.get_builder()
builder.all()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_get(self):
builder = self.get_builder()
builder.get()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_first(self):
builder = self.get_builder().first(query=True)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_last(self):
UserMock.order_by("id", "DESC").first().id == UserMock.last("id").id
def test_last_with_default_primary_key(self):
UserMock.order_by("id", "DESC").first().id == UserMock.last().id
def test_first_or_fail_exception(self):
with self.assertRaises(ModelNotFound):
user = self.get_builder().where("name", "=", "Marlysson").first_or_fail()
def test_find_or_fail_exception(self):
with self.assertRaises(ModelNotFound):
user = UserMock.find_or_fail(1000)
def test_find_or_404_exception(self):
with self.assertRaises(HTTP404):
user = UserMock.find_or_404(10000)
def test_select(self):
builder = self.get_builder()
builder.select("name", "email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select_multiple(self):
builder = self.get_builder()
builder.select("name, email")
sql = getattr(self, "select")()
self.assertEqual(builder.to_sql(), sql)
def test_add_select(self):
builder = self.get_builder()
sql = (
builder.select("name")
.add_select("phone_count", lambda q: q.count("*").table("phones"))
.add_select("salary", lambda q: q.count("*").table("salary"))
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_add_select_no_table(self):
builder = self.get_builder(table=None)
sql = (
builder.add_select(
"other_test", lambda q: q.max("updated_at").table("different_table")
)
.add_select(
"some_alias", lambda q: q.max("updated_at").table("another_table")
)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_add_select_with_raw(self):
builder = self.get_builder(table=None)
sql = (
builder.select_raw("max(updated_at) as test")
.from_("some_table")
.add_select(
"other_test",
lambda query: (
query.max("updated_at")
.from_("different_table")
.where("some_id", "=", "3")
),
)
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_select_raw(self):
builder = self.get_builder()
builder.select_raw("count(email) as email_count")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_create(self):
builder = self.get_builder()
builder.create(
{"name": "Corentin All", "email": "corentin@yopmail.com"}, query=True
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_delete(self):
builder = self.get_builder()
builder.delete("name", "Joe", query=True)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where(self):
builder = self.get_builder()
builder.where("name", "Joe")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_dictionary(self):
builder = self.get_builder()
builder.where({"name": "Joe"})
sql = getattr(self, "where")()
self.assertEqual(builder.to_sql(), sql)
def test_where_exists(self):
builder = self.get_builder()
builder.where_exists("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_limit(self):
builder = self.get_builder()
builder.limit(5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_offset(self):
builder = self.get_builder()
builder.offset(5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_offset_with_limit(self):
builder = self.get_builder()
builder.limit(2).offset(5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_join(self):
builder = self.get_builder()
builder.join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_left_join(self):
builder = self.get_builder()
builder.left_join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_right_join(self):
builder = self.get_builder()
builder.right_join("profiles", "users.id", "=", "profiles.user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_update(self):
builder = self.get_builder().update(
{"name": "Joe", "email": "joe@yopmail.com"}, dry=True
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_increment(self):
builder = self.get_builder()
builder_sql = builder.increment("age", 1)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder_sql, sql)
def test_decrement(self):
builder = self.get_builder()
builder_sql = builder.decrement("age", 1)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder_sql, sql)
def test_count(self):
builder = self.get_builder()
builder.count("id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_asc(self):
builder = self.get_builder()
builder.order_by("email", "asc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_multiple(self):
builder = self.get_builder()
builder.order_by("email, name, active")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_reference_direction(self):
builder = self.get_builder()
builder.order_by("email, name desc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_raw(self):
builder = self.get_builder()
builder.order_by_raw("col asc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_order_by_desc(self):
builder = self.get_builder()
builder.order_by("email", "desc")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_column(self):
builder = self.get_builder()
builder.where_column("name", "username")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_in(self):
builder = self.get_builder()
builder.where_not_in("id", [1, 2, 3])
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_between(self):
builder = self.get_builder()
builder.between("id", 2, 5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_between_persisted(self):
builder = QueryBuilder().table("users").on("dev")
users = builder.between("age", 1, 2).count()
self.assertEqual(users, 2)
def test_not_between(self):
builder = self.get_builder()
builder.not_between("id", 2, 5)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_not_between_persisted(self):
builder = QueryBuilder().table("users").on("dev")
users = builder.where_not_null("id").not_between("age", 1, 2).count()
self.assertEqual(users, 0)
def test_where_in(self):
builder = self.get_builder()
builder.where_in("id", [1, 2, 3])
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_null(self):
builder = self.get_builder()
builder.where_null("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_not_null(self):
builder = self.get_builder()
builder.where_not_null("name")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_having(self):
builder = self.get_builder(table="payments")
builder.select("user_id").avg("salary").group_by("user_id").having(
"salary", ">=", "1000"
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_group_by(self):
builder = self.get_builder(table="payments")
builder.select("user_id").min("salary").group_by("user_id")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_group_by_raw(self):
builder = self.get_builder(table="payments")
builder.select("user_id").min("salary").group_by_raw("count(*)")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_group_by_multiple(self):
builder = self.get_builder(table="payments")
builder.select("user_id").min("salary").group_by("user_id").group_by("salary")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_group_by_multiple_in_same_group_by(self):
builder = self.get_builder(table="payments")
builder.select("user_id").min("salary").group_by("user_id, salary")
sql = getattr(self, "group_by_multiple")()
self.assertEqual(builder.to_sql(), sql)
def test_builder_alone(self):
self.assertTrue(
QueryBuilder(
connection_details={
"default": "sqlite",
"sqlite": {
"driver": "sqlite",
"database": "orm.sqlite3",
"prefix": "",
},
}
).table("users")
)
def test_where_lt(self):
builder = self.get_builder()
builder.where("age", "<", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_lte(self):
builder = self.get_builder()
builder.where("age", "<=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_gt(self):
builder = self.get_builder()
builder.where("age", ">", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_gte(self):
builder = self.get_builder()
builder.where("age", ">=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_where_ne(self):
builder = self.get_builder()
builder.where("age", "!=", "20")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_or_where(self):
builder = self.get_builder()
builder.where("age", "20").or_where("age", "<", 20)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_can_call_with_schema(self):
builder = self.get_builder()
sql = (
builder.table("information_schema.columns")
.select("table_name")
.where("table_name", "users")
.to_sql()
)
self.assertEqual(
sql,
"""SELECT "information_schema"."columns"."table_name" FROM "information_schema"."columns" WHERE "information_schema"."columns"."table_name" = 'users'""",
)
def test_can_call_with_raw(self):
builder = self.get_builder()
sql = builder.on("dev").statement("select * from users")
self.assertTrue(sql)
def test_truncate(self):
builder = self.get_builder()
sql = builder.truncate(dry=True)
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
def test_truncate_without_foreign_keys(self):
builder = self.get_builder()
sql = builder.truncate(foreign_keys=True)
sql_ref = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(sql, sql_ref)
class SQLiteQueryBuilderTest(BaseTestQueryBuilder, unittest.TestCase):
grammar = SQLiteGrammar
def sum(self):
"""
builder = self.get_builder()
builder.sum('age')
"""
return """SELECT SUM("users"."age") AS age FROM "users\""""
def sum_aggregate_with_alias(self):
"""
builder = self.get_builder()
builder.sum('age')
"""
return """SELECT SUM("users"."age") AS number FROM "users\""""
def max(self):
"""
builder = self.get_builder()
builder.max('age')
"""
return """SELECT MAX("users"."age") AS age FROM "users\""""
def min(self):
"""
builder = self.get_builder()
builder.min('age')
"""
return """SELECT MIN("users"."age") AS age FROM "users\""""
def avg(self):
"""
builder = self.get_builder()
builder.avg('age')
"""
return """SELECT AVG("users"."age") AS age FROM "users\""""
def first(self):
"""
builder = self.get_builder()
builder.first()
"""
return """SELECT * FROM "users" LIMIT 1"""
def all(self):
"""
builder = self.get_builder()
builder.all()
"""
return """SELECT * FROM "users\""""
def get(self):
"""
builder = self.get_builder()
builder.get()
"""
return """SELECT * FROM "users\""""
def select(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return """SELECT "users"."name", "users"."email" FROM "users\""""
def select_multiple(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return """SELECT "users"."name", "users"."email" FROM "users\""""
def add_select(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return """SELECT "users"."name", (SELECT COUNT(*) AS m_count_reserved FROM "phones") AS phone_count, (SELECT COUNT(*) AS m_count_reserved FROM "salary") AS salary FROM "users\""""
def add_select_no_table(self):
"""
builder = self.get_builder()
builder.select('name', 'email')
"""
return (
"SELECT "
'(SELECT MAX("different_table"."updated_at") AS updated_at FROM "different_table") AS other_test, '
'(SELECT MAX("another_table"."updated_at") AS updated_at FROM "another_table") AS some_alias'
)
def add_select_with_raw(self):
"""
builder
.select_raw("max(updated_at) as test").from_("some_table")
.add_select(
"other_test",
lambda query: (
query.max("updated_at")
.from_("different_table")
.where(
"some_id", "=",
"3"
)
),
)
"""
return (
"SELECT max(updated_at) as test, "
'(SELECT MAX("different_table"."updated_at") AS updated_at '
'FROM "different_table" '
'WHERE "different_table"."some_id" = \'3\') AS other_test '
'FROM "some_table"'
)
def select_raw(self):
"""
builder = self.get_builder()
builder.select_raw('count(email) as email_count')
"""
return """SELECT count(email) as email_count FROM "users\""""
def create(self):
"""
builder = get_builder()
builder.create({"name": "Corentin All", 'email': 'corentin@yopmail.com'})
"""
return """INSERT INTO "users" ("name", "email") VALUES ('Corentin All', 'corentin@yopmail.com')"""
def delete(self):
"""
builder = get_builder()
builder.delete("name', 'Joe')
"""
return """DELETE FROM "users" WHERE "name" = 'Joe'"""
def where(self):
"""
builder = get_builder()
builder.where('name', 'Joe')
"""
return """SELECT * FROM "users" WHERE "users"."name" = 'Joe'"""
def where_exists(self):
"""
builder = get_builder()
builder.where_exists('name')
"""
return """SELECT * FROM "users" WHERE EXISTS 'name'"""
def limit(self):
"""
builder = get_builder()
builder.limit(5)
"""
return """SELECT * FROM "users" LIMIT 5"""
def offset(self):
"""
builder = get_builder()
builder.offset(5)
"""
return """SELECT * FROM "users" LIMIT -1 OFFSET 5"""
def offset_with_limit(self):
"""
builder = get_builder()
builder.limit(2).offset(5)
"""
return """SELECT * FROM "users" LIMIT 2 OFFSET 5"""
def join(self):
"""
builder.join("profiles", "users.id", "=", "profiles.user_id")
"""
return """SELECT * FROM "users" INNER JOIN "profiles" ON "users"."id" = "profiles"."user_id\""""
def left_join(self):
"""
builder.left_join("profiles", "users.id", "=", "profiles.user_id")
"""
return """SELECT * FROM "users" LEFT JOIN "profiles" ON "users"."id" = "profiles"."user_id\""""
def right_join(self):
"""
builder.right_join("profiles", "users.id", "=", "profiles.user_id")
"""
return """SELECT * FROM "users" LEFT JOIN "profiles" ON "users"."id" = "profiles"."user_id\""""
def update(self):
"""
builder.update({"name": "Joe", "email": "joe@yopmail.com"})
"""
return """UPDATE "users" SET "name" = 'Joe', "email" = 'joe@yopmail.com'"""
def increment(self):
"""
builder.increment('age', 1)
"""
return """UPDATE "users" SET "age" = "age" + '1'"""
def decrement(self):
"""
builder.decrement('age', 1)
"""
return """UPDATE "users" SET "age" = "age" - '1'"""
def count(self):
"""
builder.count(id)
"""
return """SELECT COUNT("users"."id") AS id FROM "users\""""
def order_by_asc(self):
"""
builder.order_by('email', 'asc')
"""
return """SELECT * FROM "users" ORDER BY "email" ASC"""
def order_by_multiple(self):
"""
builder.order_by('email', 'asc')
"""
return (
"""SELECT * FROM "users" ORDER BY "email" ASC, "name" ASC, "active" ASC"""
)
def order_by_raw(self):
"""
builder.order_by('email', 'asc')
"""
return """SELECT * FROM "users" ORDER BY col asc"""
def order_by_reference_direction(self):
"""
builder.order_by('email', 'asc')
"""
return """SELECT * FROM "users" ORDER BY "email" ASC, "name" DESC"""
def order_by_desc(self):
"""
builder.order_by('email', 'des')
"""
return """SELECT * FROM "users" ORDER BY "email" DESC"""
def where_column(self):
"""
builder.where_column('name', 'username')
"""
return """SELECT * FROM "users" WHERE "users"."name" = "users"."username\""""
def where_null(self):
"""
builder.where_null('name')
"""
return """SELECT * FROM "users" WHERE "users"."name" IS NULL"""
def where_not_null(self):
"""
builder.where_null('name')
"""
return """SELECT * FROM "users" WHERE "users"."name" IS NOT NULL"""
def where_not_in(self):
"""
builder.where_not_in('id', [1, 2, 3])
"""
return """SELECT * FROM "users" WHERE "users"."id" NOT IN ('1','2','3')"""
def where_in(self):
"""
builder.where_in('id', [1, 2, 3])
"""
return """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')"""
def between(self):
"""
builder.between('id', 2, 5)
"""
return """SELECT * FROM "users" WHERE "users"."id" BETWEEN '2' AND '5'"""
def not_between(self):
"""
builder.not_between('id', 2, 5)
"""
return """SELECT * FROM "users" WHERE "users"."id" NOT BETWEEN '2' AND '5'"""
def having(self):
"""
builder.select('user_id').avg('salary').group_by('user_id').having('salary', '>=', '1000')
"""
return """SELECT "payments"."user_id", AVG("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id" HAVING "payments"."salary" >= '1000'"""
def group_by(self):
"""
builder.select('user_id').min('salary').group_by('user_id')
"""
return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id\""""
def group_by_multiple(self):
"""
builder.select('user_id').min('salary').group_by('user_id')
"""
return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY "payments"."user_id", "payments"."salary\""""
def group_by_raw(self):
"""
builder.select('user_id').min('salary').group_by('user_id')
"""
return """SELECT "payments"."user_id", MIN("payments"."salary") AS salary FROM "payments" GROUP BY count(*)"""
def where_lt(self):
"""
builder = self.get_builder()
builder.where('age', '<', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" < '20'"""
def where_lte(self):
"""
builder = self.get_builder()
builder.where('age', '<=', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" <= '20'"""
def where_gt(self):
"""
builder = self.get_builder()
builder.where('age', '>', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" > '20'"""
def where_gte(self):
"""
builder = self.get_builder()
builder.where('age', '>=', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" >= '20'"""
def where_ne(self):
"""
builder = self.get_builder()
builder.where('age', '!=', '20')
"""
return """SELECT * FROM "users" WHERE "users"."age" != '20'"""
def or_where(self):
"""
builder = self.get_builder()
builder.where('age', '20').or_where('age','<', 20)
"""
return """SELECT * FROM "users" WHERE "users"."age" = '20' OR "users"."age" < '20'"""
def where_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%")
"""
return """SELECT * FROM "users" WHERE "users"."age" LIKE '%name%'"""
def where_not_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%")
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'"""
def test_when(self):
builder = self.get_builder()
sql = builder.when(19 > 18, lambda q: q.where("age_restricted", 1)).to_sql()
return self.assertEqual(
sql, """SELECT * FROM "users" WHERE "users"."age_restricted" = '1'"""
)
builder = self.get_builder()
sql = builder.when(17 > 18, lambda q: q.where("age_restricted", 1)).to_sql()
return self.assertEqual(sql, """SELECT * FROM "users\"""")
def truncate(self):
"""
builder = self.get_builder()
builder.truncate()
"""
return """DELETE FROM "users\""""
def truncate_without_foreign_keys(self):
"""
builder = self.get_builder()
builder.truncate(foreign_keys=True)
"""
return [
"PRAGMA foreign_keys = OFF",
'DELETE FROM "users"',
"PRAGMA foreign_keys = ON",
]
def test_latest(self):
builder = self.get_builder()
builder.latest("email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def test_oldest(self):
builder = self.get_builder()
builder.oldest("email")
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(builder.to_sql(), sql)
def oldest(self):
"""
builder.order_by('email', 'asc')
"""
return """SELECT * FROM "users" ORDER BY "email" ASC"""
def latest(self):
"""
builder.order_by('email', 'des')
"""
return """SELECT * FROM "users" ORDER BY "email" DESC"""
================================================
FILE: tests/sqlite/builder/test_sqlite_query_builder_eager_loading.py
================================================
import inspect
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to, has_many
class Logo(Model):
__connection__ = "dev"
class Article(Model):
__connection__ = "dev"
@belongs_to("id", "article_id")
def logo(self):
return Logo
@belongs_to("user_id", "id")
def user(self):
return User
class Profile(Model):
__connection__ = "dev"
class User(Model):
__connection__ = "dev"
__with__ = ["articles.logo"]
@has_many("id", "user_id")
def articles(self):
return Article
@belongs_to("id", "user_id")
def profile(self):
return Profile
class EagerUser(Model):
__connection__ = "dev"
__with__ = ("profile",)
__table__ = "users"
@belongs_to("id", "user_id")
def profile(self):
return Profile
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users", model=User()):
connection = ConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar,
connection="dev",
table=table,
model=model,
connection_details=DATABASES,
).on("dev")
def test_with(self):
builder = self.get_builder()
result = builder.with_("profile").get()
for model in result:
if model.profile:
self.assertEqual(model.profile.title, "title")
def test_with_from_model(self):
builder = EagerUser
result = builder.get()
for model in result:
if model.profile:
self.assertEqual(model.profile.title, "title")
def test_with_first(self):
builder = self.get_builder()
result = builder.with_("profile").where("id", 1).first()
self.assertEqual(result.profile.title, "title")
def test_with_where_no_relation(self):
builder = self.get_builder()
result = builder.with_("profile").where("id", 5).first()
result.serialize()
def test_with_multiple_per_same_relation(self):
builder = self.get_builder()
result = User.with_("articles", "articles.logo").where("id", 1).first()
self.assertTrue(result.serialize()["articles"])
self.assertTrue(result.serialize()["articles"][0]["logo"])
================================================
FILE: tests/sqlite/builder/test_sqlite_query_builder_relationships.py
================================================
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory
from dotenv import load_dotenv
load_dotenv(".env")
class Logo(Model):
__connection__ = "dev"
class Article(Model):
__connection__ = "dev"
@belongs_to("id", "article_id")
def logo(self):
return Logo
class Profile(Model):
__connection__ = "dev"
class User(Model):
__connection__ = "dev"
@belongs_to("id", "user_id")
def articles(self):
return Article
@belongs_to("id", "user_id")
def profile(self):
return Profile
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users"):
connection = MockConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar, connection_class=connection, table=table, model=User()
)
def test_has(self):
builder = self.get_builder()
sql = builder.has("articles").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM "users" WHERE EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\""""
""")""",
)
def test_doesnt_have(self):
builder = self.get_builder()
sql = builder.doesnt_have("articles").to_sql()
self.assertEqual(
sql,
"""SELECT * FROM "users" WHERE NOT EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\""""
""")""",
)
def test_where_doesnt_have(self):
builder = self.get_builder()
sql = builder.where_doesnt_have(
"articles", lambda q: q.where("title", "Eggs and Ham")
).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM "users" WHERE NOT EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND "articles"."title" = 'Eggs and Ham'"""
""")""",
)
def test_where_has_query(self):
builder = self.get_builder()
sql = builder.where_has("articles", lambda q: q.where("active", 1)).to_sql()
self.assertEqual(
sql,
"""SELECT * FROM "users" WHERE EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND "articles"."active" = '1'"""
""")""",
)
def test_relationship_multiple_has(self):
to_sql = User.has("articles", "profile").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\""""
""") AND EXISTS ("""
"""SELECT * FROM "profiles" WHERE "profiles"."user_id" = "users"."id\""""
""")""",
)
def test_relationship_multiple_has_calls(self):
to_sql = User.has("articles").has("profile").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE EXISTS ("""
"""SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id\""""
""") AND EXISTS ("""
"""SELECT * FROM "profiles" WHERE "profiles"."user_id" = "users"."id\""""
""")""",
)
def test_nested_has(self):
to_sql = User.has("articles.logo").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "articles" WHERE "articles"."user_id" = "users"."id" AND EXISTS (SELECT * FROM "logos" WHERE "logos"."article_id" = "articles"."id"))""",
)
def test_joins(self):
to_sql = self.get_builder().joins("articles").to_sql()
self.assertEqual(
to_sql,
"""SELECT * FROM "users" INNER JOIN "articles" ON "users"."id" = "articles"."user_id\"""",
)
================================================
FILE: tests/sqlite/builder/test_sqlite_transaction.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from tests.integrations.config.database import DB
from src.masoniteorm.collection import Collection
class User(Model):
__connection__ = "dev"
__timestamps__ = False
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def get_builder(self, table="users"):
connection = ConnectionFactory().make("sqlite")
return QueryBuilder(
grammar=SQLiteGrammar,
connection="dev",
table=table,
model=User(),
connection_details=DATABASES,
).on("dev")
def test_transaction(self):
builder = self.get_builder()
builder.begin()
builder.create({"name": "phillip3", "email": "phillip3"})
user = builder.where("name", "phillip3").first()
self.assertEqual(user["name"], "phillip3")
builder.rollback()
user = builder.where("name", "phillip3").first()
self.assertEqual(user, None)
def test_transaction_globally(self):
connection = DB.begin_transaction("dev")
self.assertEqual(connection, self.get_builder().new_connection())
DB.commit("dev")
DB.begin_transaction("dev")
DB.rollback("dev")
def test_chunking(self):
for users in self.get_builder().chunk(10):
self.assertIsInstance(users, Collection)
================================================
FILE: tests/sqlite/grammar/test_sqlite_delete_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
class BaseDeleteGrammarTest:
def setUp(self):
self.builder = QueryBuilder(SQLiteGrammar, table="users")
def test_can_compile_delete(self):
to_sql = self.builder.delete("id", 1, query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_delete_in(self):
to_sql = self.builder.delete("id", [1, 2, 3], query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_delete_with_where(self):
to_sql = (
self.builder.where("age", 20)
.where("profile", 1)
.delete(query=True)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestSqliteDeleteGrammar(BaseDeleteGrammarTest, unittest.TestCase):
grammar = "sqlite"
def can_compile_delete(self):
"""
(
self.builder
.delete('id', 1)
.to_sql()
)
"""
return """DELETE FROM "users" WHERE "id" = '1'"""
def can_compile_delete_in(self):
"""
(
self.builder
.delete('id', 1)
.to_sql()
)
"""
return """DELETE FROM "users" WHERE "id" IN ('1','2','3')"""
def can_compile_delete_with_where(self):
"""
(
self.builder
.where('age', 20)
.where('profile', 1)
.delete()
.to_sql()
)
"""
return """DELETE FROM "users" WHERE "age" = '20' AND "profile" = '1'"""
================================================
FILE: tests/sqlite/grammar/test_sqlite_insert_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
class BaseInsertGrammarTest:
def setUp(self):
self.builder = QueryBuilder(SQLiteGrammar, table="users")
def test_can_compile_insert(self):
to_sql = self.builder.create({"name": "Joe"}, query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_insert_with_keywords(self):
to_sql = self.builder.create(name="Joe", query=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create(self):
to_sql = self.builder.bulk_create(
# These keys are intentionally out of order to show column to value alignment works
[
{"name": "Joe", "age": 5},
{"age": 35, "name": "Bill"},
{"name": "John", "age": 10},
],
query=True,
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create_qmark(self):
to_sql = self.builder.bulk_create(
[{"name": "Joe"}, {"name": "Bill"}, {"name": "John"}], query=True
).to_qmark()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_bulk_create_multiple(self):
to_sql = self.builder.bulk_create(
[
{"name": "Joe", "active": "1"},
{"name": "Bill", "active": "1"},
{"name": "John", "active": "1"},
],
query=True,
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestSqliteUpdateGrammar(BaseInsertGrammarTest, unittest.TestCase):
grammar = "sqlite"
def can_compile_insert(self):
"""
self.builder.create({
'name': 'Joe'
}).to_sql()
"""
return """INSERT INTO "users" ("name") VALUES ('Joe')"""
def can_compile_insert_with_keywords(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO "users" ("name") VALUES ('Joe')"""
def can_compile_bulk_create(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO "users" ("age", "name") VALUES ('5', 'Joe'), ('35', 'Bill'), ('10', 'John')"""
def can_compile_bulk_create_multiple(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO "users" ("active", "name") VALUES ('1', 'Joe'), ('1', 'Bill'), ('1', 'John')"""
def can_compile_bulk_create_qmark(self):
"""
self.builder.create(name="Joe").to_sql()
"""
return """INSERT INTO "users" ("name") VALUES ('?'), ('?'), ('?')"""
================================================
FILE: tests/sqlite/grammar/test_sqlite_select_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.testing import BaseTestCaseSelectGrammar
class TestSQLiteGrammar(BaseTestCaseSelectGrammar, unittest.TestCase):
grammar = SQLiteGrammar
maxDiff = None
def can_compile_select(self):
"""
self.builder.to_sql()
"""
return """SELECT * FROM "users\""""
def can_compile_order_by_and_first(self):
"""
self.builder.order_by('id', 'asc').first()
"""
return """SELECT * FROM "users" ORDER BY "id" ASC LIMIT 1"""
def can_compile_with_columns(self):
"""
self.builder.select('username', 'password').to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users\""""
def can_compile_with_where(self):
"""
self.builder.select('username', 'password').where('id', 1).to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1'"""
def can_compile_with_several_where(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe'"""
def can_compile_with_several_where_and_limit(self):
"""
self.builder.select('username', 'password').where('id', 1).where('username', 'joe').limit(10).to_sql()
"""
return """SELECT "users"."username", "users"."password" FROM "users" WHERE "users"."id" = '1' AND "users"."username" = 'joe' LIMIT 10"""
def can_compile_with_sum(self):
"""
self.builder.sum('age').to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users\""""
def can_compile_with_max(self):
"""
self.builder.max('age').to_sql()
"""
return """SELECT MAX("users"."age") AS age FROM "users\""""
def can_compile_with_max_and_columns(self):
"""
self.builder.select('username').max('age').to_sql()
"""
return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\""""
def can_compile_with_max_and_columns_different_order(self):
"""
self.builder.max('age').select('username').to_sql()
"""
return """SELECT "users"."username", MAX("users"."age") AS age FROM "users\""""
def can_compile_with_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').to_sql()
"""
return """SELECT "users"."username" FROM "users" ORDER BY "age" DESC"""
def can_compile_with_multiple_order_by(self):
"""
self.builder.select('username').order_by('age', 'desc').order_by('name').to_sql()
"""
return (
"""SELECT "users"."username" FROM "users" ORDER BY "age" DESC, "name" ASC"""
)
def can_compile_with_group_by(self):
"""
self.builder.select('username').group_by('age').to_sql()
"""
return """SELECT "users"."username" FROM "users" GROUP BY "users"."age\""""
def can_compile_where_in(self):
"""
self.builder.select('username').where_in('age', [1,2,3]).to_sql()
"""
return """SELECT "users"."username" FROM "users" WHERE "users"."age" IN ('1','2','3')"""
def can_compile_where_in_empty(self):
"""
self.builder.where_in('age', []).to_sql()
"""
return """SELECT * FROM "users" WHERE 0 = 1"""
def can_compile_where_not_in(self):
"""
self.builder.select('username').where_not_in('age', [1,2,3]).to_sql()
"""
return """SELECT "users"."username" FROM "users" WHERE "users"."age" NOT IN ('1','2','3')"""
def can_compile_where_null(self):
"""
self.builder.select('username').where_null('age').to_sql()
"""
return """SELECT "users"."username" FROM "users" WHERE "users"."age" IS NULL"""
def can_compile_where_not_null(self):
"""
self.builder.select('username').where_not_null('age').to_sql()
"""
return (
"""SELECT "users"."username" FROM "users" WHERE "users"."age" IS NOT NULL"""
)
def can_compile_where_raw(self):
"""
self.builder.where_raw(""age" = '18'").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" = '18'"""
def can_compile_select_raw(self):
"""
self.builder.select_raw("COUNT(*)").to_sql()
"""
return """SELECT COUNT(*) FROM "users\""""
def can_compile_limit_and_offset(self):
"""
self.builder.limit(10).offset(10).to_sql()
"""
return """SELECT * FROM "users" LIMIT 10 OFFSET 10"""
def can_compile_select_raw_with_select(self):
"""
self.builder.select('id').select_raw("COUNT(*)").to_sql()
"""
return """SELECT "users"."id", COUNT(*) FROM "users\""""
def can_compile_count(self):
"""
self.builder.count().to_sql()
"""
return """SELECT COUNT(*) AS m_count_reserved FROM "users\""""
def can_compile_count_column(self):
"""
self.builder.count().to_sql()
"""
return """SELECT COUNT("users"."money") AS money FROM "users\""""
def can_compile_where_column(self):
"""
self.builder.where_column('name', 'email').to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" = "users"."email\""""
def can_compile_or_where(self):
"""
self.builder.where('name', 2).or_where('name', 3).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" = '2' OR "users"."name" = '3'"""
def can_grouped_where(self):
"""
self.builder.where(lambda query: query.where('age', 2).where('name', 'Joe')).to_sql()
"""
return """SELECT * FROM "users" WHERE ("users"."age" = '2' AND "users"."name" = 'Joe')"""
def can_compile_sub_select(self):
"""
self.builder.where_in('name',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age')
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users")"""
def can_compile_sub_select_where(self):
"""
self.builder.where_in('age',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('age').where('age', 2).where('name', 'Joe')
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" IN (SELECT "users"."age" FROM "users" WHERE "users"."age" = '2' AND "users"."name" = 'Joe')"""
def can_compile_sub_select_value(self):
"""
self.builder.where('name',
self.builder.new().sum('age')
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" = (SELECT SUM("users"."age") AS age FROM "users")"""
def can_compile_complex_sub_select(self):
"""
self.builder.where_in('name',
(QueryBuilder(GrammarFactory.make(self.grammar), table='users')
.select('age').where_in('email',
QueryBuilder(GrammarFactory.make(self.grammar), table='users').select('email')
))
).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."name" IN (SELECT "users"."age" FROM "users" WHERE "users"."email" IN (SELECT "users"."email" FROM "users"))"""
def can_compile_exists(self):
"""
self.builder.select('age').where_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return """SELECT "users"."age" FROM "users" WHERE EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')"""
def can_compile_not_exists(self):
"""
self.builder.select('age').where_not_exists(
self.builder.new().select('username').where('age', 12)
).to_sql()
"""
return """SELECT "users"."age" FROM "users" WHERE NOT EXISTS (SELECT "users"."username" FROM "users" WHERE "users"."age" = '12')"""
def can_compile_having(self):
"""
builder.sum('age').group_by('age').having('age').to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\""""
def can_compile_having_order(self):
"""
builder.sum('age').group_by('age').having('age').order_by('age', 'desc').to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age\" ORDER "users"."age" DESC"""
def can_compile_having_raw(self):
"""
builder.select_raw("COUNT(*) as counts").having_raw("counts > 18").to_sql()
"""
return """SELECT COUNT(*) as counts FROM "users" HAVING counts > 18"""
def can_compile_having_with_expression(self):
"""
builder.sum('age').group_by('age').having('age', 10).to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" = '10'"""
def can_compile_having_with_greater_than_expression(self):
"""
builder.sum('age').group_by('age').having('age', '>', 10).to_sql()
"""
return """SELECT SUM("users"."age") AS age FROM "users" GROUP BY "users"."age" HAVING "users"."age" > '10'"""
def can_compile_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id\""""
def can_compile_left_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return """SELECT * FROM "users" LEFT JOIN "contacts" ON "users"."id" = "contacts"."user_id\""""
def can_compile_multiple_join(self):
"""
builder.join('contacts', 'users.id', '=', 'contacts.user_id').to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "contacts" ON "users"."id" = "contacts"."user_id" INNER JOIN "posts" ON "comments"."post_id" = "posts"."id\""""
def can_compile_between(self):
"""
builder.between('age', 18, 21).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" BETWEEN '18' AND '21'"""
def can_compile_not_between(self):
"""
builder.not_between('age', 18, 21).to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT BETWEEN '18' AND '21'"""
def test_can_compile_where_raw(self):
to_sql = self.builder.where_raw(""" "age" = '18'""").to_sql()
self.assertEqual(to_sql, """SELECT * FROM "users" WHERE "age" = '18'""")
def test_can_compile_where_raw_and_where_with_multiple_bindings(self):
query = self.builder.where_raw(
""" "age" = '?' AND "is_admin" = '?' """, [18, True]
).where("email", "test@example.com")
self.assertEqual(
query.to_qmark(),
"""SELECT * FROM "users" WHERE "age" = '?' AND "is_admin" = '?' AND "users"."email" = '?'""",
)
self.assertEqual(query._bindings, [18, True, "test@example.com"])
def test_can_compile_select_raw(self):
to_sql = self.builder.select_raw("COUNT(*)").to_sql()
self.assertEqual(to_sql, """SELECT COUNT(*) FROM "users\"""")
def test_can_compile_select_raw_with_select(self):
to_sql = self.builder.select("id").select_raw("COUNT(*)").to_sql()
self.assertEqual(to_sql, """SELECT "users"."id", COUNT(*) FROM "users\"""")
def can_compile_first_or_fail(self):
"""
builder = self.get_builder()
builder.where("is_admin", "=", True).first_or_fail()
"""
return """SELECT * FROM "users" WHERE "users"."is_admin" = '1' LIMIT 1"""
def where_not_like(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT LIKE '%name%'"""
def where_like(self):
"""
builder = self.get_builder()
builder.where("age", "like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" LIKE '%name%'"""
def where_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "regexp", "Joe").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" REGEXP 'Joe'"""
def where_not_regexp(self):
"""
builder = self.get_builder()
builder.where("age", "regexp", "Joe").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."age" NOT REGEXP 'Joe'"""
def can_compile_join_clause(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on("bgt.fund", "=", "rg.fund")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt"."dept" = "rg"."dept" AND "bgt"."acct" = "rg"."acct" AND "bgt"."sub" = "rg"."sub\""""
def can_compile_join_clause_with_value(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_value("bgt.active", "=", "1")
.or_on_value("bgt.acct", "=", "1234")
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."active" = '1' OR "bgt"."acct" = '1234'"""
def can_compile_join_clause_with_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_null("bgt.acct")
.or_on_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NULL OR "dept" IS NULL AND "rg"."abc" = '10'"""
def can_compile_join_clause_with_not_null(self):
"""
builder = self.get_builder()
clause = (
JoinClause("report_groups as rg")
.on_not_null("bgt.acct")
.or_on_not_null("bgt.dept")
.on_value("rg.abc", 10)
)
builder.join(clause).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "acct" IS NOT NULL OR "dept" IS NOT NULL AND "rg"."abc" = '10'"""
def can_compile_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.on_null("bgt")
),
).to_sql()
"""
return """SELECT * FROM "users" INNER JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" AND "bgt" IS NULL"""
def can_compile_left_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.left_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return """SELECT * FROM "users" LEFT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL"""
def can_compile_right_join_clause_with_lambda(self):
"""
builder = self.get_builder()
builder.right_join(
"report_groups as rg",
lambda clause: (
clause.on("bgt.fund", "=", "rg.fund")
.or_on_null("bgt")
),
).to_sql()
"""
return """SELECT * FROM "users" LEFT JOIN "report_groups" AS "rg" ON "bgt"."fund" = "rg"."fund" OR "bgt" IS NULL"""
def update_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."votes" >= '100'"""
def shared_lock(self):
"""
builder = self.get_builder()
builder.where("age", "not like", "%name%").to_sql()
"""
return """SELECT * FROM "users" WHERE "users"."votes" >= '100'"""
def can_user_where_raw_and_where(self):
"""
builder.where_raw("age = '18'").where("name", "=", "James").to_sql()
"""
return """SELECT * FROM "users" WHERE age = '18' AND "users"."name" = 'James'"""
def where_exists_with_lambda(self):
return """SELECT * FROM "users" WHERE EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')"""
def where_not_exists_with_lambda(self):
return """SELECT * FROM "users" WHERE NOT EXISTS (SELECT * FROM "users" WHERE "users"."age" = '1')"""
def where_date(self):
return (
"""SELECT * FROM "users" WHERE DATE("users"."created_at") = '2022-06-01'"""
)
def or_where_null(self):
return """SELECT * FROM "users" WHERE "users"."column1" IS NULL OR "users"."column2" IS NULL"""
def select_distinct(self):
return """SELECT DISTINCT "users"."group" FROM "users\""""
================================================
FILE: tests/sqlite/grammar/test_sqlite_update_grammar.py
================================================
import inspect
import unittest
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.expressions import Raw
class BaseTestCaseUpdateGrammar:
def setUp(self):
self.builder = QueryBuilder(SQLiteGrammar, table="users")
def test_can_compile_update(self):
to_sql = (
self.builder.where("name", "bob").update({"name": "Joe"}, dry=True).to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_multiple_update(self):
to_sql = self.builder.update(
{"name": "Joe", "email": "user@email.com"}, dry=True
).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
def test_can_compile_update_with_multiple_where(self):
to_sql = (
self.builder.where("name", "bob")
.where("age", 20)
.update({"name": "Joe"}, dry=True)
.to_sql()
)
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
# def test_can_compile_increment(self):
# to_sql = self.builder.increment("age")
# print(to_sql)
# self.assertTrue(to_sql.isnumeric())
# def test_can_compile_decrement(self):
# to_sql = self.builder.decrement("age", 20).to_sql()
# sql = getattr(
# self, inspect.currentframe().f_code.co_name.replace("test_", "")
# )()
# self.assertEqual(to_sql, sql)
def test_raw_expression(self):
to_sql = self.builder.update({"name": Raw('"username"')}, dry=True).to_sql()
sql = getattr(
self, inspect.currentframe().f_code.co_name.replace("test_", "")
)()
self.assertEqual(to_sql, sql)
class TestSqliteUpdateGrammar(BaseTestCaseUpdateGrammar, unittest.TestCase):
grammar = "sqlite"
def can_compile_update(self):
"""
builder.where('name', 'bob').update({
'name': 'Joe'
}).to_sql()
"""
return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob'"""
def raw_expression(self):
"""
builder.where('name', 'bob').update({
'name': 'Joe'
}).to_sql()
"""
return 'UPDATE "users" SET "name" = "username"'
def can_compile_multiple_update(self):
"""
self.builder.update({"name": "Joe", "email": "user@email.com"}, dry=True).to_sql()
"""
return """UPDATE "users" SET "name" = 'Joe', "email" = 'user@email.com'"""
def can_compile_update_with_multiple_where(self):
"""
builder.where('name', 'bob').where('age', 20).update({
'name': 'Joe'
}).to_sql()
"""
return """UPDATE "users" SET "name" = 'Joe' WHERE "name" = 'bob' AND "age" = '20'"""
def can_compile_increment(self):
"""
builder.increment('age').to_sql()
"""
return """UPDATE "users" SET "age" = "age" + '1'"""
def can_compile_decrement(self):
"""
builder.decrement('age', 20).to_sql()
"""
return """UPDATE "users" SET "age" = "age" - '20'"""
================================================
FILE: tests/sqlite/models/test_attach_detach.py
================================================
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import belongs_to, has_one
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import SQLitePlatform
from tests.integrations.config.database import DATABASES
class Bottle(Model):
__table__ = "bottles"
__connection__ = "dev"
__timestamps__ = False
__fillable__ = ["label"]
@has_one(None, "bottle_id", "id")
def lid(self):
return BottleLid
class BottleLid(Model):
__table__ = "bottle_lids"
__connection__ = "dev"
__timestamps__ = False
__fillable__ = ["colour", "bottle_id"]
@belongs_to(None, "bottle_id", "id")
def bottle(self):
return Bottle
class TestAttachDetach(unittest.TestCase):
def setUp(self):
self.schema = Schema(
connection="dev",
connection_details=DATABASES,
platform=SQLitePlatform,
).on("dev")
with self.schema.create_table_if_not_exists("bottles") as table:
table.integer("id").primary()
table.string("label")
with self.schema.create_table_if_not_exists("bottle_lids") as table:
table.integer("id").primary()
table.string("colour")
table.integer("bottle_id", nullable=True) # HasOne / BelongsTo relationship
def tearDown(self):
BottleLid.delete()
Bottle.delete()
def test_has_one_attach_detach(self):
bottle = Bottle.create(
{
"label": "cola",
}
)
# test unsaved
red_lid = BottleLid().fill(
{
"colour": "Red",
}
)
current_lid = bottle.attach("lid", red_lid)
self.assertIsNotNone(bottle.lid)
self.assertIsInstance(current_lid, BottleLid)
self.assertTrue(current_lid.is_created())
self.assertEqual(bottle.id, current_lid.bottle_id)
bottle.detach("lid", current_lid)
test_lid = BottleLid.find(current_lid.id)
self.assertIsNone(test_lid.bottle_id)
self.assertIsNone(bottle.lid)
# test usning a pre-saved record
green_lid = BottleLid.create(
{
"colour": "Green",
}
)
current_lid = bottle.attach("lid", green_lid)
self.assertIsNotNone(bottle.lid)
self.assertIsInstance(current_lid, BottleLid)
self.assertEqual(bottle.id, current_lid.bottle_id)
bottle.detach("lid", current_lid)
test_lid = BottleLid.find(current_lid.id)
self.assertIsNone(test_lid.bottle_id)
self.assertIsNone(bottle.lid)
def test_belongs_to_attach_detach(self):
bottle = Bottle.create(
{
"label": "milk",
}
)
# test unsaved
red_lid = BottleLid().fill(
{
"colour": "Red",
}
)
current_lid = red_lid.attach("bottle", bottle)
self.assertIsNotNone(bottle.lid)
self.assertIsInstance(current_lid, BottleLid)
self.assertTrue(current_lid.is_created())
self.assertEqual(bottle.id, current_lid.bottle_id)
current_lid.detach("bottle", bottle)
test_lid = BottleLid.find(current_lid.id)
self.assertIsNone(test_lid.bottle_id)
self.assertIsNone(bottle.lid)
# test usning a pre-saved record
green_lid = BottleLid.create(
{
"colour": "Green",
}
)
current_lid = green_lid.attach("bottle", bottle)
self.assertIsNotNone(bottle.lid)
self.assertIsInstance(current_lid, BottleLid)
self.assertEqual(bottle.id, current_lid.bottle_id)
current_lid.detach("bottle", bottle)
test_lid = BottleLid.find(current_lid.id)
self.assertIsNone(test_lid.bottle_id)
self.assertIsNone(bottle.lid)
================================================
FILE: tests/sqlite/models/test_observers.py
================================================
import inspect
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to
from tests.utils import MockConnectionFactory
from tests.integrations.config.database import DB
class TestM:
pass
class UserObserver:
def created(self, user):
TestM.observed_created = 1
def creating(self, user):
TestM.observed_creating = 1
def saving(self, user):
TestM.observed_saving = 1
def saved(self, user):
TestM.observed_saved = 1
def updating(self, user):
TestM.observed_updating = 1
def updated(self, user):
TestM.observed_updated = 1
def booted(self, user):
TestM.observed_booting = 1
def booting(self, user):
TestM.observed_booted = 1
def hydrating(self, user):
TestM.observed_hydrating = 1
def hydrated(self, user):
TestM.observed_hydrated = 1
def deleting(self, user):
TestM.observed_deleting = 1
def deleted(self, user):
TestM.observed_deleted = 1
class Observer(Model):
__connection__ = "dev"
__timestamps__ = False
__observers__ = {}
Observer.observe(UserObserver())
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def test_created_is_observed(self):
# DB.begin_transaction("dev")
user = Observer.create({"name": "joe"})
self.assertEqual(TestM.observed_creating, 1)
self.assertEqual(TestM.observed_created, 1)
# DB.rollback("dev")
def test_saving_is_observed(self):
# DB.begin_transaction("dev")
user = Observer.hydrate({"id": 1, "name": "joe"})
user.name = "bill"
user.save()
self.assertEqual(TestM.observed_saving, 1)
self.assertEqual(TestM.observed_saved, 1)
# DB.rollback("dev")
def test_updating_is_observed(self):
# DB.begin_transaction("dev")
user = Observer.hydrate({"id": 1, "name": "joe"})
re = user.update({"name": "bill"})
self.assertEqual(TestM.observed_updated, 1)
self.assertEqual(TestM.observed_updating, 1)
# DB.rollback("dev")
def test_booting_is_observed(self):
# DB.begin_transaction("dev")
user = Observer.hydrate({"id": 1, "name": "joe"})
re = user.update({"name": "bill"})
self.assertEqual(TestM.observed_booting, 1)
self.assertEqual(TestM.observed_booted, 1)
# DB.rollback("dev")
def test_deleting_is_observed(self):
DB.begin_transaction("dev")
user = Observer.hydrate({"id": 10, "name": "joe"})
re = user.delete()
self.assertEqual(TestM.observed_deleting, 1)
self.assertEqual(TestM.observed_deleted, 1)
DB.rollback("dev")
def test_hydrating_is_observed(self):
DB.begin_transaction("dev")
user = Observer.hydrate({"id": 10, "name": "joe"})
self.assertEqual(TestM.observed_hydrating, 1)
self.assertEqual(TestM.observed_hydrated, 1)
DB.rollback("dev")
================================================
FILE: tests/sqlite/models/test_sqlite_model.py
================================================
import inspect
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import ConnectionFactory
from src.masoniteorm.models import Model
from src.masoniteorm.query import QueryBuilder
from src.masoniteorm.query.grammars import SQLiteGrammar
from src.masoniteorm.relationships import belongs_to, belongs_to_many
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform
from tests.utils import MockConnectionFactory
class User(Model):
__connection__ = "dev"
__timestamps__ = False
__dry__ = True
class UserForced(Model):
__connection__ = "dev"
__table__ = "users"
__timestamps__ = False
__dry__ = True
__force_update__ = True
class Select(Model):
__connection__ = "dev"
__selects__ = ["username", "rememember_token as token"]
__dry__ = True
class SelectPass(Model):
__connection__ = "dev"
__dry__ = True
class UserHydrateHidden(Model):
__connection__ = "dev"
__table__ = "users_hidden"
__hidden__ = ["token", "password"]
class Group(Model):
__connection__ = "dev"
__table__ = "groups"
__fillable = ["name"]
__with__ = ["team"]
@belongs_to_many("group_id", "user_id", "id", "id", table="group_user")
def team(self):
return UserHydrateHidden
class BaseTestQueryRelationships(unittest.TestCase):
maxDiff = None
def test_update_specific_record(self):
user = User.first()
sql = user.update({"name": "joe"}).to_sql()
self.assertEqual(
sql,
"""UPDATE "users" SET "name" = 'joe' WHERE "id" = '{}'""".format(user.id),
)
def test_update_all_records(self):
sql = User.update({"name": "joe"}).to_sql()
self.assertEqual(sql, """UPDATE "users" SET "name" = 'joe'""")
def test_can_find_list(self):
sql = User.find(1, query=True).to_sql()
self.assertEqual(sql, """SELECT * FROM "users" WHERE "users"."id" = '1'""")
sql = User.find([1, 2, 3], query=True).to_sql()
self.assertEqual(
sql, """SELECT * FROM "users" WHERE "users"."id" IN ('1','2','3')"""
)
def test_find_or_if_record_not_found(self):
# Insane record number so record cannot be found
record_id = 1_000_000_000_000_000
result = User.find_or(record_id, lambda: "Record not found.")
self.assertEqual(result, "Record not found.")
def test_find_or_if_record_found(self):
record_id = 1
result_id = User.find_or(record_id, lambda: "Record not found.").id
self.assertEqual(result_id, record_id)
def test_can_set_and_retreive_attribute(self):
user = User.hydrate({"id": 1, "name": "joe", "customer_id": 1})
user.customer_id = "CUST1"
self.assertEqual(user.customer_id, "CUST1")
def test_model_can_use_selects(self):
self.assertEqual(
Select.to_sql(),
'SELECT "selects"."username", "selects"."rememember_token" AS token FROM "selects"',
)
def test_model_can_use_selects_from_methods(self):
self.assertEqual(
SelectPass.all(["username"], query=True).to_sql(),
'SELECT "select_passes"."username" FROM "select_passes"',
)
def test_update_only_changed_attributes(self):
user = User.first()
sql = user.update({"name": user.name, "username": "new"}).to_sql()
# unchanged name attribute is not updated
self.assertEqual(
sql,
"""UPDATE "users" SET "username" = 'new' WHERE "id" = '{}'""".format(
user.id
),
)
def test_can_force_update_on_method(self):
user = User.first()
sql = user.update({"name": user.name, "username": "new"}, force=True).to_sql()
self.assertEqual(
sql,
"""UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format(
user.id
),
)
def test_can_force_update_on_model(self):
user = UserForced.first()
sql = user.update({"name": user.name, "username": "new"}).to_sql()
self.assertEqual(
sql,
"""UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format(
user.id
),
)
def test_force_update(self):
user = User.first()
sql = user.force_update({"name": user.name, "username": "new"}).to_sql()
self.assertEqual(
sql,
"""UPDATE "users" SET "name" = 'bill', "username" = 'new' WHERE "id" = '{}'""".format(
user.id
),
)
def test_update_is_not_done_when_no_changes(self):
user = User.first()
sql = user.update({"name": user.name}).to_sql()
self.assertNotIn("UPDATE", sql)
def test_should_collect_correct_amount_data_using_between(self):
class ModelUser(Model):
__connection__ = "dev"
__table__ = "users"
count = User.between("age", 1, 2).get().count()
self.assertEqual(count, 2)
def test_should_collect_correct_amount_data_using_not_between(self):
class ModelUser(Model):
__connection__ = "dev"
__table__ = "users"
count = User.where_not_null("id").not_between("age", 1, 2).get().count()
self.assertEqual(count, 0)
def test_get_columns(self):
columns = User.get_columns()
self.assertEqual(
columns,
[
"id",
"name",
"email",
"password",
"remember_token",
"created_at",
"is_admin",
"age",
"boo",
"tool1",
"tool2",
"active",
"updated_at",
"profile_id",
"name5",
"name6",
"age6",
"age7",
"age8",
"age10",
],
)
def test_should_return_relation_applying_hidden_attributes(self):
schema = Schema(
connection_details=DATABASES, connection="dev", platform=SQLitePlatform
).on("dev")
tables = ["users_hidden", "group_user", "groups"]
for table in tables:
schema.drop_table_if_exists(table)
with schema.create("users_hidden") as blueprint:
blueprint.increments("id")
blueprint.string("name")
blueprint.integer("token")
blueprint.string("password")
blueprint.timestamps()
with schema.create("groups") as blueprint:
blueprint.increments("id")
blueprint.string("name")
blueprint.timestamps()
with schema.create("group_user") as blueprint:
blueprint.increments("id")
blueprint.unsigned_integer("group_id")
blueprint.unsigned_integer("user_id")
blueprint.foreign("group_id").references("id").on("groups")
blueprint.foreign("user_id").references("id").on("users_hidden")
blueprint.timestamps()
UserHydrateHidden.create(
name="Name", password="pass_value", token="token_value"
)
Group.create(name="Group")
user = UserHydrateHidden.first()
group = Group.first()
group.attach_related("team", user)
serialized = Group.first().serialize()
self.assertIn("team", serialized)
self.assertTrue("team", serialized)
relation_serialized = serialized.get("team")
self.assertNotIn("password", relation_serialized)
self.assertNotIn("token", relation_serialized)
for table in tables:
schema.truncate(table)
================================================
FILE: tests/sqlite/relationships/test_sqlite_has_many_through_relationship.py
================================================
import unittest
from src.masoniteorm.collection import Collection
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import has_many_through
from tests.integrations.config.database import DATABASES
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import SQLitePlatform
class Enrolment(Model):
__table__ = "enrolment"
__connection__ = "dev"
__fillable__ = ["active_student_id", "in_course_id"]
class Student(Model):
__table__ = "student"
__connection__ = "dev"
__fillable__ = ["student_id", "name"]
class Course(Model):
__table__ = "course"
__connection__ = "dev"
__fillable__ = ["course_id", "name"]
@has_many_through(
None,
"in_course_id",
"active_student_id",
"course_id",
"student_id"
)
def students(self):
return [Student, Enrolment]
class TestHasManyThroughRelationship(unittest.TestCase):
def setUp(self):
self.schema = Schema(
connection="dev",
connection_details=DATABASES,
platform=SQLitePlatform,
).on("dev")
with self.schema.create_table_if_not_exists("student") as table:
table.integer("student_id").primary()
table.string("name")
with self.schema.create_table_if_not_exists("course") as table:
table.integer("course_id").primary()
table.string("name")
with self.schema.create_table_if_not_exists("enrolment") as table:
table.integer("enrolment_id").primary()
table.integer("active_student_id")
table.integer("in_course_id")
if not Course.count():
Course.builder.new().bulk_create(
[
{"course_id": 10, "name": "Math 101"},
{"course_id": 20, "name": "History 101"},
{"course_id": 30, "name": "Math 302"},
{"course_id": 40, "name": "Biology 302"},
]
)
if not Student.count():
Student.builder.new().bulk_create(
[
{"student_id": 100, "name": "Bob"},
{"student_id": 200, "name": "Alice"},
{"student_id": 300, "name": "Steve"},
{"student_id": 400, "name": "Megan"},
]
)
if not Enrolment.count():
Enrolment.builder.new().bulk_create(
[
{"active_student_id": 100, "in_course_id": 30},
{"active_student_id": 200, "in_course_id": 10},
{"active_student_id": 100, "in_course_id": 10},
{"active_student_id": 400, "in_course_id": 20},
]
)
def test_has_many_through_can_eager_load(self):
courses = Course.where("name", "Math 101").with_("students").get()
students = courses.first().students
self.assertIsInstance(students, Collection)
self.assertEqual(students.count(), 2)
student1 = students.shift()
self.assertIsInstance(student1, Student)
self.assertEqual(student1.name, "Alice")
student2 = students.shift()
self.assertIsInstance(student2, Student)
self.assertEqual(student2.name, "Bob")
# check .first() and .get() produce the same result
single = (
Course.where("name", "History 101")
.with_("students")
.first()
)
self.assertIsInstance(single.students, Collection)
single_get = (
Course.where("name", "History 101").with_("students").get()
)
print(single.students)
print(single_get.first().students)
self.assertEqual(single.students.count(), 1)
self.assertEqual(single_get.first().students.count(), 1)
single_name = single.students.first().name
single_get_name = single_get.first().students.first().name
self.assertEqual(single_name, single_get_name)
def test_has_many_through_eager_load_can_be_empty(self):
courses = (
Course.where("name", "Biology 302")
.with_("students")
.get()
)
self.assertIsNone(courses.first().students)
def test_has_many_through_can_get_related(self):
course = Course.where("name", "Math 101").first()
self.assertIsInstance(course.students, Collection)
self.assertIsInstance(course.students.first(), Student)
self.assertEqual(course.students.count(), 2)
def test_has_many_through_has_query(self):
courses = Course.where_has(
"students", lambda query: query.where("name", "Bob")
)
self.assertEqual(courses.count(), 2)
================================================
FILE: tests/sqlite/relationships/test_sqlite_has_one_through_relationship.py
================================================
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import has_one_through
from tests.integrations.config.database import DATABASES
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import SQLitePlatform
class Port(Model):
__table__ = "ports"
__connection__ = "dev"
__fillable__ = ["port_id", "name", "port_country_id"]
class Country(Model):
__table__ = "countries"
__connection__ = "dev"
__fillable__ = ["country_id", "name"]
class IncomingShipment(Model):
__table__ = "incoming_shipments"
__connection__ = "dev"
__fillable__ = ["shipment_id", "name", "from_port_id"]
@has_one_through(None, "from_port_id", "port_country_id", "port_id", "country_id")
def from_country(self):
return [Country, Port]
class TestHasOneThroughRelationship(unittest.TestCase):
def setUp(self):
self.schema = Schema(
connection="dev",
connection_details=DATABASES,
platform=SQLitePlatform,
).on("dev")
with self.schema.create_table_if_not_exists("incoming_shipments") as table:
table.integer("shipment_id").primary()
table.string("name")
table.integer("from_port_id")
with self.schema.create_table_if_not_exists("ports") as table:
table.integer("port_id").primary()
table.string("name")
table.integer("port_country_id")
with self.schema.create_table_if_not_exists("countries") as table:
table.integer("country_id").primary()
table.string("name")
if not Country.count():
Country.builder.new().bulk_create(
[
{"country_id": 10, "name": "Australia"},
{"country_id": 20, "name": "USA"},
{"country_id": 30, "name": "Canada"},
{"country_id": 40, "name": "United Kingdom"},
]
)
if not Port.count():
Port.builder.new().bulk_create(
[
{"port_id": 100, "name": "Melbourne", "port_country_id": 10},
{"port_id": 200, "name": "Darwin", "port_country_id": 10},
{"port_id": 300, "name": "South Louisiana", "port_country_id": 20},
{"port_id": 400, "name": "Houston", "port_country_id": 20},
{"port_id": 500, "name": "Montreal", "port_country_id": 30},
{"port_id": 600, "name": "Vancouver", "port_country_id": 30},
{"port_id": 700, "name": "Southampton", "port_country_id": 40},
{"port_id": 800, "name": "London Gateway", "port_country_id": 40},
]
)
if not IncomingShipment.count():
IncomingShipment.builder.new().bulk_create(
[
{"name": "Bread", "from_port_id": 300},
{"name": "Milk", "from_port_id": 100},
{"name": "Tractor Parts", "from_port_id": 100},
{"name": "Fridges", "from_port_id": 700},
{"name": "Wheat", "from_port_id": 600},
{"name": "Kettles", "from_port_id": 400},
{"name": "Bread", "from_port_id": 700},
]
)
def test_has_one_through_can_eager_load(self):
shipments = IncomingShipment.where("name", "Bread").with_("from_country").get()
self.assertEqual(shipments.count(), 2)
shipment1 = shipments.shift()
self.assertIsInstance(shipment1.from_country, Country)
self.assertEqual(shipment1.from_country.country_id, 20)
shipment2 = shipments.shift()
self.assertIsInstance(shipment2.from_country, Country)
self.assertEqual(shipment2.from_country.country_id, 40)
# check .first() and .get() produce the same result
single = (
IncomingShipment.where("name", "Tractor Parts")
.with_("from_country")
.first()
)
single_get = (
IncomingShipment.where("name", "Tractor Parts").with_("from_country").get()
)
self.assertEqual(single.from_country.country_id, 10)
self.assertEqual(single_get.count(), 1)
self.assertEqual(
single.from_country.country_id, single_get.first().from_country.country_id
)
def test_has_one_through_eager_load_can_be_empty(self):
shipments = (
IncomingShipment.where("name", "Bread")
.where_has("from_country", lambda query: query.where("name", "Ueaguay"))
.with_(
"from_country",
)
.get()
)
self.assertEqual(shipments.count(), 0)
def test_has_one_through_can_get_related(self):
shipment = IncomingShipment.where("name", "Milk").first()
self.assertIsInstance(shipment.from_country, Country)
self.assertEqual(shipment.from_country.country_id, 10)
def test_has_one_through_has_query(self):
shipments = IncomingShipment.where_has(
"from_country", lambda query: query.where("name", "USA")
)
self.assertEqual(shipments.count(), 2)
================================================
FILE: tests/sqlite/relationships/test_sqlite_polymorphic.py
================================================
import os
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import belongs_to, has_many, morph_to
from tests.integrations.config.database import DB
class Profile(Model):
__table__ = "profiles"
__connection__ = "dev"
class Articles(Model):
__table__ = "articles"
__connection__ = "dev"
@belongs_to("id", "article_id")
def logo(self):
return Logo
class Logo(Model):
__table__ = "logos"
__connection__ = "dev"
class Like(Model):
__connection__ = "dev"
@morph_to("record_type", "record_id")
def record(self):
return self
class User(Model):
__connection__ = "dev"
_eager_loads = ()
DB.morph_map({"user": User, "article": Articles})
class TestRelationships(unittest.TestCase):
maxDiff = None
def test_can_get_polymorphic_relation(self):
likes = Like.get()
for like in likes:
self.assertIsInstance(like.record, (Articles, User))
def test_can_get_eager_load_polymorphic_relation(self):
likes = Like.with_("record").get()
for like in likes:
self.assertIsInstance(like.record, (Articles, User))
================================================
FILE: tests/sqlite/relationships/test_sqlite_relationships.py
================================================
import unittest
from src.masoniteorm.models import Model
from src.masoniteorm.relationships import belongs_to, has_many, has_one, belongs_to_many
from tests.integrations.config.database import DB
class Profile(Model):
__table__ = "profiles"
__connection__ = "dev"
class Articles(Model):
__table__ = "articles"
__connection__ = "dev"
__timestamps__ = None
__dates__ = ["published_date"]
@belongs_to("id", "article_id")
def logo(self):
return Logo
class Logo(Model):
__table__ = "logos"
__connection__ = "dev"
__timestamps__ = None
__dates__ = ["published_date"]
class User(Model):
__connection__ = "dev"
_eager_loads = ()
__casts__ = {"is_admin": "bool"}
@belongs_to("id", "user_id")
def profile(self):
return Profile
@has_many("id", "user_id")
def articles(self):
return Articles
def get_is_admin(self):
return "You are an admin"
class Store(Model):
__connection__ = "dev"
@belongs_to_many("store_id", "product_id", "id", "id", with_timestamps=True)
def products(self):
return Product
@belongs_to_many("store_id", "product_id", "id", "id", table="product_table")
def products_table(self):
return Product
@belongs_to_many
def store_products(self):
return Product
class Product(Model):
__connection__ = "dev"
class UserHasOne(Model):
__table__ = "users"
__connection__ = "dev"
@has_one("user_id", "user_id")
def profile(self):
return Profile
class TestRelationships(unittest.TestCase):
maxDiff = None
def test_relationship_can_be_callable(self):
self.assertEqual(
User.profile().where("name", "Joe").to_sql(),
"""SELECT * FROM "profiles" WHERE "profiles"."name" = 'Joe'""",
)
def test_can_access_relationship(self):
for user in User.where("id", 1).get():
self.assertIsInstance(user.profile, Profile)
def test_can_access_has_many_relationship(self):
user = User.hydrate(User.where("id", 1).first())
self.assertEqual(len(user.articles), 1)
def test_can_access_relationship_multiple_times(self):
user = User.hydrate(User.where("id", 1).first())
self.assertEqual(len(user.articles), 1)
self.assertEqual(len(user.articles), 1)
def test_can_access_relationship_date(self):
user = User.with_("articles").where("id", 1).first()
for article in user.articles:
print(article.logo.published_date.is_past())
def test_loading(self):
users = User.with_("articles").get()
for user in users:
user
def test_relationship_has_one_sql(self):
self.assertEqual(UserHasOne.profile().to_sql(), 'SELECT * FROM "profiles"')
def test_loading_with_nested_with(self):
users = User.with_("articles", "articles.logo").get()
for user in users:
for article in user.articles:
if article.logo:
print("aa", article.logo.url)
def test_casting(self):
users = User.with_("articles").where("is_admin", True).get()
for user in users:
user
def test_setting(self):
users = User.with_("articles").where("is_admin", True).get()
for user in users:
user.name = "Joe"
user.is_admin = 1
user.save()
def test_related(self):
user = User.first()
related_query = user.related("profile").where("active", 1).to_sql()
self.assertEqual(
related_query,
"""SELECT * FROM "profiles" WHERE "profiles"."user_id" = '1' AND "profiles"."active" = '1'""",
)
def test_associate_records(self):
DB.begin_transaction("dev")
user = User.first()
articles = [Articles.hydrate({"id": 1, "title": "associate records"})]
user.save_many("articles", articles)
DB.rollback("dev")
def test_belongs_to_many(self):
store = Store.hydrate({"id": 2, "name": "Walmart"})
self.assertEqual(store.products.count(), 3)
self.assertEqual(store.products.serialize()[0]["id"], 4)
self.assertEqual(store.products.serialize()[0]["name"], "Handgun")
self.assertEqual(
store.products.serialize()[0]["updated_at"], "2020-01-01T00:00:00+00:00"
)
self.assertEqual(
store.products.serialize()[0]["created_at"], "2020-01-01T00:00:00+00:00"
)
def test_belongs_to_eager_many(self):
store = Store.hydrate({"id": 2, "name": "Walmart"})
store = Store.with_("products").first()
self.assertEqual(store.products.count(), 3)
================================================
FILE: tests/sqlite/schema/test_sqlite_schema_builder.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import SQLiteConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import SQLitePlatform
class TestSQLiteSchemaBuilder(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection="dev",
connection_details=DATABASES,
platform=SQLitePlatform,
dry=True,
).on("dev")
def test_can_add_columns(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'
],
)
def test_can_add_tiny_text(self):
with self.schema.create("users") as blueprint:
blueprint.tiny_text("description")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(), ['CREATE TABLE "users" ("description" TEXT NOT NULL)']
)
def test_can_add_unsigned_decimal(self):
with self.schema.create("users") as blueprint:
blueprint.unsigned_decimal("amount", 19, 4)
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
['CREATE TABLE "users" ("amount" DECIMAL(19, 4) NOT NULL)'],
)
def test_can_create_table_if_not_exists(self):
with self.schema.create_table_if_not_exists("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE IF NOT EXISTS "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL)'
],
)
def test_can_add_columns_with_constraint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name")
blueprint.integer("age")
blueprint.unique("name")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("name" VARCHAR(255) NOT NULL, "age" INTEGER NOT NULL, UNIQUE(name))'
],
)
def test_can_have_float_type(self):
with self.schema.create("users") as blueprint:
blueprint.float("amount")
self.assertEqual(
blueprint.to_sql(),
["""CREATE TABLE "users" (""" """\"amount" FLOAT(19, 4) NOT NULL)"""],
)
def test_can_add_columns_with_foreign_key_constraint(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.foreign("profile_id").references("id").on("profiles")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" '
'("name" VARCHAR(255) NOT NULL, '
'"age" INTEGER NOT NULL, '
'"profile_id" INTEGER NOT NULL, '
"UNIQUE(name), "
'CONSTRAINT users_profile_id_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))'
],
)
def test_can_add_columns_with_foreign_key_constraint_name(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.foreign("profile_id", name="profile_foreign").references("id").on(
"profiles"
)
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" '
'("name" VARCHAR(255) NOT NULL, '
'"age" INTEGER NOT NULL, '
'"profile_id" INTEGER NOT NULL, '
"UNIQUE(name), "
'CONSTRAINT profile_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"))'
],
)
def test_can_use_morphs_for_polymorphism_relationships(self):
with self.schema.create("likes") as blueprint:
blueprint.morphs("record")
self.assertEqual(len(blueprint.table.added_columns), 2)
sql = [
'CREATE TABLE "likes" ("record_id" INTEGER UNSIGNED NOT NULL, "record_type" VARCHAR(255) NOT NULL)',
'CREATE INDEX likes_record_id_index ON "likes"(record_id)',
'CREATE INDEX likes_record_type_index ON "likes"(record_type)',
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_advanced_table_creation(self):
with self.schema.create("users") as blueprint:
blueprint.increments("id")
blueprint.string("name")
blueprint.enum("gender", ["male", "female"])
blueprint.string("email").unique()
blueprint.string("password")
blueprint.string("option").default("ADMIN")
blueprint.integer("admin").default(0)
blueprint.string("remember_token").nullable()
blueprint.timestamp("verified_at").nullable()
blueprint.unique(["email", "name"])
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 11)
self.assertEqual(
blueprint.to_sql(),
[
"""CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR(255) NOT NULL, "gender" VARCHAR(255) CHECK(gender IN ('male', 'female')) NOT NULL, "email" VARCHAR(255) NOT NULL, """
""""password" VARCHAR(255) NOT NULL, "option" VARCHAR(255) NOT NULL DEFAULT 'ADMIN', "admin" INTEGER NOT NULL DEFAULT 0, "remember_token" VARCHAR(255) NULL, """
'"verified_at" TIMESTAMP NULL, "created_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, '
'"updated_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, CONSTRAINT users_id_primary PRIMARY KEY (id), '
"UNIQUE(email), UNIQUE(email, name))"
],
)
def test_can_create_indexes(self):
with self.schema.table("users") as blueprint:
blueprint.index("name")
blueprint.index("active", "active_idx")
blueprint.index(["name", "email"])
blueprint.unique("name")
blueprint.unique(["name", "email"])
blueprint.fulltext("description")
self.assertEqual(len(blueprint.table.added_columns), 0)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE INDEX users_name_index ON "users"(name)',
'CREATE INDEX active_idx ON "users"(active)',
'CREATE INDEX users_name_email_index ON "users"(name,email)',
'CREATE UNIQUE INDEX users_name_unique ON "users"(name)',
'CREATE UNIQUE INDEX users_name_email_unique ON "users"(name,email)',
],
)
def test_can_create_indexes_on_previous_column(self):
with self.schema.table("users") as blueprint:
blueprint.string("email").index()
blueprint.string("active").index(name="email_idx")
self.assertEqual(len(blueprint.table.added_columns), 2)
self.assertEqual(
blueprint.to_sql(),
[
'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL',
'ALTER TABLE "users" ADD COLUMN "active" VARCHAR NOT NULL',
'CREATE INDEX users_email_index ON "users"(email)',
'CREATE INDEX email_idx ON "users"(active)',
],
)
def test_can_have_composite_keys(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").unique()
blueprint.integer("age")
blueprint.integer("profile_id")
blueprint.primary(["name", "age"])
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" '
'("name" VARCHAR(255) NOT NULL, '
'"age" INTEGER NOT NULL, '
'"profile_id" INTEGER NOT NULL, '
"UNIQUE(name), "
"CONSTRAINT users_name_age_primary PRIMARY KEY (name, age))"
],
)
def test_can_have_column_primary_key(self):
with self.schema.create("users") as blueprint:
blueprint.string("name").primary()
blueprint.integer("age")
blueprint.integer("profile_id")
self.assertEqual(len(blueprint.table.added_columns), 3)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" '
'("name" VARCHAR(255) NOT NULL, '
'"age" INTEGER NOT NULL, '
'"profile_id" INTEGER NOT NULL, '
"CONSTRAINT users_name_primary PRIMARY KEY (name))"
],
)
def test_can_advanced_table_creation2(self):
with self.schema.create("users") as blueprint:
blueprint.big_increments("id")
blueprint.string("name")
blueprint.string("duration")
blueprint.string("url")
blueprint.json("payload")
blueprint.year("birth")
blueprint.inet("last_address").nullable()
blueprint.cidr("route_origin").nullable()
blueprint.macaddr("mac_address").nullable()
blueprint.datetime("published_at")
blueprint.time("wakeup_at")
blueprint.string("thumbnail").nullable()
blueprint.integer("premium")
blueprint.integer("author_id").unsigned().nullable()
blueprint.foreign("author_id").references("id").on("users").on_delete(
"set null"
)
blueprint.text("description")
blueprint.timestamps()
self.assertEqual(len(blueprint.table.added_columns), 17)
self.assertEqual(
blueprint.to_sql(),
(
[
'CREATE TABLE "users" ("id" BIGINT NOT NULL, "name" VARCHAR(255) NOT NULL, "duration" VARCHAR(255) NOT NULL, '
'"url" VARCHAR(255) NOT NULL, "payload" JSON NOT NULL, "birth" VARCHAR(4) NOT NULL, "last_address" VARCHAR(255) NULL, "route_origin" VARCHAR(255) NULL, "mac_address" VARCHAR(255) NULL, '
'"published_at" DATETIME NOT NULL, "wakeup_at" TIME NOT NULL, "thumbnail" VARCHAR(255) NULL, "premium" INTEGER NOT NULL, "author_id" INTEGER UNSIGNED NULL, "description" TEXT NOT NULL, '
'"created_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, "updated_at" DATETIME NULL DEFAULT CURRENT_TIMESTAMP, '
'CONSTRAINT users_id_primary PRIMARY KEY (id), CONSTRAINT users_author_id_foreign FOREIGN KEY ("author_id") REFERENCES "users"("id") ON DELETE SET NULL)'
]
),
)
def test_has_table(self):
schema_sql = self.schema.has_table("users")
sql = "SELECT name FROM sqlite_master WHERE type='table' AND name='users'"
self.assertEqual(schema_sql, sql)
def test_can_truncate(self):
sql = self.schema.truncate("users")
self.assertEqual(sql, 'DELETE FROM "users"')
def test_can_rename_table(self):
sql = self.schema.rename("users", "clients")
self.assertEqual(sql, 'ALTER TABLE "users" RENAME TO "clients"')
def test_can_drop_table_if_exists(self):
sql = self.schema.drop_table_if_exists("users", "clients")
self.assertEqual(sql, 'DROP TABLE IF EXISTS "users"')
def test_can_drop_table(self):
sql = self.schema.drop_table("users", "clients")
self.assertEqual(sql, 'DROP TABLE "users"')
def test_has_column(self):
sql = self.schema.has_column("users", "name")
self.assertEqual(
sql,
"SELECT column_name FROM information_schema.columns WHERE table_name='users' and column_name='name'",
)
def test_can_have_unsigned_columns(self):
with self.schema.create("users") as blueprint:
blueprint.integer("profile_id").unsigned()
blueprint.big_integer("big_profile_id").unsigned()
blueprint.tiny_integer("tiny_profile_id").unsigned()
blueprint.small_integer("small_profile_id").unsigned()
blueprint.medium_integer("medium_profile_id").unsigned()
self.assertEqual(
blueprint.to_sql(),
[
"""CREATE TABLE "users" ("""
""""profile_id" INTEGER UNSIGNED NOT NULL, """
""""big_profile_id" BIGINT UNSIGNED NOT NULL, """
""""tiny_profile_id" TINYINT UNSIGNED NOT NULL, """
""""small_profile_id" SMALLINT UNSIGNED NOT NULL, """
""""medium_profile_id" MEDIUMINT UNSIGNED NOT NULL)"""
],
)
def test_can_enable_foreign_keys(self):
sql = self.schema.enable_foreign_key_constraints()
self.assertEqual(sql, "PRAGMA foreign_keys = ON")
def test_can_disable_foreign_keys(self):
sql = self.schema.disable_foreign_key_constraints()
self.assertEqual(sql, "PRAGMA foreign_keys = OFF")
def test_can_truncate_without_foreign_keys(self):
sql = self.schema.truncate("users", foreign_keys=True)
self.assertEqual(
sql,
[
"PRAGMA foreign_keys = OFF",
'DELETE FROM "users"',
"PRAGMA foreign_keys = ON",
],
)
def test_can_add_enum(self):
with self.schema.create("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(
blueprint.to_sql(),
[
'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\')'
],
)
================================================
FILE: tests/sqlite/schema/test_sqlite_schema_builder_alter.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import SQLiteConnection
from src.masoniteorm.schema import Schema
from src.masoniteorm.schema.platforms import SQLitePlatform
from src.masoniteorm.schema.Table import Table
class TestSQLiteSchemaBuilderAlter(unittest.TestCase):
maxDiff = None
def setUp(self):
self.schema = Schema(
connection="dev",
connection_details=DATABASES,
platform=SQLitePlatform,
dry=True,
).on("dev")
def test_can_add_columns(self):
with self.schema.table("users") as blueprint:
blueprint.string("name")
blueprint.string("external_type").default("external")
blueprint.integer("age")
self.assertEqual(len(blueprint.table.added_columns), 3)
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR NOT NULL',
"""ALTER TABLE "users" ADD COLUMN "external_type" VARCHAR NOT NULL DEFAULT 'external'""",
'ALTER TABLE "users" ADD COLUMN "age" INTEGER NOT NULL',
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_add_constraints(self):
with self.schema.table("users") as blueprint:
blueprint.unique("name", name="table_unique")
self.assertEqual(len(blueprint.table.added_columns), 0)
sql = ['CREATE UNIQUE INDEX table_unique ON "users"(name)']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_rename(self):
with self.schema.table("users") as blueprint:
blueprint.rename("post", "comment", "integer")
table = Table("users")
table.add_column("post", "integer")
blueprint.table.from_table = table
sql = [
"CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("comment" INTEGER NOT NULL)',
'INSERT INTO "users" ("comment") SELECT post FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop(self):
with self.schema.table("users") as blueprint:
blueprint.drop_column("post")
table = Table("users")
table.add_column("post", "string")
table.add_column("name", "string")
table.add_column("email", "string")
blueprint.table.from_table = table
sql = [
"CREATE TEMPORARY TABLE __temp__users AS SELECT name, email FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("name" VARCHAR NOT NULL, "email" VARCHAR NOT NULL)',
'INSERT INTO "users" ("name", "email") SELECT name, email FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").change()
blueprint.string("name")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR NOT NULL',
"CREATE TEMPORARY TABLE __temp__users AS SELECT age FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("age" INTEGER NOT NULL, "name" VARCHAR(255) NOT NULL)',
'INSERT INTO "users" ("age") SELECT age FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_drop_add_and_change(self):
with self.schema.table("users") as blueprint:
blueprint.integer("age").change()
blueprint.string("name")
blueprint.drop_column("email")
self.assertEqual(len(blueprint.table.added_columns), 1)
self.assertEqual(len(blueprint.table.changed_columns), 1)
table = Table("users")
table.add_column("age", "string")
table.add_column("email", "string")
blueprint.table.from_table = table
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR',
"CREATE TEMPORARY TABLE __temp__users AS SELECT age FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("age" INTEGER NOT NULL, "name" VARCHAR(255) NOT NULL)',
'INSERT INTO "users" ("age") SELECT age FROM __temp__users',
"DROP TABLE __temp__users",
]
def test_timestamp_alter_add_nullable_column(self):
with self.schema.table("users") as blueprint:
blueprint.timestamp("due_date").nullable()
self.assertEqual(len(blueprint.table.added_columns), 1)
table = Table("users")
table.add_column("age", "string")
blueprint.table.from_table = table
sql = ['ALTER TABLE "users" ADD COLUMN "due_date" TIMESTAMP NULL']
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_drop_on_table_schema_table(self):
schema = Schema(connection="dev", connection_details=DATABASES).on("dev")
with schema.table("table_schema") as blueprint:
blueprint.drop_column("name")
with schema.table("table_schema") as blueprint:
blueprint.string("name").nullable()
def test_alter_add_primary(self):
with self.schema.table("users") as blueprint:
blueprint.primary("playlist_id")
sql = [
'ALTER TABLE "users" ADD CONSTRAINT users_playlist_id_primary PRIMARY KEY (playlist_id)'
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_column_and_foreign_key(self):
with self.schema.table("users") as blueprint:
blueprint.unsigned_integer("playlist_id").nullable()
blueprint.foreign("playlist_id").references("id").on("playlists").on_delete(
"cascade"
).on_update("SET NULL")
table = Table("users")
table.add_column("age", "string")
table.add_column("email", "string")
blueprint.table.from_table = table
sql = [
'ALTER TABLE "users" ADD COLUMN "playlist_id" INTEGER UNSIGNED NULL REFERENCES "playlists"("id")',
"CREATE TEMPORARY TABLE __temp__users AS SELECT age, email FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("age" VARCHAR NOT NULL, "email" VARCHAR NOT NULL, "playlist_id" INTEGER UNSIGNED NULL, '
'CONSTRAINT users_playlist_id_foreign FOREIGN KEY ("playlist_id") REFERENCES "playlists"("id") ON DELETE CASCADE ON UPDATE SET NULL)',
'INSERT INTO "users" ("age", "email") SELECT age, email FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_alter_add_foreign_key_only(self):
with self.schema.table("users") as blueprint:
blueprint.foreign("playlist_id").references("id").on("playlists").on_delete(
"cascade"
).on_update("set null")
table = Table("users")
table.add_column("age", "string")
table.add_column("email", "string")
blueprint.table.from_table = table
sql = [
"CREATE TEMPORARY TABLE __temp__users AS SELECT age, email FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("age" VARCHAR NOT NULL, "email" VARCHAR NOT NULL, '
'CONSTRAINT users_playlist_id_foreign FOREIGN KEY ("playlist_id") REFERENCES "playlists"("id") ON DELETE CASCADE ON UPDATE SET NULL)',
'INSERT INTO "users" ("age", "email") SELECT age, email FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_add_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active")
self.assertEqual(len(blueprint.table.added_columns), 1)
sql = [
'ALTER TABLE "users" ADD COLUMN "status" VARCHAR CHECK(\'status\' IN(\'active\', \'inactive\')) NOT NULL DEFAULT \'active\''
]
self.assertEqual(blueprint.to_sql(), sql)
def test_can_change_column_enum(self):
with self.schema.table("users") as blueprint:
blueprint.enum("status", ["active", "inactive"]).default("active").change()
blueprint.table.from_table = Table("users")
self.assertEqual(len(blueprint.table.changed_columns), 1)
sql = [
'CREATE TEMPORARY TABLE __temp__users AS SELECT FROM users',
'DROP TABLE "users"',
'CREATE TABLE "users" ("status" VARCHAR(255) CHECK(status IN (\'active\', \'inactive\')) NOT NULL DEFAULT \'active\')',
'INSERT INTO "users" ("status") SELECT status FROM __temp__users',
'DROP TABLE __temp__users'
]
self.assertEqual(blueprint.to_sql(), sql)
================================================
FILE: tests/sqlite/schema/test_table.py
================================================
import unittest
from tests.integrations.config.database import DATABASES
from src.masoniteorm.connections import SQLiteConnection
from src.masoniteorm.schema import Column, Table
from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform
class TestTable(unittest.TestCase):
maxDiff = None
def setUp(self):
self.platform = SQLitePlatform()
def test_add_columns(self):
table = Table("users")
table.add_column("name", "string")
self.assertIsInstance(table.added_columns["name"], Column)
def test_primary_key(self):
table = Table("users")
table.add_column("id", "integer")
table.set_primary_key("id")
self.assertEqual(table.primary_key, "id")
def test_create_sql(self):
table = Table("users")
table.add_column("id", "integer")
table.add_column("name", "string")
sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR NOT NULL)'
self.assertEqual([sql], self.platform.compile_create_sql(table))
def test_create_sql_with_primary_key(self):
table = Table("users")
table.add_column("id", "integer")
table.add_column("name", "string")
table.set_primary_key("id")
sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR NOT NULL)'
self.assertEqual([sql], self.platform.compile_create_sql(table))
def test_create_sql_with_unique_constraint(self):
table = Table("users")
table.add_column("id", "integer")
table.add_column("name", "string")
table.add_constraint("name", "unique", ["name"])
table.set_primary_key("id")
sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "name" VARCHAR NOT NULL, UNIQUE(name))'
self.platform.constraintize(table.get_added_constraints())
self.assertEqual(self.platform.compile_create_sql(table), [sql])
def test_create_sql_with_multiple_unique_constraints(self):
table = Table("users")
table.add_column("id", "integer")
table.add_column("email", "string")
table.add_column("name", "string")
table.add_constraint("name", "unique", ["name"])
table.add_constraint("email", "unique", ["email"])
table.set_primary_key("id")
sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "email" VARCHAR NOT NULL, "name" VARCHAR NOT NULL, UNIQUE(name), UNIQUE(email))'
self.platform.constraintize(table.get_added_constraints())
self.assertEqual(self.platform.compile_create_sql(table), [sql])
def test_create_sql_with_multiple_unique_constraint(self):
table = Table("users")
table.add_column("id", "integer")
table.add_column("email", "string")
table.add_column("name", "string")
table.add_constraint("name", "unique", ["name", "email"])
table.set_primary_key("id")
sql = 'CREATE TABLE "users" ("id" INTEGER NOT NULL, "email" VARCHAR NOT NULL, "name" VARCHAR NOT NULL, UNIQUE(name, email))'
self.platform.constraintize(table.get_added_constraints())
self.assertEqual(self.platform.compile_create_sql(table), [sql])
def test_create_sql_with_foreign_key_constraint(self):
table = Table("users")
table.add_column("id", "integer")
table.add_column("profile_id", "integer")
table.add_column("comment_id", "integer")
table.add_foreign_key("profile_id", "profiles", "id")
table.add_foreign_key("comment_id", "comments", "id")
table.set_primary_key("id")
sql = (
'CREATE TABLE "users" ('
'"id" INTEGER NOT NULL, "profile_id" INTEGER NOT NULL, "comment_id" INTEGER NOT NULL, '
'CONSTRAINT users_profile_id_foreign FOREIGN KEY ("profile_id") REFERENCES "profiles"("id"), '
'CONSTRAINT users_comment_id_foreign FOREIGN KEY ("comment_id") REFERENCES "comments"("id"))'
)
self.platform.constraintize(table.get_added_constraints())
self.assertEqual(self.platform.compile_create_sql(table), [sql])
def test_can_build_table_from_connection_call(self):
sql_details = DATABASES["dev"]
table = self.platform.get_current_schema(
SQLiteConnection(
database=sql_details["database"], name="dev"
).make_connection(),
"table_schema",
)
self.assertEqual(len(table.added_columns), 4)
================================================
FILE: tests/sqlite/schema/test_table_diff.py
================================================
import unittest
from src.masoniteorm.schema import Column, Table
from src.masoniteorm.schema.platforms.SQLitePlatform import SQLitePlatform
from src.masoniteorm.schema.TableDiff import TableDiff
class TestTableDiff(unittest.TestCase):
def setUp(self):
self.platform = SQLitePlatform()
def test_rename_table(self):
table = Table("users")
table.add_column("name", "string")
diff = TableDiff("users")
diff.from_table = table
diff.new_name = "clients"
sql = ['ALTER TABLE "users" RENAME TO "clients"']
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_drop_index(self):
table = Table("users")
table.add_index("name", "name_index", "unique")
diff = TableDiff("users")
diff.from_table = table
diff.remove_index("name")
sql = ["DROP INDEX name"]
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_drop_index_and_rename_table(self):
table = Table("users")
table.add_index("name", "name_unique", "unique")
diff = TableDiff("users")
diff.from_table = table
diff.new_name = "clients"
diff.remove_index("name_unique")
sql = ["DROP INDEX name_unique", 'ALTER TABLE "users" RENAME TO "clients"']
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_alter_add_column(self):
table = Table("users")
diff = TableDiff("users")
diff.from_table = table
diff.add_column("name", "string")
diff.add_column("email", "string")
sql = [
'ALTER TABLE "users" ADD COLUMN "name" VARCHAR NOT NULL',
'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL',
]
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_alter_rename(self):
table = Table("users")
table.add_column("post", "integer")
diff = TableDiff("users")
diff.from_table = table
diff.rename_column("post", "comment", "integer")
sql = [
"CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("comment" INTEGER NOT NULL)',
'INSERT INTO "users" ("comment") SELECT post FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_alter_advanced_rename_columns(self):
table = Table("users")
table.add_column("post", "integer")
table.add_column("user", "integer")
table.add_column("email", "integer")
diff = TableDiff("users")
diff.from_table = table
diff.rename_column("post", "comment", "integer")
sql = [
"CREATE TEMPORARY TABLE __temp__users AS SELECT post, user, email FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("comment" INTEGER NOT NULL, "user" INTEGER NOT NULL, "email" INTEGER NOT NULL)',
'INSERT INTO "users" ("comment", "user", "email") SELECT post, user, email FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_alter_rename_column_and_rename_table(self):
table = Table("users")
table.add_column("post", "integer")
diff = TableDiff("users")
diff.from_table = table
diff.new_name = "clients"
diff.rename_column("post", "comment", "integer")
sql = [
"CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("comment" INTEGER NOT NULL)',
'INSERT INTO "users" ("comment") SELECT post FROM __temp__users',
"DROP TABLE __temp__users",
'ALTER TABLE "users" RENAME TO "clients"',
]
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_alter_rename_column_and_rename_table_and_drop_index(self):
table = Table("users")
table.add_column("post", "integer")
table.add_index("name", "name_unique", "unique")
diff = TableDiff("users")
diff.from_table = table
diff.new_name = "clients"
diff.rename_column("post", "comment", "integer")
diff.remove_index("name")
sql = [
"DROP INDEX name",
"CREATE TEMPORARY TABLE __temp__users AS SELECT post FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("comment" INTEGER NOT NULL)',
'INSERT INTO "users" ("comment") SELECT post FROM __temp__users',
"DROP TABLE __temp__users",
'ALTER TABLE "users" RENAME TO "clients"',
]
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
def test_alter_can_drop_column(self):
table = Table("users")
table.add_column("post", "integer")
table.add_column("name", "string")
table.add_column("email", "string")
diff = TableDiff("users")
diff.from_table = table
diff.drop_column("post")
sql = [
"CREATE TEMPORARY TABLE __temp__users AS SELECT name, email FROM users",
'DROP TABLE "users"',
'CREATE TABLE "users" ("name" VARCHAR NOT NULL, "email" VARCHAR NOT NULL)',
'INSERT INTO "users" ("name", "email") SELECT name, email FROM __temp__users',
"DROP TABLE __temp__users",
]
self.assertEqual(sql, self.platform.compile_alter_sql(diff))
================================================
FILE: tests/utils.py
================================================
from unittest import mock
from src.masoniteorm.connections.ConnectionFactory import ConnectionFactory
from src.masoniteorm.connections.MySQLConnection import MySQLConnection
from src.masoniteorm.connections.SQLiteConnection import SQLiteConnection
from src.masoniteorm.schema.platforms import MySQLPlatform
class MockMySQLConnection(MySQLConnection):
def make_connection(self):
self._connection = mock.MagicMock()
self._cursor = mock.MagicMock()
return self
@classmethod
def get_default_platform(cls):
return MySQLPlatform
class MockMSSQLConnection(MySQLConnection):
def make_connection(self):
self._connection = mock.MagicMock()
self._cursor = mock.MagicMock()
return self
@classmethod
def get_default_platform(cls):
return MySQLPlatform
class MockPostgresConnection(MySQLConnection):
def make_connection(self):
self._connection = mock.MagicMock()
return self
class MockSQLiteConnection(SQLiteConnection):
def make_connection(self):
self._connection = mock.MagicMock()
return self
def query(self, *args, **kwargs):
return {}
class MockConnectionFactory(ConnectionFactory):
_connections = {
"mysql": MockMySQLConnection,
"mssql": MockMSSQLConnection,
"postgres": MockPostgresConnection,
"sqlite": MockSQLiteConnection,
"oracle": "",
}