Repository: go-gorm/gorm Branch: master Commit: 4380dd6dd1a5 Files: 184 Total size: 1.2 MB Directory structure: gitextract_lfhwtntz/ ├── .github/ │ ├── FUNDING.yml │ ├── dependabot.yml │ ├── labels.json │ ├── release-drafter.yml │ └── workflows/ │ ├── create-release.yml │ ├── golangci-lint.yml │ ├── invalid_question.yml │ ├── labeler.yml │ ├── missing_playground.yml │ ├── stale.yml │ └── tests.yml ├── .gitignore ├── .golangci.yml ├── CODE_OF_CONDUCT.md ├── LICENSE ├── README.md ├── association.go ├── callbacks/ │ ├── associations.go │ ├── callbacks.go │ ├── callmethod.go │ ├── create.go │ ├── create_test.go │ ├── delete.go │ ├── helper.go │ ├── helper_test.go │ ├── interfaces.go │ ├── preload.go │ ├── query.go │ ├── raw.go │ ├── row.go │ ├── transaction.go │ └── update.go ├── callbacks.go ├── chainable_api.go ├── clause/ │ ├── association.go │ ├── benchmarks_test.go │ ├── clause.go │ ├── clause_test.go │ ├── delete.go │ ├── delete_test.go │ ├── expression.go │ ├── expression_test.go │ ├── from.go │ ├── from_test.go │ ├── group_by.go │ ├── group_by_test.go │ ├── insert.go │ ├── insert_test.go │ ├── joins.go │ ├── joins_test.go │ ├── limit.go │ ├── limit_test.go │ ├── locking.go │ ├── locking_test.go │ ├── on_conflict.go │ ├── order_by.go │ ├── order_by_test.go │ ├── returning.go │ ├── returning_test.go │ ├── select.go │ ├── select_test.go │ ├── set.go │ ├── set_test.go │ ├── update.go │ ├── update_test.go │ ├── values.go │ ├── values_test.go │ ├── where.go │ ├── where_test.go │ └── with.go ├── errors.go ├── finisher_api.go ├── generics.go ├── go.mod ├── go.sum ├── gorm.go ├── interfaces.go ├── internal/ │ ├── lru/ │ │ └── lru.go │ └── stmt_store/ │ └── stmt_store.go ├── logger/ │ ├── logger.go │ ├── slog.go │ ├── slog_test.go │ ├── sql.go │ └── sql_test.go ├── migrator/ │ ├── column_type.go │ ├── index.go │ ├── migrator.go │ └── table_type.go ├── migrator.go ├── model.go ├── prepare_stmt.go ├── scan.go ├── schema/ │ ├── callbacks_test.go │ ├── constraint.go │ ├── constraint_test.go │ ├── field.go │ ├── field_test.go │ ├── index.go │ ├── index_test.go │ ├── interfaces.go │ ├── model_test.go │ ├── naming.go │ ├── naming_test.go │ ├── pool.go │ ├── relationship.go │ ├── relationship_test.go │ ├── schema.go │ ├── schema_helper_test.go │ ├── schema_test.go │ ├── serializer.go │ ├── serializer_test.go │ ├── utils.go │ └── utils_test.go ├── soft_delete.go ├── statement.go ├── statement_test.go ├── tests/ │ ├── .gitignore │ ├── README.md │ ├── association_generics_test.go │ ├── associations_belongs_to_test.go │ ├── associations_has_many_test.go │ ├── associations_has_one_test.go │ ├── associations_many2many_test.go │ ├── associations_test.go │ ├── benchmark_test.go │ ├── callbacks_test.go │ ├── chainable_api_test.go │ ├── compose.yml │ ├── connection_test.go │ ├── connpool_test.go │ ├── count_test.go │ ├── create_test.go │ ├── customize_field_test.go │ ├── default_value_test.go │ ├── delete_test.go │ ├── distinct_test.go │ ├── embedded_struct_test.go │ ├── error_translator_test.go │ ├── gaussdb_test.go │ ├── generics_test.go │ ├── go.mod │ ├── gorm_test.go │ ├── group_by_test.go │ ├── helper_test.go │ ├── hooks_test.go │ ├── joins_table_test.go │ ├── joins_test.go │ ├── lru_test.go │ ├── main_test.go │ ├── migrate_test.go │ ├── multi_primary_keys_test.go │ ├── named_argument_test.go │ ├── named_polymorphic_test.go │ ├── non_std_test.go │ ├── postgres_test.go │ ├── preload_suits_test.go │ ├── preload_test.go │ ├── prepared_stmt_test.go │ ├── query_test.go │ ├── scan_test.go │ ├── scanner_valuer_test.go │ ├── scopes_test.go │ ├── serializer_test.go │ ├── soft_delete_test.go │ ├── sql_builder_test.go │ ├── submodel_test.go │ ├── table_test.go │ ├── tests_all.sh │ ├── tests_test.go │ ├── tracer_test.go │ ├── transaction_test.go │ ├── update_belongs_to_test.go │ ├── update_has_many_test.go │ ├── update_has_one_test.go │ ├── update_many2many_test.go │ ├── update_test.go │ └── upsert_test.go └── utils/ ├── tests/ │ ├── dummy_dialecter.go │ ├── models.go │ └── utils.go ├── utils.go ├── utils_test.go ├── utils_unix_test.go └── utils_windows_test.go ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/FUNDING.yml ================================================ # These are supported funding model platforms github: [jinzhu] patreon: jinzhu open_collective: gorm ================================================ FILE: .github/dependabot.yml ================================================ --- version: 2 updates: - package-ecosystem: gomod directory: / schedule: interval: weekly - package-ecosystem: github-actions directory: / schedule: interval: weekly - package-ecosystem: gomod directory: /tests schedule: interval: weekly ================================================ FILE: .github/labels.json ================================================ { "labels": { "critical": { "name": "type:critical", "colour": "#E84137", "description": "critical questions" }, "question": { "name": "type:question", "colour": "#EDEDED", "description": "general questions" }, "feature": { "name": "type:feature_request", "colour": "#43952A", "description": "feature request" }, "invalid_question": { "name": "type:invalid question", "colour": "#CF2E1F", "description": "invalid question (not related to GORM or described in document or not enough information provided)" }, "with_playground": { "name": "type:with reproduction steps", "colour": "#00ff00", "description": "with reproduction steps" }, "without_playground": { "name": "type:missing reproduction steps", "colour": "#CF2E1F", "description": "missing reproduction steps" }, "has_pr": { "name": "type:has pull request", "colour": "#43952A", "description": "has pull request" }, "not_tested": { "name": "type:not tested", "colour": "#CF2E1F", "description": "not tested" }, "tested": { "name": "type:tested", "colour": "#00ff00", "description": "tested" }, "breaking_change": { "name": "type:breaking change", "colour": "#CF2E1F", "description": "breaking change" } }, "issue": { "with_playground": { "requires": 1, "conditions": [ { "type": "descriptionMatches", "pattern": "/github.com\/go-gorm\/playground\/pull\/\\d\\d+/s" } ] }, "critical": { "requires": 1, "conditions": [ { "type": "descriptionMatches", "pattern": "/(critical|urgent)/i" }, { "type": "titleMatches", "pattern": "/(critical|urgent)/i" } ] }, "question": { "requires": 1, "conditions": [ { "type": "titleMatches", "pattern": "/question/i" }, { "type": "descriptionMatches", "pattern": "/question/i" } ] }, "feature": { "requires": 1, "conditions": [ { "type": "titleMatches", "pattern": "/feature/i" }, { "type": "descriptionMatches", "pattern": "/Describe the feature/i" } ] }, "without_playground": { "requires": 6, "conditions": [ { "type": "descriptionMatches", "pattern": "/^((?!github.com\/go-gorm\/playground\/pull\/\\d\\d+).)*$/s" }, { "type": "titleMatches", "pattern": "/^((?!question).)*$/s" }, { "type": "descriptionMatches", "pattern": "/^((?!question).)*$/is" }, { "type": "descriptionMatches", "pattern": "/^((?!Describe the feature).)*$/is" }, { "type": "titleMatches", "pattern": "/^((?!critical|urgent).)*$/s" }, { "type": "descriptionMatches", "pattern": "/^((?!critical|urgent).)*$/s" } ] } }, "pr": { "critical": { "requires": 1, "conditions": [ { "type": "descriptionMatches", "pattern": "/(critical|urgent)/i" }, { "type": "titleMatches", "pattern": "/(critical|urgent)/i" } ] }, "not_tested": { "requires": 1, "conditions": [ { "type": "descriptionMatches", "pattern": "/\\[\\] Tested/" } ] }, "breaking_change": { "requires": 1, "conditions": [ { "type": "descriptionMatches", "pattern": "/\\[\\] Non breaking API changes/" } ] } } } ================================================ FILE: .github/release-drafter.yml ================================================ name-template: 'v Release $NEXT_PATCH_VERSION 🌈' tag-template: 'v$NEXT_PATCH_VERSION' categories: - title: '🚀 Features' labels: - 'feature' - 'enhancement' - title: '🐛 Bug Fixes' labels: - 'fix' - 'bugfix' - 'bug' - title: '🧰 Maintenance' label: 'chore' change-template: '- $TITLE @$AUTHOR (#$NUMBER)' change-title-escapes: '\<*_&' template: | ## Changes $CHANGES ================================================ FILE: .github/workflows/create-release.yml ================================================ name: Create Release on: push: tags: - 'v*.*.*' permissions: contents: write pull-requests: read jobs: create_release: name: Create Release runs-on: ubuntu-latest steps: - name: Checkout code uses: actions/checkout@v4 - name: Generate Release Notes and Publish id: generate_release_notes uses: release-drafter/release-drafter@v6 with: config-name: 'release-drafter.yml' name: "Release ${{ github.ref_name }}" tag: ${{ github.ref_name }} publish: true prerelease: false env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} ================================================ FILE: .github/workflows/golangci-lint.yml ================================================ name: golangci-lint on: push: branches: - main - master pull_request: permissions: contents: read pull-requests: read jobs: golangci: name: lint runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 - uses: actions/setup-go@v5 with: go-version: stable - name: golangci-lint uses: golangci/golangci-lint-action@v7 with: version: v2.0 only-new-issues: true ================================================ FILE: .github/workflows/invalid_question.yml ================================================ name: "Close invalid questions issues" on: schedule: - cron: "*/10 * * * *" permissions: contents: read jobs: stale: permissions: issues: write # for actions/stale to close stale issues pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been marked as invalid question, please give more information by following the `Question` template, if you believe there is a bug of GORM, please create a pull request that could reproduce the issue on [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground), the issue will be closed in 30 days if no further activity occurs. most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 30 remove-stale-when-updated: true only-labels: "type:invalid question" ================================================ FILE: .github/workflows/labeler.yml ================================================ name: "Issue Labeler" on: issues: types: [opened, edited, reopened] pull_request: types: [opened, edited, reopened] jobs: triage: runs-on: ubuntu-latest name: Label issues and pull requests steps: - name: check out uses: actions/checkout@v4 - name: labeler uses: jinzhu/super-labeler-action@develop with: GITHUB_TOKEN: "${{ secrets.GITHUB_TOKEN }}" ================================================ FILE: .github/workflows/missing_playground.yml ================================================ name: "Close Missing Playground issues" on: schedule: - cron: "*/10 * * * *" permissions: contents: read jobs: stale: permissions: issues: write # for actions/stale to close stale issues pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "The issue has been automatically marked as stale as it missing playground pull request link, which is important to help others understand your issue effectively and make sure the issue hasn't been fixed on latest master, checkout [https://github.com/go-gorm/playground](https://github.com/go-gorm/playground) for details. it will be closed in 30 days if no further activity occurs. if you are asking question, please use the `Question` template, most likely your question already answered https://github.com/go-gorm/gorm/issues or described in the document https://gorm.io ✨ [Search Before Asking](https://stackoverflow.com/help/how-to-ask) ✨" stale-issue-label: "status:stale" days-before-stale: 0 days-before-close: 30 remove-stale-when-updated: true only-labels: "type:missing reproduction steps" ================================================ FILE: .github/workflows/stale.yml ================================================ name: "Stale" on: schedule: - cron: "0 2 * * *" permissions: contents: read jobs: stale: permissions: issues: write # for actions/stale to close stale issues pull-requests: write # for actions/stale to close stale PRs runs-on: ubuntu-latest env: ACTIONS_STEP_DEBUG: true steps: - name: Close Stale Issues uses: actions/stale@v8 with: repo-token: ${{ secrets.GITHUB_TOKEN }} stale-issue-message: "This issue has been automatically marked as stale because it has been open 360 days with no activity. Remove stale label or comment or this will be closed in 180 days" days-before-stale: 360 days-before-close: 180 stale-issue-label: "status:stale" exempt-issue-labels: 'type:feature,type:with reproduction steps,type:has pull request' stale-pr-label: 'status:stale' exempt-pr-labels: 'type:feature,type:with reproduction steps,type:has pull request' ================================================ FILE: .github/workflows/tests.yml ================================================ name: tests on: push: branches-ignore: - 'gh-pages' pull_request: branches-ignore: - 'gh-pages' permissions: contents: read jobs: # Label of the container job sqlite: strategy: matrix: go: ['1.24', '1.25'] platform: [ubuntu-latest] # can not run in windows OS runs-on: ${{ matrix.platform }} steps: - name: Set up Go 1.x uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=sqlite ./tests/tests_all.sh mysql: strategy: matrix: dbversion: ['mysql:9', 'mysql:8', 'mysql:5.7'] go: ['1.24', '1.25'] platform: [ubuntu-latest] runs-on: ${{ matrix.platform }} services: mysql: image: ${{ matrix.dbversion }} env: MYSQL_DATABASE: gorm MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" ports: - 9910:3306 options: >- --health-cmd "mysqladmin ping -ugorm -pgorm" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 steps: - name: Set up Go 1.x uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh mariadb: strategy: matrix: dbversion: [ 'mariadb:latest' ] go: ['1.24', '1.25'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} services: mysql: image: ${{ matrix.dbversion }} env: MYSQL_DATABASE: gorm MYSQL_USER: gorm MYSQL_PASSWORD: gorm MYSQL_RANDOM_ROOT_PASSWORD: "yes" ports: - 9910:3306 options: >- --health-cmd "mariadb-admin ping -ugorm -pgorm" --health-interval 10s --health-start-period 10s --health-timeout 5s --health-retries 10 steps: - name: Set up Go 1.x uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=mysql GORM_DSN="gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True" ./tests/tests_all.sh postgres: strategy: matrix: dbversion: ['postgres:latest', 'postgres:15', 'postgres:14', 'postgres:13'] go: ['1.24', '1.25'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} services: postgres: image: ${{ matrix.dbversion }} env: POSTGRES_PASSWORD: gorm POSTGRES_USER: gorm POSTGRES_DB: gorm TZ: Asia/Shanghai ports: - 9920:5432 # Set health checks to wait until postgres has started options: >- --health-cmd pg_isready --health-interval 10s --health-timeout 5s --health-retries 5 steps: - name: Set up Go 1.x uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=postgres GORM_DSN="user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh sqlserver: strategy: matrix: go: ['1.24', '1.25'] platform: [ubuntu-latest] # can not run test in macOS and windows runs-on: ${{ matrix.platform }} services: mssql: image: mcr.microsoft.com/mssql/server:2022-latest env: TZ: Asia/Shanghai ACCEPT_EULA: Y MSSQL_SA_PASSWORD: LoremIpsum86 ports: - 9930:1433 options: >- --health-cmd="/opt/mssql-tools18/bin/sqlcmd -S localhost -U sa -P ${MSSQL_SA_PASSWORD} -N -C -l 30 -Q \"SELECT 1\" || exit 1" --health-start-period 10s --health-interval 10s --health-timeout 5s --health-retries 10 steps: - name: Set up Go 1.x uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=sqlserver GORM_DSN="sqlserver://sa:LoremIpsum86@localhost:9930?database=master" ./tests/tests_all.sh tidb: strategy: matrix: dbversion: [ 'v6.5.0' ] go: ['1.24', '1.25'] platform: [ ubuntu-latest ] runs-on: ${{ matrix.platform }} steps: - name: Setup TiDB uses: Icemap/tidb-action@main with: port: 9940 version: ${{matrix.dbversion}} - name: Set up Go 1.x uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 - name: go mod package cache uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=tidb GORM_DSN="root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ./tests/tests_all.sh gaussdb: strategy: matrix: dbversion: ['opengauss/opengauss:7.0.0-RC1.B023'] go: ['1.24', '1.25'] platform: [ubuntu-latest] # can not run in macOS and Windows runs-on: ${{ matrix.platform }} services: gaussdb: image: ${{ matrix.dbversion }} env: # GaussDB has password limitations GS_PASSWORD: Gaussdb@123 TZ: Asia/Shanghai ports: - 9950:5432 steps: - name: Set up Go 1.x uses: actions/setup-go@v4 with: go-version: ${{ matrix.go }} - name: Check out code into the Go module directory uses: actions/checkout@v4 - name: Waiting for GaussDB to be ready run: | container_name=$(docker ps --filter "ancestor=opengauss/opengauss:7.0.0-RC1.B023" --format "{{.Names}}") if [ -z "$container_name" ]; then echo "Error: failed to find a container created from the 'opengauss/opengauss:7.0.0-RC1.B023' image." exit 1 fi max_retries=12 retry_count=0 if [ -t 0 ]; then TTY_FLAG="-t" else TTY_FLAG="" fi while [ $retry_count -lt $max_retries ]; do if docker exec -i "${container_name}" bash -c "su - omm -c 'gsql -U omm -c \"select 1;\"'" then echo "Creating database gorm..." sql_file='/tmp/create_database.sql' echo "CREATE DATABASE gorm DBCOMPATIBILITY 'PG';" > ${sql_file} docker cp "${sql_file}" "${container_name}":"${sql_file}" docker exec -i ${TTY_FLAG} "${container_name}" bash -c "su - omm -c 'gsql -U omm -f ${sql_file}'" echo "Database initialization completed." break fi echo "Waiting for database to be ready... (attempt $((retry_count + 1))/$max_retries)" sleep 10 ((++retry_count)) done exit 0 - name: go mod package cache uses: actions/cache@v4 with: path: ~/go/pkg/mod key: ${{ runner.os }}-go-${{ matrix.go }}-${{ hashFiles('tests/go.mod') }} - name: Tests run: GITHUB_ACTION=true GORM_DIALECT=gaussdb GORM_DSN="user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" ./tests/tests_all.sh ================================================ FILE: .gitignore ================================================ TODO* documents coverage.txt _book .idea vendor .vscode ================================================ FILE: .golangci.yml ================================================ version: "2" linters: default: standard enable: - cyclop - gocritic - gosec - ineffassign - misspell - prealloc - unconvert - unparam - whitespace formatters: enable: - gofumpt - goimports ================================================ FILE: CODE_OF_CONDUCT.md ================================================ # Contributor Covenant Code of Conduct ## Our Pledge We as members, contributors, and leaders pledge to participate in our community a harassment-free experience for everyone, regardless of age, body size, visible or invisible disability, ethnicity, sex characteristics, gender identity and expression, level of experience, education, socio-economic status, nationality, personal appearance, race, religion, or sexual identity and orientation. We pledge to act and interact in ways that contribute to an open, welcoming, diverse, inclusive, and healthy community. ## Our Standards Examples of behavior that contributes to a positive environment for our community includes: * Demonstrating empathy and kindness toward other people * Being respectful of differing opinions, viewpoints, and experiences * Giving and gracefully accepting constructive feedback * Accepting responsibility and apologizing to those affected by our mistakes, and learning from the experience * Focusing on what is best not just for us as individuals, but for the overall community Examples of unacceptable behavior include: * The use of sexualized language or imagery, and sexual attention or advances of any kind * Trolling, insulting or derogatory comments, and personal or political attacks * Public or private harassment * Publishing others' private information, such as a physical or email address, without their explicit permission * Other conduct which could reasonably be considered inappropriate in a professional setting ## Enforcement Responsibilities Community leaders are responsible for clarifying and enforcing our standards of acceptable behavior and will take appropriate and fair corrective action in response to any behavior that they deem inappropriate, threatening, offensive, or harmful. Community leaders have the right and responsibility to remove, edit, or reject comments, commits, code, wiki edits, issues, and other contributions that are not aligned to this Code of Conduct, and will communicate reasons for moderation decisions when appropriate. ## Scope This Code of Conduct applies within all community spaces and also applies when an individual is officially representing the community in public spaces. Examples of representing our community include using an official e-mail address, posting via an official social media account, or acting as an appointed representative at an online or offline event. ## Enforcement Instances of abusive, harassing, or otherwise unacceptable behavior may be reported to the community leaders responsible for enforcement at . All complaints will be reviewed and investigated promptly and fairly. All community leaders are obligated to respect the privacy and security of the reporter of any incident. ## Enforcement Guidelines Community leaders will follow these Community Impact Guidelines in determining the consequences for any action they deem in violation of this Code of Conduct: ### 1. Correction **Community Impact**: Use of inappropriate language or other behavior deemed unprofessional or unwelcome in the community. **Consequence**: A private, written warning from community leaders, providing clarity around the nature of the violation and an explanation of why the behavior was inappropriate. A public apology may be requested. ### 2. Warning **Community Impact**: A violation through a single incident or series of actions. **Consequence**: A warning with consequences for continued behavior. No interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, for a specified period. This includes avoiding interactions in community spaces and external channels like social media. Violating these terms may lead to a temporary or permanent ban. ### 3. Temporary Ban **Community Impact**: A serious violation of community standards, including sustained inappropriate behavior. **Consequence**: A temporary ban from any interaction or public communication with the community for a specified period. No public or private interaction with the people involved, including unsolicited interaction with those enforcing the Code of Conduct, is allowed during this period. Violating these terms may lead to a permanent ban. ### 4. Permanent Ban **Community Impact**: Demonstrating a pattern of violation of community standards, including sustained inappropriate behavior, harassment of an individual, or aggression toward or disparagement of classes of individuals. **Consequence**: A permanent ban from any sort of public interaction within the community. ## Attribution This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 2.0, available at https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. Community Impact Guidelines were inspired by [Mozilla's code of conduct enforcement ladder](https://github.com/mozilla/diversity). [homepage]: https://www.contributor-covenant.org For answers to common questions about this code of conduct, see the FAQ at https://www.contributor-covenant.org/faq. Translations are available at https://www.contributor-covenant.org/translations. ================================================ FILE: LICENSE ================================================ The MIT License (MIT) Copyright (c) 2013-present Jinzhu Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # GORM The fantastic ORM library for Golang, aims to be developer friendly. [![go report card](https://goreportcard.com/badge/github.com/go-gorm/gorm "go report card")](https://goreportcard.com/report/github.com/go-gorm/gorm) [![test status](https://github.com/go-gorm/gorm/actions/workflows/tests.yml/badge.svg)](https://github.com/go-gorm/gorm/actions) [![MIT license](https://img.shields.io/badge/license-MIT-brightgreen.svg)](https://opensource.org/licenses/MIT) [![Go.Dev reference](https://img.shields.io/badge/go.dev-reference-blue?logo=go&logoColor=white)](https://pkg.go.dev/gorm.io/gorm?tab=doc) ## Overview * Full-Featured ORM * Associations (Has One, Has Many, Belongs To, Many To Many, Polymorphism, Single-table inheritance) * Hooks (Before/After Create/Save/Update/Delete/Find) * Eager loading with `Preload`, `Joins` * Transactions, Nested Transactions, Save Point, RollbackTo to Saved Point * Context, Prepared Statement Mode, DryRun Mode * Batch Insert, FindInBatches, Find To Map * SQL Builder, Upsert, Locking, Optimizer/Index/Comment Hints, NamedArg, Search/Update/Create with SQL Expr * Composite Primary Key * Auto Migrations * Logger * Extendable, flexible plugin API: Database Resolver (Multiple Databases, Read/Write Splitting) / Prometheus… * Every feature comes with tests * Developer Friendly ## Getting Started * GORM Guides [https://gorm.io](https://gorm.io) * Gen Guides [https://gorm.io/gen/index.html](https://gorm.io/gen/index.html) ## Contributing [You can help to deliver a better GORM, check out things you can do](https://gorm.io/contribute.html) ## Contributors [Thank you](https://github.com/go-gorm/gorm/graphs/contributors) for contributing to the GORM framework! ## License © Jinzhu, 2013~time.Now Released under the [MIT License](https://github.com/go-gorm/gorm/blob/master/LICENSE) ================================================ FILE: association.go ================================================ package gorm import ( "fmt" "reflect" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) // Association Mode contains some helper methods to handle relationship things easily. type Association struct { DB *DB Relationship *schema.Relationship Unscope bool Error error } func (db *DB) Association(column string) *Association { association := &Association{DB: db, Unscope: db.Statement.Unscoped} table := db.Statement.Table if association.Error = db.Statement.Parse(db.Statement.Model); association.Error == nil { db.Statement.Table = table association.Relationship = db.Statement.Schema.Relationships.Relations[column] if association.Relationship == nil { association.Error = fmt.Errorf("%w: %s", ErrUnsupportedRelation, column) } db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } } return association } func (association *Association) Unscoped() *Association { return &Association{ DB: association.DB, Relationship: association.Relationship, Error: association.Error, Unscope: true, } } func (association *Association) Find(out interface{}, conds ...interface{}) error { if association.Error == nil { association.Error = association.buildCondition().Find(out, conds...).Error } return association.Error } func (association *Association) Append(values ...interface{}) error { values = expandValues(values) if association.Error == nil { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: if len(values) > 0 { association.Error = association.Replace(values...) } default: association.saveAssociation( /*clear*/ false, values...) } } return association.Error } func (association *Association) Replace(values ...interface{}) error { values = expandValues(values) if association.Error == nil { reflectValue := association.DB.Statement.ReflectValue rel := association.Relationship var oldBelongsToExpr clause.Expression // we have to record the old BelongsTo value if association.Unscope && rel.Type == schema.BelongsTo { var foreignFields []*schema.Field for _, ref := range rel.References { if !ref.OwnPrimaryKey { foreignFields = append(foreignFields, ref.ForeignKey) } } if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) oldBelongsToExpr = clause.IN{Column: column, Values: values} } } // save associations if association.saveAssociation( /*clear*/ true, values...); association.Error != nil { return association.Error } // set old association's foreign key to null switch rel.Type { case schema.BelongsTo: if len(values) == 0 { updateMap := map[string]interface{}{} switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(rel.Field.FieldType).Interface()) } case reflect.Struct: association.Error = rel.Field.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(rel.Field.FieldType).Interface()) } for _, ref := range rel.References { updateMap[ref.ForeignKey.DBName] = nil } association.Error = association.DB.UpdateColumns(updateMap).Error } if association.Unscope && oldBelongsToExpr != nil { association.Error = association.DB.Model(nil).Where(oldBelongsToExpr).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error } case schema.HasOne, schema.HasMany: var ( primaryFields []*schema.Field foreignKeys []string updateMap = map[string]interface{}{} relValues = schema.GetRelationsValues(association.DB.Statement.Context, reflectValue, []*schema.Relationship{rel}) modelValue = reflect.New(rel.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) if _, rvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, relValues, rel.FieldSchema.PrimaryFields); len(rvs) > 0 { if column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs); len(values) > 0 { tx.Not(clause.IN{Column: column, Values: values}) } } for _, ref := range rel.References { if ref.OwnPrimaryKey { primaryFields = append(primaryFields, ref.PrimaryKey) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateMap[ref.ForeignKey.DBName] = nil } else if ref.PrimaryValue != "" { tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } if _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields); len(pvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs) if association.Unscope { association.Error = tx.Where(clause.IN{Column: column, Values: values}).Delete(modelValue).Error } else { association.Error = tx.Where(clause.IN{Column: column, Values: values}).UpdateColumns(updateMap).Error } } case schema.Many2Many: var ( primaryFields, relPrimaryFields []*schema.Field joinPrimaryKeys, joinRelPrimaryKeys []string modelValue = reflect.New(rel.JoinTable.ModelType).Interface() tx = association.DB.Model(modelValue) ) for _, ref := range rel.References { if ref.PrimaryValue == "" { if ref.OwnPrimaryKey { primaryFields = append(primaryFields, ref.PrimaryKey) joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) } else { relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) } } else { tx.Clauses(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if column, values := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(values) > 0 { tx.Where(clause.IN{Column: column, Values: values}) } else { return ErrPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) if relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs); len(relValues) > 0 { tx.Where(clause.Not(clause.IN{Column: relColumn, Values: relValues})) } association.Error = tx.Delete(modelValue).Error } } return association.Error } func (association *Association) Delete(values ...interface{}) error { values = expandValues(values) if association.Error == nil { var ( reflectValue = association.DB.Statement.ReflectValue rel = association.Relationship primaryFields []*schema.Field foreignKeys []string updateAttrs = map[string]interface{}{} conds []clause.Expression ) for _, ref := range rel.References { if ref.PrimaryValue == "" { primaryFields = append(primaryFields, ref.PrimaryKey) foreignKeys = append(foreignKeys, ref.ForeignKey.DBName) updateAttrs[ref.ForeignKey.DBName] = nil } else { conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } switch rel.Type { case schema.BelongsTo: associationDB := association.DB.Session(&Session{}) tx := associationDB.Model(reflect.New(rel.Schema.ModelType).Interface()) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, rel.Schema.PrimaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.Schema.Table, rel.Schema.PrimaryFieldDBNames, pvs); len(pvalues) > 0 { conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) } else { return ErrPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, primaryFields) relColumn, relValues := schema.ToQueryValues(rel.Schema.Table, foreignKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error if association.Unscope { var foreignFields []*schema.Field for _, ref := range rel.References { if !ref.OwnPrimaryKey { foreignFields = append(foreignFields, ref.ForeignKey) } } if _, fvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, foreignFields); len(fvs) > 0 { column, values := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, fvs) association.Error = associationDB.Model(nil).Where(clause.IN{Column: column, Values: values}).Delete(reflect.New(rel.FieldSchema.ModelType).Interface()).Error } } case schema.HasOne, schema.HasMany: model := reflect.New(rel.FieldSchema.ModelType).Interface() tx := association.DB.Model(model) _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.FieldSchema.Table, foreignKeys, pvs); len(pvalues) > 0 { conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) } else { return ErrPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.FieldSchema.Table, rel.FieldSchema.PrimaryFieldDBNames, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) if association.Unscope { association.Error = tx.Clauses(conds...).Delete(model).Error } else { association.Error = tx.Clauses(conds...).UpdateColumns(updateAttrs).Error } case schema.Many2Many: var ( primaryFields, relPrimaryFields []*schema.Field joinPrimaryKeys, joinRelPrimaryKeys []string joinValue = reflect.New(rel.JoinTable.ModelType).Interface() ) for _, ref := range rel.References { if ref.PrimaryValue == "" { if ref.OwnPrimaryKey { primaryFields = append(primaryFields, ref.PrimaryKey) joinPrimaryKeys = append(joinPrimaryKeys, ref.ForeignKey.DBName) } else { relPrimaryFields = append(relPrimaryFields, ref.PrimaryKey) joinRelPrimaryKeys = append(joinRelPrimaryKeys, ref.ForeignKey.DBName) } } else { conds = append(conds, clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } } _, pvs := schema.GetIdentityFieldValuesMap(association.DB.Statement.Context, reflectValue, primaryFields) if pcolumn, pvalues := schema.ToQueryValues(rel.JoinTable.Table, joinPrimaryKeys, pvs); len(pvalues) > 0 { conds = append(conds, clause.IN{Column: pcolumn, Values: pvalues}) } else { return ErrPrimaryKeyRequired } _, rvs := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, relPrimaryFields) relColumn, relValues := schema.ToQueryValues(rel.JoinTable.Table, joinRelPrimaryKeys, rvs) conds = append(conds, clause.IN{Column: relColumn, Values: relValues}) association.Error = association.DB.Where(clause.Where{Exprs: conds}).Model(nil).Delete(joinValue).Error } if association.Error == nil { // clean up deleted values' foreign key relValuesMap, _ := schema.GetIdentityFieldValuesMapFromValues(association.DB.Statement.Context, values, rel.FieldSchema.PrimaryFields) cleanUpDeletedRelations := func(data reflect.Value) { if _, zero := rel.Field.ValueOf(association.DB.Statement.Context, data); !zero { fieldValue := reflect.Indirect(rel.Field.ReflectValueOf(association.DB.Statement.Context, data)) primaryValues := make([]interface{}, len(rel.FieldSchema.PrimaryFields)) switch fieldValue.Kind() { case reflect.Slice, reflect.Array: validFieldValues := reflect.Zero(rel.Field.IndirectFieldType) for i := 0; i < fieldValue.Len(); i++ { for idx, field := range rel.FieldSchema.PrimaryFields { primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue.Index(i)) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; !ok { validFieldValues = reflect.Append(validFieldValues, fieldValue.Index(i)) } } association.Error = rel.Field.Set(association.DB.Statement.Context, data, validFieldValues.Interface()) case reflect.Struct: for idx, field := range rel.FieldSchema.PrimaryFields { primaryValues[idx], _ = field.ValueOf(association.DB.Statement.Context, fieldValue) } if _, ok := relValuesMap[utils.ToStringKey(primaryValues...)]; ok { if association.Error = rel.Field.Set(association.DB.Statement.Context, data, reflect.Zero(rel.FieldSchema.ModelType).Interface()); association.Error != nil { break } if rel.JoinTable == nil { for _, ref := range rel.References { if ref.OwnPrimaryKey || ref.PrimaryValue != "" { association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, fieldValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } else { association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, data, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } } } } } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { cleanUpDeletedRelations(reflect.Indirect(reflectValue.Index(i))) } case reflect.Struct: cleanUpDeletedRelations(reflectValue) } } } return association.Error } func (association *Association) Clear() error { return association.Replace() } func (association *Association) Count() (count int64) { if association.Error == nil { association.Error = association.buildCondition().Count(&count).Error } return } type assignBack struct { Source reflect.Value Index int Dest reflect.Value } func (association *Association) saveAssociation(clear bool, values ...interface{}) { var ( reflectValue = association.DB.Statement.ReflectValue assignBacks []assignBack // assign association values back to arguments after save ) appendToRelations := func(source, rv reflect.Value, clear bool) { switch association.Relationship.Type { case schema.HasOne, schema.BelongsTo: switch rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Index(0).Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv.Index(0)}) } } case reflect.Struct: if !rv.CanAddr() { association.Error = ErrInvalidValue return } association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, rv.Addr().Interface()) if association.Relationship.Field.FieldType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: rv}) } } case schema.HasMany, schema.Many2Many: elemType := association.Relationship.Field.IndirectFieldType.Elem() oldFieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, source)) var fieldValue reflect.Value if clear { fieldValue = reflect.MakeSlice(oldFieldValue.Type(), 0, oldFieldValue.Cap()) } else { fieldValue = reflect.MakeSlice(oldFieldValue.Type(), oldFieldValue.Len(), oldFieldValue.Cap()) reflect.Copy(fieldValue, oldFieldValue) } appendToFieldValues := func(ev reflect.Value) { if ev.Type().AssignableTo(elemType) { fieldValue = reflect.Append(fieldValue, ev) } else if ev.Type().Elem().AssignableTo(elemType) { fieldValue = reflect.Append(fieldValue, ev.Elem()) } else { association.Error = fmt.Errorf("unsupported data type: %v for relation %s", ev.Type(), association.Relationship.Name) } if elemType.Kind() == reflect.Struct { assignBacks = append(assignBacks, assignBack{Source: source, Dest: ev, Index: fieldValue.Len()}) } } processMap := func(mapv reflect.Value) { child := reflect.New(association.Relationship.FieldSchema.ModelType) switch association.Relationship.Type { case schema.HasMany: for _, ref := range association.Relationship.References { key := reflect.ValueOf(ref.ForeignKey.DBName) if ref.OwnPrimaryKey { v := ref.PrimaryKey.ReflectValueOf(association.DB.Statement.Context, source) mapv.SetMapIndex(key, v) } else if ref.PrimaryValue != "" { mapv.SetMapIndex(key, reflect.ValueOf(ref.PrimaryValue)) } } association.Error = association.DB.Session(&Session{ NewDB: true, }).Model(child.Interface()).Create(mapv.Interface()).Error case schema.Many2Many: association.Error = association.DB.Session(&Session{ NewDB: true, }).Model(child.Interface()).Create(mapv.Interface()).Error for _, key := range mapv.MapKeys() { k := strings.ToLower(key.String()) if f, ok := association.Relationship.FieldSchema.FieldsByDBName[k]; ok { _ = f.Set(association.DB.Statement.Context, child, mapv.MapIndex(key).Interface()) } } appendToFieldValues(child) } } switch rv.Kind() { case reflect.Map: processMap(rv) case reflect.Slice, reflect.Array: for i := 0; i < rv.Len(); i++ { elem := reflect.Indirect(rv.Index(i)) if elem.Kind() == reflect.Map { processMap(elem) continue } appendToFieldValues(elem.Addr()) } case reflect.Struct: if !rv.CanAddr() { association.Error = ErrInvalidValue return } appendToFieldValues(rv.Addr()) } if association.Error == nil { association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, source, fieldValue.Interface()) } } } selectedSaveColumns := []string{association.Relationship.Name} omitColumns := []string{} selectColumns, _ := association.DB.Statement.SelectAndOmitColumns(true, false) for name, ok := range selectColumns { columnName := "" if strings.HasPrefix(name, association.Relationship.Name) { if columnName = strings.TrimPrefix(name, association.Relationship.Name); columnName == ".*" { columnName = name } } else if strings.HasPrefix(name, clause.Associations) { columnName = name } if columnName != "" { if ok { selectedSaveColumns = append(selectedSaveColumns, columnName) } else { omitColumns = append(omitColumns, columnName) } } } for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey { selectedSaveColumns = append(selectedSaveColumns, ref.ForeignKey.Name) } } associationDB := association.DB.Session(&Session{}).Model(nil) if !association.DB.FullSaveAssociations { associationDB.Select(selectedSaveColumns) } if len(omitColumns) > 0 { associationDB.Omit(omitColumns...) } associationDB = associationDB.Session(&Session{}) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if len(values) != reflectValue.Len() { // clear old data if clear && len(values) == 0 { for i := 0; i < reflectValue.Len(); i++ { if err := association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.New(association.Relationship.Field.IndirectFieldType).Interface()); err != nil { association.Error = err break } if association.Relationship.JoinTable == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { if err := ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue.Index(i), reflect.Zero(ref.ForeignKey.FieldType).Interface()); err != nil { association.Error = err break } } } } } break } association.Error = ErrInvalidValueOfLength return } for i := 0; i < reflectValue.Len(); i++ { appendToRelations(reflectValue.Index(i), reflect.Indirect(reflect.ValueOf(values[i])), clear) if association.Error != nil { return } // TODO support save slice data, sql with case? association.Error = associationDB.Updates(reflectValue.Index(i).Addr().Interface()).Error } case reflect.Struct: // clear old data if clear && len(values) == 0 { association.Error = association.Relationship.Field.Set(association.DB.Statement.Context, reflectValue, reflect.New(association.Relationship.Field.IndirectFieldType).Interface()) if association.Relationship.JoinTable == nil && association.Error == nil { for _, ref := range association.Relationship.References { if !ref.OwnPrimaryKey && ref.PrimaryValue == "" { association.Error = ref.ForeignKey.Set(association.DB.Statement.Context, reflectValue, reflect.Zero(ref.ForeignKey.FieldType).Interface()) } } } } for idx, value := range values { rv := reflect.Indirect(reflect.ValueOf(value)) appendToRelations(reflectValue, rv, clear && idx == 0) if association.Error != nil { return } } if len(values) > 0 { association.Error = associationDB.Updates(reflectValue.Addr().Interface()).Error } } for _, assignBack := range assignBacks { fieldValue := reflect.Indirect(association.Relationship.Field.ReflectValueOf(association.DB.Statement.Context, assignBack.Source)) if assignBack.Index > 0 { reflect.Indirect(assignBack.Dest).Set(fieldValue.Index(assignBack.Index - 1)) } else { reflect.Indirect(assignBack.Dest).Set(fieldValue) } } } func (association *Association) buildCondition() *DB { var ( queryConds = association.Relationship.ToQueryConditions(association.DB.Statement.Context, association.DB.Statement.ReflectValue) modelValue = reflect.New(association.Relationship.FieldSchema.ModelType).Interface() tx = association.DB.Model(modelValue) ) if association.Relationship.JoinTable != nil { if !tx.Statement.Unscoped && len(association.Relationship.JoinTable.QueryClauses) > 0 { joinStmt := Statement{DB: tx, Context: tx.Statement.Context, Schema: association.Relationship.JoinTable, Table: association.Relationship.JoinTable.Table, Clauses: map[string]clause.Clause{}} for _, queryClause := range association.Relationship.JoinTable.QueryClauses { joinStmt.AddClause(queryClause) } joinStmt.Build("WHERE") if len(joinStmt.SQL.String()) > 0 { tx.Clauses(clause.Expr{SQL: strings.Replace(joinStmt.SQL.String(), "WHERE ", "", 1), Vars: joinStmt.Vars}) } } tx = tx.Session(&Session{QueryFields: true}).Clauses(clause.From{Joins: []clause.Join{{ Table: clause.Table{Name: association.Relationship.JoinTable.Table}, ON: clause.Where{Exprs: queryConds}, }}}) } else { tx.Clauses(clause.Where{Exprs: queryConds}) } return tx } func expandValues(values ...any) (results []any) { appendToResult := func(rv reflect.Value) { // unwrap interface if rv.IsValid() && rv.Kind() == reflect.Interface { rv = rv.Elem() } if rv.IsValid() && rv.Kind() == reflect.Struct { p := reflect.New(rv.Type()) p.Elem().Set(rv) results = append(results, p.Interface()) } else if rv.IsValid() { results = append(results, rv.Interface()) } } // Process each argument; if an argument is a slice/array, expand its elements for _, value := range values { rv := reflect.ValueOf(value) if rv.Kind() == reflect.Slice || rv.Kind() == reflect.Array { for i := 0; i < rv.Len(); i++ { appendToResult(rv.Index(i)) } } else { appendToResult(rv) } } return } ================================================ FILE: callbacks/associations.go ================================================ package callbacks import ( "reflect" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) func SaveBeforeAssociations(create bool) func(db *gorm.DB) { return func(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) // Save Belongs To associations for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } setupReferences := func(obj reflect.Value, elem reflect.Value) { for _, ref := range rel.References { if !ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) db.AddError(ref.ForeignKey.Set(db.Statement.Context, obj, pv)) if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { dest[ref.ForeignKey.DBName] = pv if _, ok := dest[rel.Name]; ok { dest[rel.Name] = elem.Interface() } } } } } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( rValLen = db.Statement.ReflectValue.Len() objs = make([]reflect.Value, 0, rValLen) fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr ) if !isPtr { fieldType = reflect.PointerTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) identityMap := map[string]bool{} for i := 0; i < rValLen; i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() != reflect.Struct { break } if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { // check belongs to relation value rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) // relation reflect value if !isPtr { rv = rv.Addr() } objs = append(objs, obj) elems = reflect.Append(elems, rv) relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) for _, pf := range rel.FieldSchema.PrimaryFields { if pfv, ok := pf.ValueOf(db.Statement.Context, rv); !ok { relPrimaryValues = append(relPrimaryValues, pfv) } } cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if cacheKey != "" { // has primary fields identityMap[cacheKey] = true } distinctElems = reflect.Append(distinctElems, rv) } } } if elems.Len() > 0 { if saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) == nil { for i := 0; i < elems.Len(); i++ { setupReferences(objs[i], elems.Index(i)) } } } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { rv := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) // relation reflect value if rv.Kind() != reflect.Ptr { rv = rv.Addr() } if saveAssociations(db, rel, rv, selectColumns, restricted, nil) == nil { setupReferences(db.Statement.ReflectValue, rv) } } } } } } } func SaveAfterAssociations(create bool) func(db *gorm.DB) { return func(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(create, !create) // Save Has One associations for _, rel := range db.Statement.Schema.Relationships.HasOne { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: var ( fieldType = rel.Field.FieldType isPtr = fieldType.Kind() == reflect.Ptr ) if !isPtr { fieldType = reflect.PointerTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { if _, zero := rel.Field.ValueOf(db.Statement.Context, obj); !zero { rv := rel.Field.ReflectValueOf(db.Statement.Context, obj) if rv.Kind() != reflect.Ptr { rv = rv.Addr() } for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, fv)) } else if ref.PrimaryValue != "" { db.AddError(ref.ForeignKey.Set(db.Statement.Context, rv, ref.PrimaryValue)) } } elems = reflect.Append(elems, rv) } } } if elems.Len() > 0 { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } case reflect.Struct: if _, zero := rel.Field.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !zero { f := rel.Field.ReflectValueOf(db.Statement.Context, db.Statement.ReflectValue) if f.Kind() != reflect.Ptr { f = f.Addr() } assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, db.Statement.ReflectValue) db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, fv)) } else if ref.PrimaryValue != "" { db.AddError(ref.ForeignKey.Set(db.Statement.Context, f, ref.PrimaryValue)) } assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } saveAssociations(db, rel, f, selectColumns, restricted, assignmentColumns) } } } // Save Has Many associations for _, rel := range db.Statement.Schema.Relationships.HasMany { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } fieldType := rel.Field.IndirectFieldType.Elem() isPtr := fieldType.Kind() == reflect.Ptr if !isPtr { fieldType = reflect.PointerTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) for _, ref := range rel.References { if ref.OwnPrimaryKey { pv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, v) db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, pv)) } else if ref.PrimaryValue != "" { db.AddError(ref.ForeignKey.Set(db.Statement.Context, elem, ref.PrimaryValue)) } } relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) for _, pf := range rel.FieldSchema.PrimaryFields { if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { relPrimaryValues = append(relPrimaryValues, pfv) } } cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if cacheKey != "" { // has primary fields identityMap[cacheKey] = true } if isPtr { elems = reflect.Append(elems, elem) } else { elems = reflect.Append(elems, elem.Addr()) } } } } } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { appendToElems(obj) } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) } if elems.Len() > 0 { assignmentColumns := make([]string, 0, len(rel.References)) for _, ref := range rel.References { assignmentColumns = append(assignmentColumns, ref.ForeignKey.DBName) } saveAssociations(db, rel, elems, selectColumns, restricted, assignmentColumns) } } // Save Many2Many associations for _, rel := range db.Statement.Schema.Relationships.Many2Many { if v, ok := selectColumns[rel.Name]; (ok && !v) || (!ok && restricted) { continue } fieldType := rel.Field.IndirectFieldType.Elem() isPtr := fieldType.Kind() == reflect.Ptr if !isPtr { fieldType = reflect.PointerTo(fieldType) } elems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) distinctElems := reflect.MakeSlice(reflect.SliceOf(fieldType), 0, 10) joins := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.JoinTable.ModelType)), 0, 10) objs := []reflect.Value{} appendToJoins := func(obj reflect.Value, elem reflect.Value) { joinValue := reflect.New(rel.JoinTable.ModelType) for _, ref := range rel.References { if ref.OwnPrimaryKey { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, obj) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } else if ref.PrimaryValue != "" { db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, ref.PrimaryValue)) } else { fv, _ := ref.PrimaryKey.ValueOf(db.Statement.Context, elem) db.AddError(ref.ForeignKey.Set(db.Statement.Context, joinValue, fv)) } } joins = reflect.Append(joins, joinValue) } identityMap := map[string]bool{} appendToElems := func(v reflect.Value) { if _, zero := rel.Field.ValueOf(db.Statement.Context, v); !zero { f := reflect.Indirect(rel.Field.ReflectValueOf(db.Statement.Context, v)) for i := 0; i < f.Len(); i++ { elem := f.Index(i) if !isPtr { elem = elem.Addr() } objs = append(objs, v) elems = reflect.Append(elems, elem) relPrimaryValues := make([]interface{}, 0, len(rel.FieldSchema.PrimaryFields)) for _, pf := range rel.FieldSchema.PrimaryFields { if pfv, ok := pf.ValueOf(db.Statement.Context, elem); !ok { relPrimaryValues = append(relPrimaryValues, pfv) } } cacheKey := utils.ToStringKey(relPrimaryValues...) if len(relPrimaryValues) != len(rel.FieldSchema.PrimaryFields) || !identityMap[cacheKey] { if cacheKey != "" { // has primary fields identityMap[cacheKey] = true } distinctElems = reflect.Append(distinctElems, elem) } } } } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: for i := 0; i < db.Statement.ReflectValue.Len(); i++ { obj := db.Statement.ReflectValue.Index(i) if reflect.Indirect(obj).Kind() == reflect.Struct { appendToElems(obj) } } case reflect.Struct: appendToElems(db.Statement.ReflectValue) } // optimize elems of reflect value length if elemLen := elems.Len(); elemLen > 0 { if v, ok := selectColumns[rel.Name+".*"]; !ok || v { saveAssociations(db, rel, distinctElems, selectColumns, restricted, nil) } for i := 0; i < elemLen; i++ { appendToJoins(objs[i], elems.Index(i)) } } if joins.Len() > 0 { db.AddError(db.Session(&gorm.Session{NewDB: true}).Clauses(clause.OnConflict{DoNothing: true}).Session(&gorm.Session{ SkipHooks: db.Statement.SkipHooks, DisableNestedTransaction: true, }).Create(joins.Interface()).Error) } } } } } func onConflictOption(stmt *gorm.Statement, s *schema.Schema, defaultUpdatingColumns []string) (onConflict clause.OnConflict) { if len(defaultUpdatingColumns) > 0 || stmt.DB.FullSaveAssociations { onConflict.Columns = make([]clause.Column, 0, len(s.PrimaryFieldDBNames)) for _, dbName := range s.PrimaryFieldDBNames { onConflict.Columns = append(onConflict.Columns, clause.Column{Name: dbName}) } onConflict.UpdateAll = stmt.DB.FullSaveAssociations if !onConflict.UpdateAll { onConflict.DoUpdates = clause.AssignmentColumns(defaultUpdatingColumns) } } else { onConflict.DoNothing = true } return } func saveAssociations(db *gorm.DB, rel *schema.Relationship, rValues reflect.Value, selectColumns map[string]bool, restricted bool, defaultUpdatingColumns []string) error { // stop save association loop if checkAssociationsSaved(db, rValues) { return nil } var ( selects, omits []string onConflict = onConflictOption(db.Statement, rel.FieldSchema, defaultUpdatingColumns) refName = rel.Name + "." values = rValues.Interface() ) for name, ok := range selectColumns { columnName := "" if strings.HasPrefix(name, refName) { columnName = strings.TrimPrefix(name, refName) } if columnName != "" { if ok { selects = append(selects, columnName) } else { omits = append(omits, columnName) } } } tx := db.Session(&gorm.Session{NewDB: true}).Clauses(onConflict).Session(&gorm.Session{ FullSaveAssociations: db.FullSaveAssociations, SkipHooks: db.Statement.SkipHooks, DisableNestedTransaction: true, }) db.Statement.Settings.Range(func(k, v interface{}) bool { tx.Statement.Settings.Store(k, v) return true }) if tx.Statement.FullSaveAssociations { tx = tx.Set("gorm:update_track_time", true) } if len(selects) > 0 { tx = tx.Select(selects) } else if restricted && len(omits) == 0 { tx = tx.Omit(clause.Associations) } if len(omits) > 0 { tx = tx.Omit(omits...) } return db.AddError(tx.Create(values).Error) } // check association values has been saved // if values kind is Struct, check it has been saved // if values kind is Slice/Array, check all items have been saved var visitMapStoreKey = "gorm:saved_association_map" func checkAssociationsSaved(db *gorm.DB, values reflect.Value) bool { if visit, ok := db.Get(visitMapStoreKey); ok { if v, ok := visit.(*visitMap); ok { if loadOrStoreVisitMap(v, values) { return true } } } else { vistMap := make(visitMap) loadOrStoreVisitMap(&vistMap, values) db.Set(visitMapStoreKey, &vistMap) } return false } ================================================ FILE: callbacks/callbacks.go ================================================ package callbacks import ( "gorm.io/gorm" ) var ( createClauses = []string{"INSERT", "VALUES", "ON CONFLICT"} queryClauses = []string{"SELECT", "FROM", "WHERE", "GROUP BY", "ORDER BY", "LIMIT", "FOR"} updateClauses = []string{"UPDATE", "SET", "WHERE"} deleteClauses = []string{"DELETE", "FROM", "WHERE"} ) type Config struct { LastInsertIDReversed bool CreateClauses []string QueryClauses []string UpdateClauses []string DeleteClauses []string } func RegisterDefaultCallbacks(db *gorm.DB, config *Config) { enableTransaction := func(db *gorm.DB) bool { return !db.SkipDefaultTransaction } if len(config.CreateClauses) == 0 { config.CreateClauses = createClauses } if len(config.QueryClauses) == 0 { config.QueryClauses = queryClauses } if len(config.DeleteClauses) == 0 { config.DeleteClauses = deleteClauses } if len(config.UpdateClauses) == 0 { config.UpdateClauses = updateClauses } createCallback := db.Callback().Create() createCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) createCallback.Register("gorm:before_create", BeforeCreate) createCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(true)) createCallback.Register("gorm:create", Create(config)) createCallback.Register("gorm:save_after_associations", SaveAfterAssociations(true)) createCallback.Register("gorm:after_create", AfterCreate) createCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) createCallback.Clauses = config.CreateClauses queryCallback := db.Callback().Query() queryCallback.Register("gorm:query", Query) queryCallback.Register("gorm:preload", Preload) queryCallback.Register("gorm:after_query", AfterQuery) queryCallback.Clauses = config.QueryClauses deleteCallback := db.Callback().Delete() deleteCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) deleteCallback.Register("gorm:before_delete", BeforeDelete) deleteCallback.Register("gorm:delete_before_associations", DeleteBeforeAssociations) deleteCallback.Register("gorm:delete", Delete(config)) deleteCallback.Register("gorm:after_delete", AfterDelete) deleteCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) deleteCallback.Clauses = config.DeleteClauses updateCallback := db.Callback().Update() updateCallback.Match(enableTransaction).Register("gorm:begin_transaction", BeginTransaction) updateCallback.Register("gorm:setup_reflect_value", SetupUpdateReflectValue) updateCallback.Register("gorm:before_update", BeforeUpdate) updateCallback.Register("gorm:save_before_associations", SaveBeforeAssociations(false)) updateCallback.Register("gorm:update", Update(config)) updateCallback.Register("gorm:save_after_associations", SaveAfterAssociations(false)) updateCallback.Register("gorm:after_update", AfterUpdate) updateCallback.Match(enableTransaction).Register("gorm:commit_or_rollback_transaction", CommitOrRollbackTransaction) updateCallback.Clauses = config.UpdateClauses rowCallback := db.Callback().Row() rowCallback.Register("gorm:row", RowQuery) rowCallback.Clauses = config.QueryClauses rawCallback := db.Callback().Raw() rawCallback.Register("gorm:raw", RawExec) rawCallback.Clauses = config.QueryClauses } ================================================ FILE: callbacks/callmethod.go ================================================ package callbacks import ( "reflect" "gorm.io/gorm" ) func callMethod(db *gorm.DB, fc func(value interface{}, tx *gorm.DB) bool) { tx := db.Session(&gorm.Session{NewDB: true}) if called := fc(db.Statement.ReflectValue.Interface(), tx); !called { switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: db.Statement.CurDestIndex = 0 for i := 0; i < db.Statement.ReflectValue.Len(); i++ { if value := reflect.Indirect(db.Statement.ReflectValue.Index(i)); value.CanAddr() { fc(value.Addr().Interface(), tx) } else { db.AddError(gorm.ErrInvalidValue) return } db.Statement.CurDestIndex++ } case reflect.Struct: if db.Statement.ReflectValue.CanAddr() { fc(db.Statement.ReflectValue.Addr().Interface(), tx) } else { db.AddError(gorm.ErrInvalidValue) } } } } ================================================ FILE: callbacks/create.go ================================================ package callbacks import ( "fmt" "reflect" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) // BeforeCreate before create hooks func BeforeCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeCreate { if i, ok := value.(BeforeCreateInterface); ok { called = true db.AddError(i.BeforeCreate(tx)) } } return called }) } } // Create create hook func Create(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.CreateClauses, "RETURNING") return func(db *gorm.DB) { if db.Error != nil { return } if db.Statement.Schema != nil { if !db.Statement.Unscoped { for _, c := range db.Statement.Schema.CreateClauses { db.Statement.AddClause(c) } } if supportReturning && len(db.Statement.Schema.FieldsWithDefaultDBValue) > 0 { if _, ok := db.Statement.Clauses["RETURNING"]; !ok { fromColumns := make([]clause.Column, 0, len(db.Statement.Schema.FieldsWithDefaultDBValue)) for _, field := range db.Statement.Schema.FieldsWithDefaultDBValue { if field.Readable { fromColumns = append(fromColumns, clause.Column{Name: field.DBName}) } } if len(fromColumns) > 0 { db.Statement.AddClause(clause.Returning{Columns: fromColumns}) } } } } if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Insert{}) db.Statement.AddClause(ConvertToCreateValues(db.Statement)) db.Statement.Build(db.Statement.BuildClauses...) } isDryRun := !db.DryRun && db.Error == nil if !isDryRun { return } ok, mode := hasReturning(db, supportReturning) if ok { if c, ok := db.Statement.Clauses["ON CONFLICT"]; ok { onConflict, _ := c.Expression.(clause.OnConflict) if onConflict.DoNothing { mode |= gorm.ScanOnConflictDoNothing } else if len(onConflict.DoUpdates) > 0 || onConflict.UpdateAll { mode |= gorm.ScanUpdate } } rows, err := db.Statement.ConnPool.QueryContext( db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., ) if db.AddError(err) == nil { defer func() { db.AddError(rows.Close()) }() gorm.Scan(rows, db, mode) if db.Statement.Result != nil { db.Statement.Result.RowsAffected = db.RowsAffected } } return } result, err := db.Statement.ConnPool.ExecContext( db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars..., ) if err != nil { db.AddError(err) return } db.RowsAffected, _ = result.RowsAffected() if db.Statement.Result != nil { db.Statement.Result.Result = result db.Statement.Result.RowsAffected = db.RowsAffected } if db.RowsAffected == 0 { return } var ( pkField *schema.Field pkFieldName = "@id" ) if db.Statement.Schema != nil { if db.Statement.Schema.PrioritizedPrimaryField == nil || !db.Statement.Schema.PrioritizedPrimaryField.HasDefaultValue || !db.Statement.Schema.PrioritizedPrimaryField.Readable { return } pkField = db.Statement.Schema.PrioritizedPrimaryField pkFieldName = db.Statement.Schema.PrioritizedPrimaryField.DBName } insertID, err := result.LastInsertId() insertOk := err == nil && insertID > 0 if !insertOk { if !supportReturning { db.AddError(err) } return } // append @id column with value for auto-increment primary key // the @id value is correct, when: 1. without setting auto-increment primary key, 2. database AutoIncrementIncrement = 1 switch values := db.Statement.Dest.(type) { case map[string]interface{}: values[pkFieldName] = insertID case *map[string]interface{}: (*values)[pkFieldName] = insertID case []map[string]interface{}, *[]map[string]interface{}: mapValues, ok := values.([]map[string]interface{}) if !ok { if v, ok := values.(*[]map[string]interface{}); ok { if *v != nil { mapValues = *v } } } if config.LastInsertIDReversed { insertID -= int64(len(mapValues)-1) * schema.DefaultAutoIncrementIncrement } for _, mapValue := range mapValues { if mapValue != nil { mapValue[pkFieldName] = insertID } insertID += schema.DefaultAutoIncrementIncrement } default: if pkField == nil { return } switch db.Statement.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if config.LastInsertIDReversed { for i := db.Statement.ReflectValue.Len() - 1; i >= 0; i-- { rv := db.Statement.ReflectValue.Index(i) if reflect.Indirect(rv).Kind() != reflect.Struct { break } _, isZero := pkField.ValueOf(db.Statement.Context, rv) if isZero { db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) insertID -= pkField.AutoIncrementIncrement } } } else { for i := 0; i < db.Statement.ReflectValue.Len(); i++ { rv := db.Statement.ReflectValue.Index(i) if reflect.Indirect(rv).Kind() != reflect.Struct { break } if _, isZero := pkField.ValueOf(db.Statement.Context, rv); isZero { db.AddError(pkField.Set(db.Statement.Context, rv, insertID)) insertID += pkField.AutoIncrementIncrement } } } case reflect.Struct: _, isZero := pkField.ValueOf(db.Statement.Context, db.Statement.ReflectValue) if isZero { db.AddError(pkField.Set(db.Statement.Context, db.Statement.ReflectValue, insertID)) } } } } } // AfterCreate after create hooks func AfterCreate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterCreate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterCreate { if i, ok := value.(AfterCreateInterface); ok { called = true db.AddError(i.AfterCreate(tx)) } } if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } return called }) } } // ConvertToCreateValues convert to create values func ConvertToCreateValues(stmt *gorm.Statement) (values clause.Values) { curTime := stmt.DB.NowFunc() switch value := stmt.Dest.(type) { case map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, value) case *map[string]interface{}: values = ConvertMapToValuesForCreate(stmt, *value) case []map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, value) case *[]map[string]interface{}: values = ConvertSliceOfMapToValuesForCreate(stmt, *value) default: var ( selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) _, updateTrackTime = stmt.Get("gorm:update_track_time") isZero bool ) stmt.Settings.Delete("gorm:update_track_time") values = clause.Values{Columns: make([]clause.Column, 0, len(stmt.Schema.DBNames))} for _, db := range stmt.Schema.DBNames { if field := stmt.Schema.FieldsByDBName[db]; !field.HasDefaultValue || field.DefaultValueInterface != nil { if v, ok := selectColumns[db]; (ok && v) || (!ok && (!restricted || field.AutoCreateTime > 0 || field.AutoUpdateTime > 0)) { values.Columns = append(values.Columns, clause.Column{Name: db}) } } } switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: rValLen := stmt.ReflectValue.Len() if rValLen == 0 { stmt.AddError(gorm.ErrEmptySlice) return } stmt.SQL.Grow(rValLen * 18) stmt.Vars = make([]interface{}, 0, rValLen*len(values.Columns)) values.Values = make([][]interface{}, rValLen) defaultValueFieldsHavingValue := map[*schema.Field][]interface{}{} for i := 0; i < rValLen; i++ { rv := reflect.Indirect(stmt.ReflectValue.Index(i)) if !rv.IsValid() { stmt.AddError(fmt.Errorf("slice data #%v is invalid: %w", i, gorm.ErrInvalidData)) return } values.Values[i] = make([]interface{}, len(values.Columns)) for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] if values.Values[i][idx], isZero = field.ValueOf(stmt.Context, rv); isZero { if field.DefaultValueInterface != nil { values.Values[i][idx] = field.DefaultValueInterface stmt.AddError(field.Set(stmt.Context, rv, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } else if field.AutoUpdateTime > 0 && updateTrackTime { stmt.AddError(field.Set(stmt.Context, rv, curTime)) values.Values[i][idx], _ = field.ValueOf(stmt.Context, rv) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if rvOfvalue, isZero := field.ValueOf(stmt.Context, rv); !isZero { if len(defaultValueFieldsHavingValue[field]) == 0 { defaultValueFieldsHavingValue[field] = make([]interface{}, rValLen) } defaultValueFieldsHavingValue[field][i] = rvOfvalue } } } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if vs, ok := defaultValueFieldsHavingValue[field]; ok { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) for idx := range values.Values { if vs[idx] == nil { values.Values[idx] = append(values.Values[idx], stmt.DefaultValueOf(field)) } else { values.Values[idx] = append(values.Values[idx], vs[idx]) } } } } case reflect.Struct: values.Values = [][]interface{}{make([]interface{}, len(values.Columns))} for idx, column := range values.Columns { field := stmt.Schema.FieldsByDBName[column.Name] if values.Values[0][idx], isZero = field.ValueOf(stmt.Context, stmt.ReflectValue); isZero { if field.DefaultValueInterface != nil { values.Values[0][idx] = field.DefaultValueInterface stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, field.DefaultValueInterface)) } else if field.AutoCreateTime > 0 || field.AutoUpdateTime > 0 { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } else if field.AutoUpdateTime > 0 && updateTrackTime { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, curTime)) values.Values[0][idx], _ = field.ValueOf(stmt.Context, stmt.ReflectValue) } } for _, field := range stmt.Schema.FieldsWithDefaultDBValue { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) && field.DefaultValueInterface == nil { if rvOfvalue, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { values.Columns = append(values.Columns, clause.Column{Name: field.DBName}) values.Values[0] = append(values.Values[0], rvOfvalue) } } } default: stmt.AddError(gorm.ErrInvalidData) } } if c, ok := stmt.Clauses["ON CONFLICT"]; ok { if onConflict, _ := c.Expression.(clause.OnConflict); onConflict.UpdateAll { if stmt.Schema != nil && len(values.Columns) >= 1 { selectColumns, restricted := stmt.SelectAndOmitColumns(true, true) columns := make([]string, 0, len(values.Columns)-1) for _, column := range values.Columns { if field := stmt.Schema.LookUpField(column.Name); field != nil { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if !field.PrimaryKey && (!field.HasDefaultValue || field.DefaultValueInterface != nil || strings.EqualFold(field.DefaultValue, "NULL")) && field.AutoCreateTime == 0 { if field.AutoUpdateTime > 0 { assignment := clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: curTime} switch field.AutoUpdateTime { case schema.UnixNanosecond: assignment.Value = curTime.UnixNano() case schema.UnixMillisecond: assignment.Value = curTime.UnixMilli() case schema.UnixSecond: assignment.Value = curTime.Unix() } onConflict.DoUpdates = append(onConflict.DoUpdates, assignment) } else { columns = append(columns, column.Name) } } } } } onConflict.DoUpdates = append(onConflict.DoUpdates, clause.AssignmentColumns(columns)...) if len(onConflict.DoUpdates) == 0 { onConflict.DoNothing = true } // use primary fields as default OnConflict columns if len(onConflict.Columns) == 0 { for _, field := range stmt.Schema.PrimaryFields { onConflict.Columns = append(onConflict.Columns, clause.Column{Name: field.DBName}) } } stmt.AddClause(onConflict) } } } return values } ================================================ FILE: callbacks/create_test.go ================================================ package callbacks import ( "reflect" "sync" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) var schemaCache = &sync.Map{} func TestConvertToCreateValues_DestType_Slice(t *testing.T) { type user struct { ID int `gorm:"primaryKey"` Name string Email string `gorm:"default:(-)"` Age int `gorm:"default:(-)"` } s, err := schema.Parse(&user{}, schemaCache, schema.NamingStrategy{}) if err != nil { t.Errorf("parse schema error: %v, is not expected", err) return } dest := []*user{ { ID: 1, Name: "alice", Email: "email", Age: 18, }, { ID: 2, Name: "bob", Email: "email", Age: 19, }, } stmt := &gorm.Statement{ DB: &gorm.DB{ Config: &gorm.Config{ NowFunc: func() time.Time { return time.Time{} }, }, Statement: &gorm.Statement{ Settings: sync.Map{}, Schema: s, }, }, ReflectValue: reflect.ValueOf(dest), Dest: dest, } stmt.Schema = s values := ConvertToCreateValues(stmt) expected := clause.Values{ // column has value + defaultValue column has value (which should have a stable order) Columns: []clause.Column{{Name: "name"}, {Name: "email"}, {Name: "age"}, {Name: "id"}}, Values: [][]interface{}{ {"alice", "email", 18, 1}, {"bob", "email", 19, 2}, }, } if !reflect.DeepEqual(expected, values) { t.Errorf("expected: %v got %v", expected, values) } } ================================================ FILE: callbacks/delete.go ================================================ package callbacks import ( "reflect" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) func BeforeDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.BeforeDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(BeforeDeleteInterface); ok { db.AddError(i.BeforeDelete(tx)) return true } return false }) } } func DeleteBeforeAssociations(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { selectColumns, restricted := db.Statement.SelectAndOmitColumns(true, false) if !restricted { return } for column, v := range selectColumns { if !v { continue } rel, ok := db.Statement.Schema.Relationships.Relations[column] if !ok { continue } switch rel.Type { case schema.HasOne, schema.HasMany: queryConds := rel.ToQueryConditions(db.Statement.Context, db.Statement.ReflectValue) modelValue := reflect.New(rel.FieldSchema.ModelType).Interface() tx := db.Session(&gorm.Session{NewDB: true}).Model(modelValue) withoutConditions := false if db.Statement.Unscoped { tx = tx.Unscoped() } if len(db.Statement.Selects) > 0 { selects := make([]string, 0, len(db.Statement.Selects)) for _, s := range db.Statement.Selects { if s == clause.Associations { selects = append(selects, s) } else if columnPrefix := column + "."; strings.HasPrefix(s, columnPrefix) { selects = append(selects, strings.TrimPrefix(s, columnPrefix)) } } if len(selects) > 0 { tx = tx.Select(selects) } } for _, cond := range queryConds { if c, ok := cond.(clause.IN); ok && len(c.Values) == 0 { withoutConditions = true break } } if !withoutConditions && db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { return } case schema.Many2Many: var ( queryConds = make([]clause.Expression, 0, len(rel.References)) foreignFields = make([]*schema.Field, 0, len(rel.References)) relForeignKeys = make([]string, 0, len(rel.References)) modelValue = reflect.New(rel.JoinTable.ModelType).Interface() table = rel.JoinTable.Table tx = db.Session(&gorm.Session{NewDB: true}).Model(modelValue).Table(table) ) for _, ref := range rel.References { if ref.OwnPrimaryKey { foreignFields = append(foreignFields, ref.PrimaryKey) relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) } else if ref.PrimaryValue != "" { queryConds = append(queryConds, clause.Eq{ Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, Value: ref.PrimaryValue, }) } } _, foreignValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, foreignFields) column, values := schema.ToQueryValues(table, relForeignKeys, foreignValues) queryConds = append(queryConds, clause.IN{Column: column, Values: values}) if db.AddError(tx.Clauses(clause.Where{Exprs: queryConds}).Delete(modelValue).Error) != nil { return } } } } } func Delete(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.DeleteClauses, "RETURNING") return func(db *gorm.DB) { if db.Error != nil { return } if db.Statement.Schema != nil { for _, c := range db.Statement.Schema.DeleteClauses { db.Statement.AddClause(c) } } if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) db.Statement.AddClauseIfNotExists(clause.Delete{}) if db.Statement.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(db.Statement.Context, db.Statement.ReflectValue, db.Statement.Schema.PrimaryFields) column, values := schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } if db.Statement.ReflectValue.CanAddr() && db.Statement.Dest != db.Statement.Model && db.Statement.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(db.Statement.Context, reflect.ValueOf(db.Statement.Model), db.Statement.Schema.PrimaryFields) column, values = schema.ToQueryValues(db.Statement.Table, db.Statement.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { db.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } } } db.Statement.AddClauseIfNotExists(clause.From{}) db.Statement.Build(db.Statement.BuildClauses...) } checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { ok, mode := hasReturning(db, supportReturning) if !ok { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() if db.Statement.Result != nil { db.Statement.Result.Result = result db.Statement.Result.RowsAffected = db.RowsAffected } } return } if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { gorm.Scan(rows, db, mode) if db.Statement.Result != nil { db.Statement.Result.RowsAffected = db.RowsAffected } db.AddError(rows.Close()) } } } } func AfterDelete(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterDelete { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterDeleteInterface); ok { db.AddError(i.AfterDelete(tx)) return true } return false }) } } ================================================ FILE: callbacks/helper.go ================================================ package callbacks import ( "reflect" "sort" "gorm.io/gorm" "gorm.io/gorm/clause" ) // ConvertMapToValuesForCreate convert map to values func ConvertMapToValuesForCreate(stmt *gorm.Statement, mapValue map[string]interface{}) (values clause.Values) { values.Columns = make([]clause.Column, 0, len(mapValue)) selectColumns, restricted := stmt.SelectAndOmitColumns(true, false) keys := make([]string, 0, len(mapValue)) for k := range mapValue { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { value := mapValue[k] if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { k = field.DBName } } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { values.Columns = append(values.Columns, clause.Column{Name: k}) if len(values.Values) == 0 { values.Values = [][]interface{}{{}} } values.Values[0] = append(values.Values[0], value) } } return } // ConvertSliceOfMapToValuesForCreate convert slice of map to values func ConvertSliceOfMapToValuesForCreate(stmt *gorm.Statement, mapValues []map[string]interface{}) (values clause.Values) { columns := make([]string, 0, len(mapValues)) // when the length of mapValues is zero,return directly here // no need to call stmt.SelectAndOmitColumns method if len(mapValues) == 0 { stmt.AddError(gorm.ErrEmptySlice) return } var ( result = make(map[string][]interface{}, len(mapValues)) selectColumns, restricted = stmt.SelectAndOmitColumns(true, false) ) for idx, mapValue := range mapValues { for k, v := range mapValue { if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { k = field.DBName } } if _, ok := result[k]; !ok { if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { result[k] = make([]interface{}, len(mapValues)) columns = append(columns, k) } else { continue } } result[k][idx] = v } } sort.Strings(columns) values.Values = make([][]interface{}, len(mapValues)) values.Columns = make([]clause.Column, len(columns)) for idx, column := range columns { values.Columns[idx] = clause.Column{Name: column} for i, v := range result[column] { if len(values.Values[i]) == 0 { values.Values[i] = make([]interface{}, len(columns)) } values.Values[i][idx] = v } } return } func hasReturning(tx *gorm.DB, supportReturning bool) (bool, gorm.ScanMode) { if supportReturning { if c, ok := tx.Statement.Clauses["RETURNING"]; ok { returning, _ := c.Expression.(clause.Returning) if len(returning.Columns) == 0 || (len(returning.Columns) == 1 && returning.Columns[0].Name == "*") { return true, 0 } return true, gorm.ScanUpdate } } return false, 0 } func checkMissingWhereConditions(db *gorm.DB) { if !db.AllowGlobalUpdate && db.Error == nil { where, withCondition := db.Statement.Clauses["WHERE"] if withCondition { if _, withSoftDelete := db.Statement.Clauses["soft_delete_enabled"]; withSoftDelete { whereClause, _ := where.Expression.(clause.Where) withCondition = len(whereClause.Exprs) > 1 } } if !withCondition { db.AddError(gorm.ErrMissingWhereClause) } return } } type visitMap = map[reflect.Value]bool // Check if circular values, return true if loaded func loadOrStoreVisitMap(visitMap *visitMap, v reflect.Value) (loaded bool) { if v.Kind() == reflect.Ptr { v = v.Elem() } switch v.Kind() { case reflect.Slice, reflect.Array: loaded = true for i := 0; i < v.Len(); i++ { if !loadOrStoreVisitMap(visitMap, v.Index(i)) { loaded = false } } case reflect.Struct, reflect.Interface: if v.CanAddr() { p := v.Addr() if _, ok := (*visitMap)[p]; ok { return true } (*visitMap)[p] = true } } return } ================================================ FILE: callbacks/helper_test.go ================================================ package callbacks import ( "reflect" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" ) func TestLoadOrStoreVisitMap(t *testing.T) { var vm visitMap var loaded bool type testM struct { Name string } t1 := testM{Name: "t1"} t2 := testM{Name: "t2"} t3 := testM{Name: "t3"} vm = make(visitMap) if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); loaded { t.Fatalf("loaded should be false") } if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf(&t1)); !loaded { t.Fatalf("loaded should be true") } // t1 already exist but t2 not if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t1, &t2, &t3})); loaded { t.Fatalf("loaded should be false") } if loaded = loadOrStoreVisitMap(&vm, reflect.ValueOf([]*testM{&t2, &t3})); !loaded { t.Fatalf("loaded should be true") } } func TestConvertMapToValuesForCreate(t *testing.T) { testCase := []struct { name string input map[string]interface{} expect clause.Values }{ { name: "Test convert string value", input: map[string]interface{}{ "name": "my name", }, expect: clause.Values{ Columns: []clause.Column{{Name: "name"}}, Values: [][]interface{}{{"my name"}}, }, }, { name: "Test convert int value", input: map[string]interface{}{ "age": 18, }, expect: clause.Values{ Columns: []clause.Column{{Name: "age"}}, Values: [][]interface{}{{18}}, }, }, { name: "Test convert float value", input: map[string]interface{}{ "score": 99.5, }, expect: clause.Values{ Columns: []clause.Column{{Name: "score"}}, Values: [][]interface{}{{99.5}}, }, }, { name: "Test convert bool value", input: map[string]interface{}{ "active": true, }, expect: clause.Values{ Columns: []clause.Column{{Name: "active"}}, Values: [][]interface{}{{true}}, }, }, } for _, tc := range testCase { t.Run(tc.name, func(t *testing.T) { actual := ConvertMapToValuesForCreate(&gorm.Statement{}, tc.input) if !reflect.DeepEqual(actual, tc.expect) { t.Errorf("expect %v got %v", tc.expect, actual) } }) } } func TestConvertSliceOfMapToValuesForCreate(t *testing.T) { testCase := []struct { name string input []map[string]interface{} expect clause.Values }{ { name: "Test convert slice of string value", input: []map[string]interface{}{ {"name": "my name"}, }, expect: clause.Values{ Columns: []clause.Column{{Name: "name"}}, Values: [][]interface{}{{"my name"}}, }, }, { name: "Test convert slice of int value", input: []map[string]interface{}{ {"age": 18}, }, expect: clause.Values{ Columns: []clause.Column{{Name: "age"}}, Values: [][]interface{}{{18}}, }, }, { name: "Test convert slice of float value", input: []map[string]interface{}{ {"score": 99.5}, }, expect: clause.Values{ Columns: []clause.Column{{Name: "score"}}, Values: [][]interface{}{{99.5}}, }, }, { name: "Test convert slice of bool value", input: []map[string]interface{}{ {"active": true}, }, expect: clause.Values{ Columns: []clause.Column{{Name: "active"}}, Values: [][]interface{}{{true}}, }, }, } for _, tc := range testCase { t.Run(tc.name, func(t *testing.T) { actual := ConvertSliceOfMapToValuesForCreate(&gorm.Statement{}, tc.input) if !reflect.DeepEqual(actual, tc.expect) { t.Errorf("expected %v but got %v", tc.expect, actual) } }) } } ================================================ FILE: callbacks/interfaces.go ================================================ package callbacks import "gorm.io/gorm" type BeforeCreateInterface interface { BeforeCreate(*gorm.DB) error } type AfterCreateInterface interface { AfterCreate(*gorm.DB) error } type BeforeUpdateInterface interface { BeforeUpdate(*gorm.DB) error } type AfterUpdateInterface interface { AfterUpdate(*gorm.DB) error } type BeforeSaveInterface interface { BeforeSave(*gorm.DB) error } type AfterSaveInterface interface { AfterSave(*gorm.DB) error } type BeforeDeleteInterface interface { BeforeDelete(*gorm.DB) error } type AfterDeleteInterface interface { AfterDelete(*gorm.DB) error } type AfterFindInterface interface { AfterFind(*gorm.DB) error } ================================================ FILE: callbacks/preload.go ================================================ package callbacks import ( "fmt" "reflect" "sort" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) // parsePreloadMap extracts nested preloads. e.g. // // // schema has a "k0" relation and a "k7.k8" embedded relation // parsePreloadMap(schema, map[string][]interface{}{ // clause.Associations: {"arg1"}, // "k1": {"arg2"}, // "k2.k3": {"arg3"}, // "k4.k5.k6": {"arg4"}, // }) // // preloadMap is // map[string]map[string][]interface{}{ // "k0": {}, // "k7": { // "k8": {}, // }, // "k1": {}, // "k2": { // "k3": {"arg3"}, // }, // "k4": { // "k5.k6": {"arg4"}, // }, // } func parsePreloadMap(s *schema.Schema, preloads map[string][]interface{}) map[string]map[string][]interface{} { preloadMap := map[string]map[string][]interface{}{} setPreloadMap := func(name, value string, args []interface{}) { if _, ok := preloadMap[name]; !ok { preloadMap[name] = map[string][]interface{}{} } if value != "" { preloadMap[name][value] = args } } for name, args := range preloads { preloadFields := strings.Split(name, ".") value := strings.TrimPrefix(strings.TrimPrefix(name, preloadFields[0]), ".") if preloadFields[0] == clause.Associations { for _, relation := range s.Relationships.Relations { if relation.Schema == s { setPreloadMap(relation.Name, value, args) } } for embedded, embeddedRelations := range s.Relationships.EmbeddedRelations { for _, value := range embeddedValues(embeddedRelations) { setPreloadMap(embedded, value, args) } } } else { setPreloadMap(preloadFields[0], value, args) } } return preloadMap } func embeddedValues(embeddedRelations *schema.Relationships) []string { if embeddedRelations == nil { return nil } names := make([]string, 0, len(embeddedRelations.Relations)+len(embeddedRelations.EmbeddedRelations)) for _, relation := range embeddedRelations.Relations { // skip first struct name names = append(names, strings.Join(relation.Field.EmbeddedBindNames[1:], ".")) } for _, relations := range embeddedRelations.EmbeddedRelations { names = append(names, embeddedValues(relations)...) } return names } // preloadEntryPoint enters layer by layer. It will call real preload if it finds the right entry point. // If the current relationship is embedded or joined, current query will be ignored. // //nolint:cyclop func preloadEntryPoint(db *gorm.DB, joins []string, relationships *schema.Relationships, preloads map[string][]interface{}, associationsConds []interface{}) error { preloadMap := parsePreloadMap(db.Statement.Schema, preloads) // avoid random traversal of the map preloadNames := make([]string, 0, len(preloadMap)) for key := range preloadMap { preloadNames = append(preloadNames, key) } sort.Strings(preloadNames) isJoined := func(name string) (joined bool, nestedJoins []string) { for _, join := range joins { if _, ok := relationships.Relations[join]; ok && name == join { joined = true continue } join0, join1, cut := strings.Cut(join, ".") if cut { if _, ok := relationships.Relations[join0]; ok && name == join0 { joined = true nestedJoins = append(nestedJoins, join1) } } } return joined, nestedJoins } for _, name := range preloadNames { if relations := relationships.EmbeddedRelations[name]; relations != nil { if err := preloadEntryPoint(db, joins, relations, preloadMap[name], associationsConds); err != nil { return err } } else if rel := relationships.Relations[name]; rel != nil { if joined, nestedJoins := isJoined(name); joined { switch rv := db.Statement.ReflectValue; rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() > 0 { reflectValue := rel.FieldSchema.MakeSlice().Elem() for i := 0; i < rv.Len(); i++ { frv := rel.Field.ReflectValueOf(db.Statement.Context, rv.Index(i)) if frv.Kind() != reflect.Ptr { reflectValue = reflect.Append(reflectValue, frv.Addr()) } else { if frv.IsNil() { continue } reflectValue = reflect.Append(reflectValue, frv) } } tx := preloadDB(db, reflectValue, reflectValue.Interface()) if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err } } case reflect.Struct, reflect.Pointer: reflectValue := rel.Field.ReflectValueOf(db.Statement.Context, rv) tx := preloadDB(db, reflectValue, reflectValue.Interface()) if err := preloadEntryPoint(tx, nestedJoins, &tx.Statement.Schema.Relationships, preloadMap[name], associationsConds); err != nil { return err } default: return gorm.ErrInvalidData } } else { tx := db.Table("").Session(&gorm.Session{Context: db.Statement.Context, SkipHooks: db.Statement.SkipHooks}) tx.Statement.ReflectValue = db.Statement.ReflectValue tx.Statement.Unscoped = db.Statement.Unscoped if err := preload(tx, rel, append(preloads[name], associationsConds...), preloadMap[name]); err != nil { return err } } } else { return fmt.Errorf("%s: %w for schema %s", name, gorm.ErrUnsupportedRelation, db.Statement.Schema.Name) } } return nil } func preloadDB(db *gorm.DB, reflectValue reflect.Value, dest interface{}) *gorm.DB { tx := db.Session(&gorm.Session{Context: db.Statement.Context, NewDB: true, SkipHooks: db.Statement.SkipHooks, Initialized: true}) db.Statement.Settings.Range(func(k, v interface{}) bool { tx.Statement.Settings.Store(k, v) return true }) if err := tx.Statement.Parse(dest); err != nil { tx.AddError(err) return tx } tx.Statement.ReflectValue = reflectValue tx.Statement.Unscoped = db.Statement.Unscoped return tx } func preload(tx *gorm.DB, rel *schema.Relationship, conds []interface{}, preloads map[string][]interface{}) error { var ( reflectValue = tx.Statement.ReflectValue relForeignKeys []string relForeignFields []*schema.Field foreignFields []*schema.Field foreignValues [][]interface{} identityMap = map[string][]reflect.Value{} inlineConds []interface{} ) if rel.JoinTable != nil { var ( joinForeignFields = make([]*schema.Field, 0, len(rel.References)) joinRelForeignFields = make([]*schema.Field, 0, len(rel.References)) joinForeignKeys = make([]string, 0, len(rel.References)) ) for _, ref := range rel.References { if ref.OwnPrimaryKey { joinForeignKeys = append(joinForeignKeys, ref.ForeignKey.DBName) joinForeignFields = append(joinForeignFields, ref.ForeignKey) foreignFields = append(foreignFields, ref.PrimaryKey) } else if ref.PrimaryValue != "" { tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } else { joinRelForeignFields = append(joinRelForeignFields, ref.ForeignKey) relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) relForeignFields = append(relForeignFields, ref.PrimaryKey) } } joinIdentityMap, joinForeignValues := schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(joinForeignValues) == 0 { return nil } joinResults := rel.JoinTable.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, joinForeignKeys, joinForeignValues) if err := tx.Where(clause.IN{Column: column, Values: values}).Find(joinResults.Addr().Interface()).Error; err != nil { return err } // convert join identity map to relation identity map fieldValues := make([]interface{}, len(joinForeignFields)) joinFieldValues := make([]interface{}, len(joinRelForeignFields)) for i := 0; i < joinResults.Len(); i++ { joinIndexValue := joinResults.Index(i) for idx, field := range joinForeignFields { fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } for idx, field := range joinRelForeignFields { joinFieldValues[idx], _ = field.ValueOf(tx.Statement.Context, joinIndexValue) } if results, ok := joinIdentityMap[utils.ToStringKey(fieldValues...)]; ok { joinKey := utils.ToStringKey(joinFieldValues...) identityMap[joinKey] = append(identityMap[joinKey], results...) } } _, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, joinResults, joinRelForeignFields) } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) relForeignFields = append(relForeignFields, ref.ForeignKey) foreignFields = append(foreignFields, ref.PrimaryKey) } else if ref.PrimaryValue != "" { tx = tx.Where(clause.Eq{Column: ref.ForeignKey.DBName, Value: ref.PrimaryValue}) } else { relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) relForeignFields = append(relForeignFields, ref.PrimaryKey) foreignFields = append(foreignFields, ref.ForeignKey) } } identityMap, foreignValues = schema.GetIdentityFieldValuesMap(tx.Statement.Context, reflectValue, foreignFields) if len(foreignValues) == 0 { return nil } } // nested preload for p, pvs := range preloads { tx = tx.Preload(p, pvs...) } reflectResults := rel.FieldSchema.MakeSlice().Elem() column, values := schema.ToQueryValues(clause.CurrentTable, relForeignKeys, foreignValues) if len(values) != 0 { tx = tx.Model(reflectResults.Addr().Interface()).Where(clause.IN{Column: column, Values: values}) for _, cond := range conds { if fc, ok := cond.(func(*gorm.DB) *gorm.DB); ok { tx = fc(tx) } else { inlineConds = append(inlineConds, cond) } } if len(inlineConds) > 0 { tx = tx.Where(inlineConds[0], inlineConds[1:]...) } if err := tx.Find(reflectResults.Addr().Interface()).Error; err != nil { return err } } fieldValues := make([]interface{}, len(relForeignFields)) // clean up old values before preloading switch reflectValue.Kind() { case reflect.Struct: switch rel.Type { case schema.HasMany, schema.Many2Many: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue, reflect.New(rel.Field.FieldType).Interface())) } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { switch rel.Type { case schema.HasMany, schema.Many2Many: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.MakeSlice(rel.Field.IndirectFieldType, 0, 10).Interface())) default: tx.AddError(rel.Field.Set(tx.Statement.Context, reflectValue.Index(i), reflect.New(rel.Field.FieldType).Interface())) } } } for i := 0; i < reflectResults.Len(); i++ { elem := reflectResults.Index(i) for idx, field := range relForeignFields { fieldValues[idx], _ = field.ValueOf(tx.Statement.Context, elem) } datas, ok := identityMap[utils.ToStringKey(fieldValues...)] if !ok { return fmt.Errorf("failed to assign association %#v, make sure foreign fields exists", elem.Interface()) } for _, data := range datas { reflectFieldValue := rel.Field.ReflectValueOf(tx.Statement.Context, data) if reflectFieldValue.Kind() == reflect.Ptr && reflectFieldValue.IsNil() { reflectFieldValue.Set(reflect.New(rel.Field.FieldType.Elem())) } reflectFieldValue = reflect.Indirect(reflectFieldValue) switch reflectFieldValue.Kind() { case reflect.Struct: tx.AddError(rel.Field.Set(tx.Statement.Context, data, elem.Interface())) case reflect.Slice, reflect.Array: if reflectFieldValue.Type().Elem().Kind() == reflect.Ptr { tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem).Interface())) } else { tx.AddError(rel.Field.Set(tx.Statement.Context, data, reflect.Append(reflectFieldValue, elem.Elem()).Interface())) } } } } return tx.Error } ================================================ FILE: callbacks/query.go ================================================ package callbacks import ( "fmt" "reflect" "strings" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) func Query(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) if !db.DryRun && db.Error == nil { rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) return } defer func() { db.AddError(rows.Close()) }() gorm.Scan(rows, db, 0) if db.Statement.Result != nil { db.Statement.Result.RowsAffected = db.RowsAffected } } } } func BuildQuerySQL(db *gorm.DB) { if db.Statement.Schema != nil { for _, c := range db.Statement.Schema.QueryClauses { db.Statement.AddClause(c) } } if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(100) clauseSelect := clause.Select{Distinct: db.Statement.Distinct} if db.Statement.ReflectValue.Kind() == reflect.Struct && db.Statement.ReflectValue.Type() == db.Statement.Schema.ModelType { var conds []clause.Expression for _, primaryField := range db.Statement.Schema.PrimaryFields { if v, isZero := primaryField.ValueOf(db.Statement.Context, db.Statement.ReflectValue); !isZero { conds = append(conds, clause.Eq{Column: clause.Column{Table: db.Statement.Table, Name: primaryField.DBName}, Value: v}) } } if len(conds) > 0 { db.Statement.AddClause(clause.Where{Exprs: conds}) } } if len(db.Statement.Selects) > 0 { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Selects)) for idx, name := range db.Statement.Selects { if db.Statement.Schema == nil { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } else if f := db.Statement.Schema.LookUpField(name); f != nil { clauseSelect.Columns[idx] = clause.Column{Name: f.DBName} } else { clauseSelect.Columns[idx] = clause.Column{Name: name, Raw: true} } } } else if db.Statement.Schema != nil && len(db.Statement.Omits) > 0 { selectColumns, _ := db.Statement.SelectAndOmitColumns(false, false) clauseSelect.Columns = make([]clause.Column, 0, len(db.Statement.Schema.DBNames)) for _, dbName := range db.Statement.Schema.DBNames { if v, ok := selectColumns[dbName]; (ok && v) || !ok { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{Table: db.Statement.Table, Name: dbName}) } } } else if db.Statement.Schema != nil && db.Statement.ReflectValue.IsValid() { queryFields := db.QueryFields if !queryFields { switch db.Statement.ReflectValue.Kind() { case reflect.Struct: queryFields = db.Statement.ReflectValue.Type() != db.Statement.Schema.ModelType case reflect.Slice: queryFields = db.Statement.ReflectValue.Type().Elem() != db.Statement.Schema.ModelType } } if queryFields { stmt := gorm.Statement{DB: db} // smaller struct if err := stmt.Parse(db.Statement.Dest); err == nil && (db.QueryFields || stmt.Schema.ModelType != db.Statement.Schema.ModelType) { clauseSelect.Columns = make([]clause.Column, len(stmt.Schema.DBNames)) for idx, dbName := range stmt.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } } } } // inline joins fromClause := clause.From{} if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { fromClause = v } if len(db.Statement.Joins) != 0 || len(fromClause.Joins) != 0 { if len(db.Statement.Selects) == 0 && len(db.Statement.Omits) == 0 && db.Statement.Schema != nil { clauseSelect.Columns = make([]clause.Column, len(db.Statement.Schema.DBNames)) for idx, dbName := range db.Statement.Schema.DBNames { clauseSelect.Columns[idx] = clause.Column{Table: db.Statement.Table, Name: dbName} } } specifiedRelationsName := map[string]string{clause.CurrentTable: clause.CurrentTable} for _, join := range db.Statement.Joins { if db.Statement.Schema != nil { var isRelations bool // is relations or raw sql var relations []*schema.Relationship relation, ok := db.Statement.Schema.Relationships.Relations[join.Name] if ok { isRelations = true relations = append(relations, relation) } else { // handle nested join like "Manager.Company" nestedJoinNames := strings.Split(join.Name, ".") if len(nestedJoinNames) > 1 { isNestedJoin := true guessNestedRelations := make([]*schema.Relationship, 0, len(nestedJoinNames)) currentRelations := db.Statement.Schema.Relationships.Relations for _, relname := range nestedJoinNames { // incomplete match, only treated as raw sql if relation, ok = currentRelations[relname]; ok { guessNestedRelations = append(guessNestedRelations, relation) currentRelations = relation.FieldSchema.Relationships.Relations } else { isNestedJoin = false break } } if isNestedJoin { isRelations = true relations = guessNestedRelations } } } if isRelations { genJoinClause := func(joinType clause.JoinType, tableAliasName string, parentTableName string, relation *schema.Relationship) clause.Join { columnStmt := gorm.Statement{ Table: tableAliasName, DB: db, Schema: relation.FieldSchema, Selects: join.Selects, Omits: join.Omits, } selectColumns, restricted := columnStmt.SelectAndOmitColumns(false, false) for _, s := range relation.FieldSchema.DBNames { if v, ok := selectColumns[s]; (ok && v) || (!ok && !restricted) { clauseSelect.Columns = append(clauseSelect.Columns, clause.Column{ Table: tableAliasName, Name: s, Alias: utils.NestedRelationName(tableAliasName, s), }) } } if join.Expression != nil { return clause.Join{ Type: join.JoinType, Expression: join.Expression, } } exprs := make([]clause.Expression, len(relation.References)) for idx, ref := range relation.References { if ref.OwnPrimaryKey { exprs[idx] = clause.Eq{ Column: clause.Column{Table: parentTableName, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, } } else { if ref.PrimaryValue == "" { exprs[idx] = clause.Eq{ Column: clause.Column{Table: parentTableName, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: tableAliasName, Name: ref.PrimaryKey.DBName}, } } else { exprs[idx] = clause.Eq{ Column: clause.Column{Table: tableAliasName, Name: ref.ForeignKey.DBName}, Value: ref.PrimaryValue, } } } } { onStmt := gorm.Statement{Table: tableAliasName, DB: db, Clauses: map[string]clause.Clause{}} for _, c := range relation.FieldSchema.QueryClauses { onStmt.AddClause(c) } if join.On != nil { onStmt.AddClause(join.On) } if cs, ok := onStmt.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { where.Build(&onStmt) if onSQL := onStmt.SQL.String(); onSQL != "" { vars := onStmt.Vars for idx, v := range vars { bindvar := strings.Builder{} onStmt.Vars = vars[0 : idx+1] db.Dialector.BindVarTo(&bindvar, &onStmt, v) onSQL = strings.Replace(onSQL, bindvar.String(), "?", 1) } exprs = append(exprs, clause.Expr{SQL: onSQL, Vars: vars}) } } } } return clause.Join{ Type: joinType, Table: clause.Table{Name: relation.FieldSchema.Table, Alias: tableAliasName}, ON: clause.Where{Exprs: exprs}, } } parentTableName := clause.CurrentTable for idx, rel := range relations { // joins table alias like "Manager, Company, Manager__Company" curAliasName := rel.Name if parentTableName != clause.CurrentTable { curAliasName = utils.NestedRelationName(parentTableName, curAliasName) } if _, ok := specifiedRelationsName[curAliasName]; !ok { aliasName := curAliasName if idx == len(relations)-1 && join.Alias != "" { aliasName = join.Alias } fromClause.Joins = append(fromClause.Joins, genJoinClause(join.JoinType, aliasName, specifiedRelationsName[parentTableName], rel)) specifiedRelationsName[curAliasName] = aliasName } parentTableName = curAliasName } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } } else { fromClause.Joins = append(fromClause.Joins, clause.Join{ Expression: clause.NamedExpr{SQL: join.Name, Vars: join.Conds}, }) } } db.Statement.AddClause(fromClause) } else { db.Statement.AddClauseIfNotExists(clause.From{}) } db.Statement.AddClauseIfNotExists(clauseSelect) db.Statement.Build(db.Statement.BuildClauses...) } } func Preload(db *gorm.DB) { if db.Error == nil && len(db.Statement.Preloads) > 0 { if db.Statement.Schema == nil { db.AddError(fmt.Errorf("%w when using preload", gorm.ErrModelValueRequired)) return } joins := make([]string, 0, len(db.Statement.Joins)) for _, join := range db.Statement.Joins { joins = append(joins, join.Name) } tx := preloadDB(db, db.Statement.ReflectValue, db.Statement.Dest) if tx.Error != nil { return } db.AddError(preloadEntryPoint(tx, joins, &tx.Statement.Schema.Relationships, db.Statement.Preloads, db.Statement.Preloads[clause.Associations])) } } func AfterQuery(db *gorm.DB) { // clear the joins after query because preload need it if v, ok := db.Statement.Clauses["FROM"].Expression.(clause.From); ok { fromClause := db.Statement.Clauses["FROM"] fromClause.Expression = clause.From{Tables: v.Tables, Joins: utils.RTrimSlice(v.Joins, len(db.Statement.Joins))} // keep the original From Joins db.Statement.Clauses["FROM"] = fromClause } if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && db.Statement.Schema.AfterFind && db.RowsAffected > 0 { callMethod(db, func(value interface{}, tx *gorm.DB) bool { if i, ok := value.(AfterFindInterface); ok { db.AddError(i.AfterFind(tx)) return true } return false }) } } ================================================ FILE: callbacks/raw.go ================================================ package callbacks import ( "gorm.io/gorm" ) func RawExec(db *gorm.DB) { if db.Error == nil && !db.DryRun { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if err != nil { db.AddError(err) return } db.RowsAffected, _ = result.RowsAffected() if db.Statement.Result != nil { db.Statement.Result.Result = result db.Statement.Result.RowsAffected = db.RowsAffected } } } ================================================ FILE: callbacks/row.go ================================================ package callbacks import ( "gorm.io/gorm" ) func RowQuery(db *gorm.DB) { if db.Error == nil { BuildQuerySQL(db) if db.DryRun || db.Error != nil { return } if isRows, ok := db.Get("rows"); ok && isRows.(bool) { db.Statement.Settings.Delete("rows") db.Statement.Dest, db.Error = db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } else { db.Statement.Dest = db.Statement.ConnPool.QueryRowContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) } db.RowsAffected = -1 } } ================================================ FILE: callbacks/transaction.go ================================================ package callbacks import ( "gorm.io/gorm" ) func BeginTransaction(db *gorm.DB) { if !db.Config.SkipDefaultTransaction && db.Error == nil { if tx := db.Begin(); tx.Error == nil { db.Statement.ConnPool = tx.Statement.ConnPool db.InstanceSet("gorm:started_transaction", true) } else if tx.Error == gorm.ErrInvalidTransaction { tx.Error = nil } else { db.Error = tx.Error } } } func CommitOrRollbackTransaction(db *gorm.DB) { if !db.Config.SkipDefaultTransaction { if _, ok := db.InstanceGet("gorm:started_transaction"); ok { if db.Error != nil { db.Rollback() } else { db.Commit() } db.Statement.ConnPool = db.ConnPool } } } ================================================ FILE: callbacks/update.go ================================================ package callbacks import ( "reflect" "sort" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) func SetupUpdateReflectValue(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil { if !db.Statement.ReflectValue.CanAddr() || db.Statement.Model != db.Statement.Dest { db.Statement.ReflectValue = reflect.ValueOf(db.Statement.Model) for db.Statement.ReflectValue.Kind() == reflect.Ptr { db.Statement.ReflectValue = db.Statement.ReflectValue.Elem() } if dest, ok := db.Statement.Dest.(map[string]interface{}); ok { for _, rel := range db.Statement.Schema.Relationships.BelongsTo { if _, ok := dest[rel.Name]; ok { db.AddError(rel.Field.Set(db.Statement.Context, db.Statement.ReflectValue, dest[rel.Name])) } } } } } } // BeforeUpdate before update hooks func BeforeUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.BeforeSave || db.Statement.Schema.BeforeUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.BeforeSave { if i, ok := value.(BeforeSaveInterface); ok { called = true db.AddError(i.BeforeSave(tx)) } } if db.Statement.Schema.BeforeUpdate { if i, ok := value.(BeforeUpdateInterface); ok { called = true db.AddError(i.BeforeUpdate(tx)) } } return called }) } } // Update update hook func Update(config *Config) func(db *gorm.DB) { supportReturning := utils.Contains(config.UpdateClauses, "RETURNING") return func(db *gorm.DB) { if db.Error != nil { return } if db.Statement.Schema != nil { for _, c := range db.Statement.Schema.UpdateClauses { db.Statement.AddClause(c) } } if db.Statement.SQL.Len() == 0 { db.Statement.SQL.Grow(180) db.Statement.AddClauseIfNotExists(clause.Update{}) if _, ok := db.Statement.Clauses["SET"]; !ok { if set := ConvertToAssignments(db.Statement); len(set) != 0 { defer delete(db.Statement.Clauses, "SET") db.Statement.AddClause(set) } else { return } } db.Statement.Build(db.Statement.BuildClauses...) } checkMissingWhereConditions(db) if !db.DryRun && db.Error == nil { if ok, mode := hasReturning(db, supportReturning); ok { if rows, err := db.Statement.ConnPool.QueryContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...); db.AddError(err) == nil { dest := db.Statement.Dest db.Statement.Dest = db.Statement.ReflectValue.Addr().Interface() gorm.Scan(rows, db, mode) db.Statement.Dest = dest db.AddError(rows.Close()) if db.Statement.Result != nil { db.Statement.Result.RowsAffected = db.RowsAffected } } } else { result, err := db.Statement.ConnPool.ExecContext(db.Statement.Context, db.Statement.SQL.String(), db.Statement.Vars...) if db.AddError(err) == nil { db.RowsAffected, _ = result.RowsAffected() } if db.Statement.Result != nil { db.Statement.Result.Result = result db.Statement.Result.RowsAffected = db.RowsAffected } } } } } // AfterUpdate after update hooks func AfterUpdate(db *gorm.DB) { if db.Error == nil && db.Statement.Schema != nil && !db.Statement.SkipHooks && (db.Statement.Schema.AfterSave || db.Statement.Schema.AfterUpdate) { callMethod(db, func(value interface{}, tx *gorm.DB) (called bool) { if db.Statement.Schema.AfterUpdate { if i, ok := value.(AfterUpdateInterface); ok { called = true db.AddError(i.AfterUpdate(tx)) } } if db.Statement.Schema.AfterSave { if i, ok := value.(AfterSaveInterface); ok { called = true db.AddError(i.AfterSave(tx)) } } return called }) } } // ConvertToAssignments convert to update assignments func ConvertToAssignments(stmt *gorm.Statement) (set clause.Set) { var ( selectColumns, restricted = stmt.SelectAndOmitColumns(false, true) assignValue func(field *schema.Field, value interface{}) ) switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: assignValue = func(field *schema.Field, value interface{}) { for i := 0; i < stmt.ReflectValue.Len(); i++ { if stmt.ReflectValue.CanAddr() { field.Set(stmt.Context, stmt.ReflectValue.Index(i), value) } } } case reflect.Struct: assignValue = func(field *schema.Field, value interface{}) { if stmt.ReflectValue.CanAddr() { field.Set(stmt.Context, stmt.ReflectValue, value) } } default: assignValue = func(field *schema.Field, value interface{}) { } } updatingValue := reflect.ValueOf(stmt.Dest) for updatingValue.Kind() == reflect.Ptr { updatingValue = updatingValue.Elem() } if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if size := stmt.ReflectValue.Len(); size > 0 { var isZero bool for i := 0; i < size; i++ { for _, field := range stmt.Schema.PrimaryFields { _, isZero = field.ValueOf(stmt.Context, stmt.ReflectValue.Index(i)) if !isZero { break } } } if !isZero { _, primaryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues("", stmt.Schema.PrimaryFieldDBNames, primaryValues) stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } } case reflect.Struct: for _, field := range stmt.Schema.PrimaryFields { if value, isZero := field.ValueOf(stmt.Context, stmt.ReflectValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } } } switch value := updatingValue.Interface().(type) { case map[string]interface{}: set = make([]clause.Assignment, 0, len(value)) keys := make([]string, 0, len(value)) for k := range value { keys = append(keys, k) } sort.Strings(keys) for _, k := range keys { kv := value[k] if _, ok := kv.(*gorm.DB); ok { kv = []interface{}{kv} } if stmt.Schema != nil { if field := stmt.Schema.LookUpField(k); field != nil { if field.DBName != "" { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: kv}) assignValue(field, value[k]) } } else if v, ok := selectColumns[field.Name]; (ok && v) || (!ok && !restricted) { assignValue(field, value[k]) } continue } } if v, ok := selectColumns[k]; (ok && v) || (!ok && !restricted) { set = append(set, clause.Assignment{Column: clause.Column{Name: k}, Value: kv}) } } if !stmt.SkipHooks && stmt.Schema != nil { for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.LookUpField(dbName) if field.AutoUpdateTime > 0 && value[field.Name] == nil && value[field.DBName] == nil { if v, ok := selectColumns[field.DBName]; (ok && v) || !ok { now := stmt.DB.NowFunc() assignValue(field, now) if field.AutoUpdateTime == schema.UnixNanosecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixNano()}) } else if field.AutoUpdateTime == schema.UnixMillisecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.UnixMilli()}) } else if field.AutoUpdateTime == schema.UnixSecond { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now.Unix()}) } else { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: now}) } } } } } default: updatingSchema := stmt.Schema var isDiffSchema bool if !updatingValue.CanAddr() || stmt.Dest != stmt.Model { // different schema updatingStmt := &gorm.Statement{DB: stmt.DB} if err := updatingStmt.Parse(stmt.Dest); err == nil { updatingSchema = updatingStmt.Schema isDiffSchema = true } } switch updatingValue.Kind() { case reflect.Struct: set = make([]clause.Assignment, 0, len(stmt.Schema.FieldsByDBName)) for _, dbName := range stmt.Schema.DBNames { if field := updatingSchema.LookUpField(dbName); field != nil { if !field.PrimaryKey || !updatingValue.CanAddr() || stmt.Dest != stmt.Model { if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && (!restricted || (!stmt.SkipHooks && field.AutoUpdateTime > 0))) { value, isZero := field.ValueOf(stmt.Context, updatingValue) if !stmt.SkipHooks && field.AutoUpdateTime > 0 { if field.AutoUpdateTime == schema.UnixNanosecond { value = stmt.DB.NowFunc().UnixNano() } else if field.AutoUpdateTime == schema.UnixMillisecond { value = stmt.DB.NowFunc().UnixMilli() } else if field.AutoUpdateTime == schema.UnixSecond { value = stmt.DB.NowFunc().Unix() } else { value = stmt.DB.NowFunc() } isZero = false } if (ok || !isZero) && field.Updatable { set = append(set, clause.Assignment{Column: clause.Column{Name: field.DBName}, Value: value}) assignField := field if isDiffSchema { if originField := stmt.Schema.LookUpField(dbName); originField != nil { assignField = originField } } assignValue(assignField, value) } } } else { if value, isZero := field.ValueOf(stmt.Context, updatingValue); !isZero { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.Eq{Column: field.DBName, Value: value}}}) } } } } default: stmt.AddError(gorm.ErrInvalidData) } } return } ================================================ FILE: callbacks.go ================================================ package gorm import ( "context" "errors" "fmt" "reflect" "sort" "time" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) func initializeCallbacks(db *DB) *callbacks { return &callbacks{ processors: map[string]*processor{ "create": {db: db}, "query": {db: db}, "update": {db: db}, "delete": {db: db}, "row": {db: db}, "raw": {db: db}, }, } } // callbacks gorm callbacks manager type callbacks struct { processors map[string]*processor } type processor struct { db *DB Clauses []string fns []func(*DB) callbacks []*callback } type callback struct { name string before string after string remove bool replace bool match func(*DB) bool handler func(*DB) processor *processor } func (cs *callbacks) Create() *processor { return cs.processors["create"] } func (cs *callbacks) Query() *processor { return cs.processors["query"] } func (cs *callbacks) Update() *processor { return cs.processors["update"] } func (cs *callbacks) Delete() *processor { return cs.processors["delete"] } func (cs *callbacks) Row() *processor { return cs.processors["row"] } func (cs *callbacks) Raw() *processor { return cs.processors["raw"] } func (p *processor) Execute(db *DB) *DB { // call scopes for len(db.Statement.scopes) > 0 { db = db.executeScopes() } var ( curTime = time.Now() stmt = db.Statement resetBuildClauses bool ) if len(stmt.BuildClauses) == 0 { stmt.BuildClauses = p.Clauses resetBuildClauses = true } if optimizer, ok := stmt.Dest.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } if db.DefaultContextTimeout > 0 { if _, ok := stmt.Context.Deadline(); !ok { stmt.Context, _ = context.WithTimeout(stmt.Context, db.DefaultContextTimeout) } } // assign model values if stmt.Model == nil { stmt.Model = stmt.Dest } else if stmt.Dest == nil { stmt.Dest = stmt.Model } // parse model values if stmt.Model != nil { if err := stmt.Parse(stmt.Model); err != nil && (!errors.Is(err, schema.ErrUnsupportedDataType) || (stmt.Table == "" && stmt.TableExpr == nil && stmt.SQL.Len() == 0)) { if errors.Is(err, schema.ErrUnsupportedDataType) && stmt.Table == "" && stmt.TableExpr == nil { db.AddError(fmt.Errorf("%w: Table not set, please set it like: db.Model(&user) or db.Table(\"users\")", err)) } else { db.AddError(err) } } } // assign stmt.ReflectValue if stmt.Dest != nil { stmt.ReflectValue = reflect.ValueOf(stmt.Dest) for stmt.ReflectValue.Kind() == reflect.Ptr { if stmt.ReflectValue.IsNil() && stmt.ReflectValue.CanAddr() { stmt.ReflectValue.Set(reflect.New(stmt.ReflectValue.Type().Elem())) } stmt.ReflectValue = stmt.ReflectValue.Elem() } if !stmt.ReflectValue.IsValid() { db.AddError(ErrInvalidValue) } } for _, f := range p.fns { f(db) } if stmt.SQL.Len() > 0 { db.Logger.Trace(stmt.Context, curTime, func() (string, int64) { sql, vars := stmt.SQL.String(), stmt.Vars if filter, ok := db.Logger.(ParamsFilter); ok { sql, vars = filter.ParamsFilter(stmt.Context, stmt.SQL.String(), stmt.Vars...) } return db.Dialector.Explain(sql, vars...), db.RowsAffected }, db.Error) } if !stmt.DB.DryRun { stmt.SQL.Reset() stmt.Vars = nil } if resetBuildClauses { stmt.BuildClauses = nil } return db } func (p *processor) Get(name string) func(*DB) { for i := len(p.callbacks) - 1; i >= 0; i-- { if v := p.callbacks[i]; v.name == name && !v.remove { return v.handler } } return nil } func (p *processor) Before(name string) *callback { return &callback{before: name, processor: p} } func (p *processor) After(name string) *callback { return &callback{after: name, processor: p} } func (p *processor) Match(fc func(*DB) bool) *callback { return &callback{match: fc, processor: p} } func (p *processor) Register(name string, fn func(*DB)) error { return (&callback{processor: p}).Register(name, fn) } func (p *processor) Remove(name string) error { return (&callback{processor: p}).Remove(name) } func (p *processor) Replace(name string, fn func(*DB)) error { return (&callback{processor: p}).Replace(name, fn) } func (p *processor) compile() (err error) { var callbacks []*callback removedMap := map[string]bool{} for _, callback := range p.callbacks { if callback.match == nil || callback.match(p.db) { callbacks = append(callbacks, callback) } if callback.remove { removedMap[callback.name] = true } } if len(removedMap) > 0 { callbacks = removeCallbacks(callbacks, removedMap) } p.callbacks = callbacks if p.fns, err = sortCallbacks(p.callbacks); err != nil { p.db.Logger.Error(context.Background(), "Got error when compile callbacks, got %v", err) } return } func (c *callback) Before(name string) *callback { c.before = name return c } func (c *callback) After(name string) *callback { c.after = name return c } func (c *callback) Register(name string, fn func(*DB)) error { c.name = name c.handler = fn c.processor.callbacks = append(c.processor.callbacks, c) return c.processor.compile() } func (c *callback) Remove(name string) error { c.processor.db.Logger.Warn(context.Background(), "removing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.remove = true c.processor.callbacks = append(c.processor.callbacks, c) return c.processor.compile() } func (c *callback) Replace(name string, fn func(*DB)) error { c.processor.db.Logger.Info(context.Background(), "replacing callback `%s` from %s\n", name, utils.FileWithLineNum()) c.name = name c.handler = fn c.replace = true c.processor.callbacks = append(c.processor.callbacks, c) return c.processor.compile() } // getRIndex get right index from string slice func getRIndex(strs []string, str string) int { for i := len(strs) - 1; i >= 0; i-- { if strs[i] == str { return i } } return -1 } func sortCallbacks(cs []*callback) (fns []func(*DB), err error) { var ( names, sorted []string sortCallback func(*callback) error ) sort.SliceStable(cs, func(i, j int) bool { if cs[j].before == "*" && cs[i].before != "*" { return true } if cs[j].after == "*" && cs[i].after != "*" { return true } return false }) for _, c := range cs { // show warning message the callback name already exists if idx := getRIndex(names, c.name); idx > -1 && !c.replace && !c.remove && !cs[idx].remove { c.processor.db.Logger.Warn(context.Background(), "duplicated callback `%s` from %s\n", c.name, utils.FileWithLineNum()) } names = append(names, c.name) } sortCallback = func(c *callback) error { if c.before != "" { // if defined before callback if c.before == "*" && len(sorted) > 0 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { sorted = append([]string{c.name}, sorted...) } } else if sortedIdx := getRIndex(sorted, c.before); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if before callback already sorted, append current callback just after it sorted = append(sorted[:sortedIdx], append([]string{c.name}, sorted[sortedIdx:]...)...) } else if curIdx > sortedIdx { return fmt.Errorf("conflicting callback %s with before %s", c.name, c.before) } } else if idx := getRIndex(names, c.before); idx != -1 { // if before callback exists cs[idx].after = c.name } } if c.after != "" { // if defined after callback if c.after == "*" && len(sorted) > 0 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { sorted = append(sorted, c.name) } } else if sortedIdx := getRIndex(sorted, c.after); sortedIdx != -1 { if curIdx := getRIndex(sorted, c.name); curIdx == -1 { // if after callback sorted, append current callback to last sorted = append(sorted, c.name) } else if curIdx < sortedIdx { return fmt.Errorf("conflicting callback %s with before %s", c.name, c.after) } } else if idx := getRIndex(names, c.after); idx != -1 { // if after callback exists but haven't sorted // set after callback's before callback to current callback after := cs[idx] if after.before == "" { after.before = c.name } if err := sortCallback(after); err != nil { return err } if err := sortCallback(c); err != nil { return err } } } // if current callback haven't been sorted, append it to last if getRIndex(sorted, c.name) == -1 { sorted = append(sorted, c.name) } return nil } for _, c := range cs { if err = sortCallback(c); err != nil { return } } for _, name := range sorted { if idx := getRIndex(names, name); !cs[idx].remove { fns = append(fns, cs[idx].handler) } } return } func removeCallbacks(cs []*callback, nameMap map[string]bool) []*callback { callbacks := make([]*callback, 0, len(cs)) for _, callback := range cs { if nameMap[callback.name] { continue } callbacks = append(callbacks, callback) } return callbacks } ================================================ FILE: chainable_api.go ================================================ package gorm import ( "fmt" "regexp" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) // Model specify the model you would like to run db operations // // // update all users's name to `hello` // db.Model(&User{}).Update("name", "hello") // // if user's primary key is non-blank, will use it as condition, then will only update that user's name to `hello` // db.Model(&user).Update("name", "hello") func (db *DB) Model(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Model = value return } // Clauses Add clauses // // This supports both standard clauses (clause.OrderBy, clause.Limit, clause.Where) and more // advanced techniques like specifying lock strength and optimizer hints. See the // [docs] for more depth. // // // add a simple limit clause // db.Clauses(clause.Limit{Limit: 1}).Find(&User{}) // // tell the optimizer to use the `idx_user_name` index // db.Clauses(hints.UseIndex("idx_user_name")).Find(&User{}) // // specify the lock strength to UPDATE // db.Clauses(clause.Locking{Strength: "UPDATE"}).Find(&users) // // [docs]: https://gorm.io/docs/sql_builder.html#Clauses func (db *DB) Clauses(conds ...clause.Expression) (tx *DB) { tx = db.getInstance() var whereConds []interface{} for _, cond := range conds { if c, ok := cond.(clause.Interface); ok { tx.Statement.AddClause(c) } else if optimizer, ok := cond.(StatementModifier); ok { optimizer.ModifyStatement(tx.Statement) } else { whereConds = append(whereConds, cond) } } if len(whereConds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: tx.Statement.BuildCondition(whereConds[0], whereConds[1:]...)}) } return } var tableRegexp = regexp.MustCompile(`(?i)(?:.+? AS (\w+)\s*(?:$|,)|^\w+\s+(\w+)$)`) // Table specify the table you would like to run db operations // // // Get a user // db.Table("users").Take(&result) func (db *DB) Table(name string, args ...interface{}) (tx *DB) { tx = db.getInstance() if strings.Contains(name, " ") || strings.Contains(name, "`") || len(args) > 0 { tx.Statement.TableExpr = &clause.Expr{SQL: name, Vars: args} if results := tableRegexp.FindStringSubmatch(name); len(results) == 3 { if results[1] != "" { tx.Statement.Table = results[1] } else { tx.Statement.Table = results[2] } } } else if tables := strings.Split(name, "."); len(tables) == 2 { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = tables[1] } else if name != "" { tx.Statement.TableExpr = &clause.Expr{SQL: tx.Statement.Quote(name)} tx.Statement.Table = name } else { tx.Statement.TableExpr = nil tx.Statement.Table = "" } return } // Distinct specify distinct fields that you want querying // // // Select distinct names of users // db.Distinct("name").Find(&results) // // Select distinct name/age pairs from users // db.Distinct("name", "age").Find(&results) func (db *DB) Distinct(args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Distinct = true if len(args) > 0 { tx = tx.Select(args[0], args[1:]...) } return } // Select specify fields that you want when querying, creating, updating // // Use Select when you only want a subset of the fields. By default, GORM will select all fields. // Select accepts both string arguments and arrays. // // // Select name and age of user using multiple arguments // db.Select("name", "age").Find(&users) // // Select name and age of user using an array // db.Select([]string{"name", "age"}).Find(&users) func (db *DB) Select(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() switch v := query.(type) { case []string: tx.Statement.Selects = v for _, arg := range args { switch arg := arg.(type) { case string: tx.Statement.Selects = append(tx.Statement.Selects, arg) case []string: tx.Statement.Selects = append(tx.Statement.Selects, arg...) default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) return } } if clause, ok := tx.Statement.Clauses["SELECT"]; ok { clause.Expression = nil tx.Statement.Clauses["SELECT"] = clause } case string: if strings.Count(v, "?") >= len(args) && len(args) > 0 { tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, }) } else if strings.Count(v, "@") > 0 && len(args) > 0 { tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, Expression: clause.NamedExpr{SQL: v, Vars: args}, }) } else { tx.Statement.Selects = []string{v} for _, arg := range args { switch arg := arg.(type) { case string: tx.Statement.Selects = append(tx.Statement.Selects, arg) case []string: tx.Statement.Selects = append(tx.Statement.Selects, arg...) default: tx.Statement.AddClause(clause.Select{ Distinct: db.Statement.Distinct, Expression: clause.Expr{SQL: v, Vars: args}, }) return } } if clause, ok := tx.Statement.Clauses["SELECT"]; ok { clause.Expression = nil tx.Statement.Clauses["SELECT"] = clause } } default: tx.AddError(fmt.Errorf("unsupported select args %v %v", query, args)) } return } // Omit specify fields that you want to ignore when creating, updating and querying func (db *DB) Omit(columns ...string) (tx *DB) { tx = db.getInstance() if len(columns) == 1 && strings.ContainsRune(columns[0], ',') { tx.Statement.Omits = strings.FieldsFunc(columns[0], utils.IsInvalidDBNameChar) } else { tx.Statement.Omits = columns } return } // MapColumns modify the column names in the query results to facilitate align to the corresponding structural fields func (db *DB) MapColumns(m map[string]string) (tx *DB) { tx = db.getInstance() tx.Statement.ColumnMapping = m return } // Where add conditions // // See the [docs] for details on the various formats that where clauses can take. By default, where clauses chain with AND. // // // Find the first user with name jinzhu // db.Where("name = ?", "jinzhu").First(&user) // // Find the first user with name jinzhu and age 20 // db.Where(&User{Name: "jinzhu", Age: 20}).First(&user) // // Find the first user with name jinzhu and age not equal to 20 // db.Where("name = ?", "jinzhu").Where("age <> ?", "20").First(&user) // // [docs]: https://gorm.io/docs/query.html#Conditions func (db *DB) Where(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: conds}) } return } // Not add NOT conditions // // Not works similarly to where, and has the same syntax. // // // Find the first user with name not equal to jinzhu // db.Not("name = ?", "jinzhu").First(&user) func (db *DB) Not(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Not(conds...)}}) } return } // Or add OR conditions // // Or is used to chain together queries with an OR. // // // Find the first user with name equal to jinzhu or john // db.Where("name = ?", "jinzhu").Or("name = ?", "john").First(&user) func (db *DB) Or(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() if conds := tx.Statement.BuildCondition(query, args...); len(conds) > 0 { tx.Statement.AddClause(clause.Where{Exprs: []clause.Expression{clause.Or(clause.And(conds...))}}) } return } // Joins specify Joins conditions // // db.Joins("Account").Find(&user) // db.Joins("JOIN emails ON emails.user_id = users.id AND emails.email = ?", "jinzhu@example.org").Find(&user) // db.Joins("Account", DB.Select("id").Where("user_id = users.id AND name = ?", "someName").Model(&Account{})) func (db *DB) Joins(query string, args ...interface{}) (tx *DB) { return joins(db, clause.LeftJoin, query, args...) } // InnerJoins specify inner joins conditions // db.InnerJoins("Account").Find(&user) func (db *DB) InnerJoins(query string, args ...interface{}) (tx *DB) { return joins(db, clause.InnerJoin, query, args...) } func joins(db *DB, joinType clause.JoinType, query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if len(args) == 1 { if db, ok := args[0].(*DB); ok { j := join{ Name: query, Conds: args, Selects: db.Statement.Selects, Omits: db.Statement.Omits, JoinType: joinType, } if where, ok := db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { j.On = &where } tx.Statement.Joins = append(tx.Statement.Joins, j) return } } tx.Statement.Joins = append(tx.Statement.Joins, join{Name: query, Conds: args, JoinType: joinType}) return } // Group specify the group method on the find // // // Select the sum age of users with given names // db.Model(&User{}).Select("name, sum(age) as total").Group("name").Find(&results) func (db *DB) Group(name string) (tx *DB) { tx = db.getInstance() fields := strings.FieldsFunc(name, utils.IsInvalidDBNameChar) tx.Statement.AddClause(clause.GroupBy{ Columns: []clause.Column{{Name: name, Raw: len(fields) != 1}}, }) return } // Having specify HAVING conditions for GROUP BY // // // Select the sum age of users with name jinzhu // db.Model(&User{}).Select("name, sum(age) as total").Group("name").Having("name = ?", "jinzhu").Find(&result) func (db *DB) Having(query interface{}, args ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.GroupBy{ Having: tx.Statement.BuildCondition(query, args...), }) return } // Order specify order when retrieving records from database // // db.Order("name DESC") // db.Order(clause.OrderByColumn{Column: clause.Column{Name: "name"}, Desc: true}) // db.Order(clause.OrderBy{Columns: []clause.OrderByColumn{ // {Column: clause.Column{Name: "name"}, Desc: true}, // {Column: clause.Column{Name: "age"}, Desc: true}, // }}) func (db *DB) Order(value interface{}) (tx *DB) { tx = db.getInstance() switch v := value.(type) { case clause.OrderBy: tx.Statement.AddClause(v) case clause.OrderByColumn: tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{v}, }) case string: if v != "" { tx.Statement.AddClause(clause.OrderBy{ Columns: []clause.OrderByColumn{{ Column: clause.Column{Name: v, Raw: true}, }}, }) } } return } // Limit specify the number of records to be retrieved // // Limit conditions can be cancelled by using `Limit(-1)`. // // // retrieve 3 users // db.Limit(3).Find(&users) // // retrieve 3 users into users1, and all users into users2 // db.Limit(3).Find(&users1).Limit(-1).Find(&users2) func (db *DB) Limit(limit int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Limit: &limit}) return } // Offset specify the number of records to skip before starting to return the records // // Offset conditions can be cancelled by using `Offset(-1)`. // // // select the third user // db.Offset(2).First(&user) // // select the first user by cancelling an earlier chained offset // db.Offset(5).Offset(-1).First(&user) func (db *DB) Offset(offset int) (tx *DB) { tx = db.getInstance() tx.Statement.AddClause(clause.Limit{Offset: offset}) return } // Scopes pass current database connection to arguments `func(DB) DB`, which could be used to add conditions dynamically // // func AmountGreaterThan1000(db *gorm.DB) *gorm.DB { // return db.Where("amount > ?", 1000) // } // // func OrderStatus(status []string) func (db *gorm.DB) *gorm.DB { // return func (db *gorm.DB) *gorm.DB { // return db.Scopes(AmountGreaterThan1000).Where("status in (?)", status) // } // } // // db.Scopes(AmountGreaterThan1000, OrderStatus([]string{"paid", "shipped"})).Find(&orders) func (db *DB) Scopes(funcs ...func(*DB) *DB) (tx *DB) { tx = db.getInstance() tx.Statement.scopes = append(tx.Statement.scopes, funcs...) return tx } func (db *DB) executeScopes() (tx *DB) { scopes := db.Statement.scopes db.Statement.scopes = nil for _, scope := range scopes { db = scope(db) } return db } // Preload preload associations with given conditions // // // get all users, and preload all non-cancelled orders // db.Preload("Orders", "state NOT IN (?)", "cancelled").Find(&users) func (db *DB) Preload(query string, args ...interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Preloads == nil { tx.Statement.Preloads = map[string][]interface{}{} } tx.Statement.Preloads[query] = args return } // Attrs provide attributes used in [FirstOrCreate] or [FirstOrInit] // // Attrs only adds attributes if the record is not found. // // // assign an email if the record is not found // db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) // // user -> User{Name: "non_existing", Email: "fake@fake.org"} // // // assign an email if the record is not found, otherwise ignore provided email // db.Where(User{Name: "jinzhu"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) // // user -> User{Name: "jinzhu", Age: 20} // // [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate // [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Attrs(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.attrs = attrs return } // Assign provide attributes used in [FirstOrCreate] or [FirstOrInit] // // Assign adds attributes even if the record is found. If using FirstOrCreate, this means that // records will be updated even if they are found. // // // assign an email regardless of if the record is not found // db.Where(User{Name: "non_existing"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) // // user -> User{Name: "non_existing", Email: "fake@fake.org"} // // // assign email regardless of if record is found // db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} // // [FirstOrCreate]: https://gorm.io/docs/advanced_query.html#FirstOrCreate // [FirstOrInit]: https://gorm.io/docs/advanced_query.html#FirstOrInit func (db *DB) Assign(attrs ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.assigns = attrs return } // Unscoped disables the global scope of soft deletion in a query. // By default, GORM uses soft deletion, marking records as "deleted" // by setting a timestamp on a specific field (e.g., `deleted_at`). // Unscoped allows queries to include records marked as deleted, // overriding the soft deletion behavior. // Example: // // var users []User // db.Unscoped().Find(&users) // // Retrieves all users, including deleted ones. func (db *DB) Unscoped() (tx *DB) { tx = db.getInstance() tx.Statement.Unscoped = true return } func (db *DB) Raw(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} if strings.Contains(sql, "@") { clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) } else { clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) } return } ================================================ FILE: clause/association.go ================================================ package clause // AssociationOpType represents association operation types type AssociationOpType int const ( OpUnlink AssociationOpType = iota // Unlink association OpDelete // Delete association records OpUpdate // Update association records OpCreate // Create association records with assignments ) // Association represents an association operation type Association struct { Association string // Association name Type AssociationOpType // Operation type Conditions []Expression // Filter conditions Set []Assignment // Assignment operations (for Update and Create) Values []interface{} // Values for Create operation } // AssociationAssigner is an interface for association operation providers type AssociationAssigner interface { AssociationAssignments() []Association } // Assignments implements the Assigner interface so that AssociationOperation can be used as a Set method parameter func (ao Association) Assignments() []Assignment { return []Assignment{} } // AssociationAssignments implements the AssociationAssigner interface func (ao Association) AssociationAssignments() []Association { return []Association{ao} } ================================================ FILE: clause/benchmarks_test.go ================================================ package clause_test import ( "sync" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) func BenchmarkSelect(b *testing.B) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}}} for _, clause := range clauses { stmt.AddClause(clause) } stmt.Build("SELECT", "FROM", "WHERE") _ = stmt.SQL.String() } } func BenchmarkComplexSelect(b *testing.B) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) limit10 := 10 for i := 0; i < b.N; i++ { stmt := gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clauses := []clause.Interface{ clause.Select{}, clause.From{}, clause.Where{Exprs: []clause.Expression{ clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), }}, clause.Where{Exprs: []clause.Expression{ clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"}), }}, clause.GroupBy{Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}}, clause.Limit{Limit: &limit10, Offset: 20}, clause.OrderBy{Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}}, } for _, clause := range clauses { stmt.AddClause(clause) } stmt.Build("SELECT", "FROM", "WHERE", "GROUP BY", "LIMIT", "ORDER BY") _ = stmt.SQL.String() } } ================================================ FILE: clause/clause.go ================================================ package clause // Interface clause interface type Interface interface { Name() string Build(Builder) MergeClause(*Clause) } // ClauseBuilder clause builder, allows to customize how to build clause type ClauseBuilder func(Clause, Builder) type Writer interface { WriteByte(byte) error WriteString(string) (int, error) } // Builder builder interface type Builder interface { Writer WriteQuoted(field interface{}) AddVar(Writer, ...interface{}) AddError(error) error } // Clause type Clause struct { Name string // WHERE BeforeExpression Expression AfterNameExpression Expression AfterExpression Expression Expression Expression Builder ClauseBuilder } // Build build clause func (c Clause) Build(builder Builder) { if c.Builder != nil { c.Builder(c, builder) } else if c.Expression != nil { if c.BeforeExpression != nil { c.BeforeExpression.Build(builder) builder.WriteByte(' ') } if c.Name != "" { builder.WriteString(c.Name) builder.WriteByte(' ') } if c.AfterNameExpression != nil { c.AfterNameExpression.Build(builder) builder.WriteByte(' ') } c.Expression.Build(builder) if c.AfterExpression != nil { builder.WriteByte(' ') c.AfterExpression.Build(builder) } } } const ( PrimaryKey string = "~~~py~~~" // primary key CurrentTable string = "~~~ct~~~" // current table Associations string = "~~~as~~~" // associations ) var ( currentTable = Table{Name: CurrentTable} PrimaryColumn = Column{Table: CurrentTable, Name: PrimaryKey} ) // Column quote with name type Column struct { Table string Name string Alias string Raw bool } // Table quote with name type Table struct { Name string Alias string Raw bool } ================================================ FILE: clause/clause_test.go ================================================ package clause_test import ( "reflect" "strings" "sync" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) var db, _ = gorm.Open(tests.DummyDialector{}, nil) func checkBuildClauses(t *testing.T, clauses []clause.Interface, result string, vars []interface{}) { var ( buildNames []string buildNamesMap = map[string]bool{} user, _ = schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt = gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} ) for _, c := range clauses { if _, ok := buildNamesMap[c.Name()]; !ok { buildNames = append(buildNames, c.Name()) buildNamesMap[c.Name()] = true } stmt.AddClause(c) } stmt.Build(buildNames...) if strings.TrimSpace(stmt.SQL.String()) != result { t.Errorf("SQL expects %v got %v", result, stmt.SQL.String()) } if !reflect.DeepEqual(stmt.Vars, vars) { t.Errorf("Vars expects %+v got %v", stmt.Vars, vars) } } ================================================ FILE: clause/delete.go ================================================ package clause type Delete struct { Modifier string } func (d Delete) Name() string { return "DELETE" } func (d Delete) Build(builder Builder) { builder.WriteString("DELETE") if d.Modifier != "" { builder.WriteByte(' ') builder.WriteString(d.Modifier) } } func (d Delete) MergeClause(clause *Clause) { clause.Name = "" clause.Expression = d } ================================================ FILE: clause/delete_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestDelete(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Delete{}, clause.From{}}, "DELETE FROM `users`", nil, }, { []clause.Interface{clause.Delete{Modifier: "LOW_PRIORITY"}, clause.From{}}, "DELETE LOW_PRIORITY FROM `users`", nil, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/expression.go ================================================ package clause import ( "database/sql" "database/sql/driver" "go/ast" "reflect" ) // Expression expression interface type Expression interface { Build(builder Builder) } // NegationExpressionBuilder negation expression builder type NegationExpressionBuilder interface { NegationBuild(builder Builder) } // Expr raw expression type Expr struct { SQL string Vars []interface{} WithoutParentheses bool } // Build build raw expression func (expr Expr) Build(builder Builder) { var ( afterParenthesis bool idx int ) for _, v := range []byte(expr.SQL) { if v == '?' && len(expr.Vars) > idx { if afterParenthesis || expr.WithoutParentheses { processValue(builder, expr.Vars[idx]) } else { builder.AddVar(builder, expr.Vars[idx]) } idx++ } else { if v == '(' { afterParenthesis = true } else { afterParenthesis = false } builder.WriteByte(v) } } if idx < len(expr.Vars) { for _, v := range expr.Vars[idx:] { builder.AddVar(builder, sql.NamedArg{Value: v}) } } } // NamedExpr raw expression for named expr type NamedExpr struct { SQL string Vars []interface{} } // Build build raw expression func (expr NamedExpr) Build(builder Builder) { var ( idx int inName bool afterParenthesis bool namedMap = make(map[string]interface{}, len(expr.Vars)) ) for _, v := range expr.Vars { switch value := v.(type) { case sql.NamedArg: namedMap[value.Name] = value.Value case map[string]interface{}: for k, v := range value { namedMap[k] = v } default: var appendFieldsToMap func(reflect.Value) appendFieldsToMap = func(reflectValue reflect.Value) { reflectValue = reflect.Indirect(reflectValue) switch reflectValue.Kind() { case reflect.Struct: modelType := reflectValue.Type() for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { namedMap[fieldStruct.Name] = reflectValue.Field(i).Interface() if fieldStruct.Anonymous { appendFieldsToMap(reflectValue.Field(i)) } } } } } appendFieldsToMap(reflect.ValueOf(value)) } } name := make([]byte, 0, 10) for _, v := range []byte(expr.SQL) { if v == '@' && !inName { inName = true name = name[:0] } else if v == ' ' || v == ',' || v == ')' || v == '"' || v == '\'' || v == '`' || v == '\r' || v == '\n' || v == ';' { if inName { if nv, ok := namedMap[string(name)]; ok { if afterParenthesis { processValue(builder, nv) } else { builder.AddVar(builder, nv) } } else { builder.WriteByte('@') builder.WriteString(string(name)) } inName = false } afterParenthesis = false builder.WriteByte(v) } else if v == '?' && len(expr.Vars) > idx { if afterParenthesis { processValue(builder, expr.Vars[idx]) } else { builder.AddVar(builder, expr.Vars[idx]) } idx++ } else if inName { name = append(name, v) } else { if v == '(' { afterParenthesis = true } else { afterParenthesis = false } builder.WriteByte(v) } } if inName { if nv, ok := namedMap[string(name)]; ok { builder.AddVar(builder, nv) } else { builder.WriteByte('@') builder.WriteString(string(name)) } } } // processValue handles different value types appropriately for SQL parameter binding // It checks for driver.Valuer first, then handles slices/arrays, and finally adds single values func processValue(builder Builder, value interface{}) { if _, ok := value.(driver.Valuer); ok { builder.AddVar(builder, value) return } switch rv := reflect.ValueOf(value); rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() == 0 { builder.AddVar(builder, nil) } else { for i := 0; i < rv.Len(); i++ { if i > 0 { builder.WriteByte(',') } builder.AddVar(builder, rv.Index(i).Interface()) } } default: builder.AddVar(builder, value) } } // IN Whether a value is within a set of values type IN struct { Column interface{} Values []interface{} } func (in IN) Build(builder Builder) { builder.WriteQuoted(in.Column) switch len(in.Values) { case 0: builder.WriteString(" IN (NULL)") case 1: if _, ok := in.Values[0].([]interface{}); !ok { builder.WriteString(" = ") builder.AddVar(builder, in.Values[0]) break } fallthrough default: builder.WriteString(" IN (") builder.AddVar(builder, in.Values...) builder.WriteByte(')') } } func (in IN) NegationBuild(builder Builder) { builder.WriteQuoted(in.Column) switch len(in.Values) { case 0: builder.WriteString(" IS NOT NULL") case 1: if _, ok := in.Values[0].([]interface{}); !ok { builder.WriteString(" <> ") builder.AddVar(builder, in.Values[0]) break } fallthrough default: builder.WriteString(" NOT IN (") builder.AddVar(builder, in.Values...) builder.WriteByte(')') } } // Eq equal to for where type Eq struct { Column interface{} Value interface{} } func (eq Eq) Build(builder Builder) { builder.WriteQuoted(eq.Column) switch eq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: rv := reflect.ValueOf(eq.Value) if rv.Len() == 0 { builder.WriteString(" IN (NULL)") } else { builder.WriteString(" IN (") for i := 0; i < rv.Len(); i++ { if i > 0 { builder.WriteByte(',') } builder.AddVar(builder, rv.Index(i).Interface()) } builder.WriteByte(')') } default: if eqNil(eq.Value) { builder.WriteString(" IS NULL") } else { builder.WriteString(" = ") builder.AddVar(builder, eq.Value) } } } func (eq Eq) NegationBuild(builder Builder) { Neq(eq).Build(builder) } // Neq not equal to for where type Neq Eq func (neq Neq) Build(builder Builder) { builder.WriteQuoted(neq.Column) switch neq.Value.(type) { case []string, []int, []int32, []int64, []uint, []uint32, []uint64, []interface{}: builder.WriteString(" NOT IN (") rv := reflect.ValueOf(neq.Value) for i := 0; i < rv.Len(); i++ { if i > 0 { builder.WriteByte(',') } builder.AddVar(builder, rv.Index(i).Interface()) } builder.WriteByte(')') default: if eqNil(neq.Value) { builder.WriteString(" IS NOT NULL") } else { builder.WriteString(" <> ") builder.AddVar(builder, neq.Value) } } } func (neq Neq) NegationBuild(builder Builder) { Eq(neq).Build(builder) } // Gt greater than for where type Gt Eq func (gt Gt) Build(builder Builder) { builder.WriteQuoted(gt.Column) builder.WriteString(" > ") builder.AddVar(builder, gt.Value) } func (gt Gt) NegationBuild(builder Builder) { Lte(gt).Build(builder) } // Gte greater than or equal to for where type Gte Eq func (gte Gte) Build(builder Builder) { builder.WriteQuoted(gte.Column) builder.WriteString(" >= ") builder.AddVar(builder, gte.Value) } func (gte Gte) NegationBuild(builder Builder) { Lt(gte).Build(builder) } // Lt less than for where type Lt Eq func (lt Lt) Build(builder Builder) { builder.WriteQuoted(lt.Column) builder.WriteString(" < ") builder.AddVar(builder, lt.Value) } func (lt Lt) NegationBuild(builder Builder) { Gte(lt).Build(builder) } // Lte less than or equal to for where type Lte Eq func (lte Lte) Build(builder Builder) { builder.WriteQuoted(lte.Column) builder.WriteString(" <= ") builder.AddVar(builder, lte.Value) } func (lte Lte) NegationBuild(builder Builder) { Gt(lte).Build(builder) } // Like whether string matches regular expression type Like Eq func (like Like) Build(builder Builder) { builder.WriteQuoted(like.Column) builder.WriteString(" LIKE ") builder.AddVar(builder, like.Value) } func (like Like) NegationBuild(builder Builder) { builder.WriteQuoted(like.Column) builder.WriteString(" NOT LIKE ") builder.AddVar(builder, like.Value) } func eqNil(value interface{}) bool { if valuer, ok := value.(driver.Valuer); ok && !eqNilReflect(valuer) { value, _ = valuer.Value() } return value == nil || eqNilReflect(value) } func eqNilReflect(value interface{}) bool { reflectValue := reflect.ValueOf(value) return reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() } ================================================ FILE: clause/expression_test.go ================================================ package clause_test import ( "database/sql" "fmt" "reflect" "sync" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) func TestExpr(t *testing.T) { results := []struct { SQL string Result string Vars []interface{} }{{ SQL: "create table ? (? ?, ? ?)", Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, Result: "create table `users` (`id` int, `name` text)", }} for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clause.Expr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) if stmt.SQL.String() != result.Result { t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) } }) } } func TestNamedExpr(t *testing.T) { type Base struct { Name2 string } type NamedArgument struct { Name1 string Base } results := []struct { SQL string Result string Vars []interface{} ExpectedVars []interface{} }{{ SQL: "create table ? (? ?, ? ?)", Vars: []interface{}{clause.Table{Name: "users"}, clause.Column{Name: "id"}, clause.Expr{SQL: "int"}, clause.Column{Name: "name"}, clause.Expr{SQL: "text"}}, Result: "create table `users` (`id` int, `name` text)", }, { SQL: "name1 = @name AND name2 = @name", Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }, { SQL: "name1 = @name AND name2 = @@name", Vars: []interface{}{map[string]interface{}{"name": "jinzhu"}}, Result: "name1 = ? AND name2 = @@name", ExpectedVars: []interface{}{"jinzhu"}, }, { SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Result: "name1 = ? AND name2 = ? AND name3 = ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { SQL: "name1 = @name1 AND name2 = @name2 AND name3 = @name1", Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu2"}}, Result: "name1 = ? AND name2 = ? AND name3 = ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { SQL: "@@test AND name1 = @name1 AND name2 = @name2 AND name3 = @name1 @notexist", Vars: []interface{}{sql.Named("name1", "jinzhu"), sql.Named("name2", "jinzhu2")}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { SQL: "@@test AND name1 = @Name1 AND name2 = @Name2 AND name3 = @Name1 @notexist", Vars: []interface{}{NamedArgument{Name1: "jinzhu", Base: Base{Name2: "jinzhu2"}}}, Result: "@@test AND name1 = ? AND name2 = ? AND name3 = ? @notexist", ExpectedVars: []interface{}{"jinzhu", "jinzhu2", "jinzhu"}, }, { SQL: "name in (@names)", Vars: []interface{}{map[string]interface{}{"names": []interface{}{"jinzhu", "jinzhu2"}}}, Result: "name in (?,?)", ExpectedVars: []interface{}{"jinzhu", "jinzhu2"}, }, { SQL: "name in (@names)", Vars: []interface{}{map[string]interface{}{"names": "jinzhu"}}, Result: "name in (?)", ExpectedVars: []interface{}{"jinzhu"}, }, { SQL: "create table ? (? ?, ? ?)", Vars: []interface{}{}, Result: "create table ? (? ?, ? ?)", }, { SQL: "name1 = @name AND name2 = @name;", Vars: []interface{}{sql.Named("name", "jinzhu")}, Result: "name1 = ? AND name2 = ?;", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }, { SQL: "name1 = @name1\r\n AND name2 = @name2", Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, Result: "name1 = ?\r\n AND name2 = ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }, { SQL: "name1 = @name1\r AND name2 = @name2", Vars: []interface{}{map[string]interface{}{"name1": "jinzhu", "name2": "jinzhu"}}, Result: "name1 = ?\r AND name2 = ?", ExpectedVars: []interface{}{"jinzhu", "jinzhu"}, }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: "col"}}, Result: "`table`.`col`", }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: "col", Raw: true}}, Result: "table.col", }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: clause.PrimaryKey, Raw: true}}, Result: "table.id", }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias"}}, Result: "`table`.`col` AS `alias`", }, { SQL: "?", Vars: []interface{}{clause.Column{Table: "table", Name: "col", Alias: "alias", Raw: true}}, Result: "table.col AS alias", }, { SQL: "?", Vars: []interface{}{clause.Table{Name: "table", Alias: "alias"}}, Result: "`table` `alias`", }, { SQL: "?", Vars: []interface{}{clause.Table{Name: "table", Alias: "alias", Raw: true}}, Result: "table alias", }} for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} clause.NamedExpr{SQL: result.SQL, Vars: result.Vars}.Build(stmt) if stmt.SQL.String() != result.Result { t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) } if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) } }) } } func TestExpression(t *testing.T) { column := "column-name" results := []struct { Expressions []clause.Expression ExpectedVars []interface{} Result string }{{ Expressions: []clause.Expression{ clause.Eq{Column: column, Value: "column-value"}, }, ExpectedVars: []interface{}{"column-value"}, Result: "`column-name` = ?", }, { Expressions: []clause.Expression{ clause.Eq{Column: column, Value: nil}, clause.Eq{Column: column, Value: (*string)(nil)}, clause.Eq{Column: column, Value: (*int)(nil)}, clause.Eq{Column: column, Value: (*bool)(nil)}, clause.Eq{Column: column, Value: (interface{})(nil)}, clause.Eq{Column: column, Value: sql.NullString{String: "", Valid: false}}, }, Result: "`column-name` IS NULL", }, { Expressions: []clause.Expression{ clause.Neq{Column: column, Value: "column-value"}, }, ExpectedVars: []interface{}{"column-value"}, Result: "`column-name` <> ?", }, { Expressions: []clause.Expression{ clause.Neq{Column: column, Value: nil}, clause.Neq{Column: column, Value: (*string)(nil)}, clause.Neq{Column: column, Value: (*int)(nil)}, clause.Neq{Column: column, Value: (*bool)(nil)}, clause.Neq{Column: column, Value: (interface{})(nil)}, }, Result: "`column-name` IS NOT NULL", }, { Expressions: []clause.Expression{ clause.Eq{Column: column, Value: []string{"a", "b"}}, }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` IN (?,?)", }, { Expressions: []clause.Expression{ clause.Neq{Column: column, Value: []string{"a", "b"}}, }, ExpectedVars: []interface{}{"a", "b"}, Result: "`column-name` NOT IN (?,?)", }, { Expressions: []clause.Expression{ clause.Eq{Column: column, Value: []string{}}, }, Result: "`column-name` IN (NULL)", }, { Expressions: []clause.Expression{ clause.Eq{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Name: "id"}}}, Value: 100}, }, ExpectedVars: []interface{}{100}, Result: "SUM(`id`) = ?", }, { Expressions: []clause.Expression{ clause.Gte{Column: clause.Expr{SQL: "SUM(?)", Vars: []interface{}{clause.Column{Table: "users", Name: "id"}}}, Value: 100}, }, ExpectedVars: []interface{}{100}, Result: "SUM(`users`.`id`) >= ?", }} for idx, result := range results { for idy, expression := range result.Expressions { t.Run(fmt.Sprintf("case #%v.%v", idx, idy), func(t *testing.T) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} expression.Build(stmt) if stmt.SQL.String() != result.Result { t.Errorf("generated SQL is not equal, expects %v, but got %v", result.Result, stmt.SQL.String()) } if !reflect.DeepEqual(result.ExpectedVars, stmt.Vars) { t.Errorf("generated vars is not equal, expects %v, but got %v", result.ExpectedVars, stmt.Vars) } }) } } } ================================================ FILE: clause/from.go ================================================ package clause // From from clause type From struct { Tables []Table Joins []Join } // Name from clause name func (from From) Name() string { return "FROM" } // Build build from clause func (from From) Build(builder Builder) { if len(from.Tables) > 0 { for idx, table := range from.Tables { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(table) } } else { builder.WriteQuoted(currentTable) } for _, join := range from.Joins { builder.WriteByte(' ') join.Build(builder) } } // MergeClause merge from clause func (from From) MergeClause(clause *Clause) { clause.Expression = from } ================================================ FILE: clause/from_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestFrom(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}}, "SELECT * FROM `users`", nil, }, { []clause.Interface{ clause.Select{}, clause.From{ Tables: []clause.Table{{Name: "users"}}, Joins: []clause.Join{ { Type: clause.InnerJoin, Table: clause.Table{Name: "articles"}, ON: clause.Where{ []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, }, }, }, }, }, "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id`", nil, }, { []clause.Interface{ clause.Select{}, clause.From{ Tables: []clause.Table{{Name: "users"}}, Joins: []clause.Join{ { Type: clause.RightJoin, Table: clause.Table{Name: "profiles"}, ON: clause.Where{ []clause.Expression{clause.Eq{clause.Column{Table: "profiles", Name: "email"}, clause.Column{Table: clause.CurrentTable, Name: "email"}}}, }, }, }, }, clause.From{ Joins: []clause.Join{ { Type: clause.InnerJoin, Table: clause.Table{Name: "articles"}, ON: clause.Where{ []clause.Expression{clause.Eq{clause.Column{Table: "articles", Name: "id"}, clause.PrimaryColumn}}, }, }, { Type: clause.LeftJoin, Table: clause.Table{Name: "companies"}, Using: []string{"company_name"}, }, }, }, }, "SELECT * FROM `users` INNER JOIN `articles` ON `articles`.`id` = `users`.`id` LEFT JOIN `companies` USING (`company_name`)", nil, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/group_by.go ================================================ package clause // GroupBy group by clause type GroupBy struct { Columns []Column Having []Expression } // Name from clause name func (groupBy GroupBy) Name() string { return "GROUP BY" } // Build build group by clause func (groupBy GroupBy) Build(builder Builder) { for idx, column := range groupBy.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column) } if len(groupBy.Having) > 0 { builder.WriteString(" HAVING ") Where{Exprs: groupBy.Having}.Build(builder) } } // MergeClause merge group by clause func (groupBy GroupBy) MergeClause(clause *Clause) { if v, ok := clause.Expression.(GroupBy); ok { copiedColumns := make([]Column, len(v.Columns)) copy(copiedColumns, v.Columns) groupBy.Columns = append(copiedColumns, groupBy.Columns...) copiedHaving := make([]Expression, len(v.Having)) copy(copiedHaving, v.Having) groupBy.Having = append(copiedHaving, groupBy.Having...) } clause.Expression = groupBy if len(groupBy.Columns) == 0 { clause.Name = "" } else { clause.Name = groupBy.Name() } } ================================================ FILE: clause/group_by_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestGroupBy(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}, }}, "SELECT * FROM `users` GROUP BY `role` HAVING `role` = ?", []interface{}{"admin"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.GroupBy{ Columns: []clause.Column{{Name: "role"}}, Having: []clause.Expression{clause.Eq{"role", "admin"}}, }, clause.GroupBy{ Columns: []clause.Column{{Name: "gender"}}, Having: []clause.Expression{clause.Neq{"gender", "U"}}, }}, "SELECT * FROM `users` GROUP BY `role`,`gender` HAVING `role` = ? AND `gender` <> ?", []interface{}{"admin", "U"}, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/insert.go ================================================ package clause type Insert struct { Table Table Modifier string } // Name insert clause name func (insert Insert) Name() string { return "INSERT" } // Build build insert clause func (insert Insert) Build(builder Builder) { if insert.Modifier != "" { builder.WriteString(insert.Modifier) builder.WriteByte(' ') } builder.WriteString("INTO ") if insert.Table.Name == "" { builder.WriteQuoted(currentTable) } else { builder.WriteQuoted(insert.Table) } } // MergeClause merge insert clause func (insert Insert) MergeClause(clause *Clause) { if v, ok := clause.Expression.(Insert); ok { if insert.Modifier == "" { insert.Modifier = v.Modifier } if insert.Table.Name == "" { insert.Table = v.Table } } clause.Expression = insert } ================================================ FILE: clause/insert_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestInsert(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Insert{}}, "INSERT INTO `users`", nil, }, { []clause.Interface{clause.Insert{Modifier: "LOW_PRIORITY"}}, "INSERT LOW_PRIORITY INTO `users`", nil, }, { []clause.Interface{clause.Insert{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, "INSERT LOW_PRIORITY INTO `products`", nil, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/joins.go ================================================ package clause import "gorm.io/gorm/utils" type JoinType string const ( CrossJoin JoinType = "CROSS" InnerJoin JoinType = "INNER" LeftJoin JoinType = "LEFT" RightJoin JoinType = "RIGHT" ) type JoinTarget struct { Type JoinType Association string Subquery Expression Table string } func Has(name string) JoinTarget { return JoinTarget{Type: InnerJoin, Association: name} } func (jt JoinType) Association(name string) JoinTarget { return JoinTarget{Type: jt, Association: name} } func (jt JoinType) AssociationFrom(name string, subquery Expression) JoinTarget { return JoinTarget{Type: jt, Association: name, Subquery: subquery} } func (jt JoinTarget) As(name string) JoinTarget { jt.Table = name return jt } // Join clause for from type Join struct { Type JoinType Table Table ON Where Using []string Expression Expression } func JoinTable(names ...string) Table { return Table{ Name: utils.JoinNestedRelationNames(names), } } func (join Join) Build(builder Builder) { if join.Expression != nil { join.Expression.Build(builder) } else { if join.Type != "" { builder.WriteString(string(join.Type)) builder.WriteByte(' ') } builder.WriteString("JOIN ") builder.WriteQuoted(join.Table) if len(join.ON.Exprs) > 0 { builder.WriteString(" ON ") join.ON.Build(builder) } else if len(join.Using) > 0 { builder.WriteString(" USING (") for idx, c := range join.Using { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(c) } builder.WriteByte(')') } } } ================================================ FILE: clause/joins_test.go ================================================ package clause_test import ( "sync" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) func TestJoin(t *testing.T) { results := []struct { name string join clause.Join sql string }{ { name: "LEFT JOIN", join: clause.Join{ Type: clause.LeftJoin, Table: clause.Table{Name: "user"}, ON: clause.Where{ Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, }, }, sql: "LEFT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", }, { name: "RIGHT JOIN", join: clause.Join{ Type: clause.RightJoin, Table: clause.Table{Name: "user"}, ON: clause.Where{ Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, }, }, sql: "RIGHT JOIN `user` ON `user_info`.`user_id` = `users`.`id`", }, { name: "INNER JOIN", join: clause.Join{ Type: clause.InnerJoin, Table: clause.Table{Name: "user"}, ON: clause.Where{ Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, }, }, sql: "INNER JOIN `user` ON `user_info`.`user_id` = `users`.`id`", }, { name: "CROSS JOIN", join: clause.Join{ Type: clause.CrossJoin, Table: clause.Table{Name: "user"}, ON: clause.Where{ Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, }, }, sql: "CROSS JOIN `user` ON `user_info`.`user_id` = `users`.`id`", }, { name: "USING", join: clause.Join{ Type: clause.InnerJoin, Table: clause.Table{Name: "user"}, Using: []string{"id"}, }, sql: "INNER JOIN `user` USING (`id`)", }, { name: "Expression", join: clause.Join{ // Invalid Type: clause.LeftJoin, Table: clause.Table{Name: "user"}, ON: clause.Where{ Exprs: []clause.Expression{clause.Eq{clause.Column{Table: "user_info", Name: "user_id"}, clause.PrimaryColumn}}, }, // Valid Expression: clause.Join{ Type: clause.InnerJoin, Table: clause.Table{Name: "user"}, Using: []string{"id"}, }, }, sql: "INNER JOIN `user` USING (`id`)", }, } for _, result := range results { t.Run(result.name, func(t *testing.T) { user, _ := schema.Parse(&tests.User{}, &sync.Map{}, db.NamingStrategy) stmt := &gorm.Statement{DB: db, Table: user.Table, Schema: user, Clauses: map[string]clause.Clause{}} result.join.Build(stmt) if result.sql != stmt.SQL.String() { t.Errorf("want: %s, got: %s", result.sql, stmt.SQL.String()) } }) } } ================================================ FILE: clause/limit.go ================================================ package clause // Limit limit clause type Limit struct { Limit *int Offset int } // Name where clause name func (limit Limit) Name() string { return "LIMIT" } // Build build where clause func (limit Limit) Build(builder Builder) { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteString("LIMIT ") builder.AddVar(builder, *limit.Limit) } if limit.Offset > 0 { if limit.Limit != nil && *limit.Limit >= 0 { builder.WriteByte(' ') } builder.WriteString("OFFSET ") builder.AddVar(builder, limit.Offset) } } // MergeClause merge order by clauses func (limit Limit) MergeClause(clause *Clause) { clause.Name = "" if v, ok := clause.Expression.(Limit); ok { if (limit.Limit == nil || *limit.Limit == 0) && v.Limit != nil { limit.Limit = v.Limit } if limit.Offset == 0 && v.Offset > 0 { limit.Offset = v.Offset } else if limit.Offset < 0 { limit.Offset = 0 } } clause.Expression = limit } ================================================ FILE: clause/limit_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestLimit(t *testing.T) { limit0 := 0 limit10 := 10 limit50 := 50 limitNeg10 := -10 results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{ Limit: &limit10, Offset: 20, }}, "SELECT * FROM `users` LIMIT ? OFFSET ?", []interface{}{limit10, 20}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}}, "SELECT * FROM `users` LIMIT ?", []interface{}{limit0}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit0}, clause.Limit{Offset: 0}}, "SELECT * FROM `users` LIMIT ?", []interface{}{limit0}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}}, "SELECT * FROM `users` OFFSET ?", []interface{}{20}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Offset: 30}}, "SELECT * FROM `users` OFFSET ?", []interface{}{30}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Offset: 20}, clause.Limit{Limit: &limit10}}, "SELECT * FROM `users` LIMIT ? OFFSET ?", []interface{}{limit10, 20}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}}, "SELECT * FROM `users` LIMIT ? OFFSET ?", []interface{}{limit10, 30}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Offset: -10}}, "SELECT * FROM `users` LIMIT ?", []interface{}{limit10}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limitNeg10}}, "SELECT * FROM `users` OFFSET ?", []interface{}{30}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Limit{Limit: &limit10, Offset: 20}, clause.Limit{Offset: 30}, clause.Limit{Limit: &limit50}}, "SELECT * FROM `users` LIMIT ? OFFSET ?", []interface{}{limit50, 30}, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/locking.go ================================================ package clause const ( LockingStrengthUpdate = "UPDATE" LockingStrengthShare = "SHARE" LockingOptionsSkipLocked = "SKIP LOCKED" LockingOptionsNoWait = "NOWAIT" ) type Locking struct { Strength string Table Table Options string } // Name where clause name func (locking Locking) Name() string { return "FOR" } // Build build where clause func (locking Locking) Build(builder Builder) { builder.WriteString(locking.Strength) if locking.Table.Name != "" { builder.WriteString(" OF ") builder.WriteQuoted(locking.Table) } if locking.Options != "" { builder.WriteByte(' ') builder.WriteString(locking.Options) } } // MergeClause merge order by clauses func (locking Locking) MergeClause(clause *Clause) { clause.Expression = locking } ================================================ FILE: clause/locking_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestLocking(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate}}, "SELECT * FROM `users` FOR UPDATE", nil, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthShare, Table: clause.Table{Name: clause.CurrentTable}}}, "SELECT * FROM `users` FOR SHARE OF `users`", nil, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsNoWait}}, "SELECT * FROM `users` FOR UPDATE NOWAIT", nil, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Locking{Strength: clause.LockingStrengthUpdate, Options: clause.LockingOptionsSkipLocked}}, "SELECT * FROM `users` FOR UPDATE SKIP LOCKED", nil, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/on_conflict.go ================================================ package clause type OnConflict struct { Columns []Column Where Where TargetWhere Where OnConstraint string DoNothing bool DoUpdates Set UpdateAll bool } func (OnConflict) Name() string { return "ON CONFLICT" } // Build build onConflict clause func (onConflict OnConflict) Build(builder Builder) { if onConflict.OnConstraint != "" { builder.WriteString("ON CONSTRAINT ") builder.WriteString(onConflict.OnConstraint) builder.WriteByte(' ') } else { if len(onConflict.Columns) > 0 { builder.WriteByte('(') for idx, column := range onConflict.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column) } builder.WriteString(`) `) } if len(onConflict.TargetWhere.Exprs) > 0 { builder.WriteString(" WHERE ") onConflict.TargetWhere.Build(builder) builder.WriteByte(' ') } } if onConflict.DoNothing { builder.WriteString("DO NOTHING") } else { builder.WriteString("DO UPDATE SET ") onConflict.DoUpdates.Build(builder) } if len(onConflict.Where.Exprs) > 0 { builder.WriteString(" WHERE ") onConflict.Where.Build(builder) builder.WriteByte(' ') } } // MergeClause merge onConflict clauses func (onConflict OnConflict) MergeClause(clause *Clause) { clause.Expression = onConflict } ================================================ FILE: clause/order_by.go ================================================ package clause type OrderByColumn struct { Column Column Desc bool Reorder bool } type OrderBy struct { Columns []OrderByColumn Expression Expression } // Name where clause name func (orderBy OrderBy) Name() string { return "ORDER BY" } // Build build where clause func (orderBy OrderBy) Build(builder Builder) { if orderBy.Expression != nil { orderBy.Expression.Build(builder) } else { for idx, column := range orderBy.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column.Column) if column.Desc { builder.WriteString(" DESC") } } } } // MergeClause merge order by clauses func (orderBy OrderBy) MergeClause(clause *Clause) { if v, ok := clause.Expression.(OrderBy); ok { for i := len(orderBy.Columns) - 1; i >= 0; i-- { if orderBy.Columns[i].Reorder { orderBy.Columns = orderBy.Columns[i:] clause.Expression = orderBy return } } copiedColumns := make([]OrderByColumn, len(v.Columns)) copy(copiedColumns, v.Columns) orderBy.Columns = append(copiedColumns, orderBy.Columns...) } clause.Expression = orderBy } ================================================ FILE: clause/order_by_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestOrderBy(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.OrderBy{ Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, }}, "SELECT * FROM `users` ORDER BY `users`.`id` DESC", nil, }, { []clause.Interface{ clause.Select{}, clause.From{}, clause.OrderBy{ Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, }, clause.OrderBy{ Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}}}, }, }, "SELECT * FROM `users` ORDER BY `users`.`id` DESC,`name`", nil, }, { []clause.Interface{ clause.Select{}, clause.From{}, clause.OrderBy{ Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, }, clause.OrderBy{ Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "name"}, Reorder: true}}, }, }, "SELECT * FROM `users` ORDER BY `name`", nil, }, { []clause.Interface{ clause.Select{}, clause.From{}, clause.OrderBy{ Expression: clause.Expr{SQL: "FIELD(id, ?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, }, }, "SELECT * FROM `users` ORDER BY FIELD(id, ?,?,?)", []interface{}{1, 2, 3}, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/returning.go ================================================ package clause type Returning struct { Columns []Column } // Name where clause name func (returning Returning) Name() string { return "RETURNING" } // Build build where clause func (returning Returning) Build(builder Builder) { if len(returning.Columns) > 0 { for idx, column := range returning.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column) } } else { builder.WriteByte('*') } } // MergeClause merge order by clauses func (returning Returning) MergeClause(clause *Clause) { if v, ok := clause.Expression.(Returning); ok && len(returning.Columns) > 0 { if v.Columns != nil { returning.Columns = append(v.Columns, returning.Columns...) } else { returning.Columns = nil } } clause.Expression = returning } ================================================ FILE: clause/returning_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestReturning(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ []clause.Column{clause.PrimaryColumn}, }}, "SELECT * FROM `users` RETURNING `users`.`id`", nil, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ []clause.Column{clause.PrimaryColumn}, }, clause.Returning{ []clause.Column{{Name: "name"}, {Name: "age"}}, }}, "SELECT * FROM `users` RETURNING `users`.`id`,`name`,`age`", nil, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ []clause.Column{clause.PrimaryColumn}, }, clause.Returning{}, clause.Returning{ []clause.Column{{Name: "name"}, {Name: "age"}}, }}, "SELECT * FROM `users` RETURNING *", nil, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Returning{ []clause.Column{clause.PrimaryColumn}, }, clause.Returning{ []clause.Column{{Name: "name"}, {Name: "age"}}, }, clause.Returning{}}, "SELECT * FROM `users` RETURNING *", nil, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/select.go ================================================ package clause // Select select attrs when querying, updating, creating type Select struct { Distinct bool Columns []Column Expression Expression } func (s Select) Name() string { return "SELECT" } func (s Select) Build(builder Builder) { if len(s.Columns) > 0 { if s.Distinct { builder.WriteString("DISTINCT ") } for idx, column := range s.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column) } } else { builder.WriteByte('*') } } func (s Select) MergeClause(clause *Clause) { if s.Expression != nil { if s.Distinct { if expr, ok := s.Expression.(Expr); ok { expr.SQL = "DISTINCT " + expr.SQL clause.Expression = expr return } } clause.Expression = s.Expression } else { clause.Expression = s } } // CommaExpression represents a group of expressions separated by commas. type CommaExpression struct { Exprs []Expression } func (comma CommaExpression) Build(builder Builder) { for idx, expr := range comma.Exprs { if idx > 0 { _, _ = builder.WriteString(", ") } expr.Build(builder) } } ================================================ FILE: clause/select_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestSelect(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}}, "SELECT * FROM `users`", nil, }, { []clause.Interface{clause.Select{ Columns: []clause.Column{clause.PrimaryColumn}, }, clause.From{}}, "SELECT `users`.`id` FROM `users`", nil, }, { []clause.Interface{clause.Select{ Columns: []clause.Column{clause.PrimaryColumn}, }, clause.Select{ Columns: []clause.Column{{Name: "name"}}, }, clause.From{}}, "SELECT `name` FROM `users`", nil, }, { []clause.Interface{clause.Select{ Expression: clause.CommaExpression{ Exprs: []clause.Expression{ clause.NamedExpr{"?", []interface{}{clause.Column{Name: "id"}}}, clause.NamedExpr{"?", []interface{}{clause.Column{Name: "name"}}}, clause.NamedExpr{"LENGTH(?)", []interface{}{clause.Column{Name: "mobile"}}}, }, }, }, clause.From{}}, "SELECT `id`, `name`, LENGTH(`mobile`) FROM `users`", nil, }, { []clause.Interface{clause.Select{ Expression: clause.CommaExpression{ Exprs: []clause.Expression{ clause.Expr{ SQL: "? as name", Vars: []interface{}{ clause.Eq{ Column: clause.Column{Name: "age"}, Value: 18, }, }, }, }, }, }, clause.From{}}, "SELECT `age` = ? as name FROM `users`", []interface{}{18}, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/set.go ================================================ package clause import "sort" type Set []Assignment type Assignment struct { Column Column Value interface{} } // Assigner assignments provider interface type Assigner interface { Assignments() []Assignment } func (set Set) Name() string { return "SET" } func (set Set) Build(builder Builder) { if len(set) > 0 { for idx, assignment := range set { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(assignment.Column) builder.WriteByte('=') builder.AddVar(builder, assignment.Value) } } else { builder.WriteQuoted(Column{Name: PrimaryKey}) builder.WriteByte('=') builder.WriteQuoted(Column{Name: PrimaryKey}) } } // MergeClause merge assignments clauses func (set Set) MergeClause(clause *Clause) { copiedAssignments := make([]Assignment, len(set)) copy(copiedAssignments, set) clause.Expression = Set(copiedAssignments) } // Assignments implements Assigner for Set. func (set Set) Assignments() []Assignment { return []Assignment(set) } func Assignments(values map[string]interface{}) Set { keys := make([]string, 0, len(values)) for key := range values { keys = append(keys, key) } sort.Strings(keys) assignments := make([]Assignment, len(keys)) for idx, key := range keys { assignments[idx] = Assignment{Column: Column{Name: key}, Value: values[key]} } return assignments } func AssignmentColumns(values []string) Set { assignments := make([]Assignment, len(values)) for idx, value := range values { assignments[idx] = Assignment{Column: Column{Name: value}, Value: Column{Table: "excluded", Name: value}} } return assignments } // Assignments implements Assigner for a single Assignment. func (a Assignment) Assignments() []Assignment { return []Assignment{a} } ================================================ FILE: clause/set_test.go ================================================ package clause_test import ( "fmt" "sort" "strings" "testing" "gorm.io/gorm/clause" ) // Compile-time assertions that types implement clause.Assigner var ( _ clause.Assigner = clause.Assignment{} _ clause.Assigner = clause.Set{} ) func TestSet(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{ clause.Update{}, clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), }, "UPDATE `users` SET `users`.`id`=?", []interface{}{1}, }, { []clause.Interface{ clause.Update{}, clause.Set([]clause.Assignment{{clause.PrimaryColumn, 1}}), clause.Set([]clause.Assignment{{clause.Column{Name: "name"}, "jinzhu"}}), }, "UPDATE `users` SET `name`=?", []interface{}{"jinzhu"}, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } func TestAssignments(t *testing.T) { set := clause.Assignments(map[string]interface{}{ "name": "jinzhu", "age": 18, }) assignments := []clause.Assignment(set) sort.Slice(assignments, func(i, j int) bool { return strings.Compare(assignments[i].Column.Name, assignments[j].Column.Name) > 0 }) if len(assignments) != 2 || assignments[0].Column.Name != "name" || assignments[0].Value.(string) != "jinzhu" || assignments[1].Column.Name != "age" || assignments[1].Value.(int) != 18 { t.Errorf("invalid assignments, got %v", assignments) } } ================================================ FILE: clause/update.go ================================================ package clause type Update struct { Modifier string Table Table } // Name update clause name func (update Update) Name() string { return "UPDATE" } // Build build update clause func (update Update) Build(builder Builder) { if update.Modifier != "" { builder.WriteString(update.Modifier) builder.WriteByte(' ') } if update.Table.Name == "" { builder.WriteQuoted(currentTable) } else { builder.WriteQuoted(update.Table) } } // MergeClause merge update clause func (update Update) MergeClause(clause *Clause) { if v, ok := clause.Expression.(Update); ok { if update.Modifier == "" { update.Modifier = v.Modifier } if update.Table.Name == "" { update.Table = v.Table } } clause.Expression = update } ================================================ FILE: clause/update_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestUpdate(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Update{}}, "UPDATE `users`", nil, }, { []clause.Interface{clause.Update{Modifier: "LOW_PRIORITY"}}, "UPDATE LOW_PRIORITY `users`", nil, }, { []clause.Interface{clause.Update{Table: clause.Table{Name: "products"}, Modifier: "LOW_PRIORITY"}}, "UPDATE LOW_PRIORITY `products`", nil, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/values.go ================================================ package clause type Values struct { Columns []Column Values [][]interface{} } // Name from clause name func (Values) Name() string { return "VALUES" } // Build build from clause func (values Values) Build(builder Builder) { if len(values.Columns) > 0 { builder.WriteByte('(') for idx, column := range values.Columns { if idx > 0 { builder.WriteByte(',') } builder.WriteQuoted(column) } builder.WriteByte(')') builder.WriteString(" VALUES ") for idx, value := range values.Values { if idx > 0 { builder.WriteByte(',') } builder.WriteByte('(') builder.AddVar(builder, value...) builder.WriteByte(')') } } else { builder.WriteString("DEFAULT VALUES") } } // MergeClause merge values clauses func (values Values) MergeClause(clause *Clause) { clause.Name = "" clause.Expression = values } ================================================ FILE: clause/values_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestValues(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{ clause.Insert{}, clause.Values{ Columns: []clause.Column{{Name: "name"}, {Name: "age"}}, Values: [][]interface{}{{"jinzhu", 18}, {"josh", 1}}, }, }, "INSERT INTO `users` (`name`,`age`) VALUES (?,?),(?,?)", []interface{}{"jinzhu", 18, "josh", 1}, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/where.go ================================================ package clause import ( "strings" ) const ( AndWithSpace = " AND " OrWithSpace = " OR " ) // Where where clause type Where struct { Exprs []Expression } // Name where clause name func (where Where) Name() string { return "WHERE" } // Build build where clause func (where Where) Build(builder Builder) { if len(where.Exprs) == 1 { if andCondition, ok := where.Exprs[0].(AndConditions); ok { where.Exprs = andCondition.Exprs } } // Switch position if the first query expression is a single Or condition for idx, expr := range where.Exprs { if v, ok := expr.(OrConditions); !ok || len(v.Exprs) > 1 { if idx != 0 { where.Exprs[0], where.Exprs[idx] = where.Exprs[idx], where.Exprs[0] } break } } buildExprs(where.Exprs, builder, AndWithSpace) } func buildExprs(exprs []Expression, builder Builder, joinCond string) { wrapInParentheses := false for idx, expr := range exprs { if idx > 0 { if v, ok := expr.(OrConditions); ok && len(v.Exprs) == 1 { builder.WriteString(OrWithSpace) } else { builder.WriteString(joinCond) } } if len(exprs) > 1 { switch v := expr.(type) { case OrConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { sql := strings.ToUpper(e.SQL) wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case AndConditions: if len(v.Exprs) == 1 { if e, ok := v.Exprs[0].(Expr); ok { sql := strings.ToUpper(e.SQL) wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } case Expr: sql := strings.ToUpper(v.SQL) wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) case NamedExpr: sql := strings.ToUpper(v.SQL) wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace) } } if wrapInParentheses { builder.WriteByte('(') expr.Build(builder) builder.WriteByte(')') wrapInParentheses = false } else { expr.Build(builder) } } } // MergeClause merge where clauses func (where Where) MergeClause(clause *Clause) { if w, ok := clause.Expression.(Where); ok { exprs := make([]Expression, len(w.Exprs)+len(where.Exprs)) copy(exprs, w.Exprs) copy(exprs[len(w.Exprs):], where.Exprs) where.Exprs = exprs } clause.Expression = where } func And(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } if len(exprs) == 1 { if _, ok := exprs[0].(OrConditions); !ok { return exprs[0] } } return AndConditions{Exprs: exprs} } type AndConditions struct { Exprs []Expression } func (and AndConditions) Build(builder Builder) { if len(and.Exprs) > 1 { builder.WriteByte('(') buildExprs(and.Exprs, builder, AndWithSpace) builder.WriteByte(')') } else { buildExprs(and.Exprs, builder, AndWithSpace) } } func Or(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } return OrConditions{Exprs: exprs} } type OrConditions struct { Exprs []Expression } func (or OrConditions) Build(builder Builder) { if len(or.Exprs) > 1 { builder.WriteByte('(') buildExprs(or.Exprs, builder, OrWithSpace) builder.WriteByte(')') } else { buildExprs(or.Exprs, builder, OrWithSpace) } } func Not(exprs ...Expression) Expression { if len(exprs) == 0 { return nil } if len(exprs) == 1 { if andCondition, ok := exprs[0].(AndConditions); ok { exprs = andCondition.Exprs } } return NotConditions{Exprs: exprs} } type NotConditions struct { Exprs []Expression } func (not NotConditions) Build(builder Builder) { anyNegationBuilder := false for _, c := range not.Exprs { if _, ok := c.(NegationExpressionBuilder); ok { anyNegationBuilder = true break } } if anyNegationBuilder { if len(not.Exprs) > 1 { builder.WriteByte('(') } for idx, c := range not.Exprs { if idx > 0 { builder.WriteString(AndWithSpace) } if negationBuilder, ok := c.(NegationExpressionBuilder); ok { negationBuilder.NegationBuild(builder) } else { builder.WriteString("NOT ") e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { builder.WriteByte('(') } } c.Build(builder) if wrapInParentheses { builder.WriteByte(')') } } } if len(not.Exprs) > 1 { builder.WriteByte(')') } } else { builder.WriteString("NOT ") if len(not.Exprs) > 1 { builder.WriteByte('(') } for idx, c := range not.Exprs { if idx > 0 { switch c.(type) { case OrConditions: builder.WriteString(OrWithSpace) default: builder.WriteString(AndWithSpace) } } e, wrapInParentheses := c.(Expr) if wrapInParentheses { sql := strings.ToUpper(e.SQL) if wrapInParentheses = strings.Contains(sql, AndWithSpace) || strings.Contains(sql, OrWithSpace); wrapInParentheses { builder.WriteByte('(') } } c.Build(builder) if wrapInParentheses { builder.WriteByte(')') } } if len(not.Exprs) > 1 { builder.WriteByte(')') } } } ================================================ FILE: clause/where_test.go ================================================ package clause_test import ( "fmt" "testing" "gorm.io/gorm/clause" ) func TestWhere(t *testing.T) { results := []struct { Clauses []clause.Interface Result string Vars []interface{} }{ { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }}, "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ?", []interface{}{"1", 18, "jinzhu"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}), clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}}, }}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ? AND `age` > ?", []interface{}{"1", "jinzhu", 18}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }}, "SELECT * FROM `users` WHERE `users`.`id` = ? OR `name` <> ?", []interface{}{"1", "jinzhu"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Gt{Column: "score", Value: 100}, clause.Like{Column: "name", Value: "%linus%"})}, }}, "SELECT * FROM `users` WHERE `users`.`id` = ? AND `age` > ? OR `name` <> ? AND (`score` > ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Neq{Column: "name", Value: "jinzhu"})}, }, clause.Where{ Exprs: []clause.Expression{clause.Or(clause.Not(clause.Gt{Column: "score", Value: 100}), clause.Like{Column: "name", Value: "%linus%"})}, }}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `name` <> ? AND (`score` <= ? OR `name` LIKE ?)", []interface{}{"1", 18, "jinzhu", 100, "%linus%"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.And(clause.Eq{Column: "age", Value: 18}, clause.Or(clause.Neq{Column: "name", Value: "jinzhu"}))}, }}, "SELECT * FROM `users` WHERE `age` = ? OR `name` <> ?", []interface{}{18, "jinzhu"}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, }}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", []interface{}{"1", 18, 100}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}}, }}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) AND `score` <= ?", []interface{}{"1", 18, 100}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}), clause.Or(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})}, }}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND `age` <= ?) OR `score` <= ?", []interface{}{"1", 18, 100}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{ clause.And(clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}), clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false})), }, }}, "SELECT * FROM `users` WHERE `users`.`id` <> ? AND `score` <= ?", []interface{}{"1", 100}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Not(clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.And(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}, WithoutParentheses: false}))}, }}, "SELECT * FROM `users` WHERE (`users`.`id` <> ? AND NOT `score` <= ?)", []interface{}{"1", 100}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{clause.Not(clause.Expr{SQL: "`score` <= ?", Vars: []interface{}{100}}, clause.Expr{SQL: "`age` <= ?", Vars: []interface{}{60}})}, }}, "SELECT * FROM `users` WHERE NOT (`score` <= ? AND `age` <= ?)", []interface{}{100, 60}, }, { []clause.Interface{clause.Select{}, clause.From{}, clause.Where{ Exprs: []clause.Expression{ clause.Not(clause.AndConditions{ Exprs: []clause.Expression{ clause.Eq{Column: clause.PrimaryColumn, Value: "1"}, clause.Gt{Column: "age", Value: 18}, }}, clause.OrConditions{ Exprs: []clause.Expression{ clause.Lt{Column: "score", Value: 100}, }, }), }}}, "SELECT * FROM `users` WHERE NOT ((`users`.`id` = ? AND `age` > ?) OR `score` < ?)", []interface{}{"1", 18, 100}, }, } for idx, result := range results { t.Run(fmt.Sprintf("case #%v", idx), func(t *testing.T) { checkBuildClauses(t, result.Clauses, result.Result, result.Vars) }) } } ================================================ FILE: clause/with.go ================================================ package clause type With struct{} ================================================ FILE: errors.go ================================================ package gorm import ( "errors" "gorm.io/gorm/logger" ) var ( // ErrRecordNotFound record not found error ErrRecordNotFound = logger.ErrRecordNotFound // ErrInvalidTransaction invalid transaction when you are trying to `Commit` or `Rollback` ErrInvalidTransaction = errors.New("invalid transaction") // ErrNotImplemented not implemented ErrNotImplemented = errors.New("not implemented") // ErrMissingWhereClause missing where clause ErrMissingWhereClause = errors.New("WHERE conditions required") // ErrUnsupportedRelation unsupported relations ErrUnsupportedRelation = errors.New("unsupported relations") // ErrPrimaryKeyRequired primary keys required ErrPrimaryKeyRequired = errors.New("primary key required") // ErrModelValueRequired model value required ErrModelValueRequired = errors.New("model value required") // ErrModelAccessibleFieldsRequired model accessible fields required ErrModelAccessibleFieldsRequired = errors.New("model accessible fields required") // ErrSubQueryRequired sub query required ErrSubQueryRequired = errors.New("sub query required") // ErrInvalidData unsupported data ErrInvalidData = errors.New("unsupported data") // ErrUnsupportedDriver unsupported driver ErrUnsupportedDriver = errors.New("unsupported driver") // ErrRegistered registered ErrRegistered = errors.New("registered") // ErrInvalidField invalid field ErrInvalidField = errors.New("invalid field") // ErrEmptySlice empty slice found ErrEmptySlice = errors.New("empty slice found") // ErrDryRunModeUnsupported dry run mode unsupported ErrDryRunModeUnsupported = errors.New("dry run mode unsupported") // ErrInvalidDB invalid db ErrInvalidDB = errors.New("invalid db") // ErrInvalidValue invalid value ErrInvalidValue = errors.New("invalid value, should be pointer to struct or slice") // ErrInvalidValueOfLength invalid values do not match length ErrInvalidValueOfLength = errors.New("invalid association values, length doesn't match") // ErrPreloadNotAllowed preload is not allowed when count is used ErrPreloadNotAllowed = errors.New("preload is not allowed when count is used") // ErrDuplicatedKey occurs when there is a unique key constraint violation ErrDuplicatedKey = errors.New("duplicated key not allowed") // ErrForeignKeyViolated occurs when there is a foreign key constraint violation ErrForeignKeyViolated = errors.New("violates foreign key constraint") // ErrCheckConstraintViolated occurs when there is a check constraint violation ErrCheckConstraintViolated = errors.New("violates check constraint") ) ================================================ FILE: finisher_api.go ================================================ package gorm import ( "context" "database/sql" "errors" "fmt" "hash/maphash" "reflect" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) // Create inserts value, returning the inserted data's primary key in value's id func (db *DB) Create(value interface{}) (tx *DB) { if db.CreateBatchSize > 0 { return db.CreateInBatches(value, db.CreateBatchSize) } tx = db.getInstance() tx.Statement.Dest = value return tx.callbacks.Create().Execute(tx) } // CreateInBatches inserts value in batches of batchSize func (db *DB) CreateInBatches(value interface{}, batchSize int) (tx *DB) { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var rowsAffected int64 tx = db.getInstance() // the reflection length judgment of the optimized value reflectLen := reflectValue.Len() callFc := func(tx *DB) error { for i := 0; i < reflectLen; i += batchSize { ends := i + batchSize if ends > reflectLen { ends = reflectLen } subtx := tx.getInstance() subtx.Statement.Dest = reflectValue.Slice(i, ends).Interface() subtx.callbacks.Create().Execute(subtx) if subtx.Error != nil { return subtx.Error } rowsAffected += subtx.RowsAffected } return nil } if tx.SkipDefaultTransaction || reflectLen <= batchSize { tx.AddError(callFc(tx.Session(&Session{}))) } else { tx.AddError(tx.Transaction(callFc)) } tx.RowsAffected = rowsAffected default: tx = db.getInstance() tx.Statement.Dest = value tx = tx.callbacks.Create().Execute(tx) } return } // Save updates value in database. If value doesn't contain a matching primary key, value is inserted. func (db *DB) Save(value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = value reflectValue := reflect.Indirect(reflect.ValueOf(value)) for reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { reflectValue = reflect.Indirect(reflectValue) } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := tx.Statement.Clauses["ON CONFLICT"]; !ok { tx = tx.Clauses(clause.OnConflict{UpdateAll: true}) } tx = tx.callbacks.Create().Execute(tx.Set("gorm:update_track_time", true)) case reflect.Struct: if err := tx.Statement.Parse(value); err == nil && tx.Statement.Schema != nil { for _, pf := range tx.Statement.Schema.PrimaryFields { if _, isZero := pf.ValueOf(tx.Statement.Context, reflectValue); isZero { return tx.callbacks.Create().Execute(tx) } } } fallthrough default: selectedUpdate := len(tx.Statement.Selects) != 0 // when updating, use all fields including those zero-value fields if !selectedUpdate { tx.Statement.Selects = append(tx.Statement.Selects, "*") } updateTx := tx.callbacks.Update().Execute(tx.Session(&Session{Initialized: true})) if updateTx.Error == nil && updateTx.RowsAffected == 0 && !updateTx.DryRun && !selectedUpdate { return tx.Session(&Session{SkipHooks: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(value) } return updateTx } return } // First finds the first record ordered by primary key, matching given conditions conds func (db *DB) First(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest return tx.callbacks.Query().Execute(tx) } // Take finds the first record returned by the database in no specified order, matching given conditions conds func (db *DB) Take(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest return tx.callbacks.Query().Execute(tx) } // Last finds the last record ordered by primary key, matching given conditions conds func (db *DB) Last(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Desc: true, }) if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) } } tx.Statement.RaiseErrorOnNotFound = true tx.Statement.Dest = dest return tx.callbacks.Query().Execute(tx) } // Find finds all records matching given conditions conds func (db *DB) Find(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) } } tx.Statement.Dest = dest return tx.callbacks.Query().Execute(tx) } // FindInBatches finds all records in batches of batchSize func (db *DB) FindInBatches(dest interface{}, batchSize int, fc func(tx *DB, batch int) error) *DB { var ( tx = db.Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }).Session(&Session{}) queryDB = tx rowsAffected int64 batch int ) // user specified offset or limit var totalSize int if c, ok := tx.Statement.Clauses["LIMIT"]; ok { if limit, ok := c.Expression.(clause.Limit); ok { if limit.Limit != nil { totalSize = *limit.Limit } if totalSize > 0 && batchSize > totalSize { batchSize = totalSize } // reset to offset to 0 in next batch tx = tx.Offset(-1).Session(&Session{}) } } for { result := queryDB.Limit(batchSize).Find(dest) rowsAffected += result.RowsAffected batch++ if result.Error == nil && result.RowsAffected != 0 { fcTx := result.Session(&Session{NewDB: true}) fcTx.RowsAffected = result.RowsAffected tx.AddError(fc(fcTx, batch)) } else if result.Error != nil { tx.AddError(result.Error) } if tx.Error != nil || int(result.RowsAffected) < batchSize { break } if totalSize > 0 { if totalSize <= int(rowsAffected) { break } if totalSize/batchSize == batch { batchSize = totalSize % batchSize } } // Optimize for-break resultsValue := reflect.Indirect(reflect.ValueOf(dest)) if result.Statement.Schema.PrioritizedPrimaryField == nil { tx.AddError(ErrPrimaryKeyRequired) break } primaryValue, zero := result.Statement.Schema.PrioritizedPrimaryField.ValueOf(tx.Statement.Context, resultsValue.Index(resultsValue.Len()-1)) if zero { tx.AddError(ErrPrimaryKeyRequired) break } queryDB = tx.Clauses(clause.Gt{Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, Value: primaryValue}) } tx.RowsAffected = rowsAffected return tx } func (db *DB) assignInterfacesToValue(values ...interface{}) { for _, value := range values { switch v := value.(type) { case []clause.Expression: for _, expr := range v { if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: if field := db.Statement.Schema.LookUpField(column); field != nil { db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } case clause.Column: if field := db.Statement.Schema.LookUpField(column.Name); field != nil { db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, eq.Value)) } } } else if andCond, ok := expr.(clause.AndConditions); ok { db.assignInterfacesToValue(andCond.Exprs) } } case clause.Expression, map[string]string, map[interface{}]interface{}, map[string]interface{}: if exprs := db.Statement.BuildCondition(value); len(exprs) > 0 { db.assignInterfacesToValue(exprs) } default: if s, err := schema.Parse(value, db.cacheStore, db.NamingStrategy); err == nil { reflectValue := reflect.Indirect(reflect.ValueOf(value)) switch reflectValue.Kind() { case reflect.Struct: for _, f := range s.Fields { if f.Readable { if v, isZero := f.ValueOf(db.Statement.Context, reflectValue); !isZero { if field := db.Statement.Schema.LookUpField(f.Name); field != nil { db.AddError(field.Set(db.Statement.Context, db.Statement.ReflectValue, v)) } } } } } } else if len(values) > 0 { if exprs := db.Statement.BuildCondition(values[0], values[1:]...); len(exprs) > 0 { db.assignInterfacesToValue(exprs) } return } } } } // FirstOrInit finds the first matching record, otherwise if not found initializes a new instance with given conds. // Each conds must be a struct or map. // // FirstOrInit never modifies the database. It is often used with Assign and Attrs. // // // assign an email if the record is not found // db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrInit(&user) // // user -> User{Name: "non_existing", Email: "fake@fake.org"} // // // assign email regardless of if record is found // db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrInit(&user) // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} func (db *DB) FirstOrInit(dest interface{}, conds ...interface{}) (tx *DB) { queryTx := db.Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) if tx = queryTx.Find(dest, conds...); tx.RowsAffected == 0 { if c, ok := tx.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { tx.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(tx.Statement.attrs) > 0 { tx.assignInterfacesToValue(tx.Statement.attrs...) } } // initialize with attrs, conds if len(tx.Statement.assigns) > 0 { tx.assignInterfacesToValue(tx.Statement.assigns...) } return } // FirstOrCreate finds the first matching record, otherwise if not found creates a new instance with given conds. // Each conds must be a struct or map. // // Using FirstOrCreate in conjunction with Assign will result in an update to the database even if the record exists. // // // assign an email if the record is not found // result := db.Where(User{Name: "non_existing"}).Attrs(User{Email: "fake@fake.org"}).FirstOrCreate(&user) // // user -> User{Name: "non_existing", Email: "fake@fake.org"} // // result.RowsAffected -> 1 // // // assign email regardless of if record is found // result := db.Where(User{Name: "jinzhu"}).Assign(User{Email: "fake@fake.org"}).FirstOrCreate(&user) // // user -> User{Name: "jinzhu", Age: 20, Email: "fake@fake.org"} // // result.RowsAffected -> 1 func (db *DB) FirstOrCreate(dest interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() queryTx := db.Session(&Session{}).Limit(1).Order(clause.OrderByColumn{ Column: clause.Column{Table: clause.CurrentTable, Name: clause.PrimaryKey}, }) result := queryTx.Find(dest, conds...) if result.Error != nil { tx.Error = result.Error return tx } if result.RowsAffected == 0 { if c, ok := result.Statement.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok { result.assignInterfacesToValue(where.Exprs) } } // initialize with attrs, conds if len(db.Statement.attrs) > 0 { result.assignInterfacesToValue(db.Statement.attrs...) } // initialize with attrs, conds if len(db.Statement.assigns) > 0 { result.assignInterfacesToValue(db.Statement.assigns...) } return tx.Create(dest) } else if len(db.Statement.assigns) > 0 { exprs := tx.Statement.BuildCondition(db.Statement.assigns[0], db.Statement.assigns[1:]...) assigns := map[string]interface{}{} for i := 0; i < len(exprs); i++ { expr := exprs[i] if eq, ok := expr.(clause.AndConditions); ok { exprs = append(exprs, eq.Exprs...) } else if eq, ok := expr.(clause.Eq); ok { switch column := eq.Column.(type) { case string: assigns[column] = eq.Value case clause.Column: assigns[column.Name] = eq.Value } } } return tx.Model(dest).Updates(assigns) } return tx } // Update updates column with value using callbacks. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Update(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} return tx.callbacks.Update().Execute(tx) } // Updates updates attributes using callbacks. values must be a struct or map. Reference: https://gorm.io/docs/update.html#Update-Changed-Fields func (db *DB) Updates(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumn(column string, value interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = map[string]interface{}{column: value} tx.Statement.SkipHooks = true return tx.callbacks.Update().Execute(tx) } func (db *DB) UpdateColumns(values interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.Dest = values tx.Statement.SkipHooks = true return tx.callbacks.Update().Execute(tx) } // Delete deletes value matching given conditions. If value contains primary key it is included in the conditions. If // value includes a deleted_at field, then Delete performs a soft delete instead by setting deleted_at with the current // time if null. func (db *DB) Delete(value interface{}, conds ...interface{}) (tx *DB) { tx = db.getInstance() if len(conds) > 0 { if exprs := tx.Statement.BuildCondition(conds[0], conds[1:]...); len(exprs) > 0 { tx.Statement.AddClause(clause.Where{Exprs: exprs}) } } tx.Statement.Dest = value return tx.callbacks.Delete().Execute(tx) } func (db *DB) Count(count *int64) (tx *DB) { tx = db.getInstance() if tx.Statement.Model == nil { tx.Statement.Model = tx.Statement.Dest defer func() { tx.Statement.Model = nil }() } if selectClause, ok := db.Statement.Clauses["SELECT"]; ok { defer func() { tx.Statement.Clauses["SELECT"] = selectClause }() } else { defer delete(tx.Statement.Clauses, "SELECT") } if len(tx.Statement.Selects) == 0 { tx.Statement.AddClause(clause.Select{Expression: clause.Expr{SQL: "count(*)"}}) } else if !strings.HasPrefix(strings.TrimSpace(strings.ToLower(tx.Statement.Selects[0])), "count(") { expr := clause.Expr{SQL: "count(*)"} if len(tx.Statement.Selects) == 1 { dbName := tx.Statement.Selects[0] fields := strings.FieldsFunc(dbName, utils.IsInvalidDBNameChar) if len(fields) == 1 || (len(fields) == 3 && (strings.ToUpper(fields[1]) == "AS" || fields[1] == ".")) { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(dbName); f != nil { dbName = f.DBName } } if tx.Statement.Distinct { expr = clause.Expr{SQL: "COUNT(DISTINCT(?))", Vars: []interface{}{clause.Column{Name: dbName}}} } else if dbName != "*" { expr = clause.Expr{SQL: "COUNT(?)", Vars: []interface{}{clause.Column{Name: dbName}}} } } } tx.Statement.AddClause(clause.Select{Expression: expr}) } if orderByClause, ok := db.Statement.Clauses["ORDER BY"]; ok { if _, ok := db.Statement.Clauses["GROUP BY"]; !ok { delete(tx.Statement.Clauses, "ORDER BY") defer func() { tx.Statement.Clauses["ORDER BY"] = orderByClause }() } } tx.Statement.Dest = count tx = tx.callbacks.Query().Execute(tx) if _, ok := db.Statement.Clauses["GROUP BY"]; ok || tx.RowsAffected != 1 { *count = tx.RowsAffected } return } func (db *DB) Row() *sql.Row { tx := db.getInstance().Set("rows", false) tx = tx.callbacks.Row().Execute(tx) row, ok := tx.Statement.Dest.(*sql.Row) if !ok && tx.DryRun { db.Logger.Error(tx.Statement.Context, ErrDryRunModeUnsupported.Error()) } return row } func (db *DB) Rows() (*sql.Rows, error) { tx := db.getInstance().Set("rows", true) tx = tx.callbacks.Row().Execute(tx) rows, ok := tx.Statement.Dest.(*sql.Rows) if !ok && tx.DryRun && tx.Error == nil { tx.Error = ErrDryRunModeUnsupported } return rows, tx.Error } // Scan scans selected value to the struct dest func (db *DB) Scan(dest interface{}) (tx *DB) { config := *db.Config currentLogger, newLogger := config.Logger, logger.Recorder.New() config.Logger = newLogger tx = db.getInstance() tx.Config = &config if rows, err := tx.Rows(); err == nil { if rows.Next() { tx.ScanRows(rows, dest) } else { tx.RowsAffected = 0 tx.AddError(rows.Err()) } tx.AddError(rows.Close()) } currentLogger.Trace(tx.Statement.Context, newLogger.BeginAt, func() (string, int64) { return newLogger.SQL, tx.RowsAffected }, tx.Error) tx.Logger = currentLogger return } // Pluck queries a single column from a model, returning in the slice dest. E.g.: // // var ages []int64 // db.Model(&users).Pluck("age", &ages) func (db *DB) Pluck(column string, dest interface{}) (tx *DB) { tx = db.getInstance() if tx.Statement.Model != nil { if tx.Statement.Parse(tx.Statement.Model) == nil { if f := tx.Statement.Schema.LookUpField(column); f != nil { column = f.DBName } } } if len(tx.Statement.Selects) != 1 { fields := strings.FieldsFunc(column, utils.IsInvalidDBNameChar) tx.Statement.AddClauseIfNotExists(clause.Select{ Distinct: tx.Statement.Distinct, Columns: []clause.Column{{Name: column, Raw: len(fields) != 1}}, }) } tx.Statement.Dest = dest return tx.callbacks.Query().Execute(tx) } func (db *DB) ScanRows(rows *sql.Rows, dest interface{}) error { tx := db.getInstance() if err := tx.Statement.Parse(dest); !errors.Is(err, schema.ErrUnsupportedDataType) { tx.AddError(err) } tx.Statement.Dest = dest tx.Statement.ReflectValue = reflect.ValueOf(dest) for tx.Statement.ReflectValue.Kind() == reflect.Ptr { elem := tx.Statement.ReflectValue.Elem() if !elem.IsValid() { elem = reflect.New(tx.Statement.ReflectValue.Type().Elem()) tx.Statement.ReflectValue.Set(elem) } tx.Statement.ReflectValue = elem } Scan(rows, tx, ScanInitialized) return tx.Error } // Connection uses a db connection to execute an arbitrary number of commands in fc. When finished, the connection is // returned to the connection pool. func (db *DB) Connection(fc func(tx *DB) error) (err error) { if db.Error != nil { return db.Error } tx := db.getInstance() sqlDB, err := tx.DB() if err != nil { return } conn, err := sqlDB.Conn(tx.Statement.Context) if err != nil { return } defer conn.Close() tx.Statement.ConnPool = conn return fc(tx) } // Transaction start a transaction as a block, return error will rollback, otherwise to commit. Transaction executes an // arbitrary number of commands in fc within a transaction. On success the changes are committed; if an error occurs // they are rolled back. func (db *DB) Transaction(fc func(tx *DB) error, opts ...*sql.TxOptions) (err error) { panicked := true if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { // nested transaction if !db.DisableNestedTransaction { spID := new(maphash.Hash).Sum64() err = db.SavePoint(fmt.Sprintf("sp%d", spID)).Error if err != nil { return } defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { db.RollbackTo(fmt.Sprintf("sp%d", spID)) } }() } err = fc(db.Session(&Session{NewDB: db.clone == 1})) } else { tx := db.Begin(opts...) if tx.Error != nil { return tx.Error } defer func() { // Make sure to rollback when panic, Block error or Commit error if panicked || err != nil { tx.Rollback() } }() if err = fc(tx); err == nil { panicked = false return tx.Commit().Error } } panicked = false return } // Begin begins a transaction with any transaction options opts func (db *DB) Begin(opts ...*sql.TxOptions) *DB { var ( // clone statement tx = db.getInstance().Session(&Session{Context: db.Statement.Context, NewDB: db.clone == 1}) opt *sql.TxOptions err error ) if len(opts) > 0 { opt = opts[0] } ctx := tx.Statement.Context if db.DefaultTransactionTimeout > 0 { if _, ok := ctx.Deadline(); !ok { ctx, _ = context.WithTimeout(ctx, db.DefaultTransactionTimeout) } } switch beginner := tx.Statement.ConnPool.(type) { case TxBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) case ConnPoolBeginner: tx.Statement.ConnPool, err = beginner.BeginTx(ctx, opt) default: err = ErrInvalidTransaction } if err != nil { tx.AddError(err) } return tx } // Commit commits the changes in a transaction func (db *DB) Commit() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil && !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Commit()) } else { db.AddError(ErrInvalidTransaction) } return db } // Rollback rollbacks the changes in a transaction func (db *DB) Rollback() *DB { if committer, ok := db.Statement.ConnPool.(TxCommitter); ok && committer != nil { if !reflect.ValueOf(committer).IsNil() { db.AddError(committer.Rollback()) } } else { db.AddError(ErrInvalidTransaction) } return db } func (db *DB) SavePoint(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { // close prepared statement, because SavePoint not support prepared statement. // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html var ( preparedStmtTx *PreparedStmtTX isPreparedStmtTx bool ) // close prepared statement, because SavePoint not support prepared statement. if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { db.Statement.ConnPool = preparedStmtTx.Tx } db.AddError(savePointer.SavePoint(db, name)) // restore prepared statement if isPreparedStmtTx { db.Statement.ConnPool = preparedStmtTx } } else { db.AddError(ErrUnsupportedDriver) } return db } func (db *DB) RollbackTo(name string) *DB { if savePointer, ok := db.Dialector.(SavePointerDialectorInterface); ok { // close prepared statement, because RollbackTo not support prepared statement. // e.g. mysql8.0 doc: https://dev.mysql.com/doc/refman/8.0/en/sql-prepared-statements.html var ( preparedStmtTx *PreparedStmtTX isPreparedStmtTx bool ) // close prepared statement, because SavePoint not support prepared statement. if preparedStmtTx, isPreparedStmtTx = db.Statement.ConnPool.(*PreparedStmtTX); isPreparedStmtTx { db.Statement.ConnPool = preparedStmtTx.Tx } db.AddError(savePointer.RollbackTo(db, name)) // restore prepared statement if isPreparedStmtTx { db.Statement.ConnPool = preparedStmtTx } } else { db.AddError(ErrUnsupportedDriver) } return db } // Exec executes raw sql func (db *DB) Exec(sql string, values ...interface{}) (tx *DB) { tx = db.getInstance() tx.Statement.SQL = strings.Builder{} if strings.Contains(sql, "@") { clause.NamedExpr{SQL: sql, Vars: values}.Build(tx.Statement) } else { clause.Expr{SQL: sql, Vars: values}.Build(tx.Statement) } return tx.callbacks.Raw().Execute(tx) } ================================================ FILE: generics.go ================================================ package gorm import ( "context" "database/sql" "errors" "fmt" "reflect" "sort" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) type result struct { Result sql.Result RowsAffected int64 } func (info *result) ModifyStatement(stmt *Statement) { stmt.Result = info } // Build implements clause.Expression interface func (result) Build(clause.Builder) { } func WithResult() *result { return &result{} } type Interface[T any] interface { Raw(sql string, values ...interface{}) ExecInterface[T] Exec(ctx context.Context, sql string, values ...interface{}) error CreateInterface[T] } type CreateInterface[T any] interface { ExecInterface[T] // chain methods available at start; Select/Omit keep CreateInterface to allow Create chaining Scopes(scopes ...func(db *Statement)) ChainInterface[T] Where(query interface{}, args ...interface{}) ChainInterface[T] Not(query interface{}, args ...interface{}) ChainInterface[T] Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] Select(query string, args ...interface{}) CreateInterface[T] Omit(columns ...string) CreateInterface[T] MapColumns(m map[string]string) ChainInterface[T] Distinct(args ...interface{}) ChainInterface[T] Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] Build(builder clause.Builder) Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) Table(name string, args ...interface{}) CreateInterface[T] Create(ctx context.Context, r *T) error CreateInBatches(ctx context.Context, r *[]T, batchSize int) error Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] } type ChainInterface[T any] interface { ExecInterface[T] Scopes(scopes ...func(db *Statement)) ChainInterface[T] Where(query interface{}, args ...interface{}) ChainInterface[T] Not(query interface{}, args ...interface{}) ChainInterface[T] Or(query interface{}, args ...interface{}) ChainInterface[T] Limit(offset int) ChainInterface[T] Offset(offset int) ChainInterface[T] Joins(query clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] Select(query string, args ...interface{}) ChainInterface[T] Omit(columns ...string) ChainInterface[T] MapColumns(m map[string]string) ChainInterface[T] Distinct(args ...interface{}) ChainInterface[T] Group(name string) ChainInterface[T] Having(query interface{}, args ...interface{}) ChainInterface[T] Order(value interface{}) ChainInterface[T] Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] Build(builder clause.Builder) Table(name string, args ...interface{}) ChainInterface[T] Delete(ctx context.Context) (rowsAffected int, err error) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) Updates(ctx context.Context, t T) (rowsAffected int, err error) Count(ctx context.Context, column string) (result int64, err error) } // SetUpdateOnlyInterface is returned by Set after chaining; only Update is allowed type SetUpdateOnlyInterface[T any] interface { Update(ctx context.Context) (rowsAffected int, err error) } // SetCreateOrUpdateInterface is returned by Set at start; Create or Update are allowed type SetCreateOrUpdateInterface[T any] interface { Create(ctx context.Context) error Update(ctx context.Context) (rowsAffected int, err error) } type ExecInterface[T any] interface { Scan(ctx context.Context, r interface{}) error First(context.Context) (T, error) Last(ctx context.Context) (T, error) Take(context.Context) (T, error) Find(ctx context.Context) ([]T, error) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error Row(ctx context.Context) *sql.Row Rows(ctx context.Context) (*sql.Rows, error) } type JoinBuilder interface { Select(...string) JoinBuilder Omit(...string) JoinBuilder Where(query interface{}, args ...interface{}) JoinBuilder Not(query interface{}, args ...interface{}) JoinBuilder Or(query interface{}, args ...interface{}) JoinBuilder } type PreloadBuilder interface { Select(...string) PreloadBuilder Omit(...string) PreloadBuilder Where(query interface{}, args ...interface{}) PreloadBuilder Not(query interface{}, args ...interface{}) PreloadBuilder Or(query interface{}, args ...interface{}) PreloadBuilder Limit(offset int) PreloadBuilder Offset(offset int) PreloadBuilder Order(value interface{}) PreloadBuilder LimitPerRecord(num int) PreloadBuilder } type op func(*DB) *DB func G[T any](db *DB, opts ...clause.Expression) Interface[T] { v := &g[T]{ db: db, ops: make([]op, 0, 5), } if len(opts) > 0 { v.ops = append(v.ops, func(db *DB) *DB { return db.Clauses(opts...) }) } v.createG = &createG[T]{ chainG: chainG[T]{ execG: execG[T]{g: v}, }, } return v } type g[T any] struct { *createG[T] db *DB ops []op } func (g *g[T]) apply(ctx context.Context) *DB { db := g.db if !db.DryRun { db = db.Session(&Session{NewDB: true, Context: ctx}).getInstance() } for _, op := range g.ops { db = op(db) } return db } func (c *g[T]) Raw(sql string, values ...interface{}) ExecInterface[T] { return execG[T]{g: &g[T]{ db: c.db, ops: append(c.ops, func(db *DB) *DB { var r T return db.Model(r).Raw(sql, values...) }), }} } func (c *g[T]) Exec(ctx context.Context, sql string, values ...interface{}) error { var r T return c.apply(ctx).Model(r).Exec(sql, values...).Error } type createG[T any] struct { chainG[T] } func (c createG[T]) Table(name string, args ...interface{}) CreateInterface[T] { return createG[T]{c.with(func(db *DB) *DB { return db.Table(name, args...) })} } func (c createG[T]) Select(query string, args ...interface{}) CreateInterface[T] { return createG[T]{c.with(func(db *DB) *DB { return db.Select(query, args...) })} } func (c createG[T]) Omit(columns ...string) CreateInterface[T] { return createG[T]{c.with(func(db *DB) *DB { return db.Omit(columns...) })} } func (c createG[T]) Set(assignments ...clause.Assigner) SetCreateOrUpdateInterface[T] { return c.processSet(assignments...) } func (c createG[T]) Create(ctx context.Context, r *T) error { return c.g.apply(ctx).Create(r).Error } func (c createG[T]) CreateInBatches(ctx context.Context, r *[]T, batchSize int) error { return c.g.apply(ctx).CreateInBatches(r, batchSize).Error } type chainG[T any] struct { execG[T] } func (c chainG[T]) getInstance() *DB { var r T return c.g.apply(context.Background()).Model(r).getInstance() } func (c chainG[T]) with(v op) chainG[T] { return chainG[T]{ execG: execG[T]{g: &g[T]{ db: c.g.db, ops: append(append([]op(nil), c.g.ops...), v), }}, } } func (c chainG[T]) Table(name string, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Table(name, args...) }) } func (c chainG[T]) Scopes(scopes ...func(db *Statement)) ChainInterface[T] { return c.with(func(db *DB) *DB { for _, fc := range scopes { fc(db.Statement) } return db }) } func (c chainG[T]) Where(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Where(query, args...) }) } func (c chainG[T]) Not(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Not(query, args...) }) } func (c chainG[T]) Or(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Or(query, args...) }) } func (c chainG[T]) Limit(offset int) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Limit(offset) }) } func (c chainG[T]) Offset(offset int) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Offset(offset) }) } type joinBuilder struct { db *DB } func (q *joinBuilder) Where(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } func (q *joinBuilder) Or(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } func (q *joinBuilder) Not(query interface{}, args ...interface{}) JoinBuilder { q.db.Where(query, args...) return q } func (q *joinBuilder) Select(columns ...string) JoinBuilder { q.db.Select(columns) return q } func (q *joinBuilder) Omit(columns ...string) JoinBuilder { q.db.Omit(columns...) return q } type preloadBuilder struct { limitPerRecord int db *DB } func (q *preloadBuilder) Where(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } func (q *preloadBuilder) Or(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } func (q *preloadBuilder) Not(query interface{}, args ...interface{}) PreloadBuilder { q.db.Where(query, args...) return q } func (q *preloadBuilder) Select(columns ...string) PreloadBuilder { q.db.Select(columns) return q } func (q *preloadBuilder) Omit(columns ...string) PreloadBuilder { q.db.Omit(columns...) return q } func (q *preloadBuilder) Limit(limit int) PreloadBuilder { q.db.Limit(limit) return q } func (q *preloadBuilder) Offset(offset int) PreloadBuilder { q.db.Offset(offset) return q } func (q *preloadBuilder) Order(value interface{}) PreloadBuilder { q.db.Order(value) return q } func (q *preloadBuilder) LimitPerRecord(num int) PreloadBuilder { q.limitPerRecord = num return q } func (c chainG[T]) Joins(jt clause.JoinTarget, on func(db JoinBuilder, joinTable clause.Table, curTable clause.Table) error) ChainInterface[T] { return c.with(func(db *DB) *DB { if jt.Table == "" { jt.Table = clause.JoinTable(strings.Split(jt.Association, ".")...).Name } q := joinBuilder{db: db.Session(&Session{NewDB: true, Initialized: true}).Table(jt.Table)} if on != nil { if err := on(&q, clause.Table{Name: jt.Table}, clause.Table{Name: clause.CurrentTable}); err != nil { db.AddError(err) } } j := join{ Name: jt.Association, Alias: jt.Table, Selects: q.db.Statement.Selects, Omits: q.db.Statement.Omits, JoinType: jt.Type, } if where, ok := q.db.Statement.Clauses["WHERE"].Expression.(clause.Where); ok { j.On = &where } if jt.Subquery != nil { joinType := j.JoinType if joinType == "" { joinType = clause.LeftJoin } if db, ok := jt.Subquery.(interface{ getInstance() *DB }); ok { stmt := db.getInstance().Statement if len(j.Selects) == 0 { j.Selects = stmt.Selects } if len(j.Omits) == 0 { j.Omits = stmt.Omits } } expr := clause.NamedExpr{SQL: fmt.Sprintf("%s JOIN (?) AS ?", joinType), Vars: []interface{}{jt.Subquery, clause.Table{Name: j.Alias}}} if j.On != nil { expr.SQL += " ON ?" expr.Vars = append(expr.Vars, clause.AndConditions{Exprs: j.On.Exprs}) } j.Expression = expr } db.Statement.Joins = append(db.Statement.Joins, j) sort.Slice(db.Statement.Joins, func(i, j int) bool { return db.Statement.Joins[i].Name < db.Statement.Joins[j].Name }) return db }) } func (c chainG[T]) Select(query string, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Select(query, args...) }) } func (c chainG[T]) Omit(columns ...string) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Omit(columns...) }) } func (c chainG[T]) MapColumns(m map[string]string) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.MapColumns(m) }) } func (c chainG[T]) Set(assignments ...clause.Assigner) SetUpdateOnlyInterface[T] { return c.processSet(assignments...) } func (c chainG[T]) Distinct(args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Distinct(args...) }) } func (c chainG[T]) Group(name string) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Group(name) }) } func (c chainG[T]) Having(query interface{}, args ...interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Having(query, args...) }) } func (c chainG[T]) Order(value interface{}) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Order(value) }) } func (c chainG[T]) Preload(association string, query func(db PreloadBuilder) error) ChainInterface[T] { return c.with(func(db *DB) *DB { return db.Preload(association, func(tx *DB) *DB { q := preloadBuilder{db: tx.getInstance()} if query != nil { if err := query(&q); err != nil { db.AddError(err) } } relation, ok := db.Statement.Schema.Relationships.Relations[association] if !ok { if preloadFields := strings.Split(association, "."); len(preloadFields) > 1 { relationships := &db.Statement.Schema.Relationships for _, field := range preloadFields { var ok bool relation, ok = relationships.Relations[field] if ok { relationships = &relation.FieldSchema.Relationships } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil } } } else { db.AddError(fmt.Errorf("relation %s not found", association)) return nil } } if q.limitPerRecord > 0 { if relation.JoinTable != nil { tx.AddError(fmt.Errorf("many2many relation %s don't support LimitPerRecord", association)) return tx } refColumns := []clause.Column{} for _, rel := range relation.References { if rel.OwnPrimaryKey { refColumns = append(refColumns, clause.Column{Name: rel.ForeignKey.DBName}) } } if len(refColumns) != 0 { selectExpr := clause.CommaExpression{} for _, column := range q.db.Statement.Selects { selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column}}}) } if len(selectExpr.Exprs) == 0 { selectExpr.Exprs = []clause.Expression{clause.Expr{SQL: "*", Vars: []interface{}{}}} } partitionBy := clause.CommaExpression{} for _, column := range refColumns { partitionBy.Exprs = append(partitionBy.Exprs, clause.Expr{SQL: "?", Vars: []interface{}{clause.Column{Name: column.Name}}}) } rnnColumn := clause.Column{Name: "gorm_preload_rnn"} sql := "ROW_NUMBER() OVER (PARTITION BY ? ?)" vars := []interface{}{partitionBy} if orderBy, ok := q.db.Statement.Clauses["ORDER BY"]; ok { vars = append(vars, orderBy) } else { vars = append(vars, clause.Clause{Name: "ORDER BY", Expression: clause.OrderBy{ Columns: []clause.OrderByColumn{{Column: clause.PrimaryColumn, Desc: true}}, }}) } vars = append(vars, rnnColumn) selectExpr.Exprs = append(selectExpr.Exprs, clause.Expr{SQL: sql + " AS ?", Vars: vars}) q.db.Clauses(clause.Select{Expression: selectExpr}) return q.db.Session(&Session{NewDB: true}).Unscoped().Table("(?) t", q.db).Where("? <= ?", rnnColumn, q.limitPerRecord) } } return q.db }) }) } func (c chainG[T]) Delete(ctx context.Context) (rowsAffected int, err error) { r := new(T) res := c.g.apply(ctx).Delete(r) return int(res.RowsAffected), res.Error } func (c chainG[T]) Update(ctx context.Context, name string, value any) (rowsAffected int, err error) { var r T res := c.g.apply(ctx).Model(r).Update(name, value) return int(res.RowsAffected), res.Error } func (c chainG[T]) Updates(ctx context.Context, t T) (rowsAffected int, err error) { res := c.g.apply(ctx).Updates(t) return int(res.RowsAffected), res.Error } func (c chainG[T]) Count(ctx context.Context, column string) (result int64, err error) { var r T err = c.g.apply(ctx).Model(r).Select(column).Count(&result).Error return } func (c chainG[T]) Build(builder clause.Builder) { subdb := c.getInstance() subdb.Logger = logger.Discard subdb.DryRun = true if stmt, ok := builder.(*Statement); ok { if subdb.Statement.SQL.Len() > 0 { var ( vars = subdb.Statement.Vars sql = subdb.Statement.SQL.String() ) subdb.Statement.Vars = make([]interface{}, 0, len(vars)) for _, vv := range vars { subdb.Statement.Vars = append(subdb.Statement.Vars, vv) bindvar := strings.Builder{} subdb.BindVarTo(&bindvar, subdb.Statement, vv) sql = strings.Replace(sql, bindvar.String(), "?", 1) } subdb.Statement.SQL.Reset() subdb.Statement.Vars = stmt.Vars if strings.Contains(sql, "@") { clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) } else { clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) } } else { subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) subdb.callbacks.Query().Execute(subdb) } builder.WriteString(subdb.Statement.SQL.String()) stmt.Vars = subdb.Statement.Vars } } type execG[T any] struct { g *g[T] } func (g execG[T]) First(ctx context.Context) (T, error) { var r T err := g.g.apply(ctx).First(&r).Error return r, err } func (g execG[T]) Scan(ctx context.Context, result interface{}) error { var r T err := g.g.apply(ctx).Model(r).Find(result).Error return err } func (g execG[T]) Last(ctx context.Context) (T, error) { var r T err := g.g.apply(ctx).Last(&r).Error return r, err } func (g execG[T]) Take(ctx context.Context) (T, error) { var r T err := g.g.apply(ctx).Take(&r).Error return r, err } func (g execG[T]) Find(ctx context.Context) ([]T, error) { var r []T err := g.g.apply(ctx).Find(&r).Error return r, err } func (g execG[T]) FindInBatches(ctx context.Context, batchSize int, fc func(data []T, batch int) error) error { var data []T return g.g.apply(ctx).FindInBatches(&data, batchSize, func(tx *DB, batch int) error { return fc(data, batch) }).Error } func (g execG[T]) Row(ctx context.Context) *sql.Row { var r T return g.g.apply(ctx).Model(r).Row() } func (g execG[T]) Rows(ctx context.Context) (*sql.Rows, error) { var r T return g.g.apply(ctx).Model(r).Rows() } func (c chainG[T]) processSet(items ...clause.Assigner) setCreateOrUpdateG[T] { var ( assigns []clause.Assignment assocOps []clause.Association ) for _, item := range items { // Check if it's an AssociationAssigner if assocAssigner, ok := item.(clause.AssociationAssigner); ok { assocOps = append(assocOps, assocAssigner.AssociationAssignments()...) } else { assigns = append(assigns, item.Assignments()...) } } return setCreateOrUpdateG[T]{ c: c, assigns: assigns, assocOps: assocOps, } } // setCreateOrUpdateG[T] is a struct that holds operations to be executed in a batch. // It supports regular assignments and association operations. type setCreateOrUpdateG[T any] struct { c chainG[T] assigns []clause.Assignment assocOps []clause.Association } func (s setCreateOrUpdateG[T]) Update(ctx context.Context) (rowsAffected int, err error) { // Execute association operations for _, assocOp := range s.assocOps { if err := s.executeAssociationOperation(ctx, assocOp); err != nil { return 0, err } } // Execute assignment operations if len(s.assigns) > 0 { var r T res := s.c.g.apply(ctx).Model(r).Clauses(clause.Set(s.assigns)).Updates(map[string]interface{}{}) return int(res.RowsAffected), res.Error } return 0, nil } func (s setCreateOrUpdateG[T]) Create(ctx context.Context) error { // Execute association operations for _, assocOp := range s.assocOps { if err := s.executeAssociationOperation(ctx, assocOp); err != nil { return err } } // Execute assignment operations if len(s.assigns) > 0 { data := make(map[string]interface{}, len(s.assigns)) for _, a := range s.assigns { data[a.Column.Name] = a.Value } var r T return s.c.g.apply(ctx).Model(r).Create(data).Error } return nil } // executeAssociationOperation executes an association operation func (s setCreateOrUpdateG[T]) executeAssociationOperation(ctx context.Context, op clause.Association) error { var r T base := s.c.g.apply(ctx).Model(r) switch op.Type { case clause.OpCreate: return s.handleAssociationCreate(ctx, base, op) case clause.OpUnlink, clause.OpDelete, clause.OpUpdate: return s.handleAssociation(ctx, base, op) default: return fmt.Errorf("unknown association operation type: %v", op.Type) } } func (s setCreateOrUpdateG[T]) handleAssociationCreate(ctx context.Context, base *DB, op clause.Association) error { if len(op.Set) > 0 { return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { data := make(map[string]interface{}, len(op.Set)) for _, a := range op.Set { data[a.Column.Name] = a.Value } return assoc.Append(data) }, op.Association) } return s.handleAssociationForOwners(base, ctx, func(owner T, assoc *Association) error { return assoc.Append(op.Values...) }, op.Association) } // handleAssociationForOwners is a helper function that handles associations for all owners func (s setCreateOrUpdateG[T]) handleAssociationForOwners(base *DB, ctx context.Context, handler func(owner T, association *Association) error, associationName string) error { var owners []T if err := base.Find(&owners).Error; err != nil { return err } for _, owner := range owners { assoc := base.Session(&Session{NewDB: true, Context: ctx}).Model(&owner).Association(associationName) if assoc.Error != nil { return assoc.Error } if err := handler(owner, assoc); err != nil { return err } } return nil } func (s setCreateOrUpdateG[T]) handleAssociation(ctx context.Context, base *DB, op clause.Association) error { assoc := base.Association(op.Association) if assoc.Error != nil { return assoc.Error } var ( rel = assoc.Relationship assocModel = reflect.New(rel.FieldSchema.ModelType).Interface() fkNil = map[string]any{} setMap = make(map[string]any, len(op.Set)) ownerPKNames []string ownerFKNames []string primaryColumns []any foreignColumns []any ) for _, a := range op.Set { setMap[a.Column.Name] = a.Value } for _, ref := range rel.References { fkNil[ref.ForeignKey.DBName] = nil if ref.OwnPrimaryKey && ref.PrimaryKey != nil { ownerPKNames = append(ownerPKNames, ref.PrimaryKey.DBName) primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) foreignColumns = append(foreignColumns, clause.Column{Name: ref.ForeignKey.DBName}) } else if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { ownerFKNames = append(ownerFKNames, ref.ForeignKey.DBName) primaryColumns = append(primaryColumns, clause.Column{Name: ref.PrimaryKey.DBName}) } } assocDB := s.c.g.db.Session(&Session{NewDB: true, Context: ctx}).Model(assocModel).Where(op.Conditions) switch rel.Type { case schema.HasOne, schema.HasMany: assocDB = assocDB.Where("? IN (?)", foreignColumns, base.Select(ownerPKNames)) switch op.Type { case clause.OpUnlink: return assocDB.Updates(fkNil).Error case clause.OpDelete: return assocDB.Delete(assocModel).Error case clause.OpUpdate: return assocDB.Updates(setMap).Error } case schema.BelongsTo: switch op.Type { case clause.OpDelete: return base.Transaction(func(tx *DB) error { assocDB.Statement.ConnPool = tx.Statement.ConnPool base.Statement.ConnPool = tx.Statement.ConnPool if err := assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Delete(assocModel).Error; err != nil { return err } return base.Updates(fkNil).Error }) case clause.OpUnlink: return base.Updates(fkNil).Error case clause.OpUpdate: return assocDB.Where("? IN (?)", primaryColumns, base.Select(ownerFKNames)).Updates(setMap).Error } case schema.Many2Many: joinModel := reflect.New(rel.JoinTable.ModelType).Interface() joinDB := base.Session(&Session{NewDB: true, Context: ctx}).Model(joinModel) // EXISTS owners: owners.pk = join.owner_fk for all owner refs ownersExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.Schema.Table).Select("1") for _, ref := range rel.References { if ref.OwnPrimaryKey && ref.PrimaryKey != nil { ownersExists = ownersExists.Where(clause.Eq{ Column: clause.Column{Table: rel.Schema.Table, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, }) } } // EXISTS related: related.pk = join.rel_fk for all related refs, plus optional conditions relatedExists := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Select("1") for _, ref := range rel.References { if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { relatedExists = relatedExists.Where(clause.Eq{ Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, Value: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, }) } } relatedExists = relatedExists.Where(op.Conditions) switch op.Type { case clause.OpUnlink, clause.OpDelete: joinDB = joinDB.Where("EXISTS (?)", ownersExists) if len(op.Conditions) > 0 { joinDB = joinDB.Where("EXISTS (?)", relatedExists) } return joinDB.Delete(nil).Error case clause.OpUpdate: // Update related table rows that have join rows matching owners relatedDB := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.FieldSchema.Table).Where(op.Conditions) // correlated join subquery: join.rel_fk = related.pk AND EXISTS owners joinSub := base.Session(&Session{NewDB: true, Context: ctx}).Table(rel.JoinTable.Table).Select("1") for _, ref := range rel.References { if !ref.OwnPrimaryKey && ref.PrimaryKey != nil { joinSub = joinSub.Where(clause.Eq{ Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, }) } } joinSub = joinSub.Where("EXISTS (?)", ownersExists) return relatedDB.Where("EXISTS (?)", joinSub).Updates(setMap).Error } } return errors.New("unsupported relationship") } ================================================ FILE: go.mod ================================================ module gorm.io/gorm go 1.18 require ( github.com/jinzhu/inflection v1.0.0 github.com/jinzhu/now v1.1.5 golang.org/x/text v0.20.0 ) require ( github.com/mattn/go-sqlite3 v1.14.22 // indirect gorm.io/driver/sqlite v1.6.0 // indirect ) ================================================ FILE: go.sum ================================================ github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= golang.org/x/text v0.20.0 h1:gK/Kv2otX8gz+wn7Rmb3vT96ZwuoxnQlY+HlJVj7Qug= golang.org/x/text v0.20.0/go.mod h1:D4IsuqiFMhST5bX19pQ9ikHC2GsaKyk/oF+pn3ducp4= gorm.io/driver/sqlite v1.6.0 h1:WHRRrIiulaPiPFmDcod6prc4l2VGVWHz80KspNsxSfQ= gorm.io/driver/sqlite v1.6.0/go.mod h1:AO9V1qIQddBESngQUKWL9yoH93HIeA1X6V633rBwyT8= ================================================ FILE: gorm.go ================================================ package gorm import ( "context" "database/sql" "fmt" "reflect" "sort" "sync" "time" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) // for Config.cacheStore store PreparedStmtDB key const preparedStmtDBKey = "preparedStmt" // Config GORM config type Config struct { // GORM perform single create, update, delete operations in transactions by default to ensure database data integrity // You can disable it by setting `SkipDefaultTransaction` to true SkipDefaultTransaction bool DefaultTransactionTimeout time.Duration DefaultContextTimeout time.Duration // NamingStrategy tables, columns naming strategy NamingStrategy schema.Namer // FullSaveAssociations full save associations FullSaveAssociations bool // Logger Logger logger.Interface // NowFunc the function to be used when creating a new timestamp NowFunc func() time.Time // DryRun generate sql without execute DryRun bool // PrepareStmt executes the given query in cached statement PrepareStmt bool // PrepareStmt cache support LRU expired, // default maxsize=int64 Max value and ttl=1h PrepareStmtMaxSize int PrepareStmtTTL time.Duration // DisableAutomaticPing DisableAutomaticPing bool // DisableForeignKeyConstraintWhenMigrating DisableForeignKeyConstraintWhenMigrating bool // IgnoreRelationshipsWhenMigrating IgnoreRelationshipsWhenMigrating bool // DisableNestedTransaction disable nested transaction DisableNestedTransaction bool // AllowGlobalUpdate allow global update AllowGlobalUpdate bool // QueryFields executes the SQL query with all fields of the table QueryFields bool // CreateBatchSize default create batch size CreateBatchSize int // TranslateError enabling error translation TranslateError bool // PropagateUnscoped propagate Unscoped to every other nested statement PropagateUnscoped bool // ClauseBuilders clause builder ClauseBuilders map[string]clause.ClauseBuilder // ConnPool db conn pool ConnPool ConnPool // Dialector database dialector Dialector // Plugins registered plugins Plugins map[string]Plugin callbacks *callbacks cacheStore *sync.Map } // Apply update config to new config func (c *Config) Apply(config *Config) error { if config != c { *config = *c } return nil } // AfterInitialize initialize plugins after db connected func (c *Config) AfterInitialize(db *DB) error { if db != nil { for _, plugin := range c.Plugins { if err := plugin.Initialize(db); err != nil { return err } } } return nil } // Option gorm option interface type Option interface { Apply(*Config) error AfterInitialize(*DB) error } // DB GORM DB definition type DB struct { *Config Error error RowsAffected int64 Statement *Statement clone int } // Session session config when create session with Session() method type Session struct { DryRun bool PrepareStmt bool NewDB bool Initialized bool SkipHooks bool SkipDefaultTransaction bool DisableNestedTransaction bool AllowGlobalUpdate bool FullSaveAssociations bool PropagateUnscoped bool QueryFields bool Context context.Context Logger logger.Interface NowFunc func() time.Time CreateBatchSize int } // Open initialize db session based on dialector func Open(dialector Dialector, opts ...Option) (db *DB, err error) { config := &Config{} sort.Slice(opts, func(i, j int) bool { _, isConfig := opts[i].(*Config) _, isConfig2 := opts[j].(*Config) return isConfig && !isConfig2 }) if len(opts) > 0 { if c, ok := opts[0].(*Config); ok { config = c } else { opts = append([]Option{config}, opts...) } } var skipAfterInitialize bool for _, opt := range opts { if opt != nil { if applyErr := opt.Apply(config); applyErr != nil { return nil, applyErr } defer func(opt Option) { if skipAfterInitialize { return } if errr := opt.AfterInitialize(db); errr != nil { err = errr } }(opt) } } if d, ok := dialector.(interface{ Apply(*Config) error }); ok { if err = d.Apply(config); err != nil { return } } if config.NamingStrategy == nil { config.NamingStrategy = schema.NamingStrategy{IdentifierMaxLength: 64} // Default Identifier length is 64 } if config.Logger == nil { config.Logger = logger.Default } if config.NowFunc == nil { config.NowFunc = func() time.Time { return time.Now().Local() } } if dialector != nil { config.Dialector = dialector } if config.Plugins == nil { config.Plugins = map[string]Plugin{} } if config.cacheStore == nil { config.cacheStore = &sync.Map{} } db = &DB{Config: config, clone: 1} db.callbacks = initializeCallbacks(db) if config.ClauseBuilders == nil { config.ClauseBuilders = map[string]clause.ClauseBuilder{} } if config.Dialector != nil { err = config.Dialector.Initialize(db) if err != nil { if db, _ := db.DB(); db != nil { _ = db.Close() } // DB is not initialized, so we skip AfterInitialize skipAfterInitialize = true return } if config.TranslateError { if _, ok := db.Dialector.(ErrorTranslator); !ok { config.Logger.Warn(context.Background(), "The TranslateError option is enabled, but the Dialector %s does not implement ErrorTranslator.", db.Dialector.Name()) } } } if config.PrepareStmt { preparedStmt := NewPreparedStmtDB(db.ConnPool, config.PrepareStmtMaxSize, config.PrepareStmtTTL) db.cacheStore.Store(preparedStmtDBKey, preparedStmt) db.ConnPool = preparedStmt } db.Statement = &Statement{ DB: db, ConnPool: db.ConnPool, Context: context.Background(), Clauses: map[string]clause.Clause{}, } if err == nil && !config.DisableAutomaticPing { if pinger, ok := db.ConnPool.(interface{ Ping() error }); ok { err = pinger.Ping() if err != nil { if db, _ := db.DB(); db != nil { _ = db.Close() } } } } if err != nil { config.Logger.Error(context.Background(), "failed to initialize database, got error %v", err) } return } // Session create new db session func (db *DB) Session(config *Session) *DB { var ( txConfig = *db.Config tx = &DB{ Config: &txConfig, Statement: db.Statement, Error: db.Error, clone: 1, } ) if config.CreateBatchSize > 0 { tx.Config.CreateBatchSize = config.CreateBatchSize } if config.SkipDefaultTransaction { tx.Config.SkipDefaultTransaction = true } if config.AllowGlobalUpdate { txConfig.AllowGlobalUpdate = true } if config.FullSaveAssociations { txConfig.FullSaveAssociations = true } if config.PropagateUnscoped { txConfig.PropagateUnscoped = true } if config.Context != nil || config.PrepareStmt || config.SkipHooks { tx.Statement = tx.Statement.clone() tx.Statement.DB = tx } if config.Context != nil { tx.Statement.Context = config.Context } if config.PrepareStmt { var preparedStmt *PreparedStmtDB if v, ok := db.cacheStore.Load(preparedStmtDBKey); ok { preparedStmt = v.(*PreparedStmtDB) } else { preparedStmt = NewPreparedStmtDB(db.ConnPool, db.PrepareStmtMaxSize, db.PrepareStmtTTL) db.cacheStore.Store(preparedStmtDBKey, preparedStmt) } switch t := tx.Statement.ConnPool.(type) { case Tx: tx.Statement.ConnPool = &PreparedStmtTX{ Tx: t, PreparedStmtDB: preparedStmt, } default: tx.Statement.ConnPool = &PreparedStmtDB{ ConnPool: db.Config.ConnPool, Mux: preparedStmt.Mux, Stmts: preparedStmt.Stmts, } } txConfig.ConnPool = tx.Statement.ConnPool txConfig.PrepareStmt = true } if config.SkipHooks { tx.Statement.SkipHooks = true } if config.DisableNestedTransaction { txConfig.DisableNestedTransaction = true } if !config.NewDB { tx.clone = 2 } if config.DryRun { tx.Config.DryRun = true } if config.QueryFields { tx.Config.QueryFields = true } if config.Logger != nil { tx.Config.Logger = config.Logger } if config.NowFunc != nil { tx.Config.NowFunc = config.NowFunc } if config.Initialized { tx = tx.getInstance() } return tx } // WithContext change current instance db's context to ctx func (db *DB) WithContext(ctx context.Context) *DB { return db.Session(&Session{Context: ctx}) } // Debug start debug mode func (db *DB) Debug() (tx *DB) { tx = db.getInstance() return tx.Session(&Session{ Logger: db.Logger.LogMode(logger.Info), }) } // Set store value with key into current db instance's context func (db *DB) Set(key string, value interface{}) *DB { tx := db.getInstance() tx.Statement.Settings.Store(key, value) return tx } // Get get value with key from current db instance's context func (db *DB) Get(key string) (interface{}, bool) { return db.Statement.Settings.Load(key) } // InstanceSet store value with key into current db instance's context func (db *DB) InstanceSet(key string, value interface{}) *DB { tx := db.getInstance() tx.Statement.Settings.Store(fmt.Sprintf("%p", tx.Statement)+key, value) return tx } // InstanceGet get value with key from current db instance's context func (db *DB) InstanceGet(key string) (interface{}, bool) { return db.Statement.Settings.Load(fmt.Sprintf("%p", db.Statement) + key) } // Callback returns callback manager func (db *DB) Callback() *callbacks { return db.callbacks } // AddError add error to db func (db *DB) AddError(err error) error { if err != nil { if db.Config.TranslateError { if errTranslator, ok := db.Dialector.(ErrorTranslator); ok { err = errTranslator.Translate(err) } } if db.Error == nil { db.Error = err } else { db.Error = fmt.Errorf("%v; %w", db.Error, err) } } return db.Error } // DB returns `*sql.DB` func (db *DB) DB() (*sql.DB, error) { connPool := db.ConnPool if db.Statement != nil && db.Statement.ConnPool != nil { connPool = db.Statement.ConnPool } if tx, ok := connPool.(*sql.Tx); ok && tx != nil { return (*sql.DB)(reflect.ValueOf(tx).Elem().FieldByName("db").UnsafePointer()), nil } if dbConnector, ok := connPool.(GetDBConnector); ok && dbConnector != nil { if sqldb, err := dbConnector.GetDBConn(); sqldb != nil || err != nil { return sqldb, err } } if sqldb, ok := connPool.(*sql.DB); ok && sqldb != nil { return sqldb, nil } return nil, ErrInvalidDB } func (db *DB) getInstance() *DB { if db.clone > 0 { tx := &DB{Config: db.Config, Error: db.Error} if db.clone == 1 { // clone with new statement tx.Statement = &Statement{ DB: tx, ConnPool: db.Statement.ConnPool, Context: db.Statement.Context, Clauses: map[string]clause.Clause{}, Vars: make([]interface{}, 0, 8), SkipHooks: db.Statement.SkipHooks, } if db.Config.PropagateUnscoped { tx.Statement.Unscoped = db.Statement.Unscoped } } else { // with clone statement tx.Statement = db.Statement.clone() tx.Statement.DB = tx } return tx } return db } // Expr returns clause.Expr, which can be used to pass SQL expression as params func Expr(expr string, args ...interface{}) clause.Expr { return clause.Expr{SQL: expr, Vars: args} } // SetupJoinTable setup join table schema func (db *DB) SetupJoinTable(model interface{}, field string, joinTable interface{}) error { var ( tx = db.getInstance() stmt = tx.Statement modelSchema, joinSchema *schema.Schema ) err := stmt.Parse(model) if err != nil { return err } modelSchema = stmt.Schema err = stmt.Parse(joinTable) if err != nil { return err } joinSchema = stmt.Schema relation, ok := modelSchema.Relationships.Relations[field] isRelation := ok && relation.JoinTable != nil if !isRelation { return fmt.Errorf("failed to find relation: %s", field) } for _, ref := range relation.References { f := joinSchema.LookUpField(ref.ForeignKey.DBName) if f == nil { return fmt.Errorf("missing field %s for join table", ref.ForeignKey.DBName) } f.DataType = ref.ForeignKey.DataType f.GORMDataType = ref.ForeignKey.GORMDataType if f.Size == 0 { f.Size = ref.ForeignKey.Size } ref.ForeignKey = f } for name, rel := range relation.JoinTable.Relationships.Relations { if _, ok := joinSchema.Relationships.Relations[name]; !ok { rel.Schema = joinSchema joinSchema.Relationships.Relations[name] = rel } } relation.JoinTable = joinSchema return nil } // Use use plugin func (db *DB) Use(plugin Plugin) error { name := plugin.Name() if _, ok := db.Plugins[name]; ok { return ErrRegistered } if err := plugin.Initialize(db); err != nil { return err } db.Plugins[name] = plugin return nil } // ToSQL for generate SQL string. // // db.ToSQL(func(tx *gorm.DB) *gorm.DB { // return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}) // .Limit(10).Offset(5) // .Order("name ASC") // .First(&User{}) // }) func (db *DB) ToSQL(queryFn func(tx *DB) *DB) string { tx := queryFn(db.Session(&Session{DryRun: true, SkipDefaultTransaction: true}).getInstance()) stmt := tx.Statement return db.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) } ================================================ FILE: interfaces.go ================================================ package gorm import ( "context" "database/sql" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) // Dialector GORM database dialector type Dialector interface { Name() string Initialize(*DB) error Migrator(db *DB) Migrator DataTypeOf(*schema.Field) string DefaultValueOf(*schema.Field) clause.Expression BindVarTo(writer clause.Writer, stmt *Statement, v interface{}) QuoteTo(clause.Writer, string) Explain(sql string, vars ...interface{}) string } // Plugin GORM plugin interface type Plugin interface { Name() string Initialize(*DB) error } type ParamsFilter interface { ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) } // ConnPool db conns pool interface type ConnPool interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row } // SavePointerDialectorInterface save pointer interface type SavePointerDialectorInterface interface { SavePoint(tx *DB, name string) error RollbackTo(tx *DB, name string) error } // TxBeginner tx beginner type TxBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) } // ConnPoolBeginner conn pool beginner type ConnPoolBeginner interface { BeginTx(ctx context.Context, opts *sql.TxOptions) (ConnPool, error) } // TxCommitter tx committer type TxCommitter interface { Commit() error Rollback() error } // Tx sql.Tx interface type Tx interface { ConnPool TxCommitter StmtContext(ctx context.Context, stmt *sql.Stmt) *sql.Stmt } // Valuer gorm valuer interface type Valuer interface { GormValue(context.Context, *DB) clause.Expr } // GetDBConnector SQL db connector type GetDBConnector interface { GetDBConn() (*sql.DB, error) } // Rows rows interface type Rows interface { Columns() ([]string, error) ColumnTypes() ([]*sql.ColumnType, error) Next() bool Scan(dest ...interface{}) error Err() error Close() error } type ErrorTranslator interface { Translate(err error) error } ================================================ FILE: internal/lru/lru.go ================================================ package lru // golang -lru // https://github.com/hashicorp/golang-lru import ( "sync" "time" ) // EvictCallback is used to get a callback when a cache entry is evicted type EvictCallback[K comparable, V any] func(key K, value V) // LRU implements a thread-safe LRU with expirable entries. type LRU[K comparable, V any] struct { size int evictList *LruList[K, V] items map[K]*Entry[K, V] onEvict EvictCallback[K, V] // expirable options mu sync.RWMutex ttl time.Duration done chan struct{} // buckets for expiration buckets []bucket[K, V] // uint8 because it's number between 0 and numBuckets nextCleanupBucket uint8 } // bucket is a container for holding entries to be expired type bucket[K comparable, V any] struct { entries map[K]*Entry[K, V] newestEntry time.Time } // noEvictionTTL - very long ttl to prevent eviction const noEvictionTTL = time.Hour * 24 * 365 * 10 // because of uint8 usage for nextCleanupBucket, should not exceed 256. // casting it as uint8 explicitly requires type conversions in multiple places const numBuckets = 100 // NewLRU returns a new thread-safe cache with expirable entries. // // Size parameter set to 0 makes cache of unlimited size, e.g. turns LRU mechanism off. // // Providing 0 TTL turns expiring off. // // Delete expired entries every 1/100th of ttl value. Goroutine which deletes expired entries runs indefinitely. func NewLRU[K comparable, V any](size int, onEvict EvictCallback[K, V], ttl time.Duration) *LRU[K, V] { if size < 0 { size = 0 } if ttl <= 0 { ttl = noEvictionTTL } res := LRU[K, V]{ ttl: ttl, size: size, evictList: NewList[K, V](), items: make(map[K]*Entry[K, V]), onEvict: onEvict, done: make(chan struct{}), } // initialize the buckets res.buckets = make([]bucket[K, V], numBuckets) for i := 0; i < numBuckets; i++ { res.buckets[i] = bucket[K, V]{entries: make(map[K]*Entry[K, V])} } // enable deleteExpired() running in separate goroutine for cache with non-zero TTL // // Important: done channel is never closed, so deleteExpired() goroutine will never exit, // it's decided to add functionality to close it in the version later than v2. if res.ttl != noEvictionTTL { go func(done <-chan struct{}) { ticker := time.NewTicker(res.ttl / numBuckets) defer ticker.Stop() for { select { case <-done: return case <-ticker.C: res.deleteExpired() } } }(res.done) } return &res } // Purge clears the cache completely. // onEvict is called for each evicted key. func (c *LRU[K, V]) Purge() { c.mu.Lock() defer c.mu.Unlock() for k, v := range c.items { if c.onEvict != nil { c.onEvict(k, v.Value) } delete(c.items, k) } for _, b := range c.buckets { for _, ent := range b.entries { delete(b.entries, ent.Key) } } c.evictList.Init() } // Add adds a value to the cache. Returns true if an eviction occurred. // Returns false if there was no eviction: the item was already in the cache, // or the size was not exceeded. func (c *LRU[K, V]) Add(key K, value V) (evicted bool) { c.mu.Lock() defer c.mu.Unlock() now := time.Now() // Check for existing item if ent, ok := c.items[key]; ok { c.evictList.MoveToFront(ent) c.removeFromBucket(ent) // remove the entry from its current bucket as expiresAt is renewed ent.Value = value ent.ExpiresAt = now.Add(c.ttl) c.addToBucket(ent) return false } // Add new item ent := c.evictList.PushFrontExpirable(key, value, now.Add(c.ttl)) c.items[key] = ent c.addToBucket(ent) // adds the entry to the appropriate bucket and sets entry.expireBucket evict := c.size > 0 && c.evictList.Length() > c.size // Verify size not exceeded if evict { c.removeOldest() } return evict } // Get looks up a key's value from the cache. func (c *LRU[K, V]) Get(key K) (value V, ok bool) { c.mu.Lock() defer c.mu.Unlock() var ent *Entry[K, V] if ent, ok = c.items[key]; ok { // Expired item check if time.Now().After(ent.ExpiresAt) { return value, false } c.evictList.MoveToFront(ent) return ent.Value, true } return } // Contains checks if a key is in the cache, without updating the recent-ness // or deleting it for being stale. func (c *LRU[K, V]) Contains(key K) (ok bool) { c.mu.RLock() defer c.mu.RUnlock() _, ok = c.items[key] return ok } // Peek returns the key value (or undefined if not found) without updating // the "recently used"-ness of the key. func (c *LRU[K, V]) Peek(key K) (value V, ok bool) { c.mu.RLock() defer c.mu.RUnlock() var ent *Entry[K, V] if ent, ok = c.items[key]; ok { // Expired item check if time.Now().After(ent.ExpiresAt) { return value, false } return ent.Value, true } return } // Remove removes the provided key from the cache, returning if the // key was contained. func (c *LRU[K, V]) Remove(key K) bool { c.mu.Lock() defer c.mu.Unlock() if ent, ok := c.items[key]; ok { c.removeElement(ent) return true } return false } // RemoveOldest removes the oldest item from the cache. func (c *LRU[K, V]) RemoveOldest() (key K, value V, ok bool) { c.mu.Lock() defer c.mu.Unlock() if ent := c.evictList.Back(); ent != nil { c.removeElement(ent) return ent.Key, ent.Value, true } return } // GetOldest returns the oldest entry func (c *LRU[K, V]) GetOldest() (key K, value V, ok bool) { c.mu.RLock() defer c.mu.RUnlock() if ent := c.evictList.Back(); ent != nil { return ent.Key, ent.Value, true } return } func (c *LRU[K, V]) KeyValues() map[K]V { c.mu.RLock() defer c.mu.RUnlock() maps := make(map[K]V) now := time.Now() for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() { if now.After(ent.ExpiresAt) { continue } maps[ent.Key] = ent.Value // keys = append(keys, ent.Key) } return maps } // Keys returns a slice of the keys in the cache, from oldest to newest. // Expired entries are filtered out. func (c *LRU[K, V]) Keys() []K { c.mu.RLock() defer c.mu.RUnlock() keys := make([]K, 0, len(c.items)) now := time.Now() for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() { if now.After(ent.ExpiresAt) { continue } keys = append(keys, ent.Key) } return keys } // Values returns a slice of the values in the cache, from oldest to newest. // Expired entries are filtered out. func (c *LRU[K, V]) Values() []V { c.mu.RLock() defer c.mu.RUnlock() values := make([]V, 0, len(c.items)) now := time.Now() for ent := c.evictList.Back(); ent != nil; ent = ent.PrevEntry() { if now.After(ent.ExpiresAt) { continue } values = append(values, ent.Value) } return values } // Len returns the number of items in the cache. func (c *LRU[K, V]) Len() int { c.mu.RLock() defer c.mu.RUnlock() return c.evictList.Length() } // Resize changes the cache size. Size of 0 means unlimited. func (c *LRU[K, V]) Resize(size int) (evicted int) { c.mu.Lock() defer c.mu.Unlock() if size <= 0 { c.size = 0 return 0 } diff := c.evictList.Length() - size if diff < 0 { diff = 0 } for i := 0; i < diff; i++ { c.removeOldest() } c.size = size return diff } // Close destroys cleanup goroutine. To clean up the cache, run Purge() before Close(). // func (c *LRU[K, V]) Close() { // c.mu.Lock() // defer c.mu.Unlock() // select { // case <-c.done: // return // default: // } // close(c.done) // } // removeOldest removes the oldest item from the cache. Has to be called with lock! func (c *LRU[K, V]) removeOldest() { if ent := c.evictList.Back(); ent != nil { c.removeElement(ent) } } // removeElement is used to remove a given list element from the cache. Has to be called with lock! func (c *LRU[K, V]) removeElement(e *Entry[K, V]) { c.evictList.Remove(e) delete(c.items, e.Key) c.removeFromBucket(e) if c.onEvict != nil { c.onEvict(e.Key, e.Value) } } // deleteExpired deletes expired records from the oldest bucket, waiting for the newest entry // in it to expire first. func (c *LRU[K, V]) deleteExpired() { c.mu.Lock() bucketIdx := c.nextCleanupBucket timeToExpire := time.Until(c.buckets[bucketIdx].newestEntry) // wait for newest entry to expire before cleanup without holding lock if timeToExpire > 0 { c.mu.Unlock() time.Sleep(timeToExpire) c.mu.Lock() } for _, ent := range c.buckets[bucketIdx].entries { c.removeElement(ent) } c.nextCleanupBucket = (c.nextCleanupBucket + 1) % numBuckets c.mu.Unlock() } // addToBucket adds entry to expire bucket so that it will be cleaned up when the time comes. Has to be called with lock! func (c *LRU[K, V]) addToBucket(e *Entry[K, V]) { bucketID := (numBuckets + c.nextCleanupBucket - 1) % numBuckets e.ExpireBucket = bucketID c.buckets[bucketID].entries[e.Key] = e if c.buckets[bucketID].newestEntry.Before(e.ExpiresAt) { c.buckets[bucketID].newestEntry = e.ExpiresAt } } // removeFromBucket removes the entry from its corresponding bucket. Has to be called with lock! func (c *LRU[K, V]) removeFromBucket(e *Entry[K, V]) { delete(c.buckets[e.ExpireBucket].entries, e.Key) } // Cap returns the capacity of the cache func (c *LRU[K, V]) Cap() int { return c.size } // Entry is an LRU Entry type Entry[K comparable, V any] struct { // Next and previous pointers in the doubly-linked list of elements. // To simplify the implementation, internally a list l is implemented // as a ring, such that &l.root is both the next element of the last // list element (l.Back()) and the previous element of the first list // element (l.Front()). next, prev *Entry[K, V] // The list to which this element belongs. list *LruList[K, V] // The LRU Key of this element. Key K // The Value stored with this element. Value V // The time this element would be cleaned up, optional ExpiresAt time.Time // The expiry bucket item was put in, optional ExpireBucket uint8 } // PrevEntry returns the previous list element or nil. func (e *Entry[K, V]) PrevEntry() *Entry[K, V] { if p := e.prev; e.list != nil && p != &e.list.root { return p } return nil } // LruList represents a doubly linked list. // The zero Value for LruList is an empty list ready to use. type LruList[K comparable, V any] struct { root Entry[K, V] // sentinel list element, only &root, root.prev, and root.next are used len int // current list Length excluding (this) sentinel element } // Init initializes or clears list l. func (l *LruList[K, V]) Init() *LruList[K, V] { l.root.next = &l.root l.root.prev = &l.root l.len = 0 return l } // NewList returns an initialized list. func NewList[K comparable, V any]() *LruList[K, V] { return new(LruList[K, V]).Init() } // Length returns the number of elements of list l. // The complexity is O(1). func (l *LruList[K, V]) Length() int { return l.len } // Back returns the last element of list l or nil if the list is empty. func (l *LruList[K, V]) Back() *Entry[K, V] { if l.len == 0 { return nil } return l.root.prev } // lazyInit lazily initializes a zero List Value. func (l *LruList[K, V]) lazyInit() { if l.root.next == nil { l.Init() } } // insert inserts e after at, increments l.len, and returns e. func (l *LruList[K, V]) insert(e, at *Entry[K, V]) *Entry[K, V] { e.prev = at e.next = at.next e.prev.next = e e.next.prev = e e.list = l l.len++ return e } // insertValue is a convenience wrapper for insert(&Entry{Value: v, ExpiresAt: ExpiresAt}, at). func (l *LruList[K, V]) insertValue(k K, v V, expiresAt time.Time, at *Entry[K, V]) *Entry[K, V] { return l.insert(&Entry[K, V]{Value: v, Key: k, ExpiresAt: expiresAt}, at) } // Remove removes e from its list, decrements l.len func (l *LruList[K, V]) Remove(e *Entry[K, V]) V { e.prev.next = e.next e.next.prev = e.prev e.next = nil // avoid memory leaks e.prev = nil // avoid memory leaks e.list = nil l.len-- return e.Value } // move moves e to next to at. func (l *LruList[K, V]) move(e, at *Entry[K, V]) { if e == at { return } e.prev.next = e.next e.next.prev = e.prev e.prev = at e.next = at.next e.prev.next = e e.next.prev = e } // PushFront inserts a new element e with value v at the front of list l and returns e. func (l *LruList[K, V]) PushFront(k K, v V) *Entry[K, V] { l.lazyInit() return l.insertValue(k, v, time.Time{}, &l.root) } // PushFrontExpirable inserts a new expirable element e with Value v at the front of list l and returns e. func (l *LruList[K, V]) PushFrontExpirable(k K, v V, expiresAt time.Time) *Entry[K, V] { l.lazyInit() return l.insertValue(k, v, expiresAt, &l.root) } // MoveToFront moves element e to the front of list l. // If e is not an element of l, the list is not modified. // The element must not be nil. func (l *LruList[K, V]) MoveToFront(e *Entry[K, V]) { if e.list != l || l.root.next == e { return } // see comment in List.Remove about initialization of l l.move(e, &l.root) } ================================================ FILE: internal/stmt_store/stmt_store.go ================================================ package stmt_store import ( "context" "database/sql" "math" "sync" "time" "gorm.io/gorm/internal/lru" ) type Stmt struct { *sql.Stmt Transaction bool prepared chan struct{} prepareErr error } func (stmt *Stmt) Error() error { return stmt.prepareErr } func (stmt *Stmt) Close() error { <-stmt.prepared if stmt.Stmt != nil { return stmt.Stmt.Close() } return nil } // Store defines an interface for managing the caching operations of SQL statements (Stmt). // This interface provides methods for creating new statements, retrieving all cache keys, // getting cached statements, setting cached statements, and deleting cached statements. type Store interface { // New creates a new Stmt object and caches it. // Parameters: // ctx: The context for the request, which can carry deadlines, cancellation signals, etc. // key: The key representing the SQL query, used for caching and preparing the statement. // isTransaction: Indicates whether this operation is part of a transaction, which may affect the caching strategy. // connPool: A connection pool that provides database connections. // locker: A synchronization lock that is unlocked after initialization to avoid deadlocks. // Returns: // *Stmt: A newly created statement object for executing SQL operations. // error: An error if the statement preparation fails. New(ctx context.Context, key string, isTransaction bool, connPool ConnPool, locker sync.Locker) (*Stmt, error) // Keys returns a slice of all cache keys in the store. Keys() []string // Get retrieves a Stmt object from the store based on the given key. // Parameters: // key: The key used to look up the Stmt object. // Returns: // *Stmt: The found Stmt object, or nil if not found. // bool: Indicates whether the corresponding Stmt object was successfully found. Get(key string) (*Stmt, bool) // Set stores the given Stmt object in the store and associates it with the specified key. // Parameters: // key: The key used to associate the Stmt object. // value: The Stmt object to be stored. Set(key string, value *Stmt) // Delete removes the Stmt object corresponding to the specified key from the store. // Parameters: // key: The key associated with the Stmt object to be deleted. Delete(key string) } // defaultMaxSize defines the default maximum capacity of the cache. // Its value is the maximum value of the int64 type, which means that when the cache size is not specified, // the cache can theoretically store as many elements as possible. // (1 << 63) - 1 is the maximum value that an int64 type can represent. const ( defaultMaxSize = math.MaxInt // defaultTTL defines the default time-to-live (TTL) for each cache entry. // When the TTL for cache entries is not specified, each cache entry will expire after 24 hours. defaultTTL = time.Hour * 24 ) // New creates and returns a new Store instance. // // Parameters: // - size: The maximum capacity of the cache. If the provided size is less than or equal to 0, // it defaults to defaultMaxSize. // - ttl: The time-to-live duration for each cache entry. If the provided ttl is less than or equal to 0, // it defaults to defaultTTL. // // This function defines an onEvicted callback that is invoked when a cache entry is evicted. // The callback ensures that if the evicted value (v) is not nil, its Close method is called asynchronously // to release associated resources. // // Returns: // - A Store instance implemented by lruStore, which internally uses an LRU cache with the specified size, // eviction callback, and TTL. func New(size int, ttl time.Duration) Store { if size <= 0 { size = defaultMaxSize } if ttl <= 0 { ttl = defaultTTL } onEvicted := func(k string, v *Stmt) { if v != nil { go v.Close() } } return &lruStore{lru: lru.NewLRU[string, *Stmt](size, onEvicted, ttl)} } type lruStore struct { lru *lru.LRU[string, *Stmt] } func (s *lruStore) Keys() []string { return s.lru.Keys() } func (s *lruStore) Get(key string) (*Stmt, bool) { stmt, ok := s.lru.Get(key) if ok && stmt != nil { <-stmt.prepared } return stmt, ok } func (s *lruStore) Set(key string, value *Stmt) { s.lru.Add(key, value) } func (s *lruStore) Delete(key string) { s.lru.Remove(key) } type ConnPool interface { PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) } // New creates a new Stmt object for executing SQL queries. // It caches the Stmt object for future use and handles preparation and error states. // Parameters: // // ctx: Context for the request, used to carry deadlines, cancellation signals, etc. // key: The key representing the SQL query, used for caching and preparing the statement. // isTransaction: Indicates whether this operation is part of a transaction, affecting cache strategy. // conn: A connection pool that provides database connections. // locker: A synchronization lock that is unlocked after initialization to avoid deadlocks. // // Returns: // // *Stmt: A newly created statement object for executing SQL operations. // error: An error if the statement preparation fails. func (s *lruStore) New(ctx context.Context, key string, isTransaction bool, conn ConnPool, locker sync.Locker) (_ *Stmt, err error) { // Create a Stmt object and set its Transaction property. // The prepared channel is used to synchronize the statement preparation state. cacheStmt := &Stmt{ Transaction: isTransaction, prepared: make(chan struct{}), } // Cache the Stmt object with the associated key. s.Set(key, cacheStmt) // Unlock after completing initialization to prevent deadlocks. locker.Unlock() // Ensure the prepared channel is closed after the function execution completes. defer close(cacheStmt.prepared) // Prepare the SQL statement using the provided connection. cacheStmt.Stmt, err = conn.PrepareContext(ctx, key) if err != nil { // If statement preparation fails, record the error and remove the invalid Stmt object from the cache. cacheStmt.prepareErr = err s.Delete(key) return &Stmt{}, err } // Return the successfully prepared Stmt object. return cacheStmt, nil } ================================================ FILE: logger/logger.go ================================================ package logger import ( "context" "errors" "fmt" "io" "log" "os" "time" "gorm.io/gorm/utils" ) // ErrRecordNotFound record not found error var ErrRecordNotFound = errors.New("record not found") // Colors const ( Reset = "\033[0m" Red = "\033[31m" Green = "\033[32m" Yellow = "\033[33m" Blue = "\033[34m" Magenta = "\033[35m" Cyan = "\033[36m" White = "\033[37m" BlueBold = "\033[34;1m" MagentaBold = "\033[35;1m" RedBold = "\033[31;1m" YellowBold = "\033[33;1m" ) // LogLevel log level type LogLevel int const ( // Silent silent log level Silent LogLevel = iota + 1 // Error error log level Error // Warn warn log level Warn // Info info log level Info ) // Writer log writer interface type Writer interface { Printf(string, ...interface{}) } // Config logger config type Config struct { SlowThreshold time.Duration Colorful bool IgnoreRecordNotFoundError bool ParameterizedQueries bool LogLevel LogLevel } // Interface logger interface type Interface interface { LogMode(LogLevel) Interface Info(context.Context, string, ...interface{}) Warn(context.Context, string, ...interface{}) Error(context.Context, string, ...interface{}) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) } var ( // Discard logger will print any log to io.Discard Discard = New(log.New(io.Discard, "", log.LstdFlags), Config{}) // Default Default logger Default = New(log.New(os.Stdout, "\r\n", log.LstdFlags), Config{ SlowThreshold: 200 * time.Millisecond, LogLevel: Warn, IgnoreRecordNotFoundError: false, Colorful: true, }) // Recorder logger records running SQL into a recorder instance Recorder = traceRecorder{Interface: Default, BeginAt: time.Now()} // RecorderParamsFilter defaults to no-op, allows to be run-over by a different implementation RecorderParamsFilter = func(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { return sql, params } ) // New initialize logger func New(writer Writer, config Config) Interface { var ( infoStr = "%s\n[info] " warnStr = "%s\n[warn] " errStr = "%s\n[error] " traceStr = "%s\n[%.3fms] [rows:%v] %s" traceWarnStr = "%s %s\n[%.3fms] [rows:%v] %s" traceErrStr = "%s %s\n[%.3fms] [rows:%v] %s" ) if config.Colorful { infoStr = Green + "%s\n" + Reset + Green + "[info] " + Reset warnStr = BlueBold + "%s\n" + Reset + Magenta + "[warn] " + Reset errStr = Magenta + "%s\n" + Reset + Red + "[error] " + Reset traceStr = Green + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" traceWarnStr = Green + "%s " + Yellow + "%s\n" + Reset + RedBold + "[%.3fms] " + Yellow + "[rows:%v]" + Magenta + " %s" + Reset traceErrStr = RedBold + "%s " + MagentaBold + "%s\n" + Reset + Yellow + "[%.3fms] " + BlueBold + "[rows:%v]" + Reset + " %s" } return &logger{ Writer: writer, Config: config, infoStr: infoStr, warnStr: warnStr, errStr: errStr, traceStr: traceStr, traceWarnStr: traceWarnStr, traceErrStr: traceErrStr, } } type logger struct { Writer Config infoStr, warnStr, errStr string traceStr, traceErrStr, traceWarnStr string } // LogMode log mode func (l *logger) LogMode(level LogLevel) Interface { newlogger := *l newlogger.LogLevel = level return &newlogger } // Info print info func (l *logger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.Printf(l.infoStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Warn print warn messages func (l *logger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.Printf(l.warnStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Error print error messages func (l *logger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.Printf(l.errStr+msg, append([]interface{}{utils.FileWithLineNum()}, data...)...) } } // Trace print sql message // //nolint:cyclop func (l *logger) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { if l.LogLevel <= Silent { return } elapsed := time.Since(begin) switch { case err != nil && l.LogLevel >= Error && (!errors.Is(err, ErrRecordNotFound) || !l.IgnoreRecordNotFoundError): sql, rows := fc() if rows == -1 { l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { l.Printf(l.traceErrStr, utils.FileWithLineNum(), err, float64(elapsed.Nanoseconds())/1e6, rows, sql) } case elapsed > l.SlowThreshold && l.SlowThreshold != 0 && l.LogLevel >= Warn: sql, rows := fc() slowLog := fmt.Sprintf("SLOW SQL >= %v", l.SlowThreshold) if rows == -1 { l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { l.Printf(l.traceWarnStr, utils.FileWithLineNum(), slowLog, float64(elapsed.Nanoseconds())/1e6, rows, sql) } case l.LogLevel == Info: sql, rows := fc() if rows == -1 { l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, "-", sql) } else { l.Printf(l.traceStr, utils.FileWithLineNum(), float64(elapsed.Nanoseconds())/1e6, rows, sql) } } } // ParamsFilter filter params func (l *logger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Config.ParameterizedQueries { return sql, nil } return sql, params } type traceRecorder struct { Interface BeginAt time.Time SQL string RowsAffected int64 Err error } // New trace recorder func (l *traceRecorder) New() *traceRecorder { return &traceRecorder{Interface: l.Interface, BeginAt: time.Now()} } // Trace implement logger interface func (l *traceRecorder) Trace(ctx context.Context, begin time.Time, fc func() (string, int64), err error) { l.BeginAt = begin l.SQL, l.RowsAffected = fc() l.Err = err } func (l *traceRecorder) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if RecorderParamsFilter == nil { return sql, params } return RecorderParamsFilter(ctx, sql, params...) } ================================================ FILE: logger/slog.go ================================================ //go:build go1.21 package logger import ( "context" "errors" "fmt" "log/slog" "time" "gorm.io/gorm/utils" ) type slogLogger struct { Logger *slog.Logger LogLevel LogLevel SlowThreshold time.Duration Parameterized bool Colorful bool // Ignored in slog IgnoreRecordNotFoundError bool } func NewSlogLogger(logger *slog.Logger, config Config) Interface { return &slogLogger{ Logger: logger, LogLevel: config.LogLevel, SlowThreshold: config.SlowThreshold, Parameterized: config.ParameterizedQueries, IgnoreRecordNotFoundError: config.IgnoreRecordNotFoundError, } } func (l *slogLogger) LogMode(level LogLevel) Interface { newLogger := *l newLogger.LogLevel = level return &newLogger } func (l *slogLogger) Info(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Info { l.log(ctx, slog.LevelInfo, msg, slog.Any("data", data)) } } func (l *slogLogger) Warn(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Warn { l.log(ctx, slog.LevelWarn, msg, slog.Any("data", data)) } } func (l *slogLogger) Error(ctx context.Context, msg string, data ...interface{}) { if l.LogLevel >= Error { l.log(ctx, slog.LevelError, msg, slog.Any("data", data)) } } func (l *slogLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { if l.LogLevel <= Silent { return } elapsed := time.Since(begin) sql, rows := fc() fields := []slog.Attr{ slog.String("duration", fmt.Sprintf("%.3fms", float64(elapsed.Nanoseconds())/1e6)), slog.String("sql", sql), } if rows != -1 { fields = append(fields, slog.Int64("rows", rows)) } switch { case err != nil && (!l.IgnoreRecordNotFoundError || !errors.Is(err, ErrRecordNotFound)): fields = append(fields, slog.String("error", err.Error())) l.log(ctx, slog.LevelError, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.SlowThreshold != 0 && elapsed > l.SlowThreshold: l.log(ctx, slog.LevelWarn, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) case l.LogLevel >= Info: l.log(ctx, slog.LevelInfo, "SQL executed", slog.Attr{ Key: "trace", Value: slog.GroupValue(fields...), }) } } func (l *slogLogger) log(ctx context.Context, level slog.Level, msg string, args ...any) { if ctx == nil { ctx = context.Background() } if !l.Logger.Enabled(ctx, level) { return } r := slog.NewRecord(time.Now(), level, msg, utils.CallerFrame().PC) r.Add(args...) _ = l.Logger.Handler().Handle(ctx, r) } // ParamsFilter filter params func (l *slogLogger) ParamsFilter(ctx context.Context, sql string, params ...interface{}) (string, []interface{}) { if l.Parameterized { return sql, nil } return sql, params } ================================================ FILE: logger/slog_test.go ================================================ //go:build go1.21 package logger import ( "bytes" "context" "log/slog" "strings" "testing" "time" ) func TestSlogLogger(t *testing.T) { buf := &bytes.Buffer{} handler := slog.NewTextHandler(buf, &slog.HandlerOptions{AddSource: true}) logger := NewSlogLogger(slog.New(handler), Config{LogLevel: Info}) logger.Trace(context.Background(), time.Now(), func() (string, int64) { return "select count(*) from users", 0 }, nil) if strings.Contains(buf.String(), "gorm/logger/slog.go") { t.Error("Found internal slog.go reference in caller frame. Expected only test file references.") } if !strings.Contains(buf.String(), "gorm/logger/slog_test.go") { t.Error("Missing expected test file reference. 'gorm/logger/slog_test.go' should appear in caller frames.") } } ================================================ FILE: logger/sql.go ================================================ package logger import ( "database/sql/driver" "fmt" "reflect" "regexp" "strconv" "strings" "time" "unicode" "gorm.io/gorm/utils" ) const ( tmFmtWithMS = "2006-01-02 15:04:05.999" tmFmtZero = "0000-00-00 00:00:00" nullStr = "NULL" ) func isPrintable(s string) bool { for _, r := range s { if !unicode.IsPrint(r) { return false } } return true } // A list of Go types that should be converted to SQL primitives var convertibleTypes = []reflect.Type{reflect.TypeOf(time.Time{}), reflect.TypeOf(false), reflect.TypeOf([]byte{})} // RegEx matches only numeric values var numericPlaceholderRe = regexp.MustCompile(`\$\d+\$`) func isNumeric(k reflect.Kind) bool { switch k { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: return true case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: return true case reflect.Float32, reflect.Float64: return true default: return false } } // ExplainSQL generate SQL string with given parameters, the generated SQL is expected to be used in logger, execute it might introduce a SQL injection vulnerability func ExplainSQL(sql string, numericPlaceholder *regexp.Regexp, escaper string, avars ...interface{}) string { var ( convertParams func(interface{}, int) vars = make([]string, len(avars)) ) convertParams = func(v interface{}, idx int) { switch v := v.(type) { case bool: vars[idx] = strconv.FormatBool(v) case time.Time: if v.IsZero() { vars[idx] = escaper + tmFmtZero + escaper } else { vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper } case *time.Time: if v != nil { if v.IsZero() { vars[idx] = escaper + tmFmtZero + escaper } else { vars[idx] = escaper + v.Format(tmFmtWithMS) + escaper } } else { vars[idx] = nullStr } case driver.Valuer: reflectValue := reflect.ValueOf(v) if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { r, _ := v.Value() convertParams(r, idx) } else { vars[idx] = nullStr } case fmt.Stringer: reflectValue := reflect.ValueOf(v) switch reflectValue.Kind() { case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: vars[idx] = fmt.Sprintf("%d", reflectValue.Interface()) case reflect.Float32, reflect.Float64: vars[idx] = fmt.Sprintf("%.6f", reflectValue.Interface()) case reflect.Bool: vars[idx] = fmt.Sprintf("%t", reflectValue.Interface()) case reflect.String: vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper default: if v != nil && reflectValue.IsValid() && ((reflectValue.Kind() == reflect.Ptr && !reflectValue.IsNil()) || reflectValue.Kind() != reflect.Ptr) { vars[idx] = escaper + strings.ReplaceAll(fmt.Sprintf("%v", v), escaper, escaper+escaper) + escaper } else { vars[idx] = nullStr } } case []byte: if s := string(v); isPrintable(s) { vars[idx] = escaper + strings.ReplaceAll(s, escaper, escaper+escaper) + escaper } else { vars[idx] = escaper + "" + escaper } case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: vars[idx] = utils.ToString(v) case float32: vars[idx] = strconv.FormatFloat(float64(v), 'f', -1, 32) case float64: vars[idx] = strconv.FormatFloat(v, 'f', -1, 64) case string: vars[idx] = escaper + strings.ReplaceAll(v, escaper, escaper+escaper) + escaper default: rv := reflect.ValueOf(v) if v == nil || !rv.IsValid() || rv.Kind() == reflect.Ptr && rv.IsNil() { vars[idx] = nullStr } else if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() convertParams(v, idx) } else if rv.Kind() == reflect.Ptr && !rv.IsZero() { convertParams(reflect.Indirect(rv).Interface(), idx) } else if isNumeric(rv.Kind()) { if rv.CanInt() || rv.CanUint() { vars[idx] = fmt.Sprintf("%d", rv.Interface()) } else { vars[idx] = fmt.Sprintf("%.6f", rv.Interface()) } } else { for _, t := range convertibleTypes { if rv.Type().ConvertibleTo(t) { convertParams(rv.Convert(t).Interface(), idx) return } } vars[idx] = escaper + strings.ReplaceAll(fmt.Sprint(v), escaper, escaper+escaper) + escaper } } } for idx, v := range avars { convertParams(v, idx) } if numericPlaceholder == nil { var idx int var newSQL strings.Builder for _, v := range []byte(sql) { if v == '?' { if len(vars) > idx { newSQL.WriteString(vars[idx]) idx++ continue } } newSQL.WriteByte(v) } sql = newSQL.String() } else { sql = numericPlaceholder.ReplaceAllString(sql, "$$$1$$") sql = numericPlaceholderRe.ReplaceAllStringFunc(sql, func(v string) string { num := v[1 : len(v)-1] n, _ := strconv.Atoi(num) // position var start from 1 ($1, $2) n -= 1 if n >= 0 && n <= len(vars)-1 { return vars[n] } return v }) } return sql } ================================================ FILE: logger/sql_test.go ================================================ package logger_test import ( "database/sql/driver" "encoding/json" "fmt" "regexp" "strings" "testing" "github.com/jinzhu/now" "gorm.io/gorm/logger" ) type JSON json.RawMessage func (j JSON) Value() (driver.Value, error) { if len(j) == 0 { return nil, nil } return json.RawMessage(j).MarshalJSON() } type ExampleStruct struct { Name string Val string } func (s ExampleStruct) Value() (driver.Value, error) { return json.Marshal(s) } func format(v []byte, escaper string) string { return escaper + strings.ReplaceAll(string(v), escaper, escaper+escaper) + escaper } func TestExplainSQL(t *testing.T) { type role string type password []byte type intType int type floatType float64 var ( tt = now.MustParse("2020-02-23 11:10:10") myrole = role("admin") pwd = password("pass") jsVal = []byte(`{"Name":"test","Val":"test"}`) js = JSON(jsVal) esVal = []byte(`{"Name":"test","Val":"test"}`) es = ExampleStruct{Name: "test", Val: "test"} intVal intType = 1 floatVal floatType = 1.23 ) results := []struct { SQL string NumericRegexp *regexp.Regexp Vars []interface{} Result string }{ { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10, @p11)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ($3, $4, $1, $2, $7, $8, $5, $6, $9, $10, $11)", NumericRegexp: regexp.MustCompile(`\$(\d+)`), Vars: []interface{}{999.99, true, "jinzhu", 1, &tt, nil, []byte("12345"), tt, "w@g.com", myrole, pwd}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values (@p1, @p11, @p2, @p3, @p4, @p5, @p6, @p7, @p8, @p9, @p10)", NumericRegexp: regexp.MustCompile(`@p(\d+)`), Vars: []interface{}{"jinzhu", 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.com", myrole, pwd, 1}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.com", "admin", "pass")`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, js, es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, 0.1753607109, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 0.1753607109, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu", 1, float32(999.99), true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, &js, &es}, Result: fmt.Sprintf(`create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, json_struct, example_struct) values ("jinzhu", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", %v, %v)`, format(jsVal, `"`), format(esVal, `"`)), }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, intVal}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, int_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1)`, }, { SQL: "create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", NumericRegexp: nil, Vars: []interface{}{"jinzhu?", 1, 999.99, true, []byte("12345"), tt, &tt, nil, "w@g.\"com", myrole, pwd, floatVal}, Result: `create table users (name, age, height, actived, bytes, create_at, update_at, deleted_at, email, role, pass, float_val) values ("jinzhu?", 1, 999.99, true, "12345", "2020-02-23 11:10:10", "2020-02-23 11:10:10", NULL, "w@g.""com", "admin", "pass", 1.230000)`, }, } for idx, r := range results { if result := logger.ExplainSQL(r.SQL, r.NumericRegexp, `"`, r.Vars...); result != r.Result { t.Errorf("Explain SQL #%v expects %v, but got %v", idx, r.Result, result) } } } ================================================ FILE: migrator/column_type.go ================================================ package migrator import ( "database/sql" "reflect" ) // ColumnType column type implements ColumnType interface type ColumnType struct { SQLColumnType *sql.ColumnType NameValue sql.NullString DataTypeValue sql.NullString ColumnTypeValue sql.NullString PrimaryKeyValue sql.NullBool UniqueValue sql.NullBool AutoIncrementValue sql.NullBool LengthValue sql.NullInt64 DecimalSizeValue sql.NullInt64 ScaleValue sql.NullInt64 NullableValue sql.NullBool ScanTypeValue reflect.Type CommentValue sql.NullString DefaultValueValue sql.NullString } // Name returns the name or alias of the column. func (ct ColumnType) Name() string { if ct.NameValue.Valid { return ct.NameValue.String } return ct.SQLColumnType.Name() } // DatabaseTypeName returns the database system name of the column type. If an empty // string is returned, then the driver type name is not supported. // Consult your driver documentation for a list of driver data types. Length specifiers // are not included. // Common type names include "VARCHAR", "TEXT", "NVARCHAR", "DECIMAL", "BOOL", // "INT", and "BIGINT". func (ct ColumnType) DatabaseTypeName() string { if ct.DataTypeValue.Valid { return ct.DataTypeValue.String } return ct.SQLColumnType.DatabaseTypeName() } // ColumnType returns the database type of the column. like `varchar(16)` func (ct ColumnType) ColumnType() (columnType string, ok bool) { return ct.ColumnTypeValue.String, ct.ColumnTypeValue.Valid } // PrimaryKey returns the column is primary key or not. func (ct ColumnType) PrimaryKey() (isPrimaryKey bool, ok bool) { return ct.PrimaryKeyValue.Bool, ct.PrimaryKeyValue.Valid } // AutoIncrement returns the column is auto increment or not. func (ct ColumnType) AutoIncrement() (isAutoIncrement bool, ok bool) { return ct.AutoIncrementValue.Bool, ct.AutoIncrementValue.Valid } // Length returns the column type length for variable length column types func (ct ColumnType) Length() (length int64, ok bool) { if ct.LengthValue.Valid { return ct.LengthValue.Int64, true } return ct.SQLColumnType.Length() } // DecimalSize returns the scale and precision of a decimal type. func (ct ColumnType) DecimalSize() (precision int64, scale int64, ok bool) { if ct.DecimalSizeValue.Valid { return ct.DecimalSizeValue.Int64, ct.ScaleValue.Int64, true } return ct.SQLColumnType.DecimalSize() } // Nullable reports whether the column may be null. func (ct ColumnType) Nullable() (nullable bool, ok bool) { if ct.NullableValue.Valid { return ct.NullableValue.Bool, true } return ct.SQLColumnType.Nullable() } // Unique reports whether the column may be unique. func (ct ColumnType) Unique() (unique bool, ok bool) { return ct.UniqueValue.Bool, ct.UniqueValue.Valid } // ScanType returns a Go type suitable for scanning into using Rows.Scan. func (ct ColumnType) ScanType() reflect.Type { if ct.ScanTypeValue != nil { return ct.ScanTypeValue } return ct.SQLColumnType.ScanType() } // Comment returns the comment of current column. func (ct ColumnType) Comment() (value string, ok bool) { return ct.CommentValue.String, ct.CommentValue.Valid } // DefaultValue returns the default value of current column. func (ct ColumnType) DefaultValue() (value string, ok bool) { return ct.DefaultValueValue.String, ct.DefaultValueValue.Valid } ================================================ FILE: migrator/index.go ================================================ package migrator import "database/sql" // Index implements gorm.Index interface type Index struct { TableName string NameValue string ColumnList []string PrimaryKeyValue sql.NullBool UniqueValue sql.NullBool OptionValue string } // Table return the table name of the index. func (idx Index) Table() string { return idx.TableName } // Name return the name of the index. func (idx Index) Name() string { return idx.NameValue } // Columns return the columns of the index func (idx Index) Columns() []string { return idx.ColumnList } // PrimaryKey returns the index is primary key or not. func (idx Index) PrimaryKey() (isPrimaryKey bool, ok bool) { return idx.PrimaryKeyValue.Bool, idx.PrimaryKeyValue.Valid } // Unique returns whether the index is unique or not. func (idx Index) Unique() (unique bool, ok bool) { return idx.UniqueValue.Bool, idx.UniqueValue.Valid } // Option return the optional attribute of the index func (idx Index) Option() string { return idx.OptionValue } ================================================ FILE: migrator/migrator.go ================================================ package migrator import ( "context" "database/sql" "errors" "fmt" "reflect" "regexp" "strconv" "strings" "time" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) // This regular expression seeks to find a sequence of digits (\d+) among zero or more non-digit characters (\D*), // with a possible trailing non-digit character (\D?). // For example, values that can pass this regular expression are: // - "123" // - "abc456" // -"%$#@789" var regFullDataType = regexp.MustCompile(`\D*(\d+)\D?`) // TODO:? Create const vars for raw sql queries ? var _ gorm.Migrator = (*Migrator)(nil) // Migrator m struct type Migrator struct { Config } // Config schema config type Config struct { CreateIndexAfterCreateTable bool DB *gorm.DB gorm.Dialector } type printSQLLogger struct { logger.Interface } func (l *printSQLLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() fmt.Println(sql + ";") l.Interface.Trace(ctx, begin, fc, err) } // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDBDataType(*gorm.DB, *schema.Field) string } // RunWithValue run migration with statement value func (m Migrator) RunWithValue(value interface{}, fc func(*gorm.Statement) error) error { stmt := &gorm.Statement{DB: m.DB} if m.DB.Statement != nil { stmt.Table = m.DB.Statement.Table stmt.TableExpr = m.DB.Statement.TableExpr } if table, ok := value.(string); ok { stmt.Table = table } else if err := stmt.ParseWithSpecialTableName(value, stmt.Table); err != nil { return err } return fc(stmt) } // DataTypeOf return field's db data type func (m Migrator) DataTypeOf(field *schema.Field) string { fieldValue := reflect.New(field.IndirectFieldType) if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { if dataType := dataTyper.GormDBDataType(m.DB, field); dataType != "" { return dataType } } return m.Dialector.DataTypeOf(field) } // FullDataTypeOf returns field's db full data type func (m Migrator) FullDataTypeOf(field *schema.Field) (expr clause.Expr) { expr.SQL = m.DataTypeOf(field) if field.NotNull { expr.SQL += " NOT NULL" } if field.HasDefaultValue && (field.DefaultValueInterface != nil || field.DefaultValue != "") { if field.DefaultValueInterface != nil { defaultStmt := &gorm.Statement{Vars: []interface{}{field.DefaultValueInterface}} m.Dialector.BindVarTo(defaultStmt, defaultStmt, field.DefaultValueInterface) expr.SQL += " DEFAULT " + m.Dialector.Explain(defaultStmt.SQL.String(), field.DefaultValueInterface) } else if field.DefaultValue != "(-)" { expr.SQL += " DEFAULT " + field.DefaultValue } } return } func (m Migrator) GetQueryAndExecTx() (queryTx, execTx *gorm.DB) { queryTx = m.DB.Session(&gorm.Session{}) execTx = queryTx if m.DB.DryRun { queryTx.DryRun = false execTx = m.DB.Session(&gorm.Session{Logger: &printSQLLogger{Interface: m.DB.Logger}}) } return queryTx, execTx } // AutoMigrate auto migrate values func (m Migrator) AutoMigrate(values ...interface{}) error { for _, value := range m.ReorderModels(values, true) { queryTx, execTx := m.GetQueryAndExecTx() if !queryTx.Migrator().HasTable(value) { if err := execTx.Migrator().CreateTable(value); err != nil { return err } } else { if err := m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema == nil { return errors.New("failed to get schema") } columnTypes, err := queryTx.Migrator().ColumnTypes(value) if err != nil { return err } var ( parseIndexes = stmt.Schema.ParseIndexes() parseCheckConstraints = stmt.Schema.ParseCheckConstraints() ) for _, dbName := range stmt.Schema.DBNames { var foundColumn gorm.ColumnType for _, columnType := range columnTypes { if columnType.Name() == dbName { foundColumn = columnType break } } if foundColumn == nil { // not found, add column if err = execTx.Migrator().AddColumn(value, dbName); err != nil { return err } } else { // found, smartly migrate field := stmt.Schema.FieldsByDBName[dbName] if err = execTx.Migrator().MigrateColumn(value, field, foundColumn); err != nil { return err } } } if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { for _, rel := range stmt.Schema.Relationships.Relations { if rel.Field.IgnoreMigration { continue } if constraint := rel.ParseConstraint(); constraint != nil && constraint.Schema == stmt.Schema && !queryTx.Migrator().HasConstraint(value, constraint.Name) { if err := execTx.Migrator().CreateConstraint(value, constraint.Name); err != nil { return err } } } } for _, chk := range parseCheckConstraints { if !queryTx.Migrator().HasConstraint(value, chk.Name) { if err := execTx.Migrator().CreateConstraint(value, chk.Name); err != nil { return err } } } for _, idx := range parseIndexes { if !queryTx.Migrator().HasIndex(value, idx.Name) { if err := execTx.Migrator().CreateIndex(value, idx.Name); err != nil { return err } } } return nil }); err != nil { return err } } } return nil } // GetTables returns tables func (m Migrator) GetTables() (tableList []string, err error) { err = m.DB.Raw("SELECT TABLE_NAME FROM information_schema.tables where TABLE_SCHEMA=?", m.CurrentDatabase()). Scan(&tableList).Error return } // CreateTable create table in database for values func (m Migrator) CreateTable(values ...interface{}) error { for _, value := range m.ReorderModels(values, false) { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { if stmt.Schema == nil { return errors.New("failed to get schema") } var ( createTableSQL = "CREATE TABLE ? (" values = []interface{}{m.CurrentTable(stmt)} hasPrimaryKeyInDataType bool ) for _, dbName := range stmt.Schema.DBNames { field := stmt.Schema.FieldsByDBName[dbName] if !field.IgnoreMigration { createTableSQL += "? ?" hasPrimaryKeyInDataType = hasPrimaryKeyInDataType || strings.Contains(strings.ToUpper(m.DataTypeOf(field)), "PRIMARY KEY") values = append(values, clause.Column{Name: dbName}, m.DB.Migrator().FullDataTypeOf(field)) createTableSQL += "," } } if !hasPrimaryKeyInDataType && len(stmt.Schema.PrimaryFields) > 0 { createTableSQL += "PRIMARY KEY ?," primaryKeys := make([]interface{}, 0, len(stmt.Schema.PrimaryFields)) for _, field := range stmt.Schema.PrimaryFields { primaryKeys = append(primaryKeys, clause.Column{Name: field.DBName}) } values = append(values, primaryKeys) } for _, idx := range stmt.Schema.ParseIndexes() { if m.CreateIndexAfterCreateTable { defer func(value interface{}, name string) { if err == nil { err = tx.Migrator().CreateIndex(value, name) } }(value, idx.Name) } else { if idx.Class != "" { createTableSQL += idx.Class + " " } createTableSQL += "INDEX ? ?" if idx.Comment != "" { createTableSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) } if idx.Option != "" { createTableSQL += " " + idx.Option } createTableSQL += "," values = append(values, clause.Column{Name: idx.Name}, tx.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt)) } } if !m.DB.DisableForeignKeyConstraintWhenMigrating && !m.DB.IgnoreRelationshipsWhenMigrating { for _, rel := range stmt.Schema.Relationships.Relations { if rel.Field.IgnoreMigration { continue } if constraint := rel.ParseConstraint(); constraint != nil { if constraint.Schema == stmt.Schema { sql, vars := constraint.Build() createTableSQL += sql + "," values = append(values, vars...) } } } } for _, uni := range stmt.Schema.ParseUniqueConstraints() { createTableSQL += "CONSTRAINT ? UNIQUE (?)," values = append(values, clause.Column{Name: uni.Name}, clause.Expr{SQL: stmt.Quote(uni.Field.DBName)}) } for _, chk := range stmt.Schema.ParseCheckConstraints() { createTableSQL += "CONSTRAINT ? CHECK (?)," values = append(values, clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}) } createTableSQL = strings.TrimSuffix(createTableSQL, ",") createTableSQL += ")" if tableOption, ok := m.DB.Get("gorm:table_options"); ok { createTableSQL += fmt.Sprint(tableOption) } err = tx.Exec(createTableSQL, values...).Error return err }); err != nil { return err } } return nil } // DropTable drop table for values func (m Migrator) DropTable(values ...interface{}) error { values = m.ReorderModels(values, false) for i := len(values) - 1; i >= 0; i-- { tx := m.DB.Session(&gorm.Session{}) if err := m.RunWithValue(values[i], func(stmt *gorm.Statement) error { return tx.Exec("DROP TABLE IF EXISTS ?", m.CurrentTable(stmt)).Error }); err != nil { return err } } return nil } // HasTable returns table exists or not for value, value could be a struct or string func (m Migrator) HasTable(value interface{}) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() return m.DB.Raw("SELECT count(*) FROM information_schema.tables WHERE table_schema = ? AND table_name = ? AND table_type = ?", currentDatabase, stmt.Table, "BASE TABLE").Row().Scan(&count) }) return count > 0 } // RenameTable rename table from oldName to newName func (m Migrator) RenameTable(oldName, newName interface{}) error { var oldTable, newTable interface{} if v, ok := oldName.(string); ok { oldTable = clause.Table{Name: v} } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(oldName); err == nil { oldTable = m.CurrentTable(stmt) } else { return err } } if v, ok := newName.(string); ok { newTable = clause.Table{Name: v} } else { stmt := &gorm.Statement{DB: m.DB} if err := stmt.Parse(newName); err == nil { newTable = m.CurrentTable(stmt) } else { return err } } return m.DB.Exec("ALTER TABLE ? RENAME TO ?", oldTable, newTable).Error } // AddColumn create `name` column for value func (m Migrator) AddColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { // avoid using the same name field if stmt.Schema == nil { return errors.New("failed to get schema") } f := stmt.Schema.LookUpField(name) if f == nil { return fmt.Errorf("failed to look up field with name: %s", name) } if !f.IgnoreMigration { return m.DB.Exec( "ALTER TABLE ? ADD ? ?", m.CurrentTable(stmt), clause.Column{Name: f.DBName}, m.DB.Migrator().FullDataTypeOf(f), ).Error } return nil }) } // DropColumn drop value's `name` column func (m Migrator) DropColumn(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { name = field.DBName } } return m.DB.Exec( "ALTER TABLE ? DROP COLUMN ?", m.CurrentTable(stmt), clause.Column{Name: name}, ).Error }) } // AlterColumn alter value's `field` column' type based on schema definition func (m Migrator) AlterColumn(value interface{}, field string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil { fileType := m.FullDataTypeOf(field) return m.DB.Exec( "ALTER TABLE ? ALTER COLUMN ? TYPE ?", m.CurrentTable(stmt), clause.Column{Name: field.DBName}, fileType, ).Error } } return fmt.Errorf("failed to look up field with name: %s", field) }) } // HasColumn check has column `field` for value or not func (m Migrator) HasColumn(value interface{}, field string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() name := field if stmt.Schema != nil { if field := stmt.Schema.LookUpField(field); field != nil { name = field.DBName } } return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.columns WHERE table_schema = ? AND table_name = ? AND column_name = ?", currentDatabase, stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } // RenameColumn rename value's field name from oldName to newName func (m Migrator) RenameColumn(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if field := stmt.Schema.LookUpField(oldName); field != nil { oldName = field.DBName } if field := stmt.Schema.LookUpField(newName); field != nil { newName = field.DBName } } return m.DB.Exec( "ALTER TABLE ? RENAME COLUMN ? TO ?", m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } // MigrateColumn migrate column func (m Migrator) MigrateColumn(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { if field.IgnoreMigration { return nil } // found, smart migrate fullDataType := strings.TrimSpace(strings.ToLower(m.DB.Migrator().FullDataTypeOf(field).SQL)) realDataType := strings.ToLower(columnType.DatabaseTypeName()) var ( alterColumn bool isSameType = fullDataType == realDataType ) if !field.PrimaryKey { // check type if !strings.HasPrefix(fullDataType, realDataType) { // check type aliases aliases := m.DB.Migrator().GetTypeAliases(realDataType) for _, alias := range aliases { if strings.HasPrefix(fullDataType, alias) { isSameType = true break } } if !isSameType { alterColumn = true } } } if !isSameType { // check size if length, ok := columnType.Length(); length != int64(field.Size) { if length > 0 && field.Size > 0 { alterColumn = true } else { // has size in data type and not equal // Since the following code is frequently called in the for loop, reg optimization is needed here matches2 := regFullDataType.FindAllStringSubmatch(fullDataType, -1) if !field.PrimaryKey && (len(matches2) == 1 && matches2[0][1] != fmt.Sprint(length) && ok) { alterColumn = true } } } } // check precision if realDataType == "decimal" || realDataType == "numeric" && regexp.MustCompile(realDataType+`\(.*\)`).FindString(fullDataType) != "" { // if realDataType has no precision,ignore precision, scale, ok := columnType.DecimalSize() if ok { if !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d,%d)", realDataType, precision, scale)) && !strings.HasPrefix(fullDataType, fmt.Sprintf("%s(%d)", realDataType, precision)) { alterColumn = true } } } else { if precision, _, ok := columnType.DecimalSize(); ok && int64(field.Precision) != precision { if regexp.MustCompile(fmt.Sprintf("[^0-9]%d[^0-9]", field.Precision)).MatchString(m.DataTypeOf(field)) { alterColumn = true } } } // check nullable if nullable, ok := columnType.Nullable(); ok && nullable == field.NotNull { // not primary key & current database is non-nullable(to be nullable) if !field.PrimaryKey && !nullable { alterColumn = true } } // check default value if !field.PrimaryKey { currentDefaultNotNull := field.HasDefaultValue && (field.DefaultValueInterface != nil || !strings.EqualFold(field.DefaultValue, "NULL")) dv, dvNotNull := columnType.DefaultValue() if dvNotNull && !currentDefaultNotNull { // default value -> null alterColumn = true } else if !dvNotNull && currentDefaultNotNull { // null -> default value alterColumn = true } else if currentDefaultNotNull || dvNotNull { switch field.GORMDataType { case schema.Time: if !strings.EqualFold(strings.TrimSuffix(dv, "()"), strings.TrimSuffix(field.DefaultValue, "()")) { alterColumn = true } case schema.Bool: v1, _ := strconv.ParseBool(dv) v2, _ := strconv.ParseBool(field.DefaultValue) alterColumn = v1 != v2 case schema.String: if dv != field.DefaultValue && dv != strings.Trim(field.DefaultValue, "'\"") { alterColumn = true } default: alterColumn = dv != field.DefaultValue } } } // check comment if comment, ok := columnType.Comment(); ok && comment != field.Comment { // not primary key if !field.PrimaryKey { alterColumn = true } } if alterColumn { if err := m.DB.Migrator().AlterColumn(value, field.DBName); err != nil { return err } } if err := m.DB.Migrator().MigrateColumnUnique(value, field, columnType); err != nil { return err } return nil } func (m Migrator) MigrateColumnUnique(value interface{}, field *schema.Field, columnType gorm.ColumnType) error { unique, ok := columnType.Unique() if !ok || field.PrimaryKey { return nil // skip primary key } // By default, ColumnType's Unique is not affected by UniqueIndex, so we don't care about UniqueIndex. return m.RunWithValue(value, func(stmt *gorm.Statement) error { // We're currently only receiving boolean values on `Unique` tag, // so the UniqueConstraint name is fixed constraint := m.DB.NamingStrategy.UniqueName(stmt.Table, field.DBName) if unique && !field.Unique { return m.DB.Migrator().DropConstraint(value, constraint) } if !unique && field.Unique { return m.DB.Migrator().CreateConstraint(value, constraint) } return nil }) } // ColumnTypes return columnTypes []gorm.ColumnType and execErr error func (m Migrator) ColumnTypes(value interface{}) ([]gorm.ColumnType, error) { columnTypes := make([]gorm.ColumnType, 0) execErr := m.RunWithValue(value, func(stmt *gorm.Statement) (err error) { rows, err := m.DB.Session(&gorm.Session{}).Table(stmt.Table).Limit(1).Rows() if err != nil { return err } defer func() { err = rows.Close() }() var rawColumnTypes []*sql.ColumnType rawColumnTypes, err = rows.ColumnTypes() if err != nil { return err } for _, c := range rawColumnTypes { columnTypes = append(columnTypes, ColumnType{SQLColumnType: c}) } return }) return columnTypes, execErr } // CreateView create view from Query in gorm.ViewOption. // Query in gorm.ViewOption is a [subquery] // // // CREATE VIEW `user_view` AS SELECT * FROM `users` WHERE age > 20 // q := DB.Model(&User{}).Where("age > ?", 20) // DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q}) // // // CREATE OR REPLACE VIEW `users_view` AS SELECT * FROM `users` WITH CHECK OPTION // q := DB.Model(&User{}) // DB.Debug().Migrator().CreateView("user_view", gorm.ViewOption{Query: q, Replace: true, CheckOption: "WITH CHECK OPTION"}) // // [subquery]: https://gorm.io/docs/advanced_query.html#SubQuery func (m Migrator) CreateView(name string, option gorm.ViewOption) error { if option.Query == nil { return gorm.ErrSubQueryRequired } sql := new(strings.Builder) sql.WriteString("CREATE ") if option.Replace { sql.WriteString("OR REPLACE ") } sql.WriteString("VIEW ") m.QuoteTo(sql, name) sql.WriteString(" AS ") m.DB.Statement.AddVar(sql, option.Query) if option.CheckOption != "" { sql.WriteString(" ") sql.WriteString(option.CheckOption) } return m.DB.Exec(m.Explain(sql.String(), m.DB.Statement.Vars...)).Error } // DropView drop view func (m Migrator) DropView(name string) error { return m.DB.Exec("DROP VIEW IF EXISTS ?", clause.Table{Name: name}).Error } // GuessConstraintAndTable guess statement's constraint and it's table based on name // // Deprecated: use GuessConstraintInterfaceAndTable instead. func (m Migrator) GuessConstraintAndTable(stmt *gorm.Statement, name string) (*schema.Constraint, *schema.CheckConstraint, string) { constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) switch c := constraint.(type) { case *schema.Constraint: return c, nil, table case *schema.CheckConstraint: return nil, c, table default: return nil, nil, table } } // GuessConstraintInterfaceAndTable guess statement's constraint and it's table based on name // nolint:cyclop func (m Migrator) GuessConstraintInterfaceAndTable(stmt *gorm.Statement, name string) (_ schema.ConstraintInterface, table string) { if stmt.Schema == nil { return nil, stmt.Table } checkConstraints := stmt.Schema.ParseCheckConstraints() if chk, ok := checkConstraints[name]; ok { return &chk, stmt.Table } uniqueConstraints := stmt.Schema.ParseUniqueConstraints() if uni, ok := uniqueConstraints[name]; ok { return &uni, stmt.Table } getTable := func(rel *schema.Relationship) string { switch rel.Type { case schema.HasOne, schema.HasMany: return rel.FieldSchema.Table case schema.Many2Many: return rel.JoinTable.Table } return stmt.Table } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && constraint.Name == name { return constraint, getTable(rel) } } if field := stmt.Schema.LookUpField(name); field != nil { for k := range checkConstraints { if checkConstraints[k].Field == field { v := checkConstraints[k] return &v, stmt.Table } } for k := range uniqueConstraints { if uniqueConstraints[k].Field == field { v := uniqueConstraints[k] return &v, stmt.Table } } for _, rel := range stmt.Schema.Relationships.Relations { if constraint := rel.ParseConstraint(); constraint != nil && rel.Field == field { return constraint, getTable(rel) } } } return nil, stmt.Schema.Table } // CreateConstraint create constraint func (m Migrator) CreateConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { vars := []interface{}{clause.Table{Name: table}} if stmt.TableExpr != nil { vars[0] = stmt.TableExpr } sql, values := constraint.Build() return m.DB.Exec("ALTER TABLE ? ADD "+sql, append(vars, values...)...).Error } return nil }) } // DropConstraint drop constraint func (m Migrator) DropConstraint(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { name = constraint.GetName() } return m.DB.Exec("ALTER TABLE ? DROP CONSTRAINT ?", clause.Table{Name: table}, clause.Column{Name: name}).Error }) } // HasConstraint check has constraint or not func (m Migrator) HasConstraint(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() constraint, table := m.GuessConstraintInterfaceAndTable(stmt, name) if constraint != nil { name = constraint.GetName() } return m.DB.Raw( "SELECT count(*) FROM INFORMATION_SCHEMA.table_constraints WHERE constraint_schema = ? AND table_name = ? AND constraint_name = ?", currentDatabase, table, name, ).Row().Scan(&count) }) return count > 0 } // BuildIndexOptions build index options func (m Migrator) BuildIndexOptions(opts []schema.IndexOption, stmt *gorm.Statement) (results []interface{}) { for _, opt := range opts { str := stmt.Quote(opt.DBName) if opt.Expression != "" { str = opt.Expression } else if opt.Length > 0 { str += fmt.Sprintf("(%d)", opt.Length) } if opt.Collate != "" { str += " COLLATE " + opt.Collate } if opt.Sort != "" { str += " " + opt.Sort } results = append(results, clause.Expr{SQL: str}) } return } // BuildIndexOptionsInterface build index options interface type BuildIndexOptionsInterface interface { BuildIndexOptions([]schema.IndexOption, *gorm.Statement) []interface{} } // CreateIndex create index `name` func (m Migrator) CreateIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema == nil { return errors.New("failed to get schema") } if idx := stmt.Schema.LookIndex(name); idx != nil { opts := m.DB.Migrator().(BuildIndexOptionsInterface).BuildIndexOptions(idx.Fields, stmt) values := []interface{}{clause.Column{Name: idx.Name}, m.CurrentTable(stmt), opts} createIndexSQL := "CREATE " if idx.Class != "" { createIndexSQL += idx.Class + " " } createIndexSQL += "INDEX ? ON ??" if idx.Type != "" { createIndexSQL += " USING " + idx.Type } if idx.Comment != "" { createIndexSQL += fmt.Sprintf(" COMMENT '%s'", idx.Comment) } if idx.Option != "" { createIndexSQL += " " + idx.Option } return m.DB.Exec(createIndexSQL, values...).Error } return fmt.Errorf("failed to create index with name %s", name) }) } // DropIndex drop index `name` func (m Migrator) DropIndex(value interface{}, name string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } } return m.DB.Exec("DROP INDEX ? ON ?", clause.Column{Name: name}, m.CurrentTable(stmt)).Error }) } // HasIndex check has index `name` or not func (m Migrator) HasIndex(value interface{}, name string) bool { var count int64 m.RunWithValue(value, func(stmt *gorm.Statement) error { currentDatabase := m.DB.Migrator().CurrentDatabase() if stmt.Schema != nil { if idx := stmt.Schema.LookIndex(name); idx != nil { name = idx.Name } } return m.DB.Raw( "SELECT count(*) FROM information_schema.statistics WHERE table_schema = ? AND table_name = ? AND index_name = ?", currentDatabase, stmt.Table, name, ).Row().Scan(&count) }) return count > 0 } // RenameIndex rename index from oldName to newName func (m Migrator) RenameIndex(value interface{}, oldName, newName string) error { return m.RunWithValue(value, func(stmt *gorm.Statement) error { return m.DB.Exec( "ALTER TABLE ? RENAME INDEX ? TO ?", m.CurrentTable(stmt), clause.Column{Name: oldName}, clause.Column{Name: newName}, ).Error }) } // CurrentDatabase returns current database name func (m Migrator) CurrentDatabase() (name string) { m.DB.Raw("SELECT DATABASE()").Row().Scan(&name) return } // ReorderModels reorder models according to constraint dependencies func (m Migrator) ReorderModels(values []interface{}, autoAdd bool) (results []interface{}) { type Dependency struct { *gorm.Statement Depends []*schema.Schema } var ( modelNames, orderedModelNames []string orderedModelNamesMap = map[string]bool{} parsedSchemas = map[*schema.Schema]bool{} valuesMap = map[string]Dependency{} insertIntoOrderedList func(name string) parseDependence func(value interface{}, addToList bool) ) parseDependence = func(value interface{}, addToList bool) { dep := Dependency{ Statement: &gorm.Statement{DB: m.DB, Dest: value}, } beDependedOn := map[*schema.Schema]bool{} // support for special table name if err := dep.ParseWithSpecialTableName(value, m.DB.Statement.Table); err != nil { m.DB.Logger.Error(context.Background(), "failed to parse value %#v, got error %v", value, err) } if _, ok := parsedSchemas[dep.Statement.Schema]; ok { return } parsedSchemas[dep.Statement.Schema] = true if !m.DB.IgnoreRelationshipsWhenMigrating { for _, rel := range dep.Schema.Relationships.Relations { if rel.Field.IgnoreMigration { continue } if c := rel.ParseConstraint(); c != nil && c.Schema == dep.Statement.Schema && c.Schema != c.ReferenceSchema { dep.Depends = append(dep.Depends, c.ReferenceSchema) } if rel.Type == schema.HasOne || rel.Type == schema.HasMany { beDependedOn[rel.FieldSchema] = true } if rel.JoinTable != nil { // append join value defer func(rel *schema.Relationship, joinValue interface{}) { if !beDependedOn[rel.FieldSchema] { dep.Depends = append(dep.Depends, rel.FieldSchema) } else { fieldValue := reflect.New(rel.FieldSchema.ModelType).Interface() parseDependence(fieldValue, autoAdd) } parseDependence(joinValue, autoAdd) }(rel, reflect.New(rel.JoinTable.ModelType).Interface()) } } } valuesMap[dep.Schema.Table] = dep if addToList { modelNames = append(modelNames, dep.Schema.Table) } } insertIntoOrderedList = func(name string) { if _, ok := orderedModelNamesMap[name]; ok { return // avoid loop } orderedModelNamesMap[name] = true if autoAdd { dep := valuesMap[name] for _, d := range dep.Depends { if _, ok := valuesMap[d.Table]; ok { insertIntoOrderedList(d.Table) } else { parseDependence(reflect.New(d.ModelType).Interface(), autoAdd) insertIntoOrderedList(d.Table) } } } orderedModelNames = append(orderedModelNames, name) } for _, value := range values { if v, ok := value.(string); ok { results = append(results, v) } else { parseDependence(value, true) } } for _, name := range modelNames { insertIntoOrderedList(name) } for _, name := range orderedModelNames { results = append(results, valuesMap[name].Statement.Dest) } return } // CurrentTable returns current statement's table expression func (m Migrator) CurrentTable(stmt *gorm.Statement) interface{} { if stmt.TableExpr != nil { return *stmt.TableExpr } return clause.Table{Name: stmt.Table} } // GetIndexes return Indexes []gorm.Index and execErr error func (m Migrator) GetIndexes(dst interface{}) ([]gorm.Index, error) { return nil, errors.New("not support") } // GetTypeAliases return database type aliases func (m Migrator) GetTypeAliases(databaseTypeName string) []string { return nil } // TableType return tableType gorm.TableType and execErr error func (m Migrator) TableType(dst interface{}) (gorm.TableType, error) { return nil, errors.New("not support") } ================================================ FILE: migrator/table_type.go ================================================ package migrator import ( "database/sql" ) // TableType table type implements TableType interface type TableType struct { SchemaValue string NameValue string TypeValue string CommentValue sql.NullString } // Schema returns the schema of the table. func (ct TableType) Schema() string { return ct.SchemaValue } // Name returns the name of the table. func (ct TableType) Name() string { return ct.NameValue } // Type returns the type of the table. func (ct TableType) Type() string { return ct.TypeValue } // Comment returns the comment of current table. func (ct TableType) Comment() (comment string, ok bool) { return ct.CommentValue.String, ct.CommentValue.Valid } ================================================ FILE: migrator.go ================================================ package gorm import ( "reflect" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) // Migrator returns migrator func (db *DB) Migrator() Migrator { tx := db.getInstance() // apply scopes to migrator for len(tx.Statement.scopes) > 0 { tx = tx.executeScopes() } return tx.Dialector.Migrator(tx.Session(&Session{})) } // AutoMigrate run auto migration for given models func (db *DB) AutoMigrate(dst ...interface{}) error { return db.Migrator().AutoMigrate(dst...) } // ViewOption view option type ViewOption struct { Replace bool // If true, exec `CREATE`. If false, exec `CREATE OR REPLACE` CheckOption string // optional. e.g. `WITH [ CASCADED | LOCAL ] CHECK OPTION` Query *DB // required subquery. } // ColumnType column type interface type ColumnType interface { Name() string DatabaseTypeName() string // varchar ColumnType() (columnType string, ok bool) // varchar(64) PrimaryKey() (isPrimaryKey bool, ok bool) AutoIncrement() (isAutoIncrement bool, ok bool) Length() (length int64, ok bool) DecimalSize() (precision int64, scale int64, ok bool) Nullable() (nullable bool, ok bool) Unique() (unique bool, ok bool) ScanType() reflect.Type Comment() (value string, ok bool) DefaultValue() (value string, ok bool) } type Index interface { Table() string Name() string Columns() []string PrimaryKey() (isPrimaryKey bool, ok bool) Unique() (unique bool, ok bool) Option() string } // TableType table type interface type TableType interface { Schema() string Name() string Type() string Comment() (comment string, ok bool) } // Migrator migrator interface type Migrator interface { // AutoMigrate AutoMigrate(dst ...interface{}) error // Database CurrentDatabase() string FullDataTypeOf(*schema.Field) clause.Expr GetTypeAliases(databaseTypeName string) []string // Tables CreateTable(dst ...interface{}) error DropTable(dst ...interface{}) error HasTable(dst interface{}) bool RenameTable(oldName, newName interface{}) error GetTables() (tableList []string, err error) TableType(dst interface{}) (TableType, error) // Columns AddColumn(dst interface{}, field string) error DropColumn(dst interface{}, field string) error AlterColumn(dst interface{}, field string) error MigrateColumn(dst interface{}, field *schema.Field, columnType ColumnType) error // MigrateColumnUnique migrate column's UNIQUE constraint, it's part of MigrateColumn. MigrateColumnUnique(dst interface{}, field *schema.Field, columnType ColumnType) error HasColumn(dst interface{}, field string) bool RenameColumn(dst interface{}, oldName, field string) error ColumnTypes(dst interface{}) ([]ColumnType, error) // Views CreateView(name string, option ViewOption) error DropView(name string) error // Constraints CreateConstraint(dst interface{}, name string) error DropConstraint(dst interface{}, name string) error HasConstraint(dst interface{}, name string) bool // Indexes CreateIndex(dst interface{}, name string) error DropIndex(dst interface{}, name string) error HasIndex(dst interface{}, name string) bool RenameIndex(dst interface{}, oldName, newName string) error GetIndexes(dst interface{}) ([]Index, error) } ================================================ FILE: model.go ================================================ package gorm import "time" // Model a basic GoLang struct which includes the following fields: ID, CreatedAt, UpdatedAt, DeletedAt // It may be embedded into your model or you may build your own model without it // // type User struct { // gorm.Model // } type Model struct { ID uint `gorm:"primarykey"` CreatedAt time.Time UpdatedAt time.Time DeletedAt DeletedAt `gorm:"index"` } ================================================ FILE: prepare_stmt.go ================================================ package gorm import ( "context" "database/sql" "database/sql/driver" "errors" "reflect" "sync" "time" "gorm.io/gorm/internal/stmt_store" ) type PreparedStmtDB struct { Stmts stmt_store.Store Mux *sync.RWMutex ConnPool } // NewPreparedStmtDB creates and initializes a new instance of PreparedStmtDB. // // Parameters: // - connPool: A connection pool that implements the ConnPool interface, used for managing database connections. // - maxSize: The maximum number of prepared statements that can be stored in the statement store. // - ttl: The time-to-live duration for each prepared statement in the store. Statements older than this duration will be automatically removed. // // Returns: // - A pointer to a PreparedStmtDB instance, which manages prepared statements using the provided connection pool and configuration. func NewPreparedStmtDB(connPool ConnPool, maxSize int, ttl time.Duration) *PreparedStmtDB { return &PreparedStmtDB{ ConnPool: connPool, // Assigns the provided connection pool to manage database connections. Stmts: stmt_store.New(maxSize, ttl), // Initializes a new statement store with the specified maximum size and TTL. Mux: &sync.RWMutex{}, // Sets up a read-write mutex for synchronizing access to the statement store. } } // GetDBConn returns the underlying *sql.DB connection func (db *PreparedStmtDB) GetDBConn() (*sql.DB, error) { if sqldb, ok := db.ConnPool.(*sql.DB); ok { return sqldb, nil } if dbConnector, ok := db.ConnPool.(GetDBConnector); ok && dbConnector != nil { return dbConnector.GetDBConn() } return nil, ErrInvalidDB } // Close closes all prepared statements in the store func (db *PreparedStmtDB) Close() { db.Mux.Lock() defer db.Mux.Unlock() for _, key := range db.Stmts.Keys() { db.Stmts.Delete(key) } } // Reset Deprecated use Close instead func (db *PreparedStmtDB) Reset() { db.Close() } func (db *PreparedStmtDB) prepare(ctx context.Context, conn ConnPool, isTransaction bool, query string) (_ *stmt_store.Stmt, err error) { db.Mux.RLock() if db.Stmts != nil { if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { db.Mux.RUnlock() return stmt, stmt.Error() } } db.Mux.RUnlock() // retry db.Mux.Lock() if db.Stmts != nil { if stmt, ok := db.Stmts.Get(query); ok && (!stmt.Transaction || isTransaction) { db.Mux.Unlock() return stmt, stmt.Error() } } return db.Stmts.New(ctx, query, isTransaction, conn, db.Mux) } func (db *PreparedStmtDB) BeginTx(ctx context.Context, opt *sql.TxOptions) (ConnPool, error) { if beginner, ok := db.ConnPool.(TxBeginner); ok { tx, err := beginner.BeginTx(ctx, opt) return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, err } beginner, ok := db.ConnPool.(ConnPoolBeginner) if !ok { return nil, ErrInvalidTransaction } connPool, err := beginner.BeginTx(ctx, opt) if err != nil { return nil, err } if tx, ok := connPool.(Tx); ok { return &PreparedStmtTX{PreparedStmtDB: db, Tx: tx}, nil } return nil, ErrInvalidTransaction } func (db *PreparedStmtDB) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { result, err = stmt.ExecContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { db.Stmts.Delete(query) } } return result, err } func (db *PreparedStmtDB) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { rows, err = stmt.QueryContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { db.Stmts.Delete(query) } } return rows, err } func (db *PreparedStmtDB) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { stmt, err := db.prepare(ctx, db.ConnPool, false, query) if err == nil { return stmt.QueryRowContext(ctx, args...) } return &sql.Row{} } func (db *PreparedStmtDB) Ping() error { conn, err := db.GetDBConn() if err != nil { return err } return conn.Ping() } type PreparedStmtTX struct { Tx PreparedStmtDB *PreparedStmtDB } func (db *PreparedStmtTX) GetDBConn() (*sql.DB, error) { return db.PreparedStmtDB.GetDBConn() } func (tx *PreparedStmtTX) Commit() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Commit() } return ErrInvalidTransaction } func (tx *PreparedStmtTX) Rollback() error { if tx.Tx != nil && !reflect.ValueOf(tx.Tx).IsNil() { return tx.Tx.Rollback() } return ErrInvalidTransaction } func (tx *PreparedStmtTX) ExecContext(ctx context.Context, query string, args ...interface{}) (result sql.Result, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { result, err = tx.Tx.StmtContext(ctx, stmt.Stmt).ExecContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Stmts.Delete(query) } } return result, err } func (tx *PreparedStmtTX) QueryContext(ctx context.Context, query string, args ...interface{}) (rows *sql.Rows, err error) { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { rows, err = tx.Tx.StmtContext(ctx, stmt.Stmt).QueryContext(ctx, args...) if errors.Is(err, driver.ErrBadConn) { tx.PreparedStmtDB.Stmts.Delete(query) } } return rows, err } func (tx *PreparedStmtTX) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { stmt, err := tx.PreparedStmtDB.prepare(ctx, tx.Tx, true, query) if err == nil { return tx.Tx.StmtContext(ctx, stmt.Stmt).QueryRowContext(ctx, args...) } return &sql.Row{} } func (tx *PreparedStmtTX) Ping() error { conn, err := tx.GetDBConn() if err != nil { return err } return conn.Ping() } ================================================ FILE: scan.go ================================================ package gorm import ( "database/sql" "database/sql/driver" "reflect" "strings" "time" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) // prepareValues prepare values slice func prepareValues(values []interface{}, db *DB, columnTypes []*sql.ColumnType, columns []string) { if db.Statement.Schema != nil { for idx, name := range columns { if field := db.Statement.Schema.LookUpField(name); field != nil { values[idx] = reflect.New(reflect.PointerTo(field.FieldType)).Interface() continue } values[idx] = new(interface{}) } } else if len(columnTypes) > 0 { for idx, columnType := range columnTypes { if columnType.ScanType() != nil { values[idx] = reflect.New(reflect.PointerTo(columnType.ScanType())).Interface() } else { values[idx] = new(interface{}) } } } else { for idx := range columns { values[idx] = new(interface{}) } } } func scanIntoMap(mapValue map[string]interface{}, values []interface{}, columns []string) { for idx, column := range columns { if reflectValue := reflect.Indirect(reflect.Indirect(reflect.ValueOf(values[idx]))); reflectValue.IsValid() { mapValue[column] = reflectValue.Interface() if valuer, ok := mapValue[column].(driver.Valuer); ok { mapValue[column], _ = valuer.Value() } else if b, ok := mapValue[column].(sql.RawBytes); ok { mapValue[column] = string(b) } } else { mapValue[column] = nil } } } func (db *DB) scanIntoStruct(rows Rows, reflectValue reflect.Value, values []interface{}, fields []*schema.Field, joinFields [][]*schema.Field) { for idx, field := range fields { if field != nil { values[idx] = field.NewValuePool.Get() } else if len(fields) == 1 { if reflectValue.CanAddr() { values[idx] = reflectValue.Addr().Interface() } else { values[idx] = reflectValue.Interface() } } } db.RowsAffected++ db.AddError(rows.Scan(values...)) joinedNestedSchemaMap := make(map[string]interface{}) for idx, field := range fields { if field == nil { continue } if len(joinFields) == 0 || len(joinFields[idx]) == 0 { db.AddError(field.Set(db.Statement.Context, reflectValue, values[idx])) } else { // joinFields count is larger than 2 when using join var isNilPtrValue bool var relValue reflect.Value // does not contain raw dbname nestedJoinSchemas := joinFields[idx][:len(joinFields[idx])-1] // current reflect value currentReflectValue := reflectValue fullRels := make([]string, 0, len(nestedJoinSchemas)) for _, joinSchema := range nestedJoinSchemas { fullRels = append(fullRels, joinSchema.Name) relValue = joinSchema.ReflectValueOf(db.Statement.Context, currentReflectValue) if relValue.Kind() == reflect.Ptr { fullRelsName := utils.JoinNestedRelationNames(fullRels) // same nested structure if _, ok := joinedNestedSchemaMap[fullRelsName]; !ok { if value := reflect.ValueOf(values[idx]).Elem(); value.Kind() == reflect.Ptr && value.IsNil() { isNilPtrValue = true break } relValue.Set(reflect.New(relValue.Type().Elem())) joinedNestedSchemaMap[fullRelsName] = nil } } currentReflectValue = relValue } if !isNilPtrValue { // ignore if value is nil f := joinFields[idx][len(joinFields[idx])-1] db.AddError(f.Set(db.Statement.Context, relValue, values[idx])) } } // release data to pool field.NewValuePool.Put(values[idx]) } } // ScanMode scan data mode type ScanMode uint8 // scan modes const ( ScanInitialized ScanMode = 1 << 0 // 1 ScanUpdate ScanMode = 1 << 1 // 2 ScanOnConflictDoNothing ScanMode = 1 << 2 // 4 ) // Scan scan rows into db statement func Scan(rows Rows, db *DB, mode ScanMode) { var ( columns, _ = rows.Columns() values = make([]interface{}, len(columns)) initialized = mode&ScanInitialized != 0 update = mode&ScanUpdate != 0 onConflictDonothing = mode&ScanOnConflictDoNothing != 0 ) if len(db.Statement.ColumnMapping) > 0 { for i, column := range columns { v, ok := db.Statement.ColumnMapping[column] if ok { columns[i] = v } } } db.RowsAffected = 0 switch dest := db.Statement.Dest.(type) { case map[string]interface{}, *map[string]interface{}: if initialized || rows.Next() { columnTypes, _ := rows.ColumnTypes() prepareValues(values, db, columnTypes, columns) db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue, ok := dest.(map[string]interface{}) if !ok { if v, ok := dest.(*map[string]interface{}); ok { if *v == nil { *v = map[string]interface{}{} } mapValue = *v } } scanIntoMap(mapValue, values, columns) } case *[]map[string]interface{}: columnTypes, _ := rows.ColumnTypes() for initialized || rows.Next() { prepareValues(values, db, columnTypes, columns) initialized = false db.RowsAffected++ db.AddError(rows.Scan(values...)) mapValue := map[string]interface{}{} scanIntoMap(mapValue, values, columns) *dest = append(*dest, mapValue) } case *int, *int8, *int16, *int32, *int64, *uint, *uint8, *uint16, *uint32, *uint64, *uintptr, *float32, *float64, *bool, *string, *time.Time, *sql.NullInt32, *sql.NullInt64, *sql.NullFloat64, *sql.NullBool, *sql.NullString, *sql.NullTime: for initialized || rows.Next() { initialized = false db.RowsAffected++ db.AddError(rows.Scan(dest)) } default: var ( fields = make([]*schema.Field, len(columns)) joinFields [][]*schema.Field sch = db.Statement.Schema reflectValue = db.Statement.ReflectValue ) if reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } reflectValueType := reflectValue.Type() switch reflectValueType.Kind() { case reflect.Array, reflect.Slice: reflectValueType = reflectValueType.Elem() } isPtr := reflectValueType.Kind() == reflect.Ptr if isPtr { reflectValueType = reflectValueType.Elem() } if sch != nil { if reflectValueType != sch.ModelType && reflectValueType.Kind() == reflect.Struct { sch, _ = schema.Parse(db.Statement.Dest, db.cacheStore, db.NamingStrategy) } if len(columns) == 1 { // Is Pluck if _, ok := reflect.New(reflectValueType).Interface().(sql.Scanner); (reflectValueType != sch.ModelType && ok) || // is scanner reflectValueType.Kind() != reflect.Struct || // is not struct sch.ModelType.ConvertibleTo(schema.TimeReflectType) { // is time sch = nil } } // Not Pluck if sch != nil { matchedFieldCount := make(map[string]int, len(columns)) for idx, column := range columns { if field := sch.LookUpField(column); field != nil && field.Readable { fields[idx] = field if count, ok := matchedFieldCount[column]; ok { // handle duplicate fields for _, selectField := range sch.Fields { if selectField.DBName == column && selectField.Readable { if count == 0 { matchedFieldCount[column]++ fields[idx] = selectField break } count-- } } } else { matchedFieldCount[column] = 1 } } else if names := utils.SplitNestedRelationName(column); len(names) > 1 { // has nested relation aliasName := utils.JoinNestedRelationNames(names[0 : len(names)-1]) for _, join := range db.Statement.Joins { if join.Alias == aliasName { names = append(strings.Split(join.Name, "."), names[len(names)-1]) break } } if rel, ok := sch.Relationships.Relations[names[0]]; ok { subNameCount := len(names) // nested relation fields relFields := make([]*schema.Field, 0, subNameCount-1) relFields = append(relFields, rel.Field) for _, name := range names[1 : subNameCount-1] { rel = rel.FieldSchema.Relationships.Relations[name] relFields = append(relFields, rel.Field) } // latest name is raw dbname dbName := names[subNameCount-1] if field := rel.FieldSchema.LookUpField(dbName); field != nil && field.Readable { fields[idx] = field if len(joinFields) == 0 { joinFields = make([][]*schema.Field, len(columns)) } relFields = append(relFields, field) joinFields[idx] = relFields continue } } var val interface{} values[idx] = &val } else { var val interface{} values[idx] = &val } } } } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: var ( elem reflect.Value isArrayKind = reflectValue.Kind() == reflect.Array ) if !update || reflectValue.Len() == 0 { update = false if isArrayKind { db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) } else { // if the slice cap is externally initialized, the externally initialized slice is directly used here if reflectValue.Cap() == 0 { db.Statement.ReflectValue.Set(reflect.MakeSlice(reflectValue.Type(), 0, 20)) } else { reflectValue.SetLen(0) db.Statement.ReflectValue.Set(reflectValue) } } } for initialized || rows.Next() { BEGIN: initialized = false if update { if int(db.RowsAffected) >= reflectValue.Len() { return } elem = reflectValue.Index(int(db.RowsAffected)) if onConflictDonothing { for _, field := range fields { if _, ok := field.ValueOf(db.Statement.Context, elem); !ok { db.RowsAffected++ goto BEGIN } } } } else { elem = reflect.New(reflectValueType) } db.scanIntoStruct(rows, elem, values, fields, joinFields) if !update { if !isPtr { elem = elem.Elem() } if isArrayKind { if reflectValue.Len() >= int(db.RowsAffected) { reflectValue.Index(int(db.RowsAffected - 1)).Set(elem) } } else { reflectValue = reflect.Append(reflectValue, elem) } } } if !update { db.Statement.ReflectValue.Set(reflectValue) } case reflect.Struct, reflect.Ptr: if initialized || rows.Next() { if mode == ScanInitialized && reflectValue.Kind() == reflect.Struct { db.Statement.ReflectValue.Set(reflect.Zero(reflectValue.Type())) } db.scanIntoStruct(rows, reflectValue, values, fields, joinFields) } default: db.AddError(rows.Scan(dest)) } } if err := rows.Err(); err != nil && err != db.Error { db.AddError(err) } if db.RowsAffected == 0 && db.Statement.RaiseErrorOnNotFound && db.Error == nil { db.AddError(ErrRecordNotFound) } } ================================================ FILE: schema/callbacks_test.go ================================================ package schema_test import ( "reflect" "sync" "testing" "gorm.io/gorm" "gorm.io/gorm/schema" ) type UserWithCallback struct{} func (UserWithCallback) BeforeSave(*gorm.DB) error { return nil } func (UserWithCallback) AfterCreate(*gorm.DB) error { return nil } func TestCallback(t *testing.T) { user, err := schema.Parse(&UserWithCallback{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user with callback, got error %v", err) } for _, str := range []string{"BeforeSave", "AfterCreate"} { if !reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { t.Errorf("%v should be true", str) } } for _, str := range []string{"BeforeCreate", "BeforeUpdate", "AfterUpdate", "AfterSave", "BeforeDelete", "AfterDelete", "AfterFind"} { if reflect.Indirect(reflect.ValueOf(user)).FieldByName(str).Interface().(bool) { t.Errorf("%v should be false", str) } } } ================================================ FILE: schema/constraint.go ================================================ package schema import ( "regexp" "strings" "gorm.io/gorm/clause" ) // reg match english letters and midline var regEnLetterAndMidline = regexp.MustCompile(`^[\w-]+$`) type CheckConstraint struct { Name string Constraint string // length(phone) >= 10 *Field } func (chk *CheckConstraint) GetName() string { return chk.Name } func (chk *CheckConstraint) Build() (sql string, vars []interface{}) { return "CONSTRAINT ? CHECK (?)", []interface{}{clause.Column{Name: chk.Name}, clause.Expr{SQL: chk.Constraint}} } // ParseCheckConstraints parse schema check constraints func (schema *Schema) ParseCheckConstraints() map[string]CheckConstraint { checks := map[string]CheckConstraint{} for _, field := range schema.FieldsByDBName { if chk := field.TagSettings["CHECK"]; chk != "" { names := strings.Split(chk, ",") if len(names) > 1 && regEnLetterAndMidline.MatchString(names[0]) { checks[names[0]] = CheckConstraint{Name: names[0], Constraint: strings.Join(names[1:], ","), Field: field} } else { if names[0] == "" { chk = strings.Join(names[1:], ",") } name := schema.namer.CheckerName(schema.Table, field.DBName) checks[name] = CheckConstraint{Name: name, Constraint: chk, Field: field} } } } return checks } type UniqueConstraint struct { Name string Field *Field } func (uni *UniqueConstraint) GetName() string { return uni.Name } func (uni *UniqueConstraint) Build() (sql string, vars []interface{}) { return "CONSTRAINT ? UNIQUE (?)", []interface{}{clause.Column{Name: uni.Name}, clause.Column{Name: uni.Field.DBName}} } // ParseUniqueConstraints parse schema unique constraints func (schema *Schema) ParseUniqueConstraints() map[string]UniqueConstraint { uniques := make(map[string]UniqueConstraint) for _, field := range schema.Fields { if field.Unique { name := schema.namer.UniqueName(schema.Table, field.DBName) uniques[name] = UniqueConstraint{Name: name, Field: field} } } return uniques } ================================================ FILE: schema/constraint_test.go ================================================ package schema_test import ( "reflect" "sync" "testing" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) type UserCheck struct { Name string `gorm:"check:name_checker,name <> 'jinzhu'"` Name2 string `gorm:"check:name <> 'jinzhu'"` Name3 string `gorm:"check:,name <> 'jinzhu'"` } func TestParseCheck(t *testing.T) { user, err := schema.Parse(&UserCheck{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user check, got error %v", err) } results := map[string]schema.CheckConstraint{ "name_checker": { Name: "name_checker", Constraint: "name <> 'jinzhu'", }, "chk_user_checks_name2": { Name: "chk_user_checks_name2", Constraint: "name <> 'jinzhu'", }, "chk_user_checks_name3": { Name: "chk_user_checks_name3", Constraint: "name <> 'jinzhu'", }, } checks := user.ParseCheckConstraints() for k, result := range results { v, ok := checks[k] if !ok { t.Errorf("Failed to found check %v from parsed checks %+v", k, checks) } for _, name := range []string{"Name", "Constraint"} { if reflect.ValueOf(result).FieldByName(name).Interface() != reflect.ValueOf(v).FieldByName(name).Interface() { t.Errorf( "check %v %v should equal, expects %v, got %v", k, name, reflect.ValueOf(result).FieldByName(name).Interface(), reflect.ValueOf(v).FieldByName(name).Interface(), ) } } } } func TestParseUniqueConstraints(t *testing.T) { type UserUnique struct { Name1 string `gorm:"unique"` Name2 string `gorm:"uniqueIndex"` } user, err := schema.Parse(&UserUnique{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user unique, got error %v", err) } constraints := user.ParseUniqueConstraints() results := map[string]schema.UniqueConstraint{ "uni_user_uniques_name1": { Name: "uni_user_uniques_name1", Field: &schema.Field{Name: "Name1", Unique: true}, }, } for k, result := range results { v, ok := constraints[k] if !ok { t.Errorf("Failed to found unique constraint %v from parsed constraints %+v", k, constraints) } tests.AssertObjEqual(t, result, v, "Name") tests.AssertObjEqual(t, result.Field, v.Field, "Name", "Unique", "UniqueIndex") } } ================================================ FILE: schema/field.go ================================================ package schema import ( "context" "database/sql" "database/sql/driver" "fmt" "reflect" "strconv" "strings" "sync" "time" "github.com/jinzhu/now" "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) // special types' reflect type var ( TimeReflectType = reflect.TypeOf(time.Time{}) TimePtrReflectType = reflect.TypeOf(&time.Time{}) ByteReflectType = reflect.TypeOf(uint8(0)) ) type ( // DataType GORM data type DataType string // TimeType GORM time type TimeType int64 ) // GORM time types const ( UnixTime TimeType = 1 UnixSecond TimeType = 2 UnixMillisecond TimeType = 3 UnixNanosecond TimeType = 4 ) // GORM fields types const ( Bool DataType = "bool" Int DataType = "int" Uint DataType = "uint" Float DataType = "float" String DataType = "string" Time DataType = "time" Bytes DataType = "bytes" ) const DefaultAutoIncrementIncrement int64 = 1 // Field is the representation of model schema's field type Field struct { Name string DBName string BindNames []string EmbeddedBindNames []string DataType DataType GORMDataType DataType PrimaryKey bool AutoIncrement bool AutoIncrementIncrement int64 Creatable bool Updatable bool Readable bool AutoCreateTime TimeType AutoUpdateTime TimeType HasDefaultValue bool DefaultValue string DefaultValueInterface interface{} NotNull bool Unique bool Comment string Size int Precision int Scale int IgnoreMigration bool FieldType reflect.Type IndirectFieldType reflect.Type StructField reflect.StructField Tag reflect.StructTag TagSettings map[string]string Schema *Schema EmbeddedSchema *Schema OwnerSchema *Schema ReflectValueOf func(context.Context, reflect.Value) reflect.Value ValueOf func(context.Context, reflect.Value) (value interface{}, zero bool) Set func(context.Context, reflect.Value, interface{}) error Serializer SerializerInterface NewValuePool FieldNewValuePool // In some db (e.g. MySQL), Unique and UniqueIndex are indistinguishable. // When a column has a (not Mul) UniqueIndex, Migrator always reports its gorm.ColumnType is Unique. // It causes field unnecessarily migration. // Therefore, we need to record the UniqueIndex on this column (exclude Mul UniqueIndex) for MigrateColumnUnique. UniqueIndex string } func (field *Field) BindName() string { return strings.Join(field.BindNames, ".") } // ParseField parses reflect.StructField to Field func (schema *Schema) ParseField(fieldStruct reflect.StructField) *Field { var ( err error tagSetting = ParseTagSetting(fieldStruct.Tag.Get("gorm"), ";") ) field := &Field{ Name: fieldStruct.Name, DBName: tagSetting["COLUMN"], BindNames: []string{fieldStruct.Name}, EmbeddedBindNames: []string{fieldStruct.Name}, FieldType: fieldStruct.Type, IndirectFieldType: fieldStruct.Type, StructField: fieldStruct, Tag: fieldStruct.Tag, TagSettings: tagSetting, Schema: schema, Creatable: true, Updatable: true, Readable: true, PrimaryKey: utils.CheckTruth(tagSetting["PRIMARYKEY"], tagSetting["PRIMARY_KEY"]), AutoIncrement: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), HasDefaultValue: utils.CheckTruth(tagSetting["AUTOINCREMENT"]), NotNull: utils.CheckTruth(tagSetting["NOT NULL"], tagSetting["NOTNULL"]), Unique: utils.CheckTruth(tagSetting["UNIQUE"]), Comment: tagSetting["COMMENT"], AutoIncrementIncrement: DefaultAutoIncrementIncrement, } for field.IndirectFieldType.Kind() == reflect.Ptr { field.IndirectFieldType = field.IndirectFieldType.Elem() } fieldValue := reflect.New(field.IndirectFieldType) // if field is valuer, used its value or first field as data type valuer, isValuer := fieldValue.Interface().(driver.Valuer) if isValuer { if _, ok := fieldValue.Interface().(GormDataTypeInterface); !ok { if v, err := valuer.Value(); reflect.ValueOf(v).IsValid() && err == nil { fieldValue = reflect.ValueOf(v) } // Use the field struct's first field type as data type, e.g: use `string` for sql.NullString var getRealFieldValue func(reflect.Value) getRealFieldValue = func(v reflect.Value) { var ( rv = reflect.Indirect(v) rvType = rv.Type() ) if rv.Kind() == reflect.Struct && !rvType.ConvertibleTo(TimeReflectType) { for i := 0; i < rvType.NumField(); i++ { for key, value := range ParseTagSetting(rvType.Field(i).Tag.Get("gorm"), ";") { if _, ok := field.TagSettings[key]; !ok { field.TagSettings[key] = value } } } for i := 0; i < rvType.NumField(); i++ { newFieldType := rvType.Field(i).Type for newFieldType.Kind() == reflect.Ptr { newFieldType = newFieldType.Elem() } fieldValue = reflect.New(newFieldType) if rvType != reflect.Indirect(fieldValue).Type() { getRealFieldValue(fieldValue) } if fieldValue.IsValid() { return } } } } getRealFieldValue(fieldValue) } } if v, isSerializer := fieldValue.Interface().(SerializerInterface); isSerializer { field.DataType = String field.Serializer = v } else { serializerName := field.TagSettings["JSON"] if serializerName == "" { serializerName = field.TagSettings["SERIALIZER"] } if serializerName != "" { if serializer, ok := GetSerializer(serializerName); ok { // Set default data type to string for serializer field.DataType = String field.Serializer = serializer } else { schema.err = fmt.Errorf("invalid serializer type %v", serializerName) } } } if num, ok := field.TagSettings["AUTOINCREMENTINCREMENT"]; ok { field.AutoIncrementIncrement, _ = strconv.ParseInt(num, 10, 64) } if v, ok := field.TagSettings["DEFAULT"]; ok { field.HasDefaultValue = true field.DefaultValue = v } if num, ok := field.TagSettings["SIZE"]; ok { if field.Size, err = strconv.Atoi(num); err != nil { field.Size = -1 } } if p, ok := field.TagSettings["PRECISION"]; ok { field.Precision, _ = strconv.Atoi(p) } if s, ok := field.TagSettings["SCALE"]; ok { field.Scale, _ = strconv.Atoi(s) } // default value is function or null or blank (primary keys) field.DefaultValue = strings.TrimSpace(field.DefaultValue) skipParseDefaultValue := strings.Contains(field.DefaultValue, "(") && strings.Contains(field.DefaultValue, ")") || strings.ToLower(field.DefaultValue) == "null" || field.DefaultValue == "" switch reflect.Indirect(fieldValue).Kind() { case reflect.Bool: field.DataType = Bool if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseBool(field.DefaultValue); err != nil { schema.err = fmt.Errorf("failed to parse %s as default value for bool, got error: %v", field.DefaultValue, err) } } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.DataType = Int if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseInt(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %s as default value for int, got error: %v", field.DefaultValue, err) } } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.DataType = Uint if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseUint(field.DefaultValue, 0, 64); err != nil { schema.err = fmt.Errorf("failed to parse %s as default value for uint, got error: %v", field.DefaultValue, err) } } case reflect.Float32, reflect.Float64: field.DataType = Float if field.HasDefaultValue && !skipParseDefaultValue { if field.DefaultValueInterface, err = strconv.ParseFloat(field.DefaultValue, 64); err != nil { schema.err = fmt.Errorf("failed to parse %s as default value for float, got error: %v", field.DefaultValue, err) } } case reflect.String: field.DataType = String if field.HasDefaultValue && !skipParseDefaultValue { field.DefaultValue = strings.Trim(field.DefaultValue, "'") field.DefaultValue = strings.Trim(field.DefaultValue, `"`) field.DefaultValueInterface = field.DefaultValue } case reflect.Struct: if _, ok := fieldValue.Interface().(*time.Time); ok { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(TimeReflectType) { field.DataType = Time } else if fieldValue.Type().ConvertibleTo(TimePtrReflectType) { field.DataType = Time } if field.HasDefaultValue && !skipParseDefaultValue && field.DataType == Time { if t, err := now.Parse(field.DefaultValue); err == nil { field.DefaultValueInterface = t } } case reflect.Array, reflect.Slice: if reflect.Indirect(fieldValue).Type().Elem() == ByteReflectType && field.DataType == "" { field.DataType = Bytes } } if dataTyper, ok := fieldValue.Interface().(GormDataTypeInterface); ok { field.DataType = DataType(dataTyper.GormDataType()) } if v, ok := field.TagSettings["AUTOCREATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "CreatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoCreateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { field.AutoCreateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoCreateTime = UnixMillisecond } else { field.AutoCreateTime = UnixSecond } } if v, ok := field.TagSettings["AUTOUPDATETIME"]; (ok && utils.CheckTruth(v)) || (!ok && field.Name == "UpdatedAt" && (field.DataType == Time || field.DataType == Int || field.DataType == Uint)) { if field.DataType == Time { field.AutoUpdateTime = UnixTime } else if strings.ToUpper(v) == "NANO" { field.AutoUpdateTime = UnixNanosecond } else if strings.ToUpper(v) == "MILLI" { field.AutoUpdateTime = UnixMillisecond } else { field.AutoUpdateTime = UnixSecond } } if field.GORMDataType == "" { field.GORMDataType = field.DataType } if val, ok := field.TagSettings["TYPE"]; ok { lowerVal := DataType(strings.ToLower(val)) switch lowerVal { case Bool, Int, Uint, Float, String, Time, Bytes: field.DataType = lowerVal default: field.DataType = DataType(val) } } if field.Size == 0 { switch reflect.Indirect(fieldValue).Kind() { case reflect.Int, reflect.Int64, reflect.Uint, reflect.Uint64, reflect.Float64: field.Size = 64 case reflect.Int8, reflect.Uint8: field.Size = 8 case reflect.Int16, reflect.Uint16: field.Size = 16 case reflect.Int32, reflect.Uint32, reflect.Float32: field.Size = 32 } } // setup permission if val, ok := field.TagSettings["-"]; ok { val = strings.ToLower(strings.TrimSpace(val)) switch val { case "-": field.Creatable = false field.Updatable = false field.Readable = false field.DataType = "" case "all": field.Creatable = false field.Updatable = false field.Readable = false field.DataType = "" field.IgnoreMigration = true case "migration": field.IgnoreMigration = true } } if v, ok := field.TagSettings["->"]; ok { field.Creatable = false field.Updatable = false if strings.ToLower(v) == "false" { field.Readable = false } else { field.Readable = true } } if v, ok := field.TagSettings["<-"]; ok { field.Creatable = true field.Updatable = true if v != "<-" { if !strings.Contains(v, "create") { field.Creatable = false } if !strings.Contains(v, "update") { field.Updatable = false } } } // Normal anonymous field or having `EMBEDDED` tag if _, ok := field.TagSettings["EMBEDDED"]; ok || (field.GORMDataType != Time && field.GORMDataType != Bytes && !isValuer && fieldStruct.Anonymous && (field.Creatable || field.Updatable || field.Readable)) { kind := reflect.Indirect(fieldValue).Kind() switch kind { case reflect.Struct: var err error field.Creatable = false field.Updatable = false field.Readable = false cacheStore := &sync.Map{} cacheStore.Store(embeddedCacheKey, true) if field.EmbeddedSchema, err = getOrParse(fieldValue.Interface(), cacheStore, embeddedNamer{Table: schema.Table, Namer: schema.namer}); err != nil { schema.err = err } for _, ef := range field.EmbeddedSchema.Fields { ef.Schema = schema ef.OwnerSchema = field.EmbeddedSchema ef.BindNames = append([]string{fieldStruct.Name}, ef.BindNames...) if _, ok := field.TagSettings["EMBEDDED"]; ok || !fieldStruct.Anonymous { ef.EmbeddedBindNames = append([]string{fieldStruct.Name}, ef.EmbeddedBindNames...) } // index is negative means is pointer if field.FieldType.Kind() == reflect.Struct { ef.StructField.Index = append([]int{fieldStruct.Index[0]}, ef.StructField.Index...) } else { ef.StructField.Index = append([]int{-fieldStruct.Index[0] - 1}, ef.StructField.Index...) } if prefix, ok := field.TagSettings["EMBEDDEDPREFIX"]; ok && ef.DBName != "" { ef.DBName = prefix + ef.DBName } if ef.PrimaryKey { if !utils.CheckTruth(ef.TagSettings["PRIMARYKEY"], ef.TagSettings["PRIMARY_KEY"]) { ef.PrimaryKey = false if val, ok := ef.TagSettings["AUTOINCREMENT"]; !ok || !utils.CheckTruth(val) { ef.AutoIncrement = false } if !ef.AutoIncrement && ef.DefaultValue == "" { ef.HasDefaultValue = false } } } for k, v := range field.TagSettings { ef.TagSettings[k] = v } } case reflect.Invalid, reflect.Uintptr, reflect.Array, reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer, reflect.Complex64, reflect.Complex128: schema.err = fmt.Errorf("invalid embedded struct for %s's field %s, should be struct, but got %v", field.Schema.Name, field.Name, field.FieldType) } } return field } // create valuer, setter when parse struct func (field *Field) setupValuerAndSetter(modelType reflect.Type) { // Setup NewValuePool field.setupNewValuePool() // ValueOf returns field's value and if it is zero fieldIndex := field.StructField.Index[0] switch { case len(field.StructField.Index) == 1 && fieldIndex >= 0: field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { v = reflect.Indirect(v) fieldValue := v.Field(fieldIndex) return fieldValue.Interface(), fieldValue.IsZero() } default: field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { v = reflect.Indirect(v) for _, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { v = v.Field(fieldIdx) } else { v = v.Field(-fieldIdx - 1) if !v.IsNil() { v = v.Elem() } else { return nil, true } } } fv, zero := v.Interface(), v.IsZero() return fv, zero } } if field.Serializer != nil { oldValuerOf := field.ValueOf field.ValueOf = func(ctx context.Context, v reflect.Value) (interface{}, bool) { value, zero := oldValuerOf(ctx, v) s, ok := value.(SerializerValuerInterface) if !ok { s = field.Serializer } return &serializer{ Field: field, SerializeValuer: s, Destination: v, Context: ctx, fieldValue: value, }, zero } } // ReflectValueOf returns field's reflect value switch { case len(field.StructField.Index) == 1 && fieldIndex >= 0: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { v = reflect.Indirect(v) return v.Field(fieldIndex) } default: field.ReflectValueOf = func(ctx context.Context, v reflect.Value) reflect.Value { v = reflect.Indirect(v) for idx, fieldIdx := range field.StructField.Index { if fieldIdx >= 0 { v = v.Field(fieldIdx) } else { v = v.Field(-fieldIdx - 1) if v.IsNil() { v.Set(reflect.New(v.Type().Elem())) } if idx < len(field.StructField.Index)-1 { v = v.Elem() } } } return v } } fallbackSetter := func(ctx context.Context, value reflect.Value, v interface{}, setter func(context.Context, reflect.Value, interface{}) error) (err error) { if v == nil { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else { reflectV := reflect.ValueOf(v) // Optimal value type acquisition for v reflectValType := reflectV.Type() if reflectValType.AssignableTo(field.FieldType) { if reflectV.Kind() == reflect.Ptr && reflectV.Elem().Kind() == reflect.Ptr { reflectV = reflect.Indirect(reflectV) } field.ReflectValueOf(ctx, value).Set(reflectV) return } else if reflectValType.ConvertibleTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV.Convert(field.FieldType)) return } else if field.FieldType.Kind() == reflect.Ptr { fieldValue := field.ReflectValueOf(ctx, value) fieldType := field.FieldType.Elem() if reflectValType.AssignableTo(fieldType) { if !fieldValue.IsValid() { fieldValue = reflect.New(fieldType) } else if fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldType)) } fieldValue.Elem().Set(reflectV) return } else if reflectValType.ConvertibleTo(fieldType) { if fieldValue.IsNil() { fieldValue.Set(reflect.New(fieldType)) } fieldValue.Elem().Set(reflectV.Convert(fieldType)) return } } if reflectV.Kind() == reflect.Ptr { if reflectV.IsNil() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Type().Elem().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV.Elem()) return } else { err = setter(ctx, value, reflectV.Elem().Interface()) } } else if valuer, ok := v.(driver.Valuer); ok { if v, err = valuer.Value(); err == nil { err = setter(ctx, value, v) } } else if _, ok := v.(clause.Expr); !ok { return fmt.Errorf("failed to set value %#v to field %s", v, field.Name) } } return } // Set switch field.FieldType.Kind() { case reflect.Bool: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **bool: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetBool(**data) } case bool: field.ReflectValueOf(ctx, value).SetBool(data) case int64: field.ReflectValueOf(ctx, value).SetBool(data > 0) case string: b, _ := strconv.ParseBool(data) field.ReflectValueOf(ctx, value).SetBool(b) default: return fallbackSetter(ctx, value, v, field.Set) } return nil } case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case **int64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(**data) } case **int: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(int64(**data)) } case **int8: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(int64(**data)) } case **int16: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(int64(**data)) } case **int32: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetInt(int64(**data)) } case int64: field.ReflectValueOf(ctx, value).SetInt(data) case int: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int8: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int16: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case int32: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint8: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint16: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint32: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case uint64: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float32: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case float64: field.ReflectValueOf(ctx, value).SetInt(int64(data)) case []byte: return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseInt(data, 0, 64); err == nil { field.ReflectValueOf(ctx, value).SetInt(i) } else { return err } case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } case *time.Time: if data != nil { if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixNano()) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { field.ReflectValueOf(ctx, value).SetInt(data.UnixMilli()) } else { field.ReflectValueOf(ctx, value).SetInt(data.Unix()) } } else { field.ReflectValueOf(ctx, value).SetInt(0) } default: return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case **uint64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(**data) } case **uint: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) } case **uint8: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) } case **uint16: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) } case **uint32: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetUint(uint64(**data)) } case uint64: field.ReflectValueOf(ctx, value).SetUint(data) case uint: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint8: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint16: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case uint32: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int64: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int8: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int16: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case int32: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float32: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case float64: field.ReflectValueOf(ctx, value).SetUint(uint64(data)) case []byte: return field.Set(ctx, value, string(data)) case time.Time: if field.AutoCreateTime == UnixNanosecond || field.AutoUpdateTime == UnixNanosecond { field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixNano())) } else if field.AutoCreateTime == UnixMillisecond || field.AutoUpdateTime == UnixMillisecond { field.ReflectValueOf(ctx, value).SetUint(uint64(data.UnixMilli())) } else { field.ReflectValueOf(ctx, value).SetUint(uint64(data.Unix())) } case string: if i, err := strconv.ParseUint(data, 0, 64); err == nil { field.ReflectValueOf(ctx, value).SetUint(i) } else { return err } default: return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.Float32, reflect.Float64: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case **float64: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(**data) } case **float32: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetFloat(float64(**data)) } case float64: field.ReflectValueOf(ctx, value).SetFloat(data) case float32: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int64: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int8: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int16: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case int32: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint8: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint16: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint32: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case uint64: field.ReflectValueOf(ctx, value).SetFloat(float64(data)) case []byte: return field.Set(ctx, value, string(data)) case string: if i, err := strconv.ParseFloat(data, 64); err == nil { field.ReflectValueOf(ctx, value).SetFloat(i) } else { return err } default: return fallbackSetter(ctx, value, v, field.Set) } return err } case reflect.String: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { switch data := v.(type) { case **string: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).SetString(**data) } case string: field.ReflectValueOf(ctx, value).SetString(data) case []byte: field.ReflectValueOf(ctx, value).SetString(string(data)) case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64: field.ReflectValueOf(ctx, value).SetString(utils.ToString(data)) case float64, float32: field.ReflectValueOf(ctx, value).SetString(fmt.Sprintf("%."+strconv.Itoa(field.Precision)+"f", data)) default: return fallbackSetter(ctx, value, v, field.Set) } return err } default: fieldValue := reflect.New(field.FieldType) switch fieldValue.Elem().Interface().(type) { case time.Time: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: if data != nil && *data != nil { field.Set(ctx, value, *data) } case time.Time: field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case *time.Time: if data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(data).Elem()) } else { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(time.Time{})) } case string: if t, err := now.Parse(data); err == nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(ctx, value, v, field.Set) } return nil } case *time.Time: field.Set = func(ctx context.Context, value reflect.Value, v interface{}) error { switch data := v.(type) { case **time.Time: if data != nil && *data != nil { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(*data)) } case time.Time: fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(v)) case *time.Time: field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(v)) case string: if t, err := now.Parse(data); err == nil { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { if v == "" { return nil } fieldValue.Set(reflect.New(field.FieldType.Elem())) } fieldValue.Elem().Set(reflect.ValueOf(t)) } else { return fmt.Errorf("failed to set string %v to time.Time field %s, failed to parse it as time, got error %v", v, field.Name, err) } default: return fallbackSetter(ctx, value, v, field.Set) } return nil } default: if _, ok := fieldValue.Elem().Interface().(sql.Scanner); ok { // pointer scanner field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { return field.Set(ctx, value, reflectV.Elem().Interface()) } else { fieldValue := field.ReflectValueOf(ctx, value) if fieldValue.IsNil() { fieldValue.Set(reflect.New(field.FieldType.Elem())) } if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } err = fieldValue.Interface().(sql.Scanner).Scan(v) } return } } else if _, ok := fieldValue.Interface().(sql.Scanner); ok { // struct scanner field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { reflectV := reflect.ValueOf(v) if !reflectV.IsValid() { field.ReflectValueOf(ctx, value).Set(reflect.New(field.FieldType).Elem()) } else if reflectV.Kind() == reflect.Ptr && reflectV.IsNil() { return } else if reflectV.Type().AssignableTo(field.FieldType) { field.ReflectValueOf(ctx, value).Set(reflectV) } else if reflectV.Kind() == reflect.Ptr { return field.Set(ctx, value, reflectV.Elem().Interface()) } else { if valuer, ok := v.(driver.Valuer); ok { v, _ = valuer.Value() } err = field.ReflectValueOf(ctx, value).Addr().Interface().(sql.Scanner).Scan(v) } return } } else { field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { return fallbackSetter(ctx, value, v, field.Set) } } } } if field.Serializer != nil { var ( oldFieldSetter = field.Set sameElemType bool sameType = field.FieldType == reflect.ValueOf(field.Serializer).Type() ) if reflect.ValueOf(field.Serializer).Kind() == reflect.Ptr { sameElemType = field.FieldType == reflect.ValueOf(field.Serializer).Type().Elem() } serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) serializerType := serializerValue.Type() field.Set = func(ctx context.Context, value reflect.Value, v interface{}) (err error) { if s, ok := v.(*serializer); ok { if s.fieldValue == nil && s.Serializer == nil { rv := field.ReflectValueOf(ctx, value) if rv.IsValid() && rv.CanSet() { rv.Set(reflect.Zero(field.FieldType)) } return nil } if s.fieldValue != nil { err = oldFieldSetter(ctx, value, s.fieldValue) } else if err = s.Serializer.Scan(ctx, field, value, s.value); err == nil { if sameElemType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer).Elem()) } else if sameType { field.ReflectValueOf(ctx, value).Set(reflect.ValueOf(s.Serializer)) } si := reflect.New(serializerType) si.Elem().Set(serializerValue) s.Serializer = si.Interface().(SerializerInterface) } } else { err = oldFieldSetter(ctx, value, v) } return } } } func (field *Field) setupNewValuePool() { if field.Serializer != nil { serializerValue := reflect.Indirect(reflect.ValueOf(field.Serializer)) serializerType := serializerValue.Type() field.NewValuePool = &sync.Pool{ New: func() interface{} { si := reflect.New(serializerType) si.Elem().Set(serializerValue) return &serializer{ Field: field, Serializer: si.Interface().(SerializerInterface), } }, } } if field.NewValuePool == nil { field.NewValuePool = poolInitializer(reflect.PointerTo(field.IndirectFieldType)) } } ================================================ FILE: schema/field_test.go ================================================ package schema_test import ( "context" "database/sql" "reflect" "sync" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) func TestFieldValuerAndSetter(t *testing.T) { var ( userSchema, _ = schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) user = tests.User{ Model: gorm.Model{ ID: 10, CreatedAt: time.Now(), UpdatedAt: time.Now(), DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, }, Name: "valuer_and_setter", Age: 18, Birthday: tests.Now(), Active: true, } reflectValue = reflect.ValueOf(&user) ) // test valuer values := map[string]interface{}{ "name": user.Name, "id": user.ID, "created_at": user.CreatedAt, "updated_at": user.UpdatedAt, "deleted_at": user.DeletedAt, "age": user.Age, "birthday": user.Birthday, "active": true, } checkField(t, userSchema, reflectValue, values) var f *bool // test setter newValues := map[string]interface{}{ "name": "valuer_and_setter_2", "id": 2, "created_at": time.Now(), "updated_at": nil, "deleted_at": time.Now(), "age": 20, "birthday": time.Now(), "active": f, } for k, v := range newValues { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } newValues["updated_at"] = time.Time{} newValues["active"] = false checkField(t, userSchema, reflectValue, newValues) // test valuer and other type age := myint(10) var nilTime *time.Time newValues2 := map[string]interface{}{ "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, "id": &sql.NullInt64{Int64: 3, Valid: true}, "created_at": tests.Now(), "updated_at": nilTime, "deleted_at": time.Now(), "age": &age, "birthday": mytime(time.Now()), "active": mybool(true), } for k, v := range newValues2 { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } newValues2["updated_at"] = time.Time{} checkField(t, userSchema, reflectValue, newValues2) } func TestPointerFieldValuerAndSetter(t *testing.T) { var ( userSchema, _ = schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) name = "pointer_field_valuer_and_setter" age uint = 18 active = true user = User{ Model: &gorm.Model{ ID: 10, CreatedAt: time.Now(), DeletedAt: gorm.DeletedAt{Time: time.Now(), Valid: true}, }, Name: &name, Age: &age, Birthday: tests.Now(), Active: &active, } reflectValue = reflect.ValueOf(&user) ) // test valuer values := map[string]interface{}{ "name": user.Name, "id": user.ID, "created_at": user.CreatedAt, "deleted_at": user.DeletedAt, "age": user.Age, "birthday": user.Birthday, "active": true, } checkField(t, userSchema, reflectValue, values) // test setter newValues := map[string]interface{}{ "name": "valuer_and_setter_2", "id": 2, "created_at": time.Now(), "deleted_at": time.Now(), "age": 20, "birthday": time.Now(), "active": false, } for k, v := range newValues { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } checkField(t, userSchema, reflectValue, newValues) // test valuer and other type age2 := myint(10) newValues2 := map[string]interface{}{ "name": sql.NullString{String: "valuer_and_setter_3", Valid: true}, "id": &sql.NullInt64{Int64: 3, Valid: true}, "created_at": tests.Now(), "deleted_at": time.Now(), "age": &age2, "birthday": mytime(time.Now()), "active": mybool(true), } for k, v := range newValues2 { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } checkField(t, userSchema, reflectValue, newValues2) } func TestAdvancedDataTypeValuerAndSetter(t *testing.T) { var ( userSchema, _ = schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) name = "advanced_data_type_valuer_and_setter" deletedAt = mytime(time.Now()) isAdmin = mybool(false) user = AdvancedDataTypeUser{ ID: sql.NullInt64{Int64: 10, Valid: true}, Name: &sql.NullString{String: name, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, RegisteredAt: mytime(time.Now()), DeletedAt: &deletedAt, Active: mybool(true), Admin: &isAdmin, } reflectValue = reflect.ValueOf(&user) ) // test valuer values := map[string]interface{}{ "id": user.ID, "name": user.Name, "birthday": user.Birthday, "registered_at": user.RegisteredAt, "deleted_at": user.DeletedAt, "active": user.Active, "admin": user.Admin, } checkField(t, userSchema, reflectValue, values) // test setter newDeletedAt := mytime(time.Now()) newIsAdmin := mybool(true) newValues := map[string]interface{}{ "id": sql.NullInt64{Int64: 1, Valid: true}, "name": &sql.NullString{String: name + "rename", Valid: true}, "birthday": time.Now(), "registered_at": mytime(time.Now()), "deleted_at": &newDeletedAt, "active": mybool(false), "admin": &newIsAdmin, } for k, v := range newValues { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } checkField(t, userSchema, reflectValue, newValues) newValues2 := map[string]interface{}{ "id": 5, "name": name + "rename2", "birthday": time.Now(), "registered_at": time.Now(), "deleted_at": time.Now(), "active": true, "admin": false, } for k, v := range newValues2 { if err := userSchema.FieldsByDBName[k].Set(context.Background(), reflectValue, v); err != nil { t.Errorf("no error should happen when assign value to field %v, but got %v", k, err) } } checkField(t, userSchema, reflectValue, newValues2) } type UserWithPermissionControl struct { ID uint Name string `gorm:"-"` Name2 string `gorm:"->"` Name3 string `gorm:"<-"` Name4 string `gorm:"<-:create"` Name5 string `gorm:"<-:update"` Name6 string `gorm:"<-:create,update"` Name7 string `gorm:"->:false;<-:create,update"` Name8 string `gorm:"->;-:migration"` } func TestParseFieldWithPermission(t *testing.T) { user, err := schema.Parse(&UserWithPermissionControl{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse user with permission, got error %v", err) } fields := []*schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, Creatable: true, Updatable: true, Readable: true, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "", BindNames: []string{"Name"}, DataType: "", Tag: `gorm:"-"`, Creatable: false, Updatable: false, Readable: false}, {Name: "Name2", DBName: "name2", BindNames: []string{"Name2"}, DataType: schema.String, Tag: `gorm:"->"`, Creatable: false, Updatable: false, Readable: true}, {Name: "Name3", DBName: "name3", BindNames: []string{"Name3"}, DataType: schema.String, Tag: `gorm:"<-"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name4", DBName: "name4", BindNames: []string{"Name4"}, DataType: schema.String, Tag: `gorm:"<-:create"`, Creatable: true, Updatable: false, Readable: true}, {Name: "Name5", DBName: "name5", BindNames: []string{"Name5"}, DataType: schema.String, Tag: `gorm:"<-:update"`, Creatable: false, Updatable: true, Readable: true}, {Name: "Name6", DBName: "name6", BindNames: []string{"Name6"}, DataType: schema.String, Tag: `gorm:"<-:create,update"`, Creatable: true, Updatable: true, Readable: true}, {Name: "Name7", DBName: "name7", BindNames: []string{"Name7"}, DataType: schema.String, Tag: `gorm:"->:false;<-:create,update"`, Creatable: true, Updatable: true, Readable: false}, {Name: "Name8", DBName: "name8", BindNames: []string{"Name8"}, DataType: schema.String, Tag: `gorm:"->;-:migration"`, Creatable: false, Updatable: false, Readable: true, IgnoreMigration: true}, } for _, f := range fields { checkSchemaField(t, user, f, func(f *schema.Field) {}) } } type ( ID int64 INT int INT8 int8 INT16 int16 INT32 int32 INT64 int64 UINT uint UINT8 uint8 UINT16 uint16 UINT32 uint32 UINT64 uint64 FLOAT32 float32 FLOAT64 float64 BOOL bool STRING string TIME time.Time BYTES []byte TypeAlias struct { ID INT `gorm:"column:fint"` INT8 `gorm:"column:fint8"` INT16 `gorm:"column:fint16"` INT32 `gorm:"column:fint32"` INT64 `gorm:"column:fint64"` UINT `gorm:"column:fuint"` UINT8 `gorm:"column:fuint8"` UINT16 `gorm:"column:fuint16"` UINT32 `gorm:"column:fuint32"` UINT64 `gorm:"column:fuint64"` FLOAT32 `gorm:"column:ffloat32"` FLOAT64 `gorm:"column:ffloat64"` BOOL `gorm:"column:fbool"` STRING `gorm:"column:fstring"` TIME `gorm:"column:ftime"` BYTES `gorm:"column:fbytes"` } ) func TestTypeAliasField(t *testing.T) { alias, err := schema.Parse(&TypeAlias{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse TypeAlias with permission, got error %v", err) } fields := []*schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, PrimaryKey: true, HasDefaultValue: true, AutoIncrement: true}, {Name: "INT", DBName: "fint", BindNames: []string{"INT"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint"`}, {Name: "INT8", DBName: "fint8", BindNames: []string{"INT8"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fint8"`}, {Name: "INT16", DBName: "fint16", BindNames: []string{"INT16"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fint16"`}, {Name: "INT32", DBName: "fint32", BindNames: []string{"INT32"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fint32"`}, {Name: "INT64", DBName: "fint64", BindNames: []string{"INT64"}, DataType: schema.Int, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fint64"`}, {Name: "UINT", DBName: "fuint", BindNames: []string{"UINT"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint"`}, {Name: "UINT8", DBName: "fuint8", BindNames: []string{"UINT8"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 8, Tag: `gorm:"column:fuint8"`}, {Name: "UINT16", DBName: "fuint16", BindNames: []string{"UINT16"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 16, Tag: `gorm:"column:fuint16"`}, {Name: "UINT32", DBName: "fuint32", BindNames: []string{"UINT32"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:fuint32"`}, {Name: "UINT64", DBName: "fuint64", BindNames: []string{"UINT64"}, DataType: schema.Uint, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:fuint64"`}, {Name: "FLOAT32", DBName: "ffloat32", BindNames: []string{"FLOAT32"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 32, Tag: `gorm:"column:ffloat32"`}, {Name: "FLOAT64", DBName: "ffloat64", BindNames: []string{"FLOAT64"}, DataType: schema.Float, Creatable: true, Updatable: true, Readable: true, Size: 64, Tag: `gorm:"column:ffloat64"`}, {Name: "BOOL", DBName: "fbool", BindNames: []string{"BOOL"}, DataType: schema.Bool, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbool"`}, {Name: "STRING", DBName: "fstring", BindNames: []string{"STRING"}, DataType: schema.String, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fstring"`}, {Name: "TIME", DBName: "ftime", BindNames: []string{"TIME"}, DataType: schema.Time, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:ftime"`}, {Name: "BYTES", DBName: "fbytes", BindNames: []string{"BYTES"}, DataType: schema.Bytes, Creatable: true, Updatable: true, Readable: true, Tag: `gorm:"column:fbytes"`}, } for _, f := range fields { checkSchemaField(t, alias, f, func(f *schema.Field) {}) } } ================================================ FILE: schema/index.go ================================================ package schema import ( "fmt" "sort" "strconv" "strings" ) type Index struct { Name string Class string // UNIQUE | FULLTEXT | SPATIAL Type string // btree, hash, gist, spgist, gin, and brin Where string Comment string Option string // WITH PARSER parser_name Fields []IndexOption // Note: IndexOption's Field maybe the same } type IndexOption struct { *Field Expression string Sort string // DESC, ASC Collate string Length int Priority int } // ParseIndexes parse schema indexes func (schema *Schema) ParseIndexes() []*Index { indexesByName := map[string]*Index{} indexes := []*Index{} for _, field := range schema.Fields { if field.TagSettings["INDEX"] != "" || field.TagSettings["UNIQUEINDEX"] != "" { fieldIndexes, err := parseFieldIndexes(field) if err != nil { schema.err = err break } for _, index := range fieldIndexes { idx := indexesByName[index.Name] if idx == nil { idx = &Index{Name: index.Name} indexesByName[index.Name] = idx indexes = append(indexes, idx) } idx.Name = index.Name if idx.Class == "" { idx.Class = index.Class } if idx.Type == "" { idx.Type = index.Type } if idx.Where == "" { idx.Where = index.Where } if idx.Comment == "" { idx.Comment = index.Comment } if idx.Option == "" { idx.Option = index.Option } idx.Fields = append(idx.Fields, index.Fields...) sort.Slice(idx.Fields, func(i, j int) bool { return idx.Fields[i].Priority < idx.Fields[j].Priority }) } } } for _, index := range indexes { if index.Class == "UNIQUE" && len(index.Fields) == 1 { index.Fields[0].Field.UniqueIndex = index.Name } } return indexes } func (schema *Schema) LookIndex(name string) *Index { if schema != nil { indexes := schema.ParseIndexes() for _, index := range indexes { if index.Name == name { return index } for _, field := range index.Fields { if field.Name == name { return index } } } } return nil } func parseFieldIndexes(field *Field) (indexes []Index, err error) { for _, value := range strings.Split(field.Tag.Get("gorm"), ";") { if value != "" { v := strings.Split(value, ":") k := strings.TrimSpace(strings.ToUpper(v[0])) if k == "INDEX" || k == "UNIQUEINDEX" { var ( name string tag = strings.Join(v[1:], ":") idx = strings.IndexByte(tag, ',') tagSetting = strings.Join(strings.Split(tag, ",")[1:], ",") settings = ParseTagSetting(tagSetting, ",") length, _ = strconv.Atoi(settings["LENGTH"]) ) if idx == -1 { idx = len(tag) } name = tag[0:idx] if name == "" { subName := field.Name const key = "COMPOSITE" if composite, found := settings[key]; found { if len(composite) == 0 || composite == key { err = fmt.Errorf( "the composite tag of %s.%s cannot be empty", field.Schema.Name, field.Name) return } subName = composite } name = field.Schema.namer.IndexName( field.Schema.Table, subName) } if (k == "UNIQUEINDEX") || settings["UNIQUE"] != "" { settings["CLASS"] = "UNIQUE" } priority, err := strconv.Atoi(settings["PRIORITY"]) if err != nil { priority = 10 } indexes = append(indexes, Index{ Name: name, Class: settings["CLASS"], Type: settings["TYPE"], Where: settings["WHERE"], Comment: settings["COMMENT"], Option: settings["OPTION"], Fields: []IndexOption{{ Field: field, Expression: settings["EXPRESSION"], Sort: settings["SORT"], Collate: settings["COLLATE"], Length: length, Priority: priority, }}, }) } } } err = nil return } ================================================ FILE: schema/index_test.go ================================================ package schema_test import ( "sync" "testing" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) type UserIndex struct { Name string `gorm:"index"` Name2 string `gorm:"index:idx_name,unique"` Name3 string `gorm:"index:,sort:desc,collate:utf8,type:btree,length:10,where:name3 != 'jinzhu'"` Name4 string `gorm:"uniqueIndex"` Name5 int64 `gorm:"index:,class:FULLTEXT,comment:hello \\, world,where:age > 10"` Name6 int64 `gorm:"index:profile,comment:hello \\, world,where:age > 10"` Age int64 `gorm:"index:profile,expression:ABS(age),option:WITH PARSER parser_name"` OID int64 `gorm:"index:idx_id;index:idx_oid,unique"` MemberNumber string `gorm:"index:idx_id,priority:1"` Name7 string `gorm:"index:type"` Name8 string `gorm:"index:,length:10;index:,collate:utf8"` CompName1 string `gorm:"index:,unique,composite:idx_compname_1,option:NULLS NOT DISTINCT;not null"` CompName2 string `gorm:"index:,composite:idx_compname_1"` // Composite Index: Flattened structure. Data0A string `gorm:"index:,composite:comp_id0"` Data0B string `gorm:"index:,composite:comp_id0"` // Composite Index: Nested structure. Data1A string `gorm:"index:,composite:comp_id1"` CompIdxLevel1C // Composite Index: Unique and priority. Data2A string `gorm:"index:,unique,composite:comp_id2,priority:2"` CompIdxLevel2C } type CompIdxLevel1C struct { CompIdxLevel1B Data1C string `gorm:"index:,composite:comp_id1"` } type CompIdxLevel1B struct { Data1B string `gorm:"index:,composite:comp_id1"` } type CompIdxLevel2C struct { CompIdxLevel2B Data2C string `gorm:"index:,unique,composite:comp_id2,priority:1"` } type CompIdxLevel2B struct { Data2B string `gorm:"index:,unique,composite:comp_id2,priority:3"` } func TestParseIndex(t *testing.T) { user, err := schema.Parse(&UserIndex{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user index, got error %v", err) } results := []*schema.Index{ { Name: "idx_user_indices_name", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name"}}}, }, { Name: "idx_name", Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name2", UniqueIndex: "idx_name"}}}, }, { Name: "idx_user_indices_name3", Type: "btree", Where: "name3 != 'jinzhu'", Fields: []schema.IndexOption{{ Field: &schema.Field{Name: "Name3"}, Sort: "desc", Collate: "utf8", Length: 10, }}, }, { Name: "idx_user_indices_name4", Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name4", UniqueIndex: "idx_user_indices_name4"}}}, }, { Name: "idx_user_indices_name5", Class: "FULLTEXT", Comment: "hello , world", Where: "age > 10", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name5"}}}, }, { Name: "profile", Comment: "hello , world", Where: "age > 10", Option: "WITH PARSER parser_name", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name6"}}, { Field: &schema.Field{Name: "Age"}, Expression: "ABS(age)", }}, }, { Name: "idx_id", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "MemberNumber"}}, {Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, { Name: "idx_oid", Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "OID", UniqueIndex: "idx_oid"}}}, }, { Name: "type", Type: "", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "Name7"}}}, }, { Name: "idx_user_indices_name8", Type: "", Fields: []schema.IndexOption{ {Field: &schema.Field{Name: "Name8"}, Length: 10}, // Note: Duplicate Columns {Field: &schema.Field{Name: "Name8"}, Collate: "utf8"}, }, }, { Class: "UNIQUE", Name: "idx_user_indices_idx_compname_1", Option: "NULLS NOT DISTINCT", Fields: []schema.IndexOption{ {Field: &schema.Field{Name: "CompName1", NotNull: true}}, {Field: &schema.Field{Name: "CompName2"}}, }, }, { Name: "idx_user_indices_comp_id0", Type: "", Fields: []schema.IndexOption{{ Field: &schema.Field{Name: "Data0A"}, }, { Field: &schema.Field{Name: "Data0B"}, }}, }, { Name: "idx_user_indices_comp_id1", Fields: []schema.IndexOption{{ Field: &schema.Field{Name: "Data1A"}, }, { Field: &schema.Field{Name: "Data1B"}, }, { Field: &schema.Field{Name: "Data1C"}, }}, }, { Name: "idx_user_indices_comp_id2", Class: "UNIQUE", Fields: []schema.IndexOption{{ Field: &schema.Field{Name: "Data2C"}, }, { Field: &schema.Field{Name: "Data2A"}, }, { Field: &schema.Field{Name: "Data2B"}, }}, }, } CheckIndices(t, results, user.ParseIndexes()) } func TestParseIndexWithUniqueIndexAndUnique(t *testing.T) { type IndexTest struct { FieldA string `gorm:"unique;index"` // unique and index FieldB string `gorm:"unique"` // unique FieldC string `gorm:"index:,unique"` // uniqueIndex FieldD string `gorm:"uniqueIndex;index"` // uniqueIndex and index FieldE1 string `gorm:"uniqueIndex:uniq_field_e1_e2"` // mul uniqueIndex FieldE2 string `gorm:"uniqueIndex:uniq_field_e1_e2"` FieldF1 string `gorm:"uniqueIndex:uniq_field_f1_f2;index"` // mul uniqueIndex and index FieldF2 string `gorm:"uniqueIndex:uniq_field_f1_f2;"` FieldG string `gorm:"unique;uniqueIndex"` // unique and uniqueIndex FieldH1 string `gorm:"unique;uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex FieldH2 string `gorm:"uniqueIndex:uniq_field_h1_h2"` // unique and mul uniqueIndex } indexSchema, err := schema.Parse(&IndexTest{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user index, got error %v", err) } indices := indexSchema.ParseIndexes() expectedIndices := []*schema.Index{ { Name: "idx_index_tests_field_a", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldA", Unique: true}}}, }, { Name: "idx_index_tests_field_c", Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldC", UniqueIndex: "idx_index_tests_field_c"}}}, }, { Name: "idx_index_tests_field_d", Class: "UNIQUE", Fields: []schema.IndexOption{ {Field: &schema.Field{Name: "FieldD"}}, // Note: Duplicate Columns {Field: &schema.Field{Name: "FieldD"}}, }, }, { Name: "uniq_field_e1_e2", Class: "UNIQUE", Fields: []schema.IndexOption{ {Field: &schema.Field{Name: "FieldE1"}}, {Field: &schema.Field{Name: "FieldE2"}}, }, }, { Name: "uniq_field_f1_f2", Class: "UNIQUE", Fields: []schema.IndexOption{ {Field: &schema.Field{Name: "FieldF1"}}, {Field: &schema.Field{Name: "FieldF2"}}, }, }, { Name: "idx_index_tests_field_f1", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldF1"}}}, }, { Name: "idx_index_tests_field_g", Class: "UNIQUE", Fields: []schema.IndexOption{{Field: &schema.Field{Name: "FieldG", Unique: true, UniqueIndex: "idx_index_tests_field_g"}}}, }, { Name: "uniq_field_h1_h2", Class: "UNIQUE", Fields: []schema.IndexOption{ {Field: &schema.Field{Name: "FieldH1", Unique: true}}, {Field: &schema.Field{Name: "FieldH2"}}, }, }, } CheckIndices(t, expectedIndices, indices) } func CheckIndices(t *testing.T, expected, actual []*schema.Index) { if len(expected) != len(actual) { t.Errorf("expected %d indices, but got %d", len(expected), len(actual)) return } for i, ei := range expected { t.Run(ei.Name, func(t *testing.T) { ai := actual[i] tests.AssertObjEqual(t, ai, ei, "Name", "Class", "Type", "Where", "Comment", "Option") if len(ei.Fields) != len(ai.Fields) { t.Errorf("expected index %q field length is %d but actual %d", ei.Name, len(ei.Fields), len(ai.Fields)) return } for i, ef := range ei.Fields { af := ai.Fields[i] tests.AssertObjEqual(t, af, ef, "Name", "Unique", "UniqueIndex", "Expression", "Sort", "Collate", "Length", "NotNull") } }) } } ================================================ FILE: schema/interfaces.go ================================================ package schema import ( "gorm.io/gorm/clause" ) // ConstraintInterface database constraint interface type ConstraintInterface interface { GetName() string Build() (sql string, vars []interface{}) } // GormDataTypeInterface gorm data type interface type GormDataTypeInterface interface { GormDataType() string } // FieldNewValuePool field new scan value pool type FieldNewValuePool interface { Get() interface{} Put(interface{}) } // CreateClausesInterface create clauses interface type CreateClausesInterface interface { CreateClauses(*Field) []clause.Interface } // QueryClausesInterface query clauses interface type QueryClausesInterface interface { QueryClauses(*Field) []clause.Interface } // UpdateClausesInterface update clauses interface type UpdateClausesInterface interface { UpdateClauses(*Field) []clause.Interface } // DeleteClausesInterface delete clauses interface type DeleteClausesInterface interface { DeleteClauses(*Field) []clause.Interface } ================================================ FILE: schema/model_test.go ================================================ package schema_test import ( "database/sql" "time" "gorm.io/gorm" "gorm.io/gorm/utils/tests" ) type User struct { *gorm.Model Name *string Age *uint Birthday *time.Time Account *tests.Account Pets []*tests.Pet Toys []*tests.Toy `gorm:"polymorphic:Owner"` CompanyID *int Company *tests.Company ManagerID *uint Manager *User Team []*User `gorm:"foreignkey:ManagerID"` Languages []*tests.Language `gorm:"many2many:UserSpeak"` Friends []*User `gorm:"many2many:user_friends"` Active *bool } type ( mytime time.Time myint int mybool = bool ) type AdvancedDataTypeUser struct { ID sql.NullInt64 Name *sql.NullString Birthday sql.NullTime RegisteredAt mytime DeletedAt *mytime Active mybool Admin *mybool } type BaseModel struct { ID uint CreatedAt time.Time CreatedBy *int Created *VersionUser `gorm:"foreignKey:CreatedBy"` UpdatedAt time.Time DeletedAt gorm.DeletedAt `gorm:"index"` } type VersionModel struct { BaseModel Version int } type VersionUser struct { VersionModel Name string Age uint Birthday *time.Time } ================================================ FILE: schema/naming.go ================================================ package schema import ( "crypto/sha1" "encoding/hex" "regexp" "strings" "unicode/utf8" "github.com/jinzhu/inflection" "golang.org/x/text/cases" "golang.org/x/text/language" ) // Namer namer interface type Namer interface { TableName(table string) string SchemaName(table string) string ColumnName(table, column string) string JoinTableName(joinTable string) string RelationshipFKName(Relationship) string CheckerName(table, column string) string IndexName(table, column string) string UniqueName(table, column string) string } // Replacer replacer interface like strings.Replacer type Replacer interface { Replace(name string) string } var _ Namer = (*NamingStrategy)(nil) // NamingStrategy tables, columns naming strategy type NamingStrategy struct { TablePrefix string SingularTable bool NameReplacer Replacer NoLowerCase bool IdentifierMaxLength int } // TableName convert string to table name func (ns NamingStrategy) TableName(str string) string { if ns.SingularTable { return ns.TablePrefix + ns.toDBName(str) } return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) } // SchemaName generate schema name from table name, don't guarantee it is the reverse value of TableName func (ns NamingStrategy) SchemaName(table string) string { table = strings.TrimPrefix(table, ns.TablePrefix) if ns.SingularTable { return ns.toSchemaName(table) } return ns.toSchemaName(inflection.Singular(table)) } // ColumnName convert string to column name func (ns NamingStrategy) ColumnName(table, column string) string { return ns.toDBName(column) } // JoinTableName convert string to join table name func (ns NamingStrategy) JoinTableName(str string) string { if !ns.NoLowerCase && strings.ToLower(str) == str { return ns.TablePrefix + str } if ns.SingularTable { return ns.TablePrefix + ns.toDBName(str) } return ns.TablePrefix + inflection.Plural(ns.toDBName(str)) } // RelationshipFKName generate fk name for relation func (ns NamingStrategy) RelationshipFKName(rel Relationship) string { return ns.formatName("fk", rel.Schema.Table, ns.toDBName(rel.Name)) } // CheckerName generate checker name func (ns NamingStrategy) CheckerName(table, column string) string { return ns.formatName("chk", table, column) } // IndexName generate index name func (ns NamingStrategy) IndexName(table, column string) string { return ns.formatName("idx", table, ns.toDBName(column)) } // UniqueName generate unique constraint name func (ns NamingStrategy) UniqueName(table, column string) string { return ns.formatName("uni", table, ns.toDBName(column)) } func (ns NamingStrategy) formatName(prefix, table, name string) string { formattedName := strings.ReplaceAll(strings.Join([]string{ prefix, table, name, }, "_"), ".", "_") if ns.IdentifierMaxLength == 0 { ns.IdentifierMaxLength = 64 } if utf8.RuneCountInString(formattedName) > ns.IdentifierMaxLength { h := sha1.New() h.Write([]byte(formattedName)) bs := h.Sum(nil) formattedName = formattedName[0:ns.IdentifierMaxLength-8] + hex.EncodeToString(bs)[:8] } return formattedName } var ( // https://github.com/golang/lint/blob/master/lint.go#L770 commonInitialisms = []string{"API", "ASCII", "CPU", "CSS", "DNS", "EOF", "GUID", "HTML", "HTTP", "HTTPS", "ID", "IP", "JSON", "LHS", "QPS", "RAM", "RHS", "RPC", "SLA", "SMTP", "SSH", "TLS", "TTL", "UID", "UI", "UUID", "URI", "URL", "UTF8", "VM", "XML", "XSRF", "XSS"} commonInitialismsReplacer *strings.Replacer ) func init() { commonInitialismsForReplacer := make([]string, 0, len(commonInitialisms)) for _, initialism := range commonInitialisms { commonInitialismsForReplacer = append(commonInitialismsForReplacer, initialism, cases.Title(language.Und).String(initialism)) } commonInitialismsReplacer = strings.NewReplacer(commonInitialismsForReplacer...) } func (ns NamingStrategy) toDBName(name string) string { if name == "" { return "" } if ns.NameReplacer != nil { tmpName := ns.NameReplacer.Replace(name) if tmpName == "" { return name } name = tmpName } if ns.NoLowerCase { return name } var ( value = commonInitialismsReplacer.Replace(name) buf strings.Builder lastCase, nextCase, nextNumber bool // upper case == true curCase = value[0] <= 'Z' && value[0] >= 'A' ) for i, v := range value[:len(value)-1] { nextCase = value[i+1] <= 'Z' && value[i+1] >= 'A' nextNumber = value[i+1] >= '0' && value[i+1] <= '9' if curCase { if lastCase && (nextCase || nextNumber) { buf.WriteRune(v + 32) } else { if i > 0 && value[i-1] != '_' && value[i+1] != '_' { buf.WriteByte('_') } buf.WriteRune(v + 32) } } else { buf.WriteRune(v) } lastCase = curCase curCase = nextCase } if curCase { if !lastCase && len(value) > 1 { buf.WriteByte('_') } buf.WriteByte(value[len(value)-1] + 32) } else { buf.WriteByte(value[len(value)-1]) } ret := buf.String() return ret } func (ns NamingStrategy) toSchemaName(name string) string { result := strings.ReplaceAll(cases.Title(language.Und, cases.NoLower).String(strings.ReplaceAll(name, "_", " ")), " ", "") for _, initialism := range commonInitialisms { result = regexp.MustCompile(cases.Title(language.Und, cases.NoLower).String(strings.ToLower(initialism))+"([A-Z]|$|_)").ReplaceAllString(result, initialism+"$1") } return result } ================================================ FILE: schema/naming_test.go ================================================ package schema import ( "strings" "testing" ) func TestToDBName(t *testing.T) { maps := map[string]string{ "": "", "x": "x", "X": "x", "userRestrictions": "user_restrictions", "ThisIsATest": "this_is_a_test", "PFAndESI": "pf_and_esi", "AbcAndJkl": "abc_and_jkl", "EmployeeID": "employee_id", "SKU_ID": "sku_id", "FieldX": "field_x", "HTTPAndSMTP": "http_and_smtp", "HTTPServerHandlerForURLID": "http_server_handler_for_url_id", "UUID": "uuid", "HTTPURL": "http_url", "HTTP_URL": "http_url", "SHA256Hash": "sha256_hash", "SHA256HASH": "sha256_hash", "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIdCanBeUsedAtTheEndAsID": "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id", } ns := NamingStrategy{} for key, value := range maps { if ns.toDBName(key) != value { t.Errorf("%v toName should equal %v, but got %v", key, value, ns.toDBName(key)) } } maps = map[string]string{ "x": "X", "user_restrictions": "UserRestriction", "this_is_a_test": "ThisIsATest", "abc_and_jkl": "AbcAndJkl", "employee_id": "EmployeeID", "field_x": "FieldX", "http_and_smtp": "HTTPAndSMTP", "http_server_handler_for_url_id": "HTTPServerHandlerForURLID", "uuid": "UUID", "http_url": "HTTPURL", "sha256_hash": "Sha256Hash", "this_is_actually_a_test_so_we_may_be_able_to_use_this_code_in_gorm_package_also_id_can_be_used_at_the_end_as_id": "ThisIsActuallyATestSoWeMayBeAbleToUseThisCodeInGormPackageAlsoIDCanBeUsedAtTheEndAsID", } for key, value := range maps { if ns.SchemaName(key) != value { t.Errorf("%v schema name should equal %v, but got %v", key, value, ns.SchemaName(key)) } } } func TestNamingStrategy(t *testing.T) { ns := NamingStrategy{ TablePrefix: "public.", SingularTable: true, NameReplacer: strings.NewReplacer("CID", "Cid"), } idxName := ns.IndexName("public.table", "name") if idxName != "idx_public_table_name" { t.Errorf("invalid index name generated, got %v", idxName) } chkName := ns.CheckerName("public.table", "name") if chkName != "chk_public_table_name" { t.Errorf("invalid checker name generated, got %v", chkName) } joinTable := ns.JoinTableName("user_languages") if joinTable != "public.user_languages" { t.Errorf("invalid join table generated, got %v", joinTable) } joinTable2 := ns.JoinTableName("UserLanguage") if joinTable2 != "public.user_language" { t.Errorf("invalid join table generated, got %v", joinTable2) } tableName := ns.TableName("Company") if tableName != "public.company" { t.Errorf("invalid table name generated, got %v", tableName) } columdName := ns.ColumnName("", "NameCID") if columdName != "name_cid" { t.Errorf("invalid column name generated, got %v", columdName) } } type CustomReplacer struct { f func(string) string } func (r CustomReplacer) Replace(name string) string { return r.f(name) } func TestCustomReplacer(t *testing.T) { ns := NamingStrategy{ TablePrefix: "public.", SingularTable: true, NameReplacer: CustomReplacer{ func(name string) string { replaced := "REPLACED_" + strings.ToUpper(name) return strings.NewReplacer("CID", "_Cid").Replace(replaced) }, }, NoLowerCase: false, } idxName := ns.IndexName("public.table", "name") if idxName != "idx_public_table_replaced_name" { t.Errorf("invalid index name generated, got %v", idxName) } chkName := ns.CheckerName("public.table", "name") if chkName != "chk_public_table_name" { t.Errorf("invalid checker name generated, got %v", chkName) } joinTable := ns.JoinTableName("user_languages") if joinTable != "public.user_languages" { // Seems like a bug in NamingStrategy to skip the Replacer when the name is lowercase here. t.Errorf("invalid join table generated, got %v", joinTable) } joinTable2 := ns.JoinTableName("UserLanguage") if joinTable2 != "public.replaced_userlanguage" { t.Errorf("invalid join table generated, got %v", joinTable2) } tableName := ns.TableName("Company") if tableName != "public.replaced_company" { t.Errorf("invalid table name generated, got %v", tableName) } columdName := ns.ColumnName("", "NameCID") if columdName != "replaced_name_cid" { t.Errorf("invalid column name generated, got %v", columdName) } } func TestCustomReplacerWithNoLowerCase(t *testing.T) { ns := NamingStrategy{ TablePrefix: "public.", SingularTable: true, NameReplacer: CustomReplacer{ func(name string) string { replaced := "REPLACED_" + strings.ToUpper(name) return strings.NewReplacer("CID", "_Cid").Replace(replaced) }, }, NoLowerCase: true, } idxName := ns.IndexName("public.table", "name") if idxName != "idx_public_table_REPLACED_NAME" { t.Errorf("invalid index name generated, got %v", idxName) } chkName := ns.CheckerName("public.table", "name") if chkName != "chk_public_table_name" { t.Errorf("invalid checker name generated, got %v", chkName) } joinTable := ns.JoinTableName("user_languages") if joinTable != "public.REPLACED_USER_LANGUAGES" { t.Errorf("invalid join table generated, got %v", joinTable) } joinTable2 := ns.JoinTableName("UserLanguage") if joinTable2 != "public.REPLACED_USERLANGUAGE" { t.Errorf("invalid join table generated, got %v", joinTable2) } tableName := ns.TableName("Company") if tableName != "public.REPLACED_COMPANY" { t.Errorf("invalid table name generated, got %v", tableName) } columdName := ns.ColumnName("", "NameCID") if columdName != "REPLACED_NAME_Cid" { t.Errorf("invalid column name generated, got %v", columdName) } } func TestFormatNameWithStringLongerThan63Characters(t *testing.T) { ns := NamingStrategy{IdentifierMaxLength: 63} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVer180f2c67" { t.Errorf("invalid formatted name generated, got %v", formattedName) } } func TestFormatNameWithStringLongerThan64Characters(t *testing.T) { ns := NamingStrategy{IdentifierMaxLength: 64} formattedName := ns.formatName("prefix", "table", "thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString") if formattedName != "prefix_table_thisIsAVeryVeryVeryVeryVeryVeryVeryVeryVery180f2c67" { t.Errorf("invalid formatted name generated, got %v", formattedName) } } func TestReplaceEmptyTableName(t *testing.T) { ns := NamingStrategy{ SingularTable: true, NameReplacer: strings.NewReplacer("Model", ""), } tableName := ns.TableName("Model") if tableName != "Model" { t.Errorf("invalid table name generated, got %v", tableName) } } ================================================ FILE: schema/pool.go ================================================ package schema import ( "reflect" "sync" ) // sync pools var ( normalPool sync.Map poolInitializer = func(reflectType reflect.Type) FieldNewValuePool { v, _ := normalPool.LoadOrStore(reflectType, &sync.Pool{ New: func() interface{} { return reflect.New(reflectType).Interface() }, }) return v.(FieldNewValuePool) } ) ================================================ FILE: schema/relationship.go ================================================ package schema import ( "context" "fmt" "reflect" "strings" "sync" "github.com/jinzhu/inflection" "golang.org/x/text/cases" "golang.org/x/text/language" "gorm.io/gorm/clause" ) // RelationshipType relationship type type RelationshipType string const ( HasOne RelationshipType = "has_one" // HasOneRel has one relationship HasMany RelationshipType = "has_many" // HasManyRel has many relationship BelongsTo RelationshipType = "belongs_to" // BelongsToRel belongs to relationship Many2Many RelationshipType = "many_to_many" // Many2ManyRel many to many relationship has RelationshipType = "has" ) type Relationships struct { HasOne []*Relationship BelongsTo []*Relationship HasMany []*Relationship Many2Many []*Relationship Relations map[string]*Relationship EmbeddedRelations map[string]*Relationships Mux sync.RWMutex } type Relationship struct { Name string Type RelationshipType Field *Field Polymorphic *Polymorphic References []*Reference Schema *Schema FieldSchema *Schema JoinTable *Schema foreignKeys, primaryKeys []string } type Polymorphic struct { PolymorphicID *Field PolymorphicType *Field Value string } type Reference struct { PrimaryKey *Field PrimaryValue string ForeignKey *Field OwnPrimaryKey bool } func (schema *Schema) parseRelation(field *Field) *Relationship { var ( err error fieldValue = reflect.New(field.IndirectFieldType).Interface() relation = &Relationship{ Name: field.Name, Field: field, Schema: schema, foreignKeys: toColumns(field.TagSettings["FOREIGNKEY"]), primaryKeys: toColumns(field.TagSettings["REFERENCES"]), } ) if relation.FieldSchema, err = getOrParse(fieldValue, schema.cacheStore, schema.namer); err != nil { schema.err = fmt.Errorf("failed to parse field: %s, error: %w", field.Name, err) return nil } if hasPolymorphicRelation(field.TagSettings) { schema.buildPolymorphicRelation(relation, field) } else if many2many := field.TagSettings["MANY2MANY"]; many2many != "" { schema.buildMany2ManyRelation(relation, field, many2many) } else if belongsTo := field.TagSettings["BELONGSTO"]; belongsTo != "" { schema.guessRelation(relation, field, guessBelongs) } else { switch field.IndirectFieldType.Kind() { case reflect.Struct: schema.guessRelation(relation, field, guessGuess) case reflect.Slice: schema.guessRelation(relation, field, guessHas) default: schema.err = fmt.Errorf("unsupported data type %v for %v on field %s", relation.FieldSchema, schema, field.Name) } } if relation.Type == has { if relation.FieldSchema != relation.Schema && relation.Polymorphic == nil && field.OwnerSchema == nil { relation.FieldSchema.Relationships.Mux.Lock() relation.FieldSchema.Relationships.Relations["_"+relation.Schema.Name+"_"+relation.Name] = relation relation.FieldSchema.Relationships.Mux.Unlock() } switch field.IndirectFieldType.Kind() { case reflect.Struct: relation.Type = HasOne case reflect.Slice: relation.Type = HasMany } } if schema.err == nil { schema.setRelation(relation) switch relation.Type { case HasOne: schema.Relationships.HasOne = append(schema.Relationships.HasOne, relation) case HasMany: schema.Relationships.HasMany = append(schema.Relationships.HasMany, relation) case BelongsTo: schema.Relationships.BelongsTo = append(schema.Relationships.BelongsTo, relation) case Many2Many: schema.Relationships.Many2Many = append(schema.Relationships.Many2Many, relation) } } return relation } // hasPolymorphicRelation check if has polymorphic relation // 1. `POLYMORPHIC` tag // 2. `POLYMORPHICTYPE` and `POLYMORPHICID` tag func hasPolymorphicRelation(tagSettings map[string]string) bool { if _, ok := tagSettings["POLYMORPHIC"]; ok { return true } _, hasType := tagSettings["POLYMORPHICTYPE"] _, hasId := tagSettings["POLYMORPHICID"] return hasType && hasId } func (schema *Schema) setRelation(relation *Relationship) { schema.Relationships.Mux.Lock() defer schema.Relationships.Mux.Unlock() // set non-embedded relation if rel := schema.Relationships.Relations[relation.Name]; rel != nil { if len(rel.Field.BindNames) > 1 { schema.Relationships.Relations[relation.Name] = relation } } else { schema.Relationships.Relations[relation.Name] = relation } // set embedded relation if len(relation.Field.EmbeddedBindNames) <= 1 { return } relationships := &schema.Relationships for i, name := range relation.Field.EmbeddedBindNames { if i < len(relation.Field.EmbeddedBindNames)-1 { if relationships.EmbeddedRelations == nil { relationships.EmbeddedRelations = map[string]*Relationships{} } if r := relationships.EmbeddedRelations[name]; r == nil { relationships.EmbeddedRelations[name] = &Relationships{} } relationships = relationships.EmbeddedRelations[name] } else { if relationships.Relations == nil { relationships.Relations = map[string]*Relationship{} } relationships.Relations[relation.Name] = relation } } } // User has many Toys, its `Polymorphic` is `Owner`, Pet has one Toy, its `Polymorphic` is `Owner` // // type User struct { // Toys []Toy `gorm:"polymorphic:Owner;"` // } // type Pet struct { // Toy Toy `gorm:"polymorphic:Owner;"` // } // type Toy struct { // OwnerID int // OwnerType string // } func (schema *Schema) buildPolymorphicRelation(relation *Relationship, field *Field) { polymorphic := field.TagSettings["POLYMORPHIC"] relation.Polymorphic = &Polymorphic{ Value: schema.Table, } var ( typeName = polymorphic + "Type" typeId = polymorphic + "ID" ) if value, ok := field.TagSettings["POLYMORPHICTYPE"]; ok { typeName = strings.TrimSpace(value) } if value, ok := field.TagSettings["POLYMORPHICID"]; ok { typeId = strings.TrimSpace(value) } relation.Polymorphic.PolymorphicType = relation.FieldSchema.FieldsByName[typeName] relation.Polymorphic.PolymorphicID = relation.FieldSchema.FieldsByName[typeId] if value, ok := field.TagSettings["POLYMORPHICVALUE"]; ok { relation.Polymorphic.Value = strings.TrimSpace(value) } if relation.Polymorphic.PolymorphicType == nil { schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"Type") } if relation.Polymorphic.PolymorphicID == nil { schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing field %s", relation.FieldSchema, schema, field.Name, polymorphic+"ID") } if schema.err == nil { relation.References = append(relation.References, &Reference{ PrimaryValue: relation.Polymorphic.Value, ForeignKey: relation.Polymorphic.PolymorphicType, }) primaryKeyField := schema.PrioritizedPrimaryField if len(relation.foreignKeys) > 0 { if primaryKeyField = schema.LookUpField(relation.foreignKeys[0]); primaryKeyField == nil || len(relation.foreignKeys) > 1 { schema.err = fmt.Errorf("invalid polymorphic foreign keys %+v for %v on field %s", relation.foreignKeys, schema, field.Name) } } if primaryKeyField == nil { schema.err = fmt.Errorf("invalid polymorphic type %v for %v on field %s, missing primaryKey field", relation.FieldSchema, schema, field.Name) return } // use same data type for foreign keys if copyableDataType(primaryKeyField.DataType) { relation.Polymorphic.PolymorphicID.DataType = primaryKeyField.DataType } relation.Polymorphic.PolymorphicID.GORMDataType = primaryKeyField.GORMDataType if relation.Polymorphic.PolymorphicID.Size == 0 { relation.Polymorphic.PolymorphicID.Size = primaryKeyField.Size } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryKeyField, ForeignKey: relation.Polymorphic.PolymorphicID, OwnPrimaryKey: true, }) } relation.Type = has } func (schema *Schema) buildMany2ManyRelation(relation *Relationship, field *Field, many2many string) { relation.Type = Many2Many var ( err error joinTableFields []reflect.StructField fieldsMap = map[string]*Field{} ownFieldsMap = map[string]*Field{} // fix self join many2many referFieldsMap = map[string]*Field{} joinForeignKeys = toColumns(field.TagSettings["JOINFOREIGNKEY"]) joinReferences = toColumns(field.TagSettings["JOINREFERENCES"]) ) ownForeignFields := schema.PrimaryFields refForeignFields := relation.FieldSchema.PrimaryFields if len(relation.foreignKeys) > 0 { ownForeignFields = []*Field{} for _, foreignKey := range relation.foreignKeys { if field := schema.LookUpField(foreignKey); field != nil { ownForeignFields = append(ownForeignFields, field) } else { schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } } if len(relation.primaryKeys) > 0 { refForeignFields = []*Field{} for _, foreignKey := range relation.primaryKeys { if field := relation.FieldSchema.LookUpField(foreignKey); field != nil { refForeignFields = append(refForeignFields, field) } else { schema.err = fmt.Errorf("invalid foreign key: %s", foreignKey) return } } } for idx, ownField := range ownForeignFields { joinFieldName := cases.Title(language.Und, cases.NoLower).String(schema.Name) + ownField.Name if len(joinForeignKeys) > idx { joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinForeignKeys[idx]) } ownFieldsMap[joinFieldName] = ownField fieldsMap[joinFieldName] = ownField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, PkgPath: ownField.StructField.PkgPath, Type: ownField.StructField.Type, Tag: removeSettingFromTag(appendSettingFromTag(ownField.StructField.Tag, "primaryKey"), "column", "autoincrement", "index", "unique", "uniqueindex"), }) } for idx, relField := range refForeignFields { joinFieldName := cases.Title(language.Und, cases.NoLower).String(relation.FieldSchema.Name) + relField.Name if _, ok := ownFieldsMap[joinFieldName]; ok { if field.Name != relation.FieldSchema.Name { joinFieldName = inflection.Singular(field.Name) + relField.Name } else { joinFieldName += "Reference" } } if len(joinReferences) > idx { joinFieldName = cases.Title(language.Und, cases.NoLower).String(joinReferences[idx]) } referFieldsMap[joinFieldName] = relField if _, ok := fieldsMap[joinFieldName]; !ok { fieldsMap[joinFieldName] = relField joinTableFields = append(joinTableFields, reflect.StructField{ Name: joinFieldName, PkgPath: relField.StructField.PkgPath, Type: relField.StructField.Type, Tag: removeSettingFromTag(appendSettingFromTag(relField.StructField.Tag, "primaryKey"), "column", "autoincrement", "index", "unique", "uniqueindex"), }) } } joinTableFields = append(joinTableFields, reflect.StructField{ Name: cases.Title(language.Und, cases.NoLower).String(schema.Name) + field.Name, Type: schema.ModelType, Tag: `gorm:"-"`, }) if relation.JoinTable, err = Parse(reflect.New(reflect.StructOf(joinTableFields)).Interface(), schema.cacheStore, schema.namer); err != nil { schema.err = err } relation.JoinTable.Name = many2many relation.JoinTable.Table = schema.namer.JoinTableName(many2many) relation.JoinTable.PrimaryFields = make([]*Field, 0, len(relation.JoinTable.Fields)) relName := relation.Schema.Name relRefName := relation.FieldSchema.Name if relName == relRefName { relRefName = relation.Field.Name } if _, ok := relation.JoinTable.Relationships.Relations[relName]; !ok { relation.JoinTable.Relationships.Relations[relName] = &Relationship{ Name: relName, Type: BelongsTo, Schema: relation.JoinTable, FieldSchema: relation.Schema, } } else { relation.JoinTable.Relationships.Relations[relName].References = []*Reference{} } if _, ok := relation.JoinTable.Relationships.Relations[relRefName]; !ok { relation.JoinTable.Relationships.Relations[relRefName] = &Relationship{ Name: relRefName, Type: BelongsTo, Schema: relation.JoinTable, FieldSchema: relation.FieldSchema, } } else { relation.JoinTable.Relationships.Relations[relRefName].References = []*Reference{} } // build references for _, f := range relation.JoinTable.Fields { if f.Creatable || f.Readable || f.Updatable { // use same data type for foreign keys if copyableDataType(fieldsMap[f.Name].DataType) { f.DataType = fieldsMap[f.Name].DataType } f.GORMDataType = fieldsMap[f.Name].GORMDataType if f.Size == 0 { f.Size = fieldsMap[f.Name].Size } relation.JoinTable.PrimaryFields = append(relation.JoinTable.PrimaryFields, f) if of, ok := ownFieldsMap[f.Name]; ok { joinRel := relation.JoinTable.Relationships.Relations[relName] joinRel.Field = relation.Field joinRel.References = append(joinRel.References, &Reference{ PrimaryKey: of, ForeignKey: f, }) relation.References = append(relation.References, &Reference{ PrimaryKey: of, ForeignKey: f, OwnPrimaryKey: true, }) } if rf, ok := referFieldsMap[f.Name]; ok { joinRefRel := relation.JoinTable.Relationships.Relations[relRefName] if joinRefRel.Field == nil { joinRefRel.Field = relation.Field } joinRefRel.References = append(joinRefRel.References, &Reference{ PrimaryKey: rf, ForeignKey: f, }) relation.References = append(relation.References, &Reference{ PrimaryKey: rf, ForeignKey: f, }) } } } } type guessLevel int const ( guessGuess guessLevel = iota guessBelongs guessEmbeddedBelongs guessHas guessEmbeddedHas ) func (schema *Schema) guessRelation(relation *Relationship, field *Field, cgl guessLevel) { var ( primaryFields, foreignFields []*Field primarySchema, foreignSchema = schema, relation.FieldSchema gl = cgl ) if gl == guessGuess { if field.Schema == relation.FieldSchema { gl = guessBelongs } else { gl = guessHas } } reguessOrErr := func() { switch cgl { case guessGuess: schema.guessRelation(relation, field, guessBelongs) case guessBelongs: schema.guessRelation(relation, field, guessEmbeddedBelongs) case guessEmbeddedBelongs: schema.guessRelation(relation, field, guessHas) case guessHas: schema.guessRelation(relation, field, guessEmbeddedHas) // case guessEmbeddedHas: default: schema.err = fmt.Errorf("invalid field found for struct %v's field %s: define a valid foreign key for relations or implement the Valuer/Scanner interface", schema, field.Name) } } switch gl { case guessBelongs: primarySchema, foreignSchema = relation.FieldSchema, schema case guessEmbeddedBelongs: if field.OwnerSchema == nil { reguessOrErr() return } primarySchema, foreignSchema = relation.FieldSchema, field.OwnerSchema case guessHas: case guessEmbeddedHas: if field.OwnerSchema == nil { reguessOrErr() return } primarySchema, foreignSchema = field.OwnerSchema, relation.FieldSchema } if len(relation.foreignKeys) > 0 { for _, foreignKey := range relation.foreignKeys { f := foreignSchema.LookUpField(foreignKey) if f == nil { reguessOrErr() return } foreignFields = append(foreignFields, f) } } else { primarySchemaName := primarySchema.Name if primarySchemaName == "" { primarySchemaName = relation.FieldSchema.Name } if len(relation.primaryKeys) > 0 { for _, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { primaryFields = append(primaryFields, f) } } } else { primaryFields = primarySchema.PrimaryFields } primaryFieldLoop: for _, primaryField := range primaryFields { lookUpName := primarySchemaName + primaryField.Name if gl == guessBelongs { lookUpName = field.Name + primaryField.Name } lookUpNames := []string{lookUpName} if len(primaryFields) == 1 { lookUpNames = append(lookUpNames, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID", strings.TrimSuffix(lookUpName, primaryField.Name)+"Id", schema.namer.ColumnName(foreignSchema.Table, strings.TrimSuffix(lookUpName, primaryField.Name)+"ID")) } for _, name := range lookUpNames { if f := foreignSchema.LookUpFieldByBindName(field.BindNames, name); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) continue primaryFieldLoop } } for _, name := range lookUpNames { if f := foreignSchema.LookUpField(name); f != nil { foreignFields = append(foreignFields, f) primaryFields = append(primaryFields, primaryField) continue primaryFieldLoop } } } } switch { case len(foreignFields) == 0: reguessOrErr() return case len(relation.primaryKeys) > 0: for idx, primaryKey := range relation.primaryKeys { if f := primarySchema.LookUpField(primaryKey); f != nil { if len(primaryFields) < idx+1 { primaryFields = append(primaryFields, f) } else if f != primaryFields[idx] { reguessOrErr() return } } else { reguessOrErr() return } } case len(primaryFields) == 0: if len(foreignFields) == 1 && primarySchema.PrioritizedPrimaryField != nil { primaryFields = append(primaryFields, primarySchema.PrioritizedPrimaryField) } else if len(primarySchema.PrimaryFields) == len(foreignFields) { primaryFields = append(primaryFields, primarySchema.PrimaryFields...) } else { reguessOrErr() return } } // build references for idx, foreignField := range foreignFields { // use same data type for foreign keys schema.Relationships.Mux.Lock() if schema != foreignField.Schema { foreignField.Schema.Relationships.Mux.Lock() } if copyableDataType(primaryFields[idx].DataType) { foreignField.DataType = primaryFields[idx].DataType } foreignField.GORMDataType = primaryFields[idx].GORMDataType if foreignField.Size == 0 { foreignField.Size = primaryFields[idx].Size } schema.Relationships.Mux.Unlock() if schema != foreignField.Schema { foreignField.Schema.Relationships.Mux.Unlock() } relation.References = append(relation.References, &Reference{ PrimaryKey: primaryFields[idx], ForeignKey: foreignField, OwnPrimaryKey: (schema == primarySchema && gl == guessHas) || (field.OwnerSchema == primarySchema && gl == guessEmbeddedHas), }) } if gl == guessHas || gl == guessEmbeddedHas { relation.Type = has } else { relation.Type = BelongsTo } } // Constraint is ForeignKey Constraint type Constraint struct { Name string Field *Field Schema *Schema ForeignKeys []*Field ReferenceSchema *Schema References []*Field OnDelete string OnUpdate string } func (constraint *Constraint) GetName() string { return constraint.Name } func (constraint *Constraint) Build() (sql string, vars []interface{}) { sql = "CONSTRAINT ? FOREIGN KEY ? REFERENCES ??" if constraint.OnDelete != "" { sql += " ON DELETE " + constraint.OnDelete } if constraint.OnUpdate != "" { sql += " ON UPDATE " + constraint.OnUpdate } foreignKeys := make([]interface{}, 0, len(constraint.ForeignKeys)) for _, field := range constraint.ForeignKeys { foreignKeys = append(foreignKeys, clause.Column{Name: field.DBName}) } references := make([]interface{}, 0, len(constraint.References)) for _, field := range constraint.References { references = append(references, clause.Column{Name: field.DBName}) } vars = append(vars, clause.Table{Name: constraint.Name}, foreignKeys, clause.Table{Name: constraint.ReferenceSchema.Table}, references) return } func (rel *Relationship) ParseConstraint() *Constraint { str := rel.Field.TagSettings["CONSTRAINT"] if str == "-" { return nil } if rel.Type == BelongsTo { for _, r := range rel.FieldSchema.Relationships.Relations { if r != rel && r.FieldSchema == rel.Schema && len(rel.References) == len(r.References) { matched := true for idx, ref := range r.References { if rel.References[idx].PrimaryKey != ref.PrimaryKey || rel.References[idx].ForeignKey != ref.ForeignKey || rel.References[idx].PrimaryValue != ref.PrimaryValue { matched = false break } } if matched { return nil } } } } var ( name string idx = strings.IndexByte(str, ',') settings = ParseTagSetting(str, ",") ) // optimize match english letters and midline // The following code is basically called in for. // In order to avoid the performance problems caused by repeated compilation of regular expressions, // it only needs to be done once outside, so optimization is done here. if idx != -1 && regEnLetterAndMidline.MatchString(str[0:idx]) { name = str[0:idx] } else { name = rel.Schema.namer.RelationshipFKName(*rel) } constraint := Constraint{ Name: name, Field: rel.Field, OnUpdate: settings["ONUPDATE"], OnDelete: settings["ONDELETE"], } for _, ref := range rel.References { if ref.PrimaryKey != nil && (rel.JoinTable == nil || ref.OwnPrimaryKey) { constraint.ForeignKeys = append(constraint.ForeignKeys, ref.ForeignKey) constraint.References = append(constraint.References, ref.PrimaryKey) if ref.OwnPrimaryKey { constraint.Schema = ref.ForeignKey.Schema constraint.ReferenceSchema = rel.Schema } else { constraint.Schema = rel.Schema constraint.ReferenceSchema = ref.PrimaryKey.Schema } } } return &constraint } func (rel *Relationship) ToQueryConditions(ctx context.Context, reflectValue reflect.Value) (conds []clause.Expression) { table := rel.FieldSchema.Table foreignFields := []*Field{} relForeignKeys := []string{} if rel.JoinTable != nil { table = rel.JoinTable.Table for _, ref := range rel.References { if ref.OwnPrimaryKey { foreignFields = append(foreignFields, ref.PrimaryKey) relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) } else if ref.PrimaryValue != "" { conds = append(conds, clause.Eq{ Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, Value: ref.PrimaryValue, }) } else { conds = append(conds, clause.Eq{ Column: clause.Column{Table: rel.JoinTable.Table, Name: ref.ForeignKey.DBName}, Value: clause.Column{Table: rel.FieldSchema.Table, Name: ref.PrimaryKey.DBName}, }) } } } else { for _, ref := range rel.References { if ref.OwnPrimaryKey { relForeignKeys = append(relForeignKeys, ref.ForeignKey.DBName) foreignFields = append(foreignFields, ref.PrimaryKey) } else if ref.PrimaryValue != "" { conds = append(conds, clause.Eq{ Column: clause.Column{Table: rel.FieldSchema.Table, Name: ref.ForeignKey.DBName}, Value: ref.PrimaryValue, }) } else { relForeignKeys = append(relForeignKeys, ref.PrimaryKey.DBName) foreignFields = append(foreignFields, ref.ForeignKey) } } } _, foreignValues := GetIdentityFieldValuesMap(ctx, reflectValue, foreignFields) column, values := ToQueryValues(table, relForeignKeys, foreignValues) conds = append(conds, clause.IN{Column: column, Values: values}) return } func copyableDataType(str DataType) bool { lowerStr := strings.ToLower(string(str)) for _, s := range []string{"auto_increment", "primary key"} { if strings.Contains(lowerStr, s) { return false } } return true } ================================================ FILE: schema/relationship_test.go ================================================ package schema_test import ( "sync" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) func checkStructRelation(t *testing.T, data interface{}, relations ...Relation) { if s, err := schema.Parse(data, &sync.Map{}, schema.NamingStrategy{}); err != nil { t.Errorf("Failed to parse schema, got error %v", err) } else { for _, rel := range relations { checkSchemaRelation(t, s, rel) } } } func TestBelongsToOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string } type User struct { gorm.Model Profile Profile `gorm:"ForeignKey:ProfileRefer"` ProfileRefer int } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", References: []Reference{{"ID", "Profile", "ProfileRefer", "User", "", false}}, }) } func TestBelongsToOverrideReferences(t *testing.T) { type Profile struct { gorm.Model Refer string Name string } type User struct { gorm.Model Profile Profile `gorm:"ForeignKey:ProfileID;References:Refer"` ProfileID int } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, }) } func TestBelongsToWithOnlyReferences(t *testing.T) { type Profile struct { gorm.Model Refer string Name string } type User struct { gorm.Model Profile Profile `gorm:"References:Refer"` ProfileRefer int } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, }) } func TestBelongsToWithOnlyReferences2(t *testing.T) { type Profile struct { gorm.Model Refer string Name string } type User struct { gorm.Model Profile Profile `gorm:"References:Refer"` ProfileID int } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "Profile", "ProfileID", "User", "", false}}, }) } func TestSelfReferentialBelongsTo(t *testing.T) { type User struct { ID int32 `gorm:"primaryKey"` Name string CreatorID *int32 Creator *User } checkStructRelation(t, &User{}, Relation{ Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", References: []Reference{{"ID", "User", "CreatorID", "User", "", false}}, }) } func TestSelfReferentialBelongsToOverrideReferences(t *testing.T) { type User struct { ID int32 `gorm:"primaryKey"` Name string CreatedBy *int32 Creator *User `gorm:"foreignKey:CreatedBy;references:ID"` } checkStructRelation(t, &User{}, Relation{ Name: "Creator", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", References: []Reference{{"ID", "User", "CreatedBy", "User", "", false}}, }) } func TestBelongsToWithMixin(t *testing.T) { type Profile struct { gorm.Model Refer string Name string } type ProfileMixin struct { Profile Profile `gorm:"References:Refer"` ProfileRefer int } type User struct { gorm.Model ProfileMixin } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "Profile", "ProfileRefer", "User", "", false}}, }) } func TestHasOneOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model Profile Profile `gorm:"ForeignKey:UserRefer"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, }) } func TestHasOneOverrideReferences(t *testing.T) { type Profile struct { gorm.Model Name string UserID uint } type User struct { gorm.Model Refer string Profile Profile `gorm:"ForeignKey:UserID;References:Refer"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, }) } func TestHasOneOverrideReferences2(t *testing.T) { type Profile struct { gorm.Model Name string } type User struct { gorm.Model ProfileID uint `gorm:"column:profile_id"` Profile *Profile `gorm:"foreignKey:ID;references:ProfileID"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", References: []Reference{{"ProfileID", "User", "ID", "Profile", "", true}}, }) } func TestHasOneWithOnlyReferences(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model Refer string Profile Profile `gorm:"References:Refer"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "User", "UserRefer", "Profile", "", true}}, }) } func TestHasOneWithOnlyReferences2(t *testing.T) { type Profile struct { gorm.Model Name string UserID uint } type User struct { gorm.Model Refer string Profile Profile `gorm:"References:Refer"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, }) } func TestHasManyOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model Profile []Profile `gorm:"ForeignKey:UserRefer"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, }) } func TestHasManyOverrideReferences(t *testing.T) { type Profile struct { gorm.Model Name string UserID uint } type User struct { gorm.Model Refer string Profile []Profile `gorm:"ForeignKey:UserID;References:Refer"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", References: []Reference{{"Refer", "User", "UserID", "Profile", "", true}}, }) } func TestMany2ManyOverrideForeignKeyAndReferences(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;JoinForeignKey:UserReferID;References:UserRefer;JoinReferences:ProfileRefer"` Profiles2 []Profile `gorm:"many2many:user_profiles2;ForeignKey:refer;JoinForeignKey:user_refer_id;References:user_refer;JoinReferences:profile_refer"` Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, References: []Reference{ {"Refer", "User", "UserReferID", "user_profiles", "", true}, {"UserRefer", "Profile", "ProfileRefer", "user_profiles", "", false}, }, }, Relation{ Name: "Profiles2", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", JoinTable: JoinTable{Name: "user_profiles2", Table: "user_profiles2"}, References: []Reference{ {"Refer", "User", "User_refer_id", "user_profiles2", "", true}, {"UserRefer", "Profile", "Profile_refer", "user_profiles2", "", false}, }, }) } func TestMany2ManyOverrideForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model Profiles []Profile `gorm:"many2many:user_profiles;ForeignKey:Refer;References:UserRefer"` Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, References: []Reference{ {"Refer", "User", "UserRefer", "user_profiles", "", true}, {"UserRefer", "Profile", "ProfileUserRefer", "user_profiles", "", false}, }, }) } func TestMany2ManySharedForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string Kind string ProfileRefer uint } type User struct { gorm.Model Profiles []Profile `gorm:"many2many:user_profiles;foreignKey:Refer,Kind;joinForeignKey:UserRefer,Kind;References:ProfileRefer,Kind;joinReferences:ProfileR,Kind"` Kind string Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", JoinTable: JoinTable{Name: "user_profiles", Table: "user_profiles"}, References: []Reference{ {"Refer", "User", "UserRefer", "user_profiles", "", true}, {"Kind", "User", "Kind", "user_profiles", "", true}, {"ProfileRefer", "Profile", "ProfileR", "user_profiles", "", false}, {"Kind", "Profile", "Kind", "user_profiles", "", false}, }, }) } func TestMany2ManyOverrideJoinForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model Profiles []Profile `gorm:"many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, References: []Reference{ {"ID", "User", "UserReferID", "user_profile", "", true}, {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, }, }) } func TestBuildReadonlyMany2ManyRelation(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model Profiles []Profile `gorm:"->;many2many:user_profile;JoinForeignKey:UserReferID;JoinReferences:ProfileRefer"` Refer uint } checkStructRelation(t, &User{}, Relation{ Name: "Profiles", Type: schema.Many2Many, Schema: "User", FieldSchema: "Profile", JoinTable: JoinTable{Name: "user_profile", Table: "user_profile"}, References: []Reference{ {"ID", "User", "UserReferID", "user_profile", "", true}, {"ID", "Profile", "ProfileRefer", "user_profile", "", false}, }, }) } func TestMany2ManyWithMultiPrimaryKeys(t *testing.T) { type Tag struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Value string } type Blog struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Subject string Body string Tags []Tag `gorm:"many2many:blog_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` } checkStructRelation(t, &Blog{}, Relation{ Name: "Tags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", JoinTable: JoinTable{Name: "blog_tags", Table: "blog_tags"}, References: []Reference{ {"ID", "Blog", "BlogID", "blog_tags", "", true}, {"Locale", "Blog", "BlogLocale", "blog_tags", "", true}, {"ID", "Tag", "TagID", "blog_tags", "", false}, {"Locale", "Tag", "TagLocale", "blog_tags", "", false}, }, }, Relation{ Name: "SharedTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", JoinTable: JoinTable{Name: "shared_blog_tags", Table: "shared_blog_tags"}, References: []Reference{ {"ID", "Blog", "BlogID", "shared_blog_tags", "", true}, {"ID", "Tag", "TagID", "shared_blog_tags", "", false}, }, }, Relation{ Name: "LocaleTags", Type: schema.Many2Many, Schema: "Blog", FieldSchema: "Tag", JoinTable: JoinTable{Name: "locale_blog_tags", Table: "locale_blog_tags"}, References: []Reference{ {"ID", "Blog", "BlogID", "locale_blog_tags", "", true}, {"Locale", "Blog", "BlogLocale", "locale_blog_tags", "", true}, {"ID", "Tag", "TagID", "locale_blog_tags", "", false}, }, }, ) } func TestMultipleMany2Many(t *testing.T) { type Thing struct { ID int } type Person struct { ID int Likes []Thing `gorm:"many2many:likes"` Dislikes []Thing `gorm:"many2many:dislikes"` } checkStructRelation(t, &Person{}, Relation{ Name: "Likes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", JoinTable: JoinTable{Name: "likes", Table: "likes"}, References: []Reference{ {"ID", "Person", "PersonID", "likes", "", true}, {"ID", "Thing", "ThingID", "likes", "", false}, }, }, Relation{ Name: "Dislikes", Type: schema.Many2Many, Schema: "Person", FieldSchema: "Thing", JoinTable: JoinTable{Name: "dislikes", Table: "dislikes"}, References: []Reference{ {"ID", "Person", "PersonID", "dislikes", "", true}, {"ID", "Thing", "ThingID", "dislikes", "", false}, }, }, ) } func TestSelfReferentialMany2Many(t *testing.T) { type User struct { ID int32 `gorm:"primaryKey"` Name string CreatedBy int32 Creators []User `gorm:"foreignKey:CreatedBy"` AnotherPro interface{} `gorm:"-"` } checkStructRelation(t, &User{}, Relation{ Name: "Creators", Type: schema.HasMany, Schema: "User", FieldSchema: "User", References: []Reference{{"ID", "User", "CreatedBy", "User", "", true}}, }) user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse schema") } relSchema := user.Relationships.Relations["Creators"].FieldSchema if user != relSchema { t.Fatalf("schema should be same, expects %p but got %p", user, relSchema) } } type CreatedByModel struct { CreatedByID uint CreatedBy *CreatedUser } type CreatedUser struct { gorm.Model CreatedByModel } func TestEmbeddedRelation(t *testing.T) { checkStructRelation(t, &CreatedUser{}, Relation{ Name: "CreatedBy", Type: schema.BelongsTo, Schema: "CreatedUser", FieldSchema: "CreatedUser", References: []Reference{ {"ID", "CreatedUser", "CreatedByID", "CreatedUser", "", false}, }, }) userSchema, err := schema.Parse(&CreatedUser{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse schema, got error %v", err) } if len(userSchema.Relationships.Relations) != 1 { t.Fatalf("expects 1 relations, but got %v", len(userSchema.Relationships.Relations)) } if createdByRel, ok := userSchema.Relationships.Relations["CreatedBy"]; ok { if createdByRel.FieldSchema != userSchema { t.Fatalf("expects same field schema, but got new %p, old %p", createdByRel.FieldSchema, userSchema) } } else { t.Fatalf("expects created by relations, but not found") } } func TestEmbeddedHas(t *testing.T) { type Toy struct { ID int Name string OwnerID int OwnerType string } type User struct { ID int Cat struct { Name string Toy Toy `gorm:"polymorphic:Owner;"` Toys []Toy `gorm:"polymorphic:Owner;"` } `gorm:"embedded;embeddedPrefix:cat_"` Dog struct { ID int Name string UserID int Toy Toy `gorm:"polymorphic:Owner;"` Toys []Toy `gorm:"polymorphic:Owner;"` } Toys []Toy `gorm:"polymorphic:Owner;"` } s, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse schema, got error %v", err) } checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ "Cat": { Relations: map[string]Relation{ "Toy": { Name: "Toy", Type: schema.HasOne, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, References: []Reference{ {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, }, }, "Toys": { Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, References: []Reference{ {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, }, }, }, }, }) } func TestPolymorphic(t *testing.T) { t.Run("has one", func(t *testing.T) { type Toy struct { ID int Name string OwnerID int OwnerType string } type Cat struct { ID int Name string Toy Toy `gorm:"polymorphic:Owner;"` } s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse schema, got error %v", err) } checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ "Cat": { Relations: map[string]Relation{ "Toy": { Name: "Toy", Type: schema.HasOne, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, References: []Reference{ {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, }, }, }, }, }) }) t.Run("has one with custom polymorphic type and id", func(t *testing.T) { type Toy struct { ID int Name string RefId int Type string } type Cat struct { ID int Name string Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type;polymorphicId:RefId"` } s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse schema, got error %v", err) } checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ "Cat": { Relations: map[string]Relation{ "Toy": { Name: "Toy", Type: schema.HasOne, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, References: []Reference{ {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, }, }, }, }, }) }) t.Run("has one with only polymorphic type", func(t *testing.T) { type Toy struct { ID int Name string OwnerID int Type string } type Cat struct { ID int Name string Toy Toy `gorm:"polymorphic:Owner;polymorphicType:Type"` } s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse schema, got error %v", err) } checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ "Cat": { Relations: map[string]Relation{ "Toy": { Name: "Toy", Type: schema.HasOne, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "owner_id", Type: "Type", Value: "users"}, References: []Reference{ {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, }, }, }, }, }) }) t.Run("has many", func(t *testing.T) { type Toy struct { ID int Name string OwnerID int OwnerType string } type Cat struct { ID int Name string Toys []Toy `gorm:"polymorphic:Owner;"` } s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse schema, got error %v", err) } checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ "Cat": { Relations: map[string]Relation{ "Toys": { Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, References: []Reference{ {ForeignKey: "OwnerType", ForeignSchema: "Toy", PrimaryValue: "users"}, }, }, }, }, }) }) t.Run("has many with custom polymorphic type and id", func(t *testing.T) { type Toy struct { ID int Name string RefId int Type string } type Cat struct { ID int Name string Toys []Toy `gorm:"polymorphicType:Type;polymorphicId:RefId"` } s, err := schema.Parse(&Cat{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("Failed to parse schema, got error %v", err) } checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ "Cat": { Relations: map[string]Relation{ "Toys": { Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "ref_id", Type: "Type", Value: "users"}, References: []Reference{ {ForeignKey: "Type", ForeignSchema: "Toy", PrimaryValue: "users"}, }, }, }, }, }) }) } func TestEmbeddedBelongsTo(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` Name string } type Address struct { CountryID int Country Country } type NestedAddress struct { Address } type CountryMixin struct { CountryID int Country Country } type Org struct { ID int PostalAddress Address `gorm:"embedded;embeddedPrefix:postal_address_"` VisitingAddress Address `gorm:"embedded;embeddedPrefix:visiting_address_"` AddressID int Address struct { ID int Address } NestedAddress *NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` CountryMixin } s, err := schema.Parse(&Org{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Errorf("Failed to parse schema, got error %v", err) } checkEmbeddedRelations(t, s.Relationships.EmbeddedRelations, map[string]EmbeddedRelations{ "PostalAddress": { Relations: map[string]Relation{ "Country": { Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", References: []Reference{ {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, }, }, }, }, "VisitingAddress": { Relations: map[string]Relation{ "Country": { Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", References: []Reference{ {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, }, }, }, }, "NestedAddress": { Relations: map[string]Relation{ "Country": { Name: "Country", Type: schema.BelongsTo, Schema: "Org", FieldSchema: "Country", References: []Reference{ {PrimaryKey: "ID", PrimarySchema: "Country", ForeignKey: "CountryID", ForeignSchema: "Org"}, }, }, }, }, }) } func TestVariableRelation(t *testing.T) { var result struct { User } checkStructRelation(t, &result, Relation{ Name: "Account", Type: schema.HasOne, Schema: "", FieldSchema: "Account", References: []Reference{ {"ID", "", "UserID", "Account", "", true}, }, }) checkStructRelation(t, &result, Relation{ Name: "Company", Type: schema.BelongsTo, Schema: "", FieldSchema: "Company", References: []Reference{ {"ID", "Company", "CompanyID", "", "", false}, }, }) } func TestSameForeignKey(t *testing.T) { type UserAux struct { gorm.Model Aux string UUID string } type User struct { gorm.Model Name string UUID string Aux *UserAux `gorm:"foreignkey:UUID;references:UUID"` } checkStructRelation(t, &User{}, Relation{ Name: "Aux", Type: schema.HasOne, Schema: "User", FieldSchema: "UserAux", References: []Reference{ {"UUID", "User", "UUID", "UserAux", "", true}, }, }, ) } func TestBelongsToSameForeignKey(t *testing.T) { type User struct { gorm.Model Name string UUID string } type UserAux struct { gorm.Model Aux string UUID string User User `gorm:"ForeignKey:UUID;references:UUID;belongsTo"` } checkStructRelation(t, &UserAux{}, Relation{ Name: "User", Type: schema.BelongsTo, Schema: "UserAux", FieldSchema: "User", References: []Reference{ {"UUID", "User", "UUID", "UserAux", "", false}, }, }, ) } func TestHasOneWithSameForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string ProfileRefer int // not used in relationship } type User struct { gorm.Model Profile Profile `gorm:"ForeignKey:ID;references:ProfileRefer"` ProfileRefer int } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasOne, Schema: "User", FieldSchema: "Profile", References: []Reference{{"ProfileRefer", "User", "ID", "Profile", "", true}}, }) } func TestHasManySameForeignKey(t *testing.T) { type Profile struct { gorm.Model Name string UserRefer uint } type User struct { gorm.Model UserRefer uint Profile []Profile `gorm:"ForeignKey:UserRefer"` } checkStructRelation(t, &User{}, Relation{ Name: "Profile", Type: schema.HasMany, Schema: "User", FieldSchema: "Profile", References: []Reference{{"ID", "User", "UserRefer", "Profile", "", true}}, }) } type Author struct { gorm.Model } type Book struct { gorm.Model Author Author AuthorID uint } func (Book) TableName() string { return "my_schema.a_very_very_very_very_very_very_very_very_long_table_name" } func TestParseConstraintNameWithSchemaQualifiedLongTableName(t *testing.T) { s, err := schema.Parse( &Book{}, &sync.Map{}, schema.NamingStrategy{IdentifierMaxLength: 64}, ) if err != nil { t.Fatalf("Failed to parse schema") } expectedConstraintName := "fk_my_schema_a_very_very_very_very_very_very_very_very_l4db13eec" constraint := s.Relationships.Relations["Author"].ParseConstraint() if constraint.Name != expectedConstraintName { t.Fatalf( "expected constraint name %s, got %s", expectedConstraintName, constraint.Name, ) } } type InfoRelation struct { ID int Code string Info1 []*Info1 `gorm:"foreignkey:Code;references:Code"` Info2 []*Info2 `gorm:"foreignkey:Code;references:Code"` } type Info1 struct { CreatedAt time.Time UpdatedAt time.Time Code string Relation []*InfoRelation `gorm:"foreignkey:Code;references:Code"` } type Info2 struct { CreatedAt time.Time UpdatedAt time.Time Code string Relation []*InfoRelation `gorm:"foreignkey:Code;references:Code"` } func TestDataRace(t *testing.T) { syncMap := &sync.Map{} for i := 0; i < 10; i++ { go func() { schema.Parse(&Info1{}, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) }() go func() { schema.Parse(&Info2{}, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) }() go func() { var result User schema.Parse(&result, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) }() go func() { var result tests.Account schema.Parse(&result, syncMap, schema.NamingStrategy{IdentifierMaxLength: 64}) }() } } ================================================ FILE: schema/schema.go ================================================ package schema import ( "context" "errors" "fmt" "go/ast" "path" "reflect" "strings" "sync" "gorm.io/gorm/clause" "gorm.io/gorm/logger" ) type callbackType string const ( callbackTypeBeforeCreate callbackType = "BeforeCreate" callbackTypeBeforeUpdate callbackType = "BeforeUpdate" callbackTypeAfterCreate callbackType = "AfterCreate" callbackTypeAfterUpdate callbackType = "AfterUpdate" callbackTypeBeforeSave callbackType = "BeforeSave" callbackTypeAfterSave callbackType = "AfterSave" callbackTypeBeforeDelete callbackType = "BeforeDelete" callbackTypeAfterDelete callbackType = "AfterDelete" callbackTypeAfterFind callbackType = "AfterFind" ) // ErrUnsupportedDataType unsupported data type var ErrUnsupportedDataType = errors.New("unsupported data type") type Schema struct { Name string ModelType reflect.Type Table string PrioritizedPrimaryField *Field DBNames []string PrimaryFields []*Field PrimaryFieldDBNames []string Fields []*Field FieldsByName map[string]*Field FieldsByBindName map[string]*Field // embedded fields is 'Embed.Field' FieldsByDBName map[string]*Field FieldsWithDefaultDBValue []*Field // fields with default value assigned by database Relationships Relationships CreateClauses []clause.Interface QueryClauses []clause.Interface UpdateClauses []clause.Interface DeleteClauses []clause.Interface BeforeCreate, AfterCreate bool BeforeUpdate, AfterUpdate bool BeforeDelete, AfterDelete bool BeforeSave, AfterSave bool AfterFind bool err error initialized chan struct{} namer Namer cacheStore *sync.Map } func (schema *Schema) String() string { if schema.ModelType.Name() == "" { return fmt.Sprintf("%s(%s)", schema.Name, schema.Table) } return fmt.Sprintf("%s.%s", schema.ModelType.PkgPath(), schema.ModelType.Name()) } func (schema *Schema) MakeSlice() reflect.Value { slice := reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(schema.ModelType)), 0, 20) results := reflect.New(slice.Type()) results.Elem().Set(slice) return results } func (schema *Schema) LookUpField(name string) *Field { if field, ok := schema.FieldsByDBName[name]; ok { return field } if field, ok := schema.FieldsByName[name]; ok { return field } // Lookup field using namer-driven ColumnName if schema.namer == nil { return nil } namerColumnName := schema.namer.ColumnName(schema.Table, name) if field, ok := schema.FieldsByDBName[namerColumnName]; ok { return field } return nil } // LookUpFieldByBindName looks for the closest field in the embedded struct. // // type Struct struct { // Embedded struct { // ID string // is selected by LookUpFieldByBindName([]string{"Embedded", "ID"}, "ID") // } // ID string // is selected by LookUpFieldByBindName([]string{"ID"}, "ID") // } func (schema *Schema) LookUpFieldByBindName(bindNames []string, name string) *Field { for i := len(bindNames) - 1; i >= 0; i-- { find := strings.Join(bindNames[:i], ".") + "." + name if field, ok := schema.FieldsByBindName[find]; ok { return field } } return nil } type Tabler interface { TableName() string } type TablerWithNamer interface { TableName(Namer) string } var callbackTypes = []callbackType{ callbackTypeBeforeCreate, callbackTypeAfterCreate, callbackTypeBeforeUpdate, callbackTypeAfterUpdate, callbackTypeBeforeSave, callbackTypeAfterSave, callbackTypeBeforeDelete, callbackTypeAfterDelete, callbackTypeAfterFind, } // Parse get data type from dialector func Parse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { return ParseWithSpecialTableName(dest, cacheStore, namer, "") } // ParseWithSpecialTableName get data type from dialector with extra schema table func ParseWithSpecialTableName(dest interface{}, cacheStore *sync.Map, namer Namer, specialTableName string) (*Schema, error) { if dest == nil { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } modelType := reflect.ValueOf(dest).Type() if modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { if modelType.Kind() == reflect.Interface { modelType = reflect.Indirect(reflect.ValueOf(dest)).Elem().Type() } for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } } // Cache the Schema for performance, // Use the modelType or modelType + schemaTable (if it present) as cache key. var schemaCacheKey interface{} = modelType if specialTableName != "" { schemaCacheKey = fmt.Sprintf("%p-%s", modelType, specialTableName) } // Load exist schema cache, return if exists if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } var tableName string modelValue := reflect.New(modelType) if specialTableName != "" { tableName = specialTableName } else if en, ok := namer.(embeddedNamer); ok { tableName = en.Table } else if tabler, ok := modelValue.Interface().(Tabler); ok { tableName = tabler.TableName() } else if tabler, ok := modelValue.Interface().(TablerWithNamer); ok { tableName = tabler.TableName(namer) } else { tableName = namer.TableName(modelType.Name()) } schema := &Schema{ Name: modelType.Name(), ModelType: modelType, Table: tableName, DBNames: make([]string, 0, 10), Fields: make([]*Field, 0, 10), FieldsByName: make(map[string]*Field, 10), FieldsByBindName: make(map[string]*Field, 10), FieldsByDBName: make(map[string]*Field, 10), Relationships: Relationships{Relations: map[string]*Relationship{}}, cacheStore: cacheStore, namer: namer, initialized: make(chan struct{}), } // When the schema initialization is completed, the channel will be closed defer close(schema.initialized) // Load exist schema cache, return if exists if v, ok := cacheStore.Load(schemaCacheKey); ok { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } for i := 0; i < modelType.NumField(); i++ { if fieldStruct := modelType.Field(i); ast.IsExported(fieldStruct.Name) { if field := schema.ParseField(fieldStruct); field.EmbeddedSchema != nil { schema.Fields = append(schema.Fields, field.EmbeddedSchema.Fields...) } else { schema.Fields = append(schema.Fields, field) } } } for _, field := range schema.Fields { if field.DBName == "" && field.DataType != "" { field.DBName = namer.ColumnName(schema.Table, field.Name) } bindName := field.BindName() if field.DBName != "" { // nonexistence or shortest path or first appear prioritized if has permission if v, ok := schema.FieldsByDBName[field.DBName]; !ok || ((field.Creatable || field.Updatable || field.Readable) && len(field.BindNames) < len(v.BindNames)) { if _, ok := schema.FieldsByDBName[field.DBName]; !ok { schema.DBNames = append(schema.DBNames, field.DBName) } schema.FieldsByDBName[field.DBName] = field schema.FieldsByName[field.Name] = field schema.FieldsByBindName[bindName] = field if v != nil && v.PrimaryKey { // remove the existing primary key field for idx, f := range schema.PrimaryFields { if f.DBName == v.DBName { schema.PrimaryFields = append(schema.PrimaryFields[0:idx], schema.PrimaryFields[idx+1:]...) } } } if field.PrimaryKey { schema.PrimaryFields = append(schema.PrimaryFields, field) } } } if of, ok := schema.FieldsByName[field.Name]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByName[field.Name] = field } if of, ok := schema.FieldsByBindName[bindName]; !ok || of.TagSettings["-"] == "-" { schema.FieldsByBindName[bindName] = field } field.setupValuerAndSetter(modelType) } prioritizedPrimaryField := schema.LookUpField("id") if prioritizedPrimaryField == nil { prioritizedPrimaryField = schema.LookUpField("ID") } if prioritizedPrimaryField != nil { if prioritizedPrimaryField.PrimaryKey { schema.PrioritizedPrimaryField = prioritizedPrimaryField } else if len(schema.PrimaryFields) == 0 { prioritizedPrimaryField.PrimaryKey = true schema.PrioritizedPrimaryField = prioritizedPrimaryField schema.PrimaryFields = append(schema.PrimaryFields, prioritizedPrimaryField) } } if schema.PrioritizedPrimaryField == nil { if len(schema.PrimaryFields) == 1 { schema.PrioritizedPrimaryField = schema.PrimaryFields[0] } else if len(schema.PrimaryFields) > 1 { // If there are multiple primary keys, the AUTOINCREMENT field is prioritized for _, field := range schema.PrimaryFields { if field.AutoIncrement { schema.PrioritizedPrimaryField = field break } } } } for _, field := range schema.PrimaryFields { schema.PrimaryFieldDBNames = append(schema.PrimaryFieldDBNames, field.DBName) } _, embedded := schema.cacheStore.Load(embeddedCacheKey) relationshipFields := []*Field{} for _, field := range schema.Fields { if field.DataType != "" && field.HasDefaultValue && field.DefaultValueInterface == nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } if !embedded { if field.DataType == "" && field.GORMDataType == "" && (field.Creatable || field.Updatable || field.Readable) { relationshipFields = append(relationshipFields, field) schema.FieldsByName[field.Name] = field schema.FieldsByBindName[field.BindName()] = field } fieldValue := reflect.New(field.IndirectFieldType).Interface() if fc, ok := fieldValue.(CreateClausesInterface); ok { field.Schema.CreateClauses = append(field.Schema.CreateClauses, fc.CreateClauses(field)...) } if fc, ok := fieldValue.(QueryClausesInterface); ok { field.Schema.QueryClauses = append(field.Schema.QueryClauses, fc.QueryClauses(field)...) } if fc, ok := fieldValue.(UpdateClausesInterface); ok { field.Schema.UpdateClauses = append(field.Schema.UpdateClauses, fc.UpdateClauses(field)...) } if fc, ok := fieldValue.(DeleteClausesInterface); ok { field.Schema.DeleteClauses = append(field.Schema.DeleteClauses, fc.DeleteClauses(field)...) } } } if field := schema.PrioritizedPrimaryField; field != nil { switch field.GORMDataType { case Int, Uint: if _, ok := field.TagSettings["AUTOINCREMENT"]; !ok { if !field.HasDefaultValue || field.DefaultValueInterface != nil { schema.FieldsWithDefaultDBValue = append(schema.FieldsWithDefaultDBValue, field) } field.HasDefaultValue = true field.AutoIncrement = true } } } // Cache the schema if v, loaded := cacheStore.LoadOrStore(schemaCacheKey, schema); loaded { s := v.(*Schema) // Wait for the initialization of other goroutines to complete <-s.initialized return s, s.err } defer func() { if schema.err != nil { logger.Default.Error(context.Background(), schema.err.Error()) cacheStore.Delete(modelType) } }() for _, cbName := range callbackTypes { if methodValue := modelValue.MethodByName(string(cbName)); methodValue.IsValid() { switch methodValue.Type().String() { case "func(*gorm.DB) error": expectedPkgPath := path.Dir(reflect.TypeOf(schema).Elem().PkgPath()) if inVarPkg := methodValue.Type().In(0).Elem().PkgPath(); inVarPkg == expectedPkgPath { reflect.Indirect(reflect.ValueOf(schema)).FieldByName(string(cbName)).SetBool(true) } else { logger.Default.Warn(context.Background(), "In model %v, the hook function `%v(*gorm.DB) error` has an incorrect parameter type. The expected parameter type is `%v`, but the provided type is `%v`.", schema, cbName, expectedPkgPath, inVarPkg) // PASS } default: logger.Default.Warn(context.Background(), "Model %v don't match %vInterface, should be `%v(*gorm.DB) error`. Please see https://gorm.io/docs/hooks.html", schema, cbName, cbName) } } } // parse relationships for _, field := range relationshipFields { if schema.parseRelation(field); schema.err != nil { return schema, schema.err } } return schema, schema.err } func getOrParse(dest interface{}, cacheStore *sync.Map, namer Namer) (*Schema, error) { modelType := reflect.ValueOf(dest).Type() if modelType.Kind() != reflect.Struct { for modelType.Kind() == reflect.Slice || modelType.Kind() == reflect.Array || modelType.Kind() == reflect.Ptr { modelType = modelType.Elem() } if modelType.Kind() != reflect.Struct { if modelType.PkgPath() == "" { return nil, fmt.Errorf("%w: %+v", ErrUnsupportedDataType, dest) } return nil, fmt.Errorf("%w: %s.%s", ErrUnsupportedDataType, modelType.PkgPath(), modelType.Name()) } } if v, ok := cacheStore.Load(modelType); ok { return v.(*Schema), nil } return Parse(dest, cacheStore, namer) } ================================================ FILE: schema/schema_helper_test.go ================================================ package schema_test import ( "context" "fmt" "reflect" "strings" "testing" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) func checkSchema(t *testing.T, s *schema.Schema, v *schema.Schema, primaryFields []string) { t.Run("CheckSchema/"+s.Name, func(t *testing.T) { tests.AssertObjEqual(t, s, v, "Name", "Table") for idx, field := range primaryFields { var found bool for _, f := range s.PrimaryFields { if f.Name == field { found = true } } if idx == 0 { if field != s.PrioritizedPrimaryField.Name { t.Errorf("schema %v prioritized primary field should be %v, but got %v", s, field, s.PrioritizedPrimaryField.Name) } } if !found { t.Errorf("schema %v failed to found primary key: %v", s, field) } } }) } func checkSchemaField(t *testing.T, s *schema.Schema, f *schema.Field, fc func(*schema.Field)) { t.Run("CheckField/"+f.Name, func(t *testing.T) { if fc != nil { fc(f) } if f.TagSettings == nil { if f.Tag != "" { f.TagSettings = schema.ParseTagSetting(f.Tag.Get("gorm"), ";") } else { f.TagSettings = map[string]string{} } } parsedField, ok := s.FieldsByDBName[f.DBName] if !ok { parsedField, ok = s.FieldsByName[f.Name] } if !ok { t.Errorf("schema %v failed to look up field with name %v", s, f.Name) } else { tests.AssertObjEqual(t, parsedField, f, "Name", "DBName", "BindNames", "DataType", "PrimaryKey", "AutoIncrement", "Creatable", "Updatable", "Readable", "HasDefaultValue", "DefaultValue", "NotNull", "Unique", "Comment", "Size", "Precision", "TagSettings") if f.DBName != "" { if field, ok := s.FieldsByDBName[f.DBName]; !ok || parsedField != field { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) } } for _, name := range []string{f.DBName, f.Name} { if name != "" { if field := s.LookUpField(name); field == nil || (field.Name != name && field.DBName != name) { t.Errorf("schema %v failed to look up field with dbname %v", s, f.DBName) } } } if f.PrimaryKey { var found bool for _, primaryField := range s.PrimaryFields { if primaryField == parsedField { found = true } } if !found { t.Errorf("schema %v doesn't include field %v", s, f.Name) } } } }) } type Relation struct { Name string Type schema.RelationshipType Schema string FieldSchema string Polymorphic Polymorphic JoinTable JoinTable References []Reference } type Polymorphic struct { ID string Type string Value string } type JoinTable struct { Name string Table string Fields []schema.Field } type Reference struct { PrimaryKey string PrimarySchema string ForeignKey string ForeignSchema string PrimaryValue string OwnPrimaryKey bool } func checkSchemaRelation(t *testing.T, s *schema.Schema, relation Relation) { t.Run("CheckRelation/"+relation.Name, func(t *testing.T) { if r, ok := s.Relationships.Relations[relation.Name]; ok { if r.Name != relation.Name { t.Errorf("schema %v relation name expects %v, but got %v", s, r.Name, relation.Name) } if r.Type != relation.Type { t.Errorf("schema %v relation name expects %v, but got %v", s, r.Type, relation.Type) } if r.Schema.Name != relation.Schema { t.Errorf("schema %v relation's schema expects %v, but got %v", s, relation.Schema, r.Schema.Name) } if r.FieldSchema.Name != relation.FieldSchema { t.Errorf("schema %v field relation's schema expects %v, but got %v", s, relation.FieldSchema, r.FieldSchema.Name) } if r.Polymorphic != nil { if r.Polymorphic.PolymorphicID.Name != relation.Polymorphic.ID { t.Errorf("schema %v relation's polymorphic id field expects %v, but got %v", s, relation.Polymorphic.ID, r.Polymorphic.PolymorphicID.Name) } if r.Polymorphic.PolymorphicType.Name != relation.Polymorphic.Type { t.Errorf("schema %v relation's polymorphic type field expects %v, but got %v", s, relation.Polymorphic.Type, r.Polymorphic.PolymorphicType.Name) } if r.Polymorphic.Value != relation.Polymorphic.Value { t.Errorf("schema %v relation's polymorphic value expects %v, but got %v", s, relation.Polymorphic.Value, r.Polymorphic.Value) } } if r.JoinTable != nil { if r.JoinTable.Name != relation.JoinTable.Name { t.Errorf("schema %v relation's join table name expects %v, but got %v", s, relation.JoinTable.Name, r.JoinTable.Name) } if r.JoinTable.Table != relation.JoinTable.Table { t.Errorf("schema %v relation's join table tablename expects %v, but got %v", s, relation.JoinTable.Table, r.JoinTable.Table) } for i := range relation.JoinTable.Fields { checkSchemaField(t, r.JoinTable, &relation.JoinTable.Fields[i], nil) } } if len(relation.References) != len(r.References) { t.Errorf("schema %v relation's reference's count doesn't match, expects %v, but got %v", s, len(relation.References), len(r.References)) } for _, ref := range relation.References { var found bool for _, rf := range r.References { if (rf.PrimaryKey == nil || (rf.PrimaryKey.Name == ref.PrimaryKey && rf.PrimaryKey.Schema.Name == ref.PrimarySchema)) && (rf.PrimaryValue == ref.PrimaryValue) && (rf.ForeignKey.Name == ref.ForeignKey && rf.ForeignKey.Schema.Name == ref.ForeignSchema) && (rf.OwnPrimaryKey == ref.OwnPrimaryKey) { found = true } } if !found { var refs []string for _, rf := range r.References { var primaryKey, primaryKeySchema string if rf.PrimaryKey != nil { primaryKey, primaryKeySchema = rf.PrimaryKey.Name, rf.PrimaryKey.Schema.Name } refs = append(refs, fmt.Sprintf( "{PrimaryKey: %v PrimaryKeySchame: %v ForeignKey: %v ForeignKeySchema: %v PrimaryValue: %v OwnPrimaryKey: %v}", primaryKey, primaryKeySchema, rf.ForeignKey.Name, rf.ForeignKey.Schema.Name, rf.PrimaryValue, rf.OwnPrimaryKey, )) } t.Errorf("schema %v relation %v failed to found reference %+v, has %v", s, relation.Name, ref, strings.Join(refs, ", ")) } } } else { t.Errorf("schema %v failed to find relations by name %v", s, relation.Name) } }) } type EmbeddedRelations struct { Relations map[string]Relation EmbeddedRelations map[string]EmbeddedRelations } func checkEmbeddedRelations(t *testing.T, actual map[string]*schema.Relationships, expected map[string]EmbeddedRelations) { for name, relations := range actual { rs := expected[name] t.Run("CheckEmbeddedRelations/"+name, func(t *testing.T) { if len(relations.Relations) != len(rs.Relations) { t.Errorf("schema relations count don't match, expects %d, got %d", len(rs.Relations), len(relations.Relations)) } if len(relations.EmbeddedRelations) != len(rs.EmbeddedRelations) { t.Errorf("schema embedded relations count don't match, expects %d, got %d", len(rs.EmbeddedRelations), len(relations.EmbeddedRelations)) } for n, rel := range relations.Relations { if r, ok := rs.Relations[n]; !ok { t.Errorf("failed to find relation by name %s", n) } else { checkSchemaRelation(t, &schema.Schema{ Relationships: schema.Relationships{ Relations: map[string]*schema.Relationship{n: rel}, }, }, r) } } checkEmbeddedRelations(t, relations.EmbeddedRelations, rs.EmbeddedRelations) }) } } func checkField(t *testing.T, s *schema.Schema, value reflect.Value, values map[string]interface{}) { for k, v := range values { t.Run("CheckField/"+k, func(t *testing.T) { fv, _ := s.FieldsByDBName[k].ValueOf(context.Background(), value) tests.AssertEqual(t, v, fv) }) } } ================================================ FILE: schema/schema_test.go ================================================ package schema_test import ( "strings" "sync" "testing" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" ) func TestParseSchema(t *testing.T) { user, err := schema.Parse(&tests.User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user, got error %v", err) } checkUserSchema(t, user) } func TestParseSchemaWithMap(t *testing.T) { type User struct { tests.User Attrs map[string]string `gorm:"type:Map(String,String);"` } user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse user with map, got error %v", err) } if field := user.FieldsByName["Attrs"]; field.DataType != "Map(String,String)" { t.Errorf("failed to parse user field Attrs") } } func TestParseSchemaWithPointerFields(t *testing.T) { user, err := schema.Parse(&User{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } checkUserSchema(t, user) } func checkUserSchema(t *testing.T, user *schema.Schema) { // check schema checkSchema(t, user, &schema.Schema{Name: "User", Table: "users"}, []string{"ID"}) // check fields fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Tag: `gorm:"primarykey"`, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "CreatedAt", DBName: "created_at", BindNames: []string{"Model", "CreatedAt"}, DataType: schema.Time}, {Name: "UpdatedAt", DBName: "updated_at", BindNames: []string{"Model", "UpdatedAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"Model", "DeletedAt"}, Tag: `gorm:"index"`, DataType: schema.Time}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Age", DBName: "age", BindNames: []string{"Age"}, DataType: schema.Uint, Size: 64}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "CompanyID", DBName: "company_id", BindNames: []string{"CompanyID"}, DataType: schema.Int, Size: 64}, {Name: "ManagerID", DBName: "manager_id", BindNames: []string{"ManagerID"}, DataType: schema.Uint, Size: 64}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, } for i := range fields { checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true }) } // check relations relations := []Relation{ { Name: "Account", Type: schema.HasOne, Schema: "User", FieldSchema: "Account", References: []Reference{{"ID", "User", "UserID", "Account", "", true}}, }, { Name: "Pets", Type: schema.HasMany, Schema: "User", FieldSchema: "Pet", References: []Reference{{"ID", "User", "UserID", "Pet", "", true}}, }, { Name: "Toys", Type: schema.HasMany, Schema: "User", FieldSchema: "Toy", Polymorphic: Polymorphic{ID: "OwnerID", Type: "OwnerType", Value: "users"}, References: []Reference{{"ID", "User", "OwnerID", "Toy", "", true}, {"", "", "OwnerType", "Toy", "users", false}}, }, { Name: "Company", Type: schema.BelongsTo, Schema: "User", FieldSchema: "Company", References: []Reference{{"ID", "Company", "CompanyID", "User", "", false}}, }, { Name: "Manager", Type: schema.BelongsTo, Schema: "User", FieldSchema: "User", References: []Reference{{"ID", "User", "ManagerID", "User", "", false}}, }, { Name: "Team", Type: schema.HasMany, Schema: "User", FieldSchema: "User", References: []Reference{{"ID", "User", "ManagerID", "User", "", true}}, }, { Name: "Languages", Type: schema.Many2Many, Schema: "User", FieldSchema: "Language", JoinTable: JoinTable{Name: "UserSpeak", Table: "user_speaks", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, }, { Name: "LanguageCode", DBName: "language_code", BindNames: []string{"LanguageCode"}, DataType: schema.String, Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, }, }}, References: []Reference{{"ID", "User", "UserID", "UserSpeak", "", true}, {"Code", "Language", "LanguageCode", "UserSpeak", "", false}}, }, { Name: "Friends", Type: schema.Many2Many, Schema: "User", FieldSchema: "User", JoinTable: JoinTable{Name: "user_friends", Table: "user_friends", Fields: []schema.Field{ { Name: "UserID", DBName: "user_id", BindNames: []string{"UserID"}, DataType: schema.Uint, Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, }, { Name: "FriendID", DBName: "friend_id", BindNames: []string{"FriendID"}, DataType: schema.Uint, Tag: `gorm:"primarykey"`, Creatable: true, Updatable: true, Readable: true, PrimaryKey: true, Size: 64, }, }}, References: []Reference{{"ID", "User", "UserID", "user_friends", "", true}, {"ID", "User", "FriendID", "user_friends", "", false}}, }, } for _, relation := range relations { checkSchemaRelation(t, user, relation) } } func TestParseSchemaWithAdvancedDataType(t *testing.T) { user, err := schema.Parse(&AdvancedDataTypeUser{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } // check schema checkSchema(t, user, &schema.Schema{Name: "AdvancedDataTypeUser", Table: "advanced_data_type_users"}, []string{"ID"}) // check fields fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"ID"}, DataType: schema.Int, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "Name", DBName: "name", BindNames: []string{"Name"}, DataType: schema.String}, {Name: "Birthday", DBName: "birthday", BindNames: []string{"Birthday"}, DataType: schema.Time}, {Name: "RegisteredAt", DBName: "registered_at", BindNames: []string{"RegisteredAt"}, DataType: schema.Time}, {Name: "DeletedAt", DBName: "deleted_at", BindNames: []string{"DeletedAt"}, DataType: schema.Time}, {Name: "Active", DBName: "active", BindNames: []string{"Active"}, DataType: schema.Bool}, {Name: "Admin", DBName: "admin", BindNames: []string{"Admin"}, DataType: schema.Bool}, } for i := range fields { checkSchemaField(t, user, &fields[i], func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true }) } } type CustomizeTable struct{} func (CustomizeTable) TableName() string { return "customize" } func TestCustomizeTableName(t *testing.T) { customize, err := schema.Parse(&CustomizeTable{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse pointer user, got error %v", err) } if customize.Table != "customize" { t.Errorf("Failed to customize table with TableName method") } } func TestNestedModel(t *testing.T) { versionUser, err := schema.Parse(&VersionUser{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse nested user, got error %v", err) } fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"VersionModel", "BaseModel", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true}, {Name: "CreatedBy", DBName: "created_by", BindNames: []string{"VersionModel", "BaseModel", "CreatedBy"}, DataType: schema.Uint, Size: 64}, {Name: "Version", DBName: "version", BindNames: []string{"VersionModel", "Version"}, DataType: schema.Int, Size: 64}, } for _, f := range fields { checkSchemaField(t, versionUser, &f, func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true }) } } func TestEmbeddedStruct(t *testing.T) { type CorpBase struct { gorm.Model OwnerID string } type Company struct { ID int OwnerID int Name string Ignored string `gorm:"-"` } type Corp struct { CorpBase Base Company `gorm:"embedded;embeddedPrefix:company_"` } cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) } fields := []schema.Field{ {Name: "ID", DBName: "id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "company_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { if f.Name != "Ignored" { f.Creatable = true f.Updatable = true f.Readable = true } }) } } type CustomizedNamingStrategy struct { schema.NamingStrategy } func (ns CustomizedNamingStrategy) ColumnName(table, column string) string { baseColumnName := ns.NamingStrategy.ColumnName(table, column) if table == "" { return baseColumnName } s := strings.Split(table, "_") var prefix string switch len(s) { case 1: prefix = s[0][:3] case 2: prefix = s[0][:1] + s[1][:2] default: prefix = s[0][:1] + s[1][:1] + s[2][:1] } return prefix + "_" + baseColumnName } func TestEmbeddedStructForCustomizedNamingStrategy(t *testing.T) { type CorpBase struct { gorm.Model OwnerID string } type Company struct { ID int OwnerID int Name string Ignored string `gorm:"-"` } type Corp struct { CorpBase Base Company `gorm:"embedded;embeddedPrefix:company_"` } cropSchema, err := schema.Parse(&Corp{}, &sync.Map{}, CustomizedNamingStrategy{schema.NamingStrategy{}}) if err != nil { t.Fatalf("failed to parse embedded struct with primary key, got error %v", err) } fields := []schema.Field{ {Name: "ID", DBName: "cor_id", BindNames: []string{"CorpBase", "Model", "ID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY"}}, {Name: "ID", DBName: "company_cor_id", BindNames: []string{"Base", "ID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Name", DBName: "company_cor_name", BindNames: []string{"Base", "Name"}, DataType: schema.String, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "Ignored", BindNames: []string{"Base", "Ignored"}, TagSettings: map[string]string{"-": "-", "EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "company_cor_owner_id", BindNames: []string{"Base", "OwnerID"}, DataType: schema.Int, Size: 64, TagSettings: map[string]string{"EMBEDDED": "EMBEDDED", "EMBEDDEDPREFIX": "company_"}}, {Name: "OwnerID", DBName: "cor_owner_id", BindNames: []string{"CorpBase", "OwnerID"}, DataType: schema.String}, } for _, f := range fields { checkSchemaField(t, cropSchema, &f, func(f *schema.Field) { if f.Name != "Ignored" { f.Creatable = true f.Updatable = true f.Readable = true } }) } } func TestCompositePrimaryKeyWithAutoIncrement(t *testing.T) { type Product struct { ProductID uint `gorm:"primaryKey;autoIncrement"` LanguageCode uint `gorm:"primaryKey"` Code string Name string } type ProductNonAutoIncrement struct { ProductID uint `gorm:"primaryKey;autoIncrement:false"` LanguageCode uint `gorm:"primaryKey"` Code string Name string } product, err := schema.Parse(&Product{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse product struct with composite primary key, got error %v", err) } prioritizedPrimaryField := schema.Field{ Name: "ProductID", DBName: "product_id", BindNames: []string{"ProductID"}, DataType: schema.Uint, PrimaryKey: true, Size: 64, HasDefaultValue: true, AutoIncrement: true, TagSettings: map[string]string{"PRIMARYKEY": "PRIMARYKEY", "AUTOINCREMENT": "AUTOINCREMENT"}, } product.Fields = []*schema.Field{product.PrioritizedPrimaryField} checkSchemaField(t, product, &prioritizedPrimaryField, func(f *schema.Field) { f.Creatable = true f.Updatable = true f.Readable = true }) productNonAutoIncrement, err := schema.Parse(&ProductNonAutoIncrement{}, &sync.Map{}, schema.NamingStrategy{}) if err != nil { t.Fatalf("failed to parse productNonAutoIncrement struct with composite primary key, got error %v", err) } if productNonAutoIncrement.PrioritizedPrimaryField != nil { t.Fatalf("PrioritizedPrimaryField of non autoincrement composite key should be nil") } } ================================================ FILE: schema/serializer.go ================================================ package schema import ( "bytes" "context" "database/sql" "database/sql/driver" "encoding/gob" "encoding/json" "fmt" "math" "reflect" "strings" "sync" "time" ) var serializerMap = sync.Map{} // RegisterSerializer register serializer func RegisterSerializer(name string, serializer SerializerInterface) { serializerMap.Store(strings.ToLower(name), serializer) } // GetSerializer get serializer func GetSerializer(name string) (serializer SerializerInterface, ok bool) { v, ok := serializerMap.Load(strings.ToLower(name)) if ok { serializer, ok = v.(SerializerInterface) } return serializer, ok } func init() { RegisterSerializer("json", JSONSerializer{}) RegisterSerializer("unixtime", UnixSecondSerializer{}) RegisterSerializer("gob", GobSerializer{}) } // Serializer field value serializer type serializer struct { Field *Field Serializer SerializerInterface SerializeValuer SerializerValuerInterface Destination reflect.Value Context context.Context value interface{} fieldValue interface{} } // Scan implements sql.Scanner interface func (s *serializer) Scan(value interface{}) error { s.value = value return nil } // Value implements driver.Valuer interface func (s serializer) Value() (driver.Value, error) { return s.SerializeValuer.Value(s.Context, s.Field, s.Destination, s.fieldValue) } // SerializerInterface serializer interface type SerializerInterface interface { Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) error SerializerValuerInterface } // SerializerValuerInterface serializer valuer interface type SerializerValuerInterface interface { Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) } // JSONSerializer json serializer type JSONSerializer struct{} // Scan implements serializer interface func (JSONSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { fieldValue := reflect.New(field.FieldType) if dbValue != nil { var bytes []byte switch v := dbValue.(type) { case []byte: bytes = v case string: bytes = []byte(v) default: bytes, err = json.Marshal(v) if err != nil { return err } } if len(bytes) > 0 { err = json.Unmarshal(bytes, fieldValue.Interface()) } } field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) return } // Value implements serializer interface func (JSONSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { result, err := json.Marshal(fieldValue) if string(result) == "null" { if field.TagSettings["NOT NULL"] != "" { return "", nil } return nil, err } return string(result), err } // UnixSecondSerializer json serializer type UnixSecondSerializer struct{} // Scan implements serializer interface func (UnixSecondSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { t := sql.NullTime{} if err = t.Scan(dbValue); err == nil && t.Valid { err = field.Set(ctx, dst, t.Time.Unix()) } return } // Value implements serializer interface func (UnixSecondSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (result interface{}, err error) { rv := reflect.ValueOf(fieldValue) switch fieldValue.(type) { case int, int8, int16, int32, int64: result = time.Unix(rv.Int(), 0).UTC() case uint, uint8, uint16, uint32, uint64: if uv := rv.Uint(); uv > math.MaxInt64 { err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) } else { result = time.Unix(int64(uv), 0).UTC() //nolint:gosec } case *int, *int8, *int16, *int32, *int64: if rv.IsZero() { return nil, nil } result = time.Unix(rv.Elem().Int(), 0).UTC() case *uint, *uint8, *uint16, *uint32, *uint64: if rv.IsZero() { return nil, nil } if uv := rv.Elem().Uint(); uv > math.MaxInt64 { err = fmt.Errorf("integer overflow conversion uint64(%d) -> int64", uv) } else { result = time.Unix(int64(uv), 0).UTC() //nolint:gosec } default: err = fmt.Errorf("invalid field type %#v for UnixSecondSerializer, only int, uint supported", fieldValue) } return } // GobSerializer gob serializer type GobSerializer struct{} // Scan implements serializer interface func (GobSerializer) Scan(ctx context.Context, field *Field, dst reflect.Value, dbValue interface{}) (err error) { fieldValue := reflect.New(field.FieldType) if dbValue != nil { var bytesValue []byte switch v := dbValue.(type) { case []byte: bytesValue = v default: return fmt.Errorf("failed to unmarshal gob value: %#v", dbValue) } if len(bytesValue) > 0 { decoder := gob.NewDecoder(bytes.NewBuffer(bytesValue)) err = decoder.Decode(fieldValue.Interface()) } } field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) return } // Value implements serializer interface func (GobSerializer) Value(ctx context.Context, field *Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { buf := new(bytes.Buffer) err := gob.NewEncoder(buf).Encode(fieldValue) return buf.Bytes(), err } ================================================ FILE: schema/serializer_test.go ================================================ package schema import ( "context" "math" "reflect" "testing" "time" ) func TestUnixSecondSerializer_Value(t *testing.T) { var ( intValue = math.MaxInt64 int8Value = int8(math.MaxInt8) int16Value = int16(math.MaxInt16) int32Value = int32(math.MaxInt32) int64Value = int64(math.MaxInt64) uintValue = uint(math.MaxInt64) uint8Value = uint8(math.MaxUint8) uint16Value = uint16(math.MaxUint16) uint32Value = uint32(math.MaxUint32) uint64Value = uint64(math.MaxInt64) maxInt64Plus1 = uint64(math.MaxInt64 + 1) intPtrValue = &intValue int8PtrValue = &int8Value int16PtrValue = &int16Value int32PtrValue = &int32Value int64PtrValue = &int64Value uintPtrValue = &uintValue uint8PtrValue = &uint8Value uint16PtrValue = &uint16Value uint32PtrValue = &uint32Value uint64PtrValue = &uint64Value maxInt64Plus1Ptr = &maxInt64Plus1 ) tests := []struct { name string value interface{} want interface{} wantErr bool }{ { name: "int", value: intValue, want: time.Unix(int64(intValue), 0).UTC(), wantErr: false, }, { name: "int8", value: int8Value, want: time.Unix(int64(int8Value), 0).UTC(), wantErr: false, }, { name: "int16", value: int16Value, want: time.Unix(int64(int16Value), 0).UTC(), wantErr: false, }, { name: "int32", value: int32Value, want: time.Unix(int64(int32Value), 0).UTC(), wantErr: false, }, { name: "int64", value: int64Value, want: time.Unix(int64Value, 0).UTC(), wantErr: false, }, { name: "uint", value: uintValue, want: time.Unix(int64(uintValue), 0).UTC(), //nolint:gosec wantErr: false, }, { name: "uint8", value: uint8Value, want: time.Unix(int64(uint8Value), 0).UTC(), wantErr: false, }, { name: "uint16", value: uint16Value, want: time.Unix(int64(uint16Value), 0).UTC(), wantErr: false, }, { name: "uint32", value: uint32Value, want: time.Unix(int64(uint32Value), 0).UTC(), wantErr: false, }, { name: "uint64", value: uint64Value, want: time.Unix(int64(uint64Value), 0).UTC(), //nolint:gosec wantErr: false, }, { name: "maxInt64+1", value: maxInt64Plus1, want: nil, wantErr: true, }, { name: "*int", value: intPtrValue, want: time.Unix(int64(*intPtrValue), 0).UTC(), wantErr: false, }, { name: "*int8", value: int8PtrValue, want: time.Unix(int64(*int8PtrValue), 0).UTC(), wantErr: false, }, { name: "*int16", value: int16PtrValue, want: time.Unix(int64(*int16PtrValue), 0).UTC(), wantErr: false, }, { name: "*int32", value: int32PtrValue, want: time.Unix(int64(*int32PtrValue), 0).UTC(), wantErr: false, }, { name: "*int64", value: int64PtrValue, want: time.Unix(*int64PtrValue, 0).UTC(), wantErr: false, }, { name: "*uint", value: uintPtrValue, want: time.Unix(int64(*uintPtrValue), 0).UTC(), //nolint:gosec wantErr: false, }, { name: "*uint8", value: uint8PtrValue, want: time.Unix(int64(*uint8PtrValue), 0).UTC(), wantErr: false, }, { name: "*uint16", value: uint16PtrValue, want: time.Unix(int64(*uint16PtrValue), 0).UTC(), wantErr: false, }, { name: "*uint32", value: uint32PtrValue, want: time.Unix(int64(*uint32PtrValue), 0).UTC(), wantErr: false, }, { name: "*uint64", value: uint64PtrValue, want: time.Unix(int64(*uint64PtrValue), 0).UTC(), //nolint:gosec wantErr: false, }, { name: "pointer to maxInt64+1", value: maxInt64Plus1Ptr, want: nil, wantErr: true, }, { name: "nil pointer", value: (*int)(nil), want: nil, wantErr: false, }, { name: "invalid type", value: "invalid", want: nil, wantErr: true, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { got, err := UnixSecondSerializer{}.Value(context.Background(), nil, reflect.Value{}, tt.value) if (err != nil) != tt.wantErr { t.Fatalf("UnixSecondSerializer.Value() error = %v, wantErr %v", err, tt.wantErr) } if err != nil { return } if tt.want == nil && got == nil { return } if tt.want == nil { t.Fatalf("UnixSecondSerializer.Value() = %v, want nil", got) } if got == nil { t.Fatalf("UnixSecondSerializer.Value() = nil, want %v", tt.want) } if gotTime, ok := got.(time.Time); !ok { t.Errorf("UnixSecondSerializer.Value() returned %T, expected time.Time", got) } else if !tt.want.(time.Time).Equal(gotTime) { t.Errorf("UnixSecondSerializer.Value() = %v, want %v", got, tt.want) } }) } } ================================================ FILE: schema/utils.go ================================================ package schema import ( "context" "fmt" "reflect" "regexp" "strings" "gorm.io/gorm/clause" "gorm.io/gorm/utils" ) var embeddedCacheKey = "embedded_cache_store" func ParseTagSetting(str string, sep string) map[string]string { settings := map[string]string{} names := strings.Split(str, sep) var parsedNames []string for i := 0; i < len(names); i++ { s := names[i] for strings.HasSuffix(s, "\\") && i+1 < len(names) { i++ s = s[:len(s)-1] + sep + names[i] } parsedNames = append(parsedNames, s) } for _, tag := range parsedNames { values := strings.Split(tag, ":") k := strings.TrimSpace(strings.ToUpper(values[0])) if len(values) >= 2 { val := strings.Join(values[1:], ":") val = strings.ReplaceAll(val, `\"`, `"`) settings[k] = val } else if k != "" { settings[k] = k } } return settings } func toColumns(val string) (results []string) { if val != "" { for _, v := range strings.Split(val, ",") { results = append(results, strings.TrimSpace(v)) } } return } func removeSettingFromTag(tag reflect.StructTag, names ...string) reflect.StructTag { for _, name := range names { tag = reflect.StructTag(regexp.MustCompile(`(?i)(gorm:.*?)(`+name+`(:.*?)?)(;|("))`).ReplaceAllString(string(tag), "${1}${5}")) } return tag } func appendSettingFromTag(tag reflect.StructTag, value string) reflect.StructTag { t := tag.Get("gorm") if strings.Contains(t, value) { return tag } return reflect.StructTag(fmt.Sprintf(`gorm:"%s;%s"`, value, t)) } // GetRelationsValues get relations's values from a reflect value func GetRelationsValues(ctx context.Context, reflectValue reflect.Value, rels []*Relationship) (reflectResults reflect.Value) { for _, rel := range rels { reflectResults = reflect.MakeSlice(reflect.SliceOf(reflect.PointerTo(rel.FieldSchema.ModelType)), 0, 1) appendToResults := func(value reflect.Value) { if _, isZero := rel.Field.ValueOf(ctx, value); !isZero { result := reflect.Indirect(rel.Field.ReflectValueOf(ctx, value)) switch result.Kind() { case reflect.Struct: reflectResults = reflect.Append(reflectResults, result.Addr()) case reflect.Slice, reflect.Array: for i := 0; i < result.Len(); i++ { if elem := result.Index(i); elem.Kind() == reflect.Ptr { reflectResults = reflect.Append(reflectResults, elem) } else { reflectResults = reflect.Append(reflectResults, elem.Addr()) } } } } } switch reflectValue.Kind() { case reflect.Struct: appendToResults(reflectValue) case reflect.Slice: for i := 0; i < reflectValue.Len(); i++ { appendToResults(reflectValue.Index(i)) } } reflectValue = reflectResults } return } // GetIdentityFieldValuesMap get identity map from fields func GetIdentityFieldValuesMap(ctx context.Context, reflectValue reflect.Value, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { var ( results = [][]interface{}{} dataResults = map[string][]reflect.Value{} loaded = map[interface{}]bool{} notZero, zero bool ) if reflectValue.Kind() == reflect.Ptr || reflectValue.Kind() == reflect.Interface { reflectValue = reflectValue.Elem() } switch reflectValue.Kind() { case reflect.Map: results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { mapValue := reflectValue.MapIndex(reflect.ValueOf(field.DBName)) if mapValue.IsZero() { mapValue = reflectValue.MapIndex(reflect.ValueOf(field.Name)) } results[0][idx] = mapValue.Interface() } dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Struct: results = [][]interface{}{make([]interface{}, len(fields))} for idx, field := range fields { results[0][idx], zero = field.ValueOf(ctx, reflectValue) notZero = notZero || !zero } if !notZero { return nil, nil } dataResults[utils.ToStringKey(results[0]...)] = []reflect.Value{reflectValue} case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { elem := reflectValue.Index(i) elemKey := elem.Interface() if elem.Kind() != reflect.Ptr && elem.CanAddr() { elemKey = elem.Addr().Interface() } if _, ok := loaded[elemKey]; ok { continue } loaded[elemKey] = true fieldValues := make([]interface{}, len(fields)) notZero = false for idx, field := range fields { fieldValues[idx], zero = field.ValueOf(ctx, elem) notZero = notZero || !zero } if notZero { dataKey := utils.ToStringKey(fieldValues...) if _, ok := dataResults[dataKey]; !ok { results = append(results, fieldValues) dataResults[dataKey] = []reflect.Value{elem} } else { dataResults[dataKey] = append(dataResults[dataKey], elem) } } } } return dataResults, results } // GetIdentityFieldValuesMapFromValues get identity map from fields func GetIdentityFieldValuesMapFromValues(ctx context.Context, values []interface{}, fields []*Field) (map[string][]reflect.Value, [][]interface{}) { resultsMap := map[string][]reflect.Value{} results := [][]interface{}{} for _, v := range values { rm, rs := GetIdentityFieldValuesMap(ctx, reflect.Indirect(reflect.ValueOf(v)), fields) for k, v := range rm { resultsMap[k] = append(resultsMap[k], v...) } results = append(results, rs...) } return resultsMap, results } // ToQueryValues to query values func ToQueryValues(table string, foreignKeys []string, foreignValues [][]interface{}) (interface{}, []interface{}) { queryValues := make([]interface{}, len(foreignValues)) if len(foreignKeys) == 1 { for idx, r := range foreignValues { queryValues[idx] = r[0] } return clause.Column{Table: table, Name: foreignKeys[0]}, queryValues } columns := make([]clause.Column, len(foreignKeys)) for idx, key := range foreignKeys { columns[idx] = clause.Column{Table: table, Name: key} } for idx, r := range foreignValues { queryValues[idx] = r } return columns, queryValues } type embeddedNamer struct { Table string Namer } ================================================ FILE: schema/utils_test.go ================================================ package schema import ( "reflect" "testing" ) func TestRemoveSettingFromTag(t *testing.T) { tags := map[string]string{ `gorm:"before:value;column:db;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db;" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db" other:"before:value;column:db;after:value"`: `gorm:"before:value;" other:"before:value;column:db;after:value"`, `gorm:"column:db" other:"before:value;column:db;after:value"`: `gorm:"" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db ;after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value;after:value" other:"before:value;column:db;after:value"`, `gorm:"before:value;column:db; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, `gorm:"before:value;column; after:value" other:"before:value;column:db;after:value"`: `gorm:"before:value; after:value" other:"before:value;column:db;after:value"`, } for k, v := range tags { if string(removeSettingFromTag(reflect.StructTag(k), "column")) != v { t.Errorf("%v after removeSettingFromTag should equal %v, but got %v", k, v, removeSettingFromTag(reflect.StructTag(k), "column")) } } } func TestParseTagSettingWithDoubleQuoteEscape(t *testing.T) { tag := `gorm:"expression:to_tsvector('english', \"Name\")"` settings := ParseTagSetting(reflect.StructTag(tag).Get("gorm"), ";") if v, ok := settings["EXPRESSION"]; !ok || v != `to_tsvector('english', "Name")` { t.Errorf("ParseTagSetting did not handle escaped double quotes correctly: got %#v", v) } } ================================================ FILE: soft_delete.go ================================================ package gorm import ( "database/sql" "database/sql/driver" "encoding/json" "reflect" "github.com/jinzhu/now" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) type DeletedAt sql.NullTime // Scan implements the Scanner interface. func (n *DeletedAt) Scan(value interface{}) error { return (*sql.NullTime)(n).Scan(value) } // Value implements the driver Valuer interface. func (n DeletedAt) Value() (driver.Value, error) { if !n.Valid { return nil, nil } return n.Time, nil } func (n DeletedAt) MarshalJSON() ([]byte, error) { if n.Valid { return json.Marshal(n.Time) } return json.Marshal(nil) } func (n *DeletedAt) UnmarshalJSON(b []byte) error { if string(b) == "null" { n.Valid = false return nil } err := json.Unmarshal(b, &n.Time) if err == nil { n.Valid = true } return err } func (DeletedAt) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{SoftDeleteQueryClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } func parseZeroValueTag(f *schema.Field) sql.NullString { if v, ok := f.TagSettings["ZEROVALUE"]; ok { if _, err := now.Parse(v); err == nil { return sql.NullString{String: v, Valid: true} } } return sql.NullString{Valid: false} } type SoftDeleteQueryClause struct { ZeroValue sql.NullString Field *schema.Field } func (sd SoftDeleteQueryClause) Name() string { return "" } func (sd SoftDeleteQueryClause) Build(clause.Builder) { } func (sd SoftDeleteQueryClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteQueryClause) ModifyStatement(stmt *Statement) { if _, ok := stmt.Clauses["soft_delete_enabled"]; !ok && !stmt.Statement.Unscoped { if c, ok := stmt.Clauses["WHERE"]; ok { if where, ok := c.Expression.(clause.Where); ok && len(where.Exprs) >= 1 { for _, expr := range where.Exprs { if orCond, ok := expr.(clause.OrConditions); ok && len(orCond.Exprs) == 1 { where.Exprs = []clause.Expression{clause.And(where.Exprs...)} c.Expression = where stmt.Clauses["WHERE"] = c break } } } } stmt.AddClause(clause.Where{Exprs: []clause.Expression{ clause.Eq{Column: clause.Column{Table: clause.CurrentTable, Name: sd.Field.DBName}, Value: sd.ZeroValue}, }}) stmt.Clauses["soft_delete_enabled"] = clause.Clause{} } } func (DeletedAt) UpdateClauses(f *schema.Field) []clause.Interface { return []clause.Interface{SoftDeleteUpdateClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteUpdateClause struct { ZeroValue sql.NullString Field *schema.Field } func (sd SoftDeleteUpdateClause) Name() string { return "" } func (sd SoftDeleteUpdateClause) Build(clause.Builder) { } func (sd SoftDeleteUpdateClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteUpdateClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { SoftDeleteQueryClause(sd).ModifyStatement(stmt) } } func (DeletedAt) DeleteClauses(f *schema.Field) []clause.Interface { return []clause.Interface{SoftDeleteDeleteClause{Field: f, ZeroValue: parseZeroValueTag(f)}} } type SoftDeleteDeleteClause struct { ZeroValue sql.NullString Field *schema.Field } func (sd SoftDeleteDeleteClause) Name() string { return "" } func (sd SoftDeleteDeleteClause) Build(clause.Builder) { } func (sd SoftDeleteDeleteClause) MergeClause(*clause.Clause) { } func (sd SoftDeleteDeleteClause) ModifyStatement(stmt *Statement) { if stmt.SQL.Len() == 0 && !stmt.Statement.Unscoped { curTime := stmt.DB.NowFunc() stmt.AddClause(clause.Set{{Column: clause.Column{Name: sd.Field.DBName}, Value: curTime}}) stmt.SetColumn(sd.Field.DBName, curTime, true) if stmt.Schema != nil { _, queryValues := schema.GetIdentityFieldValuesMap(stmt.Context, stmt.ReflectValue, stmt.Schema.PrimaryFields) column, values := schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } if stmt.ReflectValue.CanAddr() && stmt.Dest != stmt.Model && stmt.Model != nil { _, queryValues = schema.GetIdentityFieldValuesMap(stmt.Context, reflect.ValueOf(stmt.Model), stmt.Schema.PrimaryFields) column, values = schema.ToQueryValues(stmt.Table, stmt.Schema.PrimaryFieldDBNames, queryValues) if len(values) > 0 { stmt.AddClause(clause.Where{Exprs: []clause.Expression{clause.IN{Column: column, Values: values}}}) } } } SoftDeleteQueryClause(sd).ModifyStatement(stmt) stmt.AddClauseIfNotExists(clause.Update{}) stmt.Build(stmt.DB.Callback().Update().Clauses...) } } ================================================ FILE: statement.go ================================================ package gorm import ( "context" "database/sql" "database/sql/driver" "fmt" "reflect" "regexp" "sort" "strconv" "strings" "sync" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" "gorm.io/gorm/utils" ) // Statement statement type Statement struct { *DB TableExpr *clause.Expr Table string Model interface{} Unscoped bool Dest interface{} ReflectValue reflect.Value Clauses map[string]clause.Clause BuildClauses []string Distinct bool Selects []string // selected columns Omits []string // omit columns ColumnMapping map[string]string // map columns Joins []join Preloads map[string][]interface{} Settings sync.Map ConnPool ConnPool Schema *schema.Schema Context context.Context RaiseErrorOnNotFound bool SkipHooks bool SQL strings.Builder Vars []interface{} CurDestIndex int attrs []interface{} assigns []interface{} scopes []func(*DB) *DB Result *result } type join struct { Name string Alias string Conds []interface{} On *clause.Where Selects []string Omits []string Expression clause.Expression JoinType clause.JoinType } // StatementModifier statement modifier interface type StatementModifier interface { ModifyStatement(*Statement) } // WriteString write string func (stmt *Statement) WriteString(str string) (int, error) { return stmt.SQL.WriteString(str) } // WriteByte write byte func (stmt *Statement) WriteByte(c byte) error { return stmt.SQL.WriteByte(c) } // WriteQuoted write quoted value func (stmt *Statement) WriteQuoted(value interface{}) { stmt.QuoteTo(&stmt.SQL, value) } // QuoteTo write quoted value to writer func (stmt *Statement) QuoteTo(writer clause.Writer, field interface{}) { write := func(raw bool, str string) { if raw { writer.WriteString(str) } else { stmt.DB.Dialector.QuoteTo(writer, str) } } switch v := field.(type) { case clause.Table: if v.Name == clause.CurrentTable { if stmt.TableExpr != nil { stmt.TableExpr.Build(stmt) } else if stmt.Table != "" { write(v.Raw, stmt.Table) } else if stmt.AddError(stmt.Parse(stmt.Model)) == nil { write(v.Raw, stmt.Table) } } else { write(v.Raw, v.Name) } if v.Alias != "" { writer.WriteByte(' ') write(v.Raw, v.Alias) } case clause.Column: if v.Table != "" { if v.Table == clause.CurrentTable { write(v.Raw, stmt.Table) } else { write(v.Raw, v.Table) } writer.WriteByte('.') } if v.Name == clause.PrimaryKey { if stmt.Schema == nil { stmt.DB.AddError(ErrModelValueRequired) } else if stmt.Schema.PrioritizedPrimaryField != nil { write(v.Raw, stmt.Schema.PrioritizedPrimaryField.DBName) } else if len(stmt.Schema.DBNames) > 0 { write(v.Raw, stmt.Schema.DBNames[0]) } else { stmt.DB.AddError(ErrModelAccessibleFieldsRequired) //nolint:typecheck,errcheck } } else { write(v.Raw, v.Name) } if v.Alias != "" { writer.WriteString(" AS ") write(v.Raw, v.Alias) } case []clause.Column: writer.WriteByte('(') for idx, d := range v { if idx > 0 { writer.WriteByte(',') } stmt.QuoteTo(writer, d) } writer.WriteByte(')') case clause.Expr: v.Build(stmt) case string: stmt.DB.Dialector.QuoteTo(writer, v) case []string: writer.WriteByte('(') for idx, d := range v { if idx > 0 { writer.WriteByte(',') } stmt.DB.Dialector.QuoteTo(writer, d) } writer.WriteByte(')') default: stmt.DB.Dialector.QuoteTo(writer, fmt.Sprint(field)) } } // Quote returns quoted value func (stmt *Statement) Quote(field interface{}) string { var builder strings.Builder stmt.QuoteTo(&builder, field) return builder.String() } // AddVar add var func (stmt *Statement) AddVar(writer clause.Writer, vars ...interface{}) { for idx, v := range vars { if idx > 0 { writer.WriteByte(',') } switch v := v.(type) { case sql.NamedArg: stmt.Vars = append(stmt.Vars, v.Value) case clause.Column, clause.Table: stmt.QuoteTo(writer, v) case Valuer: reflectValue := reflect.ValueOf(v) if reflectValue.Kind() == reflect.Ptr && reflectValue.IsNil() { stmt.AddVar(writer, nil) } else { stmt.AddVar(writer, v.GormValue(stmt.Context, stmt.DB)) } case clause.Interface: c := clause.Clause{Name: v.Name()} v.MergeClause(&c) c.Build(stmt) case clause.Expression: v.Build(stmt) case driver.Valuer: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) case []byte: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) case []interface{}: if len(v) > 0 { writer.WriteByte('(') stmt.AddVar(writer, v...) writer.WriteByte(')') } else { writer.WriteString("(NULL)") } case interface{ getInstance() *DB }: cv := v.getInstance() subdb := cv.Session(&Session{Logger: logger.Discard, DryRun: true}).getInstance() if cv.Statement.SQL.Len() > 0 { var ( vars = subdb.Statement.Vars sql = cv.Statement.SQL.String() ) subdb.Statement.Vars = make([]interface{}, 0, len(vars)) for _, vv := range vars { subdb.Statement.Vars = append(subdb.Statement.Vars, vv) bindvar := strings.Builder{} cv.BindVarTo(&bindvar, subdb.Statement, vv) sql = strings.Replace(sql, bindvar.String(), "?", 1) } subdb.Statement.SQL.Reset() subdb.Statement.Vars = stmt.Vars if strings.Contains(sql, "@") { clause.NamedExpr{SQL: sql, Vars: vars}.Build(subdb.Statement) } else { clause.Expr{SQL: sql, Vars: vars}.Build(subdb.Statement) } } else { subdb.Statement.Vars = append(stmt.Vars, subdb.Statement.Vars...) subdb.callbacks.Query().Execute(subdb) } writer.WriteString(subdb.Statement.SQL.String()) stmt.Vars = subdb.Statement.Vars default: switch rv := reflect.ValueOf(v); rv.Kind() { case reflect.Slice, reflect.Array: if rv.Len() == 0 { writer.WriteString("(NULL)") } else if rv.Type().Elem() == reflect.TypeOf(uint8(0)) { stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) } else { writer.WriteByte('(') for i := 0; i < rv.Len(); i++ { if i > 0 { writer.WriteByte(',') } stmt.AddVar(writer, rv.Index(i).Interface()) } writer.WriteByte(')') } default: stmt.Vars = append(stmt.Vars, v) stmt.DB.Dialector.BindVarTo(writer, stmt, v) } } } } // AddClause add clause func (stmt *Statement) AddClause(v clause.Interface) { if optimizer, ok := v.(StatementModifier); ok { optimizer.ModifyStatement(stmt) } else { name := v.Name() c := stmt.Clauses[name] c.Name = name v.MergeClause(&c) stmt.Clauses[name] = c } } // AddClauseIfNotExists add clause if not exists func (stmt *Statement) AddClauseIfNotExists(v clause.Interface) { if c, ok := stmt.Clauses[v.Name()]; !ok || c.Expression == nil { stmt.AddClause(v) } } // BuildCondition build condition func (stmt *Statement) BuildCondition(query interface{}, args ...interface{}) []clause.Expression { if s, ok := query.(string); ok { // if it is a number, then treats it as primary key if _, err := strconv.Atoi(s); err != nil { if s == "" && len(args) == 0 { return nil } if len(args) == 0 || (len(args) > 0 && strings.Contains(s, "?")) { // looks like a where condition return []clause.Expression{clause.Expr{SQL: s, Vars: args}} } if len(args) > 0 && strings.Contains(s, "@") { // looks like a named query return []clause.Expression{clause.NamedExpr{SQL: s, Vars: args}} } if strings.Contains(strings.TrimSpace(s), " ") { // looks like a where condition return []clause.Expression{clause.Expr{SQL: s, Vars: args}} } if len(args) == 1 { return []clause.Expression{clause.Eq{Column: s, Value: args[0]}} } } } conds := make([]clause.Expression, 0, 4) args = append([]interface{}{query}, args...) for idx, arg := range args { if arg == nil { continue } if valuer, ok := arg.(driver.Valuer); ok { arg, _ = valuer.Value() } curTable := stmt.Table if curTable == "" { curTable = clause.CurrentTable } switch v := arg.(type) { case clause.Expression: conds = append(conds, v) case []clause.Expression: conds = append(conds, v...) case *DB: v.executeScopes() if cs, ok := v.Statement.Clauses["WHERE"]; ok { if where, ok := cs.Expression.(clause.Where); ok { if len(where.Exprs) == 1 { if orConds, ok := where.Exprs[0].(clause.OrConditions); ok { if len(orConds.Exprs) == 1 { where.Exprs[0] = clause.AndConditions(orConds) } } } conds = append(conds, clause.And(where.Exprs...)) } else if cs.Expression != nil { conds = append(conds, cs.Expression) } } case map[interface{}]interface{}: for i, j := range v { conds = append(conds, clause.Eq{Column: i, Value: j}) } case map[string]string: keys := make([]string, 0, len(v)) for i := range v { keys = append(keys, i) } sort.Strings(keys) for _, key := range keys { column := clause.Column{Name: key, Table: curTable} if strings.Contains(key, ".") { column = clause.Column{Name: key} } conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } case map[string]interface{}: keys := make([]string, 0, len(v)) for i := range v { keys = append(keys, i) } sort.Strings(keys) for _, key := range keys { reflectValue := reflect.Indirect(reflect.ValueOf(v[key])) column := clause.Column{Name: key, Table: curTable} if strings.Contains(key, ".") { column = clause.Column{Name: key} } switch reflectValue.Kind() { case reflect.Slice, reflect.Array: if _, ok := v[key].(driver.Valuer); ok { conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else if _, ok := v[key].(Valuer); ok { conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } else { // optimize reflect value length valueLen := reflectValue.Len() values := make([]interface{}, valueLen) for i := 0; i < valueLen; i++ { values[i] = reflectValue.Index(i).Interface() } conds = append(conds, clause.IN{Column: column, Values: values}) } default: conds = append(conds, clause.Eq{Column: column, Value: v[key]}) } } default: reflectValue := reflect.Indirect(reflect.ValueOf(arg)) for reflectValue.Kind() == reflect.Ptr { reflectValue = reflectValue.Elem() } if s, err := schema.Parse(arg, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { selectedColumns := map[string]bool{} if idx == 0 { for _, v := range args[1:] { if vs, ok := v.(string); ok { selectedColumns[vs] = true } } } restricted := len(selectedColumns) != 0 switch reflectValue.Kind() { case reflect.Struct: for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } } case reflect.Slice, reflect.Array: for i := 0; i < reflectValue.Len(); i++ { for _, field := range s.Fields { selected := selectedColumns[field.DBName] || selectedColumns[field.Name] if selected || (!restricted && field.Readable) { if v, isZero := field.ValueOf(stmt.Context, reflectValue.Index(i)); !isZero || selected { if field.DBName != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.DBName}, Value: v}) } else if field.DataType != "" { conds = append(conds, clause.Eq{Column: clause.Column{Table: curTable, Name: field.Name}, Value: v}) } } } } } } if restricted { break } } else if !reflectValue.IsValid() { stmt.AddError(ErrInvalidData) } else if len(conds) == 0 { if len(args) == 1 { switch reflectValue.Kind() { case reflect.Slice, reflect.Array: // optimize reflect value length valueLen := reflectValue.Len() values := make([]interface{}, valueLen) for i := 0; i < valueLen; i++ { values[i] = reflectValue.Index(i).Interface() } if len(values) > 0 { conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: values}) return []clause.Expression{clause.And(conds...)} } return nil } } conds = append(conds, clause.IN{Column: clause.Column{Table: curTable, Name: clause.PrimaryKey}, Values: args}) } } } if len(conds) > 0 { return []clause.Expression{clause.And(conds...)} } return nil } // Build build sql with clauses names func (stmt *Statement) Build(clauses ...string) { var firstClauseWritten bool for _, name := range clauses { if c, ok := stmt.Clauses[name]; ok { if firstClauseWritten { stmt.WriteByte(' ') } firstClauseWritten = true if b, ok := stmt.DB.ClauseBuilders[name]; ok { b(c, stmt) } else { c.Build(stmt) } } } } func (stmt *Statement) Parse(value interface{}) (err error) { return stmt.ParseWithSpecialTableName(value, "") } func (stmt *Statement) ParseWithSpecialTableName(value interface{}, specialTableName string) (err error) { if stmt.Schema, err = schema.ParseWithSpecialTableName(value, stmt.DB.cacheStore, stmt.DB.NamingStrategy, specialTableName); err == nil && stmt.Table == "" { if tables := strings.Split(stmt.Schema.Table, "."); len(tables) == 2 { stmt.TableExpr = &clause.Expr{SQL: stmt.Quote(stmt.Schema.Table)} stmt.Table = tables[1] return } stmt.Table = stmt.Schema.Table } return err } func (stmt *Statement) clone() *Statement { newStmt := &Statement{ TableExpr: stmt.TableExpr, Table: stmt.Table, Model: stmt.Model, Unscoped: stmt.Unscoped, Dest: stmt.Dest, ReflectValue: stmt.ReflectValue, Clauses: map[string]clause.Clause{}, Distinct: stmt.Distinct, Selects: stmt.Selects, Omits: stmt.Omits, ColumnMapping: stmt.ColumnMapping, Preloads: map[string][]interface{}{}, ConnPool: stmt.ConnPool, Schema: stmt.Schema, Context: stmt.Context, RaiseErrorOnNotFound: stmt.RaiseErrorOnNotFound, SkipHooks: stmt.SkipHooks, Result: stmt.Result, } if stmt.SQL.Len() > 0 { newStmt.SQL.WriteString(stmt.SQL.String()) newStmt.Vars = make([]interface{}, 0, len(stmt.Vars)) newStmt.Vars = append(newStmt.Vars, stmt.Vars...) } for k, c := range stmt.Clauses { newStmt.Clauses[k] = c } for k, p := range stmt.Preloads { newStmt.Preloads[k] = p } if len(stmt.Joins) > 0 { newStmt.Joins = make([]join, len(stmt.Joins)) copy(newStmt.Joins, stmt.Joins) } if len(stmt.scopes) > 0 { newStmt.scopes = make([]func(*DB) *DB, len(stmt.scopes)) copy(newStmt.scopes, stmt.scopes) } stmt.Settings.Range(func(k, v interface{}) bool { newStmt.Settings.Store(k, v) return true }) return newStmt } // SetColumn set column's value // // stmt.SetColumn("Name", "jinzhu") // Hooks Method // stmt.SetColumn("Name", "jinzhu", true) // Callbacks Method func (stmt *Statement) SetColumn(name string, value interface{}, fromCallbacks ...bool) { if v, ok := stmt.Dest.(map[string]interface{}); ok { v[name] = value } else if v, ok := stmt.Dest.([]map[string]interface{}); ok { for _, m := range v { m[name] = value } } else if stmt.Schema != nil { if field := stmt.Schema.LookUpField(name); field != nil { destValue := reflect.ValueOf(stmt.Dest) for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } if stmt.ReflectValue != destValue { if !destValue.CanAddr() { destValueCanAddr := reflect.New(destValue.Type()) destValueCanAddr.Elem().Set(destValue) stmt.Dest = destValueCanAddr.Interface() destValue = destValueCanAddr.Elem() } switch destValue.Kind() { case reflect.Struct: stmt.AddError(field.Set(stmt.Context, destValue, value)) default: stmt.AddError(ErrInvalidData) } } switch stmt.ReflectValue.Kind() { case reflect.Slice, reflect.Array: if len(fromCallbacks) > 0 { for i := 0; i < stmt.ReflectValue.Len(); i++ { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(i), value)) } } else { stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue.Index(stmt.CurDestIndex), value)) } case reflect.Struct: if !stmt.ReflectValue.CanAddr() { stmt.AddError(ErrInvalidValue) return } stmt.AddError(field.Set(stmt.Context, stmt.ReflectValue, value)) } } else { stmt.AddError(ErrInvalidField) } } else { stmt.AddError(ErrInvalidField) } } // Changed check model changed or not when updating func (stmt *Statement) Changed(fields ...string) bool { modelValue := stmt.ReflectValue switch modelValue.Kind() { case reflect.Slice, reflect.Array: modelValue = stmt.ReflectValue.Index(stmt.CurDestIndex) } selectColumns, restricted := stmt.SelectAndOmitColumns(false, true) changed := func(field *schema.Field) bool { fieldValue, _ := field.ValueOf(stmt.Context, modelValue) if v, ok := selectColumns[field.DBName]; (ok && v) || (!ok && !restricted) { if mv, mok := stmt.Dest.(map[string]interface{}); mok { if fv, ok := mv[field.Name]; ok { return !utils.AssertEqual(fv, fieldValue) } else if fv, ok := mv[field.DBName]; ok { return !utils.AssertEqual(fv, fieldValue) } } else { destValue := reflect.ValueOf(stmt.Dest) for destValue.Kind() == reflect.Ptr { destValue = destValue.Elem() } if descSchema, err := schema.Parse(stmt.Dest, stmt.DB.cacheStore, stmt.DB.NamingStrategy); err == nil { if destField := descSchema.LookUpField(field.DBName); destField != nil { changedValue, zero := destField.ValueOf(stmt.Context, destValue) if v { return !utils.AssertEqual(changedValue, fieldValue) } return !zero && !utils.AssertEqual(changedValue, fieldValue) } } } } return false } if len(fields) == 0 { for _, field := range stmt.Schema.FieldsByDBName { if changed(field) { return true } } } else { for _, name := range fields { if field := stmt.Schema.LookUpField(name); field != nil { if changed(field) { return true } } } } return false } var matchName = func() func(tableColumn string) (table, column string) { nameMatcher := regexp.MustCompile(`^(?:\W?(\w+?)\W?\.)?(?:(\*)|\W?(\w+?)\W?)$`) return func(tableColumn string) (table, column string) { if matches := nameMatcher.FindStringSubmatch(tableColumn); len(matches) == 4 { table = matches[1] star := matches[2] columnName := matches[3] if star != "" { return table, star } return table, columnName } return "", "" } }() // SelectAndOmitColumns get select and omit columns, select -> true, omit -> false func (stmt *Statement) SelectAndOmitColumns(requireCreate, requireUpdate bool) (map[string]bool, bool) { results := map[string]bool{} notRestricted := false processColumn := func(column string, result bool) { if stmt.Schema == nil { results[column] = result } else if column == "*" { notRestricted = result for _, dbName := range stmt.Schema.DBNames { results[dbName] = result } } else if column == clause.Associations { for _, rel := range stmt.Schema.Relationships.Relations { results[rel.Name] = result } } else if field := stmt.Schema.LookUpField(column); field != nil && field.DBName != "" { results[field.DBName] = result } else if table, col := matchName(column); col != "" && (table == stmt.Table || table == "") { if col == "*" { for _, dbName := range stmt.Schema.DBNames { results[dbName] = result } } else { results[col] = result } } else { results[column] = result } } // select columns for _, column := range stmt.Selects { processColumn(column, true) } // omit columns for _, column := range stmt.Omits { processColumn(column, false) } if stmt.Schema != nil { for _, field := range stmt.Schema.FieldsByName { name := field.DBName if name == "" { name = field.Name } if requireCreate && !field.Creatable { results[name] = false } else if requireUpdate && !field.Updatable { results[name] = false } } } return results, !notRestricted && len(stmt.Selects) > 0 } ================================================ FILE: statement_test.go ================================================ package gorm import ( "fmt" "reflect" "testing" "gorm.io/gorm/clause" ) func TestWhereCloneCorruption(t *testing.T) { for whereCount := 1; whereCount <= 8; whereCount++ { t.Run(fmt.Sprintf("w=%d", whereCount), func(t *testing.T) { s := new(Statement) for w := 0; w < whereCount; w++ { s = s.clone() s.AddClause(clause.Where{ Exprs: s.BuildCondition(fmt.Sprintf("where%d", w)), }) } s1 := s.clone() s1.AddClause(clause.Where{ Exprs: s.BuildCondition("FINAL1"), }) s2 := s.clone() s2.AddClause(clause.Where{ Exprs: s.BuildCondition("FINAL2"), }) if reflect.DeepEqual(s1.Clauses["WHERE"], s2.Clauses["WHERE"]) { t.Errorf("Where conditions should be different") } }) } } func TestNilCondition(t *testing.T) { s := new(Statement) if len(s.BuildCondition(nil)) != 0 { t.Errorf("Nil condition should be empty") } } func TestNameMatcher(t *testing.T) { for k, v := range map[string][]string{ "table.name": {"table", "name"}, "`table`.`name`": {"table", "name"}, "'table'.'name'": {"table", "name"}, "'table'.name": {"table", "name"}, "table1.name_23": {"table1", "name_23"}, "`table_1`.`name23`": {"table_1", "name23"}, "'table23'.'name_1'": {"table23", "name_1"}, "'table23'.name1": {"table23", "name1"}, "'name1'": {"", "name1"}, "`name_1`": {"", "name_1"}, "`Name_1`": {"", "Name_1"}, "`Table`.`nAme`": {"Table", "nAme"}, "my_table.*": {"my_table", "*"}, "`my_table`.*": {"my_table", "*"}, "User__Company.*": {"User__Company", "*"}, "`User__Company`.*": {"User__Company", "*"}, `"User__Company".*`: {"User__Company", "*"}, `"table"."*"`: {"", ""}, } { if table, column := matchName(k); table != v[0] || column != v[1] { t.Errorf("failed to match value: %v, got %v, expect: %v", k, []string{table, column}, v) } } } ================================================ FILE: tests/.gitignore ================================================ go.sum ================================================ FILE: tests/README.md ================================================ # Test Guide ```bash cd tests # prepare test databases docker-compose up # run all tests ./tests_all.sh ``` ================================================ FILE: tests/association_generics_test.go ================================================ package tests_test import ( "context" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) // BelongsToCompany and BelongsToUser models for belongs to tests - using existing User and Company models // Test Set + Create with Association OpCreate operation using real database func TestClauseAssociationSetCreateWithOpCreate(t *testing.T) { ctx := context.Background() // First create a user with Set + Create err := gorm.G[User](DB).Set( clause.Assignment{Column: clause.Column{Name: "name"}, Value: "TestClauseAssociationSetCreateWithOpCreate"}, clause.Assignment{Column: clause.Column{Name: "age"}, Value: 25}, ).Create(ctx) if err != nil { t.Fatalf("Set Create failed: %v", err) } // Find the created user var user User if err := DB.Where("name = ?", "TestClauseAssociationSetCreateWithOpCreate").First(&user).Error; err != nil { t.Fatalf("failed to find created user: %v", err) } // Test Set + Update with Association OpCreate assocOp := clause.Association{ Association: "Pets", Type: clause.OpCreate, Set: []clause.Assignment{ {Column: clause.Column{Name: "name"}, Value: "test-pet"}, }, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx) if err != nil { t.Fatalf("Set Update with association failed: %v", err) } // Only association operations were executed; no row update is expected if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } // Verify the association was created using real database query AssertAssociationCount(t, &user, "Pets", 1, "after Set Update with association") } // Test Set + Update with Association OpCreate operation using real database func TestClauseAssociationSetUpdateWithOpCreate(t *testing.T) { ctx := context.Background() // Create a user with a pet first using real database user := User{Name: "TestClauseAssociationSetUpdateWithOpCreate", Age: 25} user.Pets = []*Pet{{Name: "original-pet"}} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user with pet: %v", err) } // Verify initial state using real database query AssertAssociationCount(t, user, "Pets", 1, "before update") // Test Set + Update with Association OpCreate assocOp := clause.Association{ Association: "Pets", Type: clause.OpCreate, Set: []clause.Assignment{ {Column: clause.Column{Name: "name"}, Value: "new-pet"}, }, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx) if err != nil { t.Fatalf("Set Update with association failed: %v", err) } // Only association operations were executed; no row update is expected if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } // Verify the association was updated using real database query var updatedUser User if err := DB.Preload("Pets").Where("id = ?", user.ID).First(&updatedUser).Error; err != nil { t.Fatalf("failed to find updated user: %v", err) } if len(updatedUser.Pets) != 2 { t.Fatalf("expected 2 pets, got %d", len(updatedUser.Pets)) } petNames := make(map[string]bool) for _, pet := range updatedUser.Pets { petNames[pet.Name] = true } if !petNames["original-pet"] { t.Error("original pet not found") } if !petNames["new-pet"] { t.Error("new pet not found") } } // Test Set + Create with multiple associations using real database func TestClauseAssociationSetCreateWithMultipleAssociations(t *testing.T) { ctx := context.Background() // First create a user with Set + Create using real database err := gorm.G[User](DB).Set( clause.Assignment{Column: clause.Column{Name: "name"}, Value: "TestClauseAssociationSetCreateWithMultipleAssociations"}, clause.Assignment{Column: clause.Column{Name: "age"}, Value: 25}, ).Create(ctx) if err != nil { t.Fatalf("Set Create failed: %v", err) } // Find the created user using real database query var user User if err := DB.Where("name = ?", "TestClauseAssociationSetCreateWithMultipleAssociations").First(&user).Error; err != nil { t.Fatalf("failed to find created user: %v", err) } // Test Set + Update with multiple association operations assocOp1 := clause.Association{ Association: "Pets", Type: clause.OpCreate, Set: []clause.Assignment{ {Column: clause.Column{Name: "name"}, Value: "test-pet-1"}, }, } assocOp2 := clause.Association{ Association: "Toys", Type: clause.OpCreate, Set: []clause.Assignment{ {Column: clause.Column{Name: "name"}, Value: "test-toy-1"}, }, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp1, assocOp2).Update(ctx) if err != nil { t.Fatalf("Set Update with multiple associations failed: %v", err) } // Only association operations were executed; no row update is expected if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } // Verify both associations were created using real database queries AssertAssociationCount(t, &user, "Pets", 1, "after Set Update with multiple associations") AssertAssociationCount(t, &user, "Toys", 1, "after Set Update with multiple associations") } // Test Set + Update with multiple associations using real database func TestClauseAssociationSetUpdateWithMultipleAssociations(t *testing.T) { ctx := context.Background() // Create a user with initial associations using real database user := User{Name: "TestClauseAssociationSetUpdateWithMultipleAssociations", Age: 25} user.Pets = []*Pet{{Name: "original-pet"}} user.Toys = []Toy{{Name: "original-toy"}} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user with associations: %v", err) } // Verify initial state using real database queries AssertAssociationCount(t, user, "Pets", 1, "before update") AssertAssociationCount(t, user, "Toys", 1, "before update") // Test Set + Update with multiple association operations assocOp1 := clause.Association{ Association: "Pets", Type: clause.OpCreate, Set: []clause.Assignment{ {Column: clause.Column{Name: "name"}, Value: "new-pet"}, }, } assocOp2 := clause.Association{ Association: "Toys", Type: clause.OpCreate, Set: []clause.Assignment{ {Column: clause.Column{Name: "name"}, Value: "new-toy"}, }, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp1, assocOp2).Update(ctx) if err != nil { t.Fatalf("Set Update with multiple associations failed: %v", err) } // Only association operations were executed; no row update is expected if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } // Verify both associations were updated using real database queries var updatedUser User if err := DB.Preload("Pets").Preload("Toys").Where("id = ?", user.ID).First(&updatedUser).Error; err != nil { t.Fatalf("failed to find updated user: %v", err) } if len(updatedUser.Pets) != 2 { t.Fatalf("expected 2 pets, got %d", len(updatedUser.Pets)) } if len(updatedUser.Toys) != 2 { t.Fatalf("expected 2 toys, got %d", len(updatedUser.Toys)) } } // Test Set + Update with Association OpUnlink operation using real database func TestClauseAssociationSetUpdateWithOpUnlink(t *testing.T) { ctx := context.Background() // Create a user with pets using real database user := User{Name: "TestClauseAssociationSetUpdateWithOpUnlink", Age: 25} user.Pets = []*Pet{{Name: "pet-to-unlink"}, {Name: "pet-to-keep"}} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user with pets: %v", err) } // Verify initial state using real database query AssertAssociationCount(t, user, "Pets", 2, "before unlink") // Get the pet to unlink using real database query var petToUnlink Pet if err := DB.Where("name = ?", "pet-to-unlink").First(&petToUnlink).Error; err != nil { t.Fatalf("failed to find pet to unlink: %v", err) } // Test Set + Update with Association OpUnlink assocOp := clause.Association{ Association: "Pets", Type: clause.OpUnlink, Conditions: []clause.Expression{ clause.Eq{Column: clause.Column{Name: "id"}, Value: petToUnlink.ID}, }, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx) if err != nil { t.Fatalf("Set Update with association unlink failed: %v", err) } // Only association operations were executed; no row update is expected if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } // Verify only one pet remains using real database query var updatedUser User if err := DB.Preload("Pets").Where("id = ?", user.ID).First(&updatedUser).Error; err != nil { t.Fatalf("failed to find updated user: %v", err) } if len(updatedUser.Pets) != 1 { t.Fatalf("expected 1 pet after unlink, got %d", len(updatedUser.Pets)) } if updatedUser.Pets[0].Name != "pet-to-keep" { t.Errorf("expected pet-to-keep, got %s", updatedUser.Pets[0].Name) } // Verify the unlinked pet still exists in the database using real database query var count int64 if err := DB.Model(&Pet{}).Where("id = ?", petToUnlink.ID).Count(&count).Error; err != nil { t.Fatalf("failed to count pet: %v", err) } if count != 1 { t.Error("unlinked pet should still exist in database") } } // Test Set + Update with Association OpCreate operation using real database func TestClauseAssociationSetUpdateWithOpCreateValues(t *testing.T) { ctx := context.Background() // Create a user first using real database user := User{Name: "TestClauseAssociationSetUpdateWithOpCreate", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user: %v", err) } // Create a pet object newPet := Pet{Name: "created-pet"} // Test Set + Update with Association OpCreate assocOp := clause.Association{ Association: "Pets", Type: clause.OpCreate, Values: []interface{}{&newPet}, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx) if err != nil { t.Fatalf("Set Update with association create values failed: %v", err) } // Only association operations were executed; no row update is expected if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } // Verify the pet was created and associated using real database query var updatedUser User if err := DB.Preload("Pets").Where("id = ?", user.ID).First(&updatedUser).Error; err != nil { t.Fatalf("failed to find updated user: %v", err) } if len(updatedUser.Pets) != 1 { t.Fatalf("expected 1 pet, got %d", len(updatedUser.Pets)) } if updatedUser.Pets[0].Name != "created-pet" { t.Errorf("expected created-pet, got %s", updatedUser.Pets[0].Name) } } // Test Set + Create with many-to-many associations using real database func TestClauseAssociationSetCreateWithManyToMany(t *testing.T) { ctx := context.Background() // Create a user first using real database user := User{Name: "TestClauseAssociationSetCreateWithManyToMany", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user: %v", err) } // Create languages using real database langs := []Language{ {Code: "en", Name: "English"}, {Code: "fr", Name: "French"}, } for _, lang := range langs { DB.FirstOrCreate(&lang, "code = ?", lang.Code) } // Test Set + Update with many-to-many association assocOp := clause.Association{ Association: "Languages", Type: clause.OpCreate, Values: []interface{}{langs[0], langs[1]}, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx) if err != nil { t.Fatalf("Set Update with many-to-many association failed: %v", err) } // Only association operations were executed; no row update is expected if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } // Verify the languages were associated using real database query var updatedUser User if err := DB.Preload("Languages").Where("id = ?", user.ID).First(&updatedUser).Error; err != nil { t.Fatalf("failed to find updated user: %v", err) } if len(updatedUser.Languages) != 2 { t.Fatalf("expected 2 languages, got %d", len(updatedUser.Languages)) } } // Test Set + Create with belongs-to associations using real database func TestClauseAssociationSetCreateWithBelongsTo(t *testing.T) { ctx := context.Background() // Create a company first using real database company := Company{Name: "Test Company"} if err := DB.Create(&company).Error; err != nil { t.Fatalf("failed to create company: %v", err) } // Test Set + Create with belongs-to association using field assignment err := gorm.G[User](DB).Set( clause.Assignment{Column: clause.Column{Name: "name"}, Value: "TestClauseAssociationSetCreateWithBelongsTo"}, clause.Assignment{Column: clause.Column{Name: "age"}, Value: 25}, clause.Assignment{Column: clause.Column{Name: "company_id"}, Value: company.ID}, ).Create(ctx) if err != nil { t.Fatalf("Set Create with belongs-to association failed: %v", err) } // Verify the user was created with company association using real database query var newUser User if err := DB.Preload("Company").Where("name = ?", "TestClauseAssociationSetCreateWithBelongsTo").First(&newUser).Error; err != nil { t.Fatalf("failed to find created user: %v", err) } if newUser.Company.ID != company.ID { t.Errorf("expected company ID %d, got %d", company.ID, newUser.Company.ID) } if newUser.Company.Name != company.Name { t.Errorf("expected company name %s, got %s", company.Name, newUser.Company.Name) } } // BelongsTo: create and assign company via OpCreate func TestClauseAssociationSetUpdateBelongsToCreateValues(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateBelongsToCreateValues", Age: 26} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user: %v", err) } assocOp := clause.Association{Association: "Company", Type: clause.OpCreate, Values: []interface{}{Company{Name: "Belongs-To-Co"}}} if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update belongs-to create values failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } var got User if err := DB.Preload("Company").First(&got, user.ID).Error; err != nil { t.Fatalf("failed preload company: %v", err) } if got.Company.ID == 0 || got.Company.Name != "Belongs-To-Co" { t.Fatalf("expected Company assigned, got %+v", got.Company) } } // Mixed fields + association: update Age and create a pet together func TestClauseAssociationSetUpdateMixedFieldAndAssociation(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateMixed", Age: 20} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } assocOp := clause.Association{Association: "Pets", Type: clause.OpCreate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "mix-pet"}}} rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set( assocOp, clause.Assignment{Column: clause.Column{Name: "age"}, Value: 30}, ).Update(ctx) if err != nil { t.Fatalf("Set Update mixed failed: %v", err) } if rows != 1 { t.Fatalf("expected 1 row affected for field update, got %d", rows) } var got User if err := DB.Preload("Pets").First(&got, user.ID).Error; err != nil { t.Fatalf("load user: %v", err) } if got.Age != 30 { t.Fatalf("expected age 30, got %d", got.Age) } if len(got.Pets) != 1 || got.Pets[0].Name != "mix-pet" { t.Fatalf("expected pet created, got %+v", got.Pets) } } // HasOne unlink clears NamedPet func TestClauseAssociationSetUpdateHasOneUnlink(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateHasOneUnlink", Age: 25} user.NamedPet = &Pet{Name: "np"} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create: %v", err) } assocOp := clause.Association{Association: "NamedPet", Type: clause.OpUnlink} if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update has-one unlink failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } var got User if err := DB.Preload("NamedPet").First(&got, user.ID).Error; err != nil { t.Fatalf("load user: %v", err) } if got.NamedPet != nil { t.Fatalf("expected NamedPet cleared, got %+v", got.NamedPet) } } // Many-to-Many create with Set func TestClauseAssociationSetUpdateManyToManyCreateWithSet(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateMany2ManyCreateWithSet", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } assocOp := clause.Association{ Association: "Languages", Type: clause.OpCreate, Set: []clause.Assignment{{Column: clause.Column{Name: "code"}, Value: "it"}, {Column: clause.Column{Name: "name"}, Value: "Italian"}}, } if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update many2many create with set failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected, got %d", rows) } AssertAssociationCount(t, user, "Languages", 1, "after create language") } // Many-to-Many clear func TestClauseAssociationSetUpdateManyToManyClear(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateMany2ManyClear", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } langs := []Language{{Code: "pt", Name: "Portuguese"}, {Code: "ru", Name: "Russian"}} for _, l := range langs { DB.FirstOrCreate(&l, "code = ?", l.Code) } if err := DB.Model(&user).Association("Languages").Append(&langs); err != nil { t.Fatalf("append: %v", err) } AssertAssociationCount(t, user, "Languages", 2, "before clear") assocOp := clause.Association{Association: "Languages", Type: clause.OpUnlink} if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update many2many clear failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected, got %d", rows) } AssertAssociationCount(t, user, "Languages", 0, "after clear") } // Polymorphic Tools create and unlink func TestClauseAssociationSetUpdatePolymorphicTools(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdatePolymorphicTools", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } createOp := clause.Association{Association: "Tools", Type: clause.OpCreate, Values: []interface{}{Tools{Name: "wrench"}}} if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(createOp).Update(ctx); err != nil { t.Fatalf("create tools: %v", err) } else if rows != 0 { t.Fatalf("rows %d", rows) } AssertAssociationCount(t, user, "Tools", 1, "after create tool") unlinkOp := clause.Association{Association: "Tools", Type: clause.OpUnlink} if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(unlinkOp).Update(ctx); err != nil { t.Fatalf("unlink tools: %v", err) } else if rows != 0 { t.Fatalf("rows %d", rows) } AssertAssociationCount(t, user, "Tools", 0, "after clear tools") } // Invalid association should return error func TestClauseAssociationSetUpdateInvalidAssociation(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateInvalidAssociation", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } assocOp := clause.Association{Association: "Invalid", Type: clause.OpCreate} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err == nil { t.Fatalf("expected error for invalid association, got nil") } } // No owner matched; should be no-op func TestClauseAssociationSetUpdateNoOwnerMatch(t *testing.T) { ctx := context.Background() assocOp := clause.Association{Association: "Pets", Type: clause.OpCreate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "won't-create"}}} if rows, err := gorm.G[User](DB).Where("id = ?", -1).Set(assocOp).Update(ctx); err != nil { t.Fatalf("unexpected error: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows, got %d", rows) } } // OpDelete/OpUpdate should work for associations func TestClauseAssociationSetUpdateAndDelete(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateAndDelete", Age: 25} user.Pets = []*Pet{{Name: "before"}} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } AssertAssociationCount(t, user, "Pets", 1, "before update/delete") // Update pet name via OpUpdate updOp := clause.Association{Association: "Pets", Type: clause.OpUpdate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "x"}}} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(updOp).Update(ctx); err != nil { t.Fatalf("OpUpdate failed: %v", err) } var got User if err := DB.Preload("Pets").First(&got, user.ID).Error; err != nil { t.Fatalf("load user: %v", err) } if len(got.Pets) != 1 || got.Pets[0].Name != "x" { t.Fatalf("expected updated pet name, got %+v", got.Pets) } // Delete pets via OpDelete delOp := clause.Association{Association: "Pets", Type: clause.OpDelete} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(delOp).Update(ctx); err != nil { t.Fatalf("OpDelete failed: %v", err) } AssertAssociationCount(t, user, "Pets", 0, "after delete") } // HasOne: update and delete NamedPet via OpUpdate/OpDelete func TestClauseAssociationSetUpdateAndDeleteHasOne(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateAndDeleteHasOne", Age: 25} user.NamedPet = &Pet{Name: "np-before"} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } AssertAssociationCount(t, user, "NamedPet", 1, "before") upd := clause.Association{Association: "NamedPet", Type: clause.OpUpdate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "np-after"}}} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(upd).Update(ctx); err != nil { t.Fatalf("OpUpdate has-one failed: %v", err) } var u1 User if err := DB.Preload("NamedPet").First(&u1, user.ID).Error; err != nil { t.Fatalf("load: %v", err) } if u1.NamedPet == nil || u1.NamedPet.Name != "np-after" { t.Fatalf("expected name updated, got %+v", u1.NamedPet) } del := clause.Association{Association: "NamedPet", Type: clause.OpDelete} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(del).Update(ctx); err != nil { t.Fatalf("OpDelete has-one failed: %v", err) } AssertAssociationCount(t, user, "NamedPet", 0, "after delete") } // Many2Many append with map using Association API (regression for map support) func TestAssociationMany2ManyAppendMap_GenericFile(t *testing.T) { user := User{Name: "AssocM2MAppendMapGeneric", Age: 28} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } if err := DB.Model(&user).Association("Languages").Append(map[string]interface{}{ "code": "gm2m_map_1", "name": "GMap1", }); err != nil { t.Fatalf("append map: %v", err) } AssertAssociationCount(t, user, "Languages", 1, "after append 1 map (generic file)") // Append more maps individually if err := DB.Model(&user).Association("Languages").Append(map[string]interface{}{"code": "gm2m_map_2", "name": "GMap2"}); err != nil { t.Fatalf("append map 2: %v", err) } if err := DB.Model(&user).Association("Languages").Append(map[string]interface{}{"code": "gm2m_map_3", "name": "GMap3"}); err != nil { t.Fatalf("append map 3: %v", err) } AssertAssociationCount(t, user, "Languages", 3, "after append 3 maps total (generic file)") } // BelongsTo: update and delete Company via OpUpdate/OpDelete func TestClauseAssociationSetUpdateAndDeleteBelongsTo(t *testing.T) { ctx := context.Background() // Create company and user with company company := Company{Name: "Electronics"} if err := DB.Create(&company).Error; err != nil { t.Fatalf("create company: %v", err) } user := User{Name: "John", Age: 25, CompanyID: &company.ID} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } // Verify association exists AssertAssociationCount(t, &user, "Company", 1, "before") // Update company name via OpUpdate upd := clause.Association{Association: "Company", Type: clause.OpUpdate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "Electronics-New"}}} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(upd).Update(ctx); err != nil { t.Fatalf("OpUpdate belongs-to failed: %v", err) } var u1 User if err := DB.Preload("Company").First(&u1, user.ID).Error; err != nil { t.Fatalf("load: %v", err) } if u1.Company.ID == 0 || u1.Company.Name != "Electronics-New" { t.Fatalf("expected company updated, got %+v", u1.Company) } // Unlink company association via OpUnlink (instead of OpDelete which would try to delete the company record) unlink := clause.Association{Association: "Company", Type: clause.OpUnlink} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(unlink).Update(ctx); err != nil { t.Fatalf("OpUnlink belongs-to failed: %v", err) } var u2 User if err := DB.Preload("Company").First(&u2, user.ID).Error; err != nil { t.Fatalf("load: %v", err) } if u2.Company.ID != 0 { t.Fatalf("expected company association cleared due to unlink, got %+v", u2.Company) } } // Many2Many: update and delete via Set func TestClauseAssociationSetUpdateAndDeleteMany2Many(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateAndDeleteMany2Many", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } langs := []Language{{Code: "es", Name: "Spanish"}, {Code: "de", Name: "German"}} for _, l := range langs { DB.FirstOrCreate(&l, "code = ?", l.Code) } if err := DB.Model(&user).Association("Languages").Append(&langs); err != nil { t.Fatalf("append: %v", err) } AssertAssociationCount(t, user, "Languages", 2, "before") upd := clause.Association{Association: "Languages", Type: clause.OpUpdate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "Espanol"}}, Conditions: []clause.Expression{clause.Eq{Column: clause.Column{Name: "code"}, Value: "es"}}} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(upd).Update(ctx); err != nil { t.Fatalf("OpUpdate m2m failed: %v", err) } var es Language if err := DB.First(&es, "code = ?", "es").Error; err != nil { t.Fatalf("load lang: %v", err) } if es.Name != "Espanol" { t.Fatalf("expected updated language name, got %s", es.Name) } del := clause.Association{Association: "Languages", Type: clause.OpDelete, Conditions: []clause.Expression{clause.Eq{Column: clause.Column{Name: "code"}, Value: "es"}}} if _, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(del).Update(ctx); err != nil { t.Fatalf("OpDelete m2m failed: %v", err) } AssertAssociationCount(t, user, "Languages", 1, "after delete one") // language row remains var count int64 if err := DB.Model(&Language{}).Where("code = ?", "es").Count(&count).Error; err != nil { t.Fatalf("count lang: %v", err) } if count != 1 { t.Fatalf("expected language row still exists, got %d", count) } } // Multi-owners: HasMany update and delete func TestClauseAssociationSetUpdateAndDeleteManyOwnersHasMany(t *testing.T) { ctx := context.Background() u1 := User{Name: "MultiOwners-HasMany-1", Age: 21} u1.Pets = []*Pet{{Name: "p1"}} u2 := User{Name: "MultiOwners-HasMany-2", Age: 22} u2.Pets = []*Pet{{Name: "p2"}} if err := DB.Create(&u1).Error; err != nil { t.Fatalf("create u1: %v", err) } if err := DB.Create(&u2).Error; err != nil { t.Fatalf("create u2: %v", err) } AssertAssociationCount(t, u1, "Pets", 1, "before") AssertAssociationCount(t, u2, "Pets", 1, "before") upd := clause.Association{Association: "Pets", Type: clause.OpUpdate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "x"}}} if _, err := gorm.G[User](DB).Where("id IN ?", []uint{u1.ID, u2.ID}).Set(upd).Update(ctx); err != nil { t.Fatalf("OpUpdate has-many failed: %v", err) } var got1, got2 User if err := DB.Preload("Pets").First(&got1, u1.ID).Error; err != nil { t.Fatalf("load u1: %v", err) } if err := DB.Preload("Pets").First(&got2, u2.ID).Error; err != nil { t.Fatalf("load u2: %v", err) } if len(got1.Pets) != 1 || got1.Pets[0].Name != "x" { t.Fatalf("u1 pet not updated: %+v", got1.Pets) } if len(got2.Pets) != 1 || got2.Pets[0].Name != "x" { t.Fatalf("u2 pet not updated: %+v", got2.Pets) } del := clause.Association{Association: "Pets", Type: clause.OpDelete} if _, err := gorm.G[User](DB).Where("id IN ?", []uint{u1.ID, u2.ID}).Set(del).Update(ctx); err != nil { t.Fatalf("OpDelete has-many failed: %v", err) } AssertAssociationCount(t, u1, "Pets", 0, "after delete") AssertAssociationCount(t, u2, "Pets", 0, "after delete") } // Multi-owners: BelongsTo update and delete func TestClauseAssociationSetUpdateAndDeleteManyOwnersBelongsTo(t *testing.T) { ctx := context.Background() // Create companies c1 := Company{Name: "Electronics"} c2 := Company{Name: "Books"} if err := DB.Create(&c1).Error; err != nil { t.Fatalf("create c1: %v", err) } if err := DB.Create(&c2).Error; err != nil { t.Fatalf("create c2: %v", err) } // Create users with companies u1 := User{Name: "John", Age: 25, CompanyID: &c1.ID} u2 := User{Name: "Jane", Age: 30, CompanyID: &c2.ID} if err := DB.Create(&u1).Error; err != nil { t.Fatalf("create u1: %v", err) } if err := DB.Create(&u2).Error; err != nil { t.Fatalf("create u2: %v", err) } // Verify associations exist AssertAssociationCount(t, &u1, "Company", 1, "before") AssertAssociationCount(t, &u2, "Company", 1, "before") // Update companies via OpUpdate for multiple users upd := clause.Association{Association: "Company", Type: clause.OpUpdate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "Category-New"}}} if _, err := gorm.G[User](DB).Where("id IN ?", []uint{u1.ID, u2.ID}).Set(upd).Update(ctx); err != nil { t.Fatalf("OpUpdate belongs-to failed: %v", err) } // Check if companies were updated var g1, g2 User if err := DB.Preload("Company").First(&g1, u1.ID).Error; err != nil { t.Fatalf("load u1: %v", err) } if err := DB.Preload("Company").First(&g2, u2.ID).Error; err != nil { t.Fatalf("load u2: %v", err) } if (g1.Company.ID == 0 || g1.Company.Name != "Category-New") || (g2.Company.ID == 0 || g2.Company.Name != "Category-New") { t.Fatalf("company names not updated: %+v, %+v", g1.Company, g2.Company) } // Unlink companies via OpUnlink for multiple users (instead of OpDelete which would try to delete the company records) unlink := clause.Association{Association: "Company", Type: clause.OpUnlink} if _, err := gorm.G[User](DB).Where("id IN ?", []uint{u1.ID, u2.ID}).Set(unlink).Update(ctx); err != nil { t.Fatalf("OpUnlink belongs-to failed: %v", err) } // Reload users to reflect the changes in the database if err := DB.First(&u1, u1.ID).Error; err != nil { t.Fatalf("reload u1: %v", err) } if err := DB.First(&u2, u2.ID).Error; err != nil { t.Fatalf("reload u2: %v", err) } // Check if company associations were cleared AssertAssociationCount(t, &u1, "Company", 0, "after unlink") AssertAssociationCount(t, &u2, "Company", 0, "after unlink") } // Multi-owners: Many2Many update and delete func TestClauseAssociationSetUpdateAndDeleteManyOwnersMany2Many(t *testing.T) { ctx := context.Background() u1 := User{Name: "MultiOwners-M2M-1", Age: 21} u2 := User{Name: "MultiOwners-M2M-2", Age: 22} if err := DB.Create(&u1).Error; err != nil { t.Fatalf("create u1: %v", err) } if err := DB.Create(&u2).Error; err != nil { t.Fatalf("create u2: %v", err) } l1 := Language{Code: "zz", Name: "ZZ"} l2 := Language{Code: "yy", Name: "YY"} DB.FirstOrCreate(&l1, "code = ?", l1.Code) DB.FirstOrCreate(&l2, "code = ?", l2.Code) if err := DB.Model(&u1).Association("Languages").Append(&l1, &l2); err != nil { t.Fatalf("append u1: %v", err) } if err := DB.Model(&u2).Association("Languages").Append(&l1, &l2); err != nil { t.Fatalf("append u2: %v", err) } AssertAssociationCount(t, u1, "Languages", 2, "before") AssertAssociationCount(t, u2, "Languages", 2, "before") upd := clause.Association{Association: "Languages", Type: clause.OpUpdate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "ZZZ"}}, Conditions: []clause.Expression{clause.Eq{Column: clause.Column{Name: "code"}, Value: "zz"}}} if _, err := gorm.G[User](DB).Where("id IN ?", []uint{u1.ID, u2.ID}).Set(upd).Update(ctx); err != nil { t.Fatalf("OpUpdate m2m failed: %v", err) } var l Language if err := DB.First(&l, "code = ?", "zz").Error; err != nil { t.Fatalf("load lang: %v", err) } if l.Name != "ZZZ" { t.Fatalf("expected lang updated, got %s", l.Name) } del := clause.Association{Association: "Languages", Type: clause.OpDelete, Conditions: []clause.Expression{clause.Eq{Column: clause.Column{Name: "code"}, Value: "zz"}}} if _, err := gorm.G[User](DB).Where("id IN ?", []uint{u1.ID, u2.ID}).Set(del).Update(ctx); err != nil { t.Fatalf("OpDelete m2m failed: %v", err) } AssertAssociationCount(t, u1, "Languages", 1, "after delete") AssertAssociationCount(t, u2, "Languages", 1, "after delete") } // Test Set + Update with has-one (NamedPet) using OpCreate func TestClauseAssociationSetUpdateHasOneCreateValues(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateHasOneCreateValues", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user: %v", err) } assocOp := clause.Association{ Association: "NamedPet", Type: clause.OpCreate, Values: []interface{}{Pet{Name: "named-pet"}}, } rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx) if err != nil { t.Fatalf("Set Update has-one create values failed: %v", err) } if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } var updated User if err := DB.Preload("NamedPet").First(&updated, user.ID).Error; err != nil { t.Fatalf("failed to load user: %v", err) } if updated.NamedPet == nil || updated.NamedPet.Name != "named-pet" { t.Fatalf("expected named-pet created, got %+v", updated.NamedPet) } } // Test Set + Update to clear all has-many (Pets) via OpUnlink without conditions func TestClauseAssociationSetUpdateHasManyClear(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateHasManyClear", Age: 25} user.Pets = []*Pet{{Name: "p1"}, {Name: "p2"}} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user: %v", err) } AssertAssociationCount(t, user, "Pets", 2, "before clear") assocOp := clause.Association{Association: "Pets", Type: clause.OpUnlink} if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update has-many clear failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } AssertAssociationCount(t, user, "Pets", 0, "after clear") } // Test Set + Update with many-to-many unlink specific records using conditions func TestClauseAssociationSetUpdateManyToManyUnlink(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdateManyToManyUnlink", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user: %v", err) } langs := []Language{{Code: "es", Name: "Spanish"}, {Code: "de", Name: "German"}} for _, l := range langs { DB.FirstOrCreate(&l, "code = ?", l.Code) } // Associate both if err := DB.Model(&user).Association("Languages").Append(&langs); err != nil { t.Fatalf("failed to append languages: %v", err) } AssertAssociationCount(t, user, "Languages", 2, "before unlink") assocOp := clause.Association{ Association: "Languages", Type: clause.OpUnlink, Conditions: []clause.Expression{clause.Eq{Column: clause.Column{Name: "code"}, Value: "es"}}, } if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update many-to-many unlink failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } AssertAssociationCount(t, user, "Languages", 1, "after unlink one") } // Test Set + Update with polymorphic has-many (Toys) using OpCreate func TestClauseAssociationSetUpdatePolymorphicCreate(t *testing.T) { ctx := context.Background() user := User{Name: "TestClauseAssociationSetUpdatePolymorphicCreate", Age: 25} if err := DB.Create(&user).Error; err != nil { t.Fatalf("failed to create user: %v", err) } assocOp := clause.Association{ Association: "Toys", Type: clause.OpCreate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "yo-yo"}}, } if rows, err := gorm.G[User](DB).Where("id = ?", user.ID).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update polymorphic create failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } AssertAssociationCount(t, user, "Toys", 1, "after create toy") } // Test Set + Update across multiple owners func TestClauseAssociationSetUpdateMultipleOwners(t *testing.T) { ctx := context.Background() u1 := User{Name: "SetUpdateMultipleOwners-1", Age: 20} u2 := User{Name: "SetUpdateMultipleOwners-2", Age: 21} if err := DB.Create(&u1).Error; err != nil { t.Fatalf("create u1: %v", err) } if err := DB.Create(&u2).Error; err != nil { t.Fatalf("create u2: %v", err) } assocOp := clause.Association{Association: "Pets", Type: clause.OpCreate, Set: []clause.Assignment{{Column: clause.Column{Name: "name"}, Value: "multi-pet"}}} if rows, err := gorm.G[User](DB).Where("name IN ?", []string{u1.Name, u2.Name}).Set(assocOp).Update(ctx); err != nil { t.Fatalf("Set Update multi owners failed: %v", err) } else if rows != 0 { t.Fatalf("expected 0 rows affected for association-only update, got %d", rows) } AssertAssociationCount(t, u1, "Pets", 1, "u1 after create") AssertAssociationCount(t, u2, "Pets", 1, "u2 after create") } ================================================ FILE: tests/associations_belongs_to_test.go ================================================ package tests_test import ( "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestBelongsToAssociation(t *testing.T) { user := *GetUser("belongs-to", Config{Company: true, Manager: true}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) // Find var user2 User DB.Find(&user2, "id = ?", user.ID) pointerOfUser := &user2 if err := DB.Model(&pointerOfUser).Association("Company").Find(&user2.Company); err != nil { t.Errorf("failed to query users, got error %#v", err) } user2.Manager = &User{} DB.Model(&user2).Association("Manager").Find(user2.Manager) CheckUser(t, user2, user) // Count AssertAssociationCount(t, user, "Company", 1, "") AssertAssociationCount(t, user, "Manager", 1, "") // Append company := Company{Name: "company-belongs-to-append"} manager := GetUser("manager-belongs-to-append", Config{}) if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { t.Fatalf("Error happened when append Company, got %v", err) } if company.ID == 0 { t.Fatalf("Company's ID should be created") } if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { t.Fatalf("Error happened when append Manager, got %v", err) } if manager.ID == 0 { t.Fatalf("Manager's ID should be created") } user.Company = company user.Manager = manager user.CompanyID = &company.ID user.ManagerID = &manager.ID CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Company", 1, "AfterAppend") AssertAssociationCount(t, user2, "Manager", 1, "AfterAppend") // Replace company2 := Company{Name: "company-belongs-to-replace"} manager2 := GetUser("manager-belongs-to-replace", Config{}) if err := DB.Model(&user2).Association("Company").Replace(&company2); err != nil { t.Fatalf("Error happened when replace Company, got %v", err) } if company2.ID == 0 { t.Fatalf("Company's ID should be created") } if err := DB.Model(&user2).Association("Manager").Replace(manager2); err != nil { t.Fatalf("Error happened when replace Manager, got %v", err) } if manager2.ID == 0 { t.Fatalf("Manager's ID should be created") } user.Company = company2 user.Manager = manager2 user.CompanyID = &company2.ID user.ManagerID = &manager2.ID CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Company", 1, "AfterReplace") AssertAssociationCount(t, user2, "Manager", 1, "AfterReplace") // Delete if err := DB.Model(&user2).Association("Company").Delete(&Company{}); err != nil { t.Fatalf("Error happened when delete Company, got %v", err) } AssertAssociationCount(t, user2, "Company", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Company").Delete(&company2); err != nil { t.Fatalf("Error happened when delete Company, got %v", err) } AssertAssociationCount(t, user2, "Company", 0, "after delete") if err := DB.Model(&user2).Association("Manager").Delete(&User{}); err != nil { t.Fatalf("Error happened when delete Manager, got %v", err) } AssertAssociationCount(t, user2, "Manager", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Manager").Delete(manager2); err != nil { t.Fatalf("Error happened when delete Manager, got %v", err) } AssertAssociationCount(t, user2, "Manager", 0, "after delete") // Prepare Data for Clear if err := DB.Model(&user2).Association("Company").Append(&company); err != nil { t.Fatalf("Error happened when append Company, got %v", err) } if err := DB.Model(&user2).Association("Manager").Append(manager); err != nil { t.Fatalf("Error happened when append Manager, got %v", err) } AssertAssociationCount(t, user2, "Company", 1, "after prepare data") AssertAssociationCount(t, user2, "Manager", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Company").Clear(); err != nil { t.Errorf("Error happened when clear Company, got %v", err) } if err := DB.Model(&user2).Association("Manager").Clear(); err != nil { t.Errorf("Error happened when clear Manager, got %v", err) } AssertAssociationCount(t, user2, "Company", 0, "after clear") AssertAssociationCount(t, user2, "Manager", 0, "after clear") // unexist company id unexistCompanyID := company.ID + 9999999 user = User{Name: "invalid-user-with-invalid-belongs-to-foreign-key", CompanyID: &unexistCompanyID} if err := DB.Create(&user).Error; err == nil { tidbSkip(t, "not support the foreign key feature") t.Errorf("should have gotten foreign key violation error") } } func TestBelongsToAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-belongs-to-1", Config{Company: true, Manager: true}), *GetUser("slice-belongs-to-2", Config{Company: true, Manager: false}), *GetUser("slice-belongs-to-3", Config{Company: true, Manager: true}), } DB.Create(&users) AssertAssociationCount(t, users, "Company", 3, "") AssertAssociationCount(t, users, "Manager", 2, "") // Find var companies []Company if DB.Model(&users).Association("Company").Find(&companies); len(companies) != 3 { t.Errorf("companies count should be %v, but got %v", 3, len(companies)) } var managers []User if DB.Model(&users).Association("Manager").Find(&managers); len(managers) != 2 { t.Errorf("managers count should be %v, but got %v", 2, len(managers)) } // Append DB.Model(&users).Association("Company").Append( &Company{Name: "company-slice-append-1"}, &Company{Name: "company-slice-append-2"}, &Company{Name: "company-slice-append-3"}, ) AssertAssociationCount(t, users, "Company", 3, "After Append") DB.Model(&users).Association("Manager").Append( GetUser("manager-slice-belongs-to-1", Config{}), GetUser("manager-slice-belongs-to-2", Config{}), GetUser("manager-slice-belongs-to-3", Config{}), ) AssertAssociationCount(t, users, "Manager", 3, "After Append") if err := DB.Model(&users).Association("Manager").Append( GetUser("manager-slice-belongs-to-test-1", Config{}), ).Error; err == nil { t.Errorf("unmatched length when update user's manager") } // Replace -> same as append // Delete if err := DB.Model(&users).Association("Company").Delete(&users[0].Company); err != nil { t.Errorf("no error should happened when deleting company, but got %v", err) } if users[0].CompanyID != nil || users[0].Company.ID != 0 { t.Errorf("users[0]'s company should be deleted'") } AssertAssociationCount(t, users, "Company", 2, "After Delete") // Clear DB.Model(&users).Association("Company").Clear() AssertAssociationCount(t, users, "Company", 0, "After Clear") DB.Model(&users).Association("Manager").Clear() AssertAssociationCount(t, users, "Manager", 0, "After Clear") // shared company company := Company{Name: "shared"} if err := DB.Model(&users[0]).Association("Company").Append(&company); err != nil { t.Errorf("Error happened when append company to user, got %v", err) } if err := DB.Model(&users[1]).Association("Company").Append(&company); err != nil { t.Errorf("Error happened when append company to user, got %v", err) } if users[0].CompanyID == nil || users[1].CompanyID == nil || *users[0].CompanyID != *users[1].CompanyID { t.Errorf("user's company id should exists and equal, but its: %v, %v", users[0].CompanyID, users[1].CompanyID) } DB.Model(&users[0]).Association("Company").Delete(&company) AssertAssociationCount(t, users[0], "Company", 0, "After Delete") AssertAssociationCount(t, users[1], "Company", 1, "After other user Delete") } func TestBelongsToDefaultValue(t *testing.T) { type Org struct { ID string } type BelongsToUser struct { OrgID string Org Org `gorm:"default:NULL"` } tx := DB.Session(&gorm.Session{}) tx.Config.DisableForeignKeyConstraintWhenMigrating = true AssertEqual(t, DB.Config.DisableForeignKeyConstraintWhenMigrating, false) tx.Migrator().DropTable(&BelongsToUser{}, &Org{}) tx.AutoMigrate(&BelongsToUser{}, &Org{}) user := &BelongsToUser{ Org: Org{ ID: "BelongsToUser_Org_1", }, } err := DB.Create(&user).Error AssertEqual(t, err, nil) } func TestBelongsToAssociationUnscoped(t *testing.T) { type ItemParent struct { gorm.Model Logo string `gorm:"not null;type:varchar(50)"` } type ItemChild struct { gorm.Model Name string `gorm:"type:varchar(50)"` ItemParentID uint ItemParent ItemParent } tx := DB.Session(&gorm.Session{}) tx.Migrator().DropTable(&ItemParent{}, &ItemChild{}) tx.AutoMigrate(&ItemParent{}, &ItemChild{}) item := ItemChild{ Name: "name", ItemParent: ItemParent{ Logo: "logo", }, } if err := tx.Create(&item).Error; err != nil { t.Fatalf("failed to create items, got error: %v", err) } // test replace if err := tx.Model(&item).Association("ItemParent").Unscoped().Replace(&ItemParent{ Logo: "updated logo", }); err != nil { t.Errorf("failed to replace item parent, got error: %v", err) } var parents []ItemParent if err := tx.Find(&parents).Error; err != nil { t.Errorf("failed to find item parent, got error: %v", err) } if len(parents) != 1 { t.Errorf("expected %d parents, got %d", 1, len(parents)) } // test delete if err := tx.Model(&item).Association("ItemParent").Unscoped().Delete(&parents); err != nil { t.Errorf("failed to delete item parent, got error: %v", err) } if err := tx.Find(&parents).Error; err != nil { t.Errorf("failed to find item parent, got error: %v", err) } if len(parents) != 0 { t.Errorf("expected %d parents, got %d", 0, len(parents)) } } ================================================ FILE: tests/associations_has_many_test.go ================================================ package tests_test import ( "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestHasManyAssociation(t *testing.T) { user := *GetUser("hasmany", Config{Pets: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Pets").Find(&user2.Pets) CheckUser(t, user2, user) var pets []Pet DB.Model(&user).Where("name = ?", user.Pets[0].Name).Association("Pets").Find(&pets) if len(pets) != 1 { t.Fatalf("should only find one pets, but got %v", len(pets)) } CheckPet(t, pets[0], *user.Pets[0]) if count := DB.Model(&user).Where("name = ?", user.Pets[1].Name).Association("Pets").Count(); count != 1 { t.Fatalf("should only find one pets, but got %v", count) } if count := DB.Model(&user).Where("name = ?", "not found").Association("Pets").Count(); count != 0 { t.Fatalf("should only find no pet with invalid conditions, but got %v", count) } // Count AssertAssociationCount(t, user, "Pets", 2, "") // Append pet := Pet{Name: "pet-has-many-append"} if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { t.Fatalf("Error happened when append account, got %v", err) } if pet.ID == 0 { t.Fatalf("Pet's ID should be created") } user.Pets = append(user.Pets, &pet) CheckUser(t, user2, user) AssertAssociationCount(t, user, "Pets", 3, "AfterAppend") pets2 := []Pet{{Name: "pet-has-many-append-1-1"}, {Name: "pet-has-many-append-1-1"}} if err := DB.Model(&user2).Association("Pets").Append(&pets2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } for _, pet := range pets2 { pet := pet if pet.ID == 0 { t.Fatalf("Pet's ID should be created") } user.Pets = append(user.Pets, &pet) } CheckUser(t, user2, user) AssertAssociationCount(t, user, "Pets", 5, "AfterAppendSlice") // Replace pet2 := Pet{Name: "pet-has-many-replace"} if err := DB.Model(&user2).Association("Pets").Replace(&pet2); err != nil { t.Fatalf("Error happened when append pet, got %v", err) } if pet2.ID == 0 { t.Fatalf("pet2's ID should be created") } user.Pets = []*Pet{&pet2} CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Pets", 1, "AfterReplace") // Delete if err := DB.Model(&user2).Association("Pets").Delete(&Pet{}); err != nil { t.Fatalf("Error happened when delete pet, got %v", err) } AssertAssociationCount(t, user2, "Pets", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Pets").Delete(&pet2); err != nil { t.Fatalf("Error happened when delete Pets, got %v", err) } AssertAssociationCount(t, user2, "Pets", 0, "after delete") // Prepare Data for Clear if err := DB.Model(&user2).Association("Pets").Append(&pet); err != nil { t.Fatalf("Error happened when append Pets, got %v", err) } AssertAssociationCount(t, user2, "Pets", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Pets").Clear(); err != nil { t.Errorf("Error happened when clear Pets, got %v", err) } AssertAssociationCount(t, user2, "Pets", 0, "after clear") } func TestSingleTableHasManyAssociation(t *testing.T) { user := *GetUser("hasmany", Config{Team: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Team").Find(&user2.Team) CheckUser(t, user2, user) // Count AssertAssociationCount(t, user, "Team", 2, "") // Append team := *GetUser("team", Config{}) if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { t.Fatalf("Error happened when append account, got %v", err) } if team.ID == 0 { t.Fatalf("Team's ID should be created") } user.Team = append(user.Team, team) CheckUser(t, user2, user) AssertAssociationCount(t, user, "Team", 3, "AfterAppend") teams := []User{*GetUser("team-append-1", Config{}), *GetUser("team-append-2", Config{})} if err := DB.Model(&user2).Association("Team").Append(&teams); err != nil { t.Fatalf("Error happened when append team, got %v", err) } for _, team := range teams { team := team if team.ID == 0 { t.Fatalf("Team's ID should be created") } user.Team = append(user.Team, team) } CheckUser(t, user2, user) AssertAssociationCount(t, user, "Team", 5, "AfterAppendSlice") // Replace team2 := *GetUser("team-replace", Config{}) if err := DB.Model(&user2).Association("Team").Replace(&team2); err != nil { t.Fatalf("Error happened when append team, got %v", err) } if team2.ID == 0 { t.Fatalf("team2's ID should be created") } user.Team = []User{team2} CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Team", 1, "AfterReplace") // Delete if err := DB.Model(&user2).Association("Team").Delete(&User{}); err != nil { t.Fatalf("Error happened when delete team, got %v", err) } AssertAssociationCount(t, user2, "Team", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Team").Delete(&team2); err != nil { t.Fatalf("Error happened when delete Team, got %v", err) } AssertAssociationCount(t, user2, "Team", 0, "after delete") // Prepare Data for Clear if err := DB.Model(&user2).Association("Team").Append(&team); err != nil { t.Fatalf("Error happened when append Team, got %v", err) } AssertAssociationCount(t, user2, "Team", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Team").Clear(); err != nil { t.Errorf("Error happened when clear Team, got %v", err) } AssertAssociationCount(t, user2, "Team", 0, "after clear") } func TestHasManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasmany-1", Config{Pets: 2}), *GetUser("slice-hasmany-2", Config{Pets: 0}), *GetUser("slice-hasmany-3", Config{Pets: 4}), } DB.Create(&users) // Count AssertAssociationCount(t, users, "Pets", 6, "") // Find var pets []Pet if DB.Model(&users).Association("Pets").Find(&pets); len(pets) != 6 { t.Errorf("pets count should be %v, but got %v", 6, len(pets)) } // Append DB.Model(&users).Association("Pets").Append( &Pet{Name: "pet-slice-append-1"}, []*Pet{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, &Pet{Name: "pet-slice-append-3"}, ) AssertAssociationCount(t, users, "Pets", 10, "After Append") // Replace -> same as append DB.Model(&users).Association("Pets").Replace( []*Pet{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, []*Pet{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, &Pet{Name: "pet-slice-replace-3"}, ) AssertAssociationCount(t, users, "Pets", 5, "After Append") // Delete if err := DB.Model(&users).Association("Pets").Delete(&users[2].Pets); err != nil { t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Pets", 4, "after delete") if err := DB.Model(&users).Association("Pets").Delete(users[0].Pets[0], users[1].Pets[1]); err != nil { t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Pets", 2, "after delete") // Clear DB.Model(&users).Association("Pets").Clear() AssertAssociationCount(t, users, "Pets", 0, "After Clear") } func TestSingleTableHasManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasmany-1", Config{Team: 2}), *GetUser("slice-hasmany-2", Config{Team: 0}), *GetUser("slice-hasmany-3", Config{Team: 4}), } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } // Count AssertAssociationCount(t, users, "Team", 6, "") // Find var teams []User if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { t.Errorf("teams count should be %v, but got %v", 6, len(teams)) } // Append DB.Model(&users).Association("Team").Append( &User{Name: "pet-slice-append-1"}, []*User{{Name: "pet-slice-append-2-1"}, {Name: "pet-slice-append-2-2"}}, &User{Name: "pet-slice-append-3"}, ) AssertAssociationCount(t, users, "Team", 10, "After Append") // Replace -> same as append DB.Model(&users).Association("Team").Replace( []*User{{Name: "pet-slice-replace-1-1"}, {Name: "pet-slice-replace-1-2"}}, []*User{{Name: "pet-slice-replace-2-1"}, {Name: "pet-slice-replace-2-2"}}, &User{Name: "pet-slice-replace-3"}, ) AssertAssociationCount(t, users, "Team", 5, "After Append") // Delete if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Team", 4, "after delete") if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { t.Errorf("no error should happened when deleting pet, but got %v", err) } AssertAssociationCount(t, users, "Team", 2, "after delete") // Clear DB.Model(&users).Association("Team").Clear() AssertAssociationCount(t, users, "Team", 0, "After Clear") } func TestPolymorphicHasManyAssociation(t *testing.T) { user := *GetUser("hasmany", Config{Toys: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Toys").Find(&user2.Toys) CheckUser(t, user2, user) // Count AssertAssociationCount(t, user, "Toys", 2, "") // Append toy := Toy{Name: "toy-has-many-append"} if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { t.Fatalf("Error happened when append account, got %v", err) } if toy.ID == 0 { t.Fatalf("Toy's ID should be created") } user.Toys = append(user.Toys, toy) CheckUser(t, user2, user) AssertAssociationCount(t, user, "Toys", 3, "AfterAppend") toys := []Toy{{Name: "toy-has-many-append-1-1"}, {Name: "toy-has-many-append-1-1"}} if err := DB.Model(&user2).Association("Toys").Append(&toys); err != nil { t.Fatalf("Error happened when append toy, got %v", err) } for _, toy := range toys { toy := toy if toy.ID == 0 { t.Fatalf("Toy's ID should be created") } user.Toys = append(user.Toys, toy) } CheckUser(t, user2, user) AssertAssociationCount(t, user, "Toys", 5, "AfterAppendSlice") // Replace toy2 := Toy{Name: "toy-has-many-replace"} if err := DB.Model(&user2).Association("Toys").Replace(&toy2); err != nil { t.Fatalf("Error happened when append toy, got %v", err) } if toy2.ID == 0 { t.Fatalf("toy2's ID should be created") } user.Toys = []Toy{toy2} CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Toys", 1, "AfterReplace") // Delete if err := DB.Model(&user2).Association("Toys").Delete(&Toy{}); err != nil { t.Fatalf("Error happened when delete toy, got %v", err) } AssertAssociationCount(t, user2, "Toys", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Toys").Delete(&toy2); err != nil { t.Fatalf("Error happened when delete Toys, got %v", err) } AssertAssociationCount(t, user2, "Toys", 0, "after delete") // Prepare Data for Clear if err := DB.Model(&user2).Association("Toys").Append(&toy); err != nil { t.Fatalf("Error happened when append Toys, got %v", err) } AssertAssociationCount(t, user2, "Toys", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Toys").Clear(); err != nil { t.Errorf("Error happened when clear Toys, got %v", err) } AssertAssociationCount(t, user2, "Toys", 0, "after clear") } func TestPolymorphicHasManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasmany-1", Config{Toys: 2}), *GetUser("slice-hasmany-2", Config{Toys: 0, Tools: 2}), *GetUser("slice-hasmany-3", Config{Toys: 4}), } DB.Create(&users) // Count AssertAssociationCount(t, users, "Toys", 6, "") AssertAssociationCount(t, users, "Tools", 2, "") // Find var toys []Toy if DB.Model(&users).Association("Toys").Find(&toys); len(toys) != 6 { t.Errorf("toys count should be %v, but got %v", 6, len(toys)) } // Find Tools (polymorphic with custom type and id) var tools []Tools DB.Model(&users).Association("Tools").Find(&tools) if len(tools) != 2 { t.Errorf("tools count should be %v, but got %v", 2, len(tools)) } // Append DB.Model(&users).Association("Toys").Append( &Toy{Name: "toy-slice-append-1"}, []Toy{{Name: "toy-slice-append-2-1"}, {Name: "toy-slice-append-2-2"}}, &Toy{Name: "toy-slice-append-3"}, ) AssertAssociationCount(t, users, "Toys", 10, "After Append") // Replace -> same as append DB.Model(&users).Association("Toys").Replace( []*Toy{{Name: "toy-slice-replace-1-1"}, {Name: "toy-slice-replace-1-2"}}, []*Toy{{Name: "toy-slice-replace-2-1"}, {Name: "toy-slice-replace-2-2"}}, &Toy{Name: "toy-slice-replace-3"}, ) AssertAssociationCount(t, users, "Toys", 5, "After Append") // Delete if err := DB.Model(&users).Association("Toys").Delete(&users[2].Toys); err != nil { t.Errorf("no error should happened when deleting toy, but got %v", err) } AssertAssociationCount(t, users, "Toys", 4, "after delete") if err := DB.Model(&users).Association("Toys").Delete(users[0].Toys[0], users[1].Toys[1]); err != nil { t.Errorf("no error should happened when deleting toy, but got %v", err) } AssertAssociationCount(t, users, "Toys", 2, "after delete") // Clear DB.Model(&users).Association("Toys").Clear() AssertAssociationCount(t, users, "Toys", 0, "After Clear") } func TestHasManyAssociationUnscoped(t *testing.T) { type ItemContent struct { gorm.Model ItemID uint `gorm:"not null"` Name string `gorm:"not null;type:varchar(50)"` LanguageCode string `gorm:"not null;type:varchar(2)"` } type Item struct { gorm.Model Logo string `gorm:"not null;type:varchar(50)"` Contents []ItemContent `gorm:"foreignKey:ItemID"` } tx := DB.Session(&gorm.Session{}) tx.Migrator().DropTable(&ItemContent{}, &Item{}) tx.AutoMigrate(&ItemContent{}, &Item{}) item := Item{ Logo: "logo", Contents: []ItemContent{ {Name: "name", LanguageCode: "en"}, {Name: "ar name", LanguageCode: "ar"}, }, } if err := tx.Create(&item).Error; err != nil { t.Fatalf("failed to create items, got error: %v", err) } // test Replace if err := tx.Model(&item).Association("Contents").Unscoped().Replace([]ItemContent{ {Name: "updated name", LanguageCode: "en"}, {Name: "ar updated name", LanguageCode: "ar"}, {Name: "le nom", LanguageCode: "fr"}, }); err != nil { t.Errorf("failed to replace item content, got error: %v", err) } if count := tx.Model(&item).Association("Contents").Count(); count != 3 { t.Errorf("expected %d contents, got %d", 3, count) } var contents []ItemContent if err := tx.Find(&contents).Error; err != nil { t.Errorf("failed to find contents, got error: %v", err) } if len(contents) != 3 { t.Errorf("expected %d contents, got %d", 3, len(contents)) } // test delete if err := tx.Model(&item).Association("Contents").Unscoped().Delete(&contents[0]); err != nil { t.Errorf("failed to delete Contents, got error: %v", err) } if count := tx.Model(&item).Association("Contents").Count(); count != 2 { t.Errorf("expected %d contents, got %d", 2, count) } // test clear if err := tx.Model(&item).Association("Contents").Unscoped().Clear(); err != nil { t.Errorf("failed to clear contents association, got error: %v", err) } if count := tx.Model(&item).Association("Contents").Count(); count != 0 { t.Errorf("expected %d contents, got %d", 0, count) } if err := tx.Find(&contents).Error; err != nil { t.Errorf("failed to find contents, got error: %v", err) } if len(contents) != 0 { t.Errorf("expected %d contents, got %d", 0, len(contents)) } } func TestHasManyAssociationReplaceWithStructValue(t *testing.T) { user := User{Name: "HasManyAssociationReplaceWithStructValue", Languages: []Language{{Name: "EN", Code: "en"}}} if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } if err := DB.Model(&user).Association("Languages").Replace(Language{Name: "DE", Code: "de"}, Language{Name: "FR", Code: "fr"}); err != nil { t.Error("expected association error to be not nil") } var result User DB.Preload("Languages").Where("name = ?", "HasManyAssociationReplaceWithStructValue").Find(&result) if len(result.Languages) != 2 { t.Errorf("invalid languages found for user, got %v", result.Languages) } } ================================================ FILE: tests/associations_has_one_test.go ================================================ package tests_test import ( "testing" . "gorm.io/gorm/utils/tests" ) func TestHasOneAssociation(t *testing.T) { user := *GetUser("hasone", Config{Account: true}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Account").Find(&user2.Account) CheckUser(t, user2, user) // Count AssertAssociationCount(t, user, "Account", 1, "") // Append account := Account{Number: "account-has-one-append"} if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { t.Fatalf("Error happened when append account, got %v", err) } if account.ID == 0 { t.Fatalf("Account's ID should be created") } user.Account = account CheckUser(t, user2, user) AssertAssociationCount(t, user, "Account", 1, "AfterAppend") // Replace account2 := Account{Number: "account-has-one-replace"} if err := DB.Model(&user2).Association("Account").Replace(&account2); err != nil { t.Fatalf("Error happened when append Account, got %v", err) } if account2.ID == 0 { t.Fatalf("account2's ID should be created") } user.Account = account2 CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Account", 1, "AfterReplace") // Delete if err := DB.Model(&user2).Association("Account").Delete(&Account{}); err != nil { t.Fatalf("Error happened when delete account, got %v", err) } AssertAssociationCount(t, user2, "Account", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Account").Delete(&account2); err != nil { t.Fatalf("Error happened when delete Account, got %v", err) } AssertAssociationCount(t, user2, "Account", 0, "after delete") // Prepare Data for Clear account = Account{Number: "account-has-one-append"} if err := DB.Model(&user2).Association("Account").Append(&account); err != nil { t.Fatalf("Error happened when append Account, got %v", err) } AssertAssociationCount(t, user2, "Account", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Account").Clear(); err != nil { t.Errorf("Error happened when clear Account, got %v", err) } AssertAssociationCount(t, user2, "Account", 0, "after clear") } func TestHasOneAssociationWithSelect(t *testing.T) { user := *GetUser("hasone", Config{Account: true}) DB.Omit("Account.Number").Create(&user) AssertAssociationCount(t, user, "Account", 1, "") var account Account DB.Model(&user).Association("Account").Find(&account) if account.Number != "" { t.Errorf("account's number should not be saved") } } func TestHasOneAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-hasone-1", Config{Account: true}), *GetUser("slice-hasone-2", Config{Account: false}), *GetUser("slice-hasone-3", Config{Account: true}), } DB.Create(&users) // Count AssertAssociationCount(t, users, "Account", 2, "") // Find var accounts []Account if DB.Model(&users).Association("Account").Find(&accounts); len(accounts) != 2 { t.Errorf("accounts count should be %v, but got %v", 3, len(accounts)) } // Append DB.Model(&users).Association("Account").Append( &Account{Number: "account-slice-append-1"}, &Account{Number: "account-slice-append-2"}, &Account{Number: "account-slice-append-3"}, ) AssertAssociationCount(t, users, "Account", 3, "After Append") // Replace -> same as append // Delete if err := DB.Model(&users).Association("Account").Delete(&users[0].Account); err != nil { t.Errorf("no error should happened when deleting account, but got %v", err) } AssertAssociationCount(t, users, "Account", 2, "after delete") // Clear DB.Model(&users).Association("Account").Clear() AssertAssociationCount(t, users, "Account", 0, "After Clear") } func TestPolymorphicHasOneAssociation(t *testing.T) { pet := Pet{Name: "hasone", Toy: Toy{Name: "toy-has-one"}} if err := DB.Create(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckPet(t, pet, pet) // Find var pet2 Pet DB.Find(&pet2, "id = ?", pet.ID) DB.Model(&pet2).Association("Toy").Find(&pet2.Toy) CheckPet(t, pet2, pet) // Count AssertAssociationCount(t, pet, "Toy", 1, "") // Append toy := Toy{Name: "toy-has-one-append"} if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { t.Fatalf("Error happened when append toy, got %v", err) } if toy.ID == 0 { t.Fatalf("Toy's ID should be created") } pet.Toy = toy CheckPet(t, pet2, pet) AssertAssociationCount(t, pet, "Toy", 1, "AfterAppend") // Replace toy2 := Toy{Name: "toy-has-one-replace"} if err := DB.Model(&pet2).Association("Toy").Replace(&toy2); err != nil { t.Fatalf("Error happened when append Toy, got %v", err) } if toy2.ID == 0 { t.Fatalf("toy2's ID should be created") } pet.Toy = toy2 CheckPet(t, pet2, pet) AssertAssociationCount(t, pet2, "Toy", 1, "AfterReplace") // Delete if err := DB.Model(&pet2).Association("Toy").Delete(&Toy{}); err != nil { t.Fatalf("Error happened when delete toy, got %v", err) } AssertAssociationCount(t, pet2, "Toy", 1, "after delete non-existing data") if err := DB.Model(&pet2).Association("Toy").Delete(&toy2); err != nil { t.Fatalf("Error happened when delete Toy, got %v", err) } AssertAssociationCount(t, pet2, "Toy", 0, "after delete") // Prepare Data for Clear toy = Toy{Name: "toy-has-one-append"} if err := DB.Model(&pet2).Association("Toy").Append(&toy); err != nil { t.Fatalf("Error happened when append Toy, got %v", err) } AssertAssociationCount(t, pet2, "Toy", 1, "after prepare data") // Clear if err := DB.Model(&pet2).Association("Toy").Clear(); err != nil { t.Errorf("Error happened when clear Toy, got %v", err) } AssertAssociationCount(t, pet2, "Toy", 0, "after clear") } func TestPolymorphicHasOneAssociationForSlice(t *testing.T) { pets := []Pet{ {Name: "hasone-1", Toy: Toy{Name: "toy-has-one"}}, {Name: "hasone-2", Toy: Toy{}}, {Name: "hasone-3", Toy: Toy{Name: "toy-has-one"}}, } DB.Create(&pets) // Count AssertAssociationCount(t, pets, "Toy", 2, "") // Find var toys []Toy if DB.Model(&pets).Association("Toy").Find(&toys); len(toys) != 2 { t.Errorf("toys count should be %v, but got %v", 3, len(toys)) } // Append DB.Model(&pets).Association("Toy").Append( &Toy{Name: "toy-slice-append-1"}, &Toy{Name: "toy-slice-append-2"}, &Toy{Name: "toy-slice-append-3"}, ) AssertAssociationCount(t, pets, "Toy", 3, "After Append") // Replace -> same as append // Delete if err := DB.Model(&pets).Association("Toy").Delete(&pets[0].Toy); err != nil { t.Errorf("no error should happened when deleting toy, but got %v", err) } AssertAssociationCount(t, pets, "Toy", 2, "after delete") // Clear DB.Model(&pets).Association("Toy").Clear() AssertAssociationCount(t, pets, "Toy", 0, "After Clear") } func TestHasOneAssociationReplaceWithNonValidValue(t *testing.T) { user := User{Name: "jinzhu", Account: Account{Number: "1"}} if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } if err := DB.Model(&user).Association("Languages").Replace(Account{Number: "2"}); err == nil { t.Error("expected association error to be not nil") } } ================================================ FILE: tests/associations_many2many_test.go ================================================ package tests_test import ( "fmt" "sync" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestMany2ManyAssociation(t *testing.T) { user := *GetUser("many2many", Config{Languages: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Languages").Find(&user2.Languages) CheckUser(t, user2, user) // Count AssertAssociationCount(t, user, "Languages", 2, "") // Append language := Language{Code: "language-many2many-append", Name: "language-many2many-append"} DB.Create(&language) if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { t.Fatalf("Error happened when append account, got %v", err) } user.Languages = append(user.Languages, language) CheckUser(t, user2, user) AssertAssociationCount(t, user, "Languages", 3, "AfterAppend") languages := []Language{ {Code: "language-many2many-append-1-1", Name: "language-many2many-append-1-1"}, {Code: "language-many2many-append-2-1", Name: "language-many2many-append-2-1"}, } DB.Create(&languages) if err := DB.Model(&user2).Association("Languages").Append(&languages); err != nil { t.Fatalf("Error happened when append language, got %v", err) } user.Languages = append(user.Languages, languages...) CheckUser(t, user2, user) AssertAssociationCount(t, user, "Languages", 5, "AfterAppendSlice") // Replace language2 := Language{Code: "language-many2many-replace", Name: "language-many2many-replace"} DB.Create(&language2) if err := DB.Model(&user2).Association("Languages").Replace(&language2); err != nil { t.Fatalf("Error happened when append language, got %v", err) } user.Languages = []Language{language2} CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Languages", 1, "AfterReplace") // Delete if err := DB.Model(&user2).Association("Languages").Delete(&Language{}); err != nil { t.Fatalf("Error happened when delete language, got %v", err) } AssertAssociationCount(t, user2, "Languages", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Languages").Delete(&language2); err != nil { t.Fatalf("Error happened when delete Languages, got %v", err) } AssertAssociationCount(t, user2, "Languages", 0, "after delete") // Prepare Data for Clear if err := DB.Model(&user2).Association("Languages").Append(&language); err != nil { t.Fatalf("Error happened when append Languages, got %v", err) } AssertAssociationCount(t, user2, "Languages", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Languages").Clear(); err != nil { t.Errorf("Error happened when clear Languages, got %v", err) } AssertAssociationCount(t, user2, "Languages", 0, "after clear") } func TestMany2ManyOmitAssociations(t *testing.T) { tidbSkip(t, "not support the foreign key feature") user := *GetUser("many2many_omit_associations", Config{Languages: 2}) if err := DB.Omit("Languages.*").Create(&user).Error; err == nil { t.Fatalf("should raise error when create users without languages reference") } if err := DB.Create(&user.Languages).Error; err != nil { t.Fatalf("no error should happen when create languages, but got %v", err) } if err := DB.Omit("Languages.*").Create(&user).Error; err != nil { t.Fatalf("no error should happen when create user when languages exists, but got %v", err) } // Find var languages []Language if DB.Model(&user).Association("Languages").Find(&languages); len(languages) != 2 { t.Errorf("languages count should be %v, but got %v", 2, len(languages)) } newLang := Language{Code: "omitmany2many", Name: "omitmany2many"} if err := DB.Model(&user).Omit("Languages.*").Association("Languages").Replace(&newLang); err == nil { t.Errorf("should failed to insert languages due to constraint failed, error: %v", err) } } func TestMany2ManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-many2many-1", Config{Languages: 2}), *GetUser("slice-many2many-2", Config{Languages: 0}), *GetUser("slice-many2many-3", Config{Languages: 4}), } DB.Create(&users) // Count AssertAssociationCount(t, users, "Languages", 6, "") // Find var languages []Language if DB.Model(&users).Association("Languages").Find(&languages); len(languages) != 6 { t.Errorf("languages count should be %v, but got %v", 6, len(languages)) } // Append languages1 := []Language{ {Code: "language-many2many-append-1", Name: "language-many2many-append-1"}, } languages2 := []Language{} languages3 := []Language{ {Code: "language-many2many-append-3-1", Name: "language-many2many-append-3-1"}, {Code: "language-many2many-append-3-2", Name: "language-many2many-append-3-2"}, } DB.Create(&languages1) DB.Create(&languages3) DB.Model(&users).Association("Languages").Append(&languages1, &languages2, &languages3) AssertAssociationCount(t, users, "Languages", 9, "After Append") languages2_1 := []*Language{ {Code: "language-slice-replace-1-1", Name: "language-slice-replace-1-1"}, {Code: "language-slice-replace-1-2", Name: "language-slice-replace-1-2"}, } languages2_2 := []*Language{ {Code: "language-slice-replace-2-1", Name: "language-slice-replace-2-1"}, {Code: "language-slice-replace-2-2", Name: "language-slice-replace-2-2"}, } languages2_3 := &Language{Code: "language-slice-replace-3", Name: "language-slice-replace-3"} DB.Create(&languages2_1) DB.Create(&languages2_2) DB.Create(&languages2_3) // Replace DB.Model(&users).Association("Languages").Replace(&languages2_1, &languages2_2, languages2_3) AssertAssociationCount(t, users, "Languages", 5, "After Replace") // Delete if err := DB.Model(&users).Association("Languages").Delete(&users[2].Languages); err != nil { t.Errorf("no error should happened when deleting language, but got %v", err) } AssertAssociationCount(t, users, "Languages", 4, "after delete") if err := DB.Model(&users).Association("Languages").Delete(users[0].Languages[0], users[1].Languages[1]); err != nil { t.Errorf("no error should happened when deleting language, but got %v", err) } AssertAssociationCount(t, users, "Languages", 2, "after delete") // Clear DB.Model(&users).Association("Languages").Clear() AssertAssociationCount(t, users, "Languages", 0, "After Clear") } func TestSingleTableMany2ManyAssociation(t *testing.T) { user := *GetUser("many2many", Config{Friends: 2}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) // Find var user2 User DB.Find(&user2, "id = ?", user.ID) DB.Model(&user2).Association("Friends").Find(&user2.Friends) CheckUser(t, user2, user) // Count AssertAssociationCount(t, user, "Friends", 2, "") // Append friend := *GetUser("friend", Config{}) if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { t.Fatalf("Error happened when append account, got %v", err) } user.Friends = append(user.Friends, &friend) CheckUser(t, user2, user) AssertAssociationCount(t, user, "Friends", 3, "AfterAppend") friends := []*User{GetUser("friend-append-1", Config{}), GetUser("friend-append-2", Config{})} if err := DB.Model(&user2).Association("Friends").Append(&friends); err != nil { t.Fatalf("Error happened when append friend, got %v", err) } user.Friends = append(user.Friends, friends...) CheckUser(t, user2, user) AssertAssociationCount(t, user, "Friends", 5, "AfterAppendSlice") // Replace friend2 := *GetUser("friend-replace-2", Config{}) if err := DB.Model(&user2).Association("Friends").Replace(&friend2); err != nil { t.Fatalf("Error happened when append friend, got %v", err) } user.Friends = []*User{&friend2} CheckUser(t, user2, user) AssertAssociationCount(t, user2, "Friends", 1, "AfterReplace") // Delete if err := DB.Model(&user2).Association("Friends").Delete(&User{}); err != nil { t.Fatalf("Error happened when delete friend, got %v", err) } AssertAssociationCount(t, user2, "Friends", 1, "after delete non-existing data") if err := DB.Model(&user2).Association("Friends").Delete(&friend2); err != nil { t.Fatalf("Error happened when delete Friends, got %v", err) } AssertAssociationCount(t, user2, "Friends", 0, "after delete") // Prepare Data for Clear if err := DB.Model(&user2).Association("Friends").Append(&friend); err != nil { t.Fatalf("Error happened when append Friends, got %v", err) } AssertAssociationCount(t, user2, "Friends", 1, "after prepare data") // Clear if err := DB.Model(&user2).Association("Friends").Clear(); err != nil { t.Errorf("Error happened when clear Friends, got %v", err) } AssertAssociationCount(t, user2, "Friends", 0, "after clear") } func TestSingleTableMany2ManyAssociationForSlice(t *testing.T) { users := []User{ *GetUser("slice-many2many-1", Config{Team: 2}), *GetUser("slice-many2many-2", Config{Team: 0}), *GetUser("slice-many2many-3", Config{Team: 4}), } DB.Create(&users) // Count AssertAssociationCount(t, users, "Team", 6, "") // Find var teams []User if DB.Model(&users).Association("Team").Find(&teams); len(teams) != 6 { t.Errorf("teams count should be %v, but got %v", 6, len(teams)) } // Append teams1 := []User{*GetUser("friend-append-1", Config{})} teams2 := []User{} teams3 := []*User{GetUser("friend-append-3-1", Config{}), GetUser("friend-append-3-2", Config{})} DB.Model(&users).Association("Team").Append(&teams1, &teams2, &teams3) AssertAssociationCount(t, users, "Team", 9, "After Append") teams2_1 := []User{*GetUser("friend-replace-1", Config{}), *GetUser("friend-replace-2", Config{})} teams2_2 := []User{*GetUser("friend-replace-2-1", Config{}), *GetUser("friend-replace-2-2", Config{})} teams2_3 := GetUser("friend-replace-3-1", Config{}) // Replace DB.Model(&users).Association("Team").Replace(&teams2_1, &teams2_2, teams2_3) AssertAssociationCount(t, users, "Team", 5, "After Replace") // Delete if err := DB.Model(&users).Association("Team").Delete(&users[2].Team); err != nil { t.Errorf("no error should happened when deleting team, but got %v", err) } AssertAssociationCount(t, users, "Team", 4, "after delete") if err := DB.Model(&users).Association("Team").Delete(users[0].Team[0], users[1].Team[1]); err != nil { t.Errorf("no error should happened when deleting team, but got %v", err) } AssertAssociationCount(t, users, "Team", 2, "after delete") // Clear DB.Model(&users).Association("Team").Clear() AssertAssociationCount(t, users, "Team", 0, "After Clear") } func TestDuplicateMany2ManyAssociation(t *testing.T) { user1 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ {Code: "TestDuplicateMany2ManyAssociation-language-1"}, {Code: "TestDuplicateMany2ManyAssociation-language-2"}, }} user2 := User{Name: "TestDuplicateMany2ManyAssociation-1", Languages: []Language{ {Code: "TestDuplicateMany2ManyAssociation-language-1"}, {Code: "TestDuplicateMany2ManyAssociation-language-3"}, }} users := []*User{&user1, &user2} var err error err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error AssertEqual(t, nil, err) var findUser1 User err = DB.Preload("Languages").Where("id = ?", user1.ID).First(&findUser1).Error AssertEqual(t, nil, err) AssertEqual(t, user1, findUser1) var findUser2 User err = DB.Preload("Languages").Where("id = ?", user2.ID).First(&findUser2).Error AssertEqual(t, nil, err) AssertEqual(t, user2, findUser2) } func TestConcurrentMany2ManyAssociation(t *testing.T) { db, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("open test connection failed, err: %+v", err) } count := 3 var languages []Language for i := 0; i < count; i++ { language := Language{Code: fmt.Sprintf("consurrent %d", i)} db.Create(&language) languages = append(languages, language) } user := User{} db.Create(&user) db.Preload("Languages").FirstOrCreate(&user) var wg sync.WaitGroup for i := 0; i < count; i++ { wg.Add(1) go func(user User, language Language) { err := db.Model(&user).Association("Languages").Append(&language) AssertEqual(t, err, nil) wg.Done() }(user, languages[i]) } wg.Wait() var find User err = db.Preload(clause.Associations).Where("id = ?", user.ID).First(&find).Error AssertEqual(t, err, nil) AssertAssociationCount(t, find, "Languages", int64(count), "after concurrent append") } func TestMany2ManyDuplicateBelongsToAssociation(t *testing.T) { user1 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-1", Friends: []*User{ {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-1", Company: Company{ ID: 1, Name: "Test-company-1", }}, }} user2 := User{Name: "TestMany2ManyDuplicateBelongsToAssociation-2", Friends: []*User{ {Name: "TestMany2ManyDuplicateBelongsToAssociation-friend-2", Company: Company{ ID: 1, Name: "Test-company-1", }}, }} users := []*User{&user1, &user2} var err error err = DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(users).Error AssertEqual(t, nil, err) var findUser1 User err = DB.Preload("Friends.Company").Where("id = ?", user1.ID).First(&findUser1).Error AssertEqual(t, nil, err) AssertEqual(t, user1, findUser1) var findUser2 User err = DB.Preload("Friends.Company").Where("id = ?", user2.ID).First(&findUser2).Error AssertEqual(t, nil, err) AssertEqual(t, user2, findUser2) } ================================================ FILE: tests/associations_test.go ================================================ package tests_test import ( "testing" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) func AssertAssociationCount(t *testing.T, data interface{}, name string, result int64, reason string) { if count := DB.Model(data).Association(name).Count(); count != result { t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) } var newUser User if user, ok := data.(User); ok { DB.Find(&newUser, "id = ?", user.ID) } else if user, ok := data.(*User); ok { DB.Find(&newUser, "id = ?", user.ID) } if newUser.ID != 0 { if count := DB.Model(&newUser).Association(name).Count(); count != result { t.Fatalf("invalid %v count %v, expects: %v got %v", name, reason, result, count) } } } func TestInvalidAssociation(t *testing.T) { user := *GetUser("invalid", Config{Company: true, Manager: true}) if err := DB.Model(&user).Association("Invalid").Find(&user.Company).Error; err == nil { t.Fatalf("should return errors for invalid association, but got nil") } } func TestAssociationNotNullClear(t *testing.T) { type Profile struct { gorm.Model Number string MemberID uint `gorm:"not null"` } type Member struct { gorm.Model Profiles []Profile } DB.Migrator().DropTable(&Member{}, &Profile{}) if err := DB.AutoMigrate(&Member{}, &Profile{}); err != nil { t.Fatalf("Failed to migrate, got error: %v", err) } member := &Member{ Profiles: []Profile{{ Number: "1", }, { Number: "2", }}, } if err := DB.Create(&member).Error; err != nil { t.Fatalf("Failed to create test data, got error: %v", err) } if err := DB.Model(member).Association("Profiles").Clear(); err == nil { t.Fatalf("No error occurred during clearind not null association") } } func TestForeignKeyConstraints(t *testing.T) { tidbSkip(t, "not support the foreign key feature") type Profile struct { ID uint Name string MemberID uint } type Member struct { ID uint Refer uint `gorm:"uniqueIndex"` Name string Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:MemberID;References:Refer"` } DB.Migrator().DropTable(&Profile{}, &Member{}) if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { t.Fatalf("Failed to migrate, got error: %v", err) } member := Member{Refer: 1, Name: "foreign_key_constraints", Profile: Profile{Name: "my_profile"}} DB.Create(&member) var profile Profile if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { t.Fatalf("failed to find profile, got error: %v", err) } else if profile.MemberID != member.ID { t.Fatalf("member id is not equal: expects: %v, got: %v", member.ID, profile.MemberID) } member.Profile = Profile{} DB.Model(&member).Update("Refer", 100) var profile2 Profile if err := DB.First(&profile2, "id = ?", profile.ID).Error; err != nil { t.Fatalf("failed to find profile, got error: %v", err) } else if profile2.MemberID != 100 { t.Fatalf("member id is not equal: expects: %v, got: %v", 100, profile2.MemberID) } if r := DB.Delete(&member); r.Error != nil || r.RowsAffected != 1 { t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) } var result Member if err := DB.First(&result, member.ID).Error; err == nil { t.Fatalf("Should not find deleted member") } if err := DB.First(&profile2, profile.ID).Error; err == nil { t.Fatalf("Should not find deleted profile") } } func TestForeignKeyConstraintsBelongsTo(t *testing.T) { tidbSkip(t, "not support the foreign key feature") type Profile struct { ID uint Name string Refer uint `gorm:"uniqueIndex"` } type Member struct { ID uint Name string ProfileID uint Profile Profile `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:ProfileID;References:Refer"` } DB.Migrator().DropTable(&Profile{}, &Member{}) if err := DB.AutoMigrate(&Profile{}, &Member{}); err != nil { t.Fatalf("Failed to migrate, got error: %v", err) } member := Member{Name: "foreign_key_constraints_belongs_to", Profile: Profile{Name: "my_profile_belongs_to", Refer: 1}} DB.Create(&member) var profile Profile if err := DB.First(&profile, "id = ?", member.Profile.ID).Error; err != nil { t.Fatalf("failed to find profile, got error: %v", err) } else if profile.Refer != member.ProfileID { t.Fatalf("member id is not equal: expects: %v, got: %v", profile.Refer, member.ProfileID) } DB.Model(&profile).Update("Refer", 100) var member2 Member if err := DB.First(&member2, "id = ?", member.ID).Error; err != nil { t.Fatalf("failed to find member, got error: %v", err) } else if member2.ProfileID != 100 { t.Fatalf("member id is not equal: expects: %v, got: %v", 100, member2.ProfileID) } if r := DB.Delete(&profile); r.Error != nil || r.RowsAffected != 1 { t.Fatalf("Should delete member, got error: %v, affected: %v", r.Error, r.RowsAffected) } var result Member if err := DB.First(&result, member.ID).Error; err == nil { t.Fatalf("Should not find deleted member") } if err := DB.First(&profile, profile.ID).Error; err == nil { t.Fatalf("Should not find deleted profile") } } func TestFullSaveAssociations(t *testing.T) { coupon := &Coupon{ AppliesToProduct: []*CouponProduct{ {ProductId: "full-save-association-product1"}, }, AmountOff: 10, PercentOff: 0.0, } err := DB. Session(&gorm.Session{FullSaveAssociations: true}). Create(coupon).Error if err != nil { t.Errorf("Failed, got error: %v", err) } if DB.First(&Coupon{}, "id = ?", coupon.ID).Error != nil { t.Errorf("Failed to query saved coupon") } if DB.First(&CouponProduct{}, "coupon_id = ? AND product_id = ?", coupon.ID, "full-save-association-product1").Error != nil { t.Errorf("Failed to query saved association") } orders := []Order{{Num: "order1", Coupon: coupon}, {Num: "order2", Coupon: coupon}} if err := DB.Create(&orders).Error; err != nil { t.Errorf("failed to create orders, got %v", err) } coupon2 := Coupon{ AppliesToProduct: []*CouponProduct{{Desc: "coupon-description"}}, } DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&coupon2) var result Coupon if err := DB.Preload("AppliesToProduct").First(&result, "id = ?", coupon2.ID).Error; err != nil { t.Errorf("Failed to create coupon w/o name, got error: %v", err) } if len(result.AppliesToProduct) != 1 { t.Errorf("Failed to preload AppliesToProduct") } } func TestSaveBelongsCircularReference(t *testing.T) { parent := Parent{} DB.Create(&parent) child := Child{ParentID: &parent.ID, Parent: &parent} DB.Create(&child) parent.FavChildID = child.ID parent.FavChild = &child DB.Save(&parent) var parent1 Parent DB.First(&parent1, parent.ID) AssertObjEqual(t, parent, parent1, "ID", "FavChildID") // Save and Updates is the same DB.Updates(&parent) DB.First(&parent1, parent.ID) AssertObjEqual(t, parent, parent1, "ID", "FavChildID") } func TestSaveHasManyCircularReference(t *testing.T) { parent := Parent{} DB.Create(&parent) child := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference"} child1 := Child{ParentID: &parent.ID, Parent: &parent, Name: "HasManyCircularReference1"} parent.Children = []*Child{&child, &child1} DB.Save(&parent) var children []*Child DB.Where("parent_id = ?", parent.ID).Find(&children) if len(children) != len(parent.Children) || children[0].ID != parent.Children[0].ID || children[1].ID != parent.Children[1].ID { t.Errorf("circular reference children save not equal children:%v parent.Children:%v", children, parent.Children) } } func TestAssociationError(t *testing.T) { user := *GetUser("TestAssociationError", Config{Pets: 2, Company: true, Account: true, Languages: 2}) DB.Create(&user) var user1 User DB.Preload("Company").Preload("Pets").Preload("Account").Preload("Languages").First(&user1) var emptyUser User var err error // belongs to err = DB.Model(&emptyUser).Association("Company").Delete(&user1.Company) AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) // has many err = DB.Model(&emptyUser).Association("Pets").Delete(&user1.Pets) AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) // has one err = DB.Model(&emptyUser).Association("Account").Delete(&user1.Account) AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) // many to many err = DB.Model(&emptyUser).Association("Languages").Delete(&user1.Languages) AssertEqual(t, err, gorm.ErrPrimaryKeyRequired) } type ( myType string emptyQueryClause struct { Field *schema.Field } ) func (myType) QueryClauses(f *schema.Field) []clause.Interface { return []clause.Interface{emptyQueryClause{Field: f}} } func (sd emptyQueryClause) Name() string { return "empty" } func (sd emptyQueryClause) Build(clause.Builder) { } func (sd emptyQueryClause) MergeClause(*clause.Clause) { } func (sd emptyQueryClause) ModifyStatement(stmt *gorm.Statement) { // do nothing } func TestAssociationEmptyQueryClause(t *testing.T) { type Organization struct { gorm.Model Name string } type Region struct { gorm.Model Name string Organizations []Organization `gorm:"many2many:region_orgs;"` } type RegionOrg struct { RegionId uint OrganizationId uint Empty myType } if err := DB.SetupJoinTable(&Region{}, "Organizations", &RegionOrg{}); err != nil { t.Fatalf("Failed to set up join table, got error: %s", err) } if err := DB.Migrator().DropTable(&Organization{}, &Region{}); err != nil { t.Fatalf("Failed to migrate, got error: %s", err) } if err := DB.AutoMigrate(&Organization{}, &Region{}); err != nil { t.Fatalf("Failed to migrate, got error: %v", err) } region := &Region{Name: "Region1"} if err := DB.Create(region).Error; err != nil { t.Fatalf("fail to create region %v", err) } var orgs []Organization if err := DB.Model(&Region{}).Association("Organizations").Find(&orgs); err != nil { t.Fatalf("fail to find region organizations %v", err) } else { AssertEqual(t, len(orgs), 0) } } type AssociationEmptyUser struct { ID uint Name string Pets []AssociationEmptyPet } type AssociationEmptyPet struct { AssociationEmptyUserID *uint `gorm:"uniqueIndex:uniq_user_id_name"` Name string `gorm:"uniqueIndex:uniq_user_id_name;size:256"` } func TestAssociationEmptyPrimaryKey(t *testing.T) { if DB.Dialector.Name() != "mysql" { t.Skip() } DB.Migrator().DropTable(&AssociationEmptyUser{}, &AssociationEmptyPet{}) DB.AutoMigrate(&AssociationEmptyUser{}, &AssociationEmptyPet{}) id := uint(100) user := AssociationEmptyUser{ ID: id, Name: "jinzhu", Pets: []AssociationEmptyPet{ {AssociationEmptyUserID: &id, Name: "bar"}, {AssociationEmptyUserID: &id, Name: "foo"}, }, } err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Create(&user).Error if err != nil { t.Fatalf("Failed to create, got error: %v", err) } var result AssociationEmptyUser err = DB.Preload("Pets").First(&result, &id).Error if err != nil { t.Fatalf("Failed to find, got error: %v", err) } AssertEqual(t, result, user) } // Ensure Association.Append/Replace supports map for many2many func TestAssociationMany2ManyAppendMap(t *testing.T) { user := *GetUser("assoc_m2m_append_map", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } // Append single map if err := DB.Model(&user).Association("Languages").Append(map[string]interface{}{ "code": "am2m_map_1", "name": "AppendMap1", }); err != nil { t.Fatalf("append map: %v", err) } AssertAssociationCount(t, user, "Languages", 1, "after append 1 map") // Append more maps individually if err := DB.Model(&user).Association("Languages").Append(map[string]interface{}{"code": "am2m_map_2", "name": "AppendMap2"}); err != nil { t.Fatalf("append map 2: %v", err) } if err := DB.Model(&user).Association("Languages").Append(map[string]interface{}{"code": "am2m_map_3", "name": "AppendMap3"}); err != nil { t.Fatalf("append map 3: %v", err) } AssertAssociationCount(t, user, "Languages", 3, "after append 3 maps total") // Verify codes exist var langs []Language if err := DB.Model(&user).Association("Languages").Find(&langs); err != nil { t.Fatalf("find languages: %v", err) } codeSet := map[string]bool{} for _, l := range langs { codeSet[l.Code] = true } for _, c := range []string{"am2m_map_1", "am2m_map_2", "am2m_map_3"} { if !codeSet[c] { t.Fatalf("expected language code %s present", c) } } } func TestAssociationMany2ManyReplaceMap(t *testing.T) { user := *GetUser("assoc_m2m_replace_map", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("create user: %v", err) } // Prime with one language if err := DB.Model(&user).Association("Languages").Append(&Language{Code: "prime", Name: "Prime"}); err != nil { t.Fatalf("prime append: %v", err) } AssertAssociationCount(t, user, "Languages", 1, "before replace") // Replace with a new map value if err := DB.Model(&user).Association("Languages").Replace(map[string]interface{}{ "code": "rm2m_map_1", "name": "ReplaceMap1", }); err != nil { t.Fatalf("replace map: %v", err) } AssertAssociationCount(t, user, "Languages", 1, "after replace with 1 map") var langs []Language if err := DB.Model(&user).Association("Languages").Find(&langs); err != nil { t.Fatalf("find languages after replace: %v", err) } if len(langs) != 1 || langs[0].Code != "rm2m_map_1" { t.Fatalf("expected only rm2m_map_1 after replace, got %+v", langs) } } ================================================ FILE: tests/benchmark_test.go ================================================ package tests_test import ( "fmt" "testing" . "gorm.io/gorm/utils/tests" ) func BenchmarkCreate(b *testing.B) { user := *GetUser("bench", Config{}) for x := 0; x < b.N; x++ { user.ID = 0 DB.Create(&user) } } func BenchmarkFind(b *testing.B) { user := *GetUser("find", Config{}) DB.Create(&user) for x := 0; x < b.N; x++ { DB.Find(&User{}, "id = ?", user.ID) } } func BenchmarkScan(b *testing.B) { user := *GetUser("scan", Config{}) DB.Create(&user) var u User b.ResetTimer() for x := 0; x < b.N; x++ { DB.Raw("select * from users where id = ?", user.ID).Scan(&u) } } func BenchmarkScanSlice(b *testing.B) { DB.Exec("delete from users") for i := 0; i < 10_000; i++ { user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) DB.Create(&user) } var u []User b.ResetTimer() for x := 0; x < b.N; x++ { DB.Raw("select * from users").Scan(&u) } } func BenchmarkScanSlicePointer(b *testing.B) { DB.Exec("delete from users") for i := 0; i < 10_000; i++ { user := *GetUser(fmt.Sprintf("scan-%d", i), Config{}) DB.Create(&user) } var u []*User b.ResetTimer() for x := 0; x < b.N; x++ { DB.Raw("select * from users").Scan(&u) } } func BenchmarkUpdate(b *testing.B) { user := *GetUser("find", Config{}) DB.Create(&user) for x := 0; x < b.N; x++ { DB.Model(&user).Updates(map[string]interface{}{"Age": x}) } } func BenchmarkDelete(b *testing.B) { user := *GetUser("find", Config{}) for x := 0; x < b.N; x++ { user.ID = 0 DB.Create(&user) DB.Delete(&user) } } ================================================ FILE: tests/callbacks_test.go ================================================ package tests_test import ( "fmt" "reflect" "runtime" "strings" "testing" "gorm.io/gorm" ) func assertCallbacks(v interface{}, fnames []string) (result bool, msg string) { var ( got []string funcs = reflect.ValueOf(v).Elem().FieldByName("fns") ) for i := 0; i < funcs.Len(); i++ { got = append(got, getFuncName(funcs.Index(i))) } return fmt.Sprint(got) == fmt.Sprint(fnames), fmt.Sprintf("expects %v, got %v", fnames, got) } func getFuncName(fc interface{}) string { reflectValue, ok := fc.(reflect.Value) if !ok { reflectValue = reflect.ValueOf(fc) } fnames := strings.Split(runtime.FuncForPC(reflectValue.Pointer()).Name(), ".") return fnames[len(fnames)-1] } func c1(*gorm.DB) {} func c2(*gorm.DB) {} func c3(*gorm.DB) {} func c4(*gorm.DB) {} func c5(*gorm.DB) {} func c6(*gorm.DB) {} func TestCallbacks(t *testing.T) { type callback struct { name string before string after string remove bool replace bool err string match func(*gorm.DB) bool h func(*gorm.DB) } datas := []struct { callbacks []callback err string results []string }{ { callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5}}, results: []string{"c1", "c2", "c3", "c4", "c5"}, }, { callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4}, {h: c5, before: "c4"}}, results: []string{"c1", "c2", "c3", "c5", "c4"}, }, { callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5}}, results: []string{"c1", "c2", "c3", "c5", "c4"}, }, { callbacks: []callback{{h: c1}, {h: c2}, {h: c3}, {h: c4, after: "c5"}, {h: c5, before: "c4"}}, results: []string{"c1", "c2", "c3", "c5", "c4"}, }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}}, results: []string{"c1", "c5", "c2", "c3", "c4"}, }, { callbacks: []callback{{h: c1, after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, results: []string{"c3", "c1", "c5", "c2", "c4"}, }, { callbacks: []callback{{h: c1, before: "c4", after: "c3"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, results: []string{"c3", "c1", "c5", "c2", "c4"}, }, { callbacks: []callback{{h: c1, before: "c3", after: "c4"}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c5"}, {h: c4}, {h: c5}}, err: "conflicting", }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5}, {h: c2, remove: true}}, results: []string{"c1", "c3", "c4", "c5"}, }, { callbacks: []callback{{h: c1}, {name: "c", h: c2}, {h: c3}, {name: "c", h: c4, replace: true}}, results: []string{"c1", "c4", "c3"}, }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3}, {h: c4}, {h: c5, before: "*"}}, results: []string{"c5", "c1", "c2", "c3", "c4"}, }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "*"}, {h: c4}, {h: c5, before: "*"}}, results: []string{"c3", "c5", "c1", "c2", "c4"}, }, { callbacks: []callback{{h: c1}, {h: c2, before: "c4", after: "c5"}, {h: c3, before: "c4", after: "*"}, {h: c4, after: "*"}, {h: c5, before: "*"}}, results: []string{"c5", "c1", "c2", "c3", "c4"}, }, } for idx, data := range datas { db, err := gorm.Open(nil, nil) if err != nil { t.Fatal(err) } callbacks := db.Callback() for _, c := range data.callbacks { var v interface{} = callbacks.Create() callMethod := func(s interface{}, name string, args ...interface{}) { var argValues []reflect.Value for _, arg := range args { argValues = append(argValues, reflect.ValueOf(arg)) } results := reflect.ValueOf(s).MethodByName(name).Call(argValues) if len(results) > 0 { v = results[0].Interface() } } if c.name == "" { c.name = getFuncName(c.h) } if c.before != "" { callMethod(v, "Before", c.before) } if c.after != "" { callMethod(v, "After", c.after) } if c.match != nil { callMethod(v, "Match", c.match) } if c.remove { callMethod(v, "Remove", c.name) } else if c.replace { callMethod(v, "Replace", c.name, c.h) } else { callMethod(v, "Register", c.name, c.h) } if e, ok := v.(error); !ok || e != nil { err = e } } if len(data.err) > 0 && err == nil { t.Errorf("callbacks tests #%v should got error %v, but not", idx+1, data.err) } else if len(data.err) == 0 && err != nil { t.Errorf("callbacks tests #%v should not got error, but got %v", idx+1, err) } if ok, msg := assertCallbacks(callbacks.Create(), data.results); !ok { t.Errorf("callbacks tests #%v failed, got %v", idx+1, msg) } } } func TestPluginCallbacks(t *testing.T) { db, _ := gorm.Open(nil, nil) createCallback := db.Callback().Create() createCallback.Before("*").Register("plugin_1_fn1", c1) createCallback.After("*").Register("plugin_1_fn2", c2) if ok, msg := assertCallbacks(createCallback, []string{"c1", "c2"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } // plugin 2 createCallback.Before("*").Register("plugin_2_fn1", c3) if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } createCallback.After("*").Register("plugin_2_fn2", c4) if ok, msg := assertCallbacks(createCallback, []string{"c3", "c1", "c2", "c4"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } // plugin 3 createCallback.Before("*").Register("plugin_3_fn1", c5) if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } createCallback.After("*").Register("plugin_3_fn2", c6) if ok, msg := assertCallbacks(createCallback, []string{"c5", "c3", "c1", "c2", "c4", "c6"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } } func TestCallbacksGet(t *testing.T) { db, _ := gorm.Open(nil, nil) createCallback := db.Callback().Create() createCallback.Before("*").Register("c1", c1) if cb := createCallback.Get("c1"); reflect.DeepEqual(cb, c1) { t.Errorf("callbacks tests failed, got: %p, want: %p", cb, c1) } createCallback.Remove("c1") if cb := createCallback.Get("c2"); cb != nil { t.Errorf("callbacks test failed. got: %p, want: nil", cb) } } func TestCallbacksRemove(t *testing.T) { db, _ := gorm.Open(nil, nil) createCallback := db.Callback().Create() createCallback.Before("*").Register("c1", c1) createCallback.After("*").Register("c2", c2) createCallback.Before("c4").Register("c3", c3) createCallback.After("c2").Register("c4", c4) // callbacks: []string{"c1", "c3", "c4", "c2"} createCallback.Remove("c1") if ok, msg := assertCallbacks(createCallback, []string{"c3", "c4", "c2"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } createCallback.Remove("c4") if ok, msg := assertCallbacks(createCallback, []string{"c3", "c2"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } createCallback.Remove("c2") if ok, msg := assertCallbacks(createCallback, []string{"c3"}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } createCallback.Remove("c3") if ok, msg := assertCallbacks(createCallback, []string{}); !ok { t.Errorf("callbacks tests failed, got %v", msg) } } ================================================ FILE: tests/chainable_api_test.go ================================================ package tests import ( "context" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/schema" ) // testDialector is a minimal Dialector implementation used only for unit tests in-memory. type testDialector struct{} func (d testDialector) Name() string { return "test" } func (d testDialector) Initialize(*gorm.DB) error { return nil } func (d testDialector) Migrator(db *gorm.DB) gorm.Migrator { return nil } func (d testDialector) DataTypeOf(*schema.Field) string { return "" } func (d testDialector) DefaultValueOf(*schema.Field) clause.Expression { return clause.Expr{} } func (d testDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { // write a simple placeholder writer.WriteByte('?') } func (d testDialector) QuoteTo(writer clause.Writer, s string) { writer.WriteString(s) } func (d testDialector) Explain(sql string, vars ...interface{}) string { return sql } // newTestDB returns a minimal *DB with an initialized Statement suitable for unit tests func newTestDB() *gorm.DB { d := testDialector{} cfg := &gorm.Config{Dialector: d} db := &gorm.DB{Config: cfg} stmt := &gorm.Statement{ DB: db, Clauses: map[string]clause.Clause{}, Preloads: map[string][]interface{}{}, Context: context.Background(), Vars: make([]interface{}, 0), } db.Statement = stmt return db } func TestChainableAPI(t *testing.T) { db := newTestDB() // Model m := &struct{ ID int }{} tx := db.Model(m) if tx.Statement.Model != m { t.Fatalf("Model not set, got %v", tx.Statement.Model) } // Table tx = tx.Table("users") if tx.Statement.Table != "users" { t.Fatalf("Table not set, got %v", tx.Statement.Table) } if tx.Statement.TableExpr == nil { t.Fatalf("TableExpr expected to be set") } // Distinct + Select tx = tx.Distinct("name", "age") if !tx.Statement.Distinct { t.Fatalf("Distinct expected true") } if len(tx.Statement.Selects) != 2 || tx.Statement.Selects[0] != "name" { t.Fatalf("Selects expected [name age], got %v", tx.Statement.Selects) } // Where tx = tx.Where("age = ?", 20) c, ok := tx.Statement.Clauses["WHERE"] if !ok { t.Fatalf("WHERE clause expected") } if where, ok := c.Expression.(clause.Where); !ok || len(where.Exprs) == 0 { t.Fatalf("WHERE expressions expected, got %v", c.Expression) } // Order tx = tx.Order("name DESC") if _, ok := tx.Statement.Clauses["ORDER BY"]; !ok { t.Fatalf("ORDER BY clause expected") } // Limit / Offset tx = tx.Limit(10).Offset(5) if cl, ok := tx.Statement.Clauses["LIMIT"]; !ok { t.Fatalf("LIMIT clause expected") } else { if limit, ok := cl.Expression.(clause.Limit); !ok || limit.Limit == nil || *limit.Limit != 10 || limit.Offset != 5 { t.Fatalf("LIMIT/Offset values unexpected: %v", cl.Expression) } } // Joins tx = tx.Joins("JOIN accounts ON accounts.user_id = users.id") if len(tx.Statement.Joins) == 0 { t.Fatalf("Joins expected") } if tx.Statement.Joins[0].Name != "JOIN accounts ON accounts.user_id = users.id" { t.Fatalf("Join name mismatch: %v", tx.Statement.Joins[0].Name) } // Preload tx = tx.Preload("Orders", "state != ?", "cancelled") args, ok := tx.Statement.Preloads["Orders"] if !ok || len(args) != 2 { t.Fatalf("Preload expected with args, got %v", tx.Statement.Preloads) } // Scopes: just ensure calling Scopes doesn't panic and returns a DB tx = tx.Scopes(func(d *gorm.DB) *gorm.DB { return d.Where("status = ?", "ok") }) if tx == nil { t.Fatalf("Scopes returned nil") } // Unscoped tx = tx.Unscoped() if !tx.Statement.Unscoped { t.Fatalf("Unscoped expected to be true") } // Raw tx = tx.Raw("SELECT ? as x", 1) if tx.Statement.SQL.Len() == 0 { t.Fatalf("Raw SQL expected to be built") } if len(tx.Statement.Vars) != 1 || tx.Statement.Vars[0] != 1 { t.Fatalf("Raw Vars expected to contain 1, got %v", tx.Statement.Vars) } } ================================================ FILE: tests/compose.yml ================================================ services: mysql: image: 'mysql:latest' ports: - "127.0.0.1:9910:3306" environment: - MYSQL_DATABASE=gorm - MYSQL_USER=gorm - MYSQL_PASSWORD=gorm - MYSQL_RANDOM_ROOT_PASSWORD="yes" postgres: image: 'postgres:latest' ports: - "127.0.0.1:9920:5432" environment: - TZ=Asia/Shanghai - POSTGRES_DB=gorm - POSTGRES_USER=gorm - POSTGRES_PASSWORD=gorm mssql: image: '${MSSQL_IMAGE}:latest' ports: - "127.0.0.1:9930:1433" environment: - TZ=Asia/Shanghai - ACCEPT_EULA=Y - MSSQL_SA_PASSWORD=LoremIpsum86 tidb: image: 'pingcap/tidb:v6.5.0' ports: - "127.0.0.1:9940:4000" command: /tidb-server -store unistore -path "" -lease 0s > tidb.log 2>&1 & gaussdb: image: 'opengauss/opengauss:7.0.0-RC1.B023' hostname: opengauss-server ports: - "127.0.0.1:9950:5432" environment: - TZ=Asia/Shanghai - GS_PASSWORD=Gaussdb@123 - GS_CLUSTER_NAME=opengauss_cluster - PGDATA=/var/lib/opengauss/data entrypoint: "" command: |- /bin/sh -c 'set -euo pipefail; /usr/local/bin/entrypoint.sh gaussdb & counter=1; while [ "$$counter" -le 20 ]; do if su - omm -c "gsql -U omm -d postgres -c \"SELECT 1;\""; then echo "Creating database gorm..."; su - omm -c "gsql -U omm -d postgres -c \"CREATE DATABASE gorm DBCOMPATIBILITY '\'PG\'';\""; echo "Database initialized successfully"; break; fi; echo "Waiting for database to be ready... ($$counter/12)"; sleep 5; counter=$$(($$counter + 1)); done; # timeout handling if [ $$counter -gt 20 ]; then echo "Error: Database failed to start within timeout"; exit 1; fi; # keep the container running: wait for the database process in the foreground wait ' ================================================ FILE: tests/connection_test.go ================================================ package tests_test import ( "testing" "gorm.io/driver/mysql" "gorm.io/gorm" ) func TestWithSingleConnection(t *testing.T) { expectedName := "test" var actualName string setSQL, getSQL := getSetSQL(DB.Dialector.Name()) if len(setSQL) == 0 || len(getSQL) == 0 { return } err := DB.Connection(func(tx *gorm.DB) error { if err := tx.Exec(setSQL, expectedName).Error; err != nil { return err } if err := tx.Raw(getSQL).Scan(&actualName).Error; err != nil { return err } return nil }) if err != nil { t.Errorf("WithSingleConnection should work, but got err %v", err) } if actualName != expectedName { t.Errorf("WithSingleConnection() method should get correct value, expect: %v, got %v", expectedName, actualName) } } func getSetSQL(driverName string) (string, string) { switch driverName { case mysql.Dialector{}.Name(): return "SET @testName := ?", "SELECT @testName" default: return "", "" } } ================================================ FILE: tests/connpool_test.go ================================================ package tests_test import ( "context" "database/sql" "os" "reflect" "testing" "gorm.io/driver/mysql" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) type wrapperTx struct { *sql.Tx conn *wrapperConnPool } func (c *wrapperTx) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { c.conn.got = append(c.conn.got, query) return c.Tx.PrepareContext(ctx, query) } func (c *wrapperTx) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { c.conn.got = append(c.conn.got, query) return c.Tx.ExecContext(ctx, query, args...) } func (c *wrapperTx) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { c.conn.got = append(c.conn.got, query) return c.Tx.QueryContext(ctx, query, args...) } func (c *wrapperTx) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { c.conn.got = append(c.conn.got, query) return c.Tx.QueryRowContext(ctx, query, args...) } type wrapperConnPool struct { db *sql.DB got []string expect []string } func (c *wrapperConnPool) Ping() error { return c.db.Ping() } // If you use BeginTx returned *sql.Tx as shown below then you can't record queries in a transaction. // // func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { // return c.db.BeginTx(ctx, opts) // } // // You should use BeginTx returned gorm.Tx which could wrap *sql.Tx then you can record all queries. func (c *wrapperConnPool) BeginTx(ctx context.Context, opts *sql.TxOptions) (gorm.ConnPool, error) { tx, err := c.db.BeginTx(ctx, opts) if err != nil { return nil, err } return &wrapperTx{Tx: tx, conn: c}, nil } func (c *wrapperConnPool) PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) { c.got = append(c.got, query) return c.db.PrepareContext(ctx, query) } func (c *wrapperConnPool) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { c.got = append(c.got, query) return c.db.ExecContext(ctx, query, args...) } func (c *wrapperConnPool) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) { c.got = append(c.got, query) return c.db.QueryContext(ctx, query, args...) } func (c *wrapperConnPool) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { c.got = append(c.got, query) return c.db.QueryRowContext(ctx, query, args...) } func TestConnPoolWrapper(t *testing.T) { dialect := os.Getenv("GORM_DIALECT") if dialect != "mysql" { t.SkipNow() } dbDSN := os.Getenv("GORM_DSN") if dbDSN == "" { dbDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" } nativeDB, err := sql.Open("mysql", dbDSN) if err != nil { t.Fatalf("Should open db success, but got %v", err) } conn := &wrapperConnPool{ db: nativeDB, expect: []string{ "SELECT VERSION()", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "INSERT INTO `users` (`created_at`,`updated_at`,`deleted_at`,`name`,`age`,`birthday`,`company_id`,`manager_id`,`active`) VALUES (?,?,?,?,?,?,?,?,?)", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", "SELECT * FROM `users` WHERE name = ? AND `users`.`deleted_at` IS NULL ORDER BY `users`.`id` LIMIT ?", }, } defer func() { if !reflect.DeepEqual(conn.got, conn.expect) { t.Errorf("expect %#v but got %#v", conn.expect, conn.got) } }() db, err := gorm.Open(mysql.New(mysql.Config{Conn: conn, DisableWithReturning: true})) db.Logger = DB.Logger if err != nil { t.Fatalf("Should open db success, but got %v", err) } tx := db.Begin() user := *GetUser("transaction", Config{}) if err = tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } if err = tx.First(&User{}, "name = ?", "transaction").Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } user1 := *GetUser("transaction1-1", Config{}) if err = tx.Save(&user1).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } if err = tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { t.Fatalf("Should return the underlying sql.Tx") } tx.Rollback() if err = db.First(&User{}, "name = ?", "transaction").Error; err == nil { t.Fatalf("Should not find record after rollback, but got %v", err) } txDB := db.Where("fake_name = ?", "fake_name") tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() user2 := *GetUser("transaction-2", Config{}) if err = tx2.Save(&user2).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } if err = tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } tx2.Commit() if err = db.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should be able to find committed record, but got %v", err) } } ================================================ FILE: tests/count_test.go ================================================ package tests_test import ( "regexp" "sort" "strings" "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestCountWithGroup(t *testing.T) { DB.Create([]Company{ {Name: "company_count_group_a"}, {Name: "company_count_group_a"}, {Name: "company_count_group_a"}, {Name: "company_count_group_b"}, {Name: "company_count_group_c"}, }) var count1 int64 if err := DB.Model(&Company{}).Where("name = ?", "company_count_group_a").Group("name").Count(&count1).Error; err != nil { t.Errorf("Count should work, but got err %v", err) } if count1 != 1 { t.Errorf("Count with group should be 1, but got count: %v", count1) } var count2 int64 if err := DB.Model(&Company{}).Where("name in ?", []string{"company_count_group_b", "company_count_group_c"}).Group("name").Count(&count2).Error; err != nil { t.Errorf("Count should work, but got err %v", err) } if count2 != 2 { t.Errorf("Count with group should be 2, but got count: %v", count2) } } func TestCount(t *testing.T) { var ( user1 = *GetUser("count-1", Config{}) user2 = *GetUser("count-2", Config{}) user3 = *GetUser("count-3", Config{}) users []User count, count1, count2 int64 ) DB.Save(&user1).Save(&user2).Save(&user3) if err := DB.Where("name = ?", user1.Name).Or("name = ?", user3.Name).Find(&users).Count(&count).Error; err != nil { t.Errorf("Count should work, but got err %v", err) } if count != int64(len(users)) { t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) } if err := DB.Model(&User{}).Where("name = ?", user1.Name).Or("name = ?", user3.Name).Count(&count).Find(&users).Error; err != nil { t.Errorf("Count should work, but got err %v", err) } if count != int64(len(users)) { t.Errorf("Count() method should get correct value, expect: %v, got %v", count, len(users)) } DB.Model(&User{}).Where("name = ?", user1.Name).Count(&count1).Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { t.Errorf("multiple count in chain should works") } tx := DB.Model(&User{}).Where("name = ?", user1.Name).Session(&gorm.Session{}) tx.Count(&count1) tx.Or("name in ?", []string{user2.Name, user3.Name}).Count(&count2) if count1 != 1 || count2 != 3 { t.Errorf("count after new session should works") } var count3 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user2.Name, user2.Name, user3.Name}).Group("id").Count(&count3).Error; err != nil { t.Errorf("Error happened when count with group, but got %v", err) } if count3 != 2 { t.Errorf("Should get correct count for count with group, but got %v", count3) } dryDB := DB.Session(&gorm.Session{DryRun: true}) result := dryDB.Table("users").Select("name").Count(&count) if !regexp.MustCompile(`SELECT COUNT\(.name.\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) } result = dryDB.Table("users").Distinct("name").Count(&count) if !regexp.MustCompile(`SELECT COUNT\(DISTINCT\(.name.\)\) FROM .*users.*`).MatchString(result.Statement.SQL.String()) { t.Fatalf("Build count with select, but got %v", result.Statement.SQL.String()) } var count4 int64 if err := DB.Table("users").Joins("LEFT JOIN companies on companies.name = users.name").Where("users.name = ?", user1.Name).Count(&count4).Error; err != nil || count4 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count4) } var count5 int64 if err := DB.Table("users").Where("users.name = ?", user1.Name).Order("name").Count(&count5).Error; err != nil || count5 != 1 { t.Errorf("count with join, got error: %v, count %v", err, count) } var count6 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( "(CASE WHEN name=? THEN ? ELSE ? END) as name", "count-1", "main", "other", ).Count(&count6).Find(&users).Error; err != nil || count6 != 3 { t.Fatalf("Count should work, but got err %v", err) } expects := []User{{Name: "main"}, {Name: "other"}, {Name: "other"}} sort.SliceStable(users, func(i, j int) bool { return strings.Compare(users[i].Name, users[j].Name) < 0 }) AssertEqual(t, users, expects) var count7 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( "(CASE WHEN name=? THEN ? ELSE ? END) as name, age", "count-1", "main", "other", ).Count(&count7).Find(&users).Error; err != nil || count7 != 3 { t.Fatalf("Count should work, but got err %v", err) } expects = []User{{Name: "main", Age: 18}, {Name: "other", Age: 18}, {Name: "other", Age: 18}} sort.SliceStable(users, func(i, j int) bool { return strings.Compare(users[i].Name, users[j].Name) < 0 }) AssertEqual(t, users, expects) var count8 int64 if err := DB.Model(&User{}).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Select( "(CASE WHEN age=18 THEN 1 ELSE 2 END) as age", "name", ).Count(&count8).Find(&users).Error; err != nil || count8 != 3 { t.Fatalf("Count should work, but got err %v", err) } expects = []User{{Name: "count-1", Age: 1}, {Name: "count-2", Age: 1}, {Name: "count-3", Age: 1}} sort.SliceStable(users, func(i, j int) bool { return strings.Compare(users[i].Name, users[j].Name) < 0 }) AssertEqual(t, users, expects) var count9 int64 if err := DB.Scopes(func(tx *gorm.DB) *gorm.DB { return tx.Table("users") }).Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count9).Find(&users).Error; err != nil || count9 != 3 { t.Fatalf("Count should work, but got err %v", err) } var count10 int64 if err := DB.Model(&User{}).Select("*").Where("name in ?", []string{user1.Name, user2.Name, user3.Name}).Count(&count10).Error; err != nil || count10 != 3 { t.Fatalf("Count should be 3, but got count: %v err %v", count10, err) } var count11 int64 sameUsers := make([]*User, 0) for i := 0; i < 3; i++ { sameUsers = append(sameUsers, GetUser("count-4", Config{})) } DB.Create(sameUsers) if err := DB.Model(&User{}).Where("name = ?", "count-4").Group("name").Count(&count11).Error; err != nil || count11 != 1 { t.Fatalf("Count should be 1, but got count: %v err %v", count11, err) } var count12 int64 if err := DB.Table("users"). Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). Preload("Toys", func(db *gorm.DB) *gorm.DB { return db.Table("toys").Select("name") }).Count(&count12).Error; err == nil { t.Errorf("error should raise when using preload without schema") } var count13 int64 if err := DB.Model(User{}). Where("name in ?", []string{user1.Name, user2.Name, user3.Name}). Preload("Toys", func(db *gorm.DB) *gorm.DB { return db.Table("toys").Select("name") }).Count(&count13).Error; err != nil { t.Errorf("no error should raise when using count with preload, but got %v", err) } } ================================================ FILE: tests/create_test.go ================================================ package tests_test import ( "errors" "fmt" "regexp" "testing" "time" "github.com/jinzhu/now" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestCreate(t *testing.T) { u1 := *GetUser("create", Config{}) if results := DB.Create(&u1); results.Error != nil { t.Fatalf("errors happened when create: %v", results.Error) } else if results.RowsAffected != 1 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } if u1.ID == 0 { t.Errorf("user's primary key should has value after create, got : %v", u1.ID) } if u1.CreatedAt.IsZero() { t.Errorf("user's created at should be not zero") } if u1.UpdatedAt.IsZero() { t.Errorf("user's updated at should be not zero") } var newUser User if err := DB.Where("id = ?", u1.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { CheckUser(t, newUser, u1) } type user struct { ID int `gorm:"primaryKey;->:false"` Name string Age int } var u2 user if results := DB.Create(&u2); results.Error != nil { t.Fatalf("errors happened when create: %v", results.Error) } else if results.RowsAffected != 1 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } if u2.ID != 0 { t.Errorf("don't have the permission to read primary key from db, but got %v", u2.ID) } } func TestCreateInBatches(t *testing.T) { users := []User{ *GetUser("create_in_batches_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), *GetUser("create_in_batches_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), *GetUser("create_in_batches_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), *GetUser("create_in_batches_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), *GetUser("create_in_batches_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), *GetUser("create_in_batches_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), } result := DB.CreateInBatches(&users, 2) if result.RowsAffected != int64(len(users)) { t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) } for _, user := range users { if user.ID == 0 { t.Fatalf("failed to fill user's ID, got %v", user.ID) } else { var newUser User if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { CheckUser(t, newUser, user) } } } } func TestCreateInBatchesWithDefaultSize(t *testing.T) { users := []User{ *GetUser("create_with_default_batch_size_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), *GetUser("create_with_default_batch_sizs_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), *GetUser("create_with_default_batch_sizs_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), *GetUser("create_with_default_batch_sizs_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), *GetUser("create_with_default_batch_sizs_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), *GetUser("create_with_default_batch_sizs_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), } result := DB.Session(&gorm.Session{CreateBatchSize: 2}).Create(&users) if result.RowsAffected != int64(len(users)) { t.Errorf("affected rows should be %v, but got %v", len(users), result.RowsAffected) } for _, user := range users { if user.ID == 0 { t.Fatalf("failed to fill user's ID, got %v", user.ID) } else { var newUser User if err := DB.Where("id = ?", user.ID).Preload(clause.Associations).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { CheckUser(t, newUser, user) } } } } func TestCreateFromMap(t *testing.T) { if err := DB.Model(&User{}).Create(map[string]interface{}{"Name": "create_from_map", "Age": 18}).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } var result User if err := DB.Where("name = ?", "create_from_map").First(&result).Error; err != nil || result.Age != 18 { t.Fatalf("failed to create from map, got error %v", err) } if err := DB.Model(&User{}).Create(map[string]interface{}{"name": "create_from_map_1", "age": 18}).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } var result1 User if err := DB.Where("name = ?", "create_from_map_1").First(&result1).Error; err != nil || result1.Age != 18 { t.Fatalf("failed to create from map, got error %v", err) } datas := []map[string]interface{}{ {"Name": "create_from_map_2", "Age": 19}, {"name": "create_from_map_3", "Age": 20}, } if err := DB.Model(&User{}).Create(&datas).Error; err != nil { t.Fatalf("failed to create data from slice of map, got error: %v", err) } var result2 User if err := DB.Where("name = ?", "create_from_map_2").First(&result2).Error; err != nil || result2.Age != 19 { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } var result3 User if err := DB.Where("name = ?", "create_from_map_3").First(&result3).Error; err != nil || result3.Age != 20 { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } } func TestCreateWithAssociations(t *testing.T) { user := *GetUser("create_with_associations", Config{ Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1, }) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) var user2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) } func TestBulkCreateWithAssociations(t *testing.T) { users := []User{ *GetUser("bulk_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), *GetUser("bulk_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), *GetUser("bulk_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), *GetUser("bulk_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), *GetUser("bulk_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), *GetUser("bulk_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), *GetUser("bulk_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), *GetUser("bulk_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), } if results := DB.Create(&users); results.Error != nil { t.Fatalf("errors happened when create: %v", results.Error) } else if results.RowsAffected != int64(len(users)) { t.Fatalf("rows affected expects: %v, got %v", len(users), results.RowsAffected) } var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) CheckUser(t, user, user) } var users2 []User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) for idx, user := range users2 { CheckUser(t, user, users[idx]) } } func TestBulkCreatePtrDataWithAssociations(t *testing.T) { users := []*User{ GetUser("bulk_ptr_1", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 0, Languages: 1, Friends: 1}), GetUser("bulk_ptr_2", Config{Account: false, Pets: 2, Toys: 4, Company: false, Manager: false, Team: 1, Languages: 3, Friends: 5}), GetUser("bulk_ptr_3", Config{Account: true, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 4, Languages: 0, Friends: 1}), GetUser("bulk_ptr_4", Config{Account: true, Pets: 3, Toys: 0, Company: false, Manager: true, Team: 0, Languages: 3, Friends: 0}), GetUser("bulk_ptr_5", Config{Account: false, Pets: 0, Toys: 3, Company: true, Manager: false, Team: 1, Languages: 3, Friends: 1}), GetUser("bulk_ptr_6", Config{Account: true, Pets: 4, Toys: 3, Company: false, Manager: true, Team: 1, Languages: 3, Friends: 0}), GetUser("bulk_ptr_7", Config{Account: true, Pets: 1, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1}), GetUser("bulk_ptr_8", Config{Account: false, Pets: 0, Toys: 0, Company: false, Manager: false, Team: 0, Languages: 0, Friends: 0}), } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) CheckUser(t, *user, *user) } var users2 []User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").Find(&users2, "id IN ?", userIDs) for idx, user := range users2 { CheckUser(t, user, *users[idx]) } } func TestPolymorphicHasOne(t *testing.T) { t.Run("Struct", func(t *testing.T) { pet := Pet{ Name: "PolymorphicHasOne", Toy: Toy{Name: "Toy-PolymorphicHasOne"}, } if err := DB.Create(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckPet(t, pet, pet) var pet2 Pet DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) CheckPet(t, pet2, pet) }) t.Run("Slice", func(t *testing.T) { pets := []Pet{{ Name: "PolymorphicHasOne-Slice-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, }, { Name: "PolymorphicHasOne-Slice-2", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, }, { Name: "PolymorphicHasOne-Slice-3", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, }} if err := DB.Create(&pets).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var petIDs []uint for _, pet := range pets { petIDs = append(petIDs, pet.ID) CheckPet(t, pet, pet) } var pets2 []Pet DB.Preload("Toy").Find(&pets2, "id IN ?", petIDs) for idx, pet := range pets2 { CheckPet(t, pet, pets[idx]) } }) t.Run("SliceOfPtr", func(t *testing.T) { pets := []*Pet{{ Name: "PolymorphicHasOne-Slice-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-1"}, }, { Name: "PolymorphicHasOne-Slice-2", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-2"}, }, { Name: "PolymorphicHasOne-Slice-3", Toy: Toy{Name: "Toy-PolymorphicHasOne-Slice-3"}, }} if err := DB.Create(&pets).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } for _, pet := range pets { CheckPet(t, *pet, *pet) } }) t.Run("Array", func(t *testing.T) { pets := [...]Pet{{ Name: "PolymorphicHasOne-Array-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, }, { Name: "PolymorphicHasOne-Array-2", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, }, { Name: "PolymorphicHasOne-Array-3", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, }} if err := DB.Create(&pets).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } for _, pet := range pets { CheckPet(t, pet, pet) } }) t.Run("ArrayPtr", func(t *testing.T) { pets := [...]*Pet{{ Name: "PolymorphicHasOne-Array-1", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-1"}, }, { Name: "PolymorphicHasOne-Array-2", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-2"}, }, { Name: "PolymorphicHasOne-Array-3", Toy: Toy{Name: "Toy-PolymorphicHasOne-Array-3"}, }} if err := DB.Create(&pets).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } for _, pet := range pets { CheckPet(t, *pet, *pet) } }) } func TestCreateEmptyStruct(t *testing.T) { type EmptyStruct struct { ID uint } DB.Migrator().DropTable(&EmptyStruct{}) if err := DB.AutoMigrate(&EmptyStruct{}); err != nil { t.Errorf("no error should happen when auto migrate, but got %v", err) } if err := DB.Create(&EmptyStruct{}).Error; err != nil { t.Errorf("No error should happen when creating user, but got %v", err) } } func TestCreateEmptySlice(t *testing.T) { data := []User{} if err := DB.Create(&data).Error; err != gorm.ErrEmptySlice { t.Errorf("no data should be created, got %v", err) } sliceMap := []map[string]interface{}{} if err := DB.Model(&User{}).Create(&sliceMap).Error; err != gorm.ErrEmptySlice { t.Errorf("no data should be created, got %v", err) } } func TestCreateInvalidSlice(t *testing.T) { users := []*User{ GetUser("invalid_slice_1", Config{}), GetUser("invalid_slice_2", Config{}), nil, } if err := DB.Create(&users).Error; !errors.Is(err, gorm.ErrInvalidData) { t.Errorf("should returns error invalid data when creating from slice that contains invalid data") } } func TestCreateWithExistingTimestamp(t *testing.T) { user := User{Name: "CreateUserExistingTimestamp"} curTime := now.MustParse("2016-01-01") user.CreatedAt = curTime user.UpdatedAt = curTime DB.Save(&user) AssertEqual(t, user.CreatedAt, curTime) AssertEqual(t, user.UpdatedAt, curTime) var newUser User DB.First(&newUser, user.ID) AssertEqual(t, newUser.CreatedAt, curTime) AssertEqual(t, newUser.UpdatedAt, curTime) } func TestCreateWithNowFuncOverride(t *testing.T) { user := User{Name: "CreateUserTimestampOverride"} curTime := now.MustParse("2016-01-01") NEW := DB.Session(&gorm.Session{ NowFunc: func() time.Time { return curTime }, }) NEW.Save(&user) AssertEqual(t, user.CreatedAt, curTime) AssertEqual(t, user.UpdatedAt, curTime) var newUser User NEW.First(&newUser, user.ID) AssertEqual(t, newUser.CreatedAt, curTime) AssertEqual(t, newUser.UpdatedAt, curTime) } func TestCreateWithNoGORMPrimaryKey(t *testing.T) { type JoinTable struct { UserID uint FriendID uint } DB.Migrator().DropTable(&JoinTable{}) if err := DB.AutoMigrate(&JoinTable{}); err != nil { t.Errorf("no error should happen when auto migrate, but got %v", err) } jt := JoinTable{UserID: 1, FriendID: 2} err := DB.Create(&jt).Error if err != nil { t.Errorf("No error should happen when create a record without a GORM primary key. But in the database this primary key exists and is the union of 2 or more fields\n But got: %s", err) } } func TestSelectWithCreate(t *testing.T) { user := *GetUser("select_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Select("Account", "Toys", "Manager", "ManagerID", "Languages", "Name", "CreatedAt", "Age", "Active").Create(&user) var user2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&user2, user.ID) user.Birthday = nil user.Pets = nil user.Company = Company{} user.Team = nil user.Friends = nil CheckUser(t, user2, user) } func TestOmitWithCreate(t *testing.T) { user := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Omit("Account", "Toys", "Manager", "Birthday").Create(&user) var result User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result, user.ID) user.Birthday = nil user.Account = Account{} user.Toys = nil user.Manager = nil CheckUser(t, result, user) user2 := *GetUser("omit_create", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Omit(clause.Associations).Create(&user2) var result2 User DB.Preload(clause.Associations).First(&result2, user2.ID) user2.Account = Account{} user2.Toys = nil user2.Manager = nil user2.Company = Company{} user2.Pets = nil user2.Team = nil user2.Languages = nil user2.Friends = nil CheckUser(t, result2, user2) } func TestFirstOrCreateNotExistsTable(t *testing.T) { company := Company{Name: "first_or_create_if_not_exists_table"} if err := DB.Table("not_exists").FirstOrCreate(&company).Error; err == nil { t.Errorf("not exists table, but err is nil") } } func TestFirstOrCreateWithPrimaryKey(t *testing.T) { company := Company{ID: 100, Name: "company100_with_primarykey"} DB.FirstOrCreate(&company) if company.ID != 100 { t.Errorf("invalid primary key after creating, got %v", company.ID) } companies := []Company{ {ID: 101, Name: "company101_with_primarykey"}, {ID: 102, Name: "company102_with_primarykey"}, } DB.Create(&companies) if companies[0].ID != 101 || companies[1].ID != 102 { t.Errorf("invalid primary key after creating, got %v, %v", companies[0].ID, companies[1].ID) } } func TestCreateFromSubQuery(t *testing.T) { user := User{Name: "jinzhu"} DB.Create(&user) subQuery := DB.Table("users").Where("name=?", user.Name).Select("id") result := DB.Session(&gorm.Session{DryRun: true}).Model(&Pet{}).Create([]map[string]interface{}{ { "name": "cat", "user_id": gorm.Expr("(?)", DB.Table("(?) as tmp", subQuery).Select("@uid:=id")), }, { "name": "dog", "user_id": gorm.Expr("@uid"), }, }) if !regexp.MustCompile(`INSERT INTO .pets. \(.name.,.user_id.\) .*VALUES \(.+,\(SELECT @uid:=id FROM \(SELECT id FROM .users. WHERE name=.+\) as tmp\)\),\(.+,@uid\)`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid insert SQL, got %v", result.Statement.SQL.String()) } } func TestCreateNilPointer(t *testing.T) { var user *User err := DB.Create(user).Error if err == nil || err != gorm.ErrInvalidValue { t.Fatalf("it is not ErrInvalidValue") } } func TestFirstOrCreateRowsAffected(t *testing.T) { user := User{Name: "TestFirstOrCreateRowsAffected"} res := DB.FirstOrCreate(&user, "name = ?", user.Name) if res.Error != nil || res.RowsAffected != 1 { t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) } res = DB.FirstOrCreate(&user, "name = ?", user.Name) if res.Error != nil || res.RowsAffected != 0 { t.Fatalf("first or create rows affect err:%v rows:%d", res.Error, res.RowsAffected) } } func TestCreateWithAutoIncrementCompositeKey(t *testing.T) { type CompositeKeyProduct struct { ProductID int `gorm:"primaryKey;autoIncrement:true;"` // primary key LanguageCode int `gorm:"primaryKey;"` // primary key Code string Name string } if err := DB.Migrator().DropTable(&CompositeKeyProduct{}); err != nil { t.Fatalf("failed to migrate, got error %v", err) } if err := DB.AutoMigrate(&CompositeKeyProduct{}); err != nil { t.Fatalf("failed to migrate, got error %v", err) } prod := &CompositeKeyProduct{ LanguageCode: 56, Code: "Code56", Name: "ProductName56", } if err := DB.Create(&prod).Error; err != nil { t.Fatalf("failed to create, got error %v", err) } newProd := &CompositeKeyProduct{} if err := DB.First(&newProd).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newProd, prod, "ProductID", "LanguageCode", "Code", "Name") } } func TestCreateOnConflictWithDefaultNull(t *testing.T) { type OnConflictUser struct { ID string Name string `gorm:"default:null"` Email string Mobile string `gorm:"default:'133xxxx'"` } err := DB.Migrator().DropTable(&OnConflictUser{}) AssertEqual(t, err, nil) err = DB.AutoMigrate(&OnConflictUser{}) AssertEqual(t, err, nil) u := OnConflictUser{ ID: "on-conflict-user-id", Name: "on-conflict-user-name", Email: "on-conflict-user-email", Mobile: "on-conflict-user-mobile", } err = DB.Create(&u).Error AssertEqual(t, err, nil) u.Name = "on-conflict-user-name-2" u.Email = "on-conflict-user-email-2" u.Mobile = "" err = DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&u).Error AssertEqual(t, err, nil) var u2 OnConflictUser err = DB.Where("id = ?", u.ID).First(&u2).Error AssertEqual(t, err, nil) AssertEqual(t, u2.Name, "on-conflict-user-name-2") AssertEqual(t, u2.Email, "on-conflict-user-email-2") AssertEqual(t, u2.Mobile, "133xxxx") } func TestCreateFromMapWithoutPK(t *testing.T) { if !isMysql() { t.Skipf("This test case skipped, because of only supporting for mysql") } // case 1: one record, create from map[string]interface{} mapValue1 := map[string]interface{}{"name": "create_from_map_with_schema1", "age": 1} if err := DB.Model(&User{}).Create(mapValue1).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } if _, ok := mapValue1["id"]; !ok { t.Fatal("failed to create data from map with table, returning map has no primary key") } var result1 User if err := DB.Where("name = ?", "create_from_map_with_schema1").First(&result1).Error; err != nil || result1.Age != 1 { t.Fatalf("failed to create from map, got error %v", err) } var idVal int64 _, ok := mapValue1["id"].(uint) if ok { t.Skipf("This test case skipped, because the db supports returning") } idVal, ok = mapValue1["id"].(int64) if !ok { t.Fatal("ret result missing id") } if int64(result1.ID) != idVal { t.Fatal("failed to create data from map with table, @id != id") } // case2: one record, create from *map[string]interface{} mapValue2 := map[string]interface{}{"name": "create_from_map_with_schema2", "age": 1} if err := DB.Model(&User{}).Create(&mapValue2).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } if _, ok := mapValue2["id"]; !ok { t.Fatal("failed to create data from map with table, returning map has no primary key") } var result2 User if err := DB.Where("name = ?", "create_from_map_with_schema2").First(&result2).Error; err != nil || result2.Age != 1 { t.Fatalf("failed to create from map, got error %v", err) } _, ok = mapValue2["id"].(uint) if ok { t.Skipf("This test case skipped, because the db supports returning") } idVal, ok = mapValue2["id"].(int64) if !ok { t.Fatal("ret result missing id") } if int64(result2.ID) != idVal { t.Fatal("failed to create data from map with table, @id != id") } // case 3: records values := []map[string]interface{}{ {"name": "create_from_map_with_schema11", "age": 1}, {"name": "create_from_map_with_schema12", "age": 1}, } beforeLen := len(values) if err := DB.Model(&User{}).Create(&values).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } // mariadb with returning, values will be appended with id map if len(values) == beforeLen*2 { t.Skipf("This test case skipped, because the db supports returning") } for i := range values { v, ok := values[i]["id"] if !ok { t.Fatal("failed to create data from map with table, returning map has no primary key") } var result User if err := DB.Where("name = ?", fmt.Sprintf("create_from_map_with_schema1%d", i+1)).First(&result).Error; err != nil || result.Age != 1 { t.Fatalf("failed to create from map, got error %v", err) } if int64(result.ID) != v.(int64) { t.Fatal("failed to create data from map with table, @id != id") } } } func TestCreateFromMapWithTable(t *testing.T) { tableDB := DB.Table("users") supportLastInsertID := isMysql() || isSqlite() // case 1: create from map[string]interface{} record := map[string]interface{}{"name": "create_from_map_with_table", "age": 18} if err := tableDB.Create(record).Error; err != nil { t.Fatalf("failed to create data from map with table, got error: %v", err) } if _, ok := record["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } var res map[string]interface{} if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table").Find(&res).Error; err != nil || res["age"] != int64(18) { t.Fatalf("failed to create from map, got error %v", err) } if _, ok := record["@id"]; ok && fmt.Sprint(res["id"]) != fmt.Sprint(record["@id"]) { t.Fatalf("failed to create data from map with table, @id != id, got %v, expect %v", res["id"], record["@id"]) } // case 2: create from *map[string]interface{} record1 := map[string]interface{}{"name": "create_from_map_with_table_1", "age": 18} tableDB2 := DB.Table("users") if err := tableDB2.Create(&record1).Error; err != nil { t.Fatalf("failed to create data from map, got error: %v", err) } if _, ok := record1["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } var res1 map[string]interface{} if err := tableDB2.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_1").Find(&res1).Error; err != nil || res1["age"] != int64(18) { t.Fatalf("failed to create from map, got error %v", err) } if _, ok := record1["@id"]; ok && fmt.Sprint(res1["id"]) != fmt.Sprint(record1["@id"]) { t.Fatal("failed to create data from map with table, @id != id") } // case 3: create from []map[string]interface{} records := []map[string]interface{}{ {"name": "create_from_map_with_table_2", "age": 19}, {"name": "create_from_map_with_table_3", "age": 20}, } tableDB = DB.Table("users") if err := tableDB.Create(&records).Error; err != nil { t.Fatalf("failed to create data from slice of map, got error: %v", err) } if _, ok := records[0]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } if _, ok := records[1]["@id"]; !ok && supportLastInsertID { t.Fatal("failed to create data from map with table, returning map has no key '@id'") } var res2 map[string]interface{} if err := tableDB.Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_2").Find(&res2).Error; err != nil || res2["age"] != int64(19) { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } var res3 map[string]interface{} if err := DB.Table("users").Select([]string{"id", "name", "age"}).Where("name = ?", "create_from_map_with_table_3").Find(&res3).Error; err != nil || res3["age"] != int64(20) { t.Fatalf("failed to query data after create from slice of map, got error %v", err) } if _, ok := records[0]["@id"]; ok && fmt.Sprint(res2["id"]) != fmt.Sprint(records[0]["@id"]) { t.Errorf("failed to create data from map with table, @id != id, got %v, expect %v", res2["id"], records[0]["@id"]) } if _, ok := records[1]["id"]; ok && fmt.Sprint(res3["id"]) != fmt.Sprint(records[1]["@id"]) { t.Errorf("failed to create data from map with table, @id != id") } } ================================================ FILE: tests/customize_field_test.go ================================================ package tests_test import ( "testing" "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestCustomizeColumn(t *testing.T) { type CustomizeColumn struct { ID int64 `gorm:"column:mapped_id; primary_key:yes"` Name string `gorm:"column:mapped_name"` Date *time.Time `gorm:"column:mapped_time"` } DB.Migrator().DropTable(&CustomizeColumn{}) DB.AutoMigrate(&CustomizeColumn{}) expected := "foo" now := time.Now() cc := CustomizeColumn{ID: 666, Name: expected, Date: &now} if count := DB.Create(&cc).RowsAffected; count != 1 { t.Error("There should be one record be affected when create record") } var cc1 CustomizeColumn DB.First(&cc1, "mapped_name = ?", "foo") if cc1.Name != expected { t.Errorf("Failed to query CustomizeColumn") } cc.Name = "bar" DB.Save(&cc) var cc2 CustomizeColumn DB.First(&cc2, "mapped_id = ?", 666) if cc2.Name != "bar" { t.Errorf("Failed to query CustomizeColumn") } } func TestCustomColumnAndIgnoredFieldClash(t *testing.T) { // Make sure an ignored field does not interfere with another field's custom // column name that matches the ignored field. type CustomColumnAndIgnoredFieldClash struct { Body string `gorm:"-"` RawBody string `gorm:"column:body"` } DB.Migrator().DropTable(&CustomColumnAndIgnoredFieldClash{}) if err := DB.AutoMigrate(&CustomColumnAndIgnoredFieldClash{}); err != nil { t.Errorf("Should not raise error: %v", err) } } func TestCustomizeField(t *testing.T) { type CustomizeFieldStruct struct { gorm.Model Name string FieldAllowCreate string `gorm:"<-:create"` FieldAllowUpdate string `gorm:"<-:update"` FieldAllowSave string `gorm:"<-"` FieldAllowSave2 string `gorm:"<-:create,update"` FieldAllowSave3 string `gorm:"->:false;<-:create"` FieldReadonly string `gorm:"->"` FieldIgnore string `gorm:"-"` AutoUnixCreateTime int32 `gorm:"autocreatetime"` AutoUnixMilliCreateTime int `gorm:"autocreatetime:milli"` AutoUnixNanoCreateTime int64 `gorm:"autocreatetime:nano"` AutoUnixUpdateTime uint32 `gorm:"autoupdatetime"` AutoUnixMilliUpdateTime int `gorm:"autoupdatetime:milli"` AutoUnixNanoUpdateTime uint64 `gorm:"autoupdatetime:nano"` } DB.Migrator().DropTable(&CustomizeFieldStruct{}) if err := DB.AutoMigrate(&CustomizeFieldStruct{}); err != nil { t.Errorf("Failed to migrate, got error: %v", err) } if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "FieldIgnore") { t.Errorf("FieldIgnore should not be created") } if DB.Migrator().HasColumn(&CustomizeFieldStruct{}, "field_ignore") { t.Errorf("FieldIgnore should not be created") } generateStruct := func(name string) *CustomizeFieldStruct { return &CustomizeFieldStruct{ Name: name, FieldAllowCreate: name + "_allow_create", FieldAllowUpdate: name + "_allow_update", FieldAllowSave: name + "_allow_save", FieldAllowSave2: name + "_allow_save2", FieldAllowSave3: name + "_allow_save3", FieldReadonly: name + "_allow_readonly", FieldIgnore: name + "_allow_ignore", } } create := generateStruct("create") DB.Create(&create) var result CustomizeFieldStruct DB.Find(&result, "name = ?", "create") AssertObjEqual(t, result, create, "Name", "FieldAllowCreate", "FieldAllowSave", "FieldAllowSave2") if result.FieldAllowUpdate != "" || result.FieldReadonly != "" || result.FieldIgnore != "" || result.FieldAllowSave3 != "" { t.Fatalf("invalid result: %#v", result) } if int(result.AutoUnixCreateTime) != int(result.AutoUnixUpdateTime) || result.AutoUnixCreateTime == 0 { t.Fatalf("invalid create/update unix time: %#v", result) } if int(result.AutoUnixMilliCreateTime) != int(result.AutoUnixMilliUpdateTime) || result.AutoUnixMilliCreateTime == 0 || int(result.AutoUnixMilliCreateTime)/int(result.AutoUnixCreateTime) < 1e3 { t.Fatalf("invalid create/update unix milli time: %#v", result) } if int(result.AutoUnixNanoCreateTime) != int(result.AutoUnixNanoUpdateTime) || result.AutoUnixNanoCreateTime == 0 || int(result.AutoUnixNanoCreateTime)/int(result.AutoUnixCreateTime) < 1e6 { t.Fatalf("invalid create/update unix nano time: %#v", result) } result.FieldAllowUpdate = "field_allow_update_updated" result.FieldReadonly = "field_readonly_updated" result.FieldIgnore = "field_ignore_updated" DB.Save(&result) var result2 CustomizeFieldStruct DB.Find(&result2, "name = ?", "create") if result2.FieldAllowUpdate != result.FieldAllowUpdate || result2.FieldReadonly != "" || result2.FieldIgnore != "" { t.Fatalf("invalid updated result: %#v", result2) } if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: create.FieldReadonly, FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err == nil { t.Fatalf("Should failed to find result") } if err := DB.Table("customize_field_structs").Where("1 = 1").UpdateColumn("field_readonly", "readonly").Error; err != nil { t.Fatalf("failed to update field_readonly column") } if err := DB.Where(CustomizeFieldStruct{Name: create.Name, FieldReadonly: "readonly", FieldIgnore: create.FieldIgnore}).First(&CustomizeFieldStruct{}).Error; err != nil { t.Fatalf("Should find result") } var result3 CustomizeFieldStruct DB.Find(&result3, "name = ?", "create") if result3.FieldReadonly != "readonly" { t.Fatalf("invalid updated result: %#v", result3) } var result4 CustomizeFieldStruct if err := DB.First(&result4, "field_allow_save3 = ?", create.FieldAllowSave3).Error; err != nil { t.Fatalf("failed to query with inserted field, got error %v", err) } AssertEqual(t, result3, result4) createWithDefaultTime := generateStruct("create_with_default_time") createWithDefaultTime.AutoUnixCreateTime = 100 createWithDefaultTime.AutoUnixUpdateTime = 100 createWithDefaultTime.AutoUnixMilliCreateTime = 100 createWithDefaultTime.AutoUnixMilliUpdateTime = 100 createWithDefaultTime.AutoUnixNanoCreateTime = 100 createWithDefaultTime.AutoUnixNanoUpdateTime = 100 DB.Create(&createWithDefaultTime) var createWithDefaultTimeResult CustomizeFieldStruct DB.Find(&createWithDefaultTimeResult, "name = ?", createWithDefaultTime.Name) if int(createWithDefaultTimeResult.AutoUnixCreateTime) != int(createWithDefaultTimeResult.AutoUnixUpdateTime) || createWithDefaultTimeResult.AutoUnixCreateTime != 100 { t.Fatalf("invalid create/update unix time: %#v", createWithDefaultTimeResult) } if int(createWithDefaultTimeResult.AutoUnixMilliCreateTime) != int(createWithDefaultTimeResult.AutoUnixMilliUpdateTime) || createWithDefaultTimeResult.AutoUnixMilliCreateTime != 100 { t.Fatalf("invalid create/update unix milli time: %#v", createWithDefaultTimeResult) } if int(createWithDefaultTimeResult.AutoUnixNanoCreateTime) != int(createWithDefaultTimeResult.AutoUnixNanoUpdateTime) || createWithDefaultTimeResult.AutoUnixNanoCreateTime != 100 { t.Fatalf("invalid create/update unix nano time: %#v", createWithDefaultTimeResult) } } ================================================ FILE: tests/default_value_test.go ================================================ package tests_test import ( "testing" "time" "gorm.io/gorm" ) func TestDefaultValue(t *testing.T) { type Harumph struct { gorm.Model Email string `gorm:"not null;index:,unique"` Name string `gorm:"notNull;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;notNull;default:''"` Age int `gorm:"default:18"` Created time.Time `gorm:"default:2000-01-02"` Enabled bool `gorm:"default:true"` } DB.Migrator().DropTable(&Harumph{}) if err := DB.AutoMigrate(&Harumph{}); err != nil { t.Fatalf("Failed to migrate with default value, got error: %v", err) } harumph := Harumph{Email: "hello@gorm.io"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) } else if harumph.Name != "foo" || harumph.Name2 != "foo" || harumph.Name3 != "" || harumph.Age != 18 || !harumph.Enabled || harumph.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to create data with default value, got: %+v", harumph) } var result Harumph if err := DB.First(&result, "email = ?", "hello@gorm.io").Error; err != nil { t.Fatalf("Failed to find created data, got error: %v", err) } else if result.Name != "foo" || result.Name2 != "foo" || result.Name3 != "" || result.Age != 18 || !result.Enabled || result.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to find created data with default data, got %+v", result) } type Harumph2 struct { ID int `gorm:"default:0"` Email string `gorm:"not null;index:,unique"` Name string `gorm:"notNull;default:foo"` Name2 string `gorm:"size:233;not null;default:'foo'"` Name3 string `gorm:"size:233;notNull;default:''"` Age int `gorm:"default:18"` Created time.Time `gorm:"default:2000-01-02"` Enabled bool `gorm:"default:true"` } harumph2 := Harumph2{ID: 2, Email: "hello2@gorm.io"} if err := DB.Table("harumphs").Create(&harumph2).Error; err != nil { t.Fatalf("Failed to create data with default value, got error: %v", err) } else if harumph2.ID != 2 || harumph2.Name != "foo" || harumph2.Name2 != "foo" || harumph2.Name3 != "" || harumph2.Age != 18 || !harumph2.Enabled || harumph2.Created.Format("20060102") != "20000102" { t.Fatalf("Failed to create data with default value, got: %+v", harumph2) } } ================================================ FILE: tests/delete_test.go ================================================ package tests_test import ( "errors" "testing" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestDelete(t *testing.T) { users := []User{*GetUser("delete", Config{}), *GetUser("delete", Config{}), *GetUser("delete", Config{})} if err := DB.Create(&users).Error; err != nil { t.Errorf("errors happened when create: %v", err) } for _, user := range users { if user.ID == 0 { t.Fatalf("user's primary key should has value after create, got : %v", user.ID) } } if res := DB.Delete(&users[1]); res.Error != nil || res.RowsAffected != 1 { t.Errorf("errors happened when delete: %v, affected: %v", res.Error, res.RowsAffected) } var result User if err := DB.Where("id = ?", users[1].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should returns record not found error, but got %v", err) } for _, user := range []User{users[0], users[2]} { result = User{} if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } for _, user := range []User{users[0], users[2]} { result = User{} if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("no error should returns when query %v, but got %v", user.ID, err) } } if err := DB.Delete(&users[0]).Error; err != nil { t.Errorf("errors happened when delete: %v", err) } if err := DB.Delete(&User{}).Error; err != gorm.ErrMissingWhereClause { t.Errorf("errors happened when delete: %v", err) } if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should returns record not found error, but got %v", err) } } func TestDeleteWithTable(t *testing.T) { type UserWithDelete struct { gorm.Model Name string } DB.Table("deleted_users").Migrator().DropTable(UserWithDelete{}) DB.Table("deleted_users").AutoMigrate(UserWithDelete{}) user := UserWithDelete{Name: "delete1"} DB.Table("deleted_users").Create(&user) var result UserWithDelete if err := DB.Table("deleted_users").First(&result).Error; err != nil { t.Errorf("failed to find deleted user, got error %v", err) } AssertEqual(t, result, user) if err := DB.Table("deleted_users").Delete(&result).Error; err != nil { t.Errorf("failed to delete user, got error %v", err) } var result2 UserWithDelete if err := DB.Table("deleted_users").First(&result2, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should raise record not found error, but got error %v", err) } var result3 UserWithDelete if err := DB.Table("deleted_users").Unscoped().First(&result3, user.ID).Error; err != nil { t.Fatalf("failed to find record, got error %v", err) } if err := DB.Table("deleted_users").Unscoped().Delete(&result).Error; err != nil { t.Errorf("failed to delete user with unscoped, got error %v", err) } var result4 UserWithDelete if err := DB.Table("deleted_users").Unscoped().First(&result4, user.ID).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should raise record not found error, but got error %v", err) } } func TestInlineCondDelete(t *testing.T) { user1 := *GetUser("inline_delete_1", Config{}) user2 := *GetUser("inline_delete_2", Config{}) DB.Save(&user1).Save(&user2) if DB.Delete(&User{}, user1.ID).Error != nil { t.Errorf("No error should happen when delete a record") } else if err := DB.Where("name = ?", user1.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("User can't be found after delete") } if err := DB.Delete(&User{}, "name = ?", user2.Name).Error; err != nil { t.Errorf("No error should happen when delete a record, err=%s", err) } else if err := DB.Where("name = ?", user2.Name).First(&User{}).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("User can't be found after delete") } } func TestBlockGlobalDelete(t *testing.T) { if err := DB.Delete(&User{}).Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while deleting error") } if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Delete(&User{}).Error; err != nil { t.Errorf("should returns no error while enable global update, but got err %v", err) } } func TestDeleteWithAssociations(t *testing.T) { user := GetUser("delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) if err := DB.Create(user).Error; err != nil { t.Fatalf("failed to create user, got error %v", err) } if err := DB.Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { t.Fatalf("failed to delete user, got error %v", err) } for key, value := range map[string]int64{"Account": 1, "Pets": 2, "Toys": 4, "Company": 1, "Manager": 1, "Team": 1, "Languages": 0, "Friends": 0} { if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { t.Errorf("user's %v expects: %v, got %v", key, value, count) } } for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { if count := DB.Model(&user).Association(key).Count(); count != value { t.Errorf("user's %v expects: %v, got %v", key, value, count) } } } func TestDeleteAssociationsWithUnscoped(t *testing.T) { user := GetUser("unscoped_delete_with_associations", Config{Account: true, Pets: 2, Toys: 4, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 1}) if err := DB.Create(user).Error; err != nil { t.Fatalf("failed to create user, got error %v", err) } if err := DB.Unscoped().Select(clause.Associations, "Pets.Toy").Delete(&user).Error; err != nil { t.Fatalf("failed to delete user, got error %v", err) } for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { if count := DB.Unscoped().Model(&user).Association(key).Count(); count != value { t.Errorf("user's %v expects: %v, got %v", key, value, count) } } for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 1, "Manager": 1, "Team": 0, "Languages": 0, "Friends": 0} { if count := DB.Model(&user).Association(key).Count(); count != value { t.Errorf("user's %v expects: %v, got %v", key, value, count) } } } func TestDeleteSliceWithAssociations(t *testing.T) { users := []User{ *GetUser("delete_slice_with_associations1", Config{Account: true, Pets: 4, Toys: 1, Company: true, Manager: true, Team: 1, Languages: 1, Friends: 4}), *GetUser("delete_slice_with_associations2", Config{Account: true, Pets: 3, Toys: 2, Company: true, Manager: true, Team: 2, Languages: 2, Friends: 3}), *GetUser("delete_slice_with_associations3", Config{Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 2}), *GetUser("delete_slice_with_associations4", Config{Account: true, Pets: 1, Toys: 4, Company: true, Manager: true, Team: 4, Languages: 4, Friends: 1}), } if err := DB.Create(users).Error; err != nil { t.Fatalf("failed to create user, got error %v", err) } if err := DB.Select(clause.Associations).Delete(&users).Error; err != nil { t.Fatalf("failed to delete user, got error %v", err) } for key, value := range map[string]int64{"Account": 4, "Pets": 10, "Toys": 10, "Company": 4, "Manager": 4, "Team": 10, "Languages": 0, "Friends": 0} { if count := DB.Unscoped().Model(&users).Association(key).Count(); count != value { t.Errorf("user's %v expects: %v, got %v", key, value, count) } } for key, value := range map[string]int64{"Account": 0, "Pets": 0, "Toys": 0, "Company": 4, "Manager": 4, "Team": 0, "Languages": 0, "Friends": 0} { if count := DB.Model(&users).Association(key).Count(); count != value { t.Errorf("user's %v expects: %v, got %v", key, value, count) } } } // only sqlite, postgres, gaussdb, sqlserver support returning func TestSoftDeleteReturning(t *testing.T) { if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" { return } users := []*User{ GetUser("delete-returning-1", Config{}), GetUser("delete-returning-2", Config{}), GetUser("delete-returning-3", Config{}), } DB.Create(&users) var results []User DB.Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Delete(&results) if len(results) != 2 { t.Errorf("failed to return delete data, got %v", results) } var count int64 DB.Model(&User{}).Where("name IN ?", []string{users[0].Name, users[1].Name, users[2].Name}).Count(&count) if count != 1 { t.Errorf("failed to delete data, current count %v", count) } } func TestDeleteReturning(t *testing.T) { if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" { return } companies := []Company{ {Name: "delete-returning-1"}, {Name: "delete-returning-2"}, {Name: "delete-returning-3"}, } DB.Create(&companies) var results []Company DB.Where("name IN ?", []string{companies[0].Name, companies[1].Name}).Clauses(clause.Returning{}).Delete(&results) if len(results) != 2 { t.Errorf("failed to return delete data, got %v", results) } var count int64 DB.Model(&Company{}).Where("name IN ?", []string{companies[0].Name, companies[1].Name, companies[2].Name}).Count(&count) if count != 1 { t.Errorf("failed to delete data, current count %v", count) } } ================================================ FILE: tests/distinct_test.go ================================================ package tests_test import ( "regexp" "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestDistinct(t *testing.T) { users := []User{ *GetUser("distinct", Config{}), *GetUser("distinct", Config{}), *GetUser("distinct", Config{}), *GetUser("distinct-2", Config{}), *GetUser("distinct-3", Config{}), } users[0].Age = 20 if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create users: %v", err) } var names []string DB.Table("users").Where("name like ?", "distinct%").Order("name").Pluck("name", &names) AssertEqual(t, names, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) var names1 []string DB.Model(&User{}).Where("name like ?", "distinct%").Distinct().Order("name").Pluck("Name", &names1) AssertEqual(t, names1, []string{"distinct", "distinct-2", "distinct-3"}) var names2 []string DB.Scopes(func(db *gorm.DB) *gorm.DB { return db.Table("users") }).Where("name like ?", "distinct%").Order("name").Pluck("name", &names2) AssertEqual(t, names2, []string{"distinct", "distinct", "distinct", "distinct-2", "distinct-3"}) var results []User if err := DB.Distinct("name", "age").Where("name like ?", "distinct%").Order("name, age desc").Find(&results).Error; err != nil { t.Errorf("failed to query users, got error: %v", err) } expects := []User{ {Name: "distinct", Age: 20}, {Name: "distinct", Age: 18}, {Name: "distinct-2", Age: 18}, {Name: "distinct-3", Age: 18}, } if len(results) != 4 { t.Fatalf("invalid results length found, expects: %v, got %v", len(expects), len(results)) } for idx, expect := range expects { AssertObjEqual(t, results[idx], expect, "Name", "Age") } var count int64 if err := DB.Model(&User{}).Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 5 { t.Errorf("failed to query users count, got error: %v, count: %v", err, count) } if err := DB.Model(&User{}).Distinct("name").Where("name like ?", "distinct%").Count(&count).Error; err != nil || count != 3 { t.Errorf("failed to query users count, got error: %v, count %v", err, count) } dryDB := DB.Session(&gorm.Session{DryRun: true}) r := dryDB.Distinct("u.id, u.*").Table("user_speaks as s").Joins("inner join users as u on u.id = s.user_id").Where("s.language_code ='US' or s.language_code ='ES'").Find(&User{}) if !regexp.MustCompile(`SELECT DISTINCT u\.id, u\.\* FROM user_speaks as s inner join users as u`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Distinct with u.*, but got %v", r.Statement.SQL.String()) } } ================================================ FILE: tests/embedded_struct_test.go ================================================ package tests_test import ( "database/sql/driver" "encoding/json" "errors" "reflect" "testing" "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestEmbeddedStruct(t *testing.T) { type ReadOnly struct { ReadOnly *bool } type BasePost struct { Id int64 Title string URL string ReadOnly } type Author struct { ID string Name string Email string } type HNPost struct { BasePost Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct Upvotes int32 } type EngadgetPost struct { BasePost BasePost `gorm:"Embedded"` Author *Author `gorm:"Embedded;EmbeddedPrefix:author_"` // Embedded struct ImageUrl string } DB.Migrator().DropTable(&HNPost{}, &EngadgetPost{}) if err := DB.Migrator().AutoMigrate(&HNPost{}, &EngadgetPost{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } for _, name := range []string{"author_id", "author_name", "author_email"} { if !DB.Migrator().HasColumn(&EngadgetPost{}, name) { t.Errorf("should has prefixed column %v", name) } } stmt := gorm.Statement{DB: DB} if err := stmt.Parse(&EngadgetPost{}); err != nil { t.Fatalf("failed to parse embedded struct") } else if len(stmt.Schema.PrimaryFields) != 1 { t.Errorf("should have only one primary field with embedded struct, but got %v", len(stmt.Schema.PrimaryFields)) } for _, name := range []string{"user_id", "user_name", "user_email"} { if !DB.Migrator().HasColumn(&HNPost{}, name) { t.Errorf("should has prefixed column %v", name) } } // save embedded struct DB.Save(&HNPost{BasePost: BasePost{Title: "news"}}) DB.Save(&HNPost{BasePost: BasePost{Title: "hn_news"}}) var news HNPost if err := DB.First(&news, "title = ?", "hn_news").Error; err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if news.Title != "hn_news" { t.Errorf("embedded struct's value should be scanned correctly") } DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_news"}, Author: &Author{Name: "Edward"}}) DB.Save(&EngadgetPost{BasePost: BasePost{Title: "engadget_article"}, Author: &Author{Name: "George"}}) var egNews EngadgetPost if err := DB.First(&egNews, "title = ?", "engadget_news").Error; err != nil { t.Errorf("no error should happen when query with embedded struct, but got %v", err) } else if egNews.BasePost.Title != "engadget_news" { t.Errorf("embedded struct's value should be scanned correctly") } var egPosts []EngadgetPost if err := DB.Order("author_name asc").Find(&egPosts).Error; err != nil { t.Fatalf("no error should happen when query with embedded struct, but got %v", err) } expectAuthors := []string{"Edward", "George"} for i, post := range egPosts { t.Log(i, post.Author) if want := expectAuthors[i]; post.Author.Name != want { t.Errorf("expected author %s got %s", want, post.Author.Name) } } } func TestEmbeddedPointerTypeStruct(t *testing.T) { type BasePost struct { Id int64 Title string URL string } type Author struct { ID string Name string Email string Age int Content Content ContentPtr *Content Birthday time.Time BirthdayPtr *time.Time } type HNPost struct { *BasePost Upvotes int32 *Author `gorm:"EmbeddedPrefix:user_"` // Embedded struct } DB.Migrator().DropTable(&HNPost{}) if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } DB.Create(&HNPost{BasePost: &BasePost{Title: "embedded_pointer_type"}}) var hnPost HNPost if err := DB.First(&hnPost, "title = ?", "embedded_pointer_type").Error; err != nil { t.Errorf("No error should happen when find embedded pointer type, but got %v", err) } if hnPost.Title != "embedded_pointer_type" { t.Errorf("Should find correct value for embedded pointer type") } if hnPost.Author != nil { t.Errorf("Expected to get back a nil Author but got: %v", hnPost.Author) } now := time.Now().Round(time.Second) NewPost := HNPost{ BasePost: &BasePost{Title: "embedded_pointer_type2"}, Author: &Author{ Name: "test", Content: Content{"test"}, ContentPtr: nil, Birthday: now, BirthdayPtr: nil, }, } DB.Create(&NewPost) hnPost = HNPost{} if err := DB.First(&hnPost, "title = ?", NewPost.Title).Error; err != nil { t.Errorf("No error should happen when find embedded pointer type, but got %v", err) } if hnPost.Title != NewPost.Title { t.Errorf("Should find correct value for embedded pointer type") } if hnPost.Author.Name != NewPost.Author.Name { t.Errorf("Expected to get Author name %v but got: %v", NewPost.Author.Name, hnPost.Author.Name) } if !reflect.DeepEqual(NewPost.Author.Content, hnPost.Author.Content) { t.Errorf("Expected to get Author content %v but got: %v", NewPost.Author.Content, hnPost.Author.Content) } if hnPost.Author.ContentPtr != nil { t.Errorf("Expected to get nil Author contentPtr but got: %v", hnPost.Author.ContentPtr) } if NewPost.Author.Birthday.UnixMilli() != hnPost.Author.Birthday.UnixMilli() { t.Errorf("Expected to get Author birthday with %+v but got: %+v", NewPost.Author.Birthday, hnPost.Author.Birthday) } if hnPost.Author.BirthdayPtr != nil { t.Errorf("Expected to get nil Author birthdayPtr but got: %+v", hnPost.Author.BirthdayPtr) } } type Content struct { Content interface{} `gorm:"type:String"` } func (c Content) Value() (driver.Value, error) { // mssql driver with issue on handling null bytes https://github.com/denisenkom/go-mssqldb/issues/530, b, err := json.Marshal(c) return string(b[:]), err } func (c *Content) Scan(src interface{}) error { var value Content str, ok := src.(string) if !ok { byt, ok := src.([]byte) if !ok { return errors.New("Embedded.Scan byte assertion failed") } if err := json.Unmarshal(byt, &value); err != nil { return err } } else { if err := json.Unmarshal([]byte(str), &value); err != nil { return err } } *c = value return nil } func TestEmbeddedScanValuer(t *testing.T) { type HNPost struct { gorm.Model Content } DB.Migrator().DropTable(&HNPost{}) if err := DB.Migrator().AutoMigrate(&HNPost{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } hnPost := HNPost{Content: Content{Content: "hello world"}} if err := DB.Create(&hnPost).Error; err != nil { t.Errorf("Failed to create got error %v", err) } } func TestEmbeddedRelations(t *testing.T) { type EmbUser struct { gorm.Model Name string Age uint Languages []Language `gorm:"many2many:EmbUserSpeak;"` } type AdvancedUser struct { EmbUser `gorm:"embedded"` Advanced bool } DB.Migrator().DropTable(&AdvancedUser{}) if err := DB.AutoMigrate(&AdvancedUser{}); err != nil { if DB.Dialector.Name() != "sqlite" { t.Errorf("Failed to auto migrate advanced user, got error %v", err) } } } func TestEmbeddedTagSetting(t *testing.T) { type Tag1 struct { Id int64 `gorm:"autoIncrement"` } type Tag2 struct { Id int64 } type EmbeddedTag struct { Tag1 Tag1 `gorm:"Embedded;"` Tag2 Tag2 `gorm:"Embedded;EmbeddedPrefix:t2_"` Name string } DB.Migrator().DropTable(&EmbeddedTag{}) err := DB.Migrator().AutoMigrate(&EmbeddedTag{}) AssertEqual(t, err, nil) t1 := EmbeddedTag{Name: "embedded_tag"} err = DB.Save(&t1).Error AssertEqual(t, err, nil) if t1.Tag1.Id == 0 { t.Errorf("embedded struct's primary field should be rewritten") } } ================================================ FILE: tests/error_translator_test.go ================================================ package tests_test import ( "errors" "testing" "gorm.io/gorm" "gorm.io/gorm/utils/tests" ) func TestDialectorWithErrorTranslatorSupport(t *testing.T) { // it shouldn't translate error when the TranslateError flag is false translatedErr := errors.New("translated error") untranslatedErr := errors.New("some random error") db, _ := gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}) err := db.AddError(untranslatedErr) if !errors.Is(err, untranslatedErr) { t.Fatalf("expected err: %v got err: %v", untranslatedErr, err) } // it should translate error when the TranslateError flag is true db, _ = gorm.Open(tests.DummyDialector{TranslatedErr: translatedErr}, &gorm.Config{TranslateError: true}) err = db.AddError(untranslatedErr) if !errors.Is(err, translatedErr) { t.Fatalf("expected err: %v got err: %v", translatedErr, err) } } func TestSupportedDialectorWithErrDuplicatedKey(t *testing.T) { type City struct { gorm.Model Name string `gorm:"unique"` } db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true} if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { return } DB.Migrator().DropTable(&City{}) if err = db.AutoMigrate(&City{}); err != nil { t.Fatalf("failed to migrate cities table, got error: %v", err) } err = db.Create(&City{Name: "Kabul"}).Error if err != nil { t.Fatalf("failed to create record: %v", err) } err = db.Create(&City{Name: "Kabul"}).Error if !errors.Is(err, gorm.ErrDuplicatedKey) { t.Fatalf("expected err: %v got err: %v", gorm.ErrDuplicatedKey, err) } } func TestSupportedDialectorWithErrForeignKeyViolated(t *testing.T) { tidbSkip(t, "not support the foreign key feature") type City struct { gorm.Model Name string `gorm:"unique"` } type Museum struct { gorm.Model Name string `gorm:"unique"` CityID uint City City `gorm:"Constraint:OnUpdate:CASCADE,OnDelete:CASCADE;FOREIGNKEY:CityID;References:ID"` } db, err := OpenTestConnection(&gorm.Config{TranslateError: true}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } dialectors := map[string]bool{"sqlite": true, "postgres": true, "gaussdb": true, "mysql": true, "sqlserver": true} if supported, found := dialectors[db.Dialector.Name()]; !(found && supported) { return } DB.Migrator().DropTable(&City{}, &Museum{}) if err = db.AutoMigrate(&City{}, &Museum{}); err != nil { t.Fatalf("failed to migrate countries & cities tables, got error: %v", err) } city := City{Name: "Amsterdam"} err = db.Create(&city).Error if err != nil { t.Fatalf("failed to create city: %v", err) } err = db.Create(&Museum{Name: "Eye Filmmuseum", CityID: city.ID}).Error if err != nil { t.Fatalf("failed to create museum: %v", err) } err = db.Create(&Museum{Name: "Dungeon", CityID: 123}).Error if !errors.Is(err, gorm.ErrForeignKeyViolated) { t.Fatalf("expected err: %v got err: %v", gorm.ErrForeignKeyViolated, err) } } ================================================ FILE: tests/gaussdb_test.go ================================================ package tests_test import ( "testing" "time" "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestGaussDBReturningIDWhichHasStringType(t *testing.T) { t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function") if DB.Dialector.Name() != "gaussdb" { t.Skip() } type Yasuo struct { // TODO: function gen_random_uuid() does not exist ID string `gorm:"default:gen_random_uuid()"` Name string CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { t.Errorf("Failed to create extension pgcrypto, got error %v", err) } DB.Migrator().DropTable(&Yasuo{}) if err := DB.AutoMigrate(&Yasuo{}); err != nil { t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) } yasuo := Yasuo{Name: "jinzhu"} if err := DB.Create(&yasuo).Error; err != nil { t.Fatalf("should be able to create data, but got %v", err) } if yasuo.ID == "" { t.Fatal("should be able to has ID, but got zero value") } var result Yasuo if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } yasuo.Name = "jinzhu1" if err := DB.Save(&yasuo).Error; err != nil { t.Errorf("Failed to update date, got error %v", err) } if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } } func TestGaussDB(t *testing.T) { t.Skipf("This test case skipped, because of gaussdb not support pgcrypto extension and gen_random_uuid() function") if DB.Dialector.Name() != "gaussdb" { t.Skip() } type Harumph struct { gorm.Model Name string `gorm:"check:name_checker,name <> ''"` // TODO: function gen_random_uuid() does not exist Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` Things pq.StringArray `gorm:"type:text[]"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { t.Errorf("Failed to create extension pgcrypto, got error %v", err) } DB.Migrator().DropTable(&Harumph{}) if err := DB.AutoMigrate(&Harumph{}); err != nil { t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) } harumph := Harumph{} if err := DB.Create(&harumph).Error; err == nil { t.Fatalf("should failed to create data, name can't be blank") } harumph = Harumph{Name: "jinzhu"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("should be able to create data, but got %v", err) } var result Harumph if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } harumph.Name = "jinzhu1" if err := DB.Save(&harumph).Error; err != nil { t.Errorf("Failed to update date, got error %v", err) } if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } DB.Migrator().DropTable("log_usage") if err := DB.Exec(` CREATE TABLE public.log_usage ( log_id bigint NOT NULL ); ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( SEQUENCE NAME public.log_usage_log_id_seq START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1 ); `).Error; err != nil { t.Fatalf("failed to create table, got error %v", err) } columns, err := DB.Migrator().ColumnTypes("log_usage") if err != nil { t.Fatalf("failed to get columns, got error %v", err) } hasLogID := false for _, column := range columns { if column.Name() == "log_id" { hasLogID = true autoIncrement, ok := column.AutoIncrement() if !ok || !autoIncrement { t.Fatalf("column log_id should be auto incrementment") } } } if !hasLogID { t.Fatalf("failed to found column log_id") } } func TestGaussDBMany2ManyWithDefaultValueUUID(t *testing.T) { t.Skipf("This test case skipped, because of gaussdb does not have 'uuid-ossp' extension") if DB.Dialector.Name() != "gaussdb" { t.Skip() } if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil { t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err) } DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") DB.AutoMigrate(&Post{}, &Category{}) post := Post{ Title: "Hello World", Categories: []*Category{ {Title: "Coding"}, {Title: "Golang"}, }, } if err := DB.Create(&post).Error; err != nil { t.Errorf("Failed, got error: %v", err) } } func TestGaussDBOnConstraint(t *testing.T) { t.Skipf("This test case skipped, because of gaussdb not support 'ON CONSTRAINT' statement") if DB.Dialector.Name() != "gaussdb" { t.Skip() } type Thing struct { gorm.Model SomeID string OtherID string Data string } DB.Migrator().DropTable(&Thing{}) DB.Migrator().CreateTable(&Thing{}) if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { t.Error(err) } thing := Thing{ SomeID: "1234", OtherID: "1234", Data: "something", } DB.Create(&thing) thing2 := Thing{ SomeID: "1234", OtherID: "1234", Data: "something else", } result := DB.Clauses(clause.OnConflict{ OnConstraint: "some_id_other_id_unique", UpdateAll: true, }).Create(&thing2) if result.Error != nil { t.Errorf("creating second thing: %v", result.Error) } var things []Thing if err := DB.Find(&things).Error; err != nil { t.Errorf("Failed, got error: %v", err) } if len(things) > 1 { t.Errorf("expected 1 thing got more") } } func TestGaussDBAlterColumnDataType(t *testing.T) { if DB.Dialector.Name() != "gaussdb" { t.Skip() } DB.Migrator().DropTable(&Company{}) DB.AutoMigrate(Company{}) if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { t.Fatalf("failed to alter column from string to int, got error %v", err) } DB.AutoMigrate(Company{}) } ================================================ FILE: tests/generics_test.go ================================================ package tests_test import ( "context" "errors" "fmt" "reflect" "regexp" "sort" "strconv" "strings" "sync" "testing" "github.com/google/uuid" "gorm.io/driver/mysql" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestGenericsCreate(t *testing.T) { ctx := context.Background() user := User{Name: "TestGenericsCreate", Age: 18} err := gorm.G[User](DB).Create(ctx, &user) if err != nil { t.Fatalf("Create failed: %v", err) } if user.ID == 0 { t.Fatalf("no primary key found for %v", user) } if u, err := gorm.G[User](DB).Where("name = ?", user.Name).First(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != user.Name || u.ID != user.ID { t.Errorf("found invalid user, got %v, expect %v", u, user) } if u, err := gorm.G[User](DB).Where("name = ?", user.Name).Take(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != user.Name || u.ID != user.ID { t.Errorf("found invalid user, got %v, expect %v", u, user) } if u, err := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).First(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != user.Name || u.Age != 0 { t.Errorf("found invalid user, got %v, expect %v", u, user) } if u, err := gorm.G[User](DB).Omit("name").Where("name = ?", user.Name).First(ctx); err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != "" || u.Age != user.Age { t.Errorf("found invalid user, got %v, expect %v", u, user) } result := struct { ID int Name string }{} if err := gorm.G[User](DB).Where("name = ?", user.Name).Scan(ctx, &result); err != nil { t.Fatalf("failed to scan user, got error: %v", err) } else if result.Name != user.Name || uint(result.ID) != user.ID { t.Errorf("found invalid user, got %v, expect %v", result, user) } mapResult, err := gorm.G[map[string]interface{}](DB).Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "user_name"}).Take(ctx) if v := mapResult["user_name"]; fmt.Sprint(v) != user.Name { t.Errorf("failed to find map results, got %v, err %v", mapResult, err) } selectOnly := User{Name: "GenericsCreateSelectOnly", Age: 99} if err := gorm.G[User](DB).Select("name").Create(ctx, &selectOnly); err != nil { t.Fatalf("failed to create with Select, got error: %v", err) } if selectOnly.ID == 0 { t.Fatalf("no primary key found for select-only user: %v", selectOnly) } if stored, err := gorm.G[User](DB).Where("id = ?", selectOnly.ID).First(ctx); err != nil { t.Fatalf("failed to reload select-only user, got error: %v", err) } else if stored.Name != selectOnly.Name || stored.Age != 0 { t.Errorf("unexpected select-only user state, got %#v", stored) } omitAge := User{Name: "GenericsCreateOmitAge", Age: 88} if err := gorm.G[User](DB).Omit("age").Create(ctx, &omitAge); err != nil { t.Fatalf("failed to create with Omit, got error: %v", err) } if omitAge.ID == 0 { t.Fatalf("no primary key found for omit-age user: %v", omitAge) } if stored, err := gorm.G[User](DB).Where("id = ?", omitAge.ID).First(ctx); err != nil { t.Fatalf("failed to reload omit-age user, got error: %v", err) } else if stored.Name != omitAge.Name || stored.Age != 0 { t.Errorf("unexpected omit-age user state, got %#v", stored) } } func TestGenericsCreateInBatches(t *testing.T) { batch := []User{ {Name: "GenericsCreateInBatches1"}, {Name: "GenericsCreateInBatches2"}, {Name: "GenericsCreateInBatches3"}, } ctx := context.Background() if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, 2); err != nil { t.Fatalf("CreateInBatches failed: %v", err) } for _, u := range batch { if u.ID == 0 { t.Fatalf("no primary key found for %v", u) } } count, err := gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Count(ctx, "*") if err != nil { t.Fatalf("Count failed: %v", err) } if count != 3 { t.Errorf("expected 3 records, got %d", count) } found, err := gorm.G[User](DB).Raw("SELECT * FROM users WHERE name LIKE ?", "GenericsCreateInBatches%").Find(ctx) if len(found) != len(batch) { t.Errorf("expected %d from Raw Find, got %d", len(batch), len(found)) } found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Limit(2).Find(ctx) if len(found) != 2 { t.Errorf("expected %d from Raw Find, got %d", 2, len(found)) } found, err = gorm.G[User](DB).Where("name like ?", "GenericsCreateInBatches%").Offset(2).Limit(2).Find(ctx) if len(found) != 1 { t.Errorf("expected %d from Raw Find, got %d", 1, len(found)) } } func TestGenericsExecAndUpdate(t *testing.T) { ctx := context.Background() name := "GenericsExec" if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO users(name) VALUES(?)", name); err != nil { t.Fatalf("Exec insert failed: %v", err) } name2 := "GenericsExec2" if err := gorm.G[User](DB).Exec(ctx, "INSERT INTO ?(name) VALUES(?)", clause.Table{Name: clause.CurrentTable}, name2); err != nil { t.Fatalf("Exec insert failed: %v", err) } u, err := gorm.G[User](DB).Table("users as u").Where("u.name = ?", name).First(ctx) if err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if u.Name != name || u.ID == 0 { t.Errorf("found invalid user, got %v", u) } name += "Update" rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Update(ctx, "name", name) if rows != 1 { t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) } nu, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) if err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if nu.Name != name || u.ID != nu.ID { t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) } rows, err = gorm.G[User](DB).Where("id = ?", u.ID).Updates(ctx, User{Name: "GenericsExecUpdates", Age: 18}) if rows != 1 { t.Fatalf("failed to get affected rows, got %d, should be %d", rows, 1) } nu, err = gorm.G[User](DB).Where("id = ?", u.ID).Last(ctx) if err != nil { t.Fatalf("failed to find user, got error: %v", err) } else if nu.Name != "GenericsExecUpdates" || nu.Age != 18 || u.ID != nu.ID { t.Fatalf("found invalid user, got %v, expect %v", nu.ID, u.ID) } } func TestGenericsRow(t *testing.T) { ctx := context.Background() user := User{Name: "GenericsRow"} if err := gorm.G[User](DB).Create(ctx, &user); err != nil { t.Fatalf("Create failed: %v", err) } rawSQLUserRow := gorm.G[User](DB).Raw("SELECT name FROM ? WHERE id = ?", clause.Table{Name: clause.CurrentTable}, user.ID).Row(ctx) var name string if err := rawSQLUserRow.Scan(&name); err != nil { t.Fatalf("rawSQLUserRow scan failed: %v", err) } if name != user.Name { t.Errorf("expected %s, got %s", user.Name, name) } var scannedUserName string selectUserRow := gorm.G[User](DB).Select("name").Where("name = ?", user.Name).Row(ctx) if err := selectUserRow.Scan(&scannedUserName); err != nil { t.Fatalf("selectUserRow scan failed: %v", err) } if name != user.Name { t.Errorf("expected %s, got %s", user.Name, scannedUserName) } user2 := User{Name: "GenericsRow2"} if err := gorm.G[User](DB).Create(ctx, &user2); err != nil { t.Fatalf("Create failed: %v", err) } rawSQLUserRows, err := gorm.G[User](DB).Raw("SELECT name FROM users WHERE id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) if err != nil { t.Fatalf("rawSQLUserRows failed: %v", err) } count := 0 for rawSQLUserRows.Next() { var name string if err := rawSQLUserRows.Scan(&name); err != nil { t.Fatalf("rawSQLUserRows.Scan failed: %v", err) } count++ } if count != 2 { t.Errorf("expected 2 rows, got %d", count) } selectNameUserRows, err := gorm.G[User](DB).Select("name").Where("id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) if err != nil { t.Fatalf("selectNameUserRows failed: %v", err) } count = 0 for selectNameUserRows.Next() { var name string if err := selectNameUserRows.Scan(&name); err != nil { t.Fatalf("selectNameUserRows.Scan failed: %v", err) } count++ } if count != 2 { t.Errorf("expected 2 rows, got %d", count) } fullUserRows, err := gorm.G[User](DB).Where("id IN ?", []uint{user.ID, user2.ID}).Rows(ctx) if err != nil { t.Fatalf("Rows failed: %v", err) } count = 0 for fullUserRows.Next() { var scannedUser User if err := DB.ScanRows(fullUserRows, &scannedUser); err != nil { t.Fatalf("DB.ScanRows failed: %v", err) } count++ } if count != 2 { t.Errorf("expected 2 rows, got %d", count) } } func TestGenericsDelete(t *testing.T) { ctx := context.Background() u := User{Name: "GenericsDelete"} if err := gorm.G[User](DB).Create(ctx, &u); err != nil { t.Fatalf("Create failed: %v", err) } rows, err := gorm.G[User](DB).Where("id = ?", u.ID).Delete(ctx) if err != nil { t.Fatalf("Delete failed: %v", err) } if rows != 1 { t.Errorf("expected 1 row deleted, got %d", rows) } _, err = gorm.G[User](DB).Where("id = ?", u.ID).First(ctx) if err != gorm.ErrRecordNotFound { t.Fatalf("User after delete failed: %v", err) } } func TestGenericsFindInBatches(t *testing.T) { ctx := context.Background() users := []User{ {Name: "GenericsFindBatchA"}, {Name: "GenericsFindBatchB"}, {Name: "GenericsFindBatchC"}, {Name: "GenericsFindBatchD"}, {Name: "GenericsFindBatchE"}, } if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { t.Fatalf("CreateInBatches failed: %v", err) } total := 0 err := gorm.G[User](DB).Where("name like ?", "GenericsFindBatch%").FindInBatches(ctx, 2, func(chunk []User, batch int) error { if len(chunk) > 2 { t.Errorf("batch size exceed 2: got %d", len(chunk)) } total += len(chunk) return nil }) if err != nil { t.Fatalf("FindInBatches failed: %v", err) } if total != len(users) { t.Errorf("expected total %d, got %d", len(users), total) } } func TestGenericsScopes(t *testing.T) { ctx := context.Background() users := []User{{Name: "GenericsScopes1"}, {Name: "GenericsScopes2"}, {Name: "GenericsScopes3"}} err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)) if err != nil { t.Fatalf("CreateInBatches failed: %v", err) } filterName1 := func(stmt *gorm.Statement) { stmt.Where("name = ?", "GenericsScopes1") } results, err := gorm.G[User](DB).Scopes(filterName1).Find(ctx) if err != nil { t.Fatalf("Scopes failed: %v", err) } if len(results) != 1 || results[0].Name != "GenericsScopes1" { t.Fatalf("Scopes expected 1, got %d", len(results)) } notResult, err := gorm.G[User](DB).Where("name like ?", "GenericsScopes%").Not("name = ?", "GenericsScopes1").Order("name").Find(ctx) if len(notResult) != 2 { t.Fatalf("expected 2 results, got %d", len(notResult)) } else if notResult[0].Name != "GenericsScopes2" || notResult[1].Name != "GenericsScopes3" { t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", notResult[0].Name, notResult[1].Name) } orResult, err := gorm.G[User](DB).Or("name = ?", "GenericsScopes1").Or("name = ?", "GenericsScopes2").Order("name").Find(ctx) if len(orResult) != 2 { t.Fatalf("expected 2 results, got %d", len(notResult)) } else if orResult[0].Name != "GenericsScopes1" || orResult[1].Name != "GenericsScopes2" { t.Fatalf("expected names 'GenericsScopes2' and 'GenericsScopes3', got %s and %s", orResult[0].Name, orResult[1].Name) } } func TestGenericsJoins(t *testing.T) { ctx := context.Background() db := gorm.G[User](DB) u := User{Name: "GenericsJoins", Company: Company{Name: "GenericsCompany"}} u2 := User{Name: "GenericsJoins_2", Company: Company{Name: "GenericsCompany_2"}} u3 := User{Name: "GenericsJoins_3", Company: Company{Name: "GenericsCompany_3"}} db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) // Inner JOIN + WHERE result, err := db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { db.Where("?.name = ?", joinTable, u.Company.Name) return nil }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } if result.Name != u.Name || result.Company.Name != u.Company.Name { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } // Inner JOIN + WHERE with map result, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { db.Where(map[string]any{"name": u.Company.Name}) return nil }).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } if result.Name != u.Name || result.Company.Name != u.Company.Name { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } // Left JOIN w/o WHERE result, err = db.Joins(clause.LeftJoin.Association("Company"), nil).Where(map[string]any{"name": u.Name}).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } if result.Name != u.Name || result.Company.Name != u.Company.Name { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } // Left JOIN + Alias WHERE result, err = db.Joins(clause.LeftJoin.Association("Company").As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } db.Where("?.name = ?", joinTable, u.Company.Name) return nil }).Where(map[string]any{"name": u.Name}).First(ctx) if err != nil { t.Fatalf("Joins failed: %v", err) } if result.Name != u.Name || result.Company.Name != u.Company.Name { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } // Raw Subquery JOIN + WHERE result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB)).As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } db.Where("?.name = ?", joinTable, u.Company.Name) return nil }, ).Where(map[string]any{"name": u2.Name}).First(ctx) if err != nil { t.Fatalf("Raw subquery join failed: %v", err) } if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID == 0 { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } // Raw Subquery JOIN + WHERE + Select result, err = db.Joins(clause.LeftJoin.AssociationFrom("Company", gorm.G[Company](DB).Select("Name")).As("t"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { if joinTable.Name != "t" { t.Fatalf("Join table should be t, but got %v", joinTable.Name) } db.Where("?.name = ?", joinTable, u.Company.Name) return nil }, ).Where(map[string]any{"name": u2.Name}).First(ctx) if err != nil { t.Fatalf("Raw subquery join failed: %v", err) } if result.Name != u2.Name || result.Company.Name != u.Company.Name || result.Company.ID != 0 { t.Fatalf("Joins expected %s, got %+v", u.Name, result) } _, err = db.Joins(clause.Has("Company"), func(db gorm.JoinBuilder, joinTable clause.Table, curTable clause.Table) error { return errors.New("join error") }).First(ctx) if err == nil { t.Fatalf("Joins should got error, but got nil") } } func TestGenericsNestedJoins(t *testing.T) { users := []User{ { Name: "generics-nested-joins-1", Manager: &User{ Name: "generics-nested-joins-manager-1", Company: Company{ Name: "generics-nested-joins-manager-company-1", }, NamedPet: &Pet{ Name: "generics-nested-joins-manager-namepet-1", Toy: Toy{ Name: "generics-nested-joins-manager-namepet-toy-1", }, }, }, NamedPet: &Pet{Name: "generics-nested-joins-namepet-1", Toy: Toy{Name: "generics-nested-joins-namepet-toy-1"}}, }, { Name: "generics-nested-joins-2", Manager: GetUser("generics-nested-joins-manager-2", Config{Company: true, NamedPet: true}), NamedPet: &Pet{Name: "generics-nested-joins-namepet-2", Toy: Toy{Name: "generics-nested-joins-namepet-toy-2"}}, }, } ctx := context.Background() db := gorm.G[User](DB) db.CreateInBatches(ctx, &users, 100) var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) } users2, err := db.Joins(clause.LeftJoin.Association("Manager"), nil). Joins(clause.LeftJoin.Association("Manager.Company"), nil). Joins(clause.LeftJoin.Association("Manager.NamedPet.Toy"), nil). Joins(clause.LeftJoin.Association("NamedPet.Toy"), nil). Joins(clause.LeftJoin.Association("NamedPet").As("t"), nil). Where(map[string]any{"id": userIDs}).Find(ctx) if err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } else if len(users2) != len(users) { t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) } sort.Slice(users2, func(i, j int) bool { return users2[i].ID > users2[j].ID }) sort.Slice(users, func(i, j int) bool { return users[i].ID > users[j].ID }) for idx, user := range users { // user CheckUser(t, user, users2[idx]) if users2[idx].Manager == nil { t.Fatalf("Failed to load Manager") } // manager CheckUser(t, *user.Manager, *users2[idx].Manager) // user pet if users2[idx].NamedPet == nil { t.Fatalf("Failed to load NamedPet") } CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) // manager pet if users2[idx].Manager.NamedPet == nil { t.Fatalf("Failed to load NamedPet") } CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) } } func TestGenericsPreloads(t *testing.T) { ctx := context.Background() db := gorm.G[User](DB) u := *GetUser("GenericsPreloads_1", Config{Company: true, Pets: 3, Friends: 7}) u2 := *GetUser("GenericsPreloads_2", Config{Company: true, Pets: 5, Friends: 5}) u3 := *GetUser("GenericsPreloads_3", Config{Company: true, Pets: 7, Friends: 3}) names := []string{u.Name, u2.Name, u3.Name} db.CreateInBatches(ctx, &[]User{u3, u, u2}, 10) result, err := db.Preload("Company", nil).Preload("Pets", nil).Where("name = ?", u.Name).First(ctx) if err != nil { t.Fatalf("Preload failed: %v", err) } if result.Name != u.Name || result.Company.Name != u.Company.Name || len(result.Pets) != len(u.Pets) { t.Fatalf("Preload expected %s, got %+v", u.Name, result) } results, err := db.Preload("Company", func(db gorm.PreloadBuilder) error { db.Where("name = ?", u.Company.Name) return nil }).Where("name in ?", names).Find(ctx) if err != nil { t.Fatalf("Preload failed: %v", err) } for _, result := range results { if result.Name == u.Name { if result.Company.Name != u.Company.Name { t.Fatalf("Preload user %v company should be %v, but got %+v", u.Name, u.Company.Name, result.Company.Name) } } else if result.Company.Name != "" { t.Fatalf("Preload other company should not loaded, user %v company expect %v but got %+v", u.Name, u.Company.Name, result.Company.Name) } } _, err = db.Preload("Company", func(db gorm.PreloadBuilder) error { return errors.New("preload error") }).Where("name in ?", names).Find(ctx) if err == nil { t.Fatalf("Preload should failed, but got nil") } if DB.Dialector.Name() == "mysql" { // mysql 5.7 doesn't support row_number() if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { return } } results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { db.LimitPerRecord(5) return nil }).Where("name in ?", names).Find(ctx) for _, result := range results { if result.Name == u.Name { if len(result.Pets) != len(u.Pets) { t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) } } else if len(result.Pets) != 5 { t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) } } if DB.Dialector.Name() == "sqlserver" { // sqlserver doesn't support order by in subquery return } results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { db.Order("name desc").LimitPerRecord(5) return nil }).Where("name in ?", names).Find(ctx) for _, result := range results { if result.Name == u.Name { if len(result.Pets) != len(u.Pets) { t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) } } else if len(result.Pets) != 5 { t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) } for i := 1; i < len(result.Pets); i++ { if result.Pets[i-1].Name < result.Pets[i].Name { t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) } } } results, err = db.Preload("Pets", func(db gorm.PreloadBuilder) error { db.Order("name").LimitPerRecord(5) return nil }).Preload("Friends", func(db gorm.PreloadBuilder) error { db.Order("name") return nil }).Where("name in ?", names).Find(ctx) for _, result := range results { if result.Name == u.Name { if len(result.Pets) != len(u.Pets) { t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) } if len(result.Friends) != len(u.Friends) { t.Fatalf("Preload user %v pets should be %v, but got %+v", u.Name, u.Pets, result.Pets) } } else if len(result.Pets) != 5 || len(result.Friends) == 0 { t.Fatalf("Preload user %v pets should be 5, but got %+v", result.Name, result.Pets) } for i := 1; i < len(result.Pets); i++ { if result.Pets[i-1].Name > result.Pets[i].Name { t.Fatalf("Preload user %v pets not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) } } for i := 1; i < len(result.Pets); i++ { if result.Pets[i-1].Name > result.Pets[i].Name { t.Fatalf("Preload user %v friends not ordered correctly, last %v, cur %v", result.Name, result.Pets[i-1], result.Pets[i]) } } } } func TestGenericsNestedPreloads(t *testing.T) { user := *GetUser("generics_nested_preload", Config{Pets: 2}) user.Friends = []*User{GetUser("generics_nested_preload", Config{Pets: 5})} ctx := context.Background() db := gorm.G[User](DB) for idx, pet := range user.Pets { pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} } if err := db.Create(ctx, &user); err != nil { t.Fatalf("errors happened when create: %v", err) } user2, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { return nil }).Where(user.ID).Take(ctx) if err != nil { t.Errorf("failed to nested preload user") } CheckUser(t, user2, user) if len(user.Pets) == 0 || len(user.Friends) == 0 || len(user.Friends[0].Pets) == 0 { t.Fatalf("failed to nested preload") } if DB.Dialector.Name() == "mysql" { // mysql 5.7 doesn't support row_number() if strings.HasPrefix(DB.Dialector.(*mysql.Dialector).ServerVersion, "5.7") { return } } if DB.Dialector.Name() == "sqlserver" { // sqlserver doesn't support order by in subquery return } user3, err := db.Preload("Pets.Toy", nil).Preload("Friends.Pets", func(db gorm.PreloadBuilder) error { db.LimitPerRecord(3) return nil }).Where(user.ID).Take(ctx) if err != nil { t.Errorf("failed to nested preload user") } CheckUser(t, user3, user) if len(user3.Friends) != 1 || len(user3.Friends[0].Pets) != 3 { t.Errorf("failed to nested preload with limit per record") } } func TestGenericsDistinct(t *testing.T) { ctx := context.Background() batch := []User{ {Name: "GenericsDistinctDup"}, {Name: "GenericsDistinctDup"}, {Name: "GenericsDistinctUnique"}, } if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { t.Fatalf("CreateInBatches failed: %v", err) } results, err := gorm.G[User](DB).Where("name like ?", "GenericsDistinct%").Distinct("name").Find(ctx) if err != nil { t.Fatalf("Distinct Find failed: %v", err) } if len(results) != 2 { t.Errorf("expected 2 distinct names, got %d", len(results)) } var names []string for _, u := range results { names = append(names, u.Name) } sort.Strings(names) expected := []string{"GenericsDistinctDup", "GenericsDistinctUnique"} if !reflect.DeepEqual(names, expected) { t.Errorf("expected names %v, got %v", expected, names) } } func TestGenericsSetCreate(t *testing.T) { ctx := context.Background() name := "GenericsSetCreate" age := uint(21) err := gorm.G[User](DB).Set( clause.Assignment{Column: clause.Column{Name: "name"}, Value: name}, clause.Assignment{Column: clause.Column{Name: "age"}, Value: age}, ).Create(ctx) if err != nil { t.Fatalf("Set Create failed: %v", err) } u, err := gorm.G[User](DB).Where("name = ?", name).First(ctx) if err != nil { t.Fatalf("failed to find created user: %v", err) } if u.ID == 0 || u.Name != name || u.Age != age { t.Fatalf("created user mismatch, got %+v", u) } } func TestGenericsSetUpdate(t *testing.T) { ctx := context.Background() // prepare u := User{Name: "GenericsSetUpdate_Before", Age: 30} if err := gorm.G[User](DB).Create(ctx, &u); err != nil { t.Fatalf("prepare user failed: %v", err) } // update with Set after chain newName := "GenericsSetUpdate_After" newAge := uint(31) rows, err := gorm.G[User](DB). Where("id = ?", u.ID). Set( clause.Assignment{Column: clause.Column{Name: "name"}, Value: newName}, clause.Assignment{Column: clause.Column{Name: "age"}, Value: newAge}, ). Update(ctx) if err != nil { t.Fatalf("Set Update failed: %v", err) } if rows != 1 { t.Fatalf("expected 1 row affected, got %d", rows) } nu, err := gorm.G[User](DB).Where("id = ?", u.ID).First(ctx) if err != nil { t.Fatalf("failed to query updated user: %v", err) } if nu.Name != newName || nu.Age != newAge { t.Fatalf("updated user mismatch, got %+v", nu) } } func TestGenericsGroupHaving(t *testing.T) { ctx := context.Background() batch := []User{ {Name: "GenericsGroupHavingMulti"}, {Name: "GenericsGroupHavingMulti"}, {Name: "GenericsGroupHavingSingle"}, } if err := gorm.G[User](DB).CreateInBatches(ctx, &batch, len(batch)); err != nil { t.Fatalf("CreateInBatches failed: %v", err) } grouped, err := gorm.G[User](DB).Select("name").Where("name like ?", "GenericsGroupHaving%").Group("name").Having("COUNT(id) > ?", 1).Find(ctx) if err != nil { t.Fatalf("Group+Having Find failed: %v", err) } if len(grouped) != 1 { t.Errorf("expected 1 group with count>1, got %d", len(grouped)) } else if grouped[0].Name != "GenericsGroupHavingMulti" { t.Errorf("expected group name 'GenericsGroupHavingMulti', got '%s'", grouped[0].Name) } } func TestGenericsSubQuery(t *testing.T) { ctx := context.Background() users := []User{ {Name: "GenericsSubquery_1", Age: 10}, {Name: "GenericsSubquery_2", Age: 20}, {Name: "GenericsSubquery_3", Age: 30}, {Name: "GenericsSubquery_4", Age: 40}, } if err := gorm.G[User](DB).CreateInBatches(ctx, &users, len(users)); err != nil { t.Fatalf("CreateInBatches failed: %v", err) } results, err := gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name LIKE ?", "GenericsSubquery%")).Find(ctx) if err != nil { t.Fatalf("got error: %v", err) } if len(results) != 4 { t.Errorf("Four users should be found, instead found %d", len(results)) } results, err = gorm.G[User](DB).Where("name IN (?)", gorm.G[User](DB).Select("name").Where("name IN ?", []string{"GenericsSubquery_1", "GenericsSubquery_2"}).Or("name = ?", "GenericsSubquery_3")).Find(ctx) if err != nil { t.Fatalf("got error: %v", err) } if len(results) != 3 { t.Errorf("Three users should be found, instead found %d", len(results)) } } func TestGenericsUpsert(t *testing.T) { ctx := context.Background() lang := Language{Code: "upsert", Name: "Upsert"} if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang); err != nil { t.Fatalf("failed to upsert, got %v", err) } lang2 := Language{Code: "upsert", Name: "Upsert"} if err := gorm.G[Language](DB, clause.OnConflict{DoNothing: true}).Create(ctx, &lang2); err != nil { t.Fatalf("failed to upsert, got %v", err) } langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx) if err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } lang3 := Language{Code: "upsert", Name: "Upsert"} if err := gorm.G[Language](DB, clause.OnConflict{ Columns: []clause.Column{{Name: "code"}}, DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), }).Create(ctx, &lang3); err != nil { t.Fatalf("failed to upsert, got %v", err) } if langs, err := gorm.G[Language](DB).Where("code = ?", lang.Code).Find(ctx); err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } else if langs[0].Name != "upsert-new" { t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) } } func TestGenericsWithResult(t *testing.T) { ctx := context.Background() users := []User{{Name: "TestGenericsWithResult", Age: 18}, {Name: "TestGenericsWithResult2", Age: 18}} result := gorm.WithResult() err := gorm.G[User](DB, result).CreateInBatches(ctx, &users, 2) if err != nil { t.Errorf("failed to create users WithResult") } if result.RowsAffected != 2 { t.Errorf("failed to get affected rows, got %d, should be %d", result.RowsAffected, 2) } } func TestGenericsReuse(t *testing.T) { ctx := context.Background() users := []User{{Name: "TestGenericsReuse1", Age: 18}, {Name: "TestGenericsReuse2", Age: 18}} err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2) if err != nil { t.Errorf("failed to create users") } reusedb := gorm.G[User](DB).Where("name like ?", "TestGenericsReuse%") sg := sync.WaitGroup{} for i := 0; i < 5; i++ { sg.Add(1) go func() { if u1, err := reusedb.Where("id = ?", users[0].ID).First(ctx); err != nil { t.Errorf("failed to find user, got error: %v", err) } else if u1.Name != users[0].Name || u1.ID != users[0].ID { t.Errorf("found invalid user, got %v, expect %v", u1, users[0]) } if u2, err := reusedb.Where("id = ?", users[1].ID).First(ctx); err != nil { t.Errorf("failed to find user, got error: %v", err) } else if u2.Name != users[1].Name || u2.ID != users[1].ID { t.Errorf("found invalid user, got %v, expect %v", u2, users[1]) } if users, err := reusedb.Where("id IN ?", []uint{users[0].ID, users[1].ID}).Find(ctx); err != nil { t.Errorf("failed to find user, got error: %v", err) } else if len(users) != 2 { t.Errorf("should find 2 users, but got %d", len(users)) } sg.Done() }() } sg.Wait() } func TestGenericsWithTransaction(t *testing.T) { ctx := context.Background() tx := DB.Begin() if tx.Error != nil { t.Fatalf("failed to begin transaction: %v", tx.Error) } users := []User{{Name: "TestGenericsTransaction", Age: 18}, {Name: "TestGenericsTransaction2", Age: 18}} err := gorm.G[User](tx).CreateInBatches(ctx, &users, 2) count, err := gorm.G[User](tx).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") if err != nil { t.Fatalf("Count failed: %v", err) } if count != 2 { t.Errorf("expected 2 records, got %d", count) } if err := tx.Rollback().Error; err != nil { t.Fatalf("failed to rollback transaction: %v", err) } count2, err := gorm.G[User](DB).Where("name like ?", "TestGenericsTransaction%").Count(ctx, "*") if err != nil { t.Fatalf("Count failed: %v", err) } if count2 != 0 { t.Errorf("expected 0 records after rollback, got %d", count2) } } func TestGenericsToSQL(t *testing.T) { ctx := context.Background() sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { gorm.G[User](tx).Limit(10).Find(ctx) return tx }) if !regexp.MustCompile("SELECT \\* FROM .users..* 10").MatchString(sql) { t.Errorf("ToSQL: got wrong sql with Generics API %v", sql) } } func TestGenericsScanUUID(t *testing.T) { ctx := context.Background() users := []User{ {Name: uuid.NewString(), Age: 21}, {Name: uuid.NewString(), Age: 22}, {Name: uuid.NewString(), Age: 23}, } if err := gorm.G[User](DB).CreateInBatches(ctx, &users, 2); err != nil { t.Fatalf("CreateInBatches failed: %v", err) } userIds := []uuid.UUID{} if err := gorm.G[User](DB).Select("name").Where("id in ?", []uint{users[0].ID, users[1].ID, users[2].ID}).Order("age").Scan(ctx, &userIds); err != nil || len(users) != 3 { t.Fatalf("Scan failed: %v, userids %v", err, userIds) } if userIds[0].String() != users[0].Name || userIds[1].String() != users[1].Name || userIds[2].String() != users[2].Name { t.Fatalf("wrong uuid scanned") } } func TestGenericsCount(t *testing.T) { ctx := context.Background() // Just test that the API can be called _, err := gorm.G[User](DB).Count(ctx, "*") if err != nil { t.Fatalf("Count failed: %v", err) } } func TestGenericsUpdate(t *testing.T) { ctx := context.Background() // Just test that the API can be called _, err := gorm.G[User](DB).Where("id = ?", 1).Update(ctx, "name", "test") if err != nil { t.Fatalf("Update failed: %v", err) } } func TestGenericsUpdates(t *testing.T) { ctx := context.Background() // Just test that the API can be called _, err := gorm.G[User](DB).Where("id = ?", 1).Updates(ctx, User{Name: "test"}) if err != nil { t.Fatalf("Updates failed: %v", err) } } func TestGenericsDeleteAPI(t *testing.T) { ctx := context.Background() // Just test that the API can be called _, err := gorm.G[User](DB).Where("id = ?", 1).Delete(ctx) if err != nil { t.Fatalf("Delete failed: %v", err) } } func TestGenericsAssociation(t *testing.T) { // Test basic Association creation assoc := clause.Association{ Association: "Orders", Type: clause.OpCreate, Set: []clause.Assignment{ {Column: clause.Column{Name: "amount"}, Value: 100}, {Column: clause.Column{Name: "state"}, Value: "new"}, }, } // Verify it implements Assigner interface assignments := assoc.Assignments() if len(assignments) != 0 { t.Errorf("Association.Assignments() should return empty slice, got %v", assignments) } // Verify it implements AssociationAssigner interface assocAssignments := assoc.AssociationAssignments() if len(assocAssignments) != 1 { t.Errorf("Association.AssociationAssignments() should return slice with one element, got %v", assocAssignments) } if assocAssignments[0].Association != "Orders" { t.Errorf("Association.AssociationAssignments()[0].Association should be 'Orders', got %v", assocAssignments[0].Association) } // Test different association operation types operations := []struct { Type clause.AssociationOpType TypeName string }{ {clause.OpUnlink, "OpUnlink"}, {clause.OpDelete, "OpDelete"}, {clause.OpUpdate, "OpUpdate"}, {clause.OpCreate, "OpCreate"}, } for _, op := range operations { assoc := clause.Association{ Association: "Orders", Type: op.Type, } if assoc.Type != op.Type { t.Errorf("Association type should be %s, got %v", op.TypeName, assoc.Type) } } } func TestGenericsAssociationSlice(t *testing.T) { // Test that a slice of Association can be used associations := []clause.Association{ {Association: "Orders", Type: clause.OpDelete}, {Association: "Profiles", Type: clause.OpUpdate}, } // In practice, each Association would be processed individually // since []clause.Association doesn't implement AssociationAssigner directly for i, assoc := range associations { assigns := assoc.AssociationAssignments() if len(assigns) != 1 { t.Errorf("Association %d should return one assignment, got %v", i, len(assigns)) } if assigns[0].Association != assoc.Association { t.Errorf("Association %d name should be %s, got %v", i, assoc.Association, assigns[0].Association) } } } ================================================ FILE: tests/go.mod ================================================ module gorm.io/gorm/tests go 1.24.0 require ( github.com/google/uuid v1.6.0 github.com/jinzhu/now v1.1.5 github.com/lib/pq v1.10.9 github.com/stretchr/testify v1.11.1 gorm.io/driver/gaussdb v0.1.0 gorm.io/driver/mysql v1.6.0 gorm.io/driver/postgres v1.6.0 gorm.io/driver/sqlite v1.6.0 gorm.io/driver/sqlserver v1.6.1 gorm.io/gorm v1.31.0 ) require ( filippo.io/edwards25519 v1.1.0 // indirect github.com/HuaweiCloudDeveloper/gaussdb-go v1.0.0-rc1 // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-sql-driver/mysql v1.9.3 // indirect github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 // indirect github.com/golang-sql/sqlexp v0.1.0 // indirect github.com/jackc/pgpassfile v1.0.0 // indirect github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 // indirect github.com/jackc/pgx/v5 v5.7.6 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/jinzhu/inflection v1.0.0 // indirect github.com/mattn/go-sqlite3 v1.14.32 // indirect github.com/microsoft/go-mssqldb v1.9.3 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/tjfoc/gmsm v1.4.1 // indirect golang.org/x/crypto v0.43.0 // indirect golang.org/x/sync v0.17.0 // indirect golang.org/x/text v0.30.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) replace gorm.io/gorm => ../ ================================================ FILE: tests/gorm_test.go ================================================ package tests_test import ( "testing" "gorm.io/driver/mysql" "gorm.io/gorm" ) func TestOpen(t *testing.T) { dsn := "gorm:gorm@tcp(localhost:9910)/gorm?loc=Asia%2FHongKong" // invalid loc _, err := gorm.Open(mysql.Open(dsn), &gorm.Config{}) if err == nil { t.Fatalf("should returns error but got nil") } } func TestReturningWithNullToZeroValues(t *testing.T) { dialect := DB.Dialector.Name() switch dialect { case "mysql", "sqlserver": // these dialects do not support the "returning" clause return default: // This user struct will leverage the existing users table, but override // the Name field to default to null. type user struct { gorm.Model Name string `gorm:"default:null"` } u1 := user{} if results := DB.Create(&u1); results.Error != nil { t.Fatalf("errors happened on create: %v", results.Error) } else if results.RowsAffected != 1 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } else if u1.ID == 0 { t.Fatalf("ID expects : not equal 0, got %v", u1.ID) } got := user{} results := DB.First(&got, "id = ?", u1.ID) if results.Error != nil { t.Fatalf("errors happened on first: %v", results.Error) } else if results.RowsAffected != 1 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } else if got.ID != u1.ID { t.Fatalf("first expects: %v, got %v", u1, got) } results = DB.Select("id, name").Find(&got) if results.Error != nil { t.Fatalf("errors happened on first: %v", results.Error) } else if results.RowsAffected != 1 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } else if got.ID != u1.ID { t.Fatalf("select expects: %v, got %v", u1, got) } u1.Name = "jinzhu" if results := DB.Save(&u1); results.Error != nil { t.Fatalf("errors happened on update: %v", results.Error) } else if results.RowsAffected != 1 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } u1 = user{} // important to reinitialize this before creating it again u2 := user{} db := DB.Session(&gorm.Session{CreateBatchSize: 10}) if results := db.Create([]*user{&u1, &u2}); results.Error != nil { t.Fatalf("errors happened on create: %v", results.Error) } else if results.RowsAffected != 2 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } else if u1.ID == 0 { t.Fatalf("ID expects : not equal 0, got %v", u1.ID) } else if u2.ID == 0 { t.Fatalf("ID expects : not equal 0, got %v", u2.ID) } var gotUsers []user results = DB.Where("id in (?, ?)", u1.ID, u2.ID).Order("id asc").Select("id, name").Find(&gotUsers) if results.Error != nil { t.Fatalf("errors happened on first: %v", results.Error) } else if results.RowsAffected != 2 { t.Fatalf("rows affected expects: %v, got %v", 2, results.RowsAffected) } else if gotUsers[0].ID != u1.ID { t.Fatalf("select expects: %v, got %v", u1.ID, gotUsers[0].ID) } else if gotUsers[1].ID != u2.ID { t.Fatalf("select expects: %v, got %v", u2.ID, gotUsers[1].ID) } u1.Name = "Jinzhu" u2.Name = "Zhang" if results := DB.Save([]*user{&u1, &u2}); results.Error != nil { t.Fatalf("errors happened on update: %v", results.Error) } else if results.RowsAffected != 2 { t.Fatalf("rows affected expects: %v, got %v", 1, results.RowsAffected) } } } ================================================ FILE: tests/group_by_test.go ================================================ package tests_test import ( "testing" . "gorm.io/gorm/utils/tests" ) func TestGroupBy(t *testing.T) { users := []User{{ Name: "groupby", Age: 10, Birthday: Now(), Active: true, }, { Name: "groupby", Age: 20, Birthday: Now(), }, { Name: "groupby", Age: 30, Birthday: Now(), Active: true, }, { Name: "groupby1", Age: 110, Birthday: Now(), }, { Name: "groupby1", Age: 220, Birthday: Now(), Active: true, }, { Name: "groupby1", Age: 330, Birthday: Now(), Active: true, }} if err := DB.Create(&users).Error; err != nil { t.Errorf("errors happened when create: %v", err) } var name string var total int if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("name").Row().Scan(&name, &total); err != nil { t.Errorf("no error should happen, but got %v", err) } if name != "groupby" || total != 60 { t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) } if err := DB.Model(&User{}).Select("name, sum(age)").Where("name = ?", "groupby").Group("users.name").Row().Scan(&name, &total); err != nil { t.Errorf("no error should happen, but got %v", err) } if name != "groupby" || total != 60 { t.Errorf("name should be groupby, but got %v, total should be 60, but got %v", name, total) } if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Row().Scan(&name, &total); err != nil { t.Errorf("no error should happen, but got %v", err) } if name != "groupby1" || total != 660 { t.Errorf("name should be groupby, but got %v, total should be 660, but got %v", name, total) } result := struct { Name string Total int64 }{} if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Find(&result).Error; err != nil { t.Errorf("no error should happen, but got %v", err) } if result.Name != "groupby1" || result.Total != 660 { t.Errorf("name should be groupby, total should be 660, but got %+v", result) } if err := DB.Model(&User{}).Select("name, sum(age) as total").Where("name LIKE ?", "groupby%").Group("name").Having("name = ?", "groupby1").Scan(&result).Error; err != nil { t.Errorf("no error should happen, but got %v", err) } if result.Name != "groupby1" || result.Total != 660 { t.Errorf("name should be groupby, total should be 660, but got %+v", result) } var active bool if err := DB.Model(&User{}).Select("name, active, sum(age)").Where("name = ? and active = ?", "groupby", true).Group("name").Group("active").Row().Scan(&name, &active, &total); err != nil { t.Errorf("no error should happen, but got %v", err) } if name != "groupby" || active != true || total != 40 { t.Errorf("group by two columns, name %v, age %v, active: %v", name, total, active) } if DB.Dialector.Name() == "mysql" { if err := DB.Model(&User{}).Select("name, age as total").Where("name LIKE ?", "groupby%").Having("total > ?", 300).Scan(&result).Error; err != nil { t.Errorf("no error should happen, but got %v", err) } if result.Name != "groupby1" || result.Total != 330 { t.Errorf("name should be groupby, total should be 660, but got %+v", result) } } } ================================================ FILE: tests/helper_test.go ================================================ package tests_test import ( "os" "sort" "strconv" "strings" "testing" "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) type Config struct { Account bool Pets int Toys int Company bool Manager bool Team int Languages int Friends int NamedPet bool Tools int } func GetUser(name string, config Config) *User { var ( birthday = time.Now().Round(time.Second) user = User{ Name: name, Age: 18, Birthday: &birthday, } ) if config.Account { user.Account = Account{Number: name + "_account"} } for i := 0; i < config.Pets; i++ { user.Pets = append(user.Pets, &Pet{Name: name + "_pet_" + strconv.Itoa(i+1)}) } for i := 0; i < config.Toys; i++ { user.Toys = append(user.Toys, Toy{Name: name + "_toy_" + strconv.Itoa(i+1)}) } for i := 0; i < config.Tools; i++ { user.Tools = append(user.Tools, Tools{Name: name + "_tool_" + strconv.Itoa(i+1)}) } if config.Company { user.Company = Company{Name: "company-" + name} } if config.Manager { user.Manager = GetUser(name+"_manager", Config{}) } for i := 0; i < config.Team; i++ { user.Team = append(user.Team, *GetUser(name+"_team_"+strconv.Itoa(i+1), Config{})) } for i := 0; i < config.Languages; i++ { name := name + "_locale_" + strconv.Itoa(i+1) language := Language{Code: name, Name: name} user.Languages = append(user.Languages, language) } for i := 0; i < config.Friends; i++ { user.Friends = append(user.Friends, GetUser(name+"_friend_"+strconv.Itoa(i+1), Config{})) } if config.NamedPet { user.NamedPet = &Pet{Name: name + "_namepet"} } return &user } func CheckPetUnscoped(t *testing.T, pet Pet, expect Pet) { doCheckPet(t, pet, expect, true) } func CheckPet(t *testing.T, pet Pet, expect Pet) { doCheckPet(t, pet, expect, false) } func doCheckPet(t *testing.T, pet Pet, expect Pet, unscoped bool) { if pet.ID != 0 { var newPet Pet if err := db(unscoped).Where("id = ?", pet.ID).First(&newPet).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newPet, pet, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") AssertObjEqual(t, newPet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") } } AssertObjEqual(t, pet, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Name") AssertObjEqual(t, pet.Toy, expect.Toy, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "OwnerID", "OwnerType") if expect.Toy.Name != "" && expect.Toy.OwnerType != "pets" { t.Errorf("toys's OwnerType, expect: %v, got %v", "pets", expect.Toy.OwnerType) } } func CheckUserUnscoped(t *testing.T, user User, expect User) { doCheckUser(t, user, expect, true) } func CheckUser(t *testing.T, user User, expect User) { doCheckUser(t, user, expect, false) } func doCheckUser(t *testing.T, user User, expect User, unscoped bool) { if user.ID != 0 { var newUser User if err := db(unscoped).Where("id = ?", user.ID).First(&newUser).Error; err != nil { t.Fatalf("errors happened when query: %v", err) } else { AssertObjEqual(t, newUser, user, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } } AssertObjEqual(t, user, expect, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") t.Run("Account", func(t *testing.T) { AssertObjEqual(t, user.Account, expect.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") if user.Account.Number != "" { if !user.Account.UserID.Valid { t.Errorf("Account's foreign key should be saved") } else { var account Account db(unscoped).First(&account, "user_id = ?", user.ID) AssertObjEqual(t, account, user.Account, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "UserID", "Number") } } }) t.Run("Pets", func(t *testing.T) { if len(user.Pets) != len(expect.Pets) { t.Fatalf("pets should equal, expect: %v, got %v", len(expect.Pets), len(user.Pets)) } sort.Slice(user.Pets, func(i, j int) bool { return user.Pets[i].ID > user.Pets[j].ID }) sort.Slice(expect.Pets, func(i, j int) bool { return expect.Pets[i].ID > expect.Pets[j].ID }) for idx, pet := range user.Pets { if pet == nil || expect.Pets[idx] == nil { t.Errorf("pets#%v should equal, expect: %v, got %v", idx, expect.Pets[idx], pet) } else { doCheckPet(t, *pet, *expect.Pets[idx], unscoped) } } }) t.Run("Toys", func(t *testing.T) { if len(user.Toys) != len(expect.Toys) { t.Fatalf("toys should equal, expect: %v, got %v", len(expect.Toys), len(user.Toys)) } sort.Slice(user.Toys, func(i, j int) bool { return user.Toys[i].ID > user.Toys[j].ID }) sort.Slice(expect.Toys, func(i, j int) bool { return expect.Toys[i].ID > expect.Toys[j].ID }) for idx, toy := range user.Toys { if toy.OwnerType != "users" { t.Errorf("toys's OwnerType, expect: %v, got %v", "users", toy.OwnerType) } AssertObjEqual(t, toy, expect.Toys[idx], "ID", "CreatedAt", "UpdatedAt", "Name", "OwnerID", "OwnerType") } }) t.Run("Company", func(t *testing.T) { AssertObjEqual(t, user.Company, expect.Company, "ID", "Name") }) t.Run("Manager", func(t *testing.T) { if user.Manager != nil { if user.ManagerID == nil { t.Errorf("Manager's foreign key should be saved") } else { var manager User db(unscoped).First(&manager, "id = ?", *user.ManagerID) AssertObjEqual(t, manager, user.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") AssertObjEqual(t, manager, expect.Manager, "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } } else if user.ManagerID != nil { t.Errorf("Manager should not be created for zero value, got: %+v", user.ManagerID) } }) t.Run("Team", func(t *testing.T) { if len(user.Team) != len(expect.Team) { t.Fatalf("Team should equal, expect: %v, got %v", len(expect.Team), len(user.Team)) } sort.Slice(user.Team, func(i, j int) bool { return user.Team[i].ID > user.Team[j].ID }) sort.Slice(expect.Team, func(i, j int) bool { return expect.Team[i].ID > expect.Team[j].ID }) for idx, team := range user.Team { AssertObjEqual(t, team, expect.Team[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } }) t.Run("Languages", func(t *testing.T) { if len(user.Languages) != len(expect.Languages) { t.Fatalf("Languages should equal, expect: %v, got %v", len(expect.Languages), len(user.Languages)) } sort.Slice(user.Languages, func(i, j int) bool { return strings.Compare(user.Languages[i].Code, user.Languages[j].Code) > 0 }) sort.Slice(expect.Languages, func(i, j int) bool { return strings.Compare(expect.Languages[i].Code, expect.Languages[j].Code) > 0 }) for idx, language := range user.Languages { AssertObjEqual(t, language, expect.Languages[idx], "Code", "Name") } }) t.Run("Friends", func(t *testing.T) { if len(user.Friends) != len(expect.Friends) { t.Fatalf("Friends should equal, expect: %v, got %v", len(expect.Friends), len(user.Friends)) } sort.Slice(user.Friends, func(i, j int) bool { return user.Friends[i].ID > user.Friends[j].ID }) sort.Slice(expect.Friends, func(i, j int) bool { return expect.Friends[i].ID > expect.Friends[j].ID }) for idx, friend := range user.Friends { AssertObjEqual(t, friend, expect.Friends[idx], "ID", "CreatedAt", "UpdatedAt", "DeletedAt", "Name", "Age", "Birthday", "CompanyID", "ManagerID", "Active") } }) } func tidbSkip(t *testing.T, reason string) { if isTiDB() { t.Skipf("This test case skipped, because of TiDB '%s'", reason) } } func isTiDB() bool { return os.Getenv("GORM_DIALECT") == "tidb" } func isMysql() bool { return os.Getenv("GORM_DIALECT") == "mysql" } func isSqlite() bool { return os.Getenv("GORM_DIALECT") == "sqlite" } func db(unscoped bool) *gorm.DB { if unscoped { return DB.Unscoped() } else { return DB } } ================================================ FILE: tests/hooks_test.go ================================================ package tests_test import ( "errors" "log" "os" "reflect" "strings" "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) type Product struct { gorm.Model Name string Code string Price float64 AfterFindCallTimes int64 BeforeCreateCallTimes int64 AfterCreateCallTimes int64 BeforeUpdateCallTimes int64 AfterUpdateCallTimes int64 BeforeSaveCallTimes int64 AfterSaveCallTimes int64 BeforeDeleteCallTimes int64 AfterDeleteCallTimes int64 } func (s *Product) BeforeCreate(tx *gorm.DB) (err error) { if s.Code == "Invalid" { err = errors.New("invalid product") } s.BeforeCreateCallTimes = s.BeforeCreateCallTimes + 1 return } func (s *Product) BeforeUpdate(tx *gorm.DB) (err error) { if s.Code == "dont_update" { err = errors.New("can't update") } s.BeforeUpdateCallTimes = s.BeforeUpdateCallTimes + 1 return } func (s *Product) BeforeSave(tx *gorm.DB) (err error) { if s.Code == "dont_save" { err = errors.New("can't save") } s.BeforeSaveCallTimes = s.BeforeSaveCallTimes + 1 return } func (s *Product) AfterFind(tx *gorm.DB) (err error) { s.AfterFindCallTimes = s.AfterFindCallTimes + 1 return } func (s *Product) AfterCreate(tx *gorm.DB) (err error) { return tx.Model(s).UpdateColumn("AfterCreateCallTimes", s.AfterCreateCallTimes+1).Error } func (s *Product) AfterUpdate(tx *gorm.DB) (err error) { s.AfterUpdateCallTimes = s.AfterUpdateCallTimes + 1 return } func (s *Product) AfterSave(tx *gorm.DB) (err error) { if s.Code == "after_save_error" { err = errors.New("can't save") } s.AfterSaveCallTimes = s.AfterSaveCallTimes + 1 return } func (s *Product) BeforeDelete(tx *gorm.DB) (err error) { if s.Code == "dont_delete" { err = errors.New("can't delete") } s.BeforeDeleteCallTimes = s.BeforeDeleteCallTimes + 1 return } func (s *Product) AfterDelete(tx *gorm.DB) (err error) { if s.Code == "after_delete_error" { err = errors.New("can't delete") } s.AfterDeleteCallTimes = s.AfterDeleteCallTimes + 1 return } func (s *Product) GetCallTimes() []int64 { return []int64{s.BeforeCreateCallTimes, s.BeforeSaveCallTimes, s.BeforeUpdateCallTimes, s.AfterCreateCallTimes, s.AfterSaveCallTimes, s.AfterUpdateCallTimes, s.BeforeDeleteCallTimes, s.AfterDeleteCallTimes, s.AfterFindCallTimes} } func TestRunCallbacks(t *testing.T) { DB.Migrator().DropTable(&Product{}) DB.AutoMigrate(&Product{}) p := Product{Code: "unique_code", Price: 100} DB.Save(&p) if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 1, 0, 0, 0, 0}) { t.Fatalf("Callbacks should be invoked successfully, %v", p.GetCallTimes()) } DB.Where("Code = ?", "unique_code").First(&p) if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 1, 0, 1, 0, 0, 0, 0, 1}) { t.Fatalf("After callbacks values are not saved, %v", p.GetCallTimes()) } p.Price = 200 DB.Save(&p) if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 1, 1, 0, 0, 1}) { t.Fatalf("After update callbacks should be invoked successfully, %v", p.GetCallTimes()) } var products []Product DB.Find(&products, "code = ?", "unique_code") if products[0].AfterFindCallTimes != 2 { t.Fatalf("AfterFind callbacks should work with slice, called %v", products[0].AfterFindCallTimes) } DB.Where("Code = ?", "unique_code").First(&p) if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 0, 0, 2}) { t.Fatalf("After update callbacks values are not saved, %v", p.GetCallTimes()) } DB.Delete(&p) if !reflect.DeepEqual(p.GetCallTimes(), []int64{1, 2, 1, 1, 0, 0, 1, 1, 2}) { t.Fatalf("After delete callbacks should be invoked successfully, %v", p.GetCallTimes()) } if DB.Where("Code = ?", "unique_code").First(&p).Error == nil { t.Fatalf("Can't find a deleted record") } beforeCallTimes := p.AfterFindCallTimes if DB.Where("Code = ?", "unique_code").Find(&p).Error != nil { t.Fatalf("Find don't raise error when record not found") } if p.AfterFindCallTimes != beforeCallTimes { t.Fatalf("AfterFind should not be called") } } func TestCallbacksWithErrors(t *testing.T) { DB.Migrator().DropTable(&Product{}) DB.AutoMigrate(&Product{}) p := Product{Code: "Invalid", Price: 100} if DB.Save(&p).Error == nil { t.Fatalf("An error from before create callbacks happened when create with invalid value") } if DB.Where("code = ?", "Invalid").First(&Product{}).Error == nil { t.Fatalf("Should not save record that have errors") } if DB.Save(&Product{Code: "dont_save", Price: 100}).Error == nil { t.Fatalf("An error from after create callbacks happened when create with invalid value") } p2 := Product{Code: "update_callback", Price: 100} DB.Save(&p2) p2.Code = "dont_update" if DB.Save(&p2).Error == nil { t.Fatalf("An error from before update callbacks happened when update with invalid value") } if DB.Where("code = ?", "update_callback").First(&Product{}).Error != nil { t.Fatalf("Record Should not be updated due to errors happened in before update callback") } if DB.Where("code = ?", "dont_update").First(&Product{}).Error == nil { t.Fatalf("Record Should not be updated due to errors happened in before update callback") } p2.Code = "dont_save" if DB.Save(&p2).Error == nil { t.Fatalf("An error from before save callbacks happened when update with invalid value") } p3 := Product{Code: "dont_delete", Price: 100} DB.Save(&p3) if DB.Delete(&p3).Error == nil { t.Fatalf("An error from before delete callbacks happened when delete") } if DB.Where("Code = ?", "dont_delete").First(&p3).Error != nil { t.Fatalf("An error from before delete callbacks happened") } p4 := Product{Code: "after_save_error", Price: 100} DB.Save(&p4) if err := DB.First(&Product{}, "code = ?", "after_save_error").Error; err == nil { t.Fatalf("Record should be reverted if get an error in after save callback") } p5 := Product{Code: "after_delete_error", Price: 100} DB.Save(&p5) if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { t.Fatalf("Record should be found") } DB.Delete(&p5) if err := DB.First(&Product{}, "code = ?", "after_delete_error").Error; err != nil { t.Fatalf("Record shouldn't be deleted because of an error happened in after delete callback") } } type Product2 struct { gorm.Model Name string Code string Price int64 Owner string } func (s Product2) BeforeCreate(tx *gorm.DB) (err error) { if !strings.HasSuffix(s.Name, "_clone") { newProduft := s newProduft.Price *= 2 newProduft.Name += "_clone" err = tx.Create(&newProduft).Error } if s.Name == "Invalid" { return errors.New("invalid") } return nil } func (s *Product2) BeforeUpdate(tx *gorm.DB) (err error) { tx.Statement.Where("owner != ?", "admin") return } func TestUseDBInHooks(t *testing.T) { DB.Migrator().DropTable(&Product2{}) DB.AutoMigrate(&Product2{}) product := Product2{Name: "Invalid", Price: 100} if err := DB.Create(&product).Error; err == nil { t.Fatalf("should returns error %v when creating product, but got nil", err) } product2 := Product2{Name: "Nice", Price: 100} if err := DB.Create(&product2).Error; err != nil { t.Fatalf("Failed to create product, got error: %v", err) } var result Product2 if err := DB.First(&result, "name = ?", "Nice").Error; err != nil { t.Fatalf("Failed to query product, got error: %v", err) } var resultClone Product2 if err := DB.First(&resultClone, "name = ?", "Nice_clone").Error; err != nil { t.Fatalf("Failed to find cloned product, got error: %v", err) } result.Price *= 2 result.Name += "_clone" AssertObjEqual(t, result, resultClone, "Price", "Name") DB.Model(&result).Update("Price", 500) var result2 Product2 DB.First(&result2, "name = ?", "Nice") if result2.Price != 500 { t.Errorf("Failed to update product's price, expects: %v, got %v", 500, result2.Price) } product3 := Product2{Name: "Nice2", Price: 600, Owner: "admin"} if err := DB.Create(&product3).Error; err != nil { t.Fatalf("Failed to create product, got error: %v", err) } var result3 Product2 if err := DB.First(&result3, "name = ?", "Nice2").Error; err != nil { t.Fatalf("Failed to query product, got error: %v", err) } DB.Model(&result3).Update("Price", 800) var result4 Product2 DB.First(&result4, "name = ?", "Nice2") if result4.Price != 600 { t.Errorf("Admin product's price should not be changed, expects: %v, got %v", 600, result4.Price) } } type Product3 struct { gorm.Model Name string Code string Price int64 Owner string } func (s Product3) BeforeCreate(tx *gorm.DB) (err error) { tx.Statement.SetColumn("Price", s.Price+100) return nil } func (s Product3) BeforeUpdate(tx *gorm.DB) (err error) { if tx.Statement.Changed() { tx.Statement.SetColumn("Price", s.Price+10) } if tx.Statement.Changed("Code") { s.Price += 20 tx.Statement.SetColumn("Price", s.Price+30) } return nil } func TestSetColumn(t *testing.T) { DB.Migrator().DropTable(&Product3{}) DB.AutoMigrate(&Product3{}) product := Product3{Name: "Product", Price: 0} DB.Create(&product) if product.Price != 100 { t.Errorf("invalid price after create, got %+v", product) } DB.Model(&product).Select("code", "price").Updates(map[string]interface{}{"code": "L1212"}) if product.Price != 150 || product.Code != "L1212" { t.Errorf("invalid data after update, got %+v", product) } // Code not changed, price should not change DB.Model(&product).Updates(map[string]interface{}{"Name": "Product New"}) if product.Name != "Product New" || product.Price != 160 || product.Code != "L1212" { t.Errorf("invalid data after update, got %+v", product) } // Code changed, but not selected, price should not change DB.Model(&product).Select("Name", "Price").Updates(map[string]interface{}{"Name": "Product New2", "code": "L1213"}) if product.Name != "Product New2" || product.Price != 170 || product.Code != "L1212" { t.Errorf("invalid data after update, got %+v", product) } // Code changed, price should changed DB.Model(&product).Select("Name", "Code", "Price").Updates(map[string]interface{}{"Name": "Product New3", "code": "L1213"}) if product.Name != "Product New3" || product.Price != 220 || product.Code != "L1213" { t.Errorf("invalid data after update, got %+v", product) } var result Product3 DB.First(&result, product.ID) AssertEqual(t, result, product) // Select to change Code, but nothing updated, price should not change DB.Model(&product).Select("code").Updates(Product3{Name: "L1214", Code: "L1213"}) if product.Price != 220 || product.Code != "L1213" || product.Name != "Product New3" { t.Errorf("invalid data after update, got %+v", product) } DB.Model(&product).Updates(Product3{Code: "L1214"}) if product.Price != 270 || product.Code != "L1214" { t.Errorf("invalid data after update, got %+v", product) } // Code changed, price should changed DB.Model(&product).Select("Name", "Code", "Price").Updates(Product3{Name: "Product New4", Code: ""}) if product.Name != "Product New4" || product.Price != 320 || product.Code != "" { t.Errorf("invalid data after update, got %+v", product) } DB.Model(&product).UpdateColumns(Product3{Code: "L1215"}) if product.Price != 320 || product.Code != "L1215" { t.Errorf("invalid data after update, got %+v", product) } DB.Model(&product).Session(&gorm.Session{SkipHooks: true}).Updates(Product3{Code: "L1216"}) if product.Price != 320 || product.Code != "L1216" { t.Errorf("invalid data after update, got %+v", product) } var result2 Product3 DB.First(&result2, product.ID) AssertEqual(t, result2, product) product2 := Product3{Name: "Product", Price: 0} DB.Session(&gorm.Session{SkipHooks: true}).Create(&product2) if product2.Price != 0 { t.Errorf("invalid price after create without hooks, got %+v", product2) } } func TestHooksForSlice(t *testing.T) { DB.Migrator().DropTable(&Product3{}) DB.AutoMigrate(&Product3{}) products := []*Product3{ {Name: "Product-1", Price: 100}, {Name: "Product-2", Price: 200}, {Name: "Product-3", Price: 300}, } DB.Create(&products) for idx, value := range []int64{200, 300, 400} { if products[idx].Price != value { t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) } } DB.Model(&products).Update("Name", "product-name") // will set all product's price to last product's price + 10 for idx, value := range []int64{410, 410, 410} { if products[idx].Price != value { t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products[idx].Price) } } products2 := []Product3{ {Name: "Product-1", Price: 100}, {Name: "Product-2", Price: 200}, {Name: "Product-3", Price: 300}, } DB.Create(&products2) for idx, value := range []int64{200, 300, 400} { if products2[idx].Price != value { t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) } } DB.Model(&products2).Update("Name", "product-name") // will set all product's price to last product's price + 10 for idx, value := range []int64{410, 410, 410} { if products2[idx].Price != value { t.Errorf("invalid price for product #%v, expects: %v, got %v", idx, value, products2[idx].Price) } } } type Product4 struct { gorm.Model Name string Code string Price int64 Owner string Item ProductItem } type ProductItem struct { gorm.Model Code string Product4ID uint AfterFindCallTimes int } func (pi ProductItem) BeforeCreate(*gorm.DB) error { if pi.Code == "invalid" { return errors.New("invalid item") } return nil } func (pi *ProductItem) AfterFind(*gorm.DB) error { pi.AfterFindCallTimes = pi.AfterFindCallTimes + 1 return nil } func TestFailedToSaveAssociationShouldRollback(t *testing.T) { DB.Migrator().DropTable(&Product4{}, &ProductItem{}) DB.AutoMigrate(&Product4{}, &ProductItem{}) product := Product4{Name: "Product-1", Price: 100, Item: ProductItem{Code: "invalid"}} if err := DB.Create(&product).Error; err == nil { t.Errorf("should got failed to save, but error is nil") } if DB.First(&Product4{}, "name = ?", product.Name).Error == nil { t.Errorf("should got RecordNotFound, but got nil") } product = Product4{Name: "Product-2", Price: 100, Item: ProductItem{Code: "valid"}} if err := DB.Create(&product).Error; err != nil { t.Errorf("should create product, but got error %v", err) } if err := DB.First(&Product4{}, "name = ?", product.Name).Error; err != nil { t.Errorf("should find product, but got error %v", err) } var productWithItem Product4 if err := DB.Session(&gorm.Session{SkipHooks: true}).Preload("Item").First(&productWithItem, "name = ?", product.Name).Error; err != nil { t.Errorf("should find product, but got error %v", err) } if productWithItem.Item.AfterFindCallTimes != 0 { t.Fatalf("AfterFind should not be called times:%d", productWithItem.Item.AfterFindCallTimes) } } type Product5 struct { gorm.Model Name string } var beforeUpdateCall int func (p *Product5) BeforeUpdate(*gorm.DB) error { beforeUpdateCall = beforeUpdateCall + 1 return nil } func TestUpdateCallbacks(t *testing.T) { DB.Migrator().DropTable(&Product5{}) DB.AutoMigrate(&Product5{}) p := Product5{Name: "unique_code"} DB.Model(&Product5{}).Create(&p) err := DB.Model(&Product5{}).Where("id", p.ID).Update("name", "update_name_1").Error if err != nil { t.Fatalf("should update success, but got err %v", err) } if beforeUpdateCall != 1 { t.Fatalf("before update should be called") } err = DB.Model(Product5{}).Where("id", p.ID).Update("name", "update_name_2").Error if !errors.Is(err, gorm.ErrInvalidValue) { t.Fatalf("should got RecordNotFound, but got %v", err) } if beforeUpdateCall != 1 { t.Fatalf("before update should not be called") } err = DB.Model([1]*Product5{&p}).Update("name", "update_name_3").Error if err != nil { t.Fatalf("should update success, but got err %v", err) } if beforeUpdateCall != 2 { t.Fatalf("before update should be called") } err = DB.Model([1]Product5{p}).Update("name", "update_name_4").Error if !errors.Is(err, gorm.ErrInvalidValue) { t.Fatalf("should got RecordNotFound, but got %v", err) } if beforeUpdateCall != 2 { t.Fatalf("before update should not be called") } } type Product6 struct { gorm.Model Name string Item *ProductItem2 } type ProductItem2 struct { gorm.Model Product6ID uint } func (p *Product6) BeforeDelete(tx *gorm.DB) error { if err := tx.Delete(&p.Item).Error; err != nil { return err } return nil } func TestPropagateUnscoped(t *testing.T) { _DB, err := OpenTestConnection(&gorm.Config{ PropagateUnscoped: true, }) if err != nil { log.Printf("failed to connect database, got error %v", err) os.Exit(1) } _DB.Migrator().DropTable(&Product6{}, &ProductItem2{}) _DB.AutoMigrate(&Product6{}, &ProductItem2{}) p := Product6{ Name: "unique_code", Item: &ProductItem2{}, } _DB.Model(&Product6{}).Create(&p) if err := _DB.Unscoped().Delete(&p).Error; err != nil { t.Fatalf("unscoped did not propagate") } } ================================================ FILE: tests/joins_table_test.go ================================================ package tests_test import ( "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" ) type Person struct { ID int Name string Addresses []Address `gorm:"many2many:person_addresses;"` DeletedAt gorm.DeletedAt } type Address struct { ID uint Name string } type PersonAddress struct { PersonID int AddressID int CreatedAt time.Time DeletedAt gorm.DeletedAt } func TestOverrideJoinTable(t *testing.T) { DB.Migrator().DropTable(&Person{}, &Address{}, &PersonAddress{}) if err := DB.SetupJoinTable(&Person{}, "Addresses", &PersonAddress{}); err != nil { t.Fatalf("Failed to setup join table for person, got error %v", err) } if err := DB.AutoMigrate(&Person{}, &Address{}); err != nil { t.Fatalf("Failed to migrate, got %v", err) } address1 := Address{Name: "address 1"} address2 := Address{Name: "address 2"} person := Person{Name: "person", Addresses: []Address{address1, address2}} DB.Create(&person) var addresses1 []Address if err := DB.Model(&person).Association("Addresses").Find(&addresses1); err != nil || len(addresses1) != 2 { t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses1)) } if err := DB.Model(&person).Association("Addresses").Delete(&person.Addresses[0]); err != nil { t.Fatalf("Failed to delete address, got error %v", err) } if len(person.Addresses) != 1 { t.Fatalf("Should have one address left") } if DB.Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 1 { t.Fatalf("Should found one address") } var addresses2 []Address if err := DB.Model(&person).Association("Addresses").Find(&addresses2); err != nil || len(addresses2) != 1 { t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses2)) } if DB.Model(&person).Association("Addresses").Count() != 1 { t.Fatalf("Should found one address") } var addresses3 []Address if err := DB.Unscoped().Model(&person).Association("Addresses").Find(&addresses3); err != nil || len(addresses3) != 2 { t.Fatalf("Failed to find address, got error %v, length: %v", err, len(addresses3)) } if DB.Unscoped().Find(&[]PersonAddress{}, "person_id = ?", person.ID).RowsAffected != 2 { t.Fatalf("Should found soft deleted addresses with unscoped") } if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { t.Fatalf("Should found soft deleted addresses with unscoped") } DB.Model(&person).Association("Addresses").Clear() if DB.Model(&person).Association("Addresses").Count() != 0 { t.Fatalf("Should deleted all addresses") } if DB.Unscoped().Model(&person).Association("Addresses").Count() != 2 { t.Fatalf("Should found soft deleted addresses with unscoped") } DB.Unscoped().Model(&person).Association("Addresses").Clear() if DB.Unscoped().Model(&person).Association("Addresses").Count() != 0 { t.Fatalf("address should be deleted when clear with unscoped") } address2_1 := Address{Name: "address 2-1"} address2_2 := Address{Name: "address 2-2"} person2 := Person{Name: "person_2", Addresses: []Address{address2_1, address2_2}} DB.Create(&person2) if err := DB.Select(clause.Associations).Delete(&person2).Error; err != nil { t.Fatalf("failed to delete person, got error: %v", err) } if count := DB.Unscoped().Model(&person2).Association("Addresses").Count(); count != 2 { t.Errorf("person's addresses expects 2, got %v", count) } if count := DB.Model(&person2).Association("Addresses").Count(); count != 0 { t.Errorf("person's addresses expects 2, got %v", count) } } ================================================ FILE: tests/joins_test.go ================================================ package tests_test import ( "fmt" "regexp" "sort" "testing" "github.com/stretchr/testify/assert" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestJoins(t *testing.T) { user := *GetUser("joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) DB.Create(&user) var user2 User if err := DB.Joins("NamedPet").Joins("Company").Joins("Manager").Joins("Account").First(&user2, "users.name = ?", user.Name).Error; err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } CheckUser(t, user2, user) } func TestJoinsForSlice(t *testing.T) { users := []User{ *GetUser("slice-joins-1", Config{Company: true, Manager: true, Account: true}), *GetUser("slice-joins-2", Config{Company: true, Manager: true, Account: true}), *GetUser("slice-joins-3", Config{Company: true, Manager: true, Account: true}), } DB.Create(&users) var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) } var users2 []User if err := DB.Joins("Company").Joins("Manager").Joins("Account").Find(&users2, "users.id IN ?", userIDs).Error; err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } else if len(users2) != len(users) { t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) } sort.Slice(users2, func(i, j int) bool { return users2[i].ID > users2[j].ID }) sort.Slice(users, func(i, j int) bool { return users[i].ID > users[j].ID }) for idx, user := range users { CheckUser(t, user, users2[idx]) } } func TestJoinConds(t *testing.T) { user := *GetUser("joins-conds", Config{Account: true, Pets: 3}) DB.Save(&user) var users1 []User DB.Joins("inner join pets on pets.user_id = users.id").Where("users.name = ?", user.Name).Find(&users1) if len(users1) != 3 { t.Errorf("should find two users using left join, but got %v", len(users1)) } var users2 []User DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Where("users.name = ?", user.Name).First(&users2) if len(users2) != 1 { t.Errorf("should find one users using left join with conditions, but got %v", len(users2)) } var users3 []User DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where("users.name = ?", user.Name).First(&users3) if len(users3) != 1 { t.Errorf("should find one users using multiple left join conditions, but got %v", len(users3)) } var users4 []User DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number+"non-exist").Where("users.name = ?", user.Name).First(&users4) if len(users4) != 0 { t.Errorf("should find no user when searching with unexisting credit card, but got %v", len(users4)) } var users5 []User db5 := DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5) if db5.Error != nil { t.Errorf("Should not raise error for join where identical fields in different tables. Error: %s", db5.Error.Error()) } var users6 []User DB.Joins("inner join pets on pets.user_id = users.id AND pets.name = @Name", user.Pets[0]).Where("users.name = ?", user.Name).First(&users6) if len(users6) != 1 { t.Errorf("should find one users using left join with conditions, but got %v", len(users6)) } dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Joins("left join pets on pets.user_id = users.id AND pets.name = ?", user.Pets[0].Name).Joins("join accounts on accounts.user_id = users.id AND accounts.number = ?", user.Account.Number).Where(User{Model: gorm.Model{ID: 1}}).Where(Account{Model: gorm.Model{ID: 1}}).Not(Pet{Model: gorm.Model{ID: 1}}).Find(&users5).Statement if !regexp.MustCompile("SELECT .* FROM .users. left join pets.*join accounts.*").MatchString(stmt.SQL.String()) { t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) } iv := DB.Table(`table_invoices`).Select(`seller, SUM(total) as total, SUM(paid) as paid, SUM(balance) as balance`).Group(`seller`) stmt = dryDB.Table(`table_employees`).Select(`id, name, iv.total, iv.paid, iv.balance`).Joins(`LEFT JOIN (?) AS iv ON iv.seller = table_employees.id`, iv).Scan(&user).Statement if !regexp.MustCompile("SELECT id, name, iv.total, iv.paid, iv.balance FROM .table_employees. LEFT JOIN \\(SELECT seller, SUM\\(total\\) as total, SUM\\(paid\\) as paid, SUM\\(balance\\) as balance FROM .table_invoices. GROUP BY .seller.\\) AS iv ON iv.seller = table_employees.id").MatchString(stmt.SQL.String()) { t.Errorf("joins should be ordered, but got %v", stmt.SQL.String()) } } func TestJoinOn(t *testing.T) { user := *GetUser("joins-on", Config{Pets: 2}) DB.Save(&user) var user1 User onQuery := DB.Where(&Pet{Name: "joins-on_pet_1"}) if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } AssertEqual(t, user1.NamedPet.Name, "joins-on_pet_1") onQuery2 := DB.Where(&Pet{Name: "joins-on_pet_2"}) var user2 User if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } AssertEqual(t, user2.NamedPet.Name, "joins-on_pet_2") } func TestJoinsWithSelect(t *testing.T) { type result struct { ID uint PetID uint Name string } user := *GetUser("joins_with_select", Config{Pets: 2}) DB.Save(&user) var results []result DB.Table("users").Select("users.id, pets.id as pet_id, pets.name").Joins("left join pets on pets.user_id = users.id").Where("users.name = ?", "joins_with_select").Scan(&results) sort.Slice(results, func(i, j int) bool { return results[i].PetID > results[j].PetID }) sort.Slice(results, func(i, j int) bool { return user.Pets[i].ID > user.Pets[j].ID }) if len(results) != 2 || results[0].Name != user.Pets[0].Name || results[1].Name != user.Pets[1].Name { t.Errorf("Should find all two pets with Join select, got %+v", results) } } func TestJoinWithOmit(t *testing.T) { user := *GetUser("joins_with_omit", Config{Pets: 2}) DB.Save(&user) results := make([]*User, 0) if err := DB.Table("users").Omit("name").Where("users.name = ?", "joins_with_omit").Joins("left join pets on pets.user_id = users.id").Find(&results).Error; err != nil { return } if len(results) != 2 || results[0].Name != "" || results[1].Name != "" { t.Errorf("Should find all two pets with Join omit and should not find user's name, got %+v", results) return } } func TestJoinCount(t *testing.T) { companyA := Company{Name: "A"} companyB := Company{Name: "B"} DB.Create(&companyA) DB.Create(&companyB) user := User{Name: "kingGo", CompanyID: &companyB.ID} DB.Create(&user) query := DB.Model(&User{}).Joins("Company") var total int64 query.Count(&total) var result User if err := query.First(&result, user.ID).Error; err != nil { t.Fatalf("Failed, got error: %v", err) } if result.ID != user.ID { t.Fatalf("result's id, %d, doesn't match user's id, %d", result.ID, user.ID) } // should find company if result.Company.ID != *user.CompanyID { t.Fatalf("result's id, %d, doesn't match user's company id, %d", result.Company.ID, *user.CompanyID) } } func TestJoinWithSoftDeleted(t *testing.T) { user := GetUser("TestJoinWithSoftDeletedUser", Config{Account: true, NamedPet: true}) DB.Create(&user) var user1 User DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user1, user.ID) if user1.NamedPet == nil || user1.Account.ID == 0 { t.Fatalf("joins NamedPet and Account should not empty:%v", user1) } // Account should empty DB.Delete(&user1.Account) var user2 User DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user2, user.ID) if user2.NamedPet == nil || user2.Account.ID != 0 { t.Fatalf("joins Account should not empty:%v", user2) } // NamedPet should empty DB.Delete(&user1.NamedPet) var user3 User DB.Model(&User{}).Joins("NamedPet").Joins("Account").First(&user3, user.ID) if user3.NamedPet != nil || user2.Account.ID != 0 { t.Fatalf("joins NamedPet and Account should not empty:%v", user2) } } func TestInnerJoins(t *testing.T) { user := *GetUser("inner-joins-1", Config{Company: true, Manager: true, Account: true, NamedPet: false}) DB.Create(&user) var user2 User var err error err = DB.InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error AssertEqual(t, err, nil) CheckUser(t, user2, user) // inner join and NamedPet is nil err = DB.InnerJoins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user2, "users.name = ?", user.Name).Error AssertEqual(t, err, gorm.ErrRecordNotFound) // mixed inner join and left join var user3 User err = DB.Joins("NamedPet").InnerJoins("Company").InnerJoins("Manager").InnerJoins("Account").First(&user3, "users.name = ?", user.Name).Error AssertEqual(t, err, nil) CheckUser(t, user3, user) } func TestJoinWithSameColumnName(t *testing.T) { user := GetUser("TestJoinWithSameColumnName", Config{ Languages: 1, Pets: 1, }) DB.Create(user) type UserSpeak struct { UserID uint LanguageCode string } type Result struct { User UserSpeak Language Pet } results := make([]Result, 0, 1) DB.Select("users.*, user_speaks.*, languages.*, pets.*").Table("users").Joins("JOIN user_speaks ON user_speaks.user_id = users.id"). Joins("JOIN languages ON languages.code = user_speaks.language_code"). Joins("LEFT OUTER JOIN pets ON pets.user_id = users.id").Find(&results) if len(results) == 0 { t.Fatalf("no record find") } else if results[0].Pet.UserID == nil || *(results[0].Pet.UserID) != user.ID { t.Fatalf("wrong user id in pet") } else if results[0].Pet.Name != user.Pets[0].Name { t.Fatalf("wrong pet name") } } func TestJoinArgsWithDB(t *testing.T) { user := *GetUser("joins-args-db", Config{Pets: 2}) DB.Save(&user) // test where var user1 User onQuery := DB.Where(&Pet{Name: "joins-args-db_pet_2"}) if err := DB.Joins("NamedPet", onQuery).Where("users.name = ?", user.Name).First(&user1).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } AssertEqual(t, user1.NamedPet.Name, "joins-args-db_pet_2") // test where and omit onQuery2 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Omit("Name") var user2 User if err := DB.Joins("NamedPet", onQuery2).Where("users.name = ?", user.Name).First(&user2).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } AssertEqual(t, user2.NamedPet.ID, user1.NamedPet.ID) AssertEqual(t, user2.NamedPet.Name, "") // test where and select onQuery3 := DB.Where(&Pet{Name: "joins-args-db_pet_2"}).Select("Name") var user3 User if err := DB.Joins("NamedPet", onQuery3).Where("users.name = ?", user.Name).First(&user3).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } AssertEqual(t, user3.NamedPet.ID, 0) AssertEqual(t, user3.NamedPet.Name, "joins-args-db_pet_2") // test select onQuery4 := DB.Select("ID") var user4 User if err := DB.Joins("NamedPet", onQuery4).Where("users.name = ?", user.Name).First(&user4).Error; err != nil { t.Fatalf("Failed to load with joins on, got error: %v", err) } if user4.NamedPet.ID == 0 { t.Fatal("Pet ID can not be empty") } AssertEqual(t, user4.NamedPet.Name, "") } func TestNestedJoins(t *testing.T) { users := []User{ { Name: "nested-joins-1", Manager: &User{ Name: "nested-joins-manager-1", Company: Company{ Name: "nested-joins-manager-company-1", }, NamedPet: &Pet{ Name: "nested-joins-manager-namepet-1", Toy: Toy{ Name: "nested-joins-manager-namepet-toy-1", }, }, }, NamedPet: &Pet{Name: "nested-joins-namepet-1", Toy: Toy{Name: "nested-joins-namepet-toy-1"}}, }, { Name: "nested-joins-2", Manager: GetUser("nested-joins-manager-2", Config{Company: true, NamedPet: true}), NamedPet: &Pet{Name: "nested-joins-namepet-2", Toy: Toy{Name: "nested-joins-namepet-toy-2"}}, }, } DB.Create(&users) var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) } var users2 []User if err := DB. Joins("Manager"). Joins("Manager.Company"). Joins("Manager.NamedPet"). Joins("Manager.NamedPet.Toy"). Joins("NamedPet"). Joins("NamedPet.Toy"). Find(&users2, "users.id IN ?", userIDs).Error; err != nil { t.Fatalf("Failed to load with joins, got error: %v", err) } else if len(users2) != len(users) { t.Fatalf("Failed to load join users, got: %v, expect: %v", len(users2), len(users)) } sort.Slice(users2, func(i, j int) bool { return users2[i].ID > users2[j].ID }) sort.Slice(users, func(i, j int) bool { return users[i].ID > users[j].ID }) for idx, user := range users { // user CheckUser(t, user, users2[idx]) if users2[idx].Manager == nil { t.Fatalf("Failed to load Manager") } // manager CheckUser(t, *user.Manager, *users2[idx].Manager) // user pet if users2[idx].NamedPet == nil { t.Fatalf("Failed to load NamedPet") } CheckPet(t, *user.NamedPet, *users2[idx].NamedPet) // manager pet if users2[idx].Manager.NamedPet == nil { t.Fatalf("Failed to load NamedPet") } CheckPet(t, *user.Manager.NamedPet, *users2[idx].Manager.NamedPet) } } func TestJoinsPreload_Issue7013(t *testing.T) { manager := &User{Name: "Manager"} DB.Create(manager) var userIDs []uint for i := 0; i < 21; i++ { user := &User{Name: fmt.Sprintf("User%d", i), ManagerID: &manager.ID} DB.Create(user) userIDs = append(userIDs, user.ID) } var entries []User assert.NotPanics(t, func() { assert.NoError(t, DB.Preload("Manager.Team"). Joins("Manager.Company"). Find(&entries).Error) }) } func TestJoinsPreload_Issue7013_RelationEmpty(t *testing.T) { type ( Furniture struct { gorm.Model OwnerID *uint } Owner struct { gorm.Model Furnitures []Furniture CompanyID *uint Company Company } Building struct { gorm.Model Name string OwnerID *uint Owner Owner } ) DB.Migrator().DropTable(&Building{}, &Owner{}, &Furniture{}) DB.Migrator().AutoMigrate(&Building{}, &Owner{}, &Furniture{}) home := &Building{Name: "relation_empty"} DB.Create(home) var entries []Building assert.NotPanics(t, func() { assert.NoError(t, DB.Preload("Owner.Furnitures"). Joins("Owner.Company"). Find(&entries).Error) }) AssertEqual(t, entries, []Building{{Model: home.Model, Name: "relation_empty", Owner: Owner{Company: Company{}}}}) } func TestJoinsPreload_Issue7013_NoEntries(t *testing.T) { var entries []User assert.NotPanics(t, func() { assert.NoError(t, DB.Preload("Manager.Team"). Joins("Manager.Company"). Where("1 <> 1"). Find(&entries).Error) }) AssertEqual(t, len(entries), 0) } ================================================ FILE: tests/lru_test.go ================================================ package tests_test import ( "crypto/rand" "fmt" "math" "math/big" "reflect" "sync" "testing" "time" "gorm.io/gorm/internal/lru" ) func TestLRU_Add_ExistingKey_UpdatesValueAndExpiresAt(t *testing.T) { lru := lru.NewLRU[string, int](10, nil, time.Hour) lru.Add("key1", 1) lru.Add("key1", 2) if value, ok := lru.Get("key1"); !ok || value != 2 { t.Errorf("Expected value to be updated to 2, got %v", value) } } func TestLRU_Add_NewKey_AddsEntry(t *testing.T) { lru := lru.NewLRU[string, int](10, nil, time.Hour) lru.Add("key1", 1) if value, ok := lru.Get("key1"); !ok || value != 1 { t.Errorf("Expected key1 to be added with value 1, got %v", value) } } func TestLRU_Add_ExceedsSize_RemovesOldest(t *testing.T) { lru := lru.NewLRU[string, int](2, nil, time.Hour) lru.Add("key1", 1) lru.Add("key2", 2) lru.Add("key3", 3) if _, ok := lru.Get("key1"); ok { t.Errorf("Expected key1 to be removed, but it still exists") } } func TestLRU_Add_UnlimitedSize_NoEviction(t *testing.T) { lru := lru.NewLRU[string, int](0, nil, time.Hour) lru.Add("key1", 1) lru.Add("key2", 2) lru.Add("key3", 3) if _, ok := lru.Get("key1"); !ok { t.Errorf("Expected key1 to exist, but it was evicted") } } func TestLRU_Add_Eviction(t *testing.T) { lru := lru.NewLRU[string, int](0, nil, time.Second*2) lru.Add("key1", 1) lru.Add("key2", 2) lru.Add("key3", 3) time.Sleep(time.Second * 3) if lru.Cap() != 0 { t.Errorf("Expected lru to be empty, but it was not") } } func BenchmarkLRU_Rand_NoExpire(b *testing.B) { l := lru.NewLRU[int64, int64](8192, nil, 0) trace := make([]int64, b.N*2) for i := 0; i < b.N*2; i++ { trace[i] = getRand(b) % 32768 } b.ResetTimer() var hit, miss int for i := 0; i < 2*b.N; i++ { if i%2 == 0 { l.Add(trace[i], trace[i]) } else { if _, ok := l.Get(trace[i]); ok { hit++ } else { miss++ } } } b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) } func BenchmarkLRU_Freq_NoExpire(b *testing.B) { l := lru.NewLRU[int64, int64](8192, nil, 0) trace := make([]int64, b.N*2) for i := 0; i < b.N*2; i++ { if i%2 == 0 { trace[i] = getRand(b) % 16384 } else { trace[i] = getRand(b) % 32768 } } b.ResetTimer() for i := 0; i < b.N; i++ { l.Add(trace[i], trace[i]) } var hit, miss int for i := 0; i < b.N; i++ { if _, ok := l.Get(trace[i]); ok { hit++ } else { miss++ } } b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) } func BenchmarkLRU_Rand_WithExpire(b *testing.B) { l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10) trace := make([]int64, b.N*2) for i := 0; i < b.N*2; i++ { trace[i] = getRand(b) % 32768 } b.ResetTimer() var hit, miss int for i := 0; i < 2*b.N; i++ { if i%2 == 0 { l.Add(trace[i], trace[i]) } else { if _, ok := l.Get(trace[i]); ok { hit++ } else { miss++ } } } b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) } func BenchmarkLRU_Freq_WithExpire(b *testing.B) { l := lru.NewLRU[int64, int64](8192, nil, time.Millisecond*10) trace := make([]int64, b.N*2) for i := 0; i < b.N*2; i++ { if i%2 == 0 { trace[i] = getRand(b) % 16384 } else { trace[i] = getRand(b) % 32768 } } b.ResetTimer() for i := 0; i < b.N; i++ { l.Add(trace[i], trace[i]) } var hit, miss int for i := 0; i < b.N; i++ { if _, ok := l.Get(trace[i]); ok { hit++ } else { miss++ } } b.Logf("hit: %d miss: %d ratio: %f", hit, miss, float64(hit)/float64(hit+miss)) } func TestLRUNoPurge(t *testing.T) { lc := lru.NewLRU[string, string](10, nil, 0) lc.Add("key1", "val1") if lc.Len() != 1 { t.Fatalf("length differs from expected") } v, ok := lc.Peek("key1") if v != "val1" { t.Fatalf("value differs from expected") } if !ok { t.Fatalf("should be true") } if !lc.Contains("key1") { t.Fatalf("should contain key1") } if lc.Contains("key2") { t.Fatalf("should not contain key2") } v, ok = lc.Peek("key2") if v != "" { t.Fatalf("should be empty") } if ok { t.Fatalf("should be false") } if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) { t.Fatalf("value differs from expected") } if lc.Resize(0) != 0 { t.Fatalf("evicted count differs from expected") } if lc.Resize(2) != 0 { t.Fatalf("evicted count differs from expected") } lc.Add("key2", "val2") if lc.Resize(1) != 1 { t.Fatalf("evicted count differs from expected") } } func TestLRUEdgeCases(t *testing.T) { lc := lru.NewLRU[string, *string](2, nil, 0) // Adding a nil value lc.Add("key1", nil) value, exists := lc.Get("key1") if value != nil || !exists { t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists) } // Adding an entry with the same key but different value newVal := "val1" lc.Add("key1", &newVal) value, exists = lc.Get("key1") if value != &newVal || !exists { t.Fatalf("unexpected value or existence flag for key1: value=%v, exists=%v", value, exists) } } func TestLRU_Values(t *testing.T) { lc := lru.NewLRU[string, string](3, nil, 0) lc.Add("key1", "val1") lc.Add("key2", "val2") lc.Add("key3", "val3") values := lc.Values() if !reflect.DeepEqual(values, []string{"val1", "val2", "val3"}) { t.Fatalf("values differs from expected") } } // func TestExpirableMultipleClose(_ *testing.T) { // lc :=lru.NewLRU[string, string](10, nil, 0) // lc.Close() // // should not panic // lc.Close() // } func TestLRUWithPurge(t *testing.T) { var evicted []string lc := lru.NewLRU(10, func(key string, value string) { evicted = append(evicted, key, value) }, 150*time.Millisecond) k, v, ok := lc.GetOldest() if k != "" { t.Fatalf("should be empty") } if v != "" { t.Fatalf("should be empty") } if ok { t.Fatalf("should be false") } lc.Add("key1", "val1") time.Sleep(100 * time.Millisecond) // not enough to expire if lc.Len() != 1 { t.Fatalf("length differs from expected") } v, ok = lc.Get("key1") if v != "val1" { t.Fatalf("value differs from expected") } if !ok { t.Fatalf("should be true") } time.Sleep(200 * time.Millisecond) // expire v, ok = lc.Get("key1") if ok { t.Fatalf("should be false") } if v != "" { t.Fatalf("should be nil") } if lc.Len() != 0 { t.Fatalf("length differs from expected") } if !reflect.DeepEqual(evicted, []string{"key1", "val1"}) { t.Fatalf("value differs from expected") } // add new entry lc.Add("key2", "val2") if lc.Len() != 1 { t.Fatalf("length differs from expected") } k, v, ok = lc.GetOldest() if k != "key2" { t.Fatalf("value differs from expected") } if v != "val2" { t.Fatalf("value differs from expected") } if !ok { t.Fatalf("should be true") } } func TestLRUWithPurgeEnforcedBySize(t *testing.T) { lc := lru.NewLRU[string, string](10, nil, time.Hour) for i := 0; i < 100; i++ { i := i lc.Add(fmt.Sprintf("key%d", i), fmt.Sprintf("val%d", i)) v, ok := lc.Get(fmt.Sprintf("key%d", i)) if v != fmt.Sprintf("val%d", i) { t.Fatalf("value differs from expected") } if !ok { t.Fatalf("should be true") } if lc.Len() > 20 { t.Fatalf("length should be less than 20") } } if lc.Len() != 10 { t.Fatalf("length differs from expected") } } func TestLRUConcurrency(t *testing.T) { lc := lru.NewLRU[string, string](0, nil, 0) wg := sync.WaitGroup{} wg.Add(1000) for i := 0; i < 1000; i++ { go func(i int) { lc.Add(fmt.Sprintf("key-%d", i/10), fmt.Sprintf("val-%d", i/10)) wg.Done() }(i) } wg.Wait() if lc.Len() != 100 { t.Fatalf("length differs from expected") } } func TestLRUInvalidateAndEvict(t *testing.T) { var evicted int lc := lru.NewLRU(-1, func(_, _ string) { evicted++ }, 0) lc.Add("key1", "val1") lc.Add("key2", "val2") val, ok := lc.Get("key1") if !ok { t.Fatalf("should be true") } if val != "val1" { t.Fatalf("value differs from expected") } if evicted != 0 { t.Fatalf("value differs from expected") } lc.Remove("key1") if evicted != 1 { t.Fatalf("value differs from expected") } val, ok = lc.Get("key1") if val != "" { t.Fatalf("should be empty") } if ok { t.Fatalf("should be false") } } func TestLoadingExpired(t *testing.T) { lc := lru.NewLRU[string, string](0, nil, time.Millisecond*5) lc.Add("key1", "val1") if lc.Len() != 1 { t.Fatalf("length differs from expected") } v, ok := lc.Peek("key1") if v != "val1" { t.Fatalf("value differs from expected") } if !ok { t.Fatalf("should be true") } v, ok = lc.Get("key1") if v != "val1" { t.Fatalf("value differs from expected") } if !ok { t.Fatalf("should be true") } for { result, ok := lc.Get("key1") if ok && result == "" { t.Fatalf("ok should return a result") } if !ok { break } } time.Sleep(time.Millisecond * 100) // wait for expiration reaper if lc.Len() != 0 { t.Fatalf("length differs from expected") } v, ok = lc.Peek("key1") if v != "" { t.Fatalf("should be empty") } if ok { t.Fatalf("should be false") } v, ok = lc.Get("key1") if v != "" { t.Fatalf("should be empty") } if ok { t.Fatalf("should be false") } } func TestLRURemoveOldest(t *testing.T) { lc := lru.NewLRU[string, string](2, nil, 0) if lc.Cap() != 2 { t.Fatalf("expect cap is 2") } k, v, ok := lc.RemoveOldest() if k != "" { t.Fatalf("should be empty") } if v != "" { t.Fatalf("should be empty") } if ok { t.Fatalf("should be false") } ok = lc.Remove("non_existent") if ok { t.Fatalf("should be false") } lc.Add("key1", "val1") if lc.Len() != 1 { t.Fatalf("length differs from expected") } v, ok = lc.Get("key1") if !ok { t.Fatalf("should be true") } if v != "val1" { t.Fatalf("value differs from expected") } if !reflect.DeepEqual(lc.Keys(), []string{"key1"}) { t.Fatalf("value differs from expected") } if lc.Len() != 1 { t.Fatalf("length differs from expected") } lc.Add("key2", "val2") if !reflect.DeepEqual(lc.Keys(), []string{"key1", "key2"}) { t.Fatalf("value differs from expected") } if lc.Len() != 2 { t.Fatalf("length differs from expected") } k, v, ok = lc.RemoveOldest() if k != "key1" { t.Fatalf("value differs from expected") } if v != "val1" { t.Fatalf("value differs from expected") } if !ok { t.Fatalf("should be true") } if !reflect.DeepEqual(lc.Keys(), []string{"key2"}) { t.Fatalf("value differs from expected") } if lc.Len() != 1 { t.Fatalf("length differs from expected") } } func getRand(tb testing.TB) int64 { out, err := rand.Int(rand.Reader, big.NewInt(math.MaxInt64)) if err != nil { tb.Fatal(err) } return out.Int64() } ================================================ FILE: tests/main_test.go ================================================ package tests_test import ( "testing" . "gorm.io/gorm/utils/tests" ) func TestExceptionsWithInvalidSql(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlserver" { t.Skip("skip sqlserver due to it will raise data race for invalid sql") } var columns []string if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { t.Errorf("Should got error with invalid SQL") } if DB.Model(&User{}).Where("sdsd.zaaa = ?", "sd;;;aa").Pluck("aaa", &columns).Error == nil { t.Errorf("Should got error with invalid SQL") } if DB.Where("sdsd.zaaa = ?", "sd;;;aa").Find(&User{}).Error == nil { t.Errorf("Should got error with invalid SQL") } var count1, count2 int64 DB.Model(&User{}).Count(&count1) if count1 <= 0 { t.Errorf("Should find some users") } if DB.Where("name = ?", "jinzhu; delete * from users").First(&User{}).Error == nil { t.Errorf("Should got error with invalid SQL") } DB.Model(&User{}).Count(&count2) if count1 != count2 { t.Errorf("No user should not be deleted by invalid SQL") } } func TestSetAndGet(t *testing.T) { if value, ok := DB.Set("hello", "world").Get("hello"); !ok { t.Errorf("Should be able to get setting after set") } else if value.(string) != "world" { t.Errorf("Set value should not be changed") } if _, ok := DB.Get("non_existing"); ok { t.Errorf("Get non existing key should return error") } } ================================================ FILE: tests/migrate_test.go ================================================ package tests_test import ( "context" "database/sql" "fmt" "math/rand" "reflect" "strconv" "strings" "testing" "time" "github.com/stretchr/testify/assert" "gorm.io/driver/gaussdb" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/migrator" "gorm.io/gorm/schema" "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) func TestMigrate(t *testing.T) { allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Tools{}, &Man{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_speaks", "user_friends", "ccc") if err := DB.Migrator().DropTable(allModels...); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } if err := DB.AutoMigrate(allModels...); err != nil { t.Fatalf("Failed to auto migrate, got error %v", err) } if tables, err := DB.Migrator().GetTables(); err != nil { t.Fatalf("Failed to get database all tables, but got error %v", err) } else { for _, t1 := range []string{"users", "accounts", "pets", "companies", "toys", "languages", "tools"} { hasTable := false for _, t2 := range tables { if t2 == t1 { hasTable = true break } } if !hasTable { t.Fatalf("Failed to get table %v when GetTables", t1) } } } for _, m := range allModels { if !DB.Migrator().HasTable(m) { t.Fatalf("Failed to create table for %#v", m) } } DB.Scopes(func(db *gorm.DB) *gorm.DB { return db.Table("ccc") }).Migrator().CreateTable(&Company{}) if !DB.Migrator().HasTable("ccc") { t.Errorf("failed to create table ccc") } for _, indexes := range [][2]string{ {"user_speaks", "fk_user_speaks_user"}, {"user_speaks", "fk_user_speaks_language"}, {"user_friends", "fk_user_friends_user"}, {"user_friends", "fk_user_friends_friends"}, {"accounts", "fk_users_account"}, {"users", "fk_users_team"}, {"users", "fk_users_company"}, } { if !DB.Migrator().HasConstraint(indexes[0], indexes[1]) { t.Fatalf("Failed to find index for many2many for %v %v", indexes[0], indexes[1]) } } } func TestAutoMigrateInt8PGAndGaussDB(t *testing.T) { if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" { return } type Smallint int8 type MigrateInt struct { Int8 Smallint } tracer := Tracer{ Logger: DB.Config.Logger, Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() if strings.HasPrefix(sql, "ALTER TABLE \"migrate_ints\" ALTER COLUMN \"int8\" TYPE smallint") { t.Fatalf("shouldn't execute ALTER COLUMN TYPE if such type is already existed in DB schema: sql: %s", sql) } }, } DB.Migrator().DropTable(&MigrateInt{}) // The first AutoMigrate to make table with field with correct type if err := DB.AutoMigrate(&MigrateInt{}); err != nil { t.Fatalf("Failed to auto migrate: error: %v", err) } // make new session to set custom logger tracer session := DB.Session(&gorm.Session{Logger: tracer}) // The second AutoMigrate to catch an error if err := session.AutoMigrate(&MigrateInt{}); err != nil { t.Fatalf("Failed to auto migrate: error: %v", err) } } func TestAutoMigrateSelfReferential(t *testing.T) { type MigratePerson struct { ID uint Name string ManagerID *uint Manager *MigratePerson } DB.Migrator().DropTable(&MigratePerson{}) if err := DB.AutoMigrate(&MigratePerson{}); err != nil { t.Fatalf("Failed to auto migrate, but got error %v", err) } if !DB.Migrator().HasConstraint("migrate_people", "fk_migrate_people_manager") { t.Fatalf("Failed to find has one constraint between people and managers") } } func TestAutoMigrateNullable(t *testing.T) { type MigrateNullableColumn struct { ID uint Bonus float64 `gorm:"not null"` Stock float64 } DB.Migrator().DropTable(&MigrateNullableColumn{}) DB.AutoMigrate(&MigrateNullableColumn{}) type MigrateNullableColumn2 struct { ID uint Bonus float64 Stock float64 `gorm:"not null"` } if err := DB.Table("migrate_nullable_columns").AutoMigrate(&MigrateNullableColumn2{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } columnTypes, err := DB.Table("migrate_nullable_columns").Migrator().ColumnTypes(&MigrateNullableColumn{}) if err != nil { t.Fatalf("failed to get column types, got error: %v", err) } for _, columnType := range columnTypes { switch columnType.Name() { case "bonus": // allow to change non-nullable to nullable if nullable, _ := columnType.Nullable(); !nullable { t.Fatalf("bonus's nullable should be true, bug got %t", nullable) } case "stock": // do not allow to change nullable to non-nullable if nullable, _ := columnType.Nullable(); !nullable { t.Fatalf("stock's nullable should be true, bug got %t", nullable) } } } } func TestSmartMigrateColumn(t *testing.T) { fullSupported := map[string]bool{"mysql": true, "postgres": true, "gaussdb": true}[DB.Dialector.Name()] type UserMigrateColumn struct { ID uint Name string Salary float64 Birthday time.Time `gorm:"precision:4"` } DB.Migrator().DropTable(&UserMigrateColumn{}) DB.AutoMigrate(&UserMigrateColumn{}) type UserMigrateColumn2 struct { ID uint Name string `gorm:"size:128"` Salary float64 `gorm:"precision:2"` Birthday time.Time `gorm:"precision:2"` NameIgnoreMigration string `gorm:"size:100"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) if err != nil { t.Fatalf("failed to get column types, got error: %v", err) } for _, columnType := range columnTypes { switch columnType.Name() { case "name": if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) } case "birthday": if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } } } type UserMigrateColumn3 struct { ID uint Name string `gorm:"size:256"` Salary float64 `gorm:"precision:3"` Birthday time.Time `gorm:"precision:3"` NameIgnoreMigration string `gorm:"size:128;-:migration"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) if err != nil { t.Fatalf("failed to get column types, got error: %v", err) } for _, columnType := range columnTypes { switch columnType.Name() { case "name": if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("salary's precision should be 2, but got %v", precision) } case "birthday": if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } case "name_ignore_migration": if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 { t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length) } } } } func TestSmartMigrateColumnGaussDB(t *testing.T) { fullSupported := map[string]bool{"mysql": true, "gaussdb": true}[DB.Dialector.Name()] type UserMigrateColumn struct { ID uint Name string Salary float64 Birthday time.Time `gorm:"precision:4"` } DB.Migrator().DropTable(&UserMigrateColumn{}) DB.AutoMigrate(&UserMigrateColumn{}) type UserMigrateColumn2 struct { ID uint Name string `gorm:"size:128"` Salary float64 `gorm:"precision:2"` Birthday time.Time `gorm:"precision:2"` NameIgnoreMigration string `gorm:"size:100"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) if err != nil { t.Fatalf("failed to get column types, got error: %v", err) } for _, columnType := range columnTypes { switch columnType.Name() { case "name": if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 128 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": if precision, o, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("salary's precision should be 2, but got %v %v", precision, o) } case "birthday": if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 2 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } } } type UserMigrateColumn3 struct { ID uint Name string `gorm:"size:256"` Salary float64 `gorm:"precision:3"` Birthday time.Time `gorm:"precision:3"` NameIgnoreMigration string `gorm:"size:128;-:migration"` } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn3{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } columnTypes, err = DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn{}) if err != nil { t.Fatalf("failed to get column types, got error: %v", err) } for _, columnType := range columnTypes { switch columnType.Name() { case "name": if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 256 { t.Fatalf("name's length should be 128, but got %v", length) } case "salary": if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("salary's precision should be 2, but got %v", precision) } case "birthday": if precision, _, _ := columnType.DecimalSize(); (fullSupported || precision != 0) && precision != 3 { t.Fatalf("birthday's precision should be 2, but got %v", precision) } case "name_ignore_migration": if length, _ := columnType.Length(); (fullSupported || length != 0) && length != 100 { t.Fatalf("name_ignore_migration's length should still be 100 but got %v", length) } } } } func TestMigrateWithColumnComment(t *testing.T) { type UserWithColumnComment struct { gorm.Model Name string `gorm:"size:111;comment:this is a 字段"` } if err := DB.Migrator().DropTable(&UserWithColumnComment{}); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } if err := DB.AutoMigrate(&UserWithColumnComment{}); err != nil { t.Fatalf("Failed to auto migrate, but got error %v", err) } } func TestMigrateWithIndexComment(t *testing.T) { if DB.Dialector.Name() != "mysql" { t.Skip() } type UserWithIndexComment struct { gorm.Model Name string `gorm:"size:111;index:,comment:这是一个index"` } if err := DB.Migrator().DropTable(&UserWithIndexComment{}); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } if err := DB.AutoMigrate(&UserWithIndexComment{}); err != nil { t.Fatalf("Failed to auto migrate, but got error %v", err) } } func TestMigrateWithUniqueIndex(t *testing.T) { type UserWithUniqueIndex struct { ID int Name string `gorm:"size:20;index:idx_name,unique"` Date time.Time `gorm:"index:idx_name,unique"` UName string `gorm:"uniqueIndex;size:255"` } DB.Migrator().DropTable(&UserWithUniqueIndex{}) if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { t.Fatalf("failed to migrate, got %v", err) } if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_name") { t.Errorf("Failed to find created index") } if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { t.Errorf("Failed to find created index") } if err := DB.AutoMigrate(&UserWithUniqueIndex{}); err != nil { t.Fatalf("failed to migrate, got %v", err) } if !DB.Migrator().HasIndex(&UserWithUniqueIndex{}, "idx_user_with_unique_indices_u_name") { t.Errorf("Failed to find created index") } } func TestMigrateTable(t *testing.T) { type TableStruct struct { gorm.Model Name string } DB.Migrator().DropTable(&TableStruct{}) DB.AutoMigrate(&TableStruct{}) if !DB.Migrator().HasTable(&TableStruct{}) { t.Fatalf("should found created table") } type NewTableStruct struct { gorm.Model Name string } if err := DB.Migrator().RenameTable(&TableStruct{}, &NewTableStruct{}); err != nil { t.Fatalf("Failed to rename table, got error %v", err) } if !DB.Migrator().HasTable("new_table_structs") { t.Fatal("should found renamed table") } DB.Migrator().DropTable("new_table_structs") if DB.Migrator().HasTable(&NewTableStruct{}) { t.Fatal("should not found dropped table") } } func TestMigrateWithQuotedIndex(t *testing.T) { if DB.Dialector.Name() != "mysql" { t.Skip() } type QuotedIndexStruct struct { gorm.Model Name string `gorm:"size:255;index:AS"` // AS is one of MySQL reserved words } if err := DB.Migrator().DropTable(&QuotedIndexStruct{}); err != nil { t.Fatalf("Failed to drop table, got error %v", err) } if err := DB.AutoMigrate(&QuotedIndexStruct{}); err != nil { t.Fatalf("Failed to auto migrate, but got error %v", err) } } func TestMigrateIndexes(t *testing.T) { type IndexStruct struct { gorm.Model Name string `gorm:"size:255;index"` } DB.Migrator().DropTable(&IndexStruct{}) DB.AutoMigrate(&IndexStruct{}) if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { t.Fatalf("Failed to drop index for user's name, got err %v", err) } if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { t.Fatalf("Got error when tried to create index: %+v", err) } if !DB.Migrator().HasIndex(&IndexStruct{}, "Name") { t.Fatalf("Failed to find index for user's name") } if err := DB.Migrator().DropIndex(&IndexStruct{}, "Name"); err != nil { t.Fatalf("Failed to drop index for user's name, got err %v", err) } if DB.Migrator().HasIndex(&IndexStruct{}, "Name") { t.Fatalf("Should not find index for user's name after delete") } if err := DB.Migrator().CreateIndex(&IndexStruct{}, "Name"); err != nil { t.Fatalf("Got error when tried to create index: %+v", err) } if err := DB.Migrator().RenameIndex(&IndexStruct{}, "idx_index_structs_name", "idx_users_name_1"); err != nil { t.Fatalf("no error should happen when rename index, but got %v", err) } if !DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { t.Fatalf("Should find index for user's name after rename") } if err := DB.Migrator().DropIndex(&IndexStruct{}, "idx_users_name_1"); err != nil { t.Fatalf("Failed to drop index for user's name, got err %v", err) } if DB.Migrator().HasIndex(&IndexStruct{}, "idx_users_name_1") { t.Fatalf("Should not find index for user's name after delete") } } func TestTiDBMigrateColumns(t *testing.T) { if !isTiDB() { t.Skip() } // TiDB can't change column constraint and has auto_random feature type ColumnStruct struct { ID int `gorm:"primarykey;default:auto_random()"` Name string Age int `gorm:"default:18;comment:my age"` Code string `gorm:"unique;comment:my code;"` Code2 string Code3 string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { t.Errorf("Failed to migrate, got %v", err) } type ColumnStruct2 struct { ID int `gorm:"primarykey;default:auto_random()"` Name string `gorm:"size:100"` Code string `gorm:"unique;comment:my code2;default:hello"` Code2 string `gorm:"comment:my code2;default:hello"` } if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { t.Fatalf("no error should happened when alter column, but got %v", err) } if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { t.Fatalf("no error should happened when auto migrate column, but got %v", err) } if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { stmt := &gorm.Statement{DB: DB} stmt.Parse(&ColumnStruct2{}) for _, columnType := range columnTypes { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !ok || length != 100 { t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.Comment(); !ok || v != "my age" { t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.DefaultValue(); !ok || v != "hello" { t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code2": // Code2 string `gorm:"comment:my code2;default:hello"` if v, ok := columnType.DefaultValue(); !ok || v != "hello" { t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !ok || v != "my code2" { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } } } } type NewColumnStruct struct { gorm.Model Name string NewName string } if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { t.Fatalf("Failed to find added column") } if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { t.Fatalf("Found deleted column") } if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { t.Fatalf("Failed to found renamed column") } if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { t.Fatalf("Found deleted column") } } func TestMigrateColumns(t *testing.T) { tidbSkip(t, "use another test case") sqlite := DB.Dialector.Name() == "sqlite" sqlserver := DB.Dialector.Name() == "sqlserver" type ColumnStruct struct { gorm.Model Name string Age int `gorm:"default:18;comment:my age"` Code string `gorm:"unique;comment:my code;"` Code2 string Code3 string `gorm:"unique"` } DB.Migrator().DropTable(&ColumnStruct{}) if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { t.Errorf("Failed to migrate, got %v", err) } type ColumnStruct2 struct { gorm.Model Name string `gorm:"size:100"` Code string `gorm:"unique;comment:my code2;default:hello"` Code2 string `gorm:"unique"` // Code3 string } if err := DB.Table("column_structs").Migrator().AlterColumn(&ColumnStruct{}, "Name"); err != nil { t.Fatalf("no error should happened when alter column, but got %v", err) } if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { t.Fatalf("no error should happened when auto migrate column, but got %v", err) } if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { stmt := &gorm.Statement{DB: DB} stmt.Parse(&ColumnStruct2{}) for _, columnType := range columnTypes { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "name": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } if length, ok := columnType.Length(); !sqlite && (!ok || length != 100) { t.Fatalf("column name length should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), length, 100, columnType) } case "age": if v, ok := columnType.DefaultValue(); !ok || v != "18" { t.Fatalf("column age default value should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my age") { t.Fatalf("column age comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code": if v, ok := columnType.Unique(); !ok || !v { t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } if v, ok := columnType.DefaultValue(); !sqlserver && (!ok || v != "hello") { t.Fatalf("column code default value should be correct, name: %v, column: %#v, default value: %v", columnType.Name(), columnType, v) } if v, ok := columnType.Comment(); !sqlite && !sqlserver && (!ok || v != "my code2") { t.Fatalf("column code comment should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code2": if v, ok := columnType.Unique(); !sqlserver && (!ok || !v) { t.Fatalf("column code2 unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "code3": // TODO // if v, ok := columnType.Unique(); !ok || v { // t.Fatalf("column code unique should be correct, name: %v, column: %#v", columnType.Name(), columnType) // } } } } type NewColumnStruct struct { gorm.Model Name string NewName string } if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { t.Fatalf("Failed to find added column") } if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "NewName"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "NewName") { t.Fatalf("Found deleted column") } if err := DB.Table("column_structs").Migrator().AddColumn(&NewColumnStruct{}, "NewName"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if err := DB.Table("column_structs").Migrator().RenameColumn(&NewColumnStruct{}, "NewName", "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if !DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { t.Fatalf("Failed to found renamed column") } if err := DB.Table("column_structs").Migrator().DropColumn(&NewColumnStruct{}, "new_new_name"); err != nil { t.Fatalf("Failed to add column, got %v", err) } if DB.Table("column_structs").Migrator().HasColumn(&NewColumnStruct{}, "new_new_name") { t.Fatalf("Found deleted column") } } func TestMigrateConstraint(t *testing.T) { names := []string{"Account", "fk_users_account", "Pets", "fk_users_pets", "Company", "fk_users_company", "Team", "fk_users_team", "Languages", "fk_users_languages"} for _, name := range names { if !DB.Migrator().HasConstraint(&User{}, name) { DB.Migrator().CreateConstraint(&User{}, name) } if err := DB.Migrator().DropConstraint(&User{}, name); err != nil { t.Fatalf("failed to drop constraint %v, got error %v", name, err) } if DB.Migrator().HasConstraint(&User{}, name) { t.Fatalf("constraint %v should been deleted", name) } if err := DB.Migrator().CreateConstraint(&User{}, name); err != nil { t.Fatalf("failed to create constraint %v, got error %v", name, err) } if !DB.Migrator().HasConstraint(&User{}, name) { t.Fatalf("failed to found constraint %v", name) } } } type DynamicUser struct { gorm.Model Name string CompanyID string `gorm:"index"` } // To test auto migrate crate indexes for dynamic table name // https://github.com/go-gorm/gorm/issues/4752 func TestMigrateIndexesWithDynamicTableName(t *testing.T) { // Create primary table if err := DB.AutoMigrate(&DynamicUser{}); err != nil { t.Fatalf("AutoMigrate create table error: %#v", err) } // Create sub tables for _, v := range []string{"01", "02", "03"} { tableName := "dynamic_users_" + v m := DB.Scopes(func(db *gorm.DB) *gorm.DB { return db.Table(tableName) }).Migrator() if err := m.AutoMigrate(&DynamicUser{}); err != nil { t.Fatalf("AutoMigrate create table error: %#v", err) } if !m.HasTable(tableName) { t.Fatalf("AutoMigrate expected %#v exist, but not.", tableName) } if !m.HasIndex(&DynamicUser{}, "CompanyID") { t.Fatalf("Should have index on %s", "CompanyI.") } if !m.HasIndex(&DynamicUser{}, "DeletedAt") { t.Fatalf("Should have index on deleted_at.") } } } // check column order after migration, flaky test // https://github.com/go-gorm/gorm/issues/4351 func TestMigrateColumnOrder(t *testing.T) { type UserMigrateColumn struct { ID uint } DB.Migrator().DropTable(&UserMigrateColumn{}) DB.AutoMigrate(&UserMigrateColumn{}) type UserMigrateColumn2 struct { ID uint F1 string F2 string F3 string F4 string F5 string F6 string F7 string F8 string F9 string F10 string F11 string F12 string F13 string F14 string F15 string F16 string F17 string F18 string F19 string F20 string F21 string F22 string F23 string F24 string F25 string F26 string F27 string F28 string F29 string F30 string F31 string F32 string F33 string F34 string F35 string } if err := DB.Table("user_migrate_columns").AutoMigrate(&UserMigrateColumn2{}); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } columnTypes, err := DB.Table("user_migrate_columns").Migrator().ColumnTypes(&UserMigrateColumn2{}) if err != nil { t.Fatalf("failed to get column types, got error: %v", err) } typ := reflect.Indirect(reflect.ValueOf(&UserMigrateColumn2{})).Type() numField := typ.NumField() if numField != len(columnTypes) { t.Fatalf("column's number not match struct and ddl, %d != %d", numField, len(columnTypes)) } namer := schema.NamingStrategy{} for i := 0; i < numField; i++ { expectName := namer.ColumnName("", typ.Field(i).Name) if columnTypes[i].Name() != expectName { t.Fatalf("column order not match struct and ddl, idx %d: %s != %s", i, columnTypes[i].Name(), expectName) } } } // https://github.com/go-gorm/gorm/issues/5047 func TestMigrateSerialColumn(t *testing.T) { if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" { return } type Event struct { ID uint `gorm:"primarykey"` UID uint32 } type Event1 struct { ID uint `gorm:"primarykey"` UID uint32 `gorm:"not null;autoIncrement"` } type Event2 struct { ID uint `gorm:"primarykey"` UID uint16 `gorm:"not null;autoIncrement"` } var err error err = DB.Migrator().DropTable(&Event{}) if err != nil { t.Errorf("DropTable err:%v", err) } // create sequence err = DB.Table("events").AutoMigrate(&Event1{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } // delete sequence err = DB.Table("events").AutoMigrate(&Event{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } // update sequence err = DB.Table("events").AutoMigrate(&Event1{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } err = DB.Table("events").AutoMigrate(&Event2{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } DB.Table("events").Save(&Event2{}) DB.Table("events").Save(&Event2{}) DB.Table("events").Save(&Event2{}) events := make([]*Event, 0) DB.Table("events").Find(&events) AssertEqual(t, 3, len(events)) for _, v := range events { AssertEqual(t, v.ID, v.UID) } } // https://github.com/go-gorm/gorm/issues/5300 func TestMigrateWithSpecialName(t *testing.T) { var err error err = DB.AutoMigrate(&Coupon{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } err = DB.Table("coupon_product_1").AutoMigrate(&CouponProduct{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } err = DB.Table("coupon_product_2").AutoMigrate(&CouponProduct{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } AssertEqual(t, true, DB.Migrator().HasTable("coupons")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_1")) AssertEqual(t, true, DB.Migrator().HasTable("coupon_product_2")) } // https://github.com/go-gorm/gorm/issues/4760 func TestMigrateAutoIncrement(t *testing.T) { type AutoIncrementStruct struct { ID int64 `gorm:"primarykey;autoIncrement"` Field1 uint32 `gorm:"column:field1"` Field2 float32 `gorm:"column:field2"` } if err := DB.AutoMigrate(&AutoIncrementStruct{}); err != nil { t.Fatalf("AutoMigrate err: %v", err) } const ROWS = 10 for idx := 0; idx < ROWS; idx++ { if err := DB.Create(&AutoIncrementStruct{}).Error; err != nil { t.Fatalf("create auto_increment_struct fail, err: %v", err) } } rows := make([]*AutoIncrementStruct, 0, ROWS) if err := DB.Order("id ASC").Find(&rows).Error; err != nil { t.Fatalf("find auto_increment_struct fail, err: %v", err) } ids := make([]int64, 0, len(rows)) for _, row := range rows { ids = append(ids, row.ID) } lastID := ids[len(ids)-1] if err := DB.Where("id IN (?)", ids).Delete(&AutoIncrementStruct{}).Error; err != nil { t.Fatalf("delete auto_increment_struct fail, err: %v", err) } newRow := &AutoIncrementStruct{} if err := DB.Create(newRow).Error; err != nil { t.Fatalf("create auto_increment_struct fail, err: %v", err) } AssertEqual(t, newRow.ID, lastID+1) } // https://github.com/go-gorm/gorm/issues/5320 func TestPrimarykeyID(t *testing.T) { if DB.Dialector.Name() != "postgres" { return } type MissPKLanguage struct { ID string `gorm:"type:uuid;default:uuid_generate_v4()"` Name string } type MissPKUser struct { ID string `gorm:"type:uuid;default:uuid_generate_v4()"` MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` } var err error err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) if err != nil { t.Fatalf("DropTable err:%v", err) } DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } // patch err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } } func TestPrimarykeyIDGaussDB(t *testing.T) { t.Skipf("This test case skipped, because of gaussdb not support uuid-ossp plugin (SQLSTATE 58P01)") if DB.Dialector.Name() != "gaussdb" { return } type MissPKLanguage struct { ID string `gorm:"type:uuid;default:uuid_generate_v4()"` Name string } type MissPKUser struct { ID string `gorm:"type:uuid;default:uuid_generate_v4()"` MissPKLanguages []MissPKLanguage `gorm:"many2many:miss_pk_user_languages;"` } var err error err = DB.Migrator().DropTable(&MissPKUser{}, &MissPKLanguage{}) if err != nil { t.Fatalf("DropTable err:%v", err) } // TODO: ERROR: could not open extension control file: No such file or directory (SQLSTATE 58P01) DB.Exec(`CREATE EXTENSION IF NOT EXISTS "uuid-ossp";`) err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } // patch err = DB.AutoMigrate(&MissPKUser{}, &MissPKLanguage{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } } func TestCurrentTimestamp(t *testing.T) { if DB.Dialector.Name() != "mysql" { return } type CurrentTimestampTest struct { ID string `gorm:"primary_key"` TimeAt *time.Time `gorm:"type:datetime;not null;default:CURRENT_TIMESTAMP;unique"` } var err error err = DB.Migrator().DropTable(&CurrentTimestampTest{}) if err != nil { t.Errorf("DropTable err:%v", err) } err = DB.AutoMigrate(&CurrentTimestampTest{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } err = DB.AutoMigrate(&CurrentTimestampTest{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } AssertEqual(t, true, DB.Migrator().HasConstraint(&CurrentTimestampTest{}, "uni_current_timestamp_tests_time_at")) AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at")) AssertEqual(t, false, DB.Migrator().HasIndex(&CurrentTimestampTest{}, "time_at_2")) } func TestUniqueColumn(t *testing.T) { if DB.Dialector.Name() != "mysql" { return } type UniqueTest struct { ID string `gorm:"primary_key"` Name string `gorm:"unique"` } type UniqueTest2 struct { ID string `gorm:"primary_key"` Name string `gorm:"unique;default:NULL"` } type UniqueTest3 struct { ID string `gorm:"primary_key"` Name string `gorm:"unique;default:''"` } type UniqueTest4 struct { ID string `gorm:"primary_key"` Name string `gorm:"unique;default:'123'"` } var err error err = DB.Migrator().DropTable(&UniqueTest{}) if err != nil { t.Errorf("DropTable err:%v", err) } err = DB.AutoMigrate(&UniqueTest{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } // null -> null err = DB.AutoMigrate(&UniqueTest{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } ct, err := findColumnType(&UniqueTest{}, "name") if err != nil { t.Fatalf("findColumnType err:%v", err) } value, ok := ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, false, ok) // null -> null err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } // not trigger alert column AssertEqual(t, true, DB.Migrator().HasConstraint(&UniqueTest{}, "uni_unique_tests_name")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_1")) AssertEqual(t, false, DB.Migrator().HasIndex(&UniqueTest{}, "name_2")) ct, err = findColumnType(&UniqueTest{}, "name") if err != nil { t.Fatalf("findColumnType err:%v", err) } value, ok = ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, false, ok) tidbSkip(t, "can't change column constraint") // null -> empty string err = DB.Table("unique_tests").AutoMigrate(&UniqueTest3{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } ct, err = findColumnType(&UniqueTest{}, "name") if err != nil { t.Fatalf("findColumnType err:%v", err) } value, ok = ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, true, ok) // empty string -> 123 err = DB.Table("unique_tests").AutoMigrate(&UniqueTest4{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } ct, err = findColumnType(&UniqueTest{}, "name") if err != nil { t.Fatalf("findColumnType err:%v", err) } value, ok = ct.DefaultValue() AssertEqual(t, "123", value) AssertEqual(t, true, ok) // 123 -> null err = DB.Table("unique_tests").AutoMigrate(&UniqueTest2{}) if err != nil { t.Fatalf("AutoMigrate err:%v", err) } ct, err = findColumnType(&UniqueTest{}, "name") if err != nil { t.Fatalf("findColumnType err:%v", err) } value, ok = ct.DefaultValue() AssertEqual(t, "", value) AssertEqual(t, false, ok) } func findColumnType(dest interface{}, columnName string) ( foundColumn gorm.ColumnType, err error, ) { columnTypes, err := DB.Migrator().ColumnTypes(dest) if err != nil { err = fmt.Errorf("ColumnTypes err:%v", err) return } for _, c := range columnTypes { if c.Name() == columnName { foundColumn = c break } } return } func TestInvalidCachedPlanSimpleProtocol(t *testing.T) { if DB.Dialector.Name() != "postgres" { return } db, err := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) if err != nil { t.Errorf("Open err:%v", err) } type Object1 struct{} type Object2 struct { Field1 string } type Object3 struct { Field2 string } db.Migrator().DropTable("objects") err = db.Table("objects").AutoMigrate(&Object1{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } err = db.Table("objects").AutoMigrate(&Object2{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } err = db.Table("objects").AutoMigrate(&Object3{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } } // TODO: ERROR: must have at least one column (SQLSTATE 0A000) func TestInvalidCachedPlanSimpleProtocolGaussDB(t *testing.T) { t.Skipf("This test case skipped, because of gaussdb not support creaing empty table(SQLSTATE 0A000)") if DB.Dialector.Name() != "gaussdb" { return } db, err := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{}) if err != nil { t.Errorf("Open err:%v", err) } type Object1 struct{} type Object2 struct { Field1 string } type Object3 struct { Field2 string } db.Migrator().DropTable("objects") err = db.Table("objects").AutoMigrate(&Object1{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } err = db.Table("objects").AutoMigrate(&Object2{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } err = db.Table("objects").AutoMigrate(&Object3{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } } func TestDifferentTypeWithoutDeclaredLength(t *testing.T) { type DiffType struct { ID uint Name string `gorm:"type:varchar(20)"` } type DiffType1 struct { ID uint Name string `gorm:"type:text"` } var err error DB.Migrator().DropTable(&DiffType{}) err = DB.AutoMigrate(&DiffType{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } ct, err := findColumnType(&DiffType{}, "name") if err != nil { t.Errorf("findColumnType err:%v", err) } AssertEqual(t, "varchar", strings.ToLower(ct.DatabaseTypeName())) err = DB.Table("diff_types").AutoMigrate(&DiffType1{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } ct, err = findColumnType(&DiffType{}, "name") if err != nil { t.Errorf("findColumnType err:%v", err) } AssertEqual(t, "text", strings.ToLower(ct.DatabaseTypeName())) } func TestMigrateArrayTypeModel(t *testing.T) { if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" { return } type ArrayTypeModel struct { ID uint Number string `gorm:"type:varchar(51);NOT NULL"` TextArray []string `gorm:"type:text[];NOT NULL"` NestedTextArray [][]string `gorm:"type:text[][]"` NestedIntArray [][]int64 `gorm:"type:integer[3][3]"` } var err error DB.Migrator().DropTable(&ArrayTypeModel{}) err = DB.AutoMigrate(&ArrayTypeModel{}) AssertEqual(t, nil, err) ct, err := findColumnType(&ArrayTypeModel{}, "number") AssertEqual(t, nil, err) AssertEqual(t, "varchar", ct.DatabaseTypeName()) ct, err = findColumnType(&ArrayTypeModel{}, "text_array") AssertEqual(t, nil, err) AssertEqual(t, "text[]", ct.DatabaseTypeName()) ct, err = findColumnType(&ArrayTypeModel{}, "nested_text_array") AssertEqual(t, nil, err) AssertEqual(t, "text[]", ct.DatabaseTypeName()) ct, err = findColumnType(&ArrayTypeModel{}, "nested_int_array") AssertEqual(t, nil, err) AssertEqual(t, "integer[]", ct.DatabaseTypeName()) } type mockMigrator struct { gorm.Migrator } func (mm mockMigrator) AlterColumn(dst interface{}, field string) error { err := mm.Migrator.AlterColumn(dst, field) if err != nil { return err } return fmt.Errorf("trigger alter column error, field: %s", field) } func TestMigrateDonotAlterColumn(t *testing.T) { wrapMockMigrator := func(m gorm.Migrator) mockMigrator { return mockMigrator{ Migrator: m, } } m := DB.Migrator() mockM := wrapMockMigrator(m) type NotTriggerUpdate struct { ID uint F1 uint16 F2 uint32 F3 int F4 int64 F5 string F6 float32 F7 float64 F8 time.Time F9 bool F10 []byte } var err error err = mockM.DropTable(&NotTriggerUpdate{}) AssertEqual(t, err, nil) err = mockM.AutoMigrate(&NotTriggerUpdate{}) AssertEqual(t, err, nil) err = mockM.AutoMigrate(&NotTriggerUpdate{}) AssertEqual(t, err, nil) } func TestMigrateSameEmbeddedFieldName(t *testing.T) { type UserStat struct { GroundDestroyCount int } type GameUser struct { gorm.Model StatAb UserStat `gorm:"embedded;embeddedPrefix:stat_ab_"` } type UserStat1 struct { GroundDestroyCount string } type GroundRate struct { GroundDestroyCount int } type GameUser1 struct { gorm.Model StatAb UserStat1 `gorm:"embedded;embeddedPrefix:stat_ab_"` GroundRateRb GroundRate `gorm:"embedded;embeddedPrefix:rate_ground_rb_"` } DB.Migrator().DropTable(&GameUser{}) err := DB.AutoMigrate(&GameUser{}) AssertEqual(t, nil, err) err = DB.Table("game_users").AutoMigrate(&GameUser1{}) AssertEqual(t, nil, err) _, err = findColumnType(&GameUser{}, "stat_ab_ground_destroy_count") AssertEqual(t, nil, err) _, err = findColumnType(&GameUser{}, "rate_ground_rb_ground_destroy_count") AssertEqual(t, nil, err) } func TestMigrateWithDefaultValue(t *testing.T) { if DB.Dialector.Name() == "sqlserver" { // sqlserver driver treats NULL and 'NULL' the same t.Skip("skip sqlserver") } type NullModel struct { ID uint Content string `gorm:"default:null"` } type NullStringModel struct { ID uint Content string `gorm:"default:'null'"` Active bool `gorm:"default:false"` } tableName := "null_string_model" DB.Migrator().DropTable(tableName) err := DB.Table(tableName).AutoMigrate(&NullModel{}) AssertEqual(t, err, nil) // default null -> 'null' err = DB.Table(tableName).AutoMigrate(&NullStringModel{}) AssertEqual(t, err, nil) columnType, err := findColumnType(tableName, "content") AssertEqual(t, err, nil) defVal, ok := columnType.DefaultValue() AssertEqual(t, defVal, "null") AssertEqual(t, ok, true) columnType2, err := findColumnType(tableName, "active") AssertEqual(t, err, nil) defVal, ok = columnType2.DefaultValue() bv, _ := strconv.ParseBool(defVal) AssertEqual(t, bv, false) AssertEqual(t, ok, true) // default 'null' -> 'null' session := DB.Session(&gorm.Session{Logger: Tracer{ Logger: DB.Config.Logger, Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() if strings.HasPrefix(sql, "ALTER TABLE") { t.Errorf("shouldn't execute: sql=%s", sql) } }, }}) err = session.Table(tableName).AutoMigrate(&NullStringModel{}) AssertEqual(t, err, nil) columnType, err = findColumnType(tableName, "content") AssertEqual(t, err, nil) defVal, ok = columnType.DefaultValue() AssertEqual(t, defVal, "null") AssertEqual(t, ok, true) // default 'null' -> null err = DB.Table(tableName).AutoMigrate(&NullModel{}) AssertEqual(t, err, nil) columnType, err = findColumnType(tableName, "content") AssertEqual(t, err, nil) defVal, ok = columnType.DefaultValue() AssertEqual(t, defVal, "") AssertEqual(t, ok, false) } func TestMigrateMySQLWithCustomizedTypes(t *testing.T) { if DB.Dialector.Name() != "mysql" { t.Skip() } type MyTable struct { Def string `gorm:"size:512;index:idx_def,unique"` Abc string `gorm:"size:65000000"` } DB.Migrator().DropTable("my_tables") sql := "CREATE TABLE `my_tables` (`def` varchar(512),`abc` longtext,UNIQUE INDEX `idx_def` (`def`))" if err := DB.Exec(sql).Error; err != nil { t.Errorf("Failed, got error: %v", err) } session := DB.Session(&gorm.Session{Logger: Tracer{ Logger: DB.Config.Logger, Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() if strings.HasPrefix(sql, "ALTER TABLE") { t.Errorf("shouldn't execute: sql=%s", sql) } }, }}) if err := session.AutoMigrate(&MyTable{}); err != nil { t.Errorf("Failed, got error: %v", err) } } func TestMigrateIgnoreRelations(t *testing.T) { type RelationModel1 struct { ID uint } type RelationModel2 struct { ID uint } type RelationModel3 struct { ID uint RelationModel1ID uint RelationModel1 *RelationModel1 RelationModel2ID uint RelationModel2 *RelationModel2 `gorm:"-:migration"` } var err error _ = DB.Migrator().DropTable(&RelationModel1{}, &RelationModel2{}, &RelationModel3{}) tx := DB.Session(&gorm.Session{}) tx.IgnoreRelationshipsWhenMigrating = true err = tx.AutoMigrate(&RelationModel3{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } // RelationModel3 should be existed _, err = findColumnType(&RelationModel3{}, "id") AssertEqual(t, nil, err) // RelationModel1 should not be existed _, err = findColumnType(&RelationModel1{}, "id") if err == nil { t.Errorf("RelationModel1 should not be migrated") } // RelationModel2 should not be existed _, err = findColumnType(&RelationModel2{}, "id") if err == nil { t.Errorf("RelationModel2 should not be migrated") } tx.IgnoreRelationshipsWhenMigrating = false err = tx.AutoMigrate(&RelationModel3{}) if err != nil { t.Errorf("AutoMigrate err:%v", err) } // RelationModel3 should be existed _, err = findColumnType(&RelationModel3{}, "id") AssertEqual(t, nil, err) // RelationModel1 should be existed _, err = findColumnType(&RelationModel1{}, "id") AssertEqual(t, nil, err) // RelationModel2 should not be existed _, err = findColumnType(&RelationModel2{}, "id") if err == nil { t.Errorf("RelationModel2 should not be migrated") } } func TestMigrateView(t *testing.T) { DB.Save(GetUser("joins-args-db", Config{Pets: 2})) if err := DB.Migrator().CreateView("invalid_users_pets", gorm.ViewOption{Query: nil}); err != gorm.ErrSubQueryRequired { t.Fatalf("no view should be created, got %v", err) } query := DB.Model(&User{}). Select("users.id as users_id, users.name as users_name, pets.id as pets_id, pets.name as pets_name"). Joins("inner join pets on pets.user_id = users.id") if err := DB.Migrator().CreateView("users_pets", gorm.ViewOption{Query: query}); err != nil { t.Fatalf("Failed to crate view, got %v", err) } var count int64 if err := DB.Table("users_pets").Count(&count).Error; err != nil { t.Fatalf("should found created view") } if err := DB.Migrator().DropView("users_pets"); err != nil { t.Fatalf("Failed to drop view, got %v", err) } query = DB.Model(&User{}).Where("age > ?", 20) if err := DB.Migrator().CreateView("users_view", gorm.ViewOption{Query: query}); err != nil { t.Fatalf("Failed to crate view, got %v", err) } if err := DB.Migrator().DropView("users_view"); err != nil { t.Fatalf("Failed to drop view, got %v", err) } } func TestMigrateExistingBoolColumnPGAndGaussDB(t *testing.T) { if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" { return } type ColumnStruct struct { gorm.Model Name string StringBool string SmallintBool int `gorm:"type:smallint"` } type ColumnStruct2 struct { gorm.Model Name string StringBool bool // change existing boolean column from string to boolean SmallintBool bool // change existing boolean column from smallint or other to boolean } DB.Migrator().DropTable(&ColumnStruct{}) if err := DB.AutoMigrate(&ColumnStruct{}); err != nil { t.Errorf("Failed to migrate, got %v", err) } if err := DB.Table("column_structs").AutoMigrate(&ColumnStruct2{}); err != nil { t.Fatalf("no error should happened when auto migrate column, but got %v", err) } if columnTypes, err := DB.Migrator().ColumnTypes(&ColumnStruct{}); err != nil { t.Fatalf("no error should returns for ColumnTypes") } else { stmt := &gorm.Statement{DB: DB} stmt.Parse(&ColumnStruct2{}) for _, columnType := range columnTypes { switch columnType.Name() { case "id": if v, ok := columnType.PrimaryKey(); !ok || !v { t.Fatalf("column id primary key should be correct, name: %v, column: %#v", columnType.Name(), columnType) } case "string_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } case "smallint_bool": dataType := DB.Dialector.DataTypeOf(stmt.Schema.LookUpField(columnType.Name())) if !strings.Contains(strings.ToUpper(dataType), strings.ToUpper(columnType.DatabaseTypeName())) { t.Fatalf("column name type should be correct, name: %v, length: %v, expects: %v, column: %#v", columnType.Name(), columnType.DatabaseTypeName(), dataType, columnType) } } } } } func TestTableType(t *testing.T) { // currently it is only supported for mysql driver if !isMysql() { return } const tblName = "cities" const tblSchema = "gorm" const tblType = "BASE TABLE" const tblComment = "foobar comment" type City struct { gorm.Model Name string `gorm:"unique"` } DB.Migrator().DropTable(&City{}) if err := DB.Set("gorm:table_options", fmt.Sprintf("ENGINE InnoDB COMMENT '%s'", tblComment)).AutoMigrate(&City{}); err != nil { t.Fatalf("failed to migrate cities tables, got error: %v", err) } tableType, err := DB.Table("cities").Migrator().TableType(&City{}) if err != nil { t.Fatalf("failed to get table type, got error %v", err) } if tableType.Schema() != tblSchema { t.Fatalf("expected tblSchema to be %s but got %s", tblSchema, tableType.Schema()) } if tableType.Name() != tblName { t.Fatalf("expected table name to be %s but got %s", tblName, tableType.Name()) } if tableType.Type() != tblType { t.Fatalf("expected table type to be %s but got %s", tblType, tableType.Type()) } comment, ok := tableType.Comment() if !ok || comment != tblComment { t.Fatalf("expected comment %s got %s", tblComment, comment) } } func TestMigrateWithUniqueIndexAndUnique(t *testing.T) { const table = "unique_struct" checkField := func(model interface{}, fieldName string, unique bool, uniqueIndex string) { stmt := &gorm.Statement{DB: DB} err := stmt.Parse(model) if err != nil { t.Fatalf("%v: failed to parse schema, got error: %v", utils.FileWithLineNum(), err) } _ = stmt.Schema.ParseIndexes() field := stmt.Schema.LookUpField(fieldName) if field == nil { t.Fatalf("%v: failed to find column %q", utils.FileWithLineNum(), fieldName) } if field.Unique != unique { t.Fatalf("%v: %q column %q unique should be %v but got %v", utils.FileWithLineNum(), stmt.Schema.Table, fieldName, unique, field.Unique) } if field.UniqueIndex != uniqueIndex { t.Fatalf("%v: %q column %q uniqueIndex should be %v but got %v", utils.FileWithLineNum(), stmt.Schema, fieldName, uniqueIndex, field.UniqueIndex) } } type ( // not unique UniqueStruct1 struct { Name string `gorm:"size:10"` } UniqueStruct2 struct { Name string `gorm:"size:20"` } ) checkField(&UniqueStruct1{}, "name", false, "") checkField(&UniqueStruct2{}, "name", false, "") type ( // unique UniqueStruct3 struct { Name string `gorm:"size:30;unique"` } UniqueStruct4 struct { Name string `gorm:"size:40;unique"` } ) checkField(&UniqueStruct3{}, "name", true, "") checkField(&UniqueStruct4{}, "name", true, "") type ( // uniqueIndex UniqueStruct5 struct { Name string `gorm:"size:50;uniqueIndex"` } UniqueStruct6 struct { Name string `gorm:"size:60;uniqueIndex"` } UniqueStruct7 struct { Name string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` NickName string `gorm:"size:70;uniqueIndex:idx_us6_all_names"` } ) checkField(&UniqueStruct5{}, "name", false, "idx_unique_struct5_name") checkField(&UniqueStruct6{}, "name", false, "idx_unique_struct6_name") checkField(&UniqueStruct7{}, "name", false, "") checkField(&UniqueStruct7{}, "nick_name", false, "") checkField(&UniqueStruct7{}, "nick_name", false, "") type UniqueStruct8 struct { // unique and uniqueIndex Name string `gorm:"size:60;unique;index:my_us8_index,unique;"` } checkField(&UniqueStruct8{}, "name", true, "my_us8_index") type TestCase struct { name string from, to interface{} checkFunc func(t *testing.T) } checkColumnType := func(t *testing.T, fieldName string, unique bool) { columnTypes, err := DB.Migrator().ColumnTypes(table) if err != nil { t.Fatalf("%v: failed to get column types, got error: %v", utils.FileWithLineNum(), err) } var found gorm.ColumnType for _, columnType := range columnTypes { if columnType.Name() == fieldName { found = columnType } } if found == nil { t.Fatalf("%v: failed to find column type %q", utils.FileWithLineNum(), fieldName) } if actualUnique, ok := found.Unique(); !ok || actualUnique != unique { t.Fatalf("%v: column %q unique should be %v but got %v", utils.FileWithLineNum(), fieldName, unique, actualUnique) } } checkIndex := func(t *testing.T, expected []gorm.Index) { indexes, err := DB.Migrator().GetIndexes(table) if err != nil { t.Fatalf("%v: failed to get indexes, got error: %v", utils.FileWithLineNum(), err) } assert.ElementsMatch(t, expected, indexes) } uniqueIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.IndexName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} myIndex := &migrator.Index{TableName: table, NameValue: "my_us8_index", ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} mulIndex := &migrator.Index{TableName: table, NameValue: "idx_us6_all_names", ColumnList: []string{"name", "nick_name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} var checkNotUnique, checkUnique, checkUniqueIndex, checkMyIndex, checkMulIndex func(t *testing.T) // UniqueAffectedByUniqueIndex is true if DB.Dialector.Name() == "mysql" { uniqueConstraintIndex := &migrator.Index{TableName: table, NameValue: DB.Config.NamingStrategy.UniqueName(table, "name"), ColumnList: []string{"name"}, PrimaryKeyValue: sql.NullBool{Bool: false, Valid: true}, UniqueValue: sql.NullBool{Bool: true, Valid: true}} checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) checkIndex(t, nil) } checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) checkIndex(t, []gorm.Index{uniqueConstraintIndex}) } checkUniqueIndex = func(t *testing.T) { checkColumnType(t, "name", true) checkIndex(t, []gorm.Index{uniqueIndex}) } checkMyIndex = func(t *testing.T) { checkColumnType(t, "name", true) checkIndex(t, []gorm.Index{uniqueConstraintIndex, myIndex}) } checkMulIndex = func(t *testing.T) { checkColumnType(t, "name", false) checkColumnType(t, "nick_name", false) checkIndex(t, []gorm.Index{mulIndex}) } } else { checkNotUnique = func(t *testing.T) { checkColumnType(t, "name", false) } checkUnique = func(t *testing.T) { checkColumnType(t, "name", true) } checkUniqueIndex = func(t *testing.T) { checkColumnType(t, "name", false) checkIndex(t, []gorm.Index{uniqueIndex}) } checkMyIndex = func(t *testing.T) { checkColumnType(t, "name", true) if !DB.Migrator().HasIndex(table, myIndex.Name()) { t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), myIndex.Name()) } } checkMulIndex = func(t *testing.T) { checkColumnType(t, "name", false) checkColumnType(t, "nick_name", false) if !DB.Migrator().HasIndex(table, mulIndex.Name()) { t.Errorf("%v: should has index %s but not", utils.FileWithLineNum(), mulIndex.Name()) } } } tests := []TestCase{ {name: "notUnique to notUnique", from: &UniqueStruct1{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, {name: "notUnique to unique", from: &UniqueStruct1{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, {name: "notUnique to uniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, {name: "notUnique to uniqueAndUniqueIndex", from: &UniqueStruct1{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, {name: "unique to unique", from: &UniqueStruct3{}, to: &UniqueStruct4{}, checkFunc: checkUnique}, {name: "unique to uniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct5{}, checkFunc: checkUniqueIndex}, {name: "unique to uniqueAndUniqueIndex", from: &UniqueStruct3{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, {name: "uniqueIndex to uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct6{}, checkFunc: checkUniqueIndex}, {name: "uniqueIndex to uniqueAndUniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct8{}, checkFunc: checkMyIndex}, {name: "uniqueIndex to multi uniqueIndex", from: &UniqueStruct5{}, to: &UniqueStruct7{}, checkFunc: checkMulIndex}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if err := DB.Migrator().DropTable(table); err != nil { t.Fatalf("failed to drop table, got error: %v", err) } if err := DB.Table(table).AutoMigrate(test.from); err != nil { t.Fatalf("failed to migrate table, got error: %v", err) } if err := DB.Table(table).AutoMigrate(test.to); err != nil { t.Fatalf("failed to migrate table, got error: %v", err) } test.checkFunc(t) }) } if DB.Dialector.Name() != "sqlserver" { // In SQLServer, If an index or constraint depends on the column, // this column will not be able to run ALTER // see https://stackoverflow.com/questions/19460912/the-object-df-is-dependent-on-column-changing-int-to-double/19461205#19461205 // may we need to create another PR to fix it, see https://github.com/go-gorm/sqlserver/pull/106 tests = []TestCase{ {name: "unique to notUnique", from: &UniqueStruct3{}, to: &UniqueStruct1{}, checkFunc: checkNotUnique}, {name: "uniqueIndex to notUnique", from: &UniqueStruct5{}, to: &UniqueStruct2{}, checkFunc: checkNotUnique}, {name: "uniqueIndex to unique", from: &UniqueStruct5{}, to: &UniqueStruct3{}, checkFunc: checkUnique}, } } if DB.Dialector.Name() == "mysql" { compatibilityTests := []TestCase{ {name: "oldUnique to notUnique", to: UniqueStruct1{}, checkFunc: checkNotUnique}, {name: "oldUnique to unique", to: UniqueStruct3{}, checkFunc: checkUnique}, {name: "oldUnique to uniqueIndex", to: UniqueStruct5{}, checkFunc: checkUniqueIndex}, {name: "oldUnique to uniqueAndUniqueIndex", to: UniqueStruct8{}, checkFunc: checkMyIndex}, } for _, test := range compatibilityTests { t.Run(test.name, func(t *testing.T) { if err := DB.Migrator().DropTable(table); err != nil { t.Fatalf("failed to drop table, got error: %v", err) } if err := DB.Exec("CREATE TABLE ? (`name` varchar(10) UNIQUE)", clause.Table{Name: table}).Error; err != nil { t.Fatalf("failed to create table, got error: %v", err) } if err := DB.Table(table).AutoMigrate(test.to); err != nil { t.Fatalf("failed to migrate table, got error: %v", err) } test.checkFunc(t) }) } } } func testAutoMigrateDecimal(t *testing.T, model1, model2 any) []string { tracer := Tracer{ Logger: DB.Config.Logger, Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() if strings.HasPrefix(sql, "ALTER TABLE ") { t.Fatalf("shouldn't execute ALTER COLUMN TYPE if decimal is not change: sql: %s", sql) } }, } session := DB.Session(&gorm.Session{Logger: tracer}) DB.Migrator().DropTable(model1) var modifySql []string if err := session.AutoMigrate(model1); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } if err := session.AutoMigrate(model1); err != nil { t.Fatalf("failed to auto migrate, got error: %v", err) } tracer2 := Tracer{ Logger: DB.Config.Logger, Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() modifySql = append(modifySql, sql) }, } session2 := DB.Session(&gorm.Session{Logger: tracer2}) err := session2.Table("migrate_decimal_columns").Migrator().AutoMigrate(model2) if err != nil { t.Fatalf("failed to get column types, got error: %v", err) } return modifySql } func decimalColumnsTest[T, T2 any](t *testing.T, expectedSql []string) { var t1 T var t2 T2 modSql := testAutoMigrateDecimal(t, t1, t2) var alterSQL []string for _, sql := range modSql { if strings.HasPrefix(sql, "ALTER TABLE ") { alterSQL = append(alterSQL, sql) } } if len(alterSQL) != 3 { t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL) } for i := range alterSQL { if alterSQL[i] != expectedSql[i] { t.Fatalf("decimal changed error,expected: %+v,got: %+v.", expectedSql, alterSQL) } } } func TestAutoMigrateDecimal(t *testing.T) { if DB.Dialector.Name() == "sqlserver" { // database/sql will replace numeric to decimal. so only support decimal. type MigrateDecimalColumn struct { RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"` RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"` RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"` } type MigrateDecimalColumn2 struct { RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"` RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"` RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"` } expectedSql := []string{ `ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" decimal(8) NOT NULL`, `ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" decimal(9,1) NOT NULL`, `ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" decimal(9,2) NOT NULL`, } decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql) } else if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" { type MigrateDecimalColumn struct { RecID1 int64 `gorm:"column:recid1;type:numeric(9,0);not null" json:"recid1"` RecID2 int64 `gorm:"column:recid2;type:numeric(8);not null" json:"recid2"` RecID3 int64 `gorm:"column:recid3;type:numeric(8,1);not null" json:"recid3"` } type MigrateDecimalColumn2 struct { RecID1 int64 `gorm:"column:recid1;type:numeric(8);not null" json:"recid1"` RecID2 int64 `gorm:"column:recid2;type:numeric(9,1);not null" json:"recid2"` RecID3 int64 `gorm:"column:recid3;type:numeric(9,2);not null" json:"recid3"` } expectedSql := []string{ `ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid1" TYPE numeric(8) USING "recid1"::numeric(8)`, `ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid2" TYPE numeric(9,1) USING "recid2"::numeric(9,1)`, `ALTER TABLE "migrate_decimal_columns" ALTER COLUMN "recid3" TYPE numeric(9,2) USING "recid3"::numeric(9,2)`, } decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql) } else if DB.Dialector.Name() == "mysql" { type MigrateDecimalColumn struct { RecID1 int64 `gorm:"column:recid1;type:decimal(9,0);not null" json:"recid1"` RecID2 int64 `gorm:"column:recid2;type:decimal(8);not null" json:"recid2"` RecID3 int64 `gorm:"column:recid3;type:decimal(8,1);not null" json:"recid3"` } type MigrateDecimalColumn2 struct { RecID1 int64 `gorm:"column:recid1;type:decimal(8);not null" json:"recid1"` RecID2 int64 `gorm:"column:recid2;type:decimal(9,1);not null" json:"recid2"` RecID3 int64 `gorm:"column:recid3;type:decimal(9,2);not null" json:"recid3"` } expectedSql := []string{ "ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid1` decimal(8) NOT NULL", "ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid2` decimal(9,1) NOT NULL", "ALTER TABLE `migrate_decimal_columns` MODIFY COLUMN `recid3` decimal(9,2) NOT NULL", } decimalColumnsTest[MigrateDecimalColumn, MigrateDecimalColumn2](t, expectedSql) } } ================================================ FILE: tests/multi_primary_keys_test.go ================================================ package tests_test import ( "reflect" "sort" "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) type Blog struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Subject string Body string Tags []Tag `gorm:"many2many:blog_tags;"` SharedTags []Tag `gorm:"many2many:shared_blog_tags;ForeignKey:id;References:id"` LocaleTags []Tag `gorm:"many2many:locale_blog_tags;ForeignKey:id,locale;References:id"` } type Tag struct { ID uint `gorm:"primary_key"` Locale string `gorm:"primary_key"` Value string Blogs []*Blog `gorm:"many2many:blog_tags"` } func compareTags(tags []Tag, contents []string) bool { var tagContents []string for _, tag := range tags { tagContents = append(tagContents, tag.Value) } sort.Strings(tagContents) sort.Strings(contents) return reflect.DeepEqual(tagContents, contents) } func TestManyToManyWithMultiPrimaryKeys(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" || name == "gaussdb" { stmt := gorm.Statement{DB: DB} stmt.Parse(&Blog{}) stmt.Schema.LookUpField("ID").Unique = true stmt.Parse(&Tag{}) stmt.Schema.LookUpField("ID").Unique = true // postgers only allow unique constraint matching given keys } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } blog := Blog{ Locale: "ZH", Subject: "subject", Body: "body", Tags: []Tag{ {Locale: "ZH", Value: "tag1"}, {Locale: "ZH", Value: "tag2"}, }, } DB.Save(&blog) if !compareTags(blog.Tags, []string{"tag1", "tag2"}) { t.Fatalf("Blog should has two tags") } // Append tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("Tags").Append([]*Tag{tag3}) if !compareTags(blog.Tags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Blog should has three tags after Append") } if count := DB.Model(&blog).Association("Tags").Count(); count != 3 { t.Fatalf("Blog should has 3 tags after Append, got %v", count) } var tags []Tag DB.Model(&blog).Association("Tags").Find(&tags) if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Should find 3 tags") } var blog1 Blog DB.Preload("Tags").Find(&blog1) if !compareTags(blog1.Tags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Preload many2many relations") } // Replace tag5 := &Tag{Locale: "ZH", Value: "tag5"} tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog).Association("Tags").Replace(tag5, tag6) var tags2 []Tag DB.Model(&blog).Association("Tags").Find(&tags2) if !compareTags(tags2, []string{"tag5", "tag6"}) { t.Fatalf("Should find 2 tags after Replace") } if DB.Model(&blog).Association("Tags").Count() != 2 { t.Fatalf("Blog should has three tags after Replace") } // Delete DB.Model(&blog).Association("Tags").Delete(tag5) var tags3 []Tag DB.Model(&blog).Association("Tags").Find(&tags3) if !compareTags(tags3, []string{"tag6"}) { t.Fatalf("Should find 1 tags after Delete") } if DB.Model(&blog).Association("Tags").Count() != 1 { t.Fatalf("Blog should has three tags after Delete") } DB.Model(&blog).Association("Tags").Delete(tag3) var tags4 []Tag DB.Model(&blog).Association("Tags").Find(&tags4) if !compareTags(tags4, []string{"tag6"}) { t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") } // Clear DB.Model(&blog).Association("Tags").Clear() if DB.Model(&blog).Association("Tags").Count() != 0 { t.Fatalf("All tags should be cleared") } } func TestManyToManyWithCustomizedForeignKeys(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } if name := DB.Dialector.Name(); name == "postgres" { t.Skip("skip postgres due to it only allow unique constraint matching given keys") } if name := DB.Dialector.Name(); name == "gaussdb" { t.Skip("skip gaussdb due to it only allow unique constraint matching given keys") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } blog := Blog{ Locale: "ZH", Subject: "subject", Body: "body", SharedTags: []Tag{ {Locale: "ZH", Value: "tag1"}, {Locale: "ZH", Value: "tag2"}, }, } DB.Save(&blog) blog2 := Blog{ ID: blog.ID, Locale: "EN", } DB.Create(&blog2) if !compareTags(blog.SharedTags, []string{"tag1", "tag2"}) { t.Fatalf("Blog should has two tags") } // Append tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("SharedTags").Append([]*Tag{tag3}) if !compareTags(blog.SharedTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Blog should has three tags after Append") } if DB.Model(&blog).Association("SharedTags").Count() != 3 { t.Fatalf("Blog should has three tags after Append") } if DB.Model(&blog2).Association("SharedTags").Count() != 3 { t.Fatalf("Blog should has three tags after Append") } var tags []Tag DB.Model(&blog).Association("SharedTags").Find(&tags) if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Should find 3 tags") } DB.Model(&blog2).Association("SharedTags").Find(&tags) if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Should find 3 tags") } var blog1 Blog DB.Preload("SharedTags").Find(&blog1) if !compareTags(blog1.SharedTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Preload many2many relations") } tag4 := &Tag{Locale: "ZH", Value: "tag4"} DB.Model(&blog2).Association("SharedTags").Append(tag4) DB.Model(&blog).Association("SharedTags").Find(&tags) if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { t.Fatalf("Should find 3 tags") } DB.Model(&blog2).Association("SharedTags").Find(&tags) if !compareTags(tags, []string{"tag1", "tag2", "tag3", "tag4"}) { t.Fatalf("Should find 3 tags") } // Replace tag5 := &Tag{Locale: "ZH", Value: "tag5"} tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog2).Association("SharedTags").Replace(tag5, tag6) var tags2 []Tag DB.Model(&blog).Association("SharedTags").Find(&tags2) if !compareTags(tags2, []string{"tag5", "tag6"}) { t.Fatalf("Should find 2 tags after Replace") } DB.Model(&blog2).Association("SharedTags").Find(&tags2) if !compareTags(tags2, []string{"tag5", "tag6"}) { t.Fatalf("Should find 2 tags after Replace") } if DB.Model(&blog).Association("SharedTags").Count() != 2 { t.Fatalf("Blog should has three tags after Replace") } // Delete DB.Model(&blog).Association("SharedTags").Delete(tag5) var tags3 []Tag DB.Model(&blog).Association("SharedTags").Find(&tags3) if !compareTags(tags3, []string{"tag6"}) { t.Fatalf("Should find 1 tags after Delete") } if DB.Model(&blog).Association("SharedTags").Count() != 1 { t.Fatalf("Blog should has three tags after Delete") } DB.Model(&blog2).Association("SharedTags").Delete(tag3) var tags4 []Tag DB.Model(&blog).Association("SharedTags").Find(&tags4) if !compareTags(tags4, []string{"tag6"}) { t.Fatalf("Tag should not be deleted when Delete with a unrelated tag") } // Clear DB.Model(&blog2).Association("SharedTags").Clear() if DB.Model(&blog).Association("SharedTags").Count() != 0 { t.Fatalf("All tags should be cleared") } } func TestManyToManyWithCustomizedForeignKeys2(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } if name := DB.Dialector.Name(); name == "postgres" || name == "mysql" { t.Skip("skip postgres due to it only allow unique constraint matching given keys") } if name := DB.Dialector.Name(); name == "gaussdb" { t.Skip("skip gaussdb due to it only allow unique constraint matching given keys") } DB.Migrator().DropTable(&Blog{}, &Tag{}, "blog_tags", "locale_blog_tags", "shared_blog_tags") if err := DB.AutoMigrate(&Blog{}, &Tag{}); err != nil { t.Fatalf("Failed to auto migrate, got error: %v", err) } blog := Blog{ Locale: "ZH", Subject: "subject", Body: "body", LocaleTags: []Tag{ {Locale: "ZH", Value: "tag1"}, {Locale: "ZH", Value: "tag2"}, }, } DB.Save(&blog) blog2 := Blog{ ID: blog.ID, Locale: "EN", } DB.Create(&blog2) // Append tag3 := &Tag{Locale: "ZH", Value: "tag3"} DB.Model(&blog).Association("LocaleTags").Append([]*Tag{tag3}) if !compareTags(blog.LocaleTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Blog should has three tags after Append") } if DB.Model(&blog).Association("LocaleTags").Count() != 3 { t.Fatalf("Blog should has three tags after Append") } if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { t.Fatalf("EN Blog should has 0 tags after ZH Blog Append") } var tags []Tag DB.Model(&blog).Association("LocaleTags").Find(&tags) if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Should find 3 tags") } DB.Model(&blog2).Association("LocaleTags").Find(&tags) if len(tags) != 0 { t.Fatalf("Should find 0 tags for EN Blog") } var blog1 Blog DB.Preload("LocaleTags").Find(&blog1, "locale = ? AND id = ?", "ZH", blog.ID) if !compareTags(blog1.LocaleTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Preload many2many relations") } tag4 := &Tag{Locale: "ZH", Value: "tag4"} DB.Model(&blog2).Association("LocaleTags").Append(tag4) DB.Model(&blog).Association("LocaleTags").Find(&tags) if !compareTags(tags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("Should find 3 tags for EN Blog") } DB.Model(&blog2).Association("LocaleTags").Find(&tags) if !compareTags(tags, []string{"tag4"}) { t.Fatalf("Should find 1 tags for EN Blog, but got %v", tags) } // Replace tag5 := &Tag{Locale: "ZH", Value: "tag5"} tag6 := &Tag{Locale: "ZH", Value: "tag6"} DB.Model(&blog2).Association("LocaleTags").Replace(tag5, tag6) var tags2 []Tag DB.Model(&blog).Association("LocaleTags").Find(&tags2) if !compareTags(tags2, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") } var blog11 Blog DB.Preload("LocaleTags").First(&blog11, "id = ? AND locale = ?", blog.ID, blog.Locale) if !compareTags(blog11.LocaleTags, []string{"tag1", "tag2", "tag3"}) { t.Fatalf("CN Blog's tags should not be changed after EN Blog Replace") } DB.Model(&blog2).Association("LocaleTags").Find(&tags2) if !compareTags(tags2, []string{"tag5", "tag6"}) { t.Fatalf("Should find 2 tags after Replace") } var blog21 Blog DB.Preload("LocaleTags").First(&blog21, "id = ? AND locale = ?", blog2.ID, blog2.Locale) if !compareTags(blog21.LocaleTags, []string{"tag5", "tag6"}) { t.Fatalf("EN Blog's tags should be changed after Replace") } if DB.Model(&blog).Association("LocaleTags").Count() != 3 { t.Fatalf("ZH Blog should has three tags after Replace") } if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { t.Fatalf("EN Blog should has two tags after Replace") } // Delete DB.Model(&blog).Association("LocaleTags").Delete(tag5) if DB.Model(&blog).Association("LocaleTags").Count() != 3 { t.Fatalf("ZH Blog should has three tags after Delete with EN's tag") } if DB.Model(&blog2).Association("LocaleTags").Count() != 2 { t.Fatalf("EN Blog should has two tags after ZH Blog Delete with EN's tag") } DB.Model(&blog2).Association("LocaleTags").Delete(tag5) if DB.Model(&blog).Association("LocaleTags").Count() != 3 { t.Fatalf("ZH Blog should has three tags after EN Blog Delete with EN's tag") } if DB.Model(&blog2).Association("LocaleTags").Count() != 1 { t.Fatalf("EN Blog should has 1 tags after EN Blog Delete with EN's tag") } // Clear DB.Model(&blog2).Association("LocaleTags").Clear() if DB.Model(&blog).Association("LocaleTags").Count() != 3 { t.Fatalf("ZH Blog's tags should not be cleared when clear EN Blog's tags") } if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { t.Fatalf("EN Blog's tags should be cleared when clear EN Blog's tags") } DB.Model(&blog).Association("LocaleTags").Clear() if DB.Model(&blog).Association("LocaleTags").Count() != 0 { t.Fatalf("ZH Blog's tags should be cleared when clear ZH Blog's tags") } if DB.Model(&blog2).Association("LocaleTags").Count() != 0 { t.Fatalf("EN Blog's tags should be cleared") } } func TestCompositePrimaryKeysAssociations(t *testing.T) { type Label struct { BookID *uint `gorm:"primarykey"` Name string `gorm:"primarykey"` Value string } type Book struct { ID int Name string Labels []Label } DB.Migrator().DropTable(&Label{}, &Book{}) if err := DB.AutoMigrate(&Label{}, &Book{}); err != nil { t.Fatalf("failed to migrate, got %v", err) } book := Book{ Name: "my book", Labels: []Label{ {Name: "region", Value: "emea"}, }, } DB.Create(&book) var result Book if err := DB.Preload("Labels").First(&result, book.ID).Error; err != nil { t.Fatalf("failed to preload, got error %v", err) } AssertEqual(t, book, result) } ================================================ FILE: tests/named_argument_test.go ================================================ package tests_test import ( "database/sql" "errors" "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestNamedArg(t *testing.T) { type NamedUser struct { gorm.Model Name1 string Name2 string Name3 string } DB.Migrator().DropTable(&NamedUser{}) DB.AutoMigrate(&NamedUser{}) namedUser := NamedUser{Name1: "jinzhu1", Name2: "jinzhu2", Name3: "jinzhu3"} DB.Create(&namedUser) var result NamedUser DB.First(&result, "name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")) AssertEqual(t, result, namedUser) var result2 NamedUser DB.Where("name1 = @name OR name2 = @name OR name3 = @name", sql.Named("name", "jinzhu2")).First(&result2) AssertEqual(t, result2, namedUser) var result3 NamedUser DB.Where("name1 = @name OR name2 = @name OR name3 = @name", map[string]interface{}{"name": "jinzhu2"}).First(&result3) AssertEqual(t, result3, namedUser) var result4 NamedUser if err := DB.Raw("SELECT * FROM named_users WHERE name1 = @name OR name2 = @name2 OR name3 = @name", sql.Named("name", "jinzhu-none"), sql.Named("name2", "jinzhu2")).Find(&result4).Error; err != nil { t.Errorf("failed to update with named arg") } AssertEqual(t, result4, namedUser) if err := DB.Exec("UPDATE named_users SET name1 = @name, name2 = @name2, name3 = @name", sql.Named("name", "jinzhu-new"), sql.Named("name2", "jinzhu-new2")).Error; err != nil { t.Errorf("failed to update with named arg") } namedUser.Name1 = "jinzhu-new" namedUser.Name2 = "jinzhu-new2" namedUser.Name3 = "jinzhu-new" var result5 NamedUser if err := DB.Raw("SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2", map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result5).Error; err != nil { t.Errorf("failed to update with named arg") } AssertEqual(t, result5, namedUser) var result6 NamedUser if err := DB.Raw(`SELECT * FROM named_users WHERE (name1 = @name AND name3 = @name) AND name2 = @name2`, map[string]interface{}{"name": "jinzhu-new", "name2": "jinzhu-new2"}).Find(&result6).Error; err != nil { t.Errorf("failed to update with named arg") } AssertEqual(t, result6, namedUser) var result7 NamedUser if err := DB.Where("name1 = @name OR name2 = @name", sql.Named("name", "jinzhu-new")).Where("name3 = 'jinzhu-new3'").First(&result7).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should return record not found error, but got %v", err) } DB.Delete(&namedUser) var result8 NamedUser if err := DB.Where("name1 = @name OR name2 = @name", map[string]interface{}{"name": "jinzhu-new"}).First(&result8).Error; err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("should return record not found error, but got %v", err) } } ================================================ FILE: tests/named_polymorphic_test.go ================================================ package tests_test import ( "testing" . "gorm.io/gorm/utils/tests" ) type Hamster struct { Id int Name string PreferredToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_preferred"` OtherToy Toy `gorm:"polymorphic:Owner;polymorphicValue:hamster_other"` } func TestNamedPolymorphic(t *testing.T) { DB.Migrator().DropTable(&Hamster{}) DB.AutoMigrate(&Hamster{}) hamster := Hamster{Name: "Mr. Hammond", PreferredToy: Toy{Name: "bike"}, OtherToy: Toy{Name: "treadmill"}} DB.Save(&hamster) hamster2 := Hamster{} DB.Preload("PreferredToy").Preload("OtherToy").Find(&hamster2, hamster.Id) if hamster2.PreferredToy.ID != hamster.PreferredToy.ID || hamster2.PreferredToy.Name != hamster.PreferredToy.Name { t.Errorf("Hamster's preferred toy failed to preload") } if hamster2.OtherToy.ID != hamster.OtherToy.ID || hamster2.OtherToy.Name != hamster.OtherToy.Name { t.Errorf("Hamster's other toy failed to preload") } // clear to omit Toy.ID in count hamster2.PreferredToy = Toy{} hamster2.OtherToy = Toy{} if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { t.Errorf("Hamster's preferred toy count should be 1") } if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { t.Errorf("Hamster's other toy count should be 1") } // Query hamsterToy := Toy{} DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) if hamsterToy.Name != hamster.PreferredToy.Name { t.Errorf("Should find has one polymorphic association") } hamsterToy = Toy{} DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) if hamsterToy.Name != hamster.OtherToy.Name { t.Errorf("Should find has one polymorphic association") } // Append DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ Name: "bike 2", }) DB.Model(&hamster).Association("OtherToy").Append(&Toy{ Name: "treadmill 2", }) hamsterToy = Toy{} DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) if hamsterToy.Name != "bike 2" { t.Errorf("Should update has one polymorphic association with Append") } hamsterToy = Toy{} DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) if hamsterToy.Name != "treadmill 2" { t.Errorf("Should update has one polymorphic association with Append") } if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { t.Errorf("Hamster's toys count should be 1 after Append") } if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { t.Errorf("Hamster's toys count should be 1 after Append") } // Replace DB.Model(&hamster).Association("PreferredToy").Replace(&Toy{ Name: "bike 3", }) DB.Model(&hamster).Association("OtherToy").Replace(&Toy{ Name: "treadmill 3", }) hamsterToy = Toy{} DB.Model(&hamster).Association("PreferredToy").Find(&hamsterToy) if hamsterToy.Name != "bike 3" { t.Errorf("Should update has one polymorphic association with Replace") } hamsterToy = Toy{} DB.Model(&hamster).Association("OtherToy").Find(&hamsterToy) if hamsterToy.Name != "treadmill 3" { t.Errorf("Should update has one polymorphic association with Replace") } if DB.Model(&hamster2).Association("PreferredToy").Count() != 1 { t.Errorf("hamster's toys count should be 1 after Replace") } if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { t.Errorf("hamster's toys count should be 1 after Replace") } // Clear DB.Model(&hamster).Association("PreferredToy").Append(&Toy{ Name: "bike 2", }) DB.Model(&hamster).Association("OtherToy").Append(&Toy{ Name: "treadmill 2", }) if DB.Model(&hamster).Association("PreferredToy").Count() != 1 { t.Errorf("Hamster's toys should be added with Append") } if DB.Model(&hamster).Association("OtherToy").Count() != 1 { t.Errorf("Hamster's toys should be added with Append") } DB.Model(&hamster).Association("PreferredToy").Clear() if DB.Model(&hamster2).Association("PreferredToy").Count() != 0 { t.Errorf("Hamster's preferred toy should be cleared with Clear") } if DB.Model(&hamster2).Association("OtherToy").Count() != 1 { t.Errorf("Hamster's other toy should be still available") } DB.Model(&hamster).Association("OtherToy").Clear() if DB.Model(&hamster).Association("OtherToy").Count() != 0 { t.Errorf("Hamster's other toy should be cleared with Clear") } } ================================================ FILE: tests/non_std_test.go ================================================ package tests_test import ( "testing" "time" ) type Animal struct { Counter uint64 `gorm:"primary_key:yes"` Name string `gorm:"DEFAULT:'galeone'"` From string // test reserved sql keyword as field name Age *time.Time unexported string // unexported value CreatedAt time.Time UpdatedAt time.Time } func TestNonStdPrimaryKeyAndDefaultValues(t *testing.T) { DB.Migrator().DropTable(&Animal{}) if err := DB.AutoMigrate(&Animal{}); err != nil { t.Fatalf("no error should happen when migrate but got %v", err) } animal := Animal{Name: "Ferdinand"} DB.Save(&animal) updatedAt1 := animal.UpdatedAt DB.Save(&animal).Update("name", "Francis") if updatedAt1.Format(time.RFC3339Nano) == animal.UpdatedAt.Format(time.RFC3339Nano) { t.Errorf("UpdatedAt should be updated") } var animals []Animal DB.Find(&animals) if count := DB.Model(Animal{}).Where("1=1").Update("CreatedAt", time.Now().Add(2*time.Hour)).RowsAffected; count != int64(len(animals)) { t.Error("RowsAffected should be correct when do batch update") } animal = Animal{From: "somewhere"} // No name fields, should be filled with the default value (galeone) DB.Save(&animal).Update("From", "a nice place") // The name field should be untouched DB.First(&animal, animal.Counter) if animal.Name != "galeone" { t.Errorf("Name fields shouldn't be changed if untouched, but got %v", animal.Name) } // When changing a field with a default value, the change must occur animal.Name = "amazing horse" DB.Save(&animal) DB.First(&animal, animal.Counter) if animal.Name != "amazing horse" { t.Errorf("Update a filed with a default value should occur. But got %v\n", animal.Name) } // When changing a field with a default value with blank value animal.Name = "" DB.Save(&animal) DB.First(&animal, animal.Counter) if animal.Name != "" { t.Errorf("Update a filed to blank with a default value should occur. But got %v\n", animal.Name) } } ================================================ FILE: tests/postgres_test.go ================================================ package tests_test import ( "testing" "time" "github.com/google/uuid" "github.com/lib/pq" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestPostgresReturningIDWhichHasStringType(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() } type Yasuo struct { ID string `gorm:"default:gen_random_uuid()"` Name string CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { t.Errorf("Failed to create extension pgcrypto, got error %v", err) } DB.Migrator().DropTable(&Yasuo{}) if err := DB.AutoMigrate(&Yasuo{}); err != nil { t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) } yasuo := Yasuo{Name: "jinzhu"} if err := DB.Create(&yasuo).Error; err != nil { t.Fatalf("should be able to create data, but got %v", err) } if yasuo.ID == "" { t.Fatal("should be able to has ID, but got zero value") } var result Yasuo if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } if err := DB.Where("id = $1", yasuo.ID).First(&Yasuo{}).Error; err != nil || yasuo.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } yasuo.Name = "jinzhu1" if err := DB.Save(&yasuo).Error; err != nil { t.Errorf("Failed to update date, got error %v", err) } if err := DB.First(&result, "id = ?", yasuo.ID).Error; err != nil || yasuo.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } } func TestPostgres(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() } type Harumph struct { gorm.Model Name string `gorm:"check:name_checker,name <> ''"` Test uuid.UUID `gorm:"type:uuid;not null;default:gen_random_uuid()"` CreatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE"` UpdatedAt time.Time `gorm:"type:TIMESTAMP WITHOUT TIME ZONE;default:current_timestamp"` Things pq.StringArray `gorm:"type:text[]"` } if err := DB.Exec("CREATE EXTENSION IF NOT EXISTS pgcrypto;").Error; err != nil { t.Errorf("Failed to create extension pgcrypto, got error %v", err) } DB.Migrator().DropTable(&Harumph{}) if err := DB.AutoMigrate(&Harumph{}); err != nil { t.Fatalf("Failed to migrate for uuid default value, got error: %v", err) } harumph := Harumph{} if err := DB.Create(&harumph).Error; err == nil { t.Fatalf("should failed to create data, name can't be blank") } harumph = Harumph{Name: "jinzhu"} if err := DB.Create(&harumph).Error; err != nil { t.Fatalf("should be able to create data, but got %v", err) } var result Harumph if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } if err := DB.Where("id = $1", harumph.ID).First(&Harumph{}).Error; err != nil || harumph.Name != "jinzhu" { t.Errorf("No error should happen, but got %v", err) } harumph.Name = "jinzhu1" if err := DB.Save(&harumph).Error; err != nil { t.Errorf("Failed to update date, got error %v", err) } if err := DB.First(&result, "id = ?", harumph.ID).Error; err != nil || harumph.Name != "jinzhu1" { t.Errorf("No error should happen, but got %v", err) } DB.Migrator().DropTable("log_usage") if err := DB.Exec(` CREATE TABLE public.log_usage ( log_id bigint NOT NULL ); ALTER TABLE public.log_usage ALTER COLUMN log_id ADD GENERATED BY DEFAULT AS IDENTITY ( SEQUENCE NAME public.log_usage_log_id_seq START WITH 1 INCREMENT BY 1 NO MINVALUE NO MAXVALUE CACHE 1 ); `).Error; err != nil { t.Fatalf("failed to create table, got error %v", err) } columns, err := DB.Migrator().ColumnTypes("log_usage") if err != nil { t.Fatalf("failed to get columns, got error %v", err) } hasLogID := false for _, column := range columns { if column.Name() == "log_id" { hasLogID = true autoIncrement, ok := column.AutoIncrement() if !ok || !autoIncrement { t.Fatalf("column log_id should be auto incrementment") } } } if !hasLogID { t.Fatalf("failed to found column log_id") } } type Post struct { ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Categories []*Category `gorm:"Many2Many:post_categories"` } type Category struct { ID uuid.UUID `gorm:"primary_key;type:uuid;default:uuid_generate_v4();"` Title string Posts []*Post `gorm:"Many2Many:post_categories"` } func TestMany2ManyWithDefaultValueUUID(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() } if err := DB.Exec(`create extension if not exists "uuid-ossp"`).Error; err != nil { t.Fatalf("Failed to create 'uuid-ossp' extension, but got error %v", err) } DB.Migrator().DropTable(&Post{}, &Category{}, "post_categories") DB.AutoMigrate(&Post{}, &Category{}) post := Post{ Title: "Hello World", Categories: []*Category{ {Title: "Coding"}, {Title: "Golang"}, }, } if err := DB.Create(&post).Error; err != nil { t.Errorf("Failed, got error: %v", err) } } func TestPostgresOnConstraint(t *testing.T) { if DB.Dialector.Name() != "postgres" { t.Skip() } type Thing struct { gorm.Model SomeID string OtherID string Data string } DB.Migrator().DropTable(&Thing{}) DB.Migrator().CreateTable(&Thing{}) if err := DB.Exec("ALTER TABLE things ADD CONSTRAINT some_id_other_id_unique UNIQUE (some_id, other_id)").Error; err != nil { t.Error(err) } thing := Thing{ SomeID: "1234", OtherID: "1234", Data: "something", } DB.Create(&thing) thing2 := Thing{ SomeID: "1234", OtherID: "1234", Data: "something else", } result := DB.Clauses(clause.OnConflict{ OnConstraint: "some_id_other_id_unique", UpdateAll: true, }).Create(&thing2) if result.Error != nil { t.Errorf("creating second thing: %v", result.Error) } var things []Thing if err := DB.Find(&things).Error; err != nil { t.Errorf("Failed, got error: %v", err) } if len(things) > 1 { t.Errorf("expected 1 thing got more") } } type CompanyNew struct { ID int Name int } func TestAlterColumnDataType(t *testing.T) { DB.AutoMigrate(Company{}) if err := DB.Table("companies").Migrator().AlterColumn(CompanyNew{}, "name"); err != nil { t.Fatalf("failed to alter column from string to int, got error %v", err) } DB.AutoMigrate(Company{}) } ================================================ FILE: tests/preload_suits_test.go ================================================ package tests_test import ( "database/sql" "encoding/json" "reflect" "sort" "sync/atomic" "testing" "gorm.io/gorm" ) func toJSONString(v interface{}) []byte { r, _ := json.Marshal(v) return r } func TestNestedPreload1(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1 Level1 Level3ID uint } Level3 struct { ID uint Name string Level2 Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level3{Level2: Level2{Level1: Level1{Value: "value"}}} if err := DB.Create(&want).Error; err != nil { t.Error(err) } var got Level3 if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } if err := DB.Preload("Level2").Preload("Level2.Level1").First(&got, "name = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } func TestNestedPreload2(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1s []*Level1 Level3ID uint } Level3 struct { ID uint Name string Level2s []Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level3{ Level2s: []Level2{ { Level1s: []*Level1{ {Value: "value1"}, {Value: "value2"}, }, }, { Level1s: []*Level1{ {Value: "value3"}, }, }, }, } if err := DB.Create(&want).Error; err != nil { t.Error(err) } var got Level3 if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestNestedPreload3(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1 Level1 Level3ID uint } Level3 struct { Name string ID uint Level2s []Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level3{ Level2s: []Level2{ {Level1: Level1{Value: "value1"}}, {Level1: Level1{Value: "value2"}}, }, } if err := DB.Create(&want).Error; err != nil { t.Error(err) } var got Level3 if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestNestedPreload4(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1s []Level1 Level3ID uint } Level3 struct { ID uint Name string Level2 Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level3{ Level2: Level2{ Level1s: []Level1{ {Value: "value1"}, {Value: "value2"}, }, }, } if err := DB.Create(&want).Error; err != nil { t.Error(err) } var got Level3 if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } // Slice: []Level3 func TestNestedPreload5(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1 Level1 Level3ID uint } Level3 struct { ID uint Name string Level2 Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := make([]Level3, 2) want[0] = Level3{Level2: Level2{Level1: Level1{Value: "value"}}} if err := DB.Create(&want[0]).Error; err != nil { t.Error(err) } want[1] = Level3{Level2: Level2{Level1: Level1{Value: "value2"}}} if err := DB.Create(&want[1]).Error; err != nil { t.Error(err) } var got []Level3 if err := DB.Preload("Level2").Preload("Level2.Level1").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestNestedPreload6(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1s []Level1 Level3ID uint } Level3 struct { ID uint Name string Level2s []Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := make([]Level3, 2) want[0] = Level3{ Level2s: []Level2{ { Level1s: []Level1{ {Value: "value1"}, {Value: "value2"}, }, }, { Level1s: []Level1{ {Value: "value3"}, }, }, }, } if err := DB.Create(&want[0]).Error; err != nil { t.Error(err) } want[1] = Level3{ Level2s: []Level2{ { Level1s: []Level1{ {Value: "value3"}, {Value: "value4"}, }, }, { Level1s: []Level1{ {Value: "value5"}, }, }, }, } if err := DB.Create(&want[1]).Error; err != nil { t.Error(err) } var got []Level3 if err := DB.Preload("Level2s.Level1s").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestNestedPreload7(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1 Level1 Level3ID uint } Level3 struct { ID uint Name string Level2s []Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := make([]Level3, 2) want[0] = Level3{ Level2s: []Level2{ {Level1: Level1{Value: "value1"}}, {Level1: Level1{Value: "value2"}}, }, } if err := DB.Create(&want[0]).Error; err != nil { t.Error(err) } want[1] = Level3{ Level2s: []Level2{ {Level1: Level1{Value: "value3"}}, {Level1: Level1{Value: "value4"}}, }, } if err := DB.Create(&want[1]).Error; err != nil { t.Error(err) } var got []Level3 if err := DB.Preload("Level2s.Level1").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestNestedPreload8(t *testing.T) { type ( Level1 struct { ID uint Value string Level2ID uint } Level2 struct { ID uint Level1s []Level1 Level3ID uint } Level3 struct { ID uint Name string Level2 Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := make([]Level3, 2) want[0] = Level3{ Level2: Level2{ Level1s: []Level1{ {Value: "value1"}, {Value: "value2"}, }, }, } if err := DB.Create(&want[0]).Error; err != nil { t.Error(err) } want[1] = Level3{ Level2: Level2{ Level1s: []Level1{ {Value: "value3"}, {Value: "value4"}, }, }, } if err := DB.Create(&want[1]).Error; err != nil { t.Error(err) } var got []Level3 if err := DB.Preload("Level2.Level1s").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestNestedPreload9(t *testing.T) { type ( Level0 struct { ID uint Value string Level1ID uint } Level1 struct { ID uint Value string Level2ID *uint Level2_1ID *uint Level0s []Level0 `json:",omitempty"` } Level2 struct { ID uint Level1s []Level1 Level3ID uint } Level2_1 struct { ID uint Level1s []Level1 `json:",omitempty"` Level3ID uint } Level3 struct { ID uint Name string Level2 Level2 Level2_1 Level2_1 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}, &Level2_1{}, &Level0{}); err != nil { t.Error(err) } want := make([]Level3, 2) want[0] = Level3{ Level2: Level2{ Level1s: []Level1{ {Value: "value1"}, {Value: "value2"}, }, }, Level2_1: Level2_1{ Level1s: []Level1{ { Value: "value1-1", Level0s: []Level0{{Value: "Level0-1"}}, }, { Value: "value2-2", Level0s: []Level0{{Value: "Level0-2"}}, }, }, }, } if err := DB.Create(&want[0]).Error; err != nil { t.Error(err) } want[1] = Level3{ Level2: Level2{ Level1s: []Level1{ {Value: "value3"}, {Value: "value4"}, }, }, Level2_1: Level2_1{ Level1s: []Level1{ { Value: "value3-3", Level0s: []Level0{}, }, { Value: "value4-4", Level0s: []Level0{}, }, }, }, } if err := DB.Create(&want[1]).Error; err != nil { t.Error(err) } var got []Level3 if err := DB.Preload("Level2").Preload("Level2.Level1s").Preload("Level2_1").Preload("Level2_1.Level1s").Preload("Level2_1.Level1s.Level0s").Find(&got).Error; err != nil { t.Error(err) } if string(toJSONString(got)) != string(toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } type LevelA1 struct { ID uint Value string } type LevelA2 struct { ID uint Value string LevelA3s []*LevelA3 `json:",omitempty"` } type LevelA3 struct { ID uint Value string LevelA1ID sql.NullInt64 LevelA1 *LevelA1 LevelA2ID sql.NullInt64 LevelA2 *LevelA2 } func TestNestedPreload10(t *testing.T) { DB.Migrator().DropTable(&LevelA3{}, &LevelA2{}, &LevelA1{}) if err := DB.AutoMigrate(&LevelA1{}, &LevelA2{}, &LevelA3{}); err != nil { t.Error(err) } levelA1 := &LevelA1{Value: "foo"} if err := DB.Save(levelA1).Error; err != nil { t.Error(err) } want := []*LevelA2{ { Value: "bar", LevelA3s: []*LevelA3{ { Value: "qux", LevelA1: levelA1, }, }, }, { Value: "bar 2", LevelA3s: []*LevelA3{}, }, } for _, levelA2 := range want { if err := DB.Save(levelA2).Error; err != nil { t.Error(err) } } var got []*LevelA2 if err := DB.Preload("LevelA3s.LevelA1").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } type LevelB1 struct { ID uint Value string LevelB3s []*LevelB3 } type LevelB2 struct { ID uint Value string } type LevelB3 struct { ID uint Value string LevelB1ID sql.NullInt64 LevelB1 *LevelB1 LevelB2s []*LevelB2 `gorm:"many2many:levelb1_levelb3_levelb2s" json:",omitempty"` } func TestNestedPreload11(t *testing.T) { DB.Migrator().DropTable(&LevelB3{}, &LevelB2{}, &LevelB1{}) if err := DB.AutoMigrate(&LevelB1{}, &LevelB2{}, &LevelB3{}); err != nil { t.Error(err) } levelB1 := &LevelB1{Value: "foo"} if err := DB.Create(levelB1).Error; err != nil { t.Error(err) } levelB3 := &LevelB3{ Value: "bar", LevelB1ID: sql.NullInt64{Valid: true, Int64: int64(levelB1.ID)}, LevelB2s: []*LevelB2{}, } if err := DB.Create(levelB3).Error; err != nil { t.Error(err) } levelB1.LevelB3s = []*LevelB3{levelB3} want := []*LevelB1{levelB1} var got []*LevelB1 if err := DB.Preload("LevelB3s.LevelB2s").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(toJSONString(got), toJSONString(want)) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } type LevelC1 struct { ID uint Value string LevelC2ID uint } type LevelC2 struct { ID uint Value string LevelC1 LevelC1 } type LevelC3 struct { ID uint Value string LevelC2ID uint LevelC2 LevelC2 } func TestNestedPreload12(t *testing.T) { DB.Migrator().DropTable(&LevelC3{}, &LevelC2{}, &LevelC1{}) if err := DB.AutoMigrate(&LevelC1{}, &LevelC2{}, &LevelC3{}); err != nil { t.Error(err) } level2 := LevelC2{ Value: "c2", LevelC1: LevelC1{ Value: "c1", }, } DB.Create(&level2) want := []LevelC3{ { Value: "c3-1", LevelC2: level2, }, { Value: "c3-2", LevelC2: level2, }, } for i := range want { if err := DB.Create(&want[i]).Error; err != nil { t.Error(err) } } var got []LevelC3 if err := DB.Preload("LevelC2").Preload("LevelC2.LevelC1").Find(&got).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestManyToManyPreloadWithMultiPrimaryKeys(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlite" || name == "sqlserver" { t.Skip("skip sqlite, sqlserver due to it doesn't support multiple primary keys with auto increment") } if name := DB.Dialector.Name(); name == "mysql" { t.Skip("skip mysql due to it only allow unique constraint matching given keys") } type ( Level1 struct { ID uint `gorm:"primary_key;"` LanguageCode string `gorm:"primary_key"` Value string } Level2 struct { ID uint `gorm:"primary_key;"` LanguageCode string `gorm:"primary_key"` Value string Level1s []Level1 `gorm:"many2many:levels;"` } ) DB.Migrator().DropTable(&Level2{}, &Level1{}) DB.Migrator().DropTable("levels") if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level2{Value: "Bob", LanguageCode: "ru", Level1s: []Level1{ {Value: "ru", LanguageCode: "ru"}, {Value: "en", LanguageCode: "en"}, }} if err := DB.Save(&want).Error; err != nil { t.Error(err) } want2 := Level2{Value: "Tom", LanguageCode: "zh", Level1s: []Level1{ {Value: "zh", LanguageCode: "zh"}, {Value: "de", LanguageCode: "de"}, }} if err := DB.Save(&want2).Error; err != nil { t.Error(err) } var got Level2 if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } var got2 Level2 if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got2, want2) { t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) } var got3 []Level2 if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got3, []Level2{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) } var got4 []Level2 if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { t.Error(err) } var ruLevel1 Level1 var zhLevel1 Level1 DB.First(&ruLevel1, "value = ?", "ru") DB.First(&zhLevel1, "value = ?", "zh") got.Level1s = []Level1{ruLevel1} got2.Level1s = []Level1{zhLevel1} if !reflect.DeepEqual(got4, []Level2{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) } if err := DB.Preload("Level1s").Find(&got4, "value IN (?)", []string{"non-existing"}).Error; err != nil { t.Error(err) } } func TestManyToManyPreloadForNestedPointer(t *testing.T) { type ( Level1 struct { ID uint Value string } Level2 struct { ID uint Value string Level1s []*Level1 `gorm:"many2many:levels;"` } Level3 struct { ID uint Value string Level2ID sql.NullInt64 Level2 *Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) DB.Migrator().DropTable("levels") if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level3{ Value: "Bob", Level2: &Level2{ Value: "Foo", Level1s: []*Level1{ {Value: "ru"}, {Value: "en"}, }, }, } if err := DB.Save(&want).Error; err != nil { t.Error(err) } want2 := Level3{ Value: "Tom", Level2: &Level2{ Value: "Bar", Level1s: []*Level1{ {Value: "zh"}, {Value: "de"}, }, }, } if err := DB.Save(&want2).Error; err != nil { t.Error(err) } var got Level3 if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } var got2 Level3 if err := DB.Preload("Level2.Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got2, want2) { t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) } var got3 []Level3 if err := DB.Preload("Level2.Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got3, []Level3{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level3{got, got2})) } var got4 []Level3 if err := DB.Preload("Level2.Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { t.Error(err) } var got5 Level3 DB.Preload("Level2.Level1s").Find(&got5, "value = ?", "bogus") var ruLevel1 Level1 var zhLevel1 Level1 DB.First(&ruLevel1, "value = ?", "ru") DB.First(&zhLevel1, "value = ?", "zh") got.Level2.Level1s = []*Level1{&ruLevel1} got2.Level2.Level1s = []*Level1{&zhLevel1} if !reflect.DeepEqual(got4, []Level3{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level3{got, got2})) } } func TestNestedManyToManyPreload(t *testing.T) { type ( Level1 struct { ID uint Value string } Level2 struct { ID uint Value string Level1s []*Level1 `gorm:"many2many:level1_level2;"` } Level3 struct { ID uint Value string Level2s []Level2 `gorm:"many2many:level2_level3;"` } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2", "level2_level3") if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level3{ Value: "Level3", Level2s: []Level2{ { Value: "Bob", Level1s: []*Level1{ {Value: "ru"}, {Value: "en"}, }, }, { Value: "Tom", Level1s: []*Level1{ {Value: "zh"}, {Value: "de"}, }, }, }, } if err := DB.Save(&want).Error; err != nil { t.Error(err) } var got Level3 if err := DB.Preload("Level2s").Preload("Level2s.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } if err := DB.Preload("Level2s.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } func TestNestedManyToManyPreload2(t *testing.T) { type ( Level1 struct { ID uint Value string } Level2 struct { ID uint Value string Level1s []*Level1 `gorm:"many2many:level1_level2;"` } Level3 struct { ID uint Value string Level2ID sql.NullInt64 Level2 *Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) DB.Migrator().DropTable("level1_level2") if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level3{ Value: "Level3", Level2: &Level2{ Value: "Bob", Level1s: []*Level1{ {Value: "ru"}, {Value: "en"}, }, }, } if err := DB.Save(&want).Error; err != nil { t.Error(err) } var got Level3 if err := DB.Preload("Level2.Level1s").Find(&got, "value = ?", "Level3").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } if err := DB.Preload("Level2.Level1s").First(&got, "value = ?", "not_found").Error; err != gorm.ErrRecordNotFound { t.Error(err) } } func TestNestedManyToManyPreload3(t *testing.T) { type ( Level1 struct { ID uint Value string } Level2 struct { ID uint Value string Level1s []*Level1 `gorm:"many2many:level1_level2;"` } Level3 struct { ID uint Value string Level2ID sql.NullInt64 Level2 *Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, "level1_level2") if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } level1Zh := &Level1{Value: "zh"} level1Ru := &Level1{Value: "ru"} level1En := &Level1{Value: "en"} level21 := &Level2{ Value: "Level2-1", Level1s: []*Level1{level1Zh, level1Ru}, } level22 := &Level2{ Value: "Level2-2", Level1s: []*Level1{level1Zh, level1En}, } wants := []*Level3{ { Value: "Level3-1", Level2: level21, }, { Value: "Level3-2", Level2: level22, }, { Value: "Level3-3", Level2: level21, }, } for _, want := range wants { if err := DB.Save(want).Error; err != nil { t.Error(err) } } var gots []*Level3 if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { return db.Order("level1.id ASC") }).Find(&gots).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(gots, wants) { t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) } } func TestNestedManyToManyPreload3ForStruct(t *testing.T) { type ( Level1 struct { ID uint Value string } Level2 struct { ID uint Value string Level1s []Level1 `gorm:"many2many:level1_level2;"` } Level3 struct { ID uint Value string Level2ID sql.NullInt64 Level2 Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) DB.Migrator().DropTable("level1_level2") if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } level1Zh := Level1{Value: "zh"} level1Ru := Level1{Value: "ru"} level1En := Level1{Value: "en"} level21 := Level2{ Value: "Level2-1", Level1s: []Level1{level1Zh, level1Ru}, } level22 := Level2{ Value: "Level2-2", Level1s: []Level1{level1Zh, level1En}, } wants := []*Level3{ { Value: "Level3-1", Level2: level21, }, { Value: "Level3-2", Level2: level22, }, { Value: "Level3-3", Level2: level21, }, } for _, want := range wants { if err := DB.Save(want).Error; err != nil { t.Error(err) } } var gots []*Level3 if err := DB.Preload("Level2.Level1s", func(db *gorm.DB) *gorm.DB { return db.Order("level1.id ASC") }).Find(&gots).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(gots, wants) { t.Errorf("got %s; want %s", toJSONString(gots), toJSONString(wants)) } } func TestNestedManyToManyPreload4(t *testing.T) { type ( Level4 struct { ID uint Value string Level3ID uint } Level3 struct { ID uint Value string Level4s []*Level4 } Level2 struct { ID uint Value string Level3s []*Level3 `gorm:"many2many:level2_level3;"` } Level1 struct { ID uint Value string Level2s []*Level2 `gorm:"many2many:level1_level2;"` } ) DB.Migrator().DropTable("level1_level2", "level2_level3") DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) dummy := Level1{ Value: "Level1", Level2s: []*Level2{{ Value: "Level2", Level3s: []*Level3{{ Value: "Level3", Level4s: []*Level4{{ Value: "Level4", }}, }}, }}, } if err := DB.AutoMigrate(&Level4{}, &Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } if err := DB.Save(&dummy).Error; err != nil { t.Error(err) } var level1 Level1 if err := DB.Preload("Level2s").Preload("Level2s.Level3s").Preload("Level2s.Level3s.Level4s").First(&level1).Error; err != nil { t.Error(err) } } func TestManyToManyPreloadForPointer(t *testing.T) { type ( Level1 struct { ID uint Value string } Level2 struct { ID uint Value string Level1s []*Level1 `gorm:"many2many:levels;"` } ) DB.Migrator().DropTable("levels", &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level2{Value: "Bob", Level1s: []*Level1{ {Value: "ru"}, {Value: "en"}, }} if err := DB.Save(&want).Error; err != nil { t.Error(err) } want2 := Level2{Value: "Tom", Level1s: []*Level1{ {Value: "zh"}, {Value: "de"}, }} if err := DB.Save(&want2).Error; err != nil { t.Error(err) } var got Level2 if err := DB.Preload("Level1s").Find(&got, "value = ?", "Bob").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } var got2 Level2 if err := DB.Preload("Level1s").Find(&got2, "value = ?", "Tom").Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got2, want2) { t.Errorf("got %s; want %s", toJSONString(got2), toJSONString(want2)) } var got3 []Level2 if err := DB.Preload("Level1s").Find(&got3, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { t.Error(err) } if !reflect.DeepEqual(got3, []Level2{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got3), toJSONString([]Level2{got, got2})) } var got4 []Level2 if err := DB.Preload("Level1s", "value IN (?)", []string{"zh", "ru"}).Find(&got4, "value IN (?)", []string{"Bob", "Tom"}).Error; err != nil { t.Error(err) } var got5 Level2 DB.Preload("Level1s").First(&got5, "value = ?", "bogus") var ruLevel1 Level1 var zhLevel1 Level1 DB.First(&ruLevel1, "value = ?", "ru") DB.First(&zhLevel1, "value = ?", "zh") got.Level1s = []*Level1{&ruLevel1} got2.Level1s = []*Level1{&zhLevel1} if !reflect.DeepEqual(got4, []Level2{got, got2}) { t.Errorf("got %s; want %s", toJSONString(got4), toJSONString([]Level2{got, got2})) } } func TestNilPointerSlice(t *testing.T) { type ( Level3 struct { ID uint Value string } Level2 struct { ID uint Value string Level3ID uint Level3 *Level3 } Level1 struct { ID uint Value string Level2ID *uint Level2 *Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}) if err := DB.AutoMigrate(&Level3{}, &Level2{}, &Level1{}); err != nil { t.Error(err) } want := Level1{ Value: "Bob", Level2: &Level2{ Value: "en", Level3: &Level3{ Value: "native", }, }, } if err := DB.Save(&want).Error; err != nil { t.Error(err) } want2 := Level1{ Value: "Tom", Level2: nil, } if err := DB.Save(&want2).Error; err != nil { t.Fatalf("Got error %v", err) } var got []Level1 if err := DB.Preload("Level2").Preload("Level2.Level3").Find(&got).Error; err != nil { t.Error(err) } if len(got) != 2 { t.Errorf("got %v items, expected 2", len(got)) } if !reflect.DeepEqual(got[0], want) && !reflect.DeepEqual(got[1], want) { t.Fatalf("got %s; want array containing %s", toJSONString(got), toJSONString(want)) } if !reflect.DeepEqual(got[0], want2) && !reflect.DeepEqual(got[1], want2) { t.Errorf("got %s; want array containing %s", toJSONString(got), toJSONString(want2)) } } func TestNilPointerSlice2(t *testing.T) { type ( Level4 struct { ID uint } Level3 struct { ID uint Level4ID sql.NullInt64 `sql:"index"` Level4 *Level4 } Level2 struct { ID uint Level3s []*Level3 `gorm:"many2many:level2_level3s"` } Level1 struct { ID uint Level2ID sql.NullInt64 `sql:"index"` Level2 *Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) if err := DB.AutoMigrate(new(Level4), new(Level3), new(Level2), new(Level1)); err != nil { t.Error(err) } want := new(Level1) if err := DB.Save(want).Error; err != nil { t.Error(err) } got := new(Level1) err := DB.Preload("Level2.Level3s.Level4").Last(&got).Error if err != nil { t.Error(err) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestPrefixedPreloadDuplication(t *testing.T) { type ( Level4 struct { ID uint Name string Level3ID uint } Level3 struct { ID uint Name string Level4s []*Level4 `json:",omitempty"` } Level2 struct { ID uint Name string Level3ID sql.NullInt64 `sql:"index"` Level3 *Level3 } Level1 struct { ID uint Name string Level2ID sql.NullInt64 `sql:"index"` Level2 *Level2 } ) DB.Migrator().DropTable(&Level3{}, &Level2{}, &Level1{}, &Level4{}) if err := DB.AutoMigrate(new(Level3), new(Level4), new(Level2), new(Level1)); err != nil { t.Error(err) } lvl := &Level3{} if err := DB.Save(lvl).Error; err != nil { t.Error(err) } sublvl1 := &Level4{Level3ID: lvl.ID} if err := DB.Save(sublvl1).Error; err != nil { t.Error(err) } sublvl2 := &Level4{Level3ID: lvl.ID} if err := DB.Save(sublvl2).Error; err != nil { t.Error(err) } lvl.Level4s = []*Level4{sublvl1, sublvl2} want1 := Level1{ Level2: &Level2{ Level3: lvl, }, } if err := DB.Save(&want1).Error; err != nil { t.Error(err) } want2 := Level1{ Level2: &Level2{ Level3: lvl, }, } if err := DB.Save(&want2).Error; err != nil { t.Error(err) } want := []Level1{want1, want2} var got []Level1 err := DB.Preload("Level2.Level3.Level4s").Find(&got).Error if err != nil { t.Error(err) } for _, level1 := range append(got, want...) { sort.Slice(level1.Level2.Level3.Level4s, func(i, j int) bool { return level1.Level2.Level3.Level4s[i].ID > level1.Level2.Level3.Level4s[j].ID }) } if !reflect.DeepEqual(got, want) { t.Errorf("got %s; want %s", toJSONString(got), toJSONString(want)) } } func TestPreloadManyToManyCallbacks(t *testing.T) { type ( Level2 struct { ID uint Name string } Level1 struct { ID uint Name string Level2s []Level2 `gorm:"many2many:level1_level2s"` } ) DB.Migrator().DropTable("level1_level2s", &Level2{}, &Level1{}) if err := DB.AutoMigrate(new(Level1), new(Level2)); err != nil { t.Error(err) } lvl := Level1{ Name: "l1", Level2s: []Level2{ {Name: "l2-1"}, {Name: "l2-2"}, }, } DB.Save(&lvl) var called int64 DB.Callback().Query().After("gorm:query").Register("TestPreloadManyToManyCallbacks", func(_ *gorm.DB) { atomic.AddInt64(&called, 1) }) DB.Preload("Level2s").First(&Level1{}, "id = ?", lvl.ID) if called != 3 { t.Errorf("Wanted callback to be called 3 times but got %d", called) } } ================================================ FILE: tests/preload_test.go ================================================ package tests_test import ( "context" "encoding/json" "regexp" "sort" "strconv" "sync" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestPreloadWithAssociations(t *testing.T) { user := *GetUser("preload_with_associations", Config{ Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1, }) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } CheckUser(t, user, user) var user2 User DB.Preload(clause.Associations).Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) user3 := *GetUser("preload_with_associations_new", Config{ Account: true, Pets: 2, Toys: 3, Company: true, Manager: true, Team: 4, Languages: 3, Friends: 1, }) DB.Preload(clause.Associations).Find(&user3, "id = ?", user.ID) CheckUser(t, user3, user) } func TestNestedPreload(t *testing.T) { user := *GetUser("nested_preload", Config{Pets: 2}) for idx, pet := range user.Pets { pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(idx+1)} } if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var user2 User DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) var user3 User DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) CheckUser(t, user3, user) var user4 *User DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) CheckUser(t, *user4, user) } func TestNestedPreloadForSlice(t *testing.T) { users := []User{ *GetUser("slice_nested_preload_1", Config{Pets: 2}), *GetUser("slice_nested_preload_2", Config{Pets: 0}), *GetUser("slice_nested_preload_3", Config{Pets: 3}), } for _, user := range users { for idx, pet := range user.Pets { pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} } } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) } var users2 []User DB.Preload("Pets.Toy").Find(&users2, "id IN ?", userIDs) for idx, user := range users2 { CheckUser(t, user, users[idx]) } } func TestPreloadWithConds(t *testing.T) { users := []User{ *GetUser("slice_nested_preload_1", Config{Account: true}), *GetUser("slice_nested_preload_2", Config{Account: false}), *GetUser("slice_nested_preload_3", Config{Account: true}), } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) } var users2 []User DB.Preload("Account", clause.Eq{Column: "number", Value: users[0].Account.Number}).Find(&users2, "id IN ?", userIDs) sort.Slice(users2, func(i, j int) bool { return users2[i].ID < users2[j].ID }) for idx, user := range users2[1:2] { if user.Account.Number != "" { t.Errorf("No account should found for user %v but got %v", idx+2, user.Account.Number) } } CheckUser(t, users2[0], users[0]) var users3 []User if err := DB.Preload("Account", func(tx *gorm.DB) *gorm.DB { return tx.Table("accounts AS a").Select("a.*") }).Find(&users3, "id IN ?", userIDs).Error; err != nil { t.Errorf("failed to query, got error %v", err) } sort.Slice(users3, func(i, j int) bool { return users2[i].ID < users2[j].ID }) for i, u := range users3 { CheckUser(t, u, users[i]) } var user4 User DB.Delete(&users3[0].Account) if err := DB.Preload(clause.Associations).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID != 0 { t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) } if err := DB.Preload(clause.Associations, func(tx *gorm.DB) *gorm.DB { return tx.Unscoped() }).Take(&user4, "id = ?", users3[0].ID).Error; err != nil || user4.Account.ID == 0 { t.Errorf("failed to query, got error %v, account: %#v", err, user4.Account) } } func TestNestedPreloadWithConds(t *testing.T) { users := []User{ *GetUser("slice_nested_preload_1", Config{Pets: 2}), *GetUser("slice_nested_preload_2", Config{Pets: 0}), *GetUser("slice_nested_preload_3", Config{Pets: 3}), } for _, user := range users { for idx, pet := range user.Pets { pet.Toy = Toy{Name: user.Name + "_toy_nested_preload_" + strconv.Itoa(idx+1)} } } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var userIDs []uint for _, user := range users { userIDs = append(userIDs, user.ID) } var users2 []User DB.Preload("Pets.Toy", "name like ?", `%preload_3`).Find(&users2, "id IN ?", userIDs) for idx, user := range users2[0:2] { for _, pet := range user.Pets { if pet.Toy.Name != "" { t.Errorf("No toy should for user %v's pet %v but got %v", idx+1, pet.Name, pet.Toy.Name) } } } if len(users2[2].Pets) != 3 { t.Errorf("Invalid pet toys found for user 3 got %v", len(users2[2].Pets)) } else { sort.Slice(users2[2].Pets, func(i, j int) bool { return users2[2].Pets[i].ID < users2[2].Pets[j].ID }) for _, pet := range users2[2].Pets[0:2] { if pet.Toy.Name != "" { t.Errorf("No toy should for user %v's pet %v but got %v", 3, pet.Name, pet.Toy.Name) } } CheckPet(t, *users2[2].Pets[2], *users[2].Pets[2]) } } func TestPreloadEmptyData(t *testing.T) { user := *GetUser("user_without_associations", Config{}) DB.Create(&user) DB.Preload("Team").Preload("Languages").Preload("Friends").First(&user, "name = ?", user.Name) if r, err := json.Marshal(&user); err != nil { t.Errorf("failed to marshal users, got error %v", err) } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { t.Errorf("json marshal is not empty slice, got %v", string(r)) } var results []User DB.Preload("Team").Preload("Languages").Preload("Friends").Find(&results, "name = ?", user.Name) if r, err := json.Marshal(&results); err != nil { t.Errorf("failed to marshal users, got error %v", err) } else if !regexp.MustCompile(`"Team":\[\],"Languages":\[\],"Friends":\[\]`).MatchString(string(r)) { t.Errorf("json marshal is not empty slice, got %v", string(r)) } } func TestPreloadGoroutine(t *testing.T) { var wg sync.WaitGroup wg.Add(10) for i := 0; i < 10; i++ { go func() { defer wg.Done() var user2 []User tx := DB.Where("id = ?", 1).Session(&gorm.Session{}) if err := tx.Preload("Team").Find(&user2).Error; err != nil { t.Error(err) } }() } wg.Wait() } func TestPreloadWithDiffModel(t *testing.T) { user := *GetUser("preload_with_diff_model", Config{Account: true}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var result struct { Something string User } DB.Model(User{}).Preload("Account", clause.Eq{Column: "number", Value: user.Account.Number}).Select( "users.*, 'yo' as something").First(&result, "name = ?", user.Name) CheckUser(t, user, result.User) } func TestNestedPreloadWithUnscoped(t *testing.T) { user := *GetUser("nested_preload", Config{Pets: 1}) pet := user.Pets[0] pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(1)} pet.Toy = Toy{Name: "toy_nested_preload_" + strconv.Itoa(2)} if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var user2 User DB.Preload("Pets.Toy").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) DB.Delete(&pet) var user3 User DB.Preload(clause.Associations+"."+clause.Associations).Find(&user3, "id = ?", user.ID) if len(user3.Pets) != 0 { t.Fatalf("User.Pet[0] was deleted and should not exist.") } var user4 *User DB.Preload("Pets.Toy").Find(&user4, "id = ?", user.ID) if len(user4.Pets) != 0 { t.Fatalf("User.Pet[0] was deleted and should not exist.") } var user5 User DB.Unscoped().Preload(clause.Associations+"."+clause.Associations).Find(&user5, "id = ?", user.ID) CheckUserUnscoped(t, user5, user) var user6 *User DB.Unscoped().Preload("Pets.Toy").Find(&user6, "id = ?", user.ID) CheckUserUnscoped(t, *user6, user) } func TestNestedPreloadWithNestedJoin(t *testing.T) { type ( Preload struct { ID uint Value string NestedID uint } Join struct { ID uint Value string NestedID uint } Nested struct { ID uint Preloads []*Preload Join Join ValueID uint } Value struct { ID uint Name string Nested Nested } ) DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) value1 := Value{ Name: "value", Nested: Nested{ Preloads: []*Preload{ {Value: "p1"}, {Value: "p2"}, }, Join: Join{Value: "j1"}, }, } value2 := Value{ Name: "value2", Nested: Nested{ Preloads: []*Preload{ {Value: "p3"}, {Value: "p4"}, {Value: "p5"}, }, Join: Join{Value: "j2"}, }, } values := []*Value{&value1, &value2} if err := DB.Create(&values).Error; err != nil { t.Errorf("failed to create value, got err: %v", err) } var find1 Value err := DB.Joins("Nested").Joins("Nested.Join").Preload("Nested.Preloads").First(&find1, value1.ID).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } AssertEqual(t, find1, value1) var find2 Value // Joins will automatically add Nested queries. err = DB.Joins("Nested.Join").Preload("Nested.Preloads").First(&find2, value2.ID).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } AssertEqual(t, find2, value2) var finds []Value err = DB.Joins("Nested.Join").Joins("Nested").Preload("Nested.Preloads").Find(&finds).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } AssertEqual(t, len(finds), 2) AssertEqual(t, finds[0], value1) AssertEqual(t, finds[1], value2) } func TestMergeNestedPreloadWithNestedJoin(t *testing.T) { users := []User{ { Name: "TestMergeNestedPreloadWithNestedJoin-1", Manager: &User{ Name: "Alexis Manager", Tools: []Tools{ {Name: "Alexis Tool 1"}, {Name: "Alexis Tool 2"}, }, }, }, { Name: "TestMergeNestedPreloadWithNestedJoin-2", Manager: &User{ Name: "Jinzhu Manager", Tools: []Tools{ {Name: "Jinzhu Tool 1"}, {Name: "Jinzhu Tool 2"}, }, }, }, } DB.Create(&users) query := make([]string, 0) sess := DB.Session(&gorm.Session{Logger: Tracer{ Logger: DB.Config.Logger, Test: func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { sql, _ := fc() query = append(query, sql) }, }}) var result []User err := sess. Joins("Manager"). Preload("Manager.Tools"). Where("users.name Like ?", "TestMergeNestedPreloadWithNestedJoin%"). Find(&result).Error if err != nil { t.Fatalf("failed to preload and find users: %v", err) } AssertEqual(t, result, users) AssertEqual(t, len(query), 2) // Check preload queries are merged if !regexp.MustCompile(`SELECT \* FROM .*tools.* WHERE .*IN.*`).MatchString(query[0]) { t.Fatalf("Expected first query to preload manager tools, got: %s", query[0]) } } func TestNestedPreloadWithPointerJoin(t *testing.T) { type ( Preload struct { ID uint Value string JoinID uint } Join struct { ID uint Value string Preload Preload NestedID uint } Nested struct { ID uint Join Join ValueID uint } Value struct { ID uint Name string Nested *Nested } ) DB.Migrator().DropTable(&Preload{}, &Join{}, &Nested{}, &Value{}) DB.Migrator().AutoMigrate(&Preload{}, &Join{}, &Nested{}, &Value{}) value := Value{ Name: "value", Nested: &Nested{ Join: Join{ Value: "j1", Preload: Preload{ Value: "p1", }, }, }, } if err := DB.Create(&value).Error; err != nil { t.Errorf("failed to create value, got err: %v", err) } var find1 Value err := DB.Table("values").Joins("Nested").Joins("Nested.Join").Preload("Nested.Join.Preload").First(&find1).Error if err != nil { t.Errorf("failed to find value, got err: %v", err) } AssertEqual(t, find1, value) } func TestEmbedPreload(t *testing.T) { type Country struct { ID int `gorm:"primaryKey"` Name string } type EmbeddedAddress struct { ID int Name string CountryID *int Country *Country } type NestedAddress struct { EmbeddedAddress } type Org struct { ID int PostalAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:postal_address_"` VisitingAddress EmbeddedAddress `gorm:"embedded;embeddedPrefix:visiting_address_"` AddressID *int Address *EmbeddedAddress NestedAddress NestedAddress `gorm:"embedded;embeddedPrefix:nested_address_"` } DB.Migrator().DropTable(&Org{}, &EmbeddedAddress{}, &Country{}) DB.AutoMigrate(&Org{}, &EmbeddedAddress{}, &Country{}) org := Org{ PostalAddress: EmbeddedAddress{Name: "a1", Country: &Country{Name: "c1"}}, VisitingAddress: EmbeddedAddress{Name: "a2", Country: &Country{Name: "c2"}}, Address: &EmbeddedAddress{Name: "a3", Country: &Country{Name: "c3"}}, NestedAddress: NestedAddress{ EmbeddedAddress: EmbeddedAddress{Name: "a4", Country: &Country{Name: "c4"}}, }, } if err := DB.Create(&org).Error; err != nil { t.Errorf("failed to create org, got err: %v", err) } tests := []struct { name string preloads map[string][]interface{} expect Org }{ { name: "address country", preloads: map[string][]interface{}{"Address.Country": {}}, expect: Org{ ID: org.ID, PostalAddress: EmbeddedAddress{ ID: org.PostalAddress.ID, Name: org.PostalAddress.Name, CountryID: org.PostalAddress.CountryID, Country: nil, }, VisitingAddress: EmbeddedAddress{ ID: org.VisitingAddress.ID, Name: org.VisitingAddress.Name, CountryID: org.VisitingAddress.CountryID, Country: nil, }, AddressID: org.AddressID, Address: org.Address, NestedAddress: NestedAddress{EmbeddedAddress{ ID: org.NestedAddress.ID, Name: org.NestedAddress.Name, CountryID: org.NestedAddress.CountryID, Country: nil, }}, }, }, { name: "postal address country", preloads: map[string][]interface{}{"PostalAddress.Country": {}}, expect: Org{ ID: org.ID, PostalAddress: org.PostalAddress, VisitingAddress: EmbeddedAddress{ ID: org.VisitingAddress.ID, Name: org.VisitingAddress.Name, CountryID: org.VisitingAddress.CountryID, Country: nil, }, AddressID: org.AddressID, Address: nil, NestedAddress: NestedAddress{EmbeddedAddress{ ID: org.NestedAddress.ID, Name: org.NestedAddress.Name, CountryID: org.NestedAddress.CountryID, Country: nil, }}, }, }, { name: "nested address country", preloads: map[string][]interface{}{"NestedAddress.Country": {}}, expect: Org{ ID: org.ID, PostalAddress: EmbeddedAddress{ ID: org.PostalAddress.ID, Name: org.PostalAddress.Name, CountryID: org.PostalAddress.CountryID, Country: nil, }, VisitingAddress: EmbeddedAddress{ ID: org.VisitingAddress.ID, Name: org.VisitingAddress.Name, CountryID: org.VisitingAddress.CountryID, Country: nil, }, AddressID: org.AddressID, Address: nil, NestedAddress: org.NestedAddress, }, }, { name: "associations", preloads: map[string][]interface{}{ clause.Associations: {}, // clause.Associations won’t preload nested associations "Address.Country": {}, }, expect: org, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { actual := Org{} tx := DB.Where("id = ?", org.ID).Session(&gorm.Session{}) for name, args := range test.preloads { tx = tx.Preload(name, args...) } if err := tx.Find(&actual).Error; err != nil { t.Errorf("failed to find org, got err: %v", err) } AssertEqual(t, actual, test.expect) }) } } ================================================ FILE: tests/prepared_stmt_test.go ================================================ package tests_test import ( "context" "errors" "sync" "testing" "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestPreparedStmt(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) if _, ok := tx.ConnPool.(*gorm.PreparedStmtDB); !ok { t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") } ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond) defer cancel() txCtx := tx.WithContext(ctx) user := *GetUser("prepared_stmt", Config{}) txCtx.Create(&user) var result1 User if err := txCtx.Find(&result1, user.ID).Error; err != nil { t.Fatalf("no error should happen but got %v", err) } time.Sleep(time.Second) var result2 User if err := tx.Find(&result2, user.ID).Error; err != nil { t.Fatalf("no error should happen but got %v", err) } user2 := *GetUser("prepared_stmt2", Config{}) if err := txCtx.Create(&user2).Error; err == nil { t.Fatalf("should failed to create with timeout context") } if err := tx.Create(&user2).Error; err != nil { t.Fatalf("failed to create, got error %v", err) } var result3 User if err := tx.Find(&result3, user2.ID).Error; err != nil { t.Fatalf("no error should happen but got %v", err) } } func TestPreparedStmtFromTransaction(t *testing.T) { db := DB.Session(&gorm.Session{PrepareStmt: true, SkipDefaultTransaction: true}) tx := db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() if err := tx.Error; err != nil { t.Errorf("Failed to start transaction, got error %v\n", err) } if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { tx.Rollback() t.Errorf("Failed to run one transaction, got error %v\n", err) } if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { tx.Rollback() t.Errorf("Failed to run one transaction, got error %v\n", err) } if err := tx.Commit().Error; err != nil { t.Errorf("Failed to commit transaction, got error %v\n", err) } if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) } tx2 := db.Begin() if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) } tx2.Commit() } func TestPreparedStmtLruFromTransaction(t *testing.T) { db, _ := OpenTestConnection(&gorm.Config{PrepareStmt: true, PrepareStmtMaxSize: 10, PrepareStmtTTL: 20 * time.Second}) tx := db.Begin() defer func() { if r := recover(); r != nil { tx.Rollback() } }() if err := tx.Error; err != nil { t.Errorf("Failed to start transaction, got error %v\n", err) } if err := tx.Where("name=?", "zzjin").Delete(&User{}).Error; err != nil { tx.Rollback() t.Errorf("Failed to run one transaction, got error %v\n", err) } if err := tx.Create(&User{Name: "zzjin"}).Error; err != nil { tx.Rollback() t.Errorf("Failed to run one transaction, got error %v\n", err) } if err := tx.Commit().Error; err != nil { t.Errorf("Failed to commit transaction, got error %v\n", err) } if result := db.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 1 { t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) } tx2 := db.Begin() if result := tx2.Where("name=?", "zzjin").Delete(&User{}); result.Error != nil || result.RowsAffected != 0 { t.Fatalf("Failed, got error: %v, rows affected: %v", result.Error, result.RowsAffected) } tx2.Commit() // Attempt to convert the connection pool of tx to the *gorm.PreparedStmtDB type. // If the conversion is successful, ok will be true and conn will be the converted object; // otherwise, ok will be false and conn will be nil. conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) // Get the number of statement keys stored in the PreparedStmtDB. lens := len(conn.Stmts.Keys()) // Check if the number of stored statement keys is 0. if lens == 0 { // If the number is 0, it means there are no statements stored in the LRU cache. // The test fails and an error message is output. t.Fatalf("lru should not be empty") } // Wait for 40 seconds to give the statements in the cache enough time to expire. time.Sleep(time.Second * 40) // Assert whether the connection pool of tx is successfully converted to the *gorm.PreparedStmtDB type. AssertEqual(t, ok, true) // Assert whether the number of statement keys stored in the PreparedStmtDB is 0 after 40 seconds. // If it is not 0, it means the statements in the cache have not expired as expected. AssertEqual(t, len(conn.Stmts.Keys()), 0) } func TestPreparedStmtDeadlock(t *testing.T) { tx, err := OpenTestConnection(&gorm.Config{}) AssertEqual(t, err, nil) sqlDB, _ := tx.DB() sqlDB.SetMaxOpenConns(1) tx = tx.Session(&gorm.Session{PrepareStmt: true}) wg := sync.WaitGroup{} for i := 0; i < 100; i++ { wg.Add(1) go func() { user := User{Name: "jinzhu"} tx.Create(&user) var result User tx.First(&result) wg.Done() }() } wg.Wait() conn, ok := tx.ConnPool.(*gorm.PreparedStmtDB) AssertEqual(t, ok, true) AssertEqual(t, len(conn.Stmts.Keys()), 2) for _, stmt := range conn.Stmts.Keys() { if stmt == "" { t.Fatalf("stmt cannot bee nil") } } AssertEqual(t, sqlDB.Stats().InUse, 0) } func TestPreparedStmtInTransaction(t *testing.T) { user := User{Name: "jinzhu"} if err := DB.Transaction(func(tx *gorm.DB) error { tx.Session(&gorm.Session{PrepareStmt: true}).Create(&user) return errors.New("test") }); err == nil { t.Error(err) } var result User if err := DB.First(&result, user.ID).Error; err == nil { t.Errorf("Failed, got error: %v", err) } } func TestPreparedStmtClose(t *testing.T) { tx := DB.Session(&gorm.Session{PrepareStmt: true}) user := *GetUser("prepared_stmt_close", Config{}) tx = tx.Create(&user) pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) if !ok { t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") } pdb.Mux.Lock() if len(pdb.Stmts.Keys()) == 0 { pdb.Mux.Unlock() t.Fatalf("prepared stmt can not be empty") } pdb.Mux.Unlock() pdb.Close() pdb.Mux.Lock() defer pdb.Mux.Unlock() if len(pdb.Stmts.Keys()) != 0 { t.Fatalf("prepared stmt should be empty") } } func isUsingClosedConnError(err error) bool { // https://github.com/golang/go/blob/e705a2d16e4ece77e08e80c168382cdb02890f5b/src/database/sql/sql.go#L2717 return err.Error() == "sql: statement is closed" } // TestPreparedStmtConcurrentClose test calling close and executing SQL concurrently // this test making sure that the gorm would not get a Segmentation Fault, and the only error cause by this is using a closed Stmt func TestPreparedStmtConcurrentClose(t *testing.T) { name := "prepared_stmt_concurrent_close" user := *GetUser(name, Config{}) createTx := DB.Session(&gorm.Session{}).Create(&user) if createTx.Error != nil { t.Fatalf("failed to prepare record due to %s, test cannot be continue", createTx.Error) } // create a new connection to keep away from other tests tx, err := OpenTestConnection(&gorm.Config{PrepareStmt: true}) if err != nil { t.Fatalf("failed to open test connection due to %s", err) } pdb, ok := tx.ConnPool.(*gorm.PreparedStmtDB) if !ok { t.Fatalf("should assign PreparedStatement Manager back to database when using PrepareStmt mode") } loopCount := 100 var wg sync.WaitGroup var unexpectedError bool writerFinish := make(chan struct{}) wg.Add(1) go func(id uint) { defer wg.Done() defer close(writerFinish) for j := 0; j < loopCount; j++ { var tmp User err := tx.Session(&gorm.Session{}).First(&tmp, id).Error if err == nil || isUsingClosedConnError(err) { continue } t.Errorf("failed to read user of id %d due to %s, there should not be error", id, err) unexpectedError = true break } }(user.ID) wg.Add(1) go func() { defer wg.Done() <-writerFinish pdb.Close() }() wg.Wait() if unexpectedError { t.Fatalf("should is a unexpected error") } } ================================================ FILE: tests/query_test.go ================================================ package tests_test import ( "database/sql" "database/sql/driver" "fmt" "reflect" "regexp" "sort" "strconv" "strings" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestFind(t *testing.T) { users := []User{ *GetUser("find", Config{}), *GetUser("find", Config{}), *GetUser("find", Config{}), } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create users: %v", err) } t.Run("First", func(t *testing.T) { var first User if err := DB.Where("name = ?", "find").First(&first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { CheckUser(t, first, users[0]) } }) t.Run("Last", func(t *testing.T) { var last User if err := DB.Where("name = ?", "find").Last(&last).Error; err != nil { t.Errorf("errors happened when query last: %v", err) } else { CheckUser(t, last, users[2]) } }) var all []User if err := DB.Where("name = ?", "find").Find(&all).Error; err != nil || len(all) != 3 { t.Errorf("errors happened when query find: %v, length: %v", err, len(all)) } else { for idx, user := range users { t.Run("FindAll#"+strconv.Itoa(idx+1), func(t *testing.T) { CheckUser(t, all[idx], user) }) } } t.Run("FirstMap", func(t *testing.T) { first := map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").First(first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) switch name { case "Name": if _, ok := first[dbName].(string); !ok { t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) } case "Age": if _, ok := first[dbName].(uint); !ok { t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) } case "Birthday": if _, ok := first[dbName].(*time.Time); !ok { t.Errorf("invalid data type for %v, got %#v", dbName, first[dbName]) } } reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) }) } } }) t.Run("FirstMapWithTable", func(t *testing.T) { first := map[string]interface{}{} if err := DB.Table("users").Where("name = ?", "find").Find(first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) resultType := reflect.ValueOf(first[dbName]).Type().Name() switch name { case "Name": if !strings.Contains(resultType, "string") { t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) } case "Age": if !strings.Contains(resultType, "int") { t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) } case "Birthday": if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, first[dbName]) } } reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) }) } } }) t.Run("FirstPtrMap", func(t *testing.T) { first := map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").First(&first).Error; err != nil { t.Errorf("errors happened when query first: %v", err) } else { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) reflectValue := reflect.Indirect(reflect.ValueOf(users[0])) AssertEqual(t, first[dbName], reflectValue.FieldByName(name).Interface()) }) } } }) t.Run("FirstSliceOfMap", func(t *testing.T) { allMap := []map[string]interface{}{} if err := DB.Model(&User{}).Where("name = ?", "find").Find(&allMap).Error; err != nil { t.Errorf("errors happened when query find: %v", err) } else { for idx, user := range users { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) switch name { case "Name": if _, ok := allMap[idx][dbName].(string); !ok { t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) } case "Age": if _, ok := allMap[idx][dbName].(uint); !ok { t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) } case "Birthday": if _, ok := allMap[idx][dbName].(*time.Time); !ok { t.Errorf("invalid data type for %v, got %#v", dbName, allMap[idx][dbName]) } } reflectValue := reflect.Indirect(reflect.ValueOf(user)) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) }) } }) } } }) t.Run("FindSliceOfMapWithTable", func(t *testing.T) { allMap := []map[string]interface{}{} if err := DB.Table("users").Where("name = ?", "find").Find(&allMap).Error; err != nil { t.Errorf("errors happened when query find: %v", err) } else { for idx, user := range users { t.Run("FindAllMap#"+strconv.Itoa(idx+1), func(t *testing.T) { for _, name := range []string{"Name", "Age", "Birthday"} { t.Run(name, func(t *testing.T) { dbName := DB.NamingStrategy.ColumnName("", name) resultType := reflect.ValueOf(allMap[idx][dbName]).Type().Name() switch name { case "Name": if !strings.Contains(resultType, "string") { t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) } case "Age": if !strings.Contains(resultType, "int") { t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) } case "Birthday": if !strings.Contains(resultType, "Time") && !(DB.Dialector.Name() == "sqlite" && strings.Contains(resultType, "string")) { t.Errorf("invalid data type for %v, got %v %#v", dbName, resultType, allMap[idx][dbName]) } } reflectValue := reflect.Indirect(reflect.ValueOf(user)) AssertEqual(t, allMap[idx][dbName], reflectValue.FieldByName(name).Interface()) }) } }) } } }) var models []User if err := DB.Where("name in (?)", []string{"find"}).Find(&models).Error; err != nil || len(models) != 3 { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models)) } else { for idx, user := range users { t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { CheckUser(t, models[idx], user) }) } } // test array var models2 [3]User if err := DB.Where("name in (?)", []string{"find"}).Find(&models2).Error; err != nil { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models2)) } else { for idx, user := range users { t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { CheckUser(t, models2[idx], user) }) } } // test smaller array var models3 [2]User if err := DB.Where("name in (?)", []string{"find"}).Find(&models3).Error; err != nil { t.Errorf("errors happened when query find with in clause: %v, length: %v", err, len(models3)) } else { for idx, user := range users[:2] { t.Run("FindWithInClause#"+strconv.Itoa(idx+1), func(t *testing.T) { CheckUser(t, models3[idx], user) }) } } var none []User if err := DB.Where("name in (?)", []string{}).Find(&none).Error; err != nil || len(none) != 0 { t.Errorf("errors happened when query find with in clause and zero length parameter: %v, length: %v", err, len(none)) } } func TestQueryWithAssociation(t *testing.T) { user := *GetUser("query_with_association", Config{Account: true, Pets: 2, Toys: 1, Company: true, Manager: true, Team: 2, Languages: 1, Friends: 3}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create user: %v", err) } user.CreatedAt = time.Time{} user.UpdatedAt = time.Time{} if err := DB.Where(&user).First(&User{}).Error; err != nil { t.Errorf("search with struct with association should returns no error, but got %v", err) } if err := DB.Where(user).First(&User{}).Error; err != nil { t.Errorf("search with struct with association should returns no error, but got %v", err) } } func TestFindInBatches(t *testing.T) { users := []User{ *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), *GetUser("find_in_batches", Config{}), } DB.Create(&users) var ( results []User totalBatch int ) if result := DB.Table("users as u").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch if tx.RowsAffected != 2 { t.Errorf("Incorrect affected rows, expects: 2, got %v", tx.RowsAffected) } if len(results) != 2 { t.Errorf("Incorrect users length, expects: 2, got %v", len(results)) } for idx := range results { results[idx].Name = results[idx].Name + "_new" } if err := tx.Save(results).Error; err != nil { t.Fatalf("failed to save users, got error %v", err) } return nil }); result.Error != nil || result.RowsAffected != 6 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) } if totalBatch != 6 { t.Errorf("incorrect total batch, expects: %v, got %v", 6, totalBatch) } var count int64 DB.Model(&User{}).Where("name = ?", "find_in_batches_new").Count(&count) if count != 6 { t.Errorf("incorrect count after update, expects: %v, got %v", 6, count) } } func TestFindInBatchesWithOffsetLimit(t *testing.T) { users := []User{ *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), *GetUser("find_in_batches_with_offset_limit", Config{}), } DB.Create(&users) var ( sub, results []User lastBatch int ) // offset limit if result := DB.Offset(3).Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub, 2, func(tx *gorm.DB, batch int) error { results = append(results, sub...) lastBatch = batch return nil }); result.Error != nil || result.RowsAffected != 5 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) } if lastBatch != 3 { t.Fatalf("incorrect last batch, expected: %v, got: %v", 3, lastBatch) } targetUsers := users[3:8] for i := 0; i < len(targetUsers); i++ { AssertEqual(t, results[i], targetUsers[i]) } var sub1 []User // limit < batchSize if result := DB.Limit(5).Where("name = ?", users[0].Name).FindInBatches(&sub1, 10, func(tx *gorm.DB, batch int) error { return nil }); result.Error != nil || result.RowsAffected != 5 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) } var sub2 []User // only offset if result := DB.Offset(3).Where("name = ?", users[0].Name).FindInBatches(&sub2, 2, func(tx *gorm.DB, batch int) error { return nil }); result.Error != nil || result.RowsAffected != 7 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) } var sub3 []User if result := DB.Limit(4).Where("name = ?", users[0].Name).FindInBatches(&sub3, 2, func(tx *gorm.DB, batch int) error { return nil }); result.Error != nil || result.RowsAffected != 4 { t.Errorf("Failed to batch find, got error %v, rows affected: %v", result.Error, result.RowsAffected) } } func TestFindInBatchesWithError(t *testing.T) { if name := DB.Dialector.Name(); name == "sqlserver" { t.Skip("skip sqlserver due to it will raise data race for invalid sql") } users := []User{ *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), *GetUser("find_in_batches_with_error", Config{}), } DB.Create(&users) var ( results []User totalBatch int ) if result := DB.Table("wrong_table").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch return nil }); result.Error == nil || result.RowsAffected > 0 { t.Fatal("expected errors to have occurred, but nothing happened") } if totalBatch != 0 { t.Fatalf("incorrect total batch, expected: %v, got: %v", 0, totalBatch) } if result := DB.Omit("id").Where("name = ?", users[0].Name).FindInBatches(&results, 2, func(tx *gorm.DB, batch int) error { totalBatch += batch return nil }); result.Error != gorm.ErrPrimaryKeyRequired { t.Fatal("expected errors to have occurred, but nothing happened") } } func TestFillSmallerStruct(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) type SimpleUser struct { ID int64 Name string UpdatedAt time.Time CreatedAt time.Time } var simpleUser SimpleUser if err := DB.Table("users").Where("name = ?", user.Name).First(&simpleUser).Error; err != nil { t.Fatalf("Failed to query smaller user, got error %v", err) } AssertObjEqual(t, user, simpleUser, "Name", "ID", "UpdatedAt", "CreatedAt") var simpleUser2 SimpleUser if err := DB.Model(&User{}).Select("id").First(&simpleUser2, user.ID).Error; err != nil { t.Fatalf("Failed to query smaller user, got error %v", err) } AssertObjEqual(t, user, simpleUser2, "ID") var simpleUsers []SimpleUser if err := DB.Model(&User{}).Select("id").Find(&simpleUsers, user.ID).Error; err != nil || len(simpleUsers) != 1 { t.Fatalf("Failed to query smaller user, got error %v", err) } AssertObjEqual(t, user, simpleUsers[0], "ID") result := DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&simpleUsers, user.ID) if !regexp.MustCompile("SELECT .*id.*name.*updated_at.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) } result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&User{}, user.ID) if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) } result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]User{}, user.ID) if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) } result = DB.Session(&gorm.Session{DryRun: true}).Model(&User{}).Find(&[]*User{}, user.ID) if regexp.MustCompile("SELECT .*name.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should not include selected names, but got %v", result.Statement.SQL.String()) } } func TestFillSmallerStructWithAllFields(t *testing.T) { user := User{Name: "SmallerUser", Age: 100} DB.Save(&user) type SimpleUser struct { ID int64 Name string UpdatedAt time.Time CreatedAt time.Time } var simpleUsers []SimpleUser dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) result := dryDB.Model(&User{}).Find(&simpleUsers, user.ID) if !regexp.MustCompile("SELECT .users.*id.*users.*name.*users.*updated_at.*users.*created_at.* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should include selected names, but got %v", result.Statement.SQL.String()) } result = dryDB.Model(&User{}).Find(&User{}, user.ID) if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) } result = dryDB.Model(&User{}).Find(&[]User{}, user.ID) if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) } result = dryDB.Model(&User{}).Find(&[]*User{}, user.ID) if regexp.MustCompile("SELECT \\* FROM .*users").MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL should not include a * wildcard, but got %v", result.Statement.SQL.String()) } } func TestNot(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) result := dryDB.Not(map[string]interface{}{"name": "jinzhu"}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("name = ?", "jinzhu1").Not("name = ?", "jinzhu2").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ AND NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not("name = ?", "jinzhu").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not(map[string]interface{}{"name": []string{}}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* IS NOT NULL").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not(map[string]interface{}{"name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not([]int64{1, 2}).First(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not([]int64{}).First(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not(DB.Where("manager IS NULL").Where("age >= ?", 20)).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE NOT \\(manager IS NULL AND age >= .+\\) AND .users.\\..deleted_at. IS NULL").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not(DB.Where("manager IS NULL").Or("age >= ?", 20)).Find(&User{}) if !regexp.MustCompile(`SELECT \* FROM .*users.* WHERE NOT \(manager IS NULL OR age >= .+\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } } func TestNotWithAllFields(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " result := dryDB.Not(map[string]interface{}{"users.name": "jinzhu"}).Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("users.name = ?", "jinzhu1").Not("users.name = ?", "jinzhu2").Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ AND NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not("users.name = ?", "jinzhu").Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE NOT .*users.*name.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not(map[string]interface{}{"users.name": []string{"jinzhu", "jinzhu 2"}}).Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not([]int64{1, 2}).First(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*id.* NOT IN \\(.+,.+\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not([]int64{}).First(&User{}) if !regexp.MustCompile(userQuery + "WHERE .users.\\..deleted_at. IS NULL ORDER BY").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Not(User{Name: "jinzhu", Age: 18}).First(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*..*name.* <> .+ AND .*users.*..*age.* <> .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build NOT condition, but got %v", result.Statement.SQL.String()) } } func TestOr(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) var count int64 result := dryDB.Model(&User{}).Or("role = ?", "admin").Count(&count) if !regexp.MustCompile("SELECT count\\(\\*\\) FROM .*users.* WHERE role = .+ AND .*users.*\\..*deleted_at.* IS NULL").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin")).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("role = ?", "admin").Where(DB.Or("role = ?", "super_admin").Or("role = ?", "admin")).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ AND (.*role.* = .+ OR .*role.* = .+)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } sub := dryDB.Clauses(clause.Where{ Exprs: []clause.Expression{ clause.OrConditions{ Exprs: []clause.Expression{ clause.Expr{SQL: "role = ?", Vars: []interface{}{"super_admin"}}, clause.Expr{SQL: "role = ?", Vars: []interface{}{"admin"}}, }, }, }, }) result = dryDB.Where(sub).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*name.* AND .*age.*\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* WHERE .*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } } func TestOrWithAllFields(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name" + ".*users.*age.*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " result := dryDB.Where("role = ?", "admin").Or("role = ?", "super_admin").Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*role.* = .+ OR .*role.* = .+").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("users.name = ?", "jinzhu").Or(User{Name: "jinzhu 2", Age: 18}).Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*users.*name.* AND .*users.*age.*\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Where("users.name = ?", "jinzhu").Or(map[string]interface{}{"name": "jinzhu 2", "age": 18}).Find(&User{}) if !regexp.MustCompile(userQuery + "WHERE .*users.*name.* = .+ OR \\(.*age.* AND .*name.*\\)").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build OR condition, but got %v", result.Statement.SQL.String()) } } type Int64 int64 func (v Int64) Value() (driver.Value, error) { return v - 1, nil } func (f *Int64) Scan(v interface{}) error { y := v.(int64) *f = Int64(y + 1) return nil } func TestPluck(t *testing.T) { users := []*User{ GetUser("pluck-user1", Config{}), GetUser("pluck-user2", Config{}), GetUser("pluck-user3", Config{}), } DB.Create(&users) var names []string if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name").Pluck("name", &names).Error; err != nil { t.Errorf("got error when pluck name: %v", err) } var names2 []string if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Order("name desc").Pluck("name", &names2).Error; err != nil { t.Errorf("got error when pluck name: %v", err) } sort.Slice(names2, func(i, j int) bool { return names2[i] < names2[j] }) AssertEqual(t, names, names2) var ids []int if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids).Error; err != nil { t.Errorf("got error when pluck id: %v", err) } var ids2 []Int64 if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("id", &ids2).Error; err != nil { t.Errorf("got error when pluck id: %v", err) } for idx, name := range names { if name != users[idx].Name { t.Errorf("Unexpected result on pluck name, got %+v", names) } } for idx, id := range ids { if int(id) != int(users[idx].ID) { t.Errorf("Unexpected result on pluck id, got %+v", ids) } } for idx, id := range ids2 { if int(id) != int(users[idx].ID+1) { t.Errorf("Unexpected result on pluck id, got %+v", ids) } } var times []time.Time if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", ×).Error; err != nil { t.Errorf("got error when pluck time: %v", err) } for idx, tv := range times { AssertEqual(t, tv, users[idx].CreatedAt) } var ptrtimes []*time.Time if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &ptrtimes).Error; err != nil { t.Errorf("got error when pluck time: %v", err) } for idx, tv := range ptrtimes { AssertEqual(t, tv, users[idx].CreatedAt) } var nulltimes []sql.NullTime if err := DB.Model(User{}).Where("name like ?", "pluck-user%").Pluck("created_at", &nulltimes).Error; err != nil { t.Errorf("got error when pluck time: %v", err) } for idx, tv := range nulltimes { AssertEqual(t, tv.Time, users[idx].CreatedAt) } } func TestSelect(t *testing.T) { user := User{Name: "SelectUser1"} DB.Save(&user) var result User DB.Where("name = ?", user.Name).Select("name").Find(&result) if result.ID != 0 { t.Errorf("Should not have ID because only selected name, %+v", result.ID) } if user.Name != result.Name { t.Errorf("Should have user Name when selected it") } var result2 User DB.Where("name = ?", user.Name).Select("name as name").Find(&result2) if result2.ID != 0 { t.Errorf("Should not have ID because only selected name, %+v", result2.ID) } if user.Name != result2.Name { t.Errorf("Should have user Name when selected it") } dryDB := DB.Session(&gorm.Session{DryRun: true}) r := dryDB.Select("name", "age").Find(&User{}) if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with strings, but got %v", r.Statement.SQL.String()) } r = dryDB.Select([]string{"name", "age"}).Find(&User{}) if !regexp.MustCompile("SELECT .*name.*,.*age.* FROM .*users.*").MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with slice, but got %v", r.Statement.SQL.String()) } // SELECT COALESCE(age,'42') FROM users; r = dryDB.Table("users").Select("COALESCE(age,?)", 42).Find(&User{}) if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) } // named arguments r = dryDB.Table("users").Select("COALESCE(age, @default)", sql.Named("default", 42)).Find(&User{}) if !regexp.MustCompile(`SELECT COALESCE\(age,.*\) FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with func, but got %v", r.Statement.SQL.String()) } if _, err := DB.Table("users").Select("COALESCE(age,?)", "42").Rows(); err != nil { t.Fatalf("Failed, got error: %v", err) } r = dryDB.Select("u.*").Table("users as u").First(&User{}, user.ID) if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) } r = dryDB.Select("count(*)").Select("u.*").Table("users as u").First(&User{}, user.ID) if !regexp.MustCompile(`SELECT u\.\* FROM .*users.*`).MatchString(r.Statement.SQL.String()) { t.Fatalf("Build Select with u.*, but got %v", r.Statement.SQL.String()) } } func TestOmit(t *testing.T) { user := User{Name: "OmitUser1", Age: 20} DB.Save(&user) var result User DB.Where("name = ?", user.Name).Omit("name").Find(&result) if result.ID == 0 { t.Errorf("Should not have ID because only selected name, %+v", result.ID) } if result.Name != "" || result.Age != 20 { t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", result.Name, result.Age) } } func TestOmitWithAllFields(t *testing.T) { user := User{Name: "OmitUser1", Age: 20} DB.Save(&user) var userResult User DB.Session(&gorm.Session{QueryFields: true}).Where("users.name = ?", user.Name).Omit("name").Find(&userResult) if userResult.ID == 0 { t.Errorf("Should not have ID because only selected name, %+v", userResult.ID) } if userResult.Name != "" || userResult.Age != 20 { t.Errorf("User Name should be omitted, got %v, Age should be ok, got %v", userResult.Name, userResult.Age) } dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*birthday" + ".*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " result := dryDB.Omit("name, age").Find(&User{}) if !regexp.MustCompile(userQuery).MatchString(result.Statement.SQL.String()) { t.Fatalf("SQL must include table name and selected fields, got %v", result.Statement.SQL.String()) } } func TestMapColumns(t *testing.T) { user := User{Name: "MapColumnsUser", Age: 12} DB.Save(&user) type result struct { Name string Nickname string Age uint } var res result DB.Table("users").Where("name = ?", user.Name).MapColumns(map[string]string{"name": "nickname"}).Scan(&res) if res.Nickname != user.Name { t.Errorf("Expected res.Nickname to be %s, but got %s", user.Name, res.Nickname) } if res.Name != "" { t.Errorf("Expected res.Name to be empty, but got %s", res.Name) } if res.Age != user.Age { t.Errorf("Expected res.Age to be %d, but got %d", user.Age, res.Age) } } func TestPluckWithSelect(t *testing.T) { users := []User{ {Name: "pluck_with_select_1", Age: 25}, {Name: "pluck_with_select_2", Age: 26}, } DB.Create(&users) var userAges []int err := DB.Model(&User{}).Where("name like ?", "pluck_with_select%").Select("age + 1 as user_age").Pluck("user_age", &userAges).Error if err != nil { t.Fatalf("got error when pluck user_age: %v", err) } sort.Ints(userAges) AssertEqual(t, userAges, []int{26, 27}) } func TestSelectWithVariables(t *testing.T) { DB.Save(&User{Name: "select_with_variables"}) rows, _ := DB.Table("users").Where("name = ?", "select_with_variables").Select("? as fake", gorm.Expr("name")).Rows() if !rows.Next() { t.Errorf("Should have returned at least one row") } else { columns, _ := rows.Columns() AssertEqual(t, columns, []string{"fake"}) } rows.Close() } func TestSelectWithArrayInput(t *testing.T) { DB.Save(&User{Name: "select_with_array", Age: 42}) var user User DB.Select([]string{"name", "age"}).Where("age = 42 AND name = ?", "select_with_array").First(&user) if user.Name != "select_with_array" || user.Age != 42 { t.Errorf("Should have selected both age and name") } } func TestCustomizedTypePrimaryKey(t *testing.T) { type ID uint type CustomizedTypePrimaryKey struct { ID ID Name string } DB.Migrator().DropTable(&CustomizedTypePrimaryKey{}) if err := DB.AutoMigrate(&CustomizedTypePrimaryKey{}); err != nil { t.Fatalf("failed to migrate, got error %v", err) } p1 := CustomizedTypePrimaryKey{Name: "p1"} p2 := CustomizedTypePrimaryKey{Name: "p2"} p3 := CustomizedTypePrimaryKey{Name: "p3"} DB.Create(&p1) DB.Create(&p2) DB.Create(&p3) var p CustomizedTypePrimaryKey if err := DB.First(&p, p2.ID).Error; err != nil { t.Errorf("No error should returns, but got %v", err) } AssertEqual(t, p, p2) if err := DB.First(&p, "id = ?", p2.ID).Error; err != nil { t.Errorf("No error should happen when querying with customized type for primary key, got err %v", err) } AssertEqual(t, p, p2) } func TestStringPrimaryKeyForNumericValueStartingWithZero(t *testing.T) { type AddressByZipCode struct { ZipCode string `gorm:"primary_key"` Address string } DB.Migrator().DropTable(&AddressByZipCode{}) if err := DB.AutoMigrate(&AddressByZipCode{}); err != nil { t.Fatalf("failed to migrate, got error %v", err) } address := AddressByZipCode{ZipCode: "00501", Address: "Holtsville"} DB.Create(&address) var result AddressByZipCode DB.First(&result, "00501") AssertEqual(t, result, address) } func TestSearchWithEmptyChain(t *testing.T) { user := User{Name: "search_with_empty_chain", Age: 1} DB.Create(&user) var result User if DB.Where("").Where("").First(&result).Error != nil { t.Errorf("Should not raise any error if searching with empty strings") } result = User{} if DB.Where(&User{}).Where("name = ?", user.Name).First(&result).Error != nil { t.Errorf("Should not raise any error if searching with empty struct") } result = User{} if DB.Where(map[string]interface{}{}).Where("name = ?", user.Name).First(&result).Error != nil { t.Errorf("Should not raise any error if searching with empty map") } } func TestOrder(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) result := dryDB.Order("").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Order(nil).Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* IS NULL$").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Order("age desc, name").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc, name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Order("age desc").Order("name").Find(&User{}) if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY age desc,name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } stmt := dryDB.Clauses(clause.OrderBy{ Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, }).Find(&User{}).Statement explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) if !regexp.MustCompile("SELECT \\* FROM .*users.* ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { t.Fatalf("Build Order condition, but got %v", explainedSQL) } } func TestOrderWithAllFields(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) userQuery := "SELECT .*users.*id.*users.*created_at.*users.*updated_at.*users.*deleted_at.*users.*name.*users.*age" + ".*users.*birthday.*users.*company_id.*users.*manager_id.*users.*active.* FROM .*users.* " result := dryDB.Order("users.age desc, users.name").Find(&User{}) if !regexp.MustCompile(userQuery + "users.age desc, users.name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } result = dryDB.Order("users.age desc").Order("users.name").Find(&User{}) if !regexp.MustCompile(userQuery + "ORDER BY users.age desc,users.name").MatchString(result.Statement.SQL.String()) { t.Fatalf("Build Order condition, but got %v", result.Statement.SQL.String()) } stmt := dryDB.Clauses(clause.OrderBy{ Expression: clause.Expr{SQL: "FIELD(id,?)", Vars: []interface{}{[]int{1, 2, 3}}, WithoutParentheses: true}, }).Find(&User{}).Statement explainedSQL := dryDB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) if !regexp.MustCompile(userQuery + "ORDER BY FIELD\\(id,1,2,3\\)").MatchString(explainedSQL) { t.Fatalf("Build Order condition, but got %v", explainedSQL) } } func TestLimit(t *testing.T) { users := []User{ {Name: "LimitUser1", Age: 1}, {Name: "LimitUser2", Age: 10}, {Name: "LimitUser3", Age: 20}, {Name: "LimitUser4", Age: 10}, {Name: "LimitUser5", Age: 20}, {Name: "LimitUser6", Age: 20}, } DB.Create(&users) var users1, users2, users3 []User DB.Order("age desc").Limit(3).Find(&users1).Limit(5).Find(&users2).Limit(-1).Find(&users3) if len(users1) != 3 || len(users2) != 5 || len(users3) <= 5 { t.Errorf("Limit should works, users1 %v users2 %v users3 %v", len(users1), len(users2), len(users3)) } } func TestOffset(t *testing.T) { for i := 0; i < 20; i++ { DB.Save(&User{Name: fmt.Sprintf("OffsetUser%v", i)}) } var users1, users2, users3, users4 []User DB.Limit(100).Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work") } DB.Where("name like ?", "OffsetUser%").Order("age desc").Find(&users1).Offset(3).Find(&users2).Offset(5).Find(&users3).Offset(-1).Find(&users4) if (len(users1) != len(users4)) || (len(users1)-len(users2) != 3) || (len(users1)-len(users3) != 5) { t.Errorf("Offset should work without limit.") } } func TestSearchWithMap(t *testing.T) { users := []User{ *GetUser("map_search_user1", Config{}), *GetUser("map_search_user2", Config{}), *GetUser("map_search_user3", Config{}), *GetUser("map_search_user4", Config{Company: true}), } DB.Create(&users) var user User DB.First(&user, map[string]interface{}{"name": users[0].Name}) CheckUser(t, user, users[0]) user = User{} DB.First(&user, map[string]interface{}{"users.name": users[0].Name}) CheckUser(t, user, users[0]) user = User{} DB.Where(map[string]interface{}{"name": users[1].Name}).First(&user) CheckUser(t, user, users[1]) var results []User DB.Where(map[string]interface{}{"name": users[2].Name}).Find(&results) if len(results) != 1 { t.Fatalf("Search all records with inline map") } CheckUser(t, results[0], users[2]) var results2 []User DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": nil}) if len(results2) != 0 { t.Errorf("Search all records with inline map containing null value finding 0 records") } DB.Find(&results2, map[string]interface{}{"name": users[0].Name, "company_id": nil}) if len(results2) != 1 { t.Errorf("Search all records with inline map containing null value finding 1 record") } DB.Find(&results2, map[string]interface{}{"name": users[3].Name, "company_id": users[3].CompanyID}) if len(results2) != 1 { t.Errorf("Search all records with inline multiple value map") } } func TestSearchWithStruct(t *testing.T) { dryRunDB := DB.Session(&gorm.Session{DryRun: true}) result := dryRunDB.Where(User{Name: "jinzhu"}).Find(&User{}) if !regexp.MustCompile(`WHERE .users.\..name. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } result = dryRunDB.Where(User{Name: "jinzhu", Age: 18}).Find(&User{}) if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } result = dryRunDB.Where(User{Name: "jinzhu"}, "name", "Age").Find(&User{}) if !regexp.MustCompile(`WHERE \(.users.\..name. = .{1,3} AND .users.\..age. = .{1,3}\) AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } result = dryRunDB.Where(User{Name: "jinzhu"}, "age").Find(&User{}) if !regexp.MustCompile(`WHERE .users.\..age. = .{1,3} AND .users.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } func TestSubQuery(t *testing.T) { users := []User{ {Name: "subquery_1", Age: 10}, {Name: "subquery_2", Age: 20}, {Name: "subquery_3", Age: 30}, {Name: "subquery_4", Age: 40}, } DB.Create(&users) if err := DB.Select("*").Where("name IN (?)", DB.Select("name").Table("users").Where("name LIKE ?", "subquery_%")).Find(&users).Error; err != nil { t.Fatalf("got error: %v", err) } if len(users) != 4 { t.Errorf("Four users should be found, instead found %d", len(users)) } DB.Select("*").Where("name LIKE ?", "subquery%").Where("age >= (?)", DB. Select("AVG(age)").Table("users").Where("name LIKE ?", "subquery%")).Find(&users) if len(users) != 2 { t.Errorf("Two users should be found, instead found %d", len(users)) } } func TestSubQueryWithRaw(t *testing.T) { users := []User{ {Name: "subquery_raw_1", Age: 10}, {Name: "subquery_raw_2", Age: 20}, {Name: "subquery_raw_3", Age: 30}, {Name: "subquery_raw_4", Age: 40}, } DB.Create(&users) var count int64 err := DB.Raw("select count(*) from (?) tmp where 1 = ? AND name IN (?)", DB.Raw("select name from users where age >= ? and name in (?)", 10, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"}), 1, DB.Raw("select name from users where age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_2", "subquery_raw_3"})).Scan(&count).Error if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } if count != 2 { t.Errorf("Row count must be 2, instead got %d", count) } err = DB.Raw("select count(*) from (?) tmp", DB.Table("users"). Select("name"). Where("age >= ? and name in (?)", 20, []string{"subquery_raw_1", "subquery_raw_3"}). Group("name"), ).Count(&count).Error if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } if count != 1 { t.Errorf("Row count must be 1, instead got %d", count) } err = DB.Raw("select count(*) from (?) tmp", DB.Table("users"). Select("name"). Where("name LIKE ?", "subquery_raw%"). Not("age <= ?", 10).Not("name IN (?)", []string{"subquery_raw_1", "subquery_raw_3"}). Group("name"), ).Count(&count).Error if err != nil { t.Errorf("Expected to get no errors, but got %v", err) } if count != 2 { t.Errorf("Row count must be 2, instead got %d", count) } } func TestSubQueryWithHaving(t *testing.T) { users := []User{ {Name: "subquery_having_1", Age: 10}, {Name: "subquery_having_2", Age: 20}, {Name: "subquery_having_3", Age: 30}, {Name: "subquery_having_4", Age: 40}, } DB.Create(&users) var results []User DB.Select("AVG(age) as avgage").Where("name LIKE ?", "subquery_having%").Group("name").Having("AVG(age) > (?)", DB. Select("AVG(age)").Where("name LIKE ?", "subquery_having%").Table("users")).Find(&results) if len(results) != 2 { t.Errorf("Two user group should be found, instead found %d", len(results)) } } func TestScanNullValue(t *testing.T) { user := GetUser("scan_null_value", Config{}) DB.Create(&user) if err := DB.Model(&user).Update("age", nil).Error; err != nil { t.Fatalf("failed to update column age for struct, got error %v", err) } var result User if err := DB.First(&result, "id = ?", user.ID).Error; err != nil { t.Fatalf("failed to query struct data with null age, got error %v", err) } AssertEqual(t, result, user) users := []User{ *GetUser("scan_null_value_for_slice_1", Config{}), *GetUser("scan_null_value_for_slice_2", Config{}), *GetUser("scan_null_value_for_slice_3", Config{}), } DB.Create(&users) if err := DB.Model(&users[0]).Update("age", nil).Error; err != nil { t.Fatalf("failed to update column age for struct, got error %v", err) } var results []User if err := DB.Find(&results, "name like ?", "scan_null_value_for_slice%").Error; err != nil { t.Fatalf("failed to query slice data with null age, got error %v", err) } } func TestQueryWithTableAndConditions(t *testing.T) { result := DB.Session(&gorm.Session{DryRun: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) if !regexp.MustCompile(`SELECT \* FROM .user. WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } func TestQueryWithTableAndConditionsAndAllFields(t *testing.T) { result := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}).Table("user").Find(&User{}, User{Name: "jinzhu"}) userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* FROM .user. " if !regexp.MustCompile(userQuery + `WHERE .user.\..name. = .+ AND .user.\..deleted_at. IS NULL`).MatchString(result.Statement.SQL.String()) { t.Errorf("invalid query SQL, got %v", result.Statement.SQL.String()) } } type DoubleInt64 struct { data int64 } func (t *DoubleInt64) Scan(val interface{}) error { switch v := val.(type) { case int64: t.data = v * 2 return nil default: return fmt.Errorf("DoubleInt64 cant not scan with:%v", v) } } // https://github.com/go-gorm/gorm/issues/5091 func TestQueryScannerWithSingleColumn(t *testing.T) { user := User{Name: "scanner_raw_1", Age: 10} DB.Create(&user) var result1 DoubleInt64 if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Pluck( "age", &result1).Error; err != nil { t.Errorf("Failed, got error: %v", err) } AssertEqual(t, result1.data, 20) var result2 DoubleInt64 if err := DB.Model(&User{}).Where("name LIKE ?", "scanner_raw_%").Limit(1).Select( "age").Scan(&result2).Error; err != nil { t.Errorf("Failed, got error: %v", err) } AssertEqual(t, result2.data, 20) } func TestQueryResetNullValue(t *testing.T) { type QueryResetItem struct { ID string `gorm:"type:varchar(5)"` Name string } type QueryResetNullValue struct { ID int Name string `gorm:"default:NULL"` Flag bool `gorm:"default:NULL"` Number1 int64 `gorm:"default:NULL"` Number2 uint64 `gorm:"default:NULL"` Number3 float64 `gorm:"default:NULL"` Now *time.Time `gorm:"default:NULL"` Item1Id string Item1 *QueryResetItem `gorm:"references:ID"` Item2Id string Item2 *QueryResetItem `gorm:"references:ID"` } DB.Migrator().DropTable(&QueryResetNullValue{}, &QueryResetItem{}) DB.AutoMigrate(&QueryResetNullValue{}, &QueryResetItem{}) now := time.Now() q1 := QueryResetNullValue{ Name: "name", Flag: true, Number1: 100, Number2: 200, Number3: 300.1, Now: &now, Item1: &QueryResetItem{ ID: "u_1_1", Name: "item_1_1", }, Item2: &QueryResetItem{ ID: "u_1_2", Name: "item_1_2", }, } q2 := QueryResetNullValue{ Item1: &QueryResetItem{ ID: "u_2_1", Name: "item_2_1", }, Item2: &QueryResetItem{ ID: "u_2_2", Name: "item_2_2", }, } var err error err = DB.Create(&q1).Error if err != nil { t.Errorf("failed to create:%v", err) } err = DB.Create(&q2).Error if err != nil { t.Errorf("failed to create:%v", err) } var qs []QueryResetNullValue err = DB.Joins("Item1").Joins("Item2").Find(&qs).Error if err != nil { t.Errorf("failed to find:%v", err) } if len(qs) != 2 { t.Fatalf("find count not equal:%d", len(qs)) } AssertEqual(t, q1, qs[0]) AssertEqual(t, q2, qs[1]) } func TestQueryError(t *testing.T) { type P struct{} var p1 P err := DB.Take(&p1, 1).Error AssertEqual(t, err, gorm.ErrModelAccessibleFieldsRequired) var p2 interface{} err = DB.Table("ps").Clauses(clause.Eq{Column: clause.Column{ Table: clause.CurrentTable, Name: clause.PrimaryKey, }, Value: 1}).Scan(&p2).Error AssertEqual(t, err, gorm.ErrModelValueRequired) } func TestQueryScanToArray(t *testing.T) { err := DB.Create(&User{Name: "testname1", Age: 10}).Error if err != nil { t.Fatal(err) } users := [2]*User{{Name: "1"}, {Name: "2"}} err = DB.Model(&User{}).Where("name = ?", "testname1").Find(&users).Error if err != nil { t.Fatal(err) } if users[0] == nil || users[0].Name != "testname1" { t.Error("users[0] not covered") } if users[1] != nil { t.Error("users[1] should be empty") } } ================================================ FILE: tests/scan_test.go ================================================ package tests_test import ( "reflect" "sort" "strings" "testing" "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) type PersonAddressInfo struct { Person *Person `gorm:"embedded"` Address *Address `gorm:"embedded"` } func TestScan(t *testing.T) { user1 := User{Name: "ScanUser1", Age: 1} user2 := User{Name: "ScanUser2", Age: 10} user3 := User{Name: "ScanUser3", Age: 20} DB.Save(&user1).Save(&user2).Save(&user3) type result struct { ID uint Name string Age int } var res result DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&res) if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } var resPointer *result if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resPointer).Error; err != nil { t.Fatalf("Failed to query with pointer of value, got error %v", err) } else if resPointer.ID != user3.ID || resPointer.Name != user3.Name || resPointer.Age != int(user3.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } DB.Table("users").Select("id, name, age").Where("id = ?", user2.ID).Scan(&res) if res.ID != user2.ID || res.Name != user2.Name || res.Age != int(user2.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user2) } DB.Model(&User{Model: gorm.Model{ID: user3.ID}}).Select("id, name, age").Scan(&res) if res.ID != user3.ID || res.Name != user3.Name || res.Age != int(user3.Age) { t.Fatalf("Scan into struct should work, got %#v, should %#v", res, user3) } doubleAgeRes := &result{} if err := DB.Table("users").Select("age + age as age").Where("id = ?", user3.ID).Scan(&doubleAgeRes).Error; err != nil { t.Errorf("Scan to pointer of pointer") } if doubleAgeRes.Age != int(res.Age)*2 { t.Errorf("Scan double age as age, expect: %v, got %v", res.Age*2, doubleAgeRes.Age) } var results []result DB.Table("users").Select("name, age").Where("id in ?", []uint{user2.ID, user3.ID}).Scan(&results) sort.Slice(results, func(i, j int) bool { return strings.Compare(results[i].Name, results[j].Name) <= -1 }) if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { t.Errorf("Scan into struct map, got %#v", results) } type ID uint64 var id ID DB.Raw("select id from users where id = ?", user2.ID).Scan(&id) if uint(id) != user2.ID { t.Errorf("Failed to scan to customized data type") } var resInt interface{} resInt = &User{} if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt).Error; err != nil { t.Fatalf("Failed to query with pointer of value, got error %v", err) } else if resInt.(*User).ID != user3.ID || resInt.(*User).Name != user3.Name || resInt.(*User).Age != user3.Age { t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt, user3) } var resInt2 interface{} resInt2 = &User{} if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt2).Error; err != nil { t.Fatalf("Failed to query with pointer of value, got error %v", err) } else if resInt2.(*User).ID != user3.ID || resInt2.(*User).Name != user3.Name || resInt2.(*User).Age != user3.Age { t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt2, user3) } var resInt3 interface{} resInt3 = []User{} if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Find(&resInt3).Error; err != nil { t.Fatalf("Failed to query with pointer of value, got error %v", err) } else if rus := resInt3.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt3, user3) } var resInt4 interface{} resInt4 = []User{} if err := DB.Table("users").Select("id, name, age").Where("id = ?", user3.ID).Scan(&resInt4).Error; err != nil { t.Fatalf("Failed to query with pointer of value, got error %v", err) } else if rus := resInt4.([]User); len(rus) == 0 || rus[0].ID != user3.ID || rus[0].Name != user3.Name || rus[0].Age != user3.Age { t.Fatalf("Scan into struct should work, got %#v, should %#v", resInt4, user3) } var resInt5 interface{} resInt5 = []User{} if err := DB.Table("users").Select("id, name, age").Where("id IN ?", []uint{user1.ID, user2.ID, user3.ID}).Find(&resInt5).Error; err != nil { t.Fatalf("Failed to query with pointer of value, got error %v", err) } else if rus := resInt5.([]User); len(rus) != 3 { t.Fatalf("Scan into struct should work, got %+v, len %v", resInt5, len(rus)) } } func TestScanRows(t *testing.T) { user1 := User{Name: "ScanRowsUser1", Age: 1} user2 := User{Name: "ScanRowsUser2", Age: 10} user3 := User{Name: "ScanRowsUser3", Age: 20} DB.Save(&user1).Save(&user2).Save(&user3) rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() if err != nil { t.Errorf("No error should happen, got %v", err) } type Result struct { Name string Age int } var results []Result for rows.Next() { var result Result if err := DB.ScanRows(rows, &result); err != nil { t.Errorf("should get no error, but got %v", err) } results = append(results, result) } sort.Slice(results, func(i, j int) bool { return strings.Compare(results[i].Name, results[j].Name) <= -1 }) if !reflect.DeepEqual(results, []Result{{Name: "ScanRowsUser2", Age: 10}, {Name: "ScanRowsUser3", Age: 20}}) { t.Errorf("Should find expected results, got %+v", results) } var ages int if err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("SUM(age)").Scan(&ages).Error; err != nil || ages != 30 { t.Fatalf("failed to scan ages, got error %v, ages: %v", err, ages) } var name string if err := DB.Table("users").Where("name = ?", user2.Name).Select("name").Scan(&name).Error; err != nil || name != user2.Name { t.Fatalf("failed to scan name, got error %v, name: %v", err, name) } } func TestScanRowsNullValuesScanToFieldDefault(t *testing.T) { DB.Save(&User{}) rows, err := DB.Table("users"). Select(` NULL AS bool_field, NULL AS int_field, NULL AS int8_field, NULL AS int16_field, NULL AS int32_field, NULL AS int64_field, NULL AS uint_field, NULL AS uint8_field, NULL AS uint16_field, NULL AS uint32_field, NULL AS uint64_field, NULL AS float32_field, NULL AS float64_field, NULL AS string_field, NULL AS time_field, NULL AS time_ptr_field, NULL AS embedded_int_field, NULL AS nested_embedded_int_field, NULL AS embedded_ptr_int_field `).Rows() if err != nil { t.Errorf("No error should happen, got %v", err) } type NestedEmbeddedStruct struct { NestedEmbeddedIntField int NestedEmbeddedIntFieldWithDefault int `gorm:"default:2"` } type EmbeddedStruct struct { EmbeddedIntField int NestedEmbeddedStruct `gorm:"embedded"` } type EmbeddedPtrStruct struct { EmbeddedPtrIntField int *NestedEmbeddedStruct `gorm:"embedded"` } type Result struct { BoolField bool IntField int Int8Field int8 Int16Field int16 Int32Field int32 Int64Field int64 UIntField uint UInt8Field uint8 UInt16Field uint16 UInt32Field uint32 UInt64Field uint64 Float32Field float32 Float64Field float64 StringField string TimeField time.Time TimePtrField *time.Time EmbeddedStruct `gorm:"embedded"` *EmbeddedPtrStruct `gorm:"embedded"` } currTime := time.Now() reusedVar := Result{ BoolField: true, IntField: 1, Int8Field: 1, Int16Field: 1, Int32Field: 1, Int64Field: 1, UIntField: 1, UInt8Field: 1, UInt16Field: 1, UInt32Field: 1, UInt64Field: 1, Float32Field: 1.1, Float64Field: 1.1, StringField: "hello", TimeField: currTime, TimePtrField: &currTime, EmbeddedStruct: EmbeddedStruct{EmbeddedIntField: 1, NestedEmbeddedStruct: NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, EmbeddedPtrStruct: &EmbeddedPtrStruct{EmbeddedPtrIntField: 1, NestedEmbeddedStruct: &NestedEmbeddedStruct{NestedEmbeddedIntField: 1, NestedEmbeddedIntFieldWithDefault: 2}}, } for rows.Next() { if err := DB.ScanRows(rows, &reusedVar); err != nil { t.Errorf("should get no error, but got %v", err) } } if !reflect.DeepEqual(reusedVar, Result{}) { t.Errorf("Should find zero values in struct fields, got %+v\n", reusedVar) } } func TestScanToEmbedded(t *testing.T) { person1 := Person{Name: "person 1"} person2 := Person{Name: "person 2"} DB.Save(&person1).Save(&person2) address1 := Address{Name: "address 1"} address2 := Address{Name: "address 2"} DB.Save(&address1).Save(&address2) DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address1.ID)}) DB.Create(&PersonAddress{PersonID: person1.ID, AddressID: int(address2.ID)}) DB.Create(&PersonAddress{PersonID: person2.ID, AddressID: int(address1.ID)}) var personAddressInfoList []*PersonAddressInfo if err := DB.Select("people.*, addresses.*"). Table("people"). Joins("inner join person_addresses on people.id = person_addresses.person_id"). Joins("inner join addresses on person_addresses.address_id = addresses.id"). Find(&personAddressInfoList).Error; err != nil { t.Errorf("Failed to run join query, got error: %v", err) } personMatched := false addressMatched := false for _, info := range personAddressInfoList { if info.Person == nil { t.Fatalf("Failed, expected not nil, got person nil") } if info.Address == nil { t.Fatalf("Failed, expected not nil, got address nil") } if info.Person.ID == person1.ID { personMatched = true if info.Person.Name != person1.Name { t.Errorf("Failed, expected %v, got %v", person1.Name, info.Person.Name) } } if info.Address.ID == address1.ID { addressMatched = true if info.Address.Name != address1.Name { t.Errorf("Failed, expected %v, got %v", address1.Name, info.Address.Name) } } } if !personMatched { t.Errorf("Failed, no person matched") } if !addressMatched { t.Errorf("Failed, no address matched") } personDupField := Person{ID: person1.ID} if err := DB.Select("people.id, people.*"). First(&personDupField).Error; err != nil { t.Errorf("Failed to run join query, got error: %v", err) } AssertEqual(t, person1, personDupField) user := User{ Name: "TestScanToEmbedded_1", Manager: &User{ Name: "TestScanToEmbedded_1_m1", Manager: &User{Name: "TestScanToEmbedded_1_m1_m1"}, }, } DB.Create(&user) type UserScan struct { ID uint Name string ManagerID *uint } var user2 UserScan err := DB.Raw("SELECT * FROM users INNER JOIN users Manager ON users.manager_id = Manager.id WHERE users.id = ?", user.ID).Scan(&user2).Error AssertEqual(t, err, nil) } ================================================ FILE: tests/scanner_valuer_test.go ================================================ package tests_test import ( "context" "database/sql" "database/sql/driver" "encoding/json" "errors" "fmt" "reflect" "regexp" "strconv" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestScannerValuer(t *testing.T) { DB.Migrator().DropTable(&ScannerValuerStruct{}) if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } data := ScannerValuerStruct{ Name: sql.NullString{String: "name", Valid: true}, Gender: &sql.NullString{String: "M", Valid: true}, Age: sql.NullInt64{Int64: 18, Valid: true}, Male: sql.NullBool{Bool: true, Valid: true}, Height: sql.NullFloat64{Float64: 1.8888, Valid: true}, Birthday: sql.NullTime{Time: time.Now(), Valid: true}, Allergen: NullString{sql.NullString{String: "Allergen", Valid: true}}, Password: EncryptedData("pass1"), Bytes: []byte("byte"), Num: 18, Strings: StringsSlice{"a", "b", "c"}, Structs: StructsSlice{ {"name1", "value1"}, {"name2", "value2"}, }, Role: Role{Name: "admin"}, ExampleStruct: ExampleStruct{"name", "value1"}, ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err != nil { t.Fatalf("No error should happened when create scanner valuer struct, but got %v", err) } var result ScannerValuerStruct if err := DB.Find(&result, "id = ?", data.ID).Error; err != nil { t.Fatalf("no error should happen when query scanner, valuer struct, but got %v", err) } if result.ExampleStructPtr.Val != "value2" { t.Errorf(`ExampleStructPtr.Val should equal to "value2", but got %v`, result.ExampleStructPtr.Val) } if result.ExampleStruct.Val != "value1" { t.Errorf(`ExampleStruct.Val should equal to "value1", but got %#v`, result.ExampleStruct) } AssertObjEqual(t, data, result, "Name", "Gender", "Age", "Male", "Height", "Birthday", "Password", "Bytes", "Num", "Strings", "Structs") } func TestScannerValuerWithFirstOrCreate(t *testing.T) { DB.Migrator().DropTable(&ScannerValuerStruct{}) if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { t.Errorf("no error should happen when migrate scanner, valuer struct") } data := ScannerValuerStruct{ Name: sql.NullString{String: "name", Valid: true}, Gender: &sql.NullString{String: "M", Valid: true}, Age: sql.NullInt64{Int64: 18, Valid: true}, ExampleStruct: ExampleStruct{"name", "value1"}, ExampleStructPtr: &ExampleStruct{"name", "value2"}, } var result ScannerValuerStruct tx := DB.Where(data).FirstOrCreate(&result) if tx.RowsAffected != 1 { t.Errorf("RowsAffected should be 1 after create some record") } if tx.Error != nil { t.Errorf("Should not raise any error, but got %v", tx.Error) } AssertObjEqual(t, result, data, "Name", "Gender", "Age") if err := DB.Where(data).Assign(ScannerValuerStruct{Age: sql.NullInt64{Int64: 18, Valid: true}}).FirstOrCreate(&result).Error; err != nil { t.Errorf("Should not raise any error, but got %v", err) } if result.Age.Int64 != 18 { t.Errorf("should update age to 18") } var result2 ScannerValuerStruct if err := DB.First(&result2, result.ID).Error; err != nil { t.Errorf("got error %v when query with %v", err, result.ID) } AssertObjEqual(t, result2, result, "ID", "CreatedAt", "UpdatedAt", "Name", "Gender", "Age") } func TestInvalidValuer(t *testing.T) { DB.Migrator().DropTable(&ScannerValuerStruct{}) if err := DB.Migrator().AutoMigrate(&ScannerValuerStruct{}); err != nil { t.Errorf("no error should happen when migrate scanner, valuer struct") } data := ScannerValuerStruct{ Password: EncryptedData("xpass1"), ExampleStruct: ExampleStruct{"name", "value1"}, ExampleStructPtr: &ExampleStruct{"name", "value2"}, } if err := DB.Create(&data).Error; err == nil { t.Errorf("Should failed to create data with invalid data") } data.Password = EncryptedData("pass1") if err := DB.Create(&data).Error; err != nil { t.Errorf("Should got no error when creating data, but got %v", err) } if err := DB.Model(&data).Update("password", EncryptedData("xnewpass")).Error; err == nil { t.Errorf("Should failed to update data with invalid data") } if err := DB.Model(&data).Update("password", EncryptedData("newpass")).Error; err != nil { t.Errorf("Should got no error update data with valid data, but got %v", err) } AssertEqual(t, data.Password, EncryptedData("newpass")) } type ScannerValuerStruct struct { gorm.Model Name sql.NullString Gender *sql.NullString Age sql.NullInt64 Male sql.NullBool Height sql.NullFloat64 Birthday sql.NullTime Allergen NullString Password EncryptedData Bytes []byte Num Num Strings StringsSlice Structs StructsSlice Role Role UserID *sql.NullInt64 User User EmptyTime EmptyTime ExampleStruct ExampleStruct ExampleStructPtr *ExampleStruct } type EncryptedData []byte func (data *EncryptedData) Scan(value interface{}) error { if b, ok := value.([]byte); ok { if len(b) < 3 || b[0] != '*' || b[1] != '*' || b[2] != '*' { return errors.New("Too short") } *data = append((*data)[0:], b[3:]...) return nil } else if s, ok := value.(string); ok { *data = []byte(s[3:]) return nil } return errors.New("Bytes expected") } func (data EncryptedData) Value() (driver.Value, error) { if len(data) > 0 && data[0] == 'x' { // needed to test failures return nil, errors.New("Should not start with 'x'") } // prepend asterisks return append([]byte("***"), data...), nil } type Num int64 func (i *Num) Scan(src interface{}) error { switch s := src.(type) { case []byte: n, _ := strconv.Atoi(string(s)) *i = Num(n) case int64: *i = Num(s) default: return errors.New("Cannot scan NamedInt from " + reflect.ValueOf(src).String()) } return nil } type StringsSlice []string func (l StringsSlice) Value() (driver.Value, error) { bytes, err := json.Marshal(l) return string(bytes), err } func (l *StringsSlice) Scan(input interface{}) error { switch value := input.(type) { case string: return json.Unmarshal([]byte(value), l) case []byte: return json.Unmarshal(value, l) default: return errors.New("not supported") } } type ExampleStruct struct { Name string Val string } func (ExampleStruct) GormDataType() string { return "bytes" } func (s ExampleStruct) Value() (driver.Value, error) { if len(s.Name) == 0 { return nil, nil } // for test, has no practical meaning s.Name = "" return json.Marshal(s) } func (s *ExampleStruct) Scan(src interface{}) error { switch value := src.(type) { case string: return json.Unmarshal([]byte(value), s) case []byte: return json.Unmarshal(value, s) default: return errors.New("not supported") } } type StructsSlice []ExampleStruct func (l StructsSlice) Value() (driver.Value, error) { bytes, err := json.Marshal(l) return string(bytes), err } func (l *StructsSlice) Scan(input interface{}) error { switch value := input.(type) { case string: return json.Unmarshal([]byte(value), l) case []byte: return json.Unmarshal(value, l) default: return errors.New("not supported") } } type Role struct { Name string `gorm:"size:256"` } func (role *Role) Scan(value interface{}) error { if b, ok := value.([]uint8); ok { role.Name = string(b) } else { role.Name = value.(string) } return nil } func (role Role) Value() (driver.Value, error) { return role.Name, nil } func (role Role) IsAdmin() bool { return role.Name == "admin" } type EmptyTime struct { time.Time } func (t *EmptyTime) Scan(v interface{}) error { nullTime := sql.NullTime{} err := nullTime.Scan(v) t.Time = nullTime.Time return err } func (t EmptyTime) Value() (driver.Value, error) { return time.Now() /* pass tests, mysql 8 doesn't support 0000-00-00 by default */, nil } type NullString struct { sql.NullString } type Point struct { X, Y int } func (point Point) GormDataType() string { return "geo" } func (point Point) GormValue(ctx context.Context, db *gorm.DB) clause.Expr { return clause.Expr{ SQL: "ST_PointFromText(?)", Vars: []interface{}{fmt.Sprintf("POINT(%d %d)", point.X, point.Y)}, } } func TestGORMValuer(t *testing.T) { type UserWithPoint struct { Name string Point Point } dryRunDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryRunDB.Create(&UserWithPoint{ Name: "jinzhu", Point: Point{X: 100, Y: 100}, }).Statement if stmt.SQL.String() == "" || len(stmt.Vars) != 2 { t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) } if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) } if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } stmt = dryRunDB.Model(UserWithPoint{}).Create(map[string]interface{}{ "Name": "jinzhu", "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, }).Statement if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.name.,.point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) } if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } stmt = dryRunDB.Table("user_with_points").Create(&map[string]interface{}{ "Name": "jinzhu", "Point": clause.Expr{SQL: "ST_PointFromText(?)", Vars: []interface{}{"POINT(100 100)"}}, }).Statement if !regexp.MustCompile(`INSERT INTO .user_with_points. \(.Name.,.Point.\) VALUES \(.+,ST_PointFromText\(.+\)\)`).MatchString(stmt.SQL.String()) { t.Errorf("insert with sql.Expr, but got %v", stmt.SQL.String()) } if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } stmt = dryRunDB.Session(&gorm.Session{ AllowGlobalUpdate: true, }).Model(&UserWithPoint{}).Updates(UserWithPoint{ Name: "jinzhu", Point: Point{X: 100, Y: 100}, }).Statement if !regexp.MustCompile(`UPDATE .user_with_points. SET .name.=.+,.point.=ST_PointFromText\(.+\)`).MatchString(stmt.SQL.String()) { t.Errorf("update with sql.Expr, but got %v", stmt.SQL.String()) } if !reflect.DeepEqual([]interface{}{"jinzhu", "POINT(100 100)"}, stmt.Vars) { t.Errorf("generated vars is not equal, got %v", stmt.Vars) } } ================================================ FILE: tests/scopes_test.go ================================================ package tests_test import ( "context" "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func NameIn1And2(d *gorm.DB) *gorm.DB { return d.Where("name in (?)", []string{"ScopeUser1", "ScopeUser2"}) } func NameIn2And3(d *gorm.DB) *gorm.DB { return d.Where("name in (?)", []string{"ScopeUser2", "ScopeUser3"}) } func NameIn(names []string) func(d *gorm.DB) *gorm.DB { return func(d *gorm.DB) *gorm.DB { return d.Where("name in (?)", names) } } func TestScopes(t *testing.T) { users := []*User{ GetUser("ScopeUser1", Config{}), GetUser("ScopeUser2", Config{}), GetUser("ScopeUser3", Config{}), } DB.Create(&users) var users1, users2, users3 []User DB.Scopes(NameIn1And2).Find(&users1) if len(users1) != 2 { t.Errorf("Should found two users's name in 1, 2, but got %v", len(users1)) } DB.Scopes(NameIn1And2, NameIn2And3).Find(&users2) if len(users2) != 1 { t.Errorf("Should found one user's name is 2, but got %v", len(users2)) } DB.Scopes(NameIn([]string{users[0].Name, users[2].Name})).Find(&users3) if len(users3) != 2 { t.Errorf("Should found two users's name in 1, 3, but got %v", len(users3)) } db := DB.Scopes(func(tx *gorm.DB) *gorm.DB { return tx.Table("custom_table") }).Session(&gorm.Session{}) db.AutoMigrate(&User{}) if db.Find(&User{}).Statement.Table != "custom_table" { t.Errorf("failed to call Scopes") } result := DB.Scopes(NameIn1And2, func(tx *gorm.DB) *gorm.DB { return tx.Session(&gorm.Session{}) }).Find(&users1) if result.RowsAffected != 2 { t.Errorf("Should found two users's name in 1, 2, but got %v", result.RowsAffected) } var maxId int64 userTable := func(db *gorm.DB) *gorm.DB { return db.WithContext(context.Background()).Table("users") } if err := DB.Scopes(userTable).Select("max(id)").Scan(&maxId).Error; err != nil { t.Errorf("select max(id)") } } func TestComplexScopes(t *testing.T) { tests := []struct { name string queryFn func(tx *gorm.DB) *gorm.DB expected string }{ { name: "depth_1", queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, func(d *gorm.DB) *gorm.DB { return d.Where(DB.Or("b = 2").Or("c = 3")) }, ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE a = 1 AND (b = 2 OR c = 3)`, }, { name: "depth_1_pre_cond", queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Where("z = 0").Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, func(d *gorm.DB) *gorm.DB { return d.Or(DB.Where("b = 2").Or("c = 3")) }, ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE z = 0 AND a = 1 OR (b = 2 OR c = 3)`, }, { name: "depth_2", queryFn: func(tx *gorm.DB) *gorm.DB { return tx.Scopes( func(d *gorm.DB) *gorm.DB { return d.Model(&Language{}) }, func(d *gorm.DB) *gorm.DB { return d. Or(DB.Scopes( func(d *gorm.DB) *gorm.DB { return d.Where("a = 1") }, func(d *gorm.DB) *gorm.DB { return d.Where("b = 2") }, )). Or("c = 3") }, func(d *gorm.DB) *gorm.DB { return d.Where("d = 4") }, ).Find(&Language{}) }, expected: `SELECT * FROM "languages" WHERE d = 4 OR c = 3 OR (a = 1 AND b = 2)`, }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { assertEqualSQL(t, test.expected, DB.ToSQL(test.queryFn)) }) } } ================================================ FILE: tests/serializer_test.go ================================================ package tests_test import ( "bytes" "context" "fmt" "reflect" "strings" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/schema" . "gorm.io/gorm/utils/tests" ) type SerializerStruct struct { gorm.Model Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Roles2 *Roles `gorm:"serializer:json"` Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type UpdatedTime *int64 `gorm:"serializer:unixtime;type:datetime"` // store time in db, use int as field type CustomSerializerString string `gorm:"serializer:custom"` EncryptedString EncryptedString } type SerializerPostgresStruct struct { gorm.Model Name []byte `gorm:"json"` Roles Roles `gorm:"serializer:json"` Roles2 *Roles `gorm:"serializer:json"` Roles3 *Roles `gorm:"serializer:json;not null"` Contracts map[string]interface{} `gorm:"serializer:json"` JobInfo Job `gorm:"type:bytes;serializer:gob"` CreatedTime int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type UpdatedTime *int64 `gorm:"serializer:unixtime;type:timestamptz"` // store time in db, use int as field type CustomSerializerString string `gorm:"serializer:custom"` EncryptedString EncryptedString } func (*SerializerPostgresStruct) TableName() string { return "serializer_structs" } func adaptorSerializerModel(s *SerializerStruct) interface{} { if DB.Dialector.Name() == "postgres" || DB.Dialector.Name() == "gaussdb" { sps := SerializerPostgresStruct(*s) return &sps } return s } type Roles []string type Job struct { Title string Number int Location string IsIntern bool } type EncryptedString string func (es *EncryptedString) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { switch value := dbValue.(type) { case []byte: *es = EncryptedString(bytes.TrimPrefix(value, []byte("hello"))) case string: *es = EncryptedString(strings.TrimPrefix(value, "hello")) default: return fmt.Errorf("unsupported data %#v", dbValue) } return nil } func (es EncryptedString) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { return "hello" + string(es), nil } type CustomSerializer struct { prefix []byte } func NewCustomSerializer(prefix string) *CustomSerializer { return &CustomSerializer{prefix: []byte(prefix)} } func (c *CustomSerializer) Scan(ctx context.Context, field *schema.Field, dst reflect.Value, dbValue interface{}) (err error) { switch value := dbValue.(type) { case []byte: err = field.Set(ctx, dst, bytes.TrimPrefix(value, c.prefix)) case string: err = field.Set(ctx, dst, strings.TrimPrefix(value, string(c.prefix))) default: err = fmt.Errorf("unsupported data %#v", dbValue) } return err } func (c *CustomSerializer) Value(ctx context.Context, field *schema.Field, dst reflect.Value, fieldValue interface{}) (interface{}, error) { return fmt.Sprintf("%s%s", c.prefix, fieldValue), nil } func TestSerializer(t *testing.T) { schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) updatedAt := createdAt.Unix() data := SerializerStruct{ Name: []byte("jinzhu"), Roles: []string{"r1", "r2"}, Contracts: map[string]interface{}{"name": "jinzhu", "age": 10}, EncryptedString: EncryptedString("pass"), CreatedTime: createdAt.Unix(), UpdatedTime: &updatedAt, JobInfo: Job{ Title: "programmer", Number: 9920, Location: "Kenmawr", IsIntern: false, }, CustomSerializerString: "world", } if err := DB.Create(&data).Error; err != nil { t.Fatalf("failed to create data, got error %v", err) } var result SerializerStruct if err := DB.Where("roles2 IS NULL AND roles3 = ?", "").First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } AssertEqual(t, result, data) if err := DB.Model(&result).Update("roles", "").Error; err != nil { t.Fatalf("failed to update data's roles, got error %v", err) } if err := DB.First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } } func TestSerializerZeroValue(t *testing.T) { schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } data := SerializerStruct{} if err := DB.Create(&data).Error; err != nil { t.Fatalf("failed to create data, got error %v", err) } var result SerializerStruct if err := DB.First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } AssertEqual(t, result, data) if err := DB.Model(&result).Update("roles", "").Error; err != nil { t.Fatalf("failed to update data's roles, got error %v", err) } if err := DB.First(&result, data.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } } func TestSerializerAssignFirstOrCreate(t *testing.T) { schema.RegisterSerializer("custom", NewCustomSerializer("hello")) DB.Migrator().DropTable(adaptorSerializerModel(&SerializerStruct{})) if err := DB.Migrator().AutoMigrate(adaptorSerializerModel(&SerializerStruct{})); err != nil { t.Fatalf("no error should happen when migrate scanner, valuer struct, got error %v", err) } createdAt := time.Date(2020, 1, 1, 0, 0, 0, 0, time.UTC) data := SerializerStruct{ Name: []byte("ag9920"), Roles: []string{"r1", "r2"}, Contracts: map[string]interface{}{"name": "jing1", "age": 11}, EncryptedString: EncryptedString("pass"), CreatedTime: createdAt.Unix(), JobInfo: Job{ Title: "programmer", Number: 9920, Location: "Shadyside", IsIntern: false, }, CustomSerializerString: "world", } // first time insert record out := SerializerStruct{} if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) } var result SerializerStruct if err := DB.First(&result, out.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } AssertEqual(t, result, out) // update record data.Roles = append(data.Roles, "r3") data.JobInfo.Location = "Gates Hillman Complex" if err := DB.Assign(data).FirstOrCreate(&out).Error; err != nil { t.Fatalf("failed to FirstOrCreate Assigned data, got error %v", err) } if err := DB.First(&result, out.ID).Error; err != nil { t.Fatalf("failed to query data, got error %v", err) } AssertEqual(t, result.Roles, data.Roles) AssertEqual(t, result.JobInfo.Location, data.JobInfo.Location) } // Test for: panic when serializer field with any type is nil func TestSerializerWithAnyType(t *testing.T) { type ProductWithAny struct { gorm.Model Name string Data any `gorm:"serializer:json"` } DB.Migrator().DropTable(&ProductWithAny{}) if err := DB.AutoMigrate(&ProductWithAny{}); err != nil { t.Fatalf("failed to migrate ProductWithAny, got error %v", err) } // Test creating record with nil any field product := ProductWithAny{Name: "Product 1"} if err := DB.Create(&product).Error; err != nil { t.Fatalf("failed to create product with nil any field, got error %v", err) } // Test updating/saving record with nil any field (should not panic) product.Name = "Product 1 (Updated)" if err := DB.Save(&product).Error; err != nil { t.Fatalf("failed to save product with nil any field, got error %v", err) } // Verify the record was saved correctly var result ProductWithAny if err := DB.First(&result, product.ID).Error; err != nil { t.Fatalf("failed to query product, got error %v", err) } if result.Name != "Product 1 (Updated)" { t.Errorf("expected name to be 'Product 1 (Updated)', got %s", result.Name) } if result.Data != nil { t.Errorf("expected Data to be nil, got %v", result.Data) } // Test with non-nil value dataValue := map[string]interface{}{"key": "value"} product2 := ProductWithAny{Name: "Product 2", Data: dataValue} if err := DB.Create(&product2).Error; err != nil { t.Fatalf("failed to create product with non-nil any field, got error %v", err) } var result2 ProductWithAny if err := DB.First(&result2, product2.ID).Error; err != nil { t.Fatalf("failed to query product2, got error %v", err) } if result2.Data == nil { t.Error("expected Data to be non-nil") } } ================================================ FILE: tests/soft_delete_test.go ================================================ package tests_test import ( "database/sql" "encoding/json" "errors" "regexp" "testing" "github.com/jinzhu/now" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestSoftDelete(t *testing.T) { user := *GetUser("SoftDelete", Config{}) DB.Save(&user) var count int64 var age uint if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) } if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) } if err := DB.Delete(&user).Error; err != nil { t.Fatalf("No error should happen when soft delete user, but got %v", err) } if sql.NullTime(user.DeletedAt).Time.IsZero() { t.Fatalf("user's deleted at is zero") } sql := DB.Session(&gorm.Session{DryRun: true}).Delete(&user).Statement.SQL.String() if !regexp.MustCompile(`UPDATE .users. SET .deleted_at.=.* WHERE .users.\..id. = .* AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = DB.Session(&gorm.Session{DryRun: true}).Table("user u").Select("name").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`SELECT .name. FROM user u WHERE .u.\..deleted_at. IS NULL`).MatchString(sql) { t.Errorf("Table with escape character, got %v", sql) } if DB.First(&User{}, "name = ?", user.Name).Error == nil { t.Errorf("Can't find a soft deleted record") } count = 0 if DB.Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 0 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) } age = 0 if DB.Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != 0 { t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) } if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; err != nil { t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) } count = 0 if DB.Unscoped().Model(&User{}).Where("name = ?", user.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) } age = 0 if DB.Unscoped().Model(&User{}).Select("age").Where("name = ?", user.Name).Scan(&age).Error != nil || age != user.Age { t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, age) } DB.Unscoped().Delete(&user) if err := DB.Unscoped().First(&User{}, "name = ?", user.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("Can't find permanently deleted record") } } func TestDeletedAtUnMarshal(t *testing.T) { expected := &gorm.Model{} b, _ := json.Marshal(expected) result := &gorm.Model{} _ = json.Unmarshal(b, result) if result.DeletedAt != expected.DeletedAt { t.Errorf("Failed, result.DeletedAt: %v is not same as expected.DeletedAt: %v", result.DeletedAt, expected.DeletedAt) } } func TestDeletedAtOneOr(t *testing.T) { actualSQL := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Or("id = ?", 1).Find(&User{}) }) if !regexp.MustCompile(` WHERE id = 1 AND .users.\..deleted_at. IS NULL`).MatchString(actualSQL) { t.Fatalf("invalid sql generated, got %v", actualSQL) } } func TestSoftDeleteZeroValue(t *testing.T) { type SoftDeleteBook struct { ID uint Name string Pages uint DeletedAt gorm.DeletedAt `gorm:"zeroValue:'1970-01-01 00:00:01'"` } DB.Migrator().DropTable(&SoftDeleteBook{}) if err := DB.AutoMigrate(&SoftDeleteBook{}); err != nil { t.Fatalf("failed to auto migrate soft delete table") } book := SoftDeleteBook{Name: "jinzhu", Pages: 10} DB.Save(&book) var count int64 if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 1, count) } var pages uint if DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { t.Errorf("Pages soft deleted record, expects: %v, got: %v", 0, pages) } if err := DB.Delete(&book).Error; err != nil { t.Fatalf("No error should happen when soft delete user, but got %v", err) } zeroTime, _ := now.Parse("1970-01-01 00:00:01") if book.DeletedAt.Time.Equal(zeroTime) { t.Errorf("book's deleted at should not be zero, DeletedAt: %v", book.DeletedAt) } if DB.First(&SoftDeleteBook{}, "name = ?", book.Name).Error == nil { t.Errorf("Can't find a soft deleted record") } count = 0 if DB.Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 0 { t.Errorf("Count soft deleted record, expects: %v, got: %v", 0, count) } pages = 0 if err := DB.Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error; err != nil || pages != 0 { t.Fatalf("Age soft deleted record, expects: %v, got: %v, err %v", 0, pages, err) } if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; err != nil { t.Errorf("Should find soft deleted record with Unscoped, but got err %s", err) } count = 0 if DB.Unscoped().Model(&SoftDeleteBook{}).Where("name = ?", book.Name).Count(&count).Error != nil || count != 1 { t.Errorf("Count soft deleted record, expects: %v, count: %v", 1, count) } pages = 0 if DB.Unscoped().Model(&SoftDeleteBook{}).Select("pages").Where("name = ?", book.Name).Scan(&pages).Error != nil || pages != book.Pages { t.Errorf("Age soft deleted record, expects: %v, got: %v", 0, pages) } DB.Unscoped().Delete(&book) if err := DB.Unscoped().First(&SoftDeleteBook{}, "name = ?", book.Name).Error; !errors.Is(err, gorm.ErrRecordNotFound) { t.Errorf("Can't find permanently deleted record") } } ================================================ FILE: tests/sql_builder_test.go ================================================ package tests_test import ( "regexp" "strings" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestRow(t *testing.T) { user1 := User{Name: "RowUser1", Age: 1} user2 := User{Name: "RowUser2", Age: 10} user3 := User{Name: "RowUser3", Age: 20} DB.Save(&user1).Save(&user2).Save(&user3) row := DB.Table("users").Where("name = ?", user2.Name).Select("age").Row() var age int64 if err := row.Scan(&age); err != nil { t.Fatalf("Failed to scan age, got %v", err) } if age != 10 { t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) } table := "gorm.users" if DB.Dialector.Name() != "mysql" || isTiDB() { table = "users" // other databases doesn't support select with `database.table` } DB.Table(table).Where(map[string]interface{}{"name": user2.Name}).Update("age", 20) row = DB.Table(table+" as u").Where("u.name = ?", user2.Name).Select("age").Row() if err := row.Scan(&age); err != nil { t.Fatalf("Failed to scan age, got %v", err) } if age != 20 { t.Errorf("Scan with Row, age expects: %v, got %v", user2.Age, age) } } func TestRows(t *testing.T) { user1 := User{Name: "RowsUser1", Age: 1} user2 := User{Name: "RowsUser2", Age: 10} user3 := User{Name: "RowsUser3", Age: 20} DB.Save(&user1).Save(&user2).Save(&user3) rows, err := DB.Table("users").Where("name = ? or name = ?", user2.Name, user3.Name).Select("name, age").Rows() if err != nil { t.Errorf("Not error should happen, got %v", err) } count := 0 for rows.Next() { var name string var age int64 rows.Scan(&name, &age) count++ } if count != 2 { t.Errorf("Should found two records") } } func TestRaw(t *testing.T) { user1 := User{Name: "ExecRawSqlUser1", Age: 1} user2 := User{Name: "ExecRawSqlUser2", Age: 10} user3 := User{Name: "ExecRawSqlUser3", Age: 20} DB.Save(&user1).Save(&user2).Save(&user3) type result struct { Name string Email string } var results []result DB.Raw("SELECT name, age FROM users WHERE name = ? or name = ?", user2.Name, user3.Name).Scan(&results) if len(results) != 2 || results[0].Name != user2.Name || results[1].Name != user3.Name { t.Errorf("Raw with scan") } rows, _ := DB.Raw("select name, age from users where name = ?", user3.Name).Rows() count := 0 for rows.Next() { count++ } if count != 1 { t.Errorf("Raw with Rows should find one record with name 3") } DB.Exec("update users set name=? where name in (?)", "jinzhu-raw", []string{user1.Name, user2.Name, user3.Name}) if DB.Where("name in (?)", []string{user1.Name, user2.Name, user3.Name}).First(&User{}).Error != gorm.ErrRecordNotFound { t.Error("Raw sql to update records") } DB.Exec("update users set age=? where name = ?", gorm.Expr("age * ? + ?", 2, 10), "jinzhu-raw") var age int DB.Raw("select sum(age) from users where name = ?", "jinzhu-raw").Scan(&age) if age != ((1+10+20)*2 + 30) { t.Errorf("Invalid age, got %v", age) } } func TestRowsWithGroup(t *testing.T) { users := []User{ {Name: "having_user_1", Age: 1}, {Name: "having_user_2", Age: 10}, {Name: "having_user_1", Age: 20}, {Name: "having_user_1", Age: 30}, } DB.Create(&users) rows, err := DB.Select("name, count(*) as total").Table("users").Group("name").Having("name IN ?", []string{users[0].Name, users[1].Name}).Rows() if err != nil { t.Fatalf("got error %v", err) } defer rows.Close() for rows.Next() { var name string var total int64 rows.Scan(&name, &total) if name == users[0].Name && total != 3 { t.Errorf("Should have one user having name %v", users[0].Name) } else if name == users[1].Name && total != 1 { t.Errorf("Should have two users having name %v", users[1].Name) } } } func TestQueryRaw(t *testing.T) { users := []*User{ GetUser("row_query_user", Config{}), GetUser("row_query_user", Config{}), GetUser("row_query_user", Config{}), } DB.Create(&users) var user User DB.Raw("select * from users WHERE id = ?", users[1].ID).First(&user) CheckUser(t, user, *users[1]) } func TestDryRun(t *testing.T) { user := *GetUser("dry-run", Config{}) dryRunDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryRunDB.Create(&user).Statement if stmt.SQL.String() == "" || len(stmt.Vars) != 9 { t.Errorf("Failed to generate sql, got %v", stmt.SQL.String()) } stmt2 := dryRunDB.Find(&user, "id = ?", user.ID).Statement if stmt2.SQL.String() == "" || len(stmt2.Vars) != 1 { t.Errorf("Failed to generate sql, got %v", stmt2.SQL.String()) } } type ageInt int8 func (ageInt) String() string { return "age" } type ageBool bool func (ageBool) String() string { return "age" } type ageUint64 uint64 func (ageUint64) String() string { return "age" } type ageFloat float64 func (ageFloat) String() string { return "age" } func TestExplainSQL(t *testing.T) { user := *GetUser("explain-sql", Config{}) dryRunDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageInt(8)}).Statement sql := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) if !regexp.MustCompile(`.*age.*=8,`).MatchString(sql) { t.Errorf("Failed to generate sql, got %v", sql) } stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageUint64(10241024)}).Statement sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) if !regexp.MustCompile(`.*age.*=10241024,`).MatchString(sql) { t.Errorf("Failed to generate sql, got %v", sql) } stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageBool(false)}).Statement sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) if !regexp.MustCompile(`.*age.*=false,`).MatchString(sql) { t.Errorf("Failed to generate sql, got %v", sql) } stmt = dryRunDB.Model(&user).Where("id = ?", 1).Updates(map[string]interface{}{"age": ageFloat(0.12345678)}).Statement sql = DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) if !regexp.MustCompile(`.*age.*=0.123457,`).MatchString(sql) { t.Errorf("Failed to generate sql, got %v", sql) } } func TestGroupConditions(t *testing.T) { type Pizza struct { ID uint Name string Size string } dryRunDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryRunDB.Where( DB.Where("pizza = ?", "pepperoni").Where(DB.Where("size = ?", "small").Or("size = ?", "medium")), ).Or( DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), ).Find(&Pizza{}).Statement execStmt := dryRunDB.Exec("WHERE (pizza = ? AND (size = ? OR size = ?)) OR (pizza = ? AND size = ?)", "pepperoni", "small", "medium", "hawaiian", "xlarge").Statement result := DB.Dialector.Explain(stmt.SQL.String(), stmt.Vars...) expects := DB.Dialector.Explain(execStmt.SQL.String(), execStmt.Vars...) if !strings.HasSuffix(result, expects) { t.Errorf("expects: %v, got %v", expects, result) } stmt2 := dryRunDB.Where( DB.Scopes(NameIn1And2), ).Or( DB.Where("pizza = ?", "hawaiian").Where("size = ?", "xlarge"), ).Find(&Pizza{}).Statement execStmt2 := dryRunDB.Exec(`WHERE name in ? OR (pizza = ? AND size = ?)`, []string{"ScopeUser1", "ScopeUser2"}, "hawaiian", "xlarge").Statement result2 := DB.Dialector.Explain(stmt2.SQL.String(), stmt2.Vars...) expects2 := DB.Dialector.Explain(execStmt2.SQL.String(), execStmt2.Vars...) if !strings.HasSuffix(result2, expects2) { t.Errorf("expects: %v, got %v", expects2, result2) } } func TestCombineStringConditions(t *testing.T) { dryRunDB := DB.Session(&gorm.Session{DryRun: true}) sql := dryRunDB.Where("a = ? or b = ?", "a", "b").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ?", "c").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR c = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Or("c = ? and d = ?", "c", "d").Or("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE \(\(a = .+ or b = .+\) OR \(c = .+ and d = .+\) OR \(e = .+ and f = .+\)\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ?", "c").Not("e = ? and f = ?", "e", "f").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND c = .+ AND NOT \(e = .+ and f = .+\) AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Where("c = ? and d = ?", "c", "d").Not("e = ?", "e").Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE \(a = .+ or b = .+\) AND \(c = .+ and d = .+\) AND NOT e = .+ AND .users.\..deleted_at. IS NULL`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Where("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Or("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE a = .+ or b = .+$`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } sql = dryRunDB.Not("a = ? or b = ?", "a", "b").Unscoped().Find(&User{}).Statement.SQL.String() if !regexp.MustCompile(`WHERE NOT \(a = .+ or b = .+\)$`).MatchString(sql) { t.Fatalf("invalid sql generated, got %v", sql) } } func TestFromWithJoins(t *testing.T) { var result User newDB := DB.Session(&gorm.Session{NewDB: true, DryRun: true}).Table("users") newDB.Clauses( clause.From{ Tables: []clause.Table{{Name: "users"}}, Joins: []clause.Join{ { Table: clause.Table{Name: "companies", Raw: false}, ON: clause.Where{ Exprs: []clause.Expression{ clause.Eq{ Column: clause.Column{ Table: "users", Name: "company_id", }, Value: clause.Column{ Table: "companies", Name: "id", }, }, }, }, }, }, }, ) newDB.Joins("inner join rgs on rgs.id = user.id") stmt := newDB.First(&result).Statement str := stmt.SQL.String() if !strings.Contains(str, "rgs.id = user.id") { t.Errorf("The second join condition is over written instead of combining") } if !strings.Contains(str, "`users`.`company_id` = `companies`.`id`") && !strings.Contains(str, "\"users\".\"company_id\" = \"companies\".\"id\"") { t.Errorf("The first join condition is over written instead of combining") } } func TestToSQL(t *testing.T) { // By default DB.DryRun should false if DB.DryRun { t.Fatal("Failed expect DB.DryRun to be false") } if DB.Dialector.Name() == "sqlserver" { t.Skip("Skip SQL Server for this test, because it too difference with other dialects.") } date, _ := time.ParseInLocation("2006-01-02", "2021-10-18", time.Local) // find sql := DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where("id = ?", 100).Limit(10).Order("age desc").Find(&[]User{}) }) assertEqualSQL(t, `SELECT * FROM "users" WHERE id = 100 AND "users"."deleted_at" IS NULL ORDER BY age desc LIMIT 10`, sql) // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } if DB.Statement.SQL.String() != "" { t.Fatal("Failed expect DB.Statement.SQL to be empty") } // first sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where(&User{Name: "foo", Age: 20}).Limit(10).Offset(5).Order("name ASC").First(&User{}) }) assertEqualSQL(t, `SELECT * FROM "users" WHERE ("users"."name" = 'foo' AND "users"."age" = 20) AND "users"."deleted_at" IS NULL ORDER BY name ASC,"users"."id" LIMIT 1 OFFSET 5`, sql) // last and unscoped sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Unscoped().Where(&User{Name: "bar", Age: 12}).Limit(10).Offset(5).Order("name ASC").Last(&User{}) }) assertEqualSQL(t, `SELECT * FROM "users" WHERE "users"."name" = 'bar' AND "users"."age" = 12 ORDER BY name ASC,"users"."id" DESC LIMIT 1 OFFSET 5`, sql) // create user := &User{Name: "foo", Age: 20} user.CreatedAt = date user.UpdatedAt = date sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Create(user) }) assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) // save user = &User{Name: "foo", Age: 20} user.CreatedAt = date user.UpdatedAt = date sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Save(user) }) assertEqualSQL(t, `INSERT INTO "users" ("created_at","updated_at","deleted_at","name","age","birthday","company_id","manager_id","active") VALUES ('2021-10-18 00:00:00','2021-10-18 00:00:00',NULL,'foo',20,NULL,NULL,NULL,false) RETURNING "id"`, sql) // updates user = &User{Name: "bar", Age: 22} user.CreatedAt = date user.UpdatedAt = date sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where("id = ?", 100).Updates(user) }) assertEqualSQL(t, `UPDATE "users" SET "created_at"='2021-10-18 00:00:00',"updated_at"='2021-10-18 19:50:09.438',"name"='bar',"age"=22 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) // update sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where("id = ?", 100).Update("name", "Foo bar") }) assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar',"updated_at"='2021-10-18 19:50:09.438' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) // UpdateColumn sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where("id = ?", 100).UpdateColumn("name", "Foo bar") }) assertEqualSQL(t, `UPDATE "users" SET "name"='Foo bar' WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) // UpdateColumns sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&User{}).Where("id = ?", 100).UpdateColumns(User{Name: "Foo", Age: 100}) }) assertEqualSQL(t, `UPDATE "users" SET "name"='Foo',"age"=100 WHERE id = 100 AND "users"."deleted_at" IS NULL`, sql) // after model changed if DB.Statement.DryRun || DB.DryRun { t.Fatal("Failed expect DB.DryRun and DB.Statement.ToSQL to be false") } // UpdateColumns sql = DB.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Raw("SELECT * FROM users ?", clause.OrderBy{ Columns: []clause.OrderByColumn{{Column: clause.Column{Name: "id", Raw: true}, Desc: true}}, }) }) assertEqualSQL(t, `SELECT * FROM users ORDER BY id DESC`, sql) } // assertEqualSQL for assert that the sql is equal, this method will ignore quote, and dialect specials. func assertEqualSQL(t *testing.T, expected string, actually string) { t.Helper() // replace SQL quote, convert into postgresql like "" expected = replaceQuoteInSQL(expected) actually = replaceQuoteInSQL(actually) // ignore updated_at value, because it's generated in Gorm internal, can't to mock value on update. updatedAtRe := regexp.MustCompile(`(?i)"updated_at"=".+?"`) actually = updatedAtRe.ReplaceAllString(actually, `"updated_at"=?`) expected = updatedAtRe.ReplaceAllString(expected, `"updated_at"=?`) // ignore RETURNING "id" (only in PostgreSQL) returningRe := regexp.MustCompile(`(?i)RETURNING "id"`) actually = returningRe.ReplaceAllString(actually, ``) expected = returningRe.ReplaceAllString(expected, ``) actually = strings.TrimSpace(actually) expected = strings.TrimSpace(expected) if actually != expected { t.Fatalf("\nexpected: %s\nactually: %s", expected, actually) } } func replaceQuoteInSQL(sql string) string { // convert single quote into double quote sql = strings.ReplaceAll(sql, `'`, `"`) // convert dialect special quote into double quote switch DB.Dialector.Name() { case "postgres", "gaussdb": sql = strings.ReplaceAll(sql, `"`, `"`) case "mysql", "sqlite": sql = strings.ReplaceAll(sql, "`", `"`) case "sqlserver": sql = strings.ReplaceAll(sql, `'`, `"`) } return sql } ================================================ FILE: tests/submodel_test.go ================================================ package tests_test import ( "testing" "gorm.io/gorm" ) type Man struct { ID int Age int Name string Detail string } // Panic-safe BeforeUpdate hook that checks for Changed("age") func (m *Man) BeforeUpdate(tx *gorm.DB) (err error) { if !tx.Statement.Changed("age") { return nil } return nil } func TestSubModel(t *testing.T) { man := Man{Age: 18, Name: "random-name"} if err := DB.Create(&man).Error; err != nil { t.Fatalf("unexpected error: %v", err) } if err := DB.Model(&man).Where("id = ?", man.ID).Updates(struct { Age int }{Age: 20}).Error; err != nil { t.Fatalf("unexpected error: %v", err) } var result = struct { ID int Age int }{} if err := DB.Model(&man).Where("id = ?", man.ID).Find(&result).Error; err != nil { t.Fatalf("unexpected error: %v", err) } if result.ID != man.ID || result.Age != 20 { t.Fatalf("expected ID %d and Age 20, got ID %d and age", result.ID, result.Age) } } ================================================ FILE: tests/table_test.go ================================================ package tests_test import ( "regexp" "sync" "testing" "gorm.io/driver/gaussdb" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" "gorm.io/gorm/utils/tests" . "gorm.io/gorm/utils/tests" ) type UserWithTable struct { gorm.Model Name string } func (UserWithTable) TableName() string { return "gorm.user" } func TestTable(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true}) r := dryDB.Table("`user`").Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("`people`").Table("`user`").Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM `user`").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("people as p").Table("user as u").Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("people as p").Table("user").Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("gorm.people").Table("user").Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM .user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Select("name").Find(&UserWithTable{}).Statement if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Create(&UserWithTable{}).Statement if DB.Dialector.Name() != "sqlite" { if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } else { if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement if !regexp.MustCompile("SELECT \\* FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } func TestTableWithAllFields(t *testing.T) { dryDB := DB.Session(&gorm.Session{DryRun: true, QueryFields: true}) userQuery := "SELECT .*user.*id.*user.*created_at.*user.*updated_at.*user.*deleted_at.*user.*name.*user.*age" + ".*user.*birthday.*user.*company_id.*user.*manager_id.*user.*active.* " r := dryDB.Table("`user`").Find(&User{}).Statement if !regexp.MustCompile(userQuery + "FROM `user`").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("user as u").Select("name").Find(&User{}).Statement if !regexp.MustCompile("SELECT .name. FROM user as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("gorm.user").Select("name").Find(&User{}).Statement if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Select("name").Find(&UserWithTable{}).Statement if !regexp.MustCompile("SELECT .name. FROM .gorm.\\..user. WHERE .user.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Create(&UserWithTable{}).Statement if DB.Dialector.Name() != "sqlite" { if !regexp.MustCompile(`INSERT INTO .gorm.\..user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } else { if !regexp.MustCompile(`INSERT INTO .user. (.*name.*) VALUES (.*)`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } userQueryCharacter := "SELECT .*u.*id.*u.*created_at.*u.*updated_at.*u.*deleted_at.*u.*name.*u.*age.*u.*birthday" + ".*u.*company_id.*u.*manager_id.*u.*active.* " r = dryDB.Table("(?) as u", DB.Model(&User{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name"), DB.Model(&Pet{}).Select("name")).Find(&User{}).Statement if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE .pets.\\..deleted_at. IS NULL\\) as p WHERE .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } r = dryDB.Where("name = ?", 1).Table("(?) as u, (?) as p", DB.Model(&User{}).Select("name").Where("name = ?", 2), DB.Model(&Pet{}).Where("name = ?", 4).Select("name")).Where("name = ?", 3).Find(&User{}).Statement if !regexp.MustCompile(userQueryCharacter + "FROM \\(SELECT .name. FROM .users. WHERE name = .+ AND .users.\\..deleted_at. IS NULL\\) as u, \\(SELECT .name. FROM .pets. WHERE name = .+ AND .pets.\\..deleted_at. IS NULL\\) as p WHERE name = .+ AND name = .+ AND .u.\\..deleted_at. IS NULL").MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } AssertEqual(t, r.Statement.Vars, []interface{}{2, 4, 1, 3}) } type UserWithTableNamer struct { gorm.Model Name string } func (UserWithTableNamer) TableName(namer schema.Namer) string { return namer.TableName("user") } func TestTableWithNamer(t *testing.T) { db, _ := gorm.Open(tests.DummyDialector{}, &gorm.Config{ NamingStrategy: schema.NamingStrategy{ TablePrefix: "t_", }, }) sql := db.ToSQL(func(tx *gorm.DB) *gorm.DB { return tx.Model(&UserWithTableNamer{}).Find(&UserWithTableNamer{}) }) if !regexp.MustCompile("SELECT \\* FROM `t_users`").MatchString(sql) { t.Errorf("Table with namer, got %v", sql) } } func TestPostgresTableWithIdentifierLength(t *testing.T) { if DB.Dialector.Name() != "postgres" { return } type LongString struct { ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"` } t.Run("default", func(t *testing.T) { db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{}) user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) if err != nil { t.Fatalf("failed to parse user unique, got error %v", err) } constraints := user.ParseUniqueConstraints() if len(constraints) != 1 { t.Fatalf("failed to find unique constraint, got %v", constraints) } for key := range constraints { if len(key) != 63 { t.Errorf("failed to find unique constraint, got %v", constraints) } } }) t.Run("naming strategy", func(t *testing.T) { db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ NamingStrategy: schema.NamingStrategy{}, }) user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) if err != nil { t.Fatalf("failed to parse user unique, got error %v", err) } constraints := user.ParseUniqueConstraints() if len(constraints) != 1 { t.Fatalf("failed to find unique constraint, got %v", constraints) } for key := range constraints { if len(key) != 63 { t.Errorf("failed to find unique constraint, got %v", constraints) } } }) t.Run("namer", func(t *testing.T) { uname := "custom_unique_name" db, _ := gorm.Open(postgres.Open(postgresDSN), &gorm.Config{ NamingStrategy: mockUniqueNamingStrategy{ UName: uname, }, }) user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) if err != nil { t.Fatalf("failed to parse user unique, got error %v", err) } constraints := user.ParseUniqueConstraints() if len(constraints) != 1 { t.Fatalf("failed to find unique constraint, got %v", constraints) } for key := range constraints { if key != uname { t.Errorf("failed to find unique constraint, got %v", constraints) } } }) } func TestGaussDBTableWithIdentifierLength(t *testing.T) { if DB.Dialector.Name() != "gaussdb" { return } type LongString struct { ThisIsAVeryVeryVeryVeryVeryVeryVeryVeryVeryLongString string `gorm:"unique"` } t.Run("default", func(t *testing.T) { db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{}) user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) if err != nil { t.Fatalf("failed to parse user unique, got error %v", err) } constraints := user.ParseUniqueConstraints() if len(constraints) != 1 { t.Fatalf("failed to find unique constraint, got %v", constraints) } for key := range constraints { if len(key) != 63 { t.Errorf("failed to find unique constraint, got %v", constraints) } } }) t.Run("naming strategy", func(t *testing.T) { db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{ NamingStrategy: schema.NamingStrategy{}, }) user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) if err != nil { t.Fatalf("failed to parse user unique, got error %v", err) } constraints := user.ParseUniqueConstraints() if len(constraints) != 1 { t.Fatalf("failed to find unique constraint, got %v", constraints) } for key := range constraints { if len(key) != 63 { t.Errorf("failed to find unique constraint, got %v", constraints) } } }) t.Run("namer", func(t *testing.T) { uname := "custom_unique_name" db, _ := gorm.Open(gaussdb.Open(gaussdbDSN), &gorm.Config{ NamingStrategy: mockUniqueNamingStrategy{ UName: uname, }, }) user, err := schema.Parse(&LongString{}, &sync.Map{}, db.Config.NamingStrategy) if err != nil { t.Fatalf("failed to parse user unique, got error %v", err) } constraints := user.ParseUniqueConstraints() if len(constraints) != 1 { t.Fatalf("failed to find unique constraint, got %v", constraints) } for key := range constraints { if key != uname { t.Errorf("failed to find unique constraint, got %v", constraints) } } }) } type mockUniqueNamingStrategy struct { UName string schema.NamingStrategy } func (a mockUniqueNamingStrategy) UniqueName(table, column string) string { return a.UName } ================================================ FILE: tests/tests_all.sh ================================================ #!/bin/bash -e dialects=("sqlite" "mysql" "postgres" "gaussdb" "sqlserver" "tidb") if [[ $(pwd) == *"gorm/tests"* ]]; then cd .. fi if [ -d tests ] then cd tests go get -u -t ./... go mod download go mod tidy cd .. fi # SqlServer for Mac M1 if [[ -z $GITHUB_ACTION && -d tests ]]; then cd tests if [[ $(uname -a) == *" arm64" ]]; then MSSQL_IMAGE=mcr.microsoft.com/azure-sql-edge docker compose up -d --wait go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest || true for query in \ "IF DB_ID('gorm') IS NULL CREATE DATABASE gorm" \ "IF SUSER_ID (N'gorm') IS NULL CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86';" \ "IF USER_ID (N'gorm') IS NULL CREATE USER gorm FROM LOGIN gorm; ALTER SERVER ROLE sysadmin ADD MEMBER [gorm];" do SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 -Q "$query" > /dev/null || true done else MSSQL_IMAGE=mcr.microsoft.com/mssql/server docker compose up -d --wait fi cd .. fi for dialect in "${dialects[@]}" ; do if [ "$GORM_DIALECT" = "" ] || [ "$GORM_DIALECT" = "${dialect}" ] then echo "testing ${dialect}..." if [ "$GORM_VERBOSE" = "" ] then GORM_DIALECT=${dialect} go test -race -count=1 ./... if [ -d tests ] then cd tests GORM_DIALECT=${dialect} go test -race -count=1 ./... cd .. fi else GORM_DIALECT=${dialect} go test -race -count=1 -v ./... if [ -d tests ] then cd tests GORM_DIALECT=${dialect} go test -race -count=1 -v ./... cd .. fi fi fi done ================================================ FILE: tests/tests_test.go ================================================ //go:debug x509negativeserial=1 package tests_test import ( "log" "math/rand" "os" "path/filepath" "time" "gorm.io/driver/gaussdb" "gorm.io/driver/mysql" "gorm.io/driver/postgres" "gorm.io/driver/sqlite" "gorm.io/driver/sqlserver" "gorm.io/gorm" "gorm.io/gorm/logger" . "gorm.io/gorm/utils/tests" ) var DB *gorm.DB var ( mysqlDSN = "gorm:gorm@tcp(localhost:9910)/gorm?charset=utf8&parseTime=True&loc=Local" postgresDSN = "user=gorm password=gorm dbname=gorm host=localhost port=9920 sslmode=disable TimeZone=Asia/Shanghai" gaussdbDSN = "user=gaussdb password=Gaussdb@123 dbname=gorm host=localhost port=9950 sslmode=disable TimeZone=Asia/Shanghai" sqlserverDSN = "sqlserver://sa:LoremIpsum86@localhost:9930?database=master" tidbDSN = "root:@tcp(localhost:9940)/test?charset=utf8&parseTime=True&loc=Local" ) func init() { var err error if DB, err = OpenTestConnection(&gorm.Config{}); err != nil { log.Printf("failed to connect database, got error %v", err) os.Exit(1) } else { sqlDB, err := DB.DB() if err != nil { log.Printf("failed to connect database, got error %v", err) os.Exit(1) } err = sqlDB.Ping() if err != nil { log.Printf("failed to ping sqlDB, got error %v", err) os.Exit(1) } RunMigrations() } } func OpenTestConnection(cfg *gorm.Config) (db *gorm.DB, err error) { dbDSN := os.Getenv("GORM_DSN") switch os.Getenv("GORM_DIALECT") { case "mysql": log.Println("testing mysql...") if dbDSN == "" { dbDSN = mysqlDSN } db, err = gorm.Open(mysql.Open(dbDSN), cfg) case "postgres": log.Println("testing postgres...") if dbDSN == "" { dbDSN = postgresDSN } db, err = gorm.Open(postgres.New(postgres.Config{ DSN: dbDSN, PreferSimpleProtocol: true, }), cfg) case "gaussdb": log.Println("testing gaussdb...") if dbDSN == "" { dbDSN = gaussdbDSN } db, err = gorm.Open(gaussdb.New(gaussdb.Config{ DSN: dbDSN, PreferSimpleProtocol: true, }), cfg) case "sqlserver": // go install github.com/microsoft/go-sqlcmd/cmd/sqlcmd@latest // SQLCMDPASSWORD=LoremIpsum86 sqlcmd -U sa -S localhost:9930 // CREATE DATABASE gorm; // GO // CREATE LOGIN gorm WITH PASSWORD = 'LoremIpsum86'; // CREATE USER gorm FROM LOGIN gorm; // ALTER SERVER ROLE sysadmin ADD MEMBER [gorm]; // GO log.Println("testing sqlserver...") if dbDSN == "" { dbDSN = sqlserverDSN } db, err = gorm.Open(sqlserver.Open(dbDSN), cfg) case "tidb": log.Println("testing tidb...") if dbDSN == "" { dbDSN = tidbDSN } db, err = gorm.Open(mysql.Open(dbDSN), cfg) default: log.Println("testing sqlite3...") db, err = gorm.Open(sqlite.Open(filepath.Join(os.TempDir(), "gorm.db")), cfg) if err == nil { db.Exec("PRAGMA foreign_keys = ON") } } if err != nil { return } if debug := os.Getenv("DEBUG"); debug == "true" { db.Logger = db.Logger.LogMode(logger.Info) } else if debug == "false" { db.Logger = db.Logger.LogMode(logger.Silent) } return } func RunMigrations() { var err error allModels := []interface{}{&User{}, &Account{}, &Pet{}, &Company{}, &Toy{}, &Language{}, &Coupon{}, &CouponProduct{}, &Order{}, &Parent{}, &Child{}, &Tools{}} rand.Seed(time.Now().UnixNano()) rand.Shuffle(len(allModels), func(i, j int) { allModels[i], allModels[j] = allModels[j], allModels[i] }) DB.Migrator().DropTable("user_friends", "user_speaks") if err = DB.Migrator().DropTable(allModels...); err != nil { log.Printf("Failed to drop table, got error %v\n", err) os.Exit(1) } if err = DB.AutoMigrate(allModels...); err != nil { log.Printf("Failed to auto migrate, but got error %v\n", err) os.Exit(1) } for _, m := range allModels { if !DB.Migrator().HasTable(m) { log.Printf("Failed to create table for %#v\n", m) os.Exit(1) } } } ================================================ FILE: tests/tracer_test.go ================================================ package tests_test import ( "context" "time" "gorm.io/gorm/logger" ) type Tracer struct { Logger logger.Interface Test func(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) } func (S Tracer) LogMode(level logger.LogLevel) logger.Interface { return S.Logger.LogMode(level) } func (S Tracer) Info(ctx context.Context, s string, i ...interface{}) { S.Logger.Info(ctx, s, i...) } func (S Tracer) Warn(ctx context.Context, s string, i ...interface{}) { S.Logger.Warn(ctx, s, i...) } func (S Tracer) Error(ctx context.Context, s string, i ...interface{}) { S.Logger.Error(ctx, s, i...) } func (S Tracer) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) { S.Logger.Trace(ctx, begin, fc, err) S.Test(ctx, begin, fc, err) } ================================================ FILE: tests/transaction_test.go ================================================ package tests_test import ( "context" "errors" "testing" "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestTransaction(t *testing.T) { tx := DB.Begin() user := *GetUser("transaction", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } if err := tx.First(&User{}, "name = ?", "transaction").Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } user1 := *GetUser("transaction1-1", Config{}) if err := tx.Save(&user1).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } if sqlTx, ok := tx.Statement.ConnPool.(gorm.TxCommitter); !ok || sqlTx == nil { t.Fatalf("Should return the underlying sql.Tx") } tx.Rollback() if err := DB.First(&User{}, "name = ?", "transaction").Error; err == nil { t.Fatalf("Should not find record after rollback, but got %v", err) } txDB := DB.Where("fake_name = ?", "fake_name") tx2 := txDB.Session(&gorm.Session{NewDB: true}).Begin() user2 := *GetUser("transaction-2", Config{}) if err := tx2.Save(&user2).Error; err != nil { t.Fatalf("No error should raise, but got %v", err) } if err := tx2.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should find saved record, but got %v", err) } tx2.Commit() if err := DB.First(&User{}, "name = ?", "transaction-2").Error; err != nil { t.Fatalf("Should be able to find committed record, but got %v", err) } t.Run("this is test nested transaction and prepareStmt coexist case", func(t *testing.T) { // enable prepare statement tx3 := DB.Session(&gorm.Session{PrepareStmt: true}) if err := tx3.Transaction(func(tx4 *gorm.DB) error { // nested transaction return tx4.Transaction(func(tx5 *gorm.DB) error { return tx5.First(&User{}, "name = ?", "transaction-2").Error }) }); err != nil { t.Fatalf("prepare statement and nested transaction coexist: %v", err) } }) } func TestCancelTransaction(t *testing.T) { ctx := context.Background() ctx, cancelFunc := context.WithCancel(ctx) cancelFunc() user := *GetUser("cancel_transaction", Config{}) DB.Create(&user) err := DB.WithContext(ctx).Transaction(func(tx *gorm.DB) error { var result User tx.First(&result, user.ID) return nil }) if err == nil { t.Fatalf("Transaction should get error when using cancelled context") } } func TestTransactionWithBlock(t *testing.T) { assertPanic := func(f func()) { defer func() { if r := recover(); r == nil { t.Fatalf("The code did not panic") } }() f() } // rollback err := DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transaction-block", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } return errors.New("the error message") }) if err != nil && err.Error() != "the error message" { t.Fatalf("Transaction return error will equal the block returns error") } if err := DB.First(&User{}, "name = ?", "transaction-block").Error; err == nil { t.Fatalf("Should not find record after rollback") } // commit DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transaction-block-2", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } return nil }) if err := DB.First(&User{}, "name = ?", "transaction-block-2").Error; err != nil { t.Fatalf("Should be able to find committed record") } // panic will rollback assertPanic(func() { DB.Transaction(func(tx *gorm.DB) error { user := *GetUser("transaction-block-3", Config{}) if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } panic("force panic") }) }) if err := DB.First(&User{}, "name = ?", "transaction-block-3").Error; err == nil { t.Fatalf("Should not find record after panic rollback") } } func TestTransactionRaiseErrorOnRollbackAfterCommit(t *testing.T) { tx := DB.Begin() user := User{Name: "transaction"} if err := tx.Save(&user).Error; err != nil { t.Fatalf("No error should raise") } if err := tx.Commit().Error; err != nil { t.Fatalf("Commit should not raise error") } if err := tx.Rollback().Error; err == nil { t.Fatalf("Rollback after commit should raise error") } } func TestTransactionWithSavePoint(t *testing.T) { tx := DB.Begin() user := *GetUser("transaction-save-point", Config{}) tx.Create(&user) if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := tx.SavePoint("save_point1").Error; err != nil { t.Fatalf("Failed to save point, got error %v", err) } user1 := *GetUser("transaction-save-point-1", Config{}) tx.Create(&user1) if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := tx.RollbackTo("save_point1").Error; err != nil { t.Fatalf("Failed to save point, got error %v", err) } if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { t.Fatalf("Should not find rollbacked record") } if err := tx.SavePoint("save_point2").Error; err != nil { t.Fatalf("Failed to save point, got error %v", err) } user2 := *GetUser("transaction-save-point-2", Config{}) tx.Create(&user2) if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := tx.Commit().Error; err != nil { t.Fatalf("Failed to commit, got error %v", err) } if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { t.Fatalf("Should not find rollbacked record") } if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } } func TestNestedTransactionWithBlock(t *testing.T) { var ( user = *GetUser("transaction-nested", Config{}) user1 = *GetUser("transaction-nested-1", Config{}) user2 = *GetUser("transaction-nested-2", Config{}) ) if err := DB.Transaction(func(tx *gorm.DB) error { tx.Create(&user) if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := tx.Transaction(func(tx1 *gorm.DB) error { tx1.Create(&user1) if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should find saved record") } return errors.New("rollback") }); err == nil { t.Fatalf("nested transaction should returns error") } if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { t.Fatalf("Should not find rollbacked record") } if err := tx.Transaction(func(tx2 *gorm.DB) error { tx2.Create(&user2) if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } return nil }); err != nil { t.Fatalf("nested transaction returns error: %v", err) } if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } return nil }); err != nil { t.Fatalf("no error should return, but got %v", err) } if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { t.Fatalf("Should not find rollbacked record") } if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } } func TestDeeplyNestedTransactionWithBlockAndWrappedCallback(t *testing.T) { transaction := func(ctx context.Context, db *gorm.DB, callback func(ctx context.Context, db *gorm.DB) error) error { return db.WithContext(ctx).Transaction(func(tx *gorm.DB) error { return callback(ctx, tx) }) } var ( user = *GetUser("transaction-nested", Config{}) user1 = *GetUser("transaction-nested-1", Config{}) user2 = *GetUser("transaction-nested-2", Config{}) ) if err := transaction(context.Background(), DB, func(ctx context.Context, tx *gorm.DB) error { tx.Create(&user) if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := transaction(ctx, tx, func(ctx context.Context, tx1 *gorm.DB) error { tx1.Create(&user1) if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := transaction(ctx, tx1, func(ctx context.Context, tx2 *gorm.DB) error { tx2.Create(&user2) if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } return errors.New("inner rollback") }); err == nil { t.Fatalf("nested transaction has no error") } return errors.New("rollback") }); err == nil { t.Fatalf("nested transaction should returns error") } if err := tx.First(&User{}, "name = ?", user1.Name).Error; err == nil { t.Fatalf("Should not find rollbacked record") } if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } return nil }); err != nil { t.Fatalf("no error should return, but got %v", err) } if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := DB.First(&User{}, "name = ?", user1.Name).Error; err == nil { t.Fatalf("Should not find rollbacked parent record") } if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should not find rollbacked nested record") } } func TestDisabledNestedTransaction(t *testing.T) { var ( user = *GetUser("transaction-nested", Config{}) user1 = *GetUser("transaction-nested-1", Config{}) user2 = *GetUser("transaction-nested-2", Config{}) ) if err := DB.Session(&gorm.Session{DisableNestedTransaction: true}).Transaction(func(tx *gorm.DB) error { tx.Create(&user) if err := tx.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := tx.Transaction(func(tx1 *gorm.DB) error { tx1.Create(&user1) if err := tx1.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should find saved record") } return errors.New("rollback") }); err == nil { t.Fatalf("nested transaction should returns error") } if err := tx.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should not rollback record if disabled nested transaction support") } if err := tx.Transaction(func(tx2 *gorm.DB) error { tx2.Create(&user2) if err := tx2.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } return nil }); err != nil { t.Fatalf("nested transaction returns error: %v", err) } if err := tx.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } return nil }); err != nil { t.Fatalf("no error should return, but got %v", err) } if err := DB.First(&User{}, "name = ?", user.Name).Error; err != nil { t.Fatalf("Should find saved record") } if err := DB.First(&User{}, "name = ?", user1.Name).Error; err != nil { t.Fatalf("Should not rollback record if disabled nested transaction support") } if err := DB.First(&User{}, "name = ?", user2.Name).Error; err != nil { t.Fatalf("Should find saved record") } } func TestTransactionOnClosedConn(t *testing.T) { DB, err := OpenTestConnection(&gorm.Config{}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } rawDB, _ := DB.DB() rawDB.Close() if err := DB.Transaction(func(tx *gorm.DB) error { return nil }); err == nil { t.Errorf("should returns error when commit with closed conn, got error %v", err) } if err := DB.Session(&gorm.Session{PrepareStmt: true}).Transaction(func(tx *gorm.DB) error { return nil }); err == nil { t.Errorf("should returns error when commit with closed conn, got error %v", err) } } func TestTransactionWithHooks(t *testing.T) { user := GetUser("tTestTransactionWithHooks", Config{Account: true}) DB.Create(&user) var err error err = DB.Transaction(func(tx *gorm.DB) error { return tx.Model(&User{}).Limit(1).Transaction(func(tx2 *gorm.DB) error { return tx2.Scan(&User{}).Error }) }) if err != nil { t.Error(err) } // method with hooks err = DB.Transaction(func(tx1 *gorm.DB) error { // callMethod do tx2 := tx1.Find(&User{}).Session(&gorm.Session{NewDB: true}) // trx in hooks return tx2.Transaction(func(tx3 *gorm.DB) error { return tx3.Where("user_id", user.ID).Delete(&Account{}).Error }) }) if err != nil { t.Error(err) } } func TestTransactionWithDefaultTimeout(t *testing.T) { db, err := OpenTestConnection(&gorm.Config{DefaultTransactionTimeout: 2 * time.Second}) if err != nil { t.Fatalf("failed to connect database, got error %v", err) } tx := db.Begin() time.Sleep(3 * time.Second) if err = tx.Find(&User{}).Error; err == nil { t.Errorf("should return error when transaction timeout, got error %v", err) } } ================================================ FILE: tests/update_belongs_to_test.go ================================================ package tests_test import ( "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestUpdateBelongsTo(t *testing.T) { user := *GetUser("update-belongs-to", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } user.Company = Company{Name: "company-belongs-to-association"} user.Manager = &User{Name: "manager-belongs-to-association"} if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user2 User DB.Preload("Company").Preload("Manager").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) user.Company.Name += "new" user.Manager.Name += "new" if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user3 User DB.Preload("Company").Preload("Manager").Find(&user3, "id = ?", user.ID) CheckUser(t, user2, user3) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user4 User DB.Preload("Company").Preload("Manager").Find(&user4, "id = ?", user.ID) CheckUser(t, user4, user) user.Company.Name += "new2" user.Manager.Name += "new2" if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Select("`Company`").Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user5 User DB.Preload("Company").Preload("Manager").Find(&user5, "id = ?", user.ID) if user5.Manager.Name != user4.Manager.Name { t.Errorf("should not update user's manager") } else { user.Manager.Name = user4.Manager.Name } CheckUser(t, user, user5) } ================================================ FILE: tests/update_has_many_test.go ================================================ package tests_test import ( "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestUpdateHasManyAssociations(t *testing.T) { user := *GetUser("update-has-many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } user.Pets = []*Pet{{Name: "pet1"}, {Name: "pet2"}} if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user2 User DB.Preload("Pets").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) for _, pet := range user.Pets { pet.Name += "new" } if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user3 User DB.Preload("Pets").Find(&user3, "id = ?", user.ID) CheckUser(t, user2, user3) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user4 User DB.Preload("Pets").Find(&user4, "id = ?", user.ID) CheckUser(t, user4, user) t.Run("Polymorphic", func(t *testing.T) { user := *GetUser("update-has-many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } user.Toys = []Toy{{Name: "toy1"}, {Name: "toy2"}} if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user2 User DB.Preload("Toys").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) for idx := range user.Toys { user.Toys[idx].Name += "new" } if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user3 User DB.Preload("Toys").Find(&user3, "id = ?", user.ID) CheckUser(t, user2, user3) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user4 User DB.Preload("Toys").Find(&user4, "id = ?", user.ID) CheckUser(t, user4, user) }) } ================================================ FILE: tests/update_has_one_test.go ================================================ package tests_test import ( "database/sql" "testing" "time" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestUpdateHasOne(t *testing.T) { user := *GetUser("update-has-one", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } user.Account = Account{Number: "account-has-one-association"} if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user2 User DB.Preload("Account").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) user.Account.Number += "new" if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user3 User DB.Preload("Account").Find(&user3, "id = ?", user.ID) CheckUser(t, user2, user3) lastUpdatedAt := user2.Account.UpdatedAt time.Sleep(time.Second) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user4 User DB.Preload("Account").Find(&user4, "id = ?", user.ID) if lastUpdatedAt.Format(time.RFC3339) == user4.Account.UpdatedAt.Format(time.RFC3339) { t.Fatalf("updated at should be updated, but not, old: %v, new %v", lastUpdatedAt.Format(time.RFC3339), user3.Account.UpdatedAt.Format(time.RFC3339)) } else { user.Account.UpdatedAt = user4.Account.UpdatedAt CheckUser(t, user4, user) } t.Run("Polymorphic", func(t *testing.T) { pet := Pet{Name: "create"} if err := DB.Create(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } pet.Toy = Toy{Name: "Update-HasOneAssociation-Polymorphic"} if err := DB.Save(&pet).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var pet2 Pet DB.Preload("Toy").Find(&pet2, "id = ?", pet.ID) CheckPet(t, pet2, pet) pet.Toy.Name += "new" if err := DB.Save(&pet).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var pet3 Pet DB.Preload("Toy").Find(&pet3, "id = ?", pet.ID) CheckPet(t, pet2, pet3) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&pet).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var pet4 Pet DB.Preload("Toy").Find(&pet4, "id = ?", pet.ID) CheckPet(t, pet4, pet) }) t.Run("Restriction", func(t *testing.T) { type CustomizeAccount struct { gorm.Model UserID sql.NullInt64 Number string `gorm:"<-:create"` Number2 string } type CustomizeUser struct { gorm.Model Name string Account CustomizeAccount `gorm:"foreignkey:UserID"` } DB.Migrator().DropTable(&CustomizeUser{}) DB.Migrator().DropTable(&CustomizeAccount{}) if err := DB.AutoMigrate(&CustomizeUser{}); err != nil { t.Fatalf("failed to migrate, got error: %v", err) } if err := DB.AutoMigrate(&CustomizeAccount{}); err != nil { t.Fatalf("failed to migrate, got error: %v", err) } number := "number-has-one-associations" cusUser := CustomizeUser{ Name: "update-has-one-associations", Account: CustomizeAccount{ Number: number, Number2: number, }, } if err := DB.Create(&cusUser).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } cusUser.Account.Number += "-update" cusUser.Account.Number2 += "-update" if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Updates(&cusUser).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } var account2 CustomizeAccount DB.Find(&account2, "user_id = ?", cusUser.ID) AssertEqual(t, account2.Number, number) AssertEqual(t, account2.Number2, cusUser.Account.Number2) }) } ================================================ FILE: tests/update_many2many_test.go ================================================ package tests_test import ( "testing" "gorm.io/gorm" . "gorm.io/gorm/utils/tests" ) func TestUpdateMany2ManyAssociations(t *testing.T) { user := *GetUser("update-many2many", Config{}) if err := DB.Create(&user).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } user.Languages = []Language{{Code: "zh-CN", Name: "Chinese"}, {Code: "en", Name: "English"}} for _, lang := range user.Languages { DB.Create(&lang) } user.Friends = []*User{{Name: "friend-1"}, {Name: "friend-2"}} if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user2 User DB.Preload("Languages").Preload("Friends").Find(&user2, "id = ?", user.ID) CheckUser(t, user2, user) for idx := range user.Friends { user.Friends[idx].Name += "new" } for idx := range user.Languages { user.Languages[idx].Name += "new" } if err := DB.Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user3 User DB.Preload("Languages").Preload("Friends").Find(&user3, "id = ?", user.ID) CheckUser(t, user2, user3) if err := DB.Session(&gorm.Session{FullSaveAssociations: true}).Save(&user).Error; err != nil { t.Fatalf("errors happened when update: %v", err) } var user4 User DB.Preload("Languages").Preload("Friends").Find(&user4, "id = ?", user.ID) CheckUser(t, user4, user) } ================================================ FILE: tests/update_test.go ================================================ package tests_test import ( "errors" "regexp" "sort" "strings" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" "gorm.io/gorm/utils" . "gorm.io/gorm/utils/tests" ) func TestUpdate(t *testing.T) { var ( users = []*User{ GetUser("update-1", Config{}), GetUser("update-2", Config{}), GetUser("update-3", Config{}), } user = users[1] lastUpdatedAt time.Time ) checkUpdatedAtChanged := func(name string, n time.Time) { if n.UnixNano() == lastUpdatedAt.UnixNano() { t.Errorf("%v: user's updated at should be changed, but got %v, was %v", name, n, lastUpdatedAt) } lastUpdatedAt = n } checkOtherData := func(name string) { var first, last User if err := DB.Where("id = ?", users[0].ID).First(&first).Error; err != nil { t.Errorf("errors happened when query before user: %v", err) } CheckUser(t, first, *users[0]) if err := DB.Where("id = ?", users[2].ID).First(&last).Error; err != nil { t.Errorf("errors happened when query after user: %v", err) } CheckUser(t, last, *users[2]) } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } else if user.ID == 0 { t.Fatalf("user's primary value should not zero, %v", user.ID) } else if user.UpdatedAt.IsZero() { t.Fatalf("user's updated at should not zero, %v", user.UpdatedAt) } lastUpdatedAt = user.UpdatedAt if err := DB.Model(user).Update("Age", 10).Error; err != nil { t.Errorf("errors happened when update: %v", err) } else if user.Age != 10 { t.Errorf("Age should equals to 10, but got %v", user.Age) } checkUpdatedAtChanged("Update", user.UpdatedAt) checkOtherData("Update") var result User if err := DB.Where("id = ?", user.ID).First(&result).Error; err != nil { t.Errorf("errors happened when query: %v", err) } else { CheckUser(t, result, *user) } values := map[string]interface{}{"Active": true, "age": 5} if res := DB.Model(user).Updates(values); res.Error != nil { t.Errorf("errors happened when update: %v", res.Error) } else if res.RowsAffected != 1 { t.Errorf("rows affected should be 1, but got : %v", res.RowsAffected) } else if user.Age != 5 { t.Errorf("Age should equals to 5, but got %v", user.Age) } else if user.Active != true { t.Errorf("Active should be true, but got %v", user.Active) } checkUpdatedAtChanged("Updates with map", user.UpdatedAt) checkOtherData("Updates with map") var result2 User if err := DB.Where("id = ?", user.ID).First(&result2).Error; err != nil { t.Errorf("errors happened when query: %v", err) } else { CheckUser(t, result2, *user) } if err := DB.Model(user).Updates(User{Age: 2}).Error; err != nil { t.Errorf("errors happened when update: %v", err) } else if user.Age != 2 { t.Errorf("Age should equals to 2, but got %v", user.Age) } checkUpdatedAtChanged("Updates with struct", user.UpdatedAt) checkOtherData("Updates with struct") var result3 User if err := DB.Where("id = ?", user.ID).First(&result3).Error; err != nil { t.Errorf("errors happened when query: %v", err) } else { CheckUser(t, result3, *user) } user.Active = false user.Age = 1 if err := DB.Save(user).Error; err != nil { t.Errorf("errors happened when update: %v", err) } else if user.Age != 1 { t.Errorf("Age should equals to 1, but got %v", user.Age) } else if user.Active != false { t.Errorf("Active should equals to false, but got %v", user.Active) } checkUpdatedAtChanged("Save", user.UpdatedAt) checkOtherData("Save") var result4 User if err := DB.Where("id = ?", user.ID).First(&result4).Error; err != nil { t.Errorf("errors happened when query: %v", err) } else { CheckUser(t, result4, *user) } if rowsAffected := DB.Model([]User{result4}).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 1 { t.Errorf("should only update one record, but got %v", rowsAffected) } if rowsAffected := DB.Model(users).Where("age > 0").Update("name", "jinzhu").RowsAffected; rowsAffected != 3 { t.Errorf("should only update one record, but got %v", rowsAffected) } } func TestUpdates(t *testing.T) { users := []*User{ GetUser("updates_01", Config{}), GetUser("updates_02", Config{}), } DB.Create(&users) lastUpdatedAt := users[0].UpdatedAt // update with map if res := DB.Model(users[0]).Updates(map[string]interface{}{"name": "updates_01_newname", "age": 100}); res.Error != nil || res.RowsAffected != 1 { t.Errorf("Failed to update users") } if users[0].Name != "updates_01_newname" || users[0].Age != 100 { t.Errorf("Record should be updated also with map") } if users[0].UpdatedAt.UnixNano() == lastUpdatedAt.UnixNano() { t.Errorf("User's updated at should be changed, but got %v, was %v", users[0].UpdatedAt.UnixNano(), lastUpdatedAt) } // user2 should not be updated var user1, user2 User DB.First(&user1, users[0].ID) DB.First(&user2, users[1].ID) CheckUser(t, user1, *users[0]) CheckUser(t, user2, *users[1]) // update with struct time.Sleep(1 * time.Second) DB.Table("users").Where("name in ?", []string{users[1].Name}).Updates(User{Name: "updates_02_newname"}) var user3 User if err := DB.First(&user3, "name = ?", "updates_02_newname").Error; err != nil { t.Errorf("User2's name should be updated") } if user2.UpdatedAt.Format(time.RFC1123Z) == user3.UpdatedAt.Format(time.RFC1123Z) { t.Errorf("User's updated at should be changed, old %v, new %v", user2.UpdatedAt.Format(time.RFC1123Z), user3.UpdatedAt.Format(time.RFC1123Z)) } // update with gorm exprs if err := DB.Model(&user3).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } var user4 User DB.First(&user4, user3.ID) user3.Age += 100 AssertObjEqual(t, user4, user3, "UpdatedAt", "Age") } func TestUpdateColumn(t *testing.T) { users := []*User{ GetUser("update_column_01", Config{}), GetUser("update_column_02", Config{}), } DB.Create(&users) lastUpdatedAt := users[1].UpdatedAt // update with map DB.Model(users[1]).UpdateColumns(map[string]interface{}{"name": "update_column_02_newname", "age": 100}) if users[1].Name != "update_column_02_newname" || users[1].Age != 100 { t.Errorf("user 2 should be updated with update column") } AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) // user2 should not be updated var user1, user2 User DB.First(&user1, users[0].ID) DB.First(&user2, users[1].ID) CheckUser(t, user1, *users[0]) CheckUser(t, user2, *users[1]) DB.Model(users[1]).UpdateColumn("name", "update_column_02_newnew").UpdateColumn("age", 19) AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) if users[1].Name != "update_column_02_newnew" { t.Errorf("user 2's name should be updated, but got %v", users[1].Name) } if users[1].Age != 19 { t.Errorf("user 2's name should be updated, but got %v", users[1].Age) } DB.Model(users[1]).UpdateColumn("age", gorm.Expr("age + 100 - 50")) var user3 User DB.First(&user3, users[1].ID) users[1].Age += 50 CheckUser(t, user3, *users[1]) // update with struct DB.Model(users[1]).UpdateColumns(User{Name: "update_column_02_newnew2", Age: 200}) if users[1].Name != "update_column_02_newnew2" || users[1].Age != 200 { t.Errorf("user 2 should be updated with update column") } AssertEqual(t, lastUpdatedAt.UnixNano(), users[1].UpdatedAt.UnixNano()) // user2 should not be updated var user5, user6 User DB.First(&user5, users[0].ID) DB.First(&user6, users[1].ID) CheckUser(t, user5, *users[0]) CheckUser(t, user6, *users[1]) } func TestBlockGlobalUpdate(t *testing.T) { if err := DB.Model(&User{}).Update("name", "jinzhu").Error; err == nil || !errors.Is(err, gorm.ErrMissingWhereClause) { t.Errorf("should returns missing WHERE clause while updating error, got err %v", err) } if err := DB.Session(&gorm.Session{AllowGlobalUpdate: true}).Model(&User{}).Update("name", "jinzhu").Error; err != nil { t.Errorf("should returns no error while enable global update, but got err %v", err) } } func TestSelectWithUpdate(t *testing.T) { user := *GetUser("select_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) var result User DB.First(&result, user.ID) user2 := *GetUser("select_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) result.Name = user2.Name result.Age = 50 result.Account = user2.Account result.Pets = user2.Pets result.Toys = user2.Toys result.Company = user2.Company result.Manager = user2.Manager result.Team = user2.Team result.Languages = user2.Languages result.Friends = user2.Friends DB.Select("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) var result2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) result.Languages = append(user.Languages, result.Languages...) result.Toys = append(user.Toys, result.Toys...) sort.Slice(result.Languages, func(i, j int) bool { return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 }) sort.Slice(result.Toys, func(i, j int) bool { return result.Toys[i].ID < result.Toys[j].ID }) sort.Slice(result2.Languages, func(i, j int) bool { return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 }) sort.Slice(result2.Toys, func(i, j int) bool { return result2.Toys[i].ID < result2.Toys[j].ID }) AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") DB.Model(&result).Select("Name", "Age").Updates(User{Name: "update_with_select"}) if result.Age != 0 || result.Name != "update_with_select" { t.Fatalf("Failed to update struct with select, got %+v", result) } AssertObjEqual(t, result, user, "UpdatedAt") var result3 User DB.First(&result3, result.ID) AssertObjEqual(t, result, result3, "Name", "Age", "UpdatedAt") DB.Model(&result).Select("Name", "Age", "UpdatedAt").Updates(User{Name: "update_with_select"}) if utils.AssertEqual(result.UpdatedAt, user.UpdatedAt) { t.Fatalf("Update struct should update UpdatedAt, was %+v, got %+v", result.UpdatedAt, user.UpdatedAt) } AssertObjEqual(t, result, User{Name: "update_with_select"}, "Name", "Age") } func TestSelectWithUpdateWithMap(t *testing.T) { user := *GetUser("select_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) var result User DB.First(&result, user.ID) user2 := *GetUser("select_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) updateValues := map[string]interface{}{ "Name": user2.Name, "Age": 50, "Account": user2.Account, "Pets": user2.Pets, "Toys": user2.Toys, "Company": user2.Company, "Manager": user2.Manager, "Team": user2.Team, "Languages": user2.Languages, "Friends": user2.Friends, } DB.Model(&result).Omit("name", "updated_at").Updates(updateValues) var result2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) result.Languages = append(user.Languages, result.Languages...) result.Toys = append(user.Toys, result.Toys...) sort.Slice(result.Languages, func(i, j int) bool { return strings.Compare(result.Languages[i].Code, result.Languages[j].Code) > 0 }) sort.Slice(result.Toys, func(i, j int) bool { return result.Toys[i].ID < result.Toys[j].ID }) sort.Slice(result2.Languages, func(i, j int) bool { return strings.Compare(result2.Languages[i].Code, result2.Languages[j].Code) > 0 }) sort.Slice(result2.Toys, func(i, j int) bool { return result2.Toys[i].ID < result2.Toys[j].ID }) AssertObjEqual(t, result2, result, "Name", "Account", "Toys", "Manager", "ManagerID", "Languages") } func TestWithUpdateWithInvalidMap(t *testing.T) { user := *GetUser("update_with_invalid_map", Config{}) DB.Create(&user) if err := DB.Model(&user).Updates(map[string]string{"name": "jinzhu"}).Error; !errors.Is(err, gorm.ErrInvalidData) { t.Errorf("should returns error for unsupported updating data") } } func TestOmitWithUpdate(t *testing.T) { user := *GetUser("omit_update", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) var result User DB.First(&result, user.ID) user2 := *GetUser("omit_update_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) result.Name = user2.Name result.Age = 50 result.Account = user2.Account result.Pets = user2.Pets result.Toys = user2.Toys result.Company = user2.Company result.Manager = user2.Manager result.Team = user2.Team result.Languages = user2.Languages result.Friends = user2.Friends DB.Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Save(&result) var result2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) result.Pets = append(user.Pets, result.Pets...) result.Team = append(user.Team, result.Team...) result.Friends = append(user.Friends, result.Friends...) sort.Slice(result.Pets, func(i, j int) bool { return result.Pets[i].ID < result.Pets[j].ID }) sort.Slice(result.Team, func(i, j int) bool { return result.Team[i].ID < result.Team[j].ID }) sort.Slice(result.Friends, func(i, j int) bool { return result.Friends[i].ID < result.Friends[j].ID }) sort.Slice(result2.Pets, func(i, j int) bool { return result2.Pets[i].ID < result2.Pets[j].ID }) sort.Slice(result2.Team, func(i, j int) bool { return result2.Team[i].ID < result2.Team[j].ID }) sort.Slice(result2.Friends, func(i, j int) bool { return result2.Friends[i].ID < result2.Friends[j].ID }) AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") } func TestOmitWithUpdateWithMap(t *testing.T) { user := *GetUser("omit_update_map", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) var result User DB.First(&result, user.ID) user2 := *GetUser("omit_update_map_new", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) updateValues := map[string]interface{}{ "Name": user2.Name, "Age": 50, "Account": user2.Account, "Pets": user2.Pets, "Toys": user2.Toys, "Company": user2.Company, "Manager": user2.Manager, "Team": user2.Team, "Languages": user2.Languages, "Friends": user2.Friends, } DB.Model(&result).Omit("Name", "Account", "Toys", "Manager", "ManagerID", "Languages").Updates(updateValues) var result2 User DB.Preload("Account").Preload("Pets").Preload("Toys").Preload("Company").Preload("Manager").Preload("Team").Preload("Languages").Preload("Friends").First(&result2, user.ID) result.Pets = append(user.Pets, result.Pets...) result.Team = append(user.Team, result.Team...) result.Friends = append(user.Friends, result.Friends...) sort.Slice(result.Pets, func(i, j int) bool { return result.Pets[i].ID < result.Pets[j].ID }) sort.Slice(result.Team, func(i, j int) bool { return result.Team[i].ID < result.Team[j].ID }) sort.Slice(result.Friends, func(i, j int) bool { return result.Friends[i].ID < result.Friends[j].ID }) sort.Slice(result2.Pets, func(i, j int) bool { return result2.Pets[i].ID < result2.Pets[j].ID }) sort.Slice(result2.Team, func(i, j int) bool { return result2.Team[i].ID < result2.Team[j].ID }) sort.Slice(result2.Friends, func(i, j int) bool { return result2.Friends[i].ID < result2.Friends[j].ID }) AssertObjEqual(t, result2, result, "Age", "Pets", "Company", "CompanyID", "Team", "Friends") } func TestSelectWithUpdateColumn(t *testing.T) { user := *GetUser("select_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} var result User DB.First(&result, user.ID) time.Sleep(time.Second) lastUpdatedAt := result.UpdatedAt DB.Model(&result).Select("Name").Updates(updateValues) var result2 User DB.First(&result2, user.ID) if lastUpdatedAt.Format(time.RFC3339Nano) == result2.UpdatedAt.Format(time.RFC3339Nano) { t.Errorf("UpdatedAt should be changed") } if result2.Name == user.Name || result2.Age != user.Age { t.Errorf("Should only update users with name column") } } func TestOmitWithUpdateColumn(t *testing.T) { user := *GetUser("omit_with_update_column", Config{Account: true, Pets: 3, Toys: 3, Company: true, Manager: true, Team: 3, Languages: 3, Friends: 4}) DB.Create(&user) updateValues := map[string]interface{}{"Name": "new_name", "Age": 50} var result User DB.First(&result, user.ID) DB.Model(&result).Omit("Name").UpdateColumns(updateValues) var result2 User DB.First(&result2, user.ID) if result2.Name != user.Name || result2.Age == user.Age { t.Errorf("Should only update users with name column") } } func TestUpdateColumnsSkipsAssociations(t *testing.T) { user := *GetUser("update_column_skips_association", Config{}) DB.Create(&user) // Update a single field of the user and verify that the changed address is not stored. newAge := uint(100) user.Account.Number = "new_account_number" db := DB.Model(&user).UpdateColumns(User{Age: newAge}) if db.RowsAffected != 1 { t.Errorf("Expected RowsAffected=1 but instead RowsAffected=%v", db.RowsAffected) } // Verify that Age now=`newAge`. result := &User{} result.ID = user.ID DB.Preload("Account").First(result) if result.Age != newAge { t.Errorf("Expected freshly queried user to have Age=%v but instead found Age=%v", newAge, result.Age) } if result.Account.Number != user.Account.Number { t.Errorf("account number should not been changed, expects: %v, got %v", user.Account.Number, result.Account.Number) } } func TestUpdatesWithBlankValues(t *testing.T) { user := *GetUser("updates_with_blank_value", Config{}) DB.Save(&user) var user2 User user2.ID = user.ID DB.Model(&user2).Updates(&User{Age: 100}) var result User DB.First(&result, user.ID) if result.Name != user.Name || result.Age != 100 { t.Errorf("user's name should not be updated") } } func TestUpdatesTableWithIgnoredValues(t *testing.T) { type ElementWithIgnoredField struct { Id int64 Value string IgnoredField int64 `gorm:"-"` } DB.Migrator().DropTable(&ElementWithIgnoredField{}) DB.AutoMigrate(&ElementWithIgnoredField{}) elem := ElementWithIgnoredField{Value: "foo", IgnoredField: 10} DB.Save(&elem) DB.Model(&ElementWithIgnoredField{}). Where("id = ?", elem.Id). Updates(&ElementWithIgnoredField{Value: "bar", IgnoredField: 100}) var result ElementWithIgnoredField if err := DB.First(&result, elem.Id).Error; err != nil { t.Errorf("error getting an element from database: %s", err.Error()) } if result.IgnoredField != 0 { t.Errorf("element's ignored field should not be updated") } } func TestUpdateFromSubQuery(t *testing.T) { user := *GetUser("update_from_sub_query", Config{Company: true}) if err := DB.Create(&user).Error; err != nil { t.Errorf("failed to create user, got error: %v", err) } if err := DB.Model(&user).Update("name", DB.Model(&Company{}).Select("name").Where("companies.id = users.company_id")).Error; err != nil { t.Errorf("failed to update with sub query, got error %v", err) } var result User DB.First(&result, user.ID) if result.Name != user.Company.Name { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } DB.Model(&user.Company).Update("Name", "new company name") if err := DB.Table("users").Where("1 = 1").Update("name", DB.Table("companies").Select("name").Where("companies.id = users.company_id")).Error; err != nil { t.Errorf("failed to update with sub query, got error %v", err) } DB.First(&result, user.ID) if result.Name != "new company name" { t.Errorf("name should be %v, but got %v", user.Company.Name, result.Name) } } func TestIdempotentSave(t *testing.T) { create := Company{ Name: "company_idempotent", } DB.Create(&create) var company Company if err := DB.Find(&company, "id = ?", create.ID).Error; err != nil { t.Fatalf("failed to find created company, got err: %v", err) } if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { t.Errorf("failed to save company, got err: %v", err) } if err := DB.Save(&company).Error; err != nil || company.ID != create.ID { t.Errorf("failed to save company, got err: %v", err) } } func TestSave(t *testing.T) { user := *GetUser("save", Config{}) DB.Create(&user) if err := DB.First(&User{}, "name = ?", "save").Error; err != nil { t.Fatalf("failed to find created user") } user.Name = "save2" DB.Save(&user) var result User if err := DB.First(&result, "name = ?", "save2").Error; err != nil || result.ID != user.ID { t.Fatalf("failed to find updated user") } user2 := *GetUser("save2", Config{}) DB.Create(&user2) time.Sleep(time.Second) user1UpdatedAt := result.UpdatedAt user2UpdatedAt := user2.UpdatedAt users := []*User{&result, &user2} DB.Save(&users) if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) } if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { t.Fatalf("user's updated at should be changed, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) } DB.First(&result) if user1UpdatedAt.Format(time.RFC1123Z) == result.UpdatedAt.Format(time.RFC1123Z) { t.Fatalf("user's updated at should be changed after reload, expects: %+v, got: %+v", user1UpdatedAt, result.UpdatedAt) } DB.First(&user2) if user2UpdatedAt.Format(time.RFC1123Z) == user2.UpdatedAt.Format(time.RFC1123Z) { t.Fatalf("user2's updated at should be changed after reload, expects: %+v, got: %+v", user2UpdatedAt, user2.UpdatedAt) } dryDB := DB.Session(&gorm.Session{DryRun: true}) stmt := dryDB.Save(&user).Statement if !regexp.MustCompile(`.users.\..deleted_at. IS NULL`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } dryDB = DB.Session(&gorm.Session{DryRun: true}) stmt = dryDB.Unscoped().Save(&user).Statement if !regexp.MustCompile(`WHERE .id. = [^ ]+$`).MatchString(stmt.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", stmt.SQL.String()) } user3 := *GetUser("save3", Config{}) DB.Create(&user3) if err := DB.First(&User{}, "name = ?", "save3").Error; err != nil { t.Fatalf("failed to find created user") } user3.Name = "save3_" if err := DB.Model(User{Model: user3.Model}).Save(&user3).Error; err != nil { t.Fatalf("failed to save user, got %v", err) } var result2 User if err := DB.First(&result2, "name = ?", "save3_").Error; err != nil || result2.ID != user3.ID { t.Fatalf("failed to find updated user, got %v", err) } if err := DB.Model(User{Model: user3.Model}).Save(&struct { gorm.Model Placeholder string Name string }{ Model: user3.Model, Placeholder: "placeholder", Name: "save3__", }).Error; err != nil { t.Fatalf("failed to update user, got %v", err) } var result3 User if err := DB.First(&result3, "name = ?", "save3__").Error; err != nil || result3.ID != user3.ID { t.Fatalf("failed to find updated user") } } func TestSaveWithPrimaryValue(t *testing.T) { lang := Language{Code: "save", Name: "save"} if result := DB.Save(&lang); result.RowsAffected != 1 { t.Errorf("should create language, rows affected: %v", result.RowsAffected) } var result Language DB.First(&result, "code = ?", "save") AssertEqual(t, result, lang) lang.Name = "save name2" if result := DB.Save(&lang); result.RowsAffected != 1 { t.Errorf("should update language") } var result2 Language DB.First(&result2, "code = ?", "save") AssertEqual(t, result2, lang) DB.Table("langs").Migrator().DropTable(&Language{}) DB.Table("langs").AutoMigrate(&Language{}) if err := DB.Table("langs").Save(&lang).Error; err != nil { t.Errorf("no error should happen when creating data, but got %v", err) } var result3 Language if err := DB.Table("langs").First(&result3, "code = ?", lang.Code).Error; err != nil || result3.Name != lang.Name { t.Errorf("failed to find created record, got error: %v, result: %+v", err, result3) } lang.Name += "name2" if err := DB.Table("langs").Save(&lang).Error; err != nil { t.Errorf("no error should happen when creating data, but got %v", err) } var result4 Language if err := DB.Table("langs").First(&result4, "code = ?", lang.Code).Error; err != nil || result4.Name != lang.Name { t.Errorf("failed to find created record, got error: %v, result: %+v", err, result4) } } // only sqlite, postgres, gaussdb, sqlserver support returning func TestUpdateReturning(t *testing.T) { if DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlserver" { return } users := []*User{ GetUser("update-returning-1", Config{}), GetUser("update-returning-2", Config{}), GetUser("update-returning-3", Config{}), } DB.Create(&users) var results []User DB.Model(&results).Where("name IN ?", []string{users[0].Name, users[1].Name}).Clauses(clause.Returning{}).Update("age", 88) if len(results) != 2 || results[0].Age != 88 || results[1].Age != 88 { t.Errorf("failed to return updated data, got %v", results) } if err := DB.Model(&results[0]).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } if err := DB.Model(&results[1]).Clauses(clause.Returning{Columns: []clause.Column{{Name: "age"}}}).Updates(map[string]interface{}{"age": gorm.Expr("age + ?", 100)}).Error; err != nil { t.Errorf("Not error should happen when updating with gorm expr, but got %v", err) } if results[1].Age-results[0].Age != 100 { t.Errorf("failed to return updated age column") } } func TestUpdateWithDiffSchema(t *testing.T) { user := GetUser("update-diff-schema-1", Config{}) DB.Create(&user) type UserTemp struct { Name string } err := DB.Model(&user).Updates(&UserTemp{Name: "update-diff-schema-2"}).Error AssertEqual(t, err, nil) AssertEqual(t, "update-diff-schema-2", user.Name) } type TokenOwner struct { ID int Name string Token Token `gorm:"foreignKey:UserID"` } func (t *TokenOwner) BeforeSave(tx *gorm.DB) error { t.Name += "_name" return nil } type Token struct { UserID int `gorm:"primary_key"` Content string `gorm:"type:varchar(100)"` } func (t *Token) BeforeSave(tx *gorm.DB) error { t.Content += "_encrypted" return nil } func TestSaveWithHooks(t *testing.T) { DB.Migrator().DropTable(&Token{}, &TokenOwner{}) DB.AutoMigrate(&Token{}, &TokenOwner{}) saveTokenOwner := func(owner *TokenOwner) (*TokenOwner, error) { var newOwner TokenOwner if err := DB.Transaction(func(tx *gorm.DB) error { if err := tx.Session(&gorm.Session{FullSaveAssociations: true}).Save(owner).Error; err != nil { return err } if err := tx.Preload("Token").First(&newOwner, owner.ID).Error; err != nil { return err } return nil }); err != nil { return nil, err } return &newOwner, nil } owner := TokenOwner{ Name: "user", Token: Token{Content: "token"}, } o1, err := saveTokenOwner(&owner) if err != nil { t.Errorf("failed to save token owner, got error: %v", err) } if o1.Name != "user_name" { t.Errorf(`owner name should be "user_name", but got: "%s"`, o1.Name) } if o1.Token.Content != "token_encrypted" { t.Errorf(`token content should be "token_encrypted", but got: "%s"`, o1.Token.Content) } owner = TokenOwner{ ID: owner.ID, Name: "user", Token: Token{Content: "token2"}, } o2, err := saveTokenOwner(&owner) if err != nil { t.Errorf("failed to save token owner, got error: %v", err) } if o2.Name != "user_name" { t.Errorf(`owner name should be "user_name", but got: "%s"`, o2.Name) } if o2.Token.Content != "token2_encrypted" { t.Errorf(`token content should be "token2_encrypted", but got: "%s"`, o2.Token.Content) } } // only postgres, gaussdb, sqlserver, sqlite support update from func TestUpdateFrom(t *testing.T) { if DB.Dialector.Name() != "postgres" && DB.Dialector.Name() != "gaussdb" && DB.Dialector.Name() != "sqlite" && DB.Dialector.Name() != "sqlserver" { return } users := []*User{ GetUser("update-from-1", Config{Account: true}), GetUser("update-from-2", Config{Account: true}), GetUser("update-from-3", Config{}), } if err := DB.Create(&users).Error; err != nil { t.Fatalf("errors happened when create: %v", err) } else if users[0].ID == 0 { t.Fatalf("user's primary value should not zero, %v", users[0].ID) } else if users[0].UpdatedAt.IsZero() { t.Fatalf("user's updated at should not zero, %v", users[0].UpdatedAt) } if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number = ? AND accounts.deleted_at IS NULL", users[0].Account.Number).Update("name", "franco").RowsAffected; rowsAffected != 1 { t.Errorf("should only update one record, but got %v", rowsAffected) } var result User if err := DB.Where("id = ?", users[0].ID).First(&result).Error; err != nil { t.Errorf("errors happened when query before user: %v", err) } else if result.UpdatedAt.UnixNano() == users[0].UpdatedAt.UnixNano() { t.Errorf("user's updated at should be changed, but got %v, was %v", result.UpdatedAt, users[0].UpdatedAt) } else if result.Name != "franco" { t.Errorf("user's name should be updated") } if rowsAffected := DB.Model(&User{}).Clauses(clause.From{Tables: []clause.Table{{Name: "accounts"}}}).Where("accounts.user_id = users.id AND accounts.number IN ? AND accounts.deleted_at IS NULL", []string{users[0].Account.Number, users[1].Account.Number}).Update("name", gorm.Expr("accounts.number")).RowsAffected; rowsAffected != 2 { t.Errorf("should update two records, but got %v", rowsAffected) } var results []User if err := DB.Preload("Account").Find(&results, []uint{users[0].ID, users[1].ID}).Error; err != nil { t.Errorf("Not error should happen when finding users, but got %v", err) } for _, user := range results { if user.Name != user.Account.Number { t.Errorf("user's name should be equal to the account's number %v, but got %v", user.Account.Number, user.Name) } } } ================================================ FILE: tests/upsert_test.go ================================================ package tests_test import ( "regexp" "testing" "time" "gorm.io/gorm" "gorm.io/gorm/clause" . "gorm.io/gorm/utils/tests" ) func TestUpsert(t *testing.T) { lang := Language{Code: "upsert", Name: "Upsert"} if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang).Error; err != nil { t.Fatalf("failed to upsert, got %v", err) } lang2 := Language{Code: "upsert", Name: "Upsert"} if err := DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&lang2).Error; err != nil { t.Fatalf("failed to upsert, got %v", err) } var langs []Language if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } lang3 := Language{Code: "upsert", Name: "Upsert"} if err := DB.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "code"}}, DoUpdates: clause.Assignments(map[string]interface{}{"name": "upsert-new"}), }).Create(&lang3).Error; err != nil { t.Fatalf("failed to upsert, got %v", err) } if err := DB.Find(&langs, "code = ?", lang.Code).Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } else if langs[0].Name != "upsert-new" { t.Errorf("should update name on conflict, but got name %+v", langs[0].Name) } lang = Language{Code: "upsert", Name: "Upsert-Newname"} if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&lang).Error; err != nil { t.Fatalf("failed to upsert, got %v", err) } var result Language if err := DB.Find(&result, "code = ?", lang.Code).Error; err != nil || result.Name != lang.Name { t.Fatalf("failed to upsert, got name %v", result.Name) } if name := DB.Dialector.Name(); name != "sqlserver" { type RestrictedLanguage struct { Code string `gorm:"primarykey"` Name string Lang string `gorm:"<-:create"` } r := DB.Session(&gorm.Session{DryRun: true}).Clauses(clause.OnConflict{UpdateAll: true}).Create(&RestrictedLanguage{Code: "upsert_code", Name: "upsert_name", Lang: "upsert_lang"}) if !regexp.MustCompile(`INTO .restricted_languages. .*\(.code.,.name.,.lang.\) .* (SET|UPDATE) .name.=.*.name.\W*$`).MatchString(r.Statement.SQL.String()) { t.Errorf("Table with escape character, got %v", r.Statement.SQL.String()) } } user := *GetUser("upsert_on_conflict", Config{}) user.Age = 20 if err := DB.Create(&user).Error; err != nil { t.Errorf("failed to create user, got error %v", err) } var user2 User DB.First(&user2, user.ID) user2.Age = 30 time.Sleep(time.Second) if err := DB.Clauses(clause.OnConflict{UpdateAll: true}).Create(&user2).Error; err != nil { t.Fatalf("failed to onconflict create user, got error %v", err) } else { var user3 User DB.First(&user3, user.ID) if user3.UpdatedAt.UnixNano() == user2.UpdatedAt.UnixNano() { t.Fatalf("failed to update user's updated_at, old: %v, new: %v", user2.UpdatedAt, user3.UpdatedAt) } } } func TestUpsertSlice(t *testing.T) { langs := []Language{ {Code: "upsert-slice1", Name: "Upsert-slice1"}, {Code: "upsert-slice2", Name: "Upsert-slice2"}, {Code: "upsert-slice3", Name: "Upsert-slice3"}, } DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) var langs2 []Language if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs2) != 3 { t.Errorf("should only find only 3 languages, but got %+v", langs2) } DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) var langs3 []Language if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs3) != 3 { t.Errorf("should only find only 3 languages, but got %+v", langs3) } for idx, lang := range langs { lang.Name = lang.Name + "_new" langs[idx] = lang } if err := DB.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "code"}}, DoUpdates: clause.AssignmentColumns([]string{"name"}), }).Create(&langs).Error; err != nil { t.Fatalf("failed to upsert, got %v", err) } for _, lang := range langs { var results []Language if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(results) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } else if results[0].Name != lang.Name { t.Errorf("should update name on conflict, but got name %+v", results[0].Name) } } } func TestUpsertSliceWithReturning(t *testing.T) { langs := []Language{ {Code: "upsert-slice1", Name: "Upsert-slice1"}, {Code: "upsert-slice2", Name: "Upsert-slice2"}, {Code: "upsert-slice3", Name: "Upsert-slice3"}, } DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) var langs2 []Language if err := DB.Find(&langs2, "code LIKE ?", "upsert-slice%").Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs2) != 3 { t.Errorf("should only find only 3 languages, but got %+v", langs2) } DB.Clauses(clause.OnConflict{DoNothing: true}).Create(&langs) var langs3 []Language if err := DB.Find(&langs3, "code LIKE ?", "upsert-slice%").Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(langs3) != 3 { t.Errorf("should only find only 3 languages, but got %+v", langs3) } for idx, lang := range langs { lang.Name = lang.Name + "_new" langs[idx] = lang } if err := DB.Clauses(clause.OnConflict{ Columns: []clause.Column{{Name: "code"}}, DoUpdates: clause.AssignmentColumns([]string{"name"}), }, clause.Returning{}).CreateInBatches(&langs, len(langs)).Error; err != nil { t.Fatalf("failed to upsert, got %v", err) } for _, lang := range langs { var results []Language if err := DB.Find(&results, "code = ?", lang.Code).Error; err != nil { t.Errorf("no error should happen when find languages with code, but got %v", err) } else if len(results) != 1 { t.Errorf("should only find only 1 languages, but got %+v", langs) } else if results[0].Name != lang.Name { t.Errorf("should update name on conflict, but got name %+v", results[0].Name) } } } func TestUpsertWithSave(t *testing.T) { langs := []Language{ {Code: "upsert-save-1", Name: "Upsert-save-1"}, {Code: "upsert-save-2", Name: "Upsert-save-2"}, } if err := DB.Save(&langs).Error; err != nil { t.Errorf("Failed to create, got error %v", err) } for _, lang := range langs { var result Language if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { t.Errorf("Failed to query lang, got error %v", err) } else { AssertEqual(t, result, lang) } } for idx, lang := range langs { lang.Name += "_new" langs[idx] = lang } if err := DB.Save(&langs).Error; err != nil { t.Errorf("Failed to upsert, got error %v", err) } for _, lang := range langs { var result Language if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { t.Errorf("Failed to query lang, got error %v", err) } else { AssertEqual(t, result, lang) } } lang := Language{Code: "upsert-save-3", Name: "Upsert-save-3"} if err := DB.Save(&lang).Error; err != nil { t.Errorf("Failed to create, got error %v", err) } var result Language if err := DB.First(&result, "code = ?", lang.Code).Error; err != nil { t.Errorf("Failed to query lang, got error %v", err) } else { AssertEqual(t, result, lang) } lang.Name += "_new" if err := DB.Save(&lang).Error; err != nil { t.Errorf("Failed to create, got error %v", err) } var result2 Language if err := DB.First(&result2, "code = ?", lang.Code).Error; err != nil { t.Errorf("Failed to query lang, got error %v", err) } else { AssertEqual(t, result2, lang) } } func TestFindOrInitialize(t *testing.T) { var user1, user2, user3, user4, user5, user6 User if err := DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user1).Error; err != nil { t.Errorf("no error should happen when FirstOrInit, but got %v", err) } if user1.Name != "find or init" || user1.ID != 0 || user1.Age != 33 { t.Errorf("user should be initialized with search value") } DB.Where(User{Name: "find or init", Age: 33}).FirstOrInit(&user2) if user2.Name != "find or init" || user2.ID != 0 || user2.Age != 33 { t.Errorf("user should be initialized with search value") } DB.FirstOrInit(&user3, map[string]interface{}{"name": "find or init 2"}) if user3.Name != "find or init 2" || user3.ID != 0 { t.Errorf("user should be initialized with inline search value") } DB.Where(&User{Name: "find or init"}).Attrs(User{Age: 44}).FirstOrInit(&user4) if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { t.Errorf("user should be initialized with search value and attrs") } DB.Where(&User{Name: "find or init"}).Assign("age", 44).FirstOrInit(&user4) if user4.Name != "find or init" || user4.ID != 0 || user4.Age != 44 { t.Errorf("user should be initialized with search value and assign attrs") } DB.Save(&User{Name: "find or init", Age: 33}) DB.Where(&User{Name: "find or init"}).Attrs("age", 44).FirstOrInit(&user5) if user5.Name != "find or init" || user5.ID == 0 || user5.Age != 33 { t.Errorf("user should be found and not initialized by Attrs") } DB.Where(&User{Name: "find or init", Age: 33}).FirstOrInit(&user6) if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 33 { t.Errorf("user should be found with FirstOrInit") } DB.Where(&User{Name: "find or init"}).Assign(User{Age: 44}).FirstOrInit(&user6) if user6.Name != "find or init" || user6.ID == 0 || user6.Age != 44 { t.Errorf("user should be found and updated with assigned attrs") } } func TestFindOrCreate(t *testing.T) { var user1, user2, user3, user4, user5, user6, user7, user8 User if err := DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user1).Error; err != nil { t.Errorf("no error should happen when FirstOrInit, but got %v", err) } if user1.Name != "find or create" || user1.ID == 0 || user1.Age != 33 { t.Errorf("user should be created with search value") } DB.Where(&User{Name: "find or create", Age: 33}).FirstOrCreate(&user2) if user1.ID != user2.ID || user2.Name != "find or create" || user2.ID == 0 || user2.Age != 33 { t.Errorf("user should be created with search value") } DB.FirstOrCreate(&user3, map[string]interface{}{"name": "find or create 2"}) if user3.Name != "find or create 2" || user3.ID == 0 { t.Errorf("user should be created with inline search value") } DB.Where(&User{Name: "find or create 3"}).Attrs("age", 44).FirstOrCreate(&user4) if user4.Name != "find or create 3" || user4.ID == 0 || user4.Age != 44 { t.Errorf("user should be created with search value and attrs") } updatedAt1 := user4.UpdatedAt DB.Where(&User{Name: "find or create 3"}).Assign("age", 55).FirstOrCreate(&user4) if user4.Age != 55 { t.Errorf("Failed to set change to 55, got %v", user4.Age) } if updatedAt1.Format(time.RFC3339Nano) == user4.UpdatedAt.Format(time.RFC3339Nano) { t.Errorf("UpdateAt should be changed when update values with assign") } DB.Where(&User{Name: "find or create 4"}).Assign(User{Age: 44}).FirstOrCreate(&user4) if user4.Name != "find or create 4" || user4.ID == 0 || user4.Age != 44 { t.Errorf("user should be created with search value and assigned attrs") } DB.Where(&User{Name: "find or create"}).Attrs("age", 44).FirstOrInit(&user5) if user5.Name != "find or create" || user5.ID == 0 || user5.Age != 33 { t.Errorf("user should be found and not initialized by Attrs") } DB.Where(&User{Name: "find or create"}).Assign(User{Age: 44}).FirstOrCreate(&user6) if user6.Name != "find or create" || user6.ID == 0 || user6.Age != 44 { t.Errorf("user should be found and updated with assigned attrs") } DB.Where(&User{Name: "find or create"}).Find(&user7) if user7.Name != "find or create" || user7.ID == 0 || user7.Age != 44 { t.Errorf("user should be found and updated with assigned attrs") } DB.Where(&User{Name: "find or create embedded struct"}).Assign(User{Age: 44, Account: Account{Number: "1231231231"}, Pets: []*Pet{{Name: "first_or_create_pet1"}, {Name: "first_or_create_pet2"}}}).FirstOrCreate(&user8) if err := DB.Where("name = ?", "first_or_create_pet1").First(&Pet{}).Error; err != nil { t.Errorf("has many association should be saved") } if err := DB.Where("number = ?", "1231231231").First(&Account{}).Error; err != nil { t.Errorf("belongs to association should be saved") } } func TestUpdateWithMissWhere(t *testing.T) { type User struct { ID uint `gorm:"column:id;<-:create"` Name string `gorm:"column:name"` } user := User{ID: 1, Name: "king"} tx := DB.Session(&gorm.Session{DryRun: true}).Save(&user) if err := tx.Error; err != nil { t.Fatalf("failed to update user,missing where condition,err=%+v", err) } if !regexp.MustCompile("WHERE .id. = [^ ]+$").MatchString(tx.Statement.SQL.String()) { t.Fatalf("invalid updating SQL, got %v", tx.Statement.SQL.String()) } } ================================================ FILE: utils/tests/dummy_dialecter.go ================================================ package tests import ( "gorm.io/gorm" "gorm.io/gorm/callbacks" "gorm.io/gorm/clause" "gorm.io/gorm/logger" "gorm.io/gorm/schema" ) type DummyDialector struct { TranslatedErr error } func (DummyDialector) Name() string { return "dummy" } func (DummyDialector) Initialize(db *gorm.DB) error { callbacks.RegisterDefaultCallbacks(db, &callbacks.Config{ CreateClauses: []string{"INSERT", "VALUES", "ON CONFLICT", "RETURNING"}, UpdateClauses: []string{"UPDATE", "SET", "WHERE", "RETURNING"}, DeleteClauses: []string{"DELETE", "FROM", "WHERE", "RETURNING"}, LastInsertIDReversed: true, }) return nil } func (DummyDialector) DefaultValueOf(field *schema.Field) clause.Expression { return clause.Expr{SQL: "DEFAULT"} } func (DummyDialector) Migrator(*gorm.DB) gorm.Migrator { return nil } func (DummyDialector) BindVarTo(writer clause.Writer, stmt *gorm.Statement, v interface{}) { writer.WriteByte('?') } func (DummyDialector) QuoteTo(writer clause.Writer, str string) { var ( underQuoted, selfQuoted bool continuousBacktick int8 shiftDelimiter int8 ) for _, v := range []byte(str) { switch v { case '`': continuousBacktick++ if continuousBacktick == 2 { writer.WriteString("``") continuousBacktick = 0 } case '.': if continuousBacktick > 0 || !selfQuoted { shiftDelimiter = 0 underQuoted = false continuousBacktick = 0 writer.WriteByte('`') } writer.WriteByte(v) continue default: if shiftDelimiter-continuousBacktick <= 0 && !underQuoted { writer.WriteByte('`') underQuoted = true if selfQuoted = continuousBacktick > 0; selfQuoted { continuousBacktick -= 1 } } for ; continuousBacktick > 0; continuousBacktick -= 1 { writer.WriteString("``") } writer.WriteByte(v) } shiftDelimiter++ } if continuousBacktick > 0 && !selfQuoted { writer.WriteString("``") } writer.WriteByte('`') } func (DummyDialector) Explain(sql string, vars ...interface{}) string { return logger.ExplainSQL(sql, nil, `"`, vars...) } func (DummyDialector) DataTypeOf(*schema.Field) string { return "" } func (d DummyDialector) Translate(err error) error { return d.TranslatedErr } ================================================ FILE: utils/tests/models.go ================================================ package tests import ( "database/sql" "time" "gorm.io/gorm" ) // User has one `Account` (has one), many `Pets` (has many) and `Toys` (has many - polymorphic) // He works in a Company (belongs to), he has a Manager (belongs to - single-table), and also managed a Team (has many - single-table) // He speaks many languages (many to many) and has many friends (many to many - single-table) // His pet also has one Toy (has one - polymorphic) // NamedPet is a reference to a named `Pet` (has one) type User struct { gorm.Model Name string Age uint Birthday *time.Time Account Account Pets []*Pet NamedPet *Pet Toys []Toy `gorm:"polymorphic:Owner"` Tools []Tools `gorm:"polymorphicType:Type;polymorphicId:CustomID"` CompanyID *int Company Company ManagerID *uint Manager *User Team []User `gorm:"foreignkey:ManagerID"` Languages []Language `gorm:"many2many:UserSpeak;"` Friends []*User `gorm:"many2many:user_friends;"` Active bool } type Account struct { gorm.Model UserID sql.NullInt64 Number string } type Pet struct { gorm.Model UserID *uint Name string Toy Toy `gorm:"polymorphic:Owner;"` } type Toy struct { gorm.Model Name string OwnerID string OwnerType string } type Tools struct { gorm.Model Name string CustomID string Type string } type Company struct { ID int Name string } type Language struct { Code string `gorm:"primarykey"` Name string } type Coupon struct { ID int `gorm:"primarykey; size:255"` AppliesToProduct []*CouponProduct `gorm:"foreignKey:CouponId;constraint:OnDelete:CASCADE"` AmountOff uint32 `gorm:"column:amount_off"` PercentOff float32 `gorm:"column:percent_off"` } type CouponProduct struct { CouponId int `gorm:"primarykey;size:255"` ProductId string `gorm:"primarykey;size:255"` Desc string } type Order struct { gorm.Model Num string Coupon *Coupon CouponID string } type Parent struct { gorm.Model FavChildID uint FavChild *Child Children []*Child } type Child struct { gorm.Model Name string ParentID *uint Parent *Parent } ================================================ FILE: utils/tests/utils.go ================================================ package tests import ( "database/sql/driver" "fmt" "go/ast" "reflect" "testing" "time" "gorm.io/gorm/utils" ) func AssertObjEqual(t *testing.T, r, e interface{}, names ...string) { for _, name := range names { rv := reflect.Indirect(reflect.ValueOf(r)) ev := reflect.Indirect(reflect.ValueOf(e)) if rv.IsValid() != ev.IsValid() { t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), r, e) return } got := rv.FieldByName(name).Interface() expect := ev.FieldByName(name).Interface() t.Run(name, func(t *testing.T) { AssertEqual(t, got, expect) }) } } func AssertEqual(t *testing.T, got, expect interface{}) { if !reflect.DeepEqual(got, expect) { isEqual := func() { if curTime, ok := got.(time.Time); ok { format := "2006-01-02T15:04:05Z07:00" if curTime.Round(time.Second).UTC().Format(format) != expect.(time.Time).Round(time.Second).UTC().Format(format) && curTime.Truncate(time.Second).UTC().Format(format) != expect.(time.Time).Truncate(time.Second).UTC().Format(format) { t.Errorf("%v: expect: %v, got %v after time round", utils.FileWithLineNum(), expect.(time.Time), curTime) } } else if fmt.Sprint(got) != fmt.Sprint(expect) { t.Errorf("%v: expect: %#v, got %#v", utils.FileWithLineNum(), expect, got) } } if fmt.Sprint(got) == fmt.Sprint(expect) { return } if reflect.Indirect(reflect.ValueOf(got)).IsValid() != reflect.Indirect(reflect.ValueOf(expect)).IsValid() { t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) return } if valuer, ok := got.(driver.Valuer); ok { got, _ = valuer.Value() } if valuer, ok := expect.(driver.Valuer); ok { expect, _ = valuer.Value() } if got != nil { got = reflect.Indirect(reflect.ValueOf(got)).Interface() } if expect != nil { expect = reflect.Indirect(reflect.ValueOf(expect)).Interface() } if reflect.ValueOf(got).IsValid() != reflect.ValueOf(expect).IsValid() { t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) return } if reflect.ValueOf(got).Kind() == reflect.Slice { if reflect.ValueOf(expect).Kind() == reflect.Slice { if reflect.ValueOf(got).Len() == reflect.ValueOf(expect).Len() { for i := 0; i < reflect.ValueOf(got).Len(); i++ { name := fmt.Sprintf(reflect.ValueOf(got).Type().Name()+" #%v", i) t.Run(name, func(t *testing.T) { AssertEqual(t, reflect.ValueOf(got).Index(i).Interface(), reflect.ValueOf(expect).Index(i).Interface()) }) } } else { name := reflect.ValueOf(got).Type().Elem().Name() t.Errorf("%v expects length: %v, got %v (expects: %+v, got %+v)", name, reflect.ValueOf(expect).Len(), reflect.ValueOf(got).Len(), expect, got) } return } } if reflect.ValueOf(got).Kind() == reflect.Struct { if reflect.ValueOf(expect).Kind() == reflect.Struct { if reflect.ValueOf(got).NumField() == reflect.ValueOf(expect).NumField() { exported := false for i := 0; i < reflect.ValueOf(got).NumField(); i++ { if fieldStruct := reflect.ValueOf(got).Type().Field(i); ast.IsExported(fieldStruct.Name) { exported = true field := reflect.ValueOf(got).Field(i) t.Run(fieldStruct.Name, func(t *testing.T) { AssertEqual(t, field.Interface(), reflect.ValueOf(expect).Field(i).Interface()) }) } } if exported { return } } } } if reflect.ValueOf(got).Type().ConvertibleTo(reflect.ValueOf(expect).Type()) { got = reflect.ValueOf(got).Convert(reflect.ValueOf(expect).Type()).Interface() isEqual() } else if reflect.ValueOf(expect).Type().ConvertibleTo(reflect.ValueOf(got).Type()) { expect = reflect.ValueOf(got).Convert(reflect.ValueOf(got).Type()).Interface() isEqual() } else { t.Errorf("%v: expect: %+v, got %+v", utils.FileWithLineNum(), expect, got) return } } } func Now() *time.Time { now := time.Now() return &now } ================================================ FILE: utils/utils.go ================================================ package utils import ( "database/sql/driver" "fmt" "path/filepath" "reflect" "runtime" "strconv" "strings" "unicode" ) var gormSourceDir string func init() { _, file, _, _ := runtime.Caller(0) // compatible solution to get gorm source directory with various operating systems gormSourceDir = sourceDir(file) } func sourceDir(file string) string { dir := filepath.Dir(file) dir = filepath.Dir(dir) s := filepath.Dir(dir) if filepath.Base(s) != "gorm.io" { s = dir } return filepath.ToSlash(s) + "/" } // CallerFrame retrieves the first relevant stack frame outside of GORM's internal implementation files. // It skips: // - GORM's core source files (identified by gormSourceDir prefix) // - Exclude test files (*_test.go) // - go-gorm/gen's Generated files (*.gen.go) func CallerFrame() runtime.Frame { pcs := [13]uintptr{} // the third caller usually from gorm internal len := runtime.Callers(3, pcs[:]) frames := runtime.CallersFrames(pcs[:len]) for i := 0; i < len; i++ { // second return value is "more", not "ok" frame, _ := frames.Next() if (!strings.HasPrefix(frame.File, gormSourceDir) || strings.HasSuffix(frame.File, "_test.go")) && !strings.HasSuffix(frame.File, ".gen.go") { return frame } } return runtime.Frame{} } // FileWithLineNum return the file name and line number of the current file func FileWithLineNum() string { frame := CallerFrame() if frame.PC != 0 { return string(strconv.AppendInt(append([]byte(frame.File), ':'), int64(frame.Line), 10)) } return "" } func IsInvalidDBNameChar(c rune) bool { return !unicode.IsLetter(c) && !unicode.IsNumber(c) && c != '.' && c != '*' && c != '_' && c != '$' && c != '@' } // CheckTruth check string true or not func CheckTruth(vals ...string) bool { for _, val := range vals { if val != "" && !strings.EqualFold(val, "false") { return true } } return false } func ToStringKey(values ...interface{}) string { results := make([]string, len(values)) for idx, value := range values { if valuer, ok := value.(driver.Valuer); ok { value, _ = valuer.Value() } switch v := value.(type) { case string: results[idx] = v case []byte: results[idx] = string(v) case uint: results[idx] = strconv.FormatUint(uint64(v), 10) default: results[idx] = "nil" vv := reflect.ValueOf(v) if vv.IsValid() && !vv.IsZero() { results[idx] = fmt.Sprint(reflect.Indirect(vv).Interface()) } } } return strings.Join(results, "_") } func Contains(elems []string, elem string) bool { for _, e := range elems { if elem == e { return true } } return false } func AssertEqual(x, y interface{}) bool { if reflect.DeepEqual(x, y) { return true } if x == nil || y == nil { return false } xval := reflect.ValueOf(x) yval := reflect.ValueOf(y) if xval.Kind() == reflect.Ptr && xval.IsNil() || yval.Kind() == reflect.Ptr && yval.IsNil() { return false } if valuer, ok := x.(driver.Valuer); ok { x, _ = valuer.Value() } if valuer, ok := y.(driver.Valuer); ok { y, _ = valuer.Value() } return reflect.DeepEqual(x, y) } func ToString(value interface{}) string { switch v := value.(type) { case string: return v case int: return strconv.FormatInt(int64(v), 10) case int8: return strconv.FormatInt(int64(v), 10) case int16: return strconv.FormatInt(int64(v), 10) case int32: return strconv.FormatInt(int64(v), 10) case int64: return strconv.FormatInt(v, 10) case uint: return strconv.FormatUint(uint64(v), 10) case uint8: return strconv.FormatUint(uint64(v), 10) case uint16: return strconv.FormatUint(uint64(v), 10) case uint32: return strconv.FormatUint(uint64(v), 10) case uint64: return strconv.FormatUint(v, 10) } return "" } const nestedRelationSplit = "__" // NestedRelationName nested relationships like `Manager__Company` func NestedRelationName(prefix, name string) string { return prefix + nestedRelationSplit + name } // SplitNestedRelationName Split nested relationships to `[]string{"Manager","Company"}` func SplitNestedRelationName(name string) []string { return strings.Split(name, nestedRelationSplit) } // JoinNestedRelationNames nested relationships like `Manager__Company` func JoinNestedRelationNames(relationNames []string) string { return strings.Join(relationNames, nestedRelationSplit) } // RTrimSlice Right trims the given slice by given length func RTrimSlice[T any](v []T, trimLen int) []T { if trimLen >= len(v) { // trimLen greater than slice len means fully sliced return v[:0] } if trimLen < 0 { // negative trimLen is ignored return v[:] } return v[:len(v)-trimLen] } ================================================ FILE: utils/utils_test.go ================================================ package utils import ( "database/sql" "database/sql/driver" "errors" "math" "strings" "testing" "time" ) func TestIsInvalidDBNameChar(t *testing.T) { for _, db := range []string{"db", "dbName", "db_name", "db1", "1dbname", "db$name"} { if fields := strings.FieldsFunc(db, IsInvalidDBNameChar); len(fields) != 1 { t.Fatalf("failed to parse db name %v", db) } } } func TestCheckTruth(t *testing.T) { checkTruthTests := []struct { v string out bool }{ {"123", true}, {"true", true}, {"", false}, {"false", false}, {"False", false}, {"FALSE", false}, {"\u0046alse", false}, } for _, test := range checkTruthTests { t.Run(test.v, func(t *testing.T) { if out := CheckTruth(test.v); out != test.out { t.Errorf("CheckTruth(%s) want: %t, got: %t", test.v, test.out, out) } }) } } func TestToStringKey(t *testing.T) { cases := []struct { values []interface{} key string }{ {[]interface{}{"a"}, "a"}, {[]interface{}{1, 2, 3}, "1_2_3"}, {[]interface{}{1, nil, 3}, "1_nil_3"}, {[]interface{}{[]interface{}{1, 2, 3}}, "[1 2 3]"}, {[]interface{}{[]interface{}{"1", "2", "3"}}, "[1 2 3]"}, {[]interface{}{[]interface{}{"1", nil, "3"}}, "[1 3]"}, } for _, c := range cases { if key := ToStringKey(c.values...); key != c.key { t.Errorf("%v: expected %v, got %v", c.values, c.key, key) } } } func TestContains(t *testing.T) { containsTests := []struct { name string elems []string elem string out bool }{ {"exists", []string{"1", "2", "3"}, "1", true}, {"not exists", []string{"1", "2", "3"}, "4", false}, } for _, test := range containsTests { t.Run(test.name, func(t *testing.T) { if out := Contains(test.elems, test.elem); test.out != out { t.Errorf("Contains(%v, %s) want: %t, got: %t", test.elems, test.elem, test.out, out) } }) } } type ModifyAt sql.NullTime // Value return a Unix time. func (n ModifyAt) Value() (driver.Value, error) { if !n.Valid { return nil, nil } return n.Time.Unix(), nil } func TestAssertEqual(t *testing.T) { now := time.Now() assertEqualTests := []struct { name string src, dst interface{} out bool }{ {"error equal", errors.New("1"), errors.New("1"), true}, {"error not equal", errors.New("1"), errors.New("2"), false}, {"driver.Valuer equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now, Valid: true}, true}, {"driver.Valuer not equal", ModifyAt{Time: now, Valid: true}, ModifyAt{Time: now.Add(time.Second), Valid: true}, false}, {"driver.Valuer equal (ptr to nil ptr)", (*ModifyAt)(nil), &ModifyAt{}, false}, } for _, test := range assertEqualTests { t.Run(test.name, func(t *testing.T) { if out := AssertEqual(test.src, test.dst); test.out != out { t.Errorf("AssertEqual(%v, %v) want: %t, got: %t", test.src, test.dst, test.out, out) } }) } } func TestToString(t *testing.T) { tests := []struct { name string in interface{} out string }{ {"int", math.MaxInt64, "9223372036854775807"}, {"int8", int8(math.MaxInt8), "127"}, {"int16", int16(math.MaxInt16), "32767"}, {"int32", int32(math.MaxInt32), "2147483647"}, {"int64", int64(math.MaxInt64), "9223372036854775807"}, {"uint", uint(math.MaxUint64), "18446744073709551615"}, {"uint8", uint8(math.MaxUint8), "255"}, {"uint16", uint16(math.MaxUint16), "65535"}, {"uint32", uint32(math.MaxUint32), "4294967295"}, {"uint64", uint64(math.MaxUint64), "18446744073709551615"}, {"string", "abc", "abc"}, {"other", true, ""}, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { if out := ToString(test.in); test.out != out { t.Fatalf("ToString(%v) want: %s, got: %s", test.in, test.out, out) } }) } } func TestRTrimSlice(t *testing.T) { tests := []struct { name string input []int trimLen int expected []int }{ { name: "Trim two elements from end", input: []int{1, 2, 3, 4, 5}, trimLen: 2, expected: []int{1, 2, 3}, }, { name: "Trim entire slice", input: []int{1, 2, 3}, trimLen: 3, expected: []int{}, }, { name: "Trim length greater than slice length", input: []int{1, 2, 3}, trimLen: 5, expected: []int{}, }, { name: "Zero trim length", input: []int{1, 2, 3}, trimLen: 0, expected: []int{1, 2, 3}, }, { name: "Trim one element from end", input: []int{1, 2, 3}, trimLen: 1, expected: []int{1, 2}, }, { name: "Empty slice", input: []int{}, trimLen: 2, expected: []int{}, }, { name: "Negative trim length (should be treated as zero)", input: []int{1, 2, 3}, trimLen: -1, expected: []int{1, 2, 3}, }, } for _, testcase := range tests { t.Run(testcase.name, func(t *testing.T) { result := RTrimSlice(testcase.input, testcase.trimLen) if !AssertEqual(result, testcase.expected) { t.Errorf("RTrimSlice(%v, %d) = %v; want %v", testcase.input, testcase.trimLen, result, testcase.expected) } }) } } ================================================ FILE: utils/utils_unix_test.go ================================================ //go:build unix // +build unix package utils import ( "testing" ) func TestSourceDir(t *testing.T) { cases := []struct { file string want string }{ { file: "/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go", want: "/Users/name/go/pkg/mod/gorm.io/", }, { file: "/go/work/proj/gorm/utils/utils.go", want: "/go/work/proj/gorm/", }, { file: "/go/work/proj/gorm_alias/utils/utils.go", want: "/go/work/proj/gorm_alias/", }, { file: "/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go", want: "/go/work/proj/my.gorm.io/gorm@v1.2.3/", }, } for _, c := range cases { s := sourceDir(c.file) if s != c.want { t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) } } } ================================================ FILE: utils/utils_windows_test.go ================================================ package utils import ( "testing" ) func TestSourceDir(t *testing.T) { cases := []struct { file string want string }{ { file: `C:/Users/name/go/pkg/mod/gorm.io/gorm@v1.2.3/utils/utils.go`, want: `C:/Users/name/go/pkg/mod/gorm.io/`, }, { file: `C:/go/work/proj/gorm/utils/utils.go`, want: `C:/go/work/proj/gorm/`, }, { file: `C:/go/work/proj/gorm_alias/utils/utils.go`, want: `C:/go/work/proj/gorm_alias/`, }, { file: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/utils/utils.go`, want: `C:/go/work/proj/my.gorm.io/gorm@v1.2.3/`, }, } for _, c := range cases { s := sourceDir(c.file) if s != c.want { t.Fatalf("%s: expected %s, got %s", c.file, c.want, s) } } }